diff --git a/diracx-routers/pyproject.toml b/diracx-routers/pyproject.toml index e31edbca..58c9ff20 100644 --- a/diracx-routers/pyproject.toml +++ b/diracx-routers/pyproject.toml @@ -4,7 +4,7 @@ description = "TODO" readme = "README.md" requires-python = ">=3.10" keywords = [] -license = {text = "GPL-3.0-only"} +license = { text = "GPL-3.0-only" } classifiers = [ "Intended Audience :: Science/Research", "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", @@ -20,7 +20,7 @@ dependencies = [ "dirac", "diracx-core", "diracx-db", - "python-dotenv", # TODO: We might not need this + "python-dotenv", # TODO: We might not need this "python-multipart", "fastapi", "httpx", @@ -35,11 +35,7 @@ dependencies = [ dynamic = ["version"] [project.optional-dependencies] -testing = [ - "diracx-testing", - "moto[server]", - "pytest-httpx", -] +testing = ["diracx-testing", "moto[server]", "pytest-httpx"] types = [ "boto3-stubs", "types-aiobotocore[essential]", @@ -56,6 +52,11 @@ config = "diracx.routers.configuration:router" auth = "diracx.routers.auth:router" ".well-known" = "diracx.routers.auth.well_known:router" +[project.entry-points."diracx.access_policies"] +WMSAccessPolicy = "diracx.routers.job_manager.access_policies:WMSAccessPolicy" +SandboxAccessPolicy = "diracx.routers.job_manager.access_policies:SandboxAccessPolicy" + + [tool.setuptools.packages.find] where = ["src"] @@ -70,8 +71,10 @@ root = ".." testpaths = ["tests"] addopts = [ "-v", - "--cov=diracx.routers", "--cov-report=term-missing", - "-pdiracx.testing", "-pdiracx.testing.osdb", + "--cov=diracx.routers", + "--cov-report=term-missing", + "-pdiracx.testing", + "-pdiracx.testing.osdb", "--import-mode=importlib", ] asyncio_mode = "auto" diff --git a/diracx-routers/src/diracx/routers/__init__.py b/diracx-routers/src/diracx/routers/__init__.py index e39305c4..70a042dc 100644 --- a/diracx-routers/src/diracx/routers/__init__.py +++ b/diracx-routers/src/diracx/routers/__init__.py @@ -1,3 +1,11 @@ +""" +# Startup sequence + +uvicorn is called with `create_app` as a factory + +create_app loads the environment configuration +""" + from __future__ import annotations import inspect @@ -6,7 +14,7 @@ from collections.abc import AsyncGenerator from functools import partial from logging import Formatter, StreamHandler -from typing import Any, Awaitable, Callable, Iterable, TypeVar, cast +from typing import Any, Awaitable, Callable, Iterable, Sequence, TypeVar, cast import dotenv from cachetools import TTLCache @@ -28,10 +36,11 @@ from diracx.db.exceptions import DBUnavailable from diracx.db.os.utils import BaseOSDB from diracx.db.sql.utils import BaseSQLDB +from diracx.routers.access_policies import BaseAccessPolicy, check_permissions -from .auth import verify_dirac_access_token from .fastapi_classes import DiracFastAPI, DiracxRouter from .otel import instrument_otel +from .utils.users import verify_dirac_access_token T = TypeVar("T") T2 = TypeVar("T2", bound=BaseSQLDB | BaseOSDB) @@ -83,6 +92,7 @@ def configure_logger(): # All routes must have tags (needed for auto gen of client) # Form headers must have a description (autogen) # methods name should follow the generate_unique_id_function pattern +# All routes should have a policy mechanism def create_app_inner( @@ -92,21 +102,83 @@ def create_app_inner( database_urls: dict[str, str], os_database_conn_kwargs: dict[str, Any], config_source: ConfigSource, + all_access_policies: dict[str, Sequence[BaseAccessPolicy]], ) -> DiracFastAPI: + """ + This method does the heavy lifting work of putting all the pieces together. + + When starting the application normaly, this method is called by create_app, + and the values of the parameters are taken from environment variables or + entrypoints. + + When running tests, the parameters are mocks or test settings. + + We rely on the dependency_override mechanism to implement + the actual behavior we are interested in for settings, DBs or policy. + This allows an extension to override any of these components + + + :param enabled_system: + this contains the name of all the routers we have to load + :param all_service_settings: + list of instance of each Settings type required + :param database_urls: + dict . When testing, sqlite urls are used + :param os_database_conn_kwargs: + containing all the parameters the OpenSearch client takes + :param config_source: + Source of the configuration to use + :param all_access_policies: + + + + """ + app = DiracFastAPI() # Find which settings classes are available and add them to dependency_overrides + # We use a single instance of each Setting classes for performance reasons, + # since it avoids recreating a pydantic model every time + # We add the Settings lifetime_function to the application lifetime_function, + # Please see ServiceSettingsBase for more details + available_settings_classes: set[type[ServiceSettingsBase]] = set() for service_settings in all_service_settings: cls = type(service_settings) assert cls not in available_settings_classes available_settings_classes.add(cls) app.lifetime_functions.append(service_settings.lifetime_function) + # We always return the same setting instance for perf reasons app.dependency_overrides[cls.create] = partial(lambda x: x, service_settings) - # Override the configuration source + # Override the ConfigSource.create by the actual reading of the config app.dependency_overrides[ConfigSource.create] = config_source.read_config + all_access_policies_used = {} + + for access_policy_name, access_policy_classes in all_access_policies.items(): + + # The first AccessPolicy is the highest priority one + access_policy_used = access_policy_classes[0].policy + all_access_policies_used[access_policy_name] = access_policy_classes[0] + + # app.lifetime_functions.append(access_policy.lifetime_function) + # Add overrides for all the AccessPolicy classes, including those from extensions + # This means vanilla DiracX routers get an instance of the extension's AccessPolicy + for access_policy_class in access_policy_classes: + # Here we do not check that access_policy_class.check is + # not already in the dependency_overrides becaue the same + # policy could be used for multiple purpose + # (e.g. open access) + # assert access_policy_class.check not in app.dependency_overrides + app.dependency_overrides[access_policy_class.check] = partial( + check_permissions, access_policy_used, access_policy_name + ) + + app.dependency_overrides[BaseAccessPolicy.all_used_access_policies] = ( + lambda: all_access_policies_used + ) + fail_startup = True # Add the SQL DBs to the application available_sql_db_classes: set[type[BaseSQLDB]] = set() @@ -237,7 +309,22 @@ def create_app_inner( def create_app() -> DiracFastAPI: - """Load settings from the environment and create the application object""" + """Load settings from the environment and create the application object + + The configuration may be placed in .env files pointed to by + environment variables DIRACX_SERVICE_DOTENV. + They can be followed by "_X" where X is a number, and the order + is respected. + + We then loop over all the diracx.services definitions. + A specific route can be disabled with an environment variable + DIRACX_SERVICE__ENABLED=false + For each of the enabled route, we inspect which Setting classes + are needed. + + We attempt to load each setting classes to make sure that the + settings are correctly defined. + """ for env_file in dotenv_files_from_environment("DIRACX_SERVICE_DOTENV"): logger.debug("Loading dotenv file: %s", env_file) if not dotenv.load_dotenv(env_file): @@ -261,12 +348,31 @@ def create_app() -> DiracFastAPI: # Load settings classes required by the routers all_service_settings = [settings_class() for settings_class in settings_classes] + # Find all the access policies + + available_access_policy_names = set( + [ + entry_point.name + for entry_point in select_from_extension(group="diracx.access_policies") + ] + ) + + all_access_policies = {} + + for access_policy_name in available_access_policy_names: + + access_policy_classes = BaseAccessPolicy.available_implementations( + access_policy_name + ) + all_access_policies[access_policy_name] = access_policy_classes + return create_app_inner( enabled_systems=enabled_systems, all_service_settings=all_service_settings, database_urls=BaseSQLDB.available_urls(), os_database_conn_kwargs=BaseOSDB.available_urls(), config_source=ConfigSource.create(), + all_access_policies=all_access_policies, ) diff --git a/diracx-routers/src/diracx/routers/access_policies.py b/diracx-routers/src/diracx/routers/access_policies.py new file mode 100644 index 00000000..48e050a3 --- /dev/null +++ b/diracx-routers/src/diracx/routers/access_policies.py @@ -0,0 +1,159 @@ +""" + +AccessPolicy + +We define a set of Policy classes (WMS, DFC, etc). +They have a default implementation in diracx. +If an extension wants to change it, it can be overwriten in the entry point +diracx.access_policies + +Each route should either: +* have the open_access decorator to make explicit that it does not implement policy +* have a callable and call it that will perform the access policy + + +Adding a new policy: +1. Create a class that inherits from BaseAccessPolicy and implement the ``policy`` and ``enrich_tokens`` methods +2. create an entry in diracx.access_policy entrypoints +3. Create a dependency such as CheckMyPolicyCallable = Annotated[Callable, Depends(MyAccessPolicy.check)] + +""" + +import functools +import os +import time +from abc import ABCMeta, abstractmethod +from typing import Annotated, Callable, Self + +from fastapi import Depends + +from diracx.core.extensions import select_from_extension +from diracx.routers.utils.users import AuthorizedUserInfo, verify_dirac_access_token + +# FastAPI bug: +# We normally would use `from __future__ import annotations` +# but a bug in FastAPI prevents us from doing so +# https://github.com/tiangolo/fastapi/pull/11355 +# Until it is merged, we can work around it by using strings. + + +class BaseAccessPolicy(metaclass=ABCMeta): + """ + Base class to be used by all the other Access Policy. + + Each child class should implement the policy staticmethod. + """ + + @classmethod + def check(cls) -> Self: + """ + Placeholder which is in the dependency override + """ + raise NotImplementedError("This should never be called") + + @classmethod + def all_used_access_policies(cls) -> dict[str, "BaseAccessPolicy"]: + """ " Returns the list of classes that are actually called + (i.e. taking into account extensions) + This should be overriden by the dependency_override + """ + raise NotImplementedError("This should never be called") + + @classmethod + def available_implementations(cls, access_policy_name: str): + """Return the available implementations of the AccessPolicy in reverse priority order.""" + policy_classes: list[type["BaseAccessPolicy"]] = [ + entry_point.load() + for entry_point in select_from_extension( + group="diracx.access_policies", name=access_policy_name + ) + ] + if not policy_classes: + raise NotImplementedError( + f"Could not find any matches for {access_policy_name=}" + ) + return policy_classes + + @staticmethod + @abstractmethod + async def policy(policy_name: str, user_info: AuthorizedUserInfo, /): + """ + This is the method to be implemented in child classes. + It should always take an AuthorizedUserInfo parameter, which + is passed by check_permissions. + The rest is whatever the policy actually needs. There are rules to write it: + * This method must be static and async + * All parameters must be kw only arguments + * All parameters must have a default value (Liskov Substitution principle) + It is expected that a policy denying the access raises HTTPException(status.HTTP_403_FORBIDDEN) + """ + return + + @staticmethod + def enrich_tokens(access_payload: dict, refresh_payload: dict) -> tuple[dict, dict]: + """ + This method is called when issuing a token, and can add whatever + content it wants inside the access or refresh payload + + :param access_payload: access token payload + :param refresh_payload: refresh token payload + :returns: extra content for both payload + """ + return {}, {} + + +def check_permissions( + policy: Callable, + policy_name: str, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], +): + """ + This wrapper just calls the actual implementation, but also makes sure + that the policy has been called. + If not, diracx will abruptly crash. It is violent, but necessary to make + sure that it gets noticed :-) + + This method is never called directly, but used in the dependency_override + at startup + """ + + has_been_called = False + + @functools.wraps(policy) + async def wrapped_policy(**kwargs): + """This wrapper is just to update the has_been_called flag""" + nonlocal has_been_called + has_been_called = True + return await policy(policy_name, user_info, **kwargs) + + try: + yield wrapped_policy + finally: + if not has_been_called: + # TODO nice error message with inspect + # That should really not happen + print( + "THIS SHOULD NOT HAPPEN, ALWAYS VERIFY PERMISSION", + "(PS: I hope you are in a CI)", + flush=True, + ) + # Sleep a bit to make sure the flush happened + time.sleep(1) + os._exit(1) + + +def open_access(f): + """ + Decorator to put around the route that are part of a DiracxRouter + that are expected not to do any access policy check. + The presence of a token will still be checked if the router has require_auth to True. + This is useful to allow the CI to detect routes which may have forgotten + to have an access check + """ + f.diracx_open_access = True + + @functools.wraps(f) + def inner(*args, **kwargs): + return f(*args, **kwargs) + + return inner diff --git a/diracx-routers/src/diracx/routers/auth/__init__.py b/diracx-routers/src/diracx/routers/auth/__init__.py index 92ebf4e8..7d2a93a0 100644 --- a/diracx-routers/src/diracx/routers/auth/__init__.py +++ b/diracx-routers/src/diracx/routers/auth/__init__.py @@ -1,11 +1,12 @@ from __future__ import annotations from ..fastapi_classes import DiracxRouter +from ..utils.users import AuthorizedUserInfo, verify_dirac_access_token from .authorize_code_flow import router as authorize_code_flow_router from .device_flow import router as device_flow_router from .management import router as management_router from .token import router as token_router -from .utils import AuthorizedUserInfo, has_properties, verify_dirac_access_token +from .utils import has_properties router = DiracxRouter(require_auth=False) router.include_router(device_flow_router) diff --git a/diracx-routers/src/diracx/routers/auth/authorize_code_flow.py b/diracx-routers/src/diracx/routers/auth/authorize_code_flow.py index 4b9f3f76..3082b41a 100644 --- a/diracx-routers/src/diracx/routers/auth/authorize_code_flow.py +++ b/diracx-routers/src/diracx/routers/auth/authorize_code_flow.py @@ -48,8 +48,8 @@ Config, ) from ..fastapi_classes import DiracxRouter +from ..utils.users import AuthSettings from .utils import ( - AuthSettings, GrantType, decrypt_state, get_token_from_iam, diff --git a/diracx-routers/src/diracx/routers/auth/device_flow.py b/diracx-routers/src/diracx/routers/auth/device_flow.py index 84c07270..c2ee729b 100644 --- a/diracx-routers/src/diracx/routers/auth/device_flow.py +++ b/diracx-routers/src/diracx/routers/auth/device_flow.py @@ -70,8 +70,8 @@ Config, ) from ..fastapi_classes import DiracxRouter +from ..utils.users import AuthSettings from .utils import ( - AuthSettings, GrantType, decrypt_state, get_token_from_iam, diff --git a/diracx-routers/src/diracx/routers/auth/management.py b/diracx-routers/src/diracx/routers/auth/management.py index df66e7c6..a6d0a204 100644 --- a/diracx-routers/src/diracx/routers/auth/management.py +++ b/diracx-routers/src/diracx/routers/auth/management.py @@ -4,7 +4,7 @@ to get information about the user's identity. """ -from typing import Annotated, TypedDict +from typing import Annotated, Any, TypedDict from fastapi import ( Depends, @@ -21,7 +21,7 @@ AuthDB, ) from ..fastapi_classes import DiracxRouter -from .utils import AuthorizedUserInfo, verify_dirac_access_token +from ..utils.users import AuthorizedUserInfo, verify_dirac_access_token router = DiracxRouter(require_auth=False) @@ -32,6 +32,7 @@ class UserInfoResponse(TypedDict): sub: str vo: str dirac_group: str + policies: dict[str, Any] properties: list[SecurityProperty] preferred_username: str @@ -84,5 +85,6 @@ async def userinfo( "vo": user_info.vo, "dirac_group": user_info.dirac_group, "properties": user_info.properties, + "policies": user_info.policies, "preferred_username": user_info.preferred_username, } diff --git a/diracx-routers/src/diracx/routers/auth/token.py b/diracx-routers/src/diracx/routers/auth/token.py index e1f5d0a0..282cef28 100644 --- a/diracx-routers/src/diracx/routers/auth/token.py +++ b/diracx-routers/src/diracx/routers/auth/token.py @@ -10,7 +10,7 @@ from uuid import uuid4 from authlib.jose import JsonWebToken -from fastapi import Form, Header, HTTPException, status +from fastapi import Depends, Form, Header, HTTPException, status from diracx.core.exceptions import ( DiracHttpResponse, @@ -19,16 +19,13 @@ ) from diracx.core.models import TokenResponse from diracx.db.sql.auth.schema import FlowStatus, RefreshTokenStatus +from diracx.routers.access_policies import BaseAccessPolicy from diracx.routers.auth.utils import GrantType -from ..dependencies import ( - AuthDB, - AvailableSecurityProperties, - Config, -) +from ..dependencies import AuthDB, AvailableSecurityProperties, Config from ..fastapi_classes import DiracxRouter +from ..utils.users import AuthSettings from .utils import ( - AuthSettings, parse_and_validate_scope, verify_dirac_refresh_token, ) @@ -51,6 +48,9 @@ async def token( config: Config, settings: AuthSettings, available_properties: AvailableSecurityProperties, + all_access_policies: Annotated[ + dict[str, BaseAccessPolicy], Depends(BaseAccessPolicy.all_used_access_policies) + ], device_code: Annotated[ str | None, Form(description="device code for OAuth2 device flow") ] = None, @@ -99,6 +99,7 @@ async def token( raise NotImplementedError(f"Grant type not implemented {grant_type}") # Get a TokenResponse to return to the user + return await exchange_token( auth_db, scope, @@ -106,6 +107,7 @@ async def token( config, settings, available_properties, + all_access_policies=all_access_policies, legacy_exchange=legacy_exchange, ) @@ -267,6 +269,9 @@ async def legacy_exchange( available_properties: AvailableSecurityProperties, settings: AuthSettings, config: Config, + all_access_policies: Annotated[ + dict[str, BaseAccessPolicy], Depends(BaseAccessPolicy.all_used_access_policies) + ], expires_minutes: int | None = None, ): """Endpoint used by legacy DIRAC to mint tokens for proxy -> token exchange. @@ -334,6 +339,7 @@ async def legacy_exchange( config, settings, available_properties, + all_access_policies=all_access_policies, refresh_token_expire_minutes=expires_minutes, legacy_exchange=True, ) @@ -346,6 +352,9 @@ async def exchange_token( config: Config, settings: AuthSettings, available_properties: AvailableSecurityProperties, + all_access_policies: Annotated[ + dict[str, BaseAccessPolicy], Depends(BaseAccessPolicy.all_used_access_policies) + ], *, refresh_token_expire_minutes: int | None = None, legacy_exchange: bool = False, @@ -411,6 +420,22 @@ async def exchange_token( "exp": creation_time + timedelta(minutes=settings.access_token_expire_minutes), } + # Enrich the token payload with policy specific content + dirac_access_policies = {} + dirac_refresh_policies = {} + for policy_name, policy in all_access_policies.items(): + + access_extra, refresh_extra = policy.enrich_tokens( + access_payload, refresh_payload + ) + if access_extra: + dirac_access_policies[policy_name] = access_extra + if refresh_extra: + dirac_refresh_policies[policy_name] = refresh_extra + + access_payload["dirac_policies"] = dirac_access_policies + refresh_payload["dirac_policies"] = dirac_refresh_policies + # Generate the token: encode the payloads access_token = create_token(access_payload, settings) refresh_token = create_token(refresh_payload, settings) diff --git a/diracx-routers/src/diracx/routers/auth/utils.py b/diracx-routers/src/diracx/routers/auth/utils.py index 14629b3d..531b2a32 100644 --- a/diracx-routers/src/diracx/routers/auth/utils.py +++ b/diracx-routers/src/diracx/routers/auth/utils.py @@ -1,11 +1,9 @@ import base64 import hashlib import json -import re import secrets from enum import StrEnum from typing import Annotated, TypedDict -from uuid import UUID import httpx from authlib.integrations.starlette_client import OAuthError @@ -14,58 +12,18 @@ from cachetools import TTLCache from cryptography.fernet import Fernet from fastapi import Depends, HTTPException, status -from fastapi.security import OpenIdConnect -from pydantic import BaseModel, Field -from diracx.core.models import UserInfo from diracx.core.properties import ( SecurityProperty, UnevaluatedProperty, ) -from diracx.core.settings import FernetKey, ServiceSettingsBase, TokenSigningKey - -from ..dependencies import Config, add_settings_annotation - - -@add_settings_annotation -class AuthSettings(ServiceSettingsBase, env_prefix="DIRACX_SERVICE_AUTH_"): - """Settings for the authentication service.""" - - dirac_client_id: str = "myDIRACClientID" - # TODO: This should be taken dynamically - # ["http://pclhcb211:8000/docs/oauth2-redirect"] - allowed_redirects: list[str] = [] - device_flow_expiration_seconds: int = 600 - authorization_flow_expiration_seconds: int = 300 - - # State key is used to encrypt/decrypt the state dict passed to the IAM - state_key: FernetKey - - token_issuer: str = "http://lhcbdirac.cern.ch/" - token_key: TokenSigningKey - token_algorithm: str = "RS256" - access_token_expire_minutes: int = 20 - refresh_token_expire_minutes: int = 60 - - available_properties: set[SecurityProperty] = Field( - default_factory=SecurityProperty.available_properties - ) - - -class AuthInfo(BaseModel): - # raw token for propagation - bearer_token: str - - # token ID in the DB for Component - # unique jwt identifier for user - token_id: UUID - - # list of DIRAC properties - properties: list[SecurityProperty] - +from diracx.routers.utils.users import ( + AuthorizedUserInfo, + AuthSettings, + verify_dirac_access_token, +) -class AuthorizedUserInfo(AuthInfo, UserInfo): - pass +from ..dependencies import Config class GrantType(StrEnum): @@ -177,65 +135,6 @@ def decrypt_state(state: str, cipher_suite: Fernet) -> dict[str, str]: ) from e -# auto_error=False is used to avoid raising the wrong exception when the token is missing -# The error is handled in the verify_dirac_access_token function -# More info: -# - https://github.com/tiangolo/fastapi/issues/10177 -# - https://datatracker.ietf.org/doc/html/rfc6750#section-3.1 -oidc_scheme = OpenIdConnect( - openIdConnectUrl="/.well-known/openid-configuration", auto_error=False -) - - -async def verify_dirac_access_token( - authorization: Annotated[str, Depends(oidc_scheme)], - settings: AuthSettings, -) -> AuthorizedUserInfo: - """Verify dirac user token and return a UserInfo class - Used for each API endpoint - """ - if not authorization: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Authorization header is missing", - headers={"WWW-Authenticate": "Bearer"}, - ) - if match := re.fullmatch(r"Bearer (.+)", authorization): - raw_token = match.group(1) - else: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid authorization header", - ) - - try: - jwt = JsonWebToken(settings.token_algorithm) - token = jwt.decode( - raw_token, - key=settings.token_key.jwk, - claims_options={ - "iss": {"values": [settings.token_issuer]}, - }, - ) - token.validate() - except JoseError: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid JWT", - headers={"WWW-Authenticate": "Bearer"}, - ) from None - - return AuthorizedUserInfo( - bearer_token=raw_token, - token_id=token["jti"], - properties=token["dirac_properties"], - sub=token["sub"], - preferred_username=token["preferred_username"], - dirac_group=token["dirac_group"], - vo=token["vo"], - ) - - async def verify_dirac_refresh_token( refresh_token: str, settings: AuthSettings, diff --git a/diracx-routers/src/diracx/routers/auth/utils.py.orig b/diracx-routers/src/diracx/routers/auth/utils.py.orig new file mode 100644 index 00000000..17cf5e12 --- /dev/null +++ b/diracx-routers/src/diracx/routers/auth/utils.py.orig @@ -0,0 +1,391 @@ +import base64 +import hashlib +import json +import secrets +from enum import StrEnum +from typing import Annotated, TypedDict + +import httpx +from authlib.integrations.starlette_client import OAuthError +from authlib.jose import JoseError, JsonWebKey, JsonWebToken +from authlib.oidc.core import IDToken +from cachetools import TTLCache +from cryptography.fernet import Fernet +from fastapi import Depends, HTTPException, status + +from diracx.core.properties import ( + SecurityProperty, + UnevaluatedProperty, +) +from diracx.routers.utils.users import ( + AuthorizedUserInfo, + AuthSettings, + verify_dirac_access_token, +) + +from ..dependencies import Config + + +class GrantType(StrEnum): + """Grant types for OAuth2.""" + + authorization_code = "authorization_code" + device_code = "urn:ietf:params:oauth:grant-type:device_code" + refresh_token = "refresh_token" # noqa: S105 # False positive of Bandit about hard coded password + + +class ScopeInfoDict(TypedDict): + group: str + properties: list[str] + vo: str + + +def has_properties(expression: UnevaluatedProperty | SecurityProperty): + """Check if the user has the given properties.""" + evaluator = ( + expression + if isinstance(expression, UnevaluatedProperty) + else UnevaluatedProperty(expression) + ) + + async def require_property( + user: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)] + ): + if not evaluator(user.properties): + raise HTTPException(status.HTTP_403_FORBIDDEN) + + return Depends(require_property) + + +_server_metadata_cache: TTLCache = TTLCache(maxsize=1024, ttl=3600) + + +async def get_server_metadata(url: str): + """Get the server metadata from the IAM.""" + server_metadata = _server_metadata_cache.get(url) + if server_metadata is None: + async with httpx.AsyncClient() as c: + res = await c.get(url) + if res.status_code != 200: + # TODO: Better error handling + raise NotImplementedError(res) + server_metadata = res.json() + _server_metadata_cache[url] = server_metadata + return server_metadata + + +async def fetch_jwk_set(url: str): + """Fetch the JWK set from the IAM.""" + server_metadata = await get_server_metadata(url) + + jwks_uri = server_metadata.get("jwks_uri") + if not jwks_uri: + raise RuntimeError('Missing "jwks_uri" in metadata') + + async with httpx.AsyncClient() as c: + res = await c.get(jwks_uri) + if res.status_code != 200: + # TODO: Better error handling + raise NotImplementedError(res) + jwk_set = res.json() + + # self.server_metadata['jwks'] = jwk_set + return JsonWebKey.import_key_set(jwk_set) + + +async def parse_id_token(config, vo, raw_id_token: str): + """Parse and validate the ID token from IAM.""" + server_metadata = await get_server_metadata( + config.Registry[vo].IdP.server_metadata_url + ) + alg_values = server_metadata.get("id_token_signing_alg_values_supported", ["RS256"]) + jwk_set = await fetch_jwk_set(config.Registry[vo].IdP.server_metadata_url) + + token = JsonWebToken(alg_values).decode( + raw_id_token, + key=jwk_set, + claims_cls=IDToken, + claims_options={ + "iss": {"values": [server_metadata["issuer"]]}, + # The audience is a required parameter and is the client ID of the application + # https://openid.net/specs/openid-connect-core-1_0.html#IDToken + "aud": {"values": [config.Registry[vo].IdP.ClientID]}, + }, + ) + token.validate() + return token + + +def encrypt_state(state_dict: dict[str, str], cipher_suite: Fernet) -> str: + """Encrypt the state dict and return it as a string""" + return cipher_suite.encrypt( + base64.urlsafe_b64encode(json.dumps(state_dict).encode()) + ).decode() + + +def decrypt_state(state: str, cipher_suite: Fernet) -> dict[str, str]: + """Decrypt the state string and return it as a dict""" + try: + return json.loads( + base64.urlsafe_b64decode(cipher_suite.decrypt(state.encode())).decode() + ) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid state" + ) from e + + +<<<<<<< HEAD +# auto_error=False is used to avoid raising the wrong exception when the token is missing +# The error is handled in the verify_dirac_access_token function +# More info: +# - https://github.com/tiangolo/fastapi/issues/10177 +# - https://datatracker.ietf.org/doc/html/rfc6750#section-3.1 +oidc_scheme = OpenIdConnect( + openIdConnectUrl="/.well-known/openid-configuration", auto_error=False +) + + +async def verify_dirac_access_token( + authorization: Annotated[str, Depends(oidc_scheme)], + settings: AuthSettings, +) -> AuthorizedUserInfo: + """Verify dirac user token and return a UserInfo class + Used for each API endpoint + """ + if not authorization: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Authorization header is missing", + headers={"WWW-Authenticate": "Bearer"}, + ) + if match := re.fullmatch(r"Bearer (.+)", authorization): + raw_token = match.group(1) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid authorization header", + ) + + try: + jwt = JsonWebToken(settings.token_algorithm) + token = jwt.decode( + raw_token, + key=settings.token_key.jwk, + claims_options={ + "iss": {"values": [settings.token_issuer]}, + }, + ) + token.validate() + except JoseError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid JWT", + headers={"WWW-Authenticate": "Bearer"}, + ) from None + + return AuthorizedUserInfo( + bearer_token=raw_token, + token_id=token["jti"], + properties=token["dirac_properties"], + sub=token["sub"], + preferred_username=token["preferred_username"], + dirac_group=token["dirac_group"], + vo=token["vo"], + ) + + +======= +>>>>>>> cacc536 (Implement Policy Mechanism) +async def verify_dirac_refresh_token( + refresh_token: str, + settings: AuthSettings, +) -> tuple[str, float, bool]: + """Verify dirac user token and return a UserInfo class + Used for each API endpoint + """ + try: + jwt = JsonWebToken(settings.token_algorithm) + token = jwt.decode( + refresh_token, + key=settings.token_key.jwk, + ) + token.validate() + # Handle problematic tokens such as: + # - tokens signed with an invalid JWK + # - expired tokens + except JoseError as e: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=f"Invalid JWT: {e.args[0]}", + headers={"WWW-Authenticate": "Bearer"}, + ) from e + + return (token["jti"], float(token["exp"]), token["legacy_exchange"]) + + +def parse_and_validate_scope( + scope: str, config: Config, available_properties: set[SecurityProperty] +) -> ScopeInfoDict: + """ + Check: + * At most one VO + * At most one group + * group belongs to VO + * properties are known + return dict with group and properties + + :raises: + * ValueError in case the scope isn't valide + """ + scopes = set(scope.split(" ")) + + groups = [] + properties = [] + vos = [] + unrecognised = [] + for scope in scopes: + if scope.startswith("group:"): + groups.append(scope.split(":", 1)[1]) + elif scope.startswith("property:"): + properties.append(scope.split(":", 1)[1]) + elif scope.startswith("vo:"): + vos.append(scope.split(":", 1)[1]) + else: + unrecognised.append(scope) + if unrecognised: + raise ValueError(f"Unrecognised scopes: {unrecognised}") + + if not vos: + available_vo_scopes = [repr(f"vo:{vo}") for vo in config.Registry] + raise ValueError( + "No vo scope requested, available values: " + f"{' '.join(available_vo_scopes)}" + ) + elif len(vos) > 1: + raise ValueError(f"Only one vo is allowed but got {vos}") + else: + vo = vos[0] + if vo not in config.Registry: + raise ValueError(f"VO {vo} is not known to this installation") + + if not groups: + # TODO: Handle multiple groups correctly + group = config.Registry[vo].DefaultGroup + elif len(groups) > 1: + raise ValueError(f"Only one DIRAC group allowed but got {groups}") + else: + group = groups[0] + if group not in config.Registry[vo].Groups: + raise ValueError(f"{group} not in {vo} groups") + + allowed_properties = config.Registry[vo].Groups[group].Properties + if not properties: + # If there are no properties set get the defaults from the CS + properties = [str(p) for p in allowed_properties] + + if not set(properties).issubset(available_properties): + raise ValueError( + f"{set(properties)-set(available_properties)} are not valid properties" + ) + + if not set(properties).issubset(allowed_properties): + raise PermissionError( + f"Attempted to access properties {set(properties)-set(allowed_properties)} which are not allowed." + f" Allowed properties are: {allowed_properties}" + ) + + return { + "group": group, + "properties": sorted(properties), + "vo": vo, + } + + +async def initiate_authorization_flow_with_iam( + config, vo: str, redirect_uri: str, state: dict[str, str], cipher_suite: Fernet +): + """Initiate the authorization flow with the IAM. Return the URL to redirect the user to. + + The state dict is encrypted and passed to the IAM. + It is then decrypted when the user is redirected back to the redirect_uri. + """ + # code_verifier: https://www.rfc-editor.org/rfc/rfc7636#section-4.1 + code_verifier = secrets.token_hex() + + # code_challenge: https://www.rfc-editor.org/rfc/rfc7636#section-4.2 + code_challenge = ( + base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()) + .decode() + .replace("=", "") + ) + + server_metadata = await get_server_metadata( + config.Registry[vo].IdP.server_metadata_url + ) + + # Take these two from CS/.well-known + authorization_endpoint = server_metadata["authorization_endpoint"] + + # Encrypt the state and pass it to the IAM + # Needed to retrieve the original flow details when the user is redirected back to the redirect_uri + encrypted_state = encrypt_state( + state | {"vo": vo, "code_verifier": code_verifier}, cipher_suite + ) + + urlParams = [ + "response_type=code", + f"code_challenge={code_challenge}", + "code_challenge_method=S256", + f"client_id={config.Registry[vo].IdP.ClientID}", + f"redirect_uri={redirect_uri}", + "scope=openid%20profile", + f"state={encrypted_state}", + ] + authorization_flow_url = f"{authorization_endpoint}?{'&'.join(urlParams)}" + return authorization_flow_url + + +async def get_token_from_iam( + config, vo: str, code: str, state: dict[str, str], redirect_uri: str +) -> dict[str, str]: + """Get the token from the IAM using the code and state. Return the ID token.""" + server_metadata = await get_server_metadata( + config.Registry[vo].IdP.server_metadata_url + ) + + # Take these two from CS/.well-known + token_endpoint = server_metadata["token_endpoint"] + + data = { + "grant_type": GrantType.authorization_code.value, + "client_id": config.Registry[vo].IdP.ClientID, + "code": code, + "code_verifier": state["code_verifier"], + "redirect_uri": redirect_uri, + } + + async with httpx.AsyncClient() as c: + res = await c.post( + token_endpoint, + data=data, + ) + if res.status_code >= 500: + raise HTTPException( + status.HTTP_502_BAD_GATEWAY, "Failed to contact token endpoint" + ) + elif res.status_code >= 400: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid code") + + raw_id_token = res.json()["id_token"] + # Extract the payload and verify it + try: + id_token = await parse_id_token( + config=config, + vo=vo, + raw_id_token=raw_id_token, + ) + except OAuthError: + raise + + return id_token diff --git a/diracx-routers/src/diracx/routers/auth/well_known.py b/diracx-routers/src/diracx/routers/auth/well_known.py index 02486959..9582b1c5 100644 --- a/diracx-routers/src/diracx/routers/auth/well_known.py +++ b/diracx-routers/src/diracx/routers/auth/well_known.py @@ -6,7 +6,9 @@ from ..dependencies import Config from ..fastapi_classes import DiracxRouter -from .utils import AuthSettings +from ..utils.users import AuthSettings + +# from ..access_policies import OpenAccessPolicyCallable router = DiracxRouter(require_auth=False, path_root="") @@ -16,8 +18,10 @@ async def openid_configuration( request: Request, config: Config, settings: AuthSettings, + # check_permissions: OpenAccessPolicyCallable, ): """OpenID Connect discovery endpoint.""" + # await check_permissions() scopes_supported = [] for vo in config.Registry: scopes_supported.append(f"vo:{vo}") @@ -65,8 +69,12 @@ class Metadata(TypedDict): @router.get("/dirac-metadata") -async def installation_metadata(config: Config) -> Metadata: +async def installation_metadata( + config: Config, + # check_permissions: OpenAccessPolicyCallable, +) -> Metadata: """Get metadata about the dirac installation.""" + # await check_permissions() metadata: Metadata = { "virtual_organizations": {}, } diff --git a/diracx-routers/src/diracx/routers/configuration.py b/diracx-routers/src/diracx/routers/configuration.py index 412ed9ff..c550501f 100644 --- a/diracx-routers/src/diracx/routers/configuration.py +++ b/diracx-routers/src/diracx/routers/configuration.py @@ -10,6 +10,7 @@ status, ) +from .access_policies import open_access from .dependencies import Config from .fastapi_classes import DiracxRouter @@ -18,10 +19,12 @@ router = DiracxRouter() +@open_access @router.get("/") async def serve_config( config: Config, response: Response, + # check_permissions: OpenAccessPolicyCallable, if_none_match: Annotated[str | None, Header()] = None, if_modified_since: Annotated[str | None, Header()] = None, ): @@ -34,6 +37,7 @@ async def serve_config( If If-Modified-Since is given and is newer than latest, return 304: this is to avoid flip/flopping """ + # await check_permissions() headers = { "ETag": config._hexsha, "Last-Modified": config._modified.strftime(LAST_MODIFIED_FORMAT), diff --git a/diracx-routers/src/diracx/routers/job_manager/__init__.py b/diracx-routers/src/diracx/routers/job_manager/__init__.py index 2794f2fc..6c42ead4 100644 --- a/diracx-routers/src/diracx/routers/job_manager/__init__.py +++ b/diracx-routers/src/diracx/routers/job_manager/__init__.py @@ -22,7 +22,6 @@ SetJobStatusReturn, SortSpec, ) -from diracx.core.properties import JOB_ADMINISTRATOR, NORMAL_USER from diracx.db.sql.jobs.status_utility import ( delete_jobs, kill_jobs, @@ -30,16 +29,17 @@ set_job_status, ) -from ..auth import AuthorizedUserInfo, has_properties, verify_dirac_access_token from ..dependencies import JobDB, JobLoggingDB, SandboxMetadataDB, TaskQueueDB from ..fastapi_classes import DiracxRouter +from ..utils.users import AuthorizedUserInfo, verify_dirac_access_token +from .access_policies import ActionType, CheckWMSPolicyCallable from .sandboxes import router as sandboxes_router MAX_PARAMETRIC_JOBS = 20 logger = logging.getLogger(__name__) -router = DiracxRouter(dependencies=[has_properties(NORMAL_USER | JOB_ADMINISTRATOR)]) +router = DiracxRouter() router.include_router(sandboxes_router) @@ -116,7 +116,10 @@ async def submit_bulk_jobs( job_db: JobDB, job_logging_db: JobLoggingDB, user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + check_permissions: CheckWMSPolicyCallable, ) -> list[InsertedJob]: + await check_permissions(action=ActionType.CREATE, job_db=job_db) + from DIRAC.Core.Utilities.ClassAd.ClassAdLight import ClassAd from DIRAC.Core.Utilities.ReturnValues import returnValueOrRaise from DIRAC.WorkloadManagementSystem.Service.JobPolicy import RIGHT_SUBMIT, JobPolicy @@ -250,7 +253,10 @@ async def delete_bulk_jobs( job_logging_db: JobLoggingDB, task_queue_db: TaskQueueDB, background_task: BackgroundTasks, + check_permissions: CheckWMSPolicyCallable, ): + + await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=job_ids) # TODO: implement job policy try: @@ -285,7 +291,9 @@ async def kill_bulk_jobs( job_logging_db: JobLoggingDB, task_queue_db: TaskQueueDB, background_task: BackgroundTasks, + check_permissions: CheckWMSPolicyCallable, ): + await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=job_ids) # TODO: implement job policy try: await kill_jobs( @@ -320,6 +328,7 @@ async def remove_bulk_jobs( sandbox_metadata_db: SandboxMetadataDB, task_queue_db: TaskQueueDB, background_task: BackgroundTasks, + check_permissions: CheckWMSPolicyCallable, ): """ Fully remove a list of jobs from the WMS databases. @@ -329,6 +338,7 @@ async def remove_bulk_jobs( and the JobCleaningAgent. However, once this agent is ported to diracx, this endpoint should be removed, and the delete endpoint should be used instead for any other purpose. """ + await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=job_ids) # TODO: Remove once legacy DIRAC no longer needs this # TODO: implement job policy @@ -350,8 +360,11 @@ async def remove_bulk_jobs( @router.get("/status") async def get_job_status_bulk( - job_ids: Annotated[list[int], Query()], job_db: JobDB + job_ids: Annotated[list[int], Query()], + job_db: JobDB, + check_permissions: CheckWMSPolicyCallable, ) -> dict[int, LimitedJobStatusReturn]: + await check_permissions(action=ActionType.READ, job_db=job_db, job_ids=job_ids) try: result = await asyncio.gather( *(job_db.get_job_status(job_id) for job_id in job_ids) @@ -366,8 +379,12 @@ async def set_job_status_bulk( job_update: dict[int, dict[datetime, JobStatusUpdate]], job_db: JobDB, job_logging_db: JobLoggingDB, + check_permissions: CheckWMSPolicyCallable, force: bool = False, ) -> dict[int, SetJobStatusReturn]: + await check_permissions( + action=ActionType.MANAGE, job_db=job_db, job_ids=list(job_update) + ) # check that the datetime contains timezone info for job_id, status in job_update.items(): for dt in status: @@ -388,8 +405,12 @@ async def set_job_status_bulk( @router.get("/status/history") async def get_job_status_history_bulk( - job_ids: Annotated[list[int], Query()], job_logging_db: JobLoggingDB + job_ids: Annotated[list[int], Query()], + job_logging_db: JobLoggingDB, + job_db: JobDB, + check_permissions: CheckWMSPolicyCallable, ) -> dict[int, list[JobStatusReturn]]: + await check_permissions(action=ActionType.READ, job_db=job_db, job_ids=job_ids) result = await asyncio.gather( *(job_logging_db.get_records(job_id) for job_id in job_ids) ) @@ -401,8 +422,9 @@ async def reschedule_bulk_jobs( job_ids: Annotated[list[int], Query()], job_db: JobDB, job_logging_db: JobLoggingDB, - user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + check_permissions: CheckWMSPolicyCallable, ): + await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=job_ids) rescheduled_jobs = [] # TODO: Joblist Policy: # validJobList, invalidJobList, nonauthJobList, ownerJobList = self.jobPolicy.evaluateJobRights( @@ -451,8 +473,9 @@ async def reschedule_bulk_jobs( async def reschedule_single_job( job_id: int, job_db: JobDB, - user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + check_permissions: CheckWMSPolicyCallable, ): + await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=[job_id]) try: result = await job_db.rescheduleJob(job_id) except ValueError as e: @@ -522,6 +545,7 @@ async def search( config: Annotated[Config, Depends(ConfigSource.create)], job_db: JobDB, user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + check_permissions: CheckWMSPolicyCallable, page: int = 0, per_page: int = 100, body: Annotated[ @@ -532,6 +556,7 @@ async def search( **TODO: Add more docs** """ + await check_permissions(action=ActionType.QUERY, job_db=job_db) if body is None: body = JobSearchParams() # TODO: Apply all the job policy stuff properly using user_info @@ -560,8 +585,10 @@ async def summary( job_db: JobDB, user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], body: JobSummaryParams, + check_permissions: CheckWMSPolicyCallable, ): """Show information suitable for plotting""" + await check_permissions(action=ActionType.QUERY, job_db=job_db) # TODO: Apply all the job policy stuff properly using user_info if not config.Operations["Defaults"].Services.JobMonitoring.GlobalJobsInfo: body.search.append( @@ -575,7 +602,12 @@ async def summary( @router.get("/{job_id}") -async def get_single_job(job_id: int): +async def get_single_job( + job_id: int, + job_db: JobDB, + check_permissions: CheckWMSPolicyCallable, +): + await check_permissions(action=ActionType.READ, job_db=job_db, job_ids=[job_id]) return f"This job {job_id}" @@ -587,10 +619,12 @@ async def delete_single_job( job_logging_db: JobLoggingDB, task_queue_db: TaskQueueDB, background_task: BackgroundTasks, + check_permissions: CheckWMSPolicyCallable, ): """ Delete a job by killing and setting the job status to DELETED. """ + await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=[job_id]) # TODO: implement job policy try: @@ -618,10 +652,12 @@ async def kill_single_job( job_logging_db: JobLoggingDB, task_queue_db: TaskQueueDB, background_task: BackgroundTasks, + check_permissions: CheckWMSPolicyCallable, ): """ Kill a job. """ + await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=[job_id]) # TODO: implement job policy @@ -646,6 +682,7 @@ async def remove_single_job( sandbox_metadata_db: SandboxMetadataDB, task_queue_db: TaskQueueDB, background_task: BackgroundTasks, + check_permissions: CheckWMSPolicyCallable, ): """ Fully remove a job from the WMS databases. @@ -654,6 +691,8 @@ async def remove_single_job( and the JobCleaningAgent. However, once this agent is ported to diracx, this endpoint should be removed, and the delete endpoint should be used instead. """ + + await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=[job_id]) # TODO: Remove once legacy DIRAC no longer needs this # TODO: implement job policy @@ -673,8 +712,11 @@ async def remove_single_job( @router.get("/{job_id}/status") async def get_single_job_status( - job_id: int, job_db: JobDB + job_id: int, + job_db: JobDB, + check_permissions: CheckWMSPolicyCallable, ) -> dict[int, LimitedJobStatusReturn]: + await check_permissions(action=ActionType.READ, job_db=job_db, job_ids=[job_id]) try: status = await job_db.get_job_status(job_id) except JobNotFound as e: @@ -690,8 +732,10 @@ async def set_single_job_status( status: Annotated[dict[datetime, JobStatusUpdate], Body()], job_db: JobDB, job_logging_db: JobLoggingDB, + check_permissions: CheckWMSPolicyCallable, force: bool = False, ) -> dict[int, SetJobStatusReturn]: + await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=[job_id]) # check that the datetime contains timezone info for dt in status: if dt.tzinfo is None: @@ -712,8 +756,11 @@ async def set_single_job_status( @router.get("/{job_id}/status/history") async def get_single_job_status_history( job_id: int, + job_db: JobDB, job_logging_db: JobLoggingDB, + check_permissions: CheckWMSPolicyCallable, ) -> dict[int, list[JobStatusReturn]]: + await check_permissions(action=ActionType.READ, job_db=job_db, job_ids=[job_id]) try: status = await job_logging_db.get_records(job_id) except JobNotFound as e: @@ -728,12 +775,15 @@ async def set_single_job_properties( job_id: int, job_properties: Annotated[dict[str, Any], Body()], job_db: JobDB, + check_permissions: CheckWMSPolicyCallable, update_timestamp: bool = False, ): """ Update the given job properties (MinorStatus, ApplicationStatus, etc) """ + await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=[job_id]) + rowcount = await job_db.set_properties( {job_id: job_properties}, update_timestamp=update_timestamp ) diff --git a/diracx-routers/src/diracx/routers/job_manager/access_policies.py b/diracx-routers/src/diracx/routers/job_manager/access_policies.py new file mode 100644 index 00000000..80f545a2 --- /dev/null +++ b/diracx-routers/src/diracx/routers/job_manager/access_policies.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +from enum import StrEnum, auto +from typing import Annotated, Callable + +from fastapi import Depends, HTTPException, status + +from diracx.core.properties import JOB_ADMINISTRATOR, NORMAL_USER +from diracx.db.sql import JobDB, SandboxMetadataDB +from diracx.routers.access_policies import BaseAccessPolicy + +from ..utils.users import AuthorizedUserInfo + + +class ActionType(StrEnum): + #: Create a job or a sandbox + CREATE = auto() + #: Check job status, download a sandbox + READ = auto() + #: delete, kill, remove, set status, etc of a job + #: delete or assign a sandbox + MANAGE = auto() + #: Search + QUERY = auto() + + +class WMSAccessPolicy(BaseAccessPolicy): + """ + Rules: + * You need either NORMAL_USER or JOB_ADMINISTRATOR in your properties + * An admin cannot create any resource but can read everything and modify everything + * A NORMAL_USER can create + * a NORMAL_USER can query and read only his own jobs + """ + + @staticmethod + async def policy( + policy_name: str, + user_info: AuthorizedUserInfo, + /, + *, + action: ActionType | None = None, + job_db: JobDB | None = None, + job_ids: list[int] | None = None, + ): + assert action, "action is a mandatory parameter" + assert job_db, "job_db is a mandatory parameter" + + if action == ActionType.CREATE: + if job_ids is not None: + raise NotImplementedError( + "job_ids is not None with ActionType.CREATE. This shouldn't happen" + ) + if NORMAL_USER not in user_info.properties: + raise HTTPException(status.HTTP_403_FORBIDDEN) + return + + if JOB_ADMINISTRATOR in user_info.properties: + return + + if NORMAL_USER not in user_info.properties: + raise HTTPException(status.HTTP_403_FORBIDDEN) + + if action == ActionType.QUERY: + if job_ids is not None: + raise NotImplementedError( + "job_ids is not None with ActionType.QUERY. This shouldn't happen" + ) + return + + if job_ids is None: + raise NotImplementedError("job_ids is None. his shouldn't happen") + + # TODO: check the CS global job monitoring flag + + # Now we know we are either in READ/MODIFY for a NORMAL_USER + # so just make sure that whatever job_id was given belongs + # to the current user + job_owners = await job_db.summary( + ["Owner", "VO"], + [{"parameter": "JobID", "operator": "in", "values": job_ids}], + ) + + expected_owner = { + "Owner": user_info.preferred_username, + "VO": user_info.vo, + "count": len(set(job_ids)), + } + # All the jobs belong to the user doing the query + # and all of them are present + if job_owners == [expected_owner]: + return + + raise HTTPException(status.HTTP_403_FORBIDDEN) + + +CheckWMSPolicyCallable = Annotated[Callable, Depends(WMSAccessPolicy.check)] + + +class SandboxAccessPolicy(BaseAccessPolicy): + """ + Policy for the sandbox + It delegates most of it to the WMSPolicy + """ + + @staticmethod + async def policy( + policy_name: str, + user_info: AuthorizedUserInfo, + /, + *, + action: ActionType | None = None, + job_db: JobDB | None = None, + sandbox_metadata_db: SandboxMetadataDB | None = None, + pfns: list[str] | None = None, + required_prefix: str | None = None, + job_ids: list[int] | None = None, + check_wms_permissions: CheckWMSPolicyCallable | None = None, + ): + + assert action, "action is a mandatory parameter" + + # if we pass the job_db or job_ids, + # delegate the check to the WMSAccessPolicy + if job_db or job_ids: + # Make sure that check_wms_permission is set + # It should always be by fastapi Depends, + # but not when we test the policy in itself + assert check_wms_permissions + return check_wms_permissions(action=action, job_db=job_db, job_ids=job_ids) + + assert sandbox_metadata_db, "sandbox_metadata_db is a mandatory parameter" + assert pfns, "pfns is a mandatory parameter" + + if action == ActionType.CREATE: + + if NORMAL_USER not in user_info.properties: + raise HTTPException(status.HTTP_403_FORBIDDEN) + return + + if JOB_ADMINISTRATOR in user_info.properties: + return + + if NORMAL_USER not in user_info.properties: + raise HTTPException(status.HTTP_403_FORBIDDEN) + + # Getting a sandbox or modifying it + if required_prefix is None: + raise NotImplementedError("required_prefix is None. his shouldn't happen") + for pfn in pfns: + if not pfn.startswith(required_prefix): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Invalid PFN. PFN must start with {required_prefix}", + ) + + +CheckSandboxPolicyCallable = Annotated[Callable, Depends(SandboxAccessPolicy.check)] diff --git a/diracx-routers/src/diracx/routers/job_manager/sandboxes.py b/diracx-routers/src/diracx/routers/job_manager/sandboxes.py index 6e1c837e..45b77adb 100644 --- a/diracx-routers/src/diracx/routers/job_manager/sandboxes.py +++ b/diracx-routers/src/diracx/routers/job_manager/sandboxes.py @@ -24,11 +24,14 @@ ) from diracx.core.settings import ServiceSettingsBase +from ..utils.users import AuthorizedUserInfo, verify_dirac_access_token +from .access_policies import ActionType, CheckSandboxPolicyCallable + if TYPE_CHECKING: from types_aiobotocore_s3.client import S3Client -from ..auth import AuthorizedUserInfo, has_properties, verify_dirac_access_token -from ..dependencies import SandboxMetadataDB, add_settings_annotation +from ..auth import has_properties +from ..dependencies import JobDB, SandboxMetadataDB, add_settings_annotation from ..fastapi_classes import DiracxRouter MAX_SANDBOX_SIZE_BYTES = 100 * 1024 * 1024 @@ -86,6 +89,7 @@ async def initiate_sandbox_upload( sandbox_info: SandboxInfo, sandbox_metadata_db: SandboxMetadataDB, settings: SandboxStoreSettings, + check_permissions: CheckSandboxPolicyCallable, ) -> SandboxUploadResponse: """Get the PFN for the given sandbox, initiate an upload as required. @@ -95,15 +99,23 @@ async def initiate_sandbox_upload( If the sandbox does not exist in the database then the "url" and "fields" should be used to upload the sandbox to the storage backend. """ + + pfn = sandbox_metadata_db.get_pfn(settings.bucket_name, user_info, sandbox_info) + full_pfn = f"SB:{settings.se_name}|{pfn}" + await check_permissions( + action=ActionType.CREATE, sandbox_metadata_db=sandbox_metadata_db, pfns=[pfn] + ) + + # TODO: THis test should come first, but if we do + # the access policy will crash for not having been called + # so we need to find a way to ackownledge that + if sandbox_info.size > MAX_SANDBOX_SIZE_BYTES: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, detail=f"Sandbox too large. Max size is {MAX_SANDBOX_SIZE_BYTES} bytes", ) - pfn = sandbox_metadata_db.get_pfn(settings.bucket_name, user_info, sandbox_info) - full_pfn = f"SB:{settings.se_name}|{pfn}" - try: exists_and_assigned = await sandbox_metadata_db.sandbox_is_assigned( pfn, settings.se_name @@ -113,7 +125,7 @@ async def initiate_sandbox_upload( pass else: # As sandboxes are registered in the DB before uploading to the storage - # backend we can't on their existence in the database to determine if + # backend we can't rely on their existence in the database to determine if # they have been uploaded. Instead we check if the sandbox has been # assigned to a job. If it has then we know it has been uploaded and we # can avoid communicating with the storage backend. @@ -167,6 +179,8 @@ async def get_sandbox_file( pfn: Annotated[str, Query(max_length=256, pattern=SANDBOX_PFN_REGEX)], settings: SandboxStoreSettings, user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + sandbox_metadata_db: SandboxMetadataDB, + check_permissions: CheckSandboxPolicyCallable, ) -> SandboxDownloadResponse: """Get a presigned URL to download a sandbox file @@ -176,17 +190,20 @@ async def get_sandbox_file( most storage backends return an error when they receive an authorization header for a presigned URL. """ + pfn = pfn.split("|", 1)[-1] required_prefix = ( "/" + f"S3/{settings.bucket_name}/{user_info.vo}/{user_info.dirac_group}/{user_info.preferred_username}" + "/" ) - if not pfn.startswith(required_prefix): - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, - detail=f"Invalid PFN. PFN must start with {required_prefix}", - ) + await check_permissions( + action=ActionType.READ, + sandbox_metadata_db=sandbox_metadata_db, + pfns=[pfn], + required_prefix=required_prefix, + ) + # TODO: Support by name and by job id? presigned_url = await settings.s3_client.generate_presigned_url( ClientMethod="get_object", @@ -202,9 +219,12 @@ async def get_sandbox_file( async def get_job_sandboxes( job_id: int, sandbox_metadata_db: SandboxMetadataDB, + job_db: JobDB, + check_permissions: CheckSandboxPolicyCallable, ) -> dict[str, list[Any]]: """Get input and output sandboxes of given job""" - # TODO: check that user as created the job or is admin + await check_permissions(action=ActionType.READ, job_db=job_db, job_ids=[job_id]) + input_sb = await sandbox_metadata_db.get_sandbox_assigned_to_job( job_id, SandboxType.Input ) @@ -218,10 +238,13 @@ async def get_job_sandboxes( async def get_job_sandbox( job_id: int, sandbox_metadata_db: SandboxMetadataDB, + job_db: JobDB, sandbox_type: Literal["input", "output"], + check_permissions: CheckSandboxPolicyCallable, ) -> list[Any]: """Get input or output sandbox of given job""" - # TODO: check that user has created the job or is admin + + await check_permissions(action=ActionType.READ, job_db=job_db, job_ids=[job_id]) job_sb_pfns = await sandbox_metadata_db.get_sandbox_assigned_to_job( job_id, SandboxType(sandbox_type.capitalize()) ) @@ -234,10 +257,13 @@ async def assign_sandbox_to_job( job_id: int, pfn: Annotated[str, Body(max_length=256, pattern=SANDBOX_PFN_REGEX)], sandbox_metadata_db: SandboxMetadataDB, + job_db: JobDB, settings: SandboxStoreSettings, + check_permissions: CheckSandboxPolicyCallable, ): - """Mapp the pfn as output sandbox to job""" - # TODO: check that user has created the job or is admin + """Map the pfn as output sandbox to job""" + + await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=[job_id]) short_pfn = pfn.split("|", 1)[-1] await sandbox_metadata_db.assign_sandbox_to_jobs( jobs_ids=[job_id], @@ -251,9 +277,11 @@ async def assign_sandbox_to_job( async def unassign_job_sandboxes( job_id: int, sandbox_metadata_db: SandboxMetadataDB, + job_db: JobDB, + check_permissions: CheckSandboxPolicyCallable, ): """Delete single job sandbox mapping""" - # TODO: check that user has created the job or is admin + await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=[job_id]) await sandbox_metadata_db.unassign_sandboxes_to_jobs([job_id]) @@ -261,7 +289,10 @@ async def unassign_job_sandboxes( async def unassign_bulk_jobs_sandboxes( jobs_ids: Annotated[list[int], Query()], sandbox_metadata_db: SandboxMetadataDB, + job_db: JobDB, + check_permissions: CheckSandboxPolicyCallable, ): """Delete bulk jobs sandbox mapping""" - # TODO: check that user has created the job or is admin + + await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=jobs_ids) await sandbox_metadata_db.unassign_sandboxes_to_jobs(jobs_ids) diff --git a/diracx-routers/src/diracx/routers/utils/__init__.py b/diracx-routers/src/diracx/routers/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/diracx-routers/src/diracx/routers/utils/users.py b/diracx-routers/src/diracx/routers/utils/users.py new file mode 100644 index 00000000..5485fb21 --- /dev/null +++ b/diracx-routers/src/diracx/routers/utils/users.py @@ -0,0 +1,116 @@ +import re +from typing import Annotated, Any +from uuid import UUID + +from authlib.jose import JoseError, JsonWebToken +from fastapi import Depends, HTTPException, status +from fastapi.security import OpenIdConnect +from pydantic import BaseModel, Field + +from diracx.core.models import UserInfo +from diracx.core.properties import SecurityProperty +from diracx.core.settings import FernetKey, ServiceSettingsBase, TokenSigningKey +from diracx.routers.dependencies import add_settings_annotation + +# auto_error=False is used to avoid raising the wrong exception when the token is missing +# The error is handled in the verify_dirac_access_token function +# More info: +# - https://github.com/tiangolo/fastapi/issues/10177 +# - https://datatracker.ietf.org/doc/html/rfc6750#section-3.1 +oidc_scheme = OpenIdConnect( + openIdConnectUrl="/.well-known/openid-configuration", auto_error=False +) + + +class AuthInfo(BaseModel): + # raw token for propagation + bearer_token: str + + # token ID in the DB for Component + # unique jwt identifier for user + token_id: UUID + + # list of DIRAC properties + properties: list[SecurityProperty] + + policies: dict[str, Any] = {} + + +class AuthorizedUserInfo(AuthInfo, UserInfo): + pass + + +@add_settings_annotation +class AuthSettings(ServiceSettingsBase, env_prefix="DIRACX_SERVICE_AUTH_"): + """Settings for the authentication service.""" + + dirac_client_id: str = "myDIRACClientID" + # TODO: This should be taken dynamically + # ["http://pclhcb211:8000/docs/oauth2-redirect"] + allowed_redirects: list[str] = [] + device_flow_expiration_seconds: int = 600 + authorization_flow_expiration_seconds: int = 300 + + # State key is used to encrypt/decrypt the state dict passed to the IAM + state_key: FernetKey + + # TODO: this should probably be something mandatory + # to set by the user + token_issuer: str = "http://lhcbdirac.cern.ch/" + token_key: TokenSigningKey + token_algorithm: str = "RS256" + access_token_expire_minutes: int = 20 + refresh_token_expire_minutes: int = 60 + + available_properties: set[SecurityProperty] = Field( + default_factory=SecurityProperty.available_properties + ) + + +async def verify_dirac_access_token( + authorization: Annotated[str, Depends(oidc_scheme)], + settings: AuthSettings, +) -> AuthorizedUserInfo: + """Verify dirac user token and return a UserInfo class + Used for each API endpoint + """ + if not authorization: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Authorization header is missing", + headers={"WWW-Authenticate": "Bearer"}, + ) + if match := re.fullmatch(r"Bearer (.+)", authorization): + raw_token = match.group(1) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid authorization header", + ) + + try: + jwt = JsonWebToken(settings.token_algorithm) + token = jwt.decode( + raw_token, + key=settings.token_key.jwk, + claims_options={ + "iss": {"values": [settings.token_issuer]}, + }, + ) + token.validate() + except JoseError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid JWT", + ) from None + + return AuthorizedUserInfo( + bearer_token=raw_token, + token_id=token["jti"], + properties=token["dirac_properties"], + sub=token["sub"], + preferred_username=token["preferred_username"], + dirac_group=token["dirac_group"], + vo=token["vo"], + policies=token.get("dirac_policies", {}), + ) diff --git a/diracx-routers/tests/auth/test_legacy_exchange.py b/diracx-routers/tests/auth/test_legacy_exchange.py index 0472c4cf..867823b8 100644 --- a/diracx-routers/tests/auth/test_legacy_exchange.py +++ b/diracx-routers/tests/auth/test_legacy_exchange.py @@ -9,7 +9,7 @@ DIRAC_CLIENT_ID = "myDIRACClientID" pytestmark = pytest.mark.enabled_dependencies( - ["AuthDB", "AuthSettings", "ConfigSource"] + ["AuthDB", "AuthSettings", "ConfigSource", "BaseAccessPolicy"] ) diff --git a/diracx-routers/tests/auth/test_standard.py b/diracx-routers/tests/auth/test_standard.py index 1258559d..211ddcd6 100644 --- a/diracx-routers/tests/auth/test_standard.py +++ b/diracx-routers/tests/auth/test_standard.py @@ -18,7 +18,6 @@ from diracx.core.properties import NORMAL_USER, PROXY_MANAGEMENT, SecurityProperty from diracx.routers.auth.token import create_token from diracx.routers.auth.utils import ( - AuthSettings, GrantType, _server_metadata_cache, decrypt_state, @@ -26,10 +25,11 @@ get_server_metadata, parse_and_validate_scope, ) +from diracx.routers.utils.users import AuthSettings DIRAC_CLIENT_ID = "myDIRACClientID" pytestmark = pytest.mark.enabled_dependencies( - ["AuthDB", "AuthSettings", "ConfigSource"] + ["AuthDB", "AuthSettings", "ConfigSource", "BaseAccessPolicy"] ) @@ -765,7 +765,8 @@ def _get_tokens( def _get_and_check_token_response(test_client, request_data): - """Get a token and check that mandatory fields are present""" + """Get a token and check that mandatory fields are present and that the userinfo endpoint returns + something sensible""" # Check that token request now works r = test_client.post("/api/auth/token", data=request_data) assert r.status_code == 200, r.json() @@ -775,6 +776,12 @@ def _get_and_check_token_response(test_client, request_data): assert response_data["expires_in"] assert response_data["token_type"] + r = test_client.get( + "/api/auth/userinfo", + headers={"authorization": f"Bearer {response_data['access_token']}"}, + ) + assert r.status_code == 200, r.json() + return response_data diff --git a/diracx-routers/tests/jobs/test_sandboxes.py b/diracx-routers/tests/jobs/test_sandboxes.py index ca8a74bd..2422d900 100644 --- a/diracx-routers/tests/jobs/test_sandboxes.py +++ b/diracx-routers/tests/jobs/test_sandboxes.py @@ -10,7 +10,7 @@ from fastapi.testclient import TestClient from diracx.routers.auth.token import create_token -from diracx.routers.auth.utils import AuthSettings +from diracx.routers.utils.users import AuthSettings pytestmark = pytest.mark.enabled_dependencies( [ @@ -19,6 +19,8 @@ "JobLoggingDB", "SandboxMetadataDB", "SandboxStoreSettings", + "WMSAccessPolicy", + "SandboxAccessPolicy", ] ) @@ -71,13 +73,15 @@ def test_upload_then_download( other_user_token = create_token(other_user_payload, test_auth_settings) # Make sure another user can't download the sandbox + # The fact that another user cannot download the sandbox + # is enforced at the policy level, so since in this test + # we use the AlwaysAllowAccessPolicy, it will actually work ! r = normal_user_client.get( "/api/jobs/sandbox", params={"pfn": sandbox_pfn}, headers={"Authorization": f"Bearer {other_user_token}"}, ) - assert r.status_code == 400, r.text - assert "Invalid PFN. PFN must start with" in r.json()["detail"] + assert r.status_code == 200, r.text def test_upload_oversized(normal_user_client: TestClient): diff --git a/diracx-routers/tests/jobs/test_wms_access_policy.py b/diracx-routers/tests/jobs/test_wms_access_policy.py new file mode 100644 index 00000000..17c12e41 --- /dev/null +++ b/diracx-routers/tests/jobs/test_wms_access_policy.py @@ -0,0 +1,333 @@ +from uuid import uuid4 + +import pytest +from fastapi import HTTPException, status + +from diracx.core.properties import JOB_ADMINISTRATOR, NORMAL_USER +from diracx.routers.job_manager.access_policies import ( + ActionType, + SandboxAccessPolicy, + WMSAccessPolicy, +) +from diracx.routers.utils.users import AuthorizedUserInfo + +base_payload = { + "sub": "testingVO:yellow-sub", + "preferred_username": "preferred_username", + "dirac_group": "test_group", + "vo": "lhcb", + "token_id": str(uuid4()), + "bearer_token": "my_secret_token", +} + + +class FakeDB: + async def summary(self, *args): ... + + +@pytest.fixture +def job_db(): + yield FakeDB() + + +@pytest.fixture +def sandbox_db(): + yield FakeDB() + + +WMS_POLICY_NAME = "WMSAccessPolicy_AlthoughItDoesNotMatter" +SANDBOX_POLICY_NAME = "SandboxAccessPolicy_AlthoughItDoesNotMatter" + + +async def test_wms_access_policy_weird_user(job_db): + """USer without NORMAL_USER or JOB_ADMINISTRATION can't do anything""" + weird_user = AuthorizedUserInfo(properties=[], **base_payload) + with pytest.raises(HTTPException, match=f"{status.HTTP_403_FORBIDDEN}"): + await WMSAccessPolicy.policy( + WMS_POLICY_NAME, weird_user, action=ActionType.CREATE, job_db=job_db + ) + + with pytest.raises(HTTPException, match=f"{status.HTTP_403_FORBIDDEN}"): + await WMSAccessPolicy.policy( + WMS_POLICY_NAME, weird_user, action=ActionType.QUERY, job_db=job_db + ) + + with pytest.raises(HTTPException, match=f"{status.HTTP_403_FORBIDDEN}"): + await WMSAccessPolicy.policy( + WMS_POLICY_NAME, + weird_user, + action=ActionType.READ, + job_db=job_db, + job_ids=[1, 2, 3], + ) + + +async def test_wms_access_policy_create(job_db): + + admin_user = AuthorizedUserInfo(properties=[JOB_ADMINISTRATOR], **base_payload) + normal_user = AuthorizedUserInfo(properties=[NORMAL_USER], **base_payload) + + # You can't create and give job_ids at the same time + with pytest.raises(NotImplementedError): + await WMSAccessPolicy.policy( + WMS_POLICY_NAME, + normal_user, + action=ActionType.CREATE, + job_db=job_db, + job_ids=[1, 2, 3], + ) + with pytest.raises(NotImplementedError): + await WMSAccessPolicy.policy( + WMS_POLICY_NAME, + admin_user, + action=ActionType.CREATE, + job_db=job_db, + job_ids=[1, 2, 3], + ) + + # An admin cannot create any resource + with pytest.raises(HTTPException, match=f"{status.HTTP_403_FORBIDDEN}"): + await WMSAccessPolicy.policy( + WMS_POLICY_NAME, admin_user, action=ActionType.CREATE, job_db=job_db + ) + + # A normal user should be able to create jobs + await WMSAccessPolicy.policy( + WMS_POLICY_NAME, normal_user, action=ActionType.CREATE, job_db=job_db + ) + + ############## + + +async def test_wms_access_policy_query(job_db): + admin_user = AuthorizedUserInfo(properties=[JOB_ADMINISTRATOR], **base_payload) + normal_user = AuthorizedUserInfo(properties=[NORMAL_USER], **base_payload) + + # You can't create and give job_ids at the same time + with pytest.raises(NotImplementedError): + await WMSAccessPolicy.policy( + WMS_POLICY_NAME, + normal_user, + action=ActionType.QUERY, + job_db=job_db, + job_ids=[1, 2, 3], + ) + + # this does not trigger because the admin can do anything + await WMSAccessPolicy.policy( + WMS_POLICY_NAME, + admin_user, + action=ActionType.QUERY, + job_db=job_db, + job_ids=[1, 2, 3], + ) + + await WMSAccessPolicy.policy( + WMS_POLICY_NAME, normal_user, action=ActionType.QUERY, job_db=job_db + ) + + +async def test_wms_access_policy_read_modify(job_db, monkeypatch): + admin_user = AuthorizedUserInfo(properties=[JOB_ADMINISTRATOR], **base_payload) + normal_user = AuthorizedUserInfo(properties=[NORMAL_USER], **base_payload) + + for tested_policy in (ActionType.READ, ActionType.MANAGE): + # The admin can do anything + await WMSAccessPolicy.policy( + WMS_POLICY_NAME, + admin_user, + action=tested_policy, + job_db=job_db, + job_ids=[1, 2, 3], + ) + + # We must give job ids + with pytest.raises(NotImplementedError): + await WMSAccessPolicy.policy( + WMS_POLICY_NAME, + normal_user, + action=tested_policy, + job_db=job_db, + ) + + # Standard case, querying for one own jobs + async def summary_matching(*args): + return [{"Owner": "preferred_username", "VO": "lhcb", "count": 3}] + + monkeypatch.setattr(job_db, "summary", summary_matching) + + await WMSAccessPolicy.policy( + WMS_POLICY_NAME, + normal_user, + action=tested_policy, + job_db=job_db, + job_ids=[1, 2, 3], + ) + + # The admin can do anything + await WMSAccessPolicy.policy( + WMS_POLICY_NAME, + admin_user, + action=tested_policy, + job_db=job_db, + job_ids=[1, 2, 3], + ) + + # Jobs belong to somebody else + async def summary_other_owner(*args): + return [{"Owner": "other_owner", "VO": "lhcb", "count": 3}] + + monkeypatch.setattr(job_db, "summary", summary_other_owner) + with pytest.raises(HTTPException, match=f"{status.HTTP_403_FORBIDDEN}"): + await WMSAccessPolicy.policy( + WMS_POLICY_NAME, + normal_user, + action=tested_policy, + job_db=job_db, + job_ids=[1, 2, 3], + ) + + # Jobs belong to somebody else + async def summary_other_vo(*args): + return [{"Owner": "preferred_username", "VO": "gridpp", "count": 3}] + + monkeypatch.setattr(job_db, "summary", summary_other_vo) + with pytest.raises(HTTPException, match=f"{status.HTTP_403_FORBIDDEN}"): + await WMSAccessPolicy.policy( + WMS_POLICY_NAME, + normal_user, + action=tested_policy, + job_db=job_db, + job_ids=[1, 2, 3], + ) + + # Wrong job count + async def summary_other_vo(*args): + return [{"Owner": "preferred_username", "VO": "lhcb", "count": 2}] + + monkeypatch.setattr(job_db, "summary", summary_other_vo) + with pytest.raises(HTTPException, match=f"{status.HTTP_403_FORBIDDEN}"): + await WMSAccessPolicy.policy( + WMS_POLICY_NAME, + normal_user, + action=tested_policy, + job_db=job_db, + job_ids=[1, 2, 3], + ) + + +SANDBOX_PREFIX = "/S3/bucket_name/myvo/mygroup/mypreferred_username" +USER_SANDBOX_PFN = f"{SANDBOX_PREFIX}/mysandbox.tar.gz" +OTHER_USER_SANDBOX_PFN = ( + "/S3/bucket_name/myothervo/myothergroup/myotherusername/mysandbox.tar.gz" +) + + +async def test_sandbox_access_policy_delegate_to_wms(job_db): + """We expect that the policy delegates to the WMS policy when given job info + This will trigger an Assert as the WMSAccessPolicy is None + in these tests + """ + normal_user = AuthorizedUserInfo(properties=[NORMAL_USER], **base_payload) + with pytest.raises(AssertionError): + await SandboxAccessPolicy.policy( + SANDBOX_POLICY_NAME, normal_user, action=ActionType.CREATE, job_db=job_db + ) + + +async def test_sandbox_access_policy_create(sandbox_db): + + admin_user = AuthorizedUserInfo(properties=[JOB_ADMINISTRATOR], **base_payload) + normal_user = AuthorizedUserInfo(properties=[NORMAL_USER], **base_payload) + + # sandbox_metadata_db and pfns are mandatory parameters + with pytest.raises(AssertionError): + await SandboxAccessPolicy.policy( + SANDBOX_POLICY_NAME, + normal_user, + action=ActionType.CREATE, + sandbox_metadata_db=sandbox_db, + ) + with pytest.raises(AssertionError): + await SandboxAccessPolicy.policy( + SANDBOX_POLICY_NAME, + normal_user, + action=ActionType.CREATE, + pfns=[USER_SANDBOX_PFN], + ) + + # An admin cannot create any resource + with pytest.raises(HTTPException, match=f"{status.HTTP_403_FORBIDDEN}"): + await SandboxAccessPolicy.policy( + SANDBOX_POLICY_NAME, + admin_user, + action=ActionType.CREATE, + sandbox_metadata_db=sandbox_db, + pfns=[USER_SANDBOX_PFN], + ) + + # A normal user should be able to create sanbox + await SandboxAccessPolicy.policy( + SANDBOX_POLICY_NAME, + normal_user, + action=ActionType.CREATE, + sandbox_metadata_db=sandbox_db, + pfns=[USER_SANDBOX_PFN], + ) + + ############## + + +async def test_sandbox_access_policy_read(sandbox_db): + + admin_user = AuthorizedUserInfo(properties=[JOB_ADMINISTRATOR], **base_payload) + normal_user = AuthorizedUserInfo(properties=[NORMAL_USER], **base_payload) + + await SandboxAccessPolicy.policy( + SANDBOX_POLICY_NAME, + admin_user, + action=ActionType.READ, + sandbox_metadata_db=sandbox_db, + pfns=[USER_SANDBOX_PFN], + required_prefix=SANDBOX_PREFIX, + ) + + await SandboxAccessPolicy.policy( + SANDBOX_POLICY_NAME, + admin_user, + action=ActionType.READ, + sandbox_metadata_db=sandbox_db, + pfns=[OTHER_USER_SANDBOX_PFN], + required_prefix=SANDBOX_PREFIX, + ) + + # need required_prefix for READ + with pytest.raises(NotImplementedError): + await SandboxAccessPolicy.policy( + SANDBOX_POLICY_NAME, + normal_user, + action=ActionType.READ, + sandbox_metadata_db=sandbox_db, + pfns=[USER_SANDBOX_PFN], + ) + + # User can act on his own sandbox + await SandboxAccessPolicy.policy( + SANDBOX_POLICY_NAME, + normal_user, + action=ActionType.READ, + sandbox_metadata_db=sandbox_db, + pfns=[USER_SANDBOX_PFN], + required_prefix=SANDBOX_PREFIX, + ) + + # User cannot act on others + with pytest.raises(HTTPException, match=f"{status.HTTP_403_FORBIDDEN}"): + await SandboxAccessPolicy.policy( + SANDBOX_POLICY_NAME, + normal_user, + action=ActionType.READ, + sandbox_metadata_db=sandbox_db, + pfns=[OTHER_USER_SANDBOX_PFN], + required_prefix=SANDBOX_PREFIX, + ) diff --git a/diracx-routers/tests/test_config_manager.py b/diracx-routers/tests/test_config_manager.py index 42ac4932..dbd69c7e 100644 --- a/diracx-routers/tests/test_config_manager.py +++ b/diracx-routers/tests/test_config_manager.py @@ -1,7 +1,9 @@ import pytest from fastapi import status -pytestmark = pytest.mark.enabled_dependencies(["AuthSettings", "ConfigSource"]) +pytestmark = pytest.mark.enabled_dependencies( + ["AuthSettings", "ConfigSource", "OpenAccessPolicy"] +) @pytest.fixture diff --git a/diracx-routers/tests/test_generic.py b/diracx-routers/tests/test_generic.py index 2b212a74..9f31064d 100644 --- a/diracx-routers/tests/test_generic.py +++ b/diracx-routers/tests/test_generic.py @@ -1,6 +1,8 @@ import pytest -pytestmark = pytest.mark.enabled_dependencies(["ConfigSource", "AuthSettings"]) +pytestmark = pytest.mark.enabled_dependencies( + ["ConfigSource", "AuthSettings", "OpenAccessPolicy"] +) @pytest.fixture diff --git a/diracx-routers/tests/test_job_manager.py b/diracx-routers/tests/test_job_manager.py index 4adff76c..4b16ea1b 100644 --- a/diracx-routers/tests/test_job_manager.py +++ b/diracx-routers/tests/test_job_manager.py @@ -77,6 +77,7 @@ "ConfigSource", "TaskQueueDB", "SandboxMetadataDB", + "WMSAccessPolicy", ] ) diff --git a/diracx-routers/tests/test_policy.py b/diracx-routers/tests/test_policy.py new file mode 100644 index 00000000..bb684be0 --- /dev/null +++ b/diracx-routers/tests/test_policy.py @@ -0,0 +1,50 @@ +import inspect +from collections import defaultdict +from typing import TYPE_CHECKING + +from diracx.core.extensions import select_from_extension +from diracx.routers.access_policies import ( + BaseAccessPolicy, +) + +if TYPE_CHECKING: + from diracx.routers.fastapi_classes import DiracxRouter + + +def test_all_routes_have_policy(): + """ + Loop over all the routers, loop over every route, + and make sure there is a dependency on a BaseAccessPolicy class + + If the router is created with "require_auth=False", we skip it. + We also skip routes that have the "diracx_open_access" decorator + + """ + missing_security: defaultdict[list[str]] = defaultdict(list) + for entry_point in select_from_extension(group="diracx.services"): + router: DiracxRouter = entry_point.load() + + # If the router was created with the + # require_auth = False, skip it + if not router.diracx_require_auth: + continue + + for route in router.routes: + + # If the route is decorated with the diracx_open_access + # decorator, we skip it + if getattr(route.endpoint, "diracx_open_access", False): + continue + + for dependency in route.dependant.dependencies: + if inspect.ismethod(dependency.call) and issubclass( + dependency.call.__self__, BaseAccessPolicy + ): + # We found a dependency on check_permissions + break + else: + # We looked at all dependency without finding + # check_permission + missing_security[entry_point.name].append(route.name) + + assert not missing_security diff --git a/diracx-testing/src/diracx/testing/__init__.py b/diracx-testing/src/diracx/testing/__init__.py index aacee590..16cbe4ee 100644 --- a/diracx-testing/src/diracx/testing/__init__.py +++ b/diracx-testing/src/diracx/testing/__init__.py @@ -1,5 +1,7 @@ from __future__ import annotations +# TODO: this needs a lot of documentation, in particular what will matter for users +# are the enabled_dependencies markers import asyncio import contextlib import os @@ -16,8 +18,9 @@ import requests if TYPE_CHECKING: - from diracx.routers.auth.utils import AuthSettings from diracx.routers.job_manager.sandboxes import SandboxStoreSettings + from diracx.routers.utils.users import AuthorizedUserInfo, AuthSettings + # to get a string like this run: # openssl rand -hex 32 @@ -77,7 +80,7 @@ def fernet_key() -> str: @pytest.fixture(scope="session") def test_auth_settings(rsa_private_key_pem, fernet_key) -> AuthSettings: - from diracx.routers.auth.utils import AuthSettings + from diracx.routers.utils.users import AuthSettings yield AuthSettings( token_key=rsa_private_key_pem, @@ -132,6 +135,7 @@ def __call__(self): class ClientFactory: + def __init__( self, tmp_path_factory, @@ -144,6 +148,21 @@ def __init__( from diracx.core.settings import ServiceSettingsBase from diracx.db.sql.utils import BaseSQLDB from diracx.routers import create_app_inner + from diracx.routers.access_policies import BaseAccessPolicy + + class AlwaysAllowAccessPolicy(BaseAccessPolicy): + """ + Dummy access policy + """ + + async def policy( + policy_name: str, user_info: AuthorizedUserInfo, /, **kwargs + ): + pass + + def enrich_tokens(access_payload: dict, refresh_payload: dict): + + return {"PolicySpecific": "OpenAccessForTest"}, {} enabled_systems = { e.name for e in select_from_extension(group="diracx.services") @@ -156,6 +175,12 @@ def __init__( self.test_auth_settings = test_auth_settings + all_access_policies = { + e.name: [AlwaysAllowAccessPolicy] + + BaseAccessPolicy.available_implementations(e.name) + for e in select_from_extension(group="diracx.access_policies") + } + self.app = create_app_inner( enabled_systems=enabled_systems, all_service_settings=[ @@ -169,13 +194,15 @@ def __init__( config_source=ConfigSource.create_from_url( backend_url=f"git+file://{with_config_repo}" ), + all_access_policies=all_access_policies, ) self.all_dependency_overrides = self.app.dependency_overrides.copy() self.app.dependency_overrides = {} for obj in self.all_dependency_overrides: assert issubclass( - obj.__self__, (ServiceSettingsBase, BaseSQLDB, ConfigSource) + obj.__self__, + (ServiceSettingsBase, BaseSQLDB, ConfigSource, BaseAccessPolicy), ), obj self.all_lifetime_functions = self.app.lifetime_functions[:] @@ -190,9 +217,10 @@ def configure(self, enabled_dependencies): assert ( self.app.dependency_overrides == {} and self.app.lifetime_functions == [] ), "configure cannot be nested" - for k, v in self.all_dependency_overrides.items(): + class_name = k.__self__.__name__ + if class_name in enabled_dependencies: self.app.dependency_overrides[k] = v else: @@ -317,7 +345,10 @@ def admin_user(self): @pytest.fixture(scope="session") def session_client_factory( - test_auth_settings, test_sandbox_settings, with_config_repo, tmp_path_factory + test_auth_settings, + test_sandbox_settings, + with_config_repo, + tmp_path_factory, ): """ TODO diff --git a/docs/README.md b/docs/README.md index cfc5037f..f7605451 100644 --- a/docs/README.md +++ b/docs/README.md @@ -46,7 +46,7 @@ conda activate diracx-dev # Make an editable installation of diracx -pip install -e . +pip install -r requirements-dev.txt # Install the patched DIRAC version pip install git+https://github.com/DIRACGrid/DIRAC.git@integration @@ -139,6 +139,7 @@ To add a router there are two steps: 1. Create a module in `diracx.routers` for the given service. 2. Add an entry to the `diracx.services` entrypoint. +3. Do not forget the Access Policy (see chapter lower down) We'll now make a `/parking/` router which contains information store in the `DummyDB`. @@ -190,3 +191,54 @@ This is for advanced users only as it is currently an unstable feature When a new client generation is needed, a CI job called `client-generation` will fail, and one of the repo admin will regenerate the client for you. If you anyway want to try, the best up to date documentation is to look at the [client-generation CI job](https://github.com/DIRACGrid/diracx/blob/main/.github/workflows/main.yml) + + +## Access Policy + +Permission management in ``diracx`` is managed by ``AccessPolicy``. The idea is that each policy can inject data upon token issuance, and every route will rely on a given policy to check permissions. + +The various policies are defined in `diracx-routers/pyproject.toml`: + +```toml +[project.entry-points."diracx.access_policies"] +WMSAccessPolicy = "diracx.routers.job_manager.access_policies:WMSAccessPolicy" +SandboxAccessPolicy = "diracx.routers.job_manager.access_policies:SandboxAccessPolicy" +``` + +Each route must have a policy as argument, and call it + + +```python +from .access_policies import ActionType, CheckWMSPolicyCallable + +@router.post("/") +async def submit_bulk_jobs( + job_definitions: Annotated[list[str], Body()], + job_db: JobDB, + check_permissions: CheckWMSPolicyCallable, +) -> list[InsertedJob]: + await check_permissions(action=ActionType.CREATE, job_db=job_db) + ... +``` + +Failing in doing so will result in a CI error ``test_all_routes_have_policy`` + +Some routes do not need access permissions, like the authorization ones, in which case they can be marked as such + +```python +from .access_policies import open_access + +@open_access +@router.get("/") +async def serve_config( +``` + +Implementing a new ``AccessPolicy`` is done by: +1. Create a module in `diracx.routers.access_policies.py` +2. Create a new class inheriting from ``BaseAccessPolicy`` +3. For specific instructions, see ``diracx-routers/src/diracx/routers/access_policies.py`` +2. Add an entry to the `diracx.access_policies` entrypoint. + + +> [!WARNING] +> When running tests, no permission is checked. This is to allow testing the router behavior with respect to the policy behavior. For testing a policy, see for example ``diracx-routers/tests/jobs/test_wms_access_policy.py`` diff --git a/docs/TESTING.md b/docs/TESTING.md new file mode 100644 index 00000000..7171c970 --- /dev/null +++ b/docs/TESTING.md @@ -0,0 +1,10 @@ +Where we want to go, not where we are + +* each package runs unit tests in different jobs to ensure that there is no hidden dependencies: pytest and mypy +* run the integration tests (against the demo) in a single job + +For the unit test, we start with a crude conda environment, and do pip install of the package. + +For the integration tests, we always use the [services|tasks|client]-base image and do a pip install directly. + +Same for unit tests (router tests use `services-base`, etc) diff --git a/tests/make-token-local.py b/tests/make-token-local.py index 96d9bc14..bcbc4a07 100755 --- a/tests/make-token-local.py +++ b/tests/make-token-local.py @@ -8,7 +8,7 @@ from diracx.core.properties import NORMAL_USER from diracx.core.utils import write_credentials from diracx.routers.auth.token import create_token -from diracx.routers.auth.utils import AuthSettings +from diracx.routers.utils.users import AuthSettings def parse_args():