Skip to content

Commit

Permalink
Merge pull request #227 from chaen/policyMechanism
Browse files Browse the repository at this point in the history
Policy Mechanism
  • Loading branch information
chrisburr committed Jun 5, 2024
2 parents 51e0fba + 4cfe53b commit 47037a3
Show file tree
Hide file tree
Showing 29 changed files with 1,621 additions and 176 deletions.
21 changes: 12 additions & 9 deletions diracx-routers/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ description = "TODO"
readme = "README.md"
requires-python = ">=3.10"
keywords = []
license = {text = "GPL-3.0-only"}
license = { text = "GPL-3.0-only" }
classifiers = [
"Intended Audience :: Science/Research",
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
Expand All @@ -20,7 +20,7 @@ dependencies = [
"dirac",
"diracx-core",
"diracx-db",
"python-dotenv", # TODO: We might not need this
"python-dotenv", # TODO: We might not need this
"python-multipart",
"fastapi",
"httpx",
Expand All @@ -35,11 +35,7 @@ dependencies = [
dynamic = ["version"]

[project.optional-dependencies]
testing = [
"diracx-testing",
"moto[server]",
"pytest-httpx",
]
testing = ["diracx-testing", "moto[server]", "pytest-httpx"]
types = [
"boto3-stubs",
"types-aiobotocore[essential]",
Expand All @@ -56,6 +52,11 @@ config = "diracx.routers.configuration:router"
auth = "diracx.routers.auth:router"
".well-known" = "diracx.routers.auth.well_known:router"

[project.entry-points."diracx.access_policies"]
WMSAccessPolicy = "diracx.routers.job_manager.access_policies:WMSAccessPolicy"
SandboxAccessPolicy = "diracx.routers.job_manager.access_policies:SandboxAccessPolicy"


[tool.setuptools.packages.find]
where = ["src"]

Expand All @@ -70,8 +71,10 @@ root = ".."
testpaths = ["tests"]
addopts = [
"-v",
"--cov=diracx.routers", "--cov-report=term-missing",
"-pdiracx.testing", "-pdiracx.testing.osdb",
"--cov=diracx.routers",
"--cov-report=term-missing",
"-pdiracx.testing",
"-pdiracx.testing.osdb",
"--import-mode=importlib",
]
asyncio_mode = "auto"
Expand Down
114 changes: 110 additions & 4 deletions diracx-routers/src/diracx/routers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
"""
# Startup sequence
uvicorn is called with `create_app` as a factory
create_app loads the environment configuration
"""

from __future__ import annotations

import inspect
Expand All @@ -6,7 +14,7 @@
from collections.abc import AsyncGenerator
from functools import partial
from logging import Formatter, StreamHandler
from typing import Any, Awaitable, Callable, Iterable, TypeVar, cast
from typing import Any, Awaitable, Callable, Iterable, Sequence, TypeVar, cast

import dotenv
from cachetools import TTLCache
Expand All @@ -28,10 +36,11 @@
from diracx.db.exceptions import DBUnavailable
from diracx.db.os.utils import BaseOSDB
from diracx.db.sql.utils import BaseSQLDB
from diracx.routers.access_policies import BaseAccessPolicy, check_permissions

from .auth import verify_dirac_access_token
from .fastapi_classes import DiracFastAPI, DiracxRouter
from .otel import instrument_otel
from .utils.users import verify_dirac_access_token

T = TypeVar("T")
T2 = TypeVar("T2", bound=BaseSQLDB | BaseOSDB)
Expand Down Expand Up @@ -83,6 +92,7 @@ def configure_logger():
# All routes must have tags (needed for auto gen of client)
# Form headers must have a description (autogen)
# methods name should follow the generate_unique_id_function pattern
# All routes should have a policy mechanism


def create_app_inner(
Expand All @@ -92,21 +102,83 @@ def create_app_inner(
database_urls: dict[str, str],
os_database_conn_kwargs: dict[str, Any],
config_source: ConfigSource,
all_access_policies: dict[str, Sequence[BaseAccessPolicy]],
) -> DiracFastAPI:
"""
This method does the heavy lifting work of putting all the pieces together.
When starting the application normaly, this method is called by create_app,
and the values of the parameters are taken from environment variables or
entrypoints.
When running tests, the parameters are mocks or test settings.
We rely on the dependency_override mechanism to implement
the actual behavior we are interested in for settings, DBs or policy.
This allows an extension to override any of these components
:param enabled_system:
this contains the name of all the routers we have to load
:param all_service_settings:
list of instance of each Settings type required
:param database_urls:
dict <db_name: url>. When testing, sqlite urls are used
:param os_database_conn_kwargs:
<db_name:dict> containing all the parameters the OpenSearch client takes
:param config_source:
Source of the configuration to use
:param all_access_policies:
<policy_name: [implementations]>
"""

app = DiracFastAPI()

# Find which settings classes are available and add them to dependency_overrides
# We use a single instance of each Setting classes for performance reasons,
# since it avoids recreating a pydantic model every time
# We add the Settings lifetime_function to the application lifetime_function,
# Please see ServiceSettingsBase for more details

available_settings_classes: set[type[ServiceSettingsBase]] = set()
for service_settings in all_service_settings:
cls = type(service_settings)
assert cls not in available_settings_classes
available_settings_classes.add(cls)
app.lifetime_functions.append(service_settings.lifetime_function)
# We always return the same setting instance for perf reasons
app.dependency_overrides[cls.create] = partial(lambda x: x, service_settings)

# Override the configuration source
# Override the ConfigSource.create by the actual reading of the config
app.dependency_overrides[ConfigSource.create] = config_source.read_config

all_access_policies_used = {}

for access_policy_name, access_policy_classes in all_access_policies.items():

# The first AccessPolicy is the highest priority one
access_policy_used = access_policy_classes[0].policy
all_access_policies_used[access_policy_name] = access_policy_classes[0]

# app.lifetime_functions.append(access_policy.lifetime_function)
# Add overrides for all the AccessPolicy classes, including those from extensions
# This means vanilla DiracX routers get an instance of the extension's AccessPolicy
for access_policy_class in access_policy_classes:
# Here we do not check that access_policy_class.check is
# not already in the dependency_overrides becaue the same
# policy could be used for multiple purpose
# (e.g. open access)
# assert access_policy_class.check not in app.dependency_overrides
app.dependency_overrides[access_policy_class.check] = partial(
check_permissions, access_policy_used, access_policy_name
)

app.dependency_overrides[BaseAccessPolicy.all_used_access_policies] = (
lambda: all_access_policies_used
)

fail_startup = True
# Add the SQL DBs to the application
available_sql_db_classes: set[type[BaseSQLDB]] = set()
Expand Down Expand Up @@ -237,7 +309,22 @@ def create_app_inner(


def create_app() -> DiracFastAPI:
"""Load settings from the environment and create the application object"""
"""Load settings from the environment and create the application object
The configuration may be placed in .env files pointed to by
environment variables DIRACX_SERVICE_DOTENV.
They can be followed by "_X" where X is a number, and the order
is respected.
We then loop over all the diracx.services definitions.
A specific route can be disabled with an environment variable
DIRACX_SERVICE_<name>_ENABLED=false
For each of the enabled route, we inspect which Setting classes
are needed.
We attempt to load each setting classes to make sure that the
settings are correctly defined.
"""
for env_file in dotenv_files_from_environment("DIRACX_SERVICE_DOTENV"):
logger.debug("Loading dotenv file: %s", env_file)
if not dotenv.load_dotenv(env_file):
Expand All @@ -261,12 +348,31 @@ def create_app() -> DiracFastAPI:
# Load settings classes required by the routers
all_service_settings = [settings_class() for settings_class in settings_classes]

# Find all the access policies

available_access_policy_names = set(
[
entry_point.name
for entry_point in select_from_extension(group="diracx.access_policies")
]
)

all_access_policies = {}

for access_policy_name in available_access_policy_names:

access_policy_classes = BaseAccessPolicy.available_implementations(
access_policy_name
)
all_access_policies[access_policy_name] = access_policy_classes

return create_app_inner(
enabled_systems=enabled_systems,
all_service_settings=all_service_settings,
database_urls=BaseSQLDB.available_urls(),
os_database_conn_kwargs=BaseOSDB.available_urls(),
config_source=ConfigSource.create(),
all_access_policies=all_access_policies,
)


Expand Down
159 changes: 159 additions & 0 deletions diracx-routers/src/diracx/routers/access_policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
"""
AccessPolicy
We define a set of Policy classes (WMS, DFC, etc).
They have a default implementation in diracx.
If an extension wants to change it, it can be overwriten in the entry point
diracx.access_policies
Each route should either:
* have the open_access decorator to make explicit that it does not implement policy
* have a callable and call it that will perform the access policy
Adding a new policy:
1. Create a class that inherits from BaseAccessPolicy and implement the ``policy`` and ``enrich_tokens`` methods
2. create an entry in diracx.access_policy entrypoints
3. Create a dependency such as CheckMyPolicyCallable = Annotated[Callable, Depends(MyAccessPolicy.check)]
"""

import functools
import os
import time
from abc import ABCMeta, abstractmethod
from typing import Annotated, Callable, Self

from fastapi import Depends

from diracx.core.extensions import select_from_extension
from diracx.routers.utils.users import AuthorizedUserInfo, verify_dirac_access_token

# FastAPI bug:
# We normally would use `from __future__ import annotations`
# but a bug in FastAPI prevents us from doing so
# https://github.com/tiangolo/fastapi/pull/11355
# Until it is merged, we can work around it by using strings.


class BaseAccessPolicy(metaclass=ABCMeta):
"""
Base class to be used by all the other Access Policy.
Each child class should implement the policy staticmethod.
"""

@classmethod
def check(cls) -> Self:
"""
Placeholder which is in the dependency override
"""
raise NotImplementedError("This should never be called")

@classmethod
def all_used_access_policies(cls) -> dict[str, "BaseAccessPolicy"]:
""" " Returns the list of classes that are actually called
(i.e. taking into account extensions)
This should be overriden by the dependency_override
"""
raise NotImplementedError("This should never be called")

@classmethod
def available_implementations(cls, access_policy_name: str):
"""Return the available implementations of the AccessPolicy in reverse priority order."""
policy_classes: list[type["BaseAccessPolicy"]] = [
entry_point.load()
for entry_point in select_from_extension(
group="diracx.access_policies", name=access_policy_name
)
]
if not policy_classes:
raise NotImplementedError(
f"Could not find any matches for {access_policy_name=}"
)
return policy_classes

@staticmethod
@abstractmethod
async def policy(policy_name: str, user_info: AuthorizedUserInfo, /):
"""
This is the method to be implemented in child classes.
It should always take an AuthorizedUserInfo parameter, which
is passed by check_permissions.
The rest is whatever the policy actually needs. There are rules to write it:
* This method must be static and async
* All parameters must be kw only arguments
* All parameters must have a default value (Liskov Substitution principle)
It is expected that a policy denying the access raises HTTPException(status.HTTP_403_FORBIDDEN)
"""
return

@staticmethod
def enrich_tokens(access_payload: dict, refresh_payload: dict) -> tuple[dict, dict]:
"""
This method is called when issuing a token, and can add whatever
content it wants inside the access or refresh payload
:param access_payload: access token payload
:param refresh_payload: refresh token payload
:returns: extra content for both payload
"""
return {}, {}


def check_permissions(
policy: Callable,
policy_name: str,
user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)],
):
"""
This wrapper just calls the actual implementation, but also makes sure
that the policy has been called.
If not, diracx will abruptly crash. It is violent, but necessary to make
sure that it gets noticed :-)
This method is never called directly, but used in the dependency_override
at startup
"""

has_been_called = False

@functools.wraps(policy)
async def wrapped_policy(**kwargs):
"""This wrapper is just to update the has_been_called flag"""
nonlocal has_been_called
has_been_called = True
return await policy(policy_name, user_info, **kwargs)

try:
yield wrapped_policy
finally:
if not has_been_called:
# TODO nice error message with inspect
# That should really not happen
print(
"THIS SHOULD NOT HAPPEN, ALWAYS VERIFY PERMISSION",
"(PS: I hope you are in a CI)",
flush=True,
)
# Sleep a bit to make sure the flush happened
time.sleep(1)
os._exit(1)


def open_access(f):
"""
Decorator to put around the route that are part of a DiracxRouter
that are expected not to do any access policy check.
The presence of a token will still be checked if the router has require_auth to True.
This is useful to allow the CI to detect routes which may have forgotten
to have an access check
"""
f.diracx_open_access = True

@functools.wraps(f)
def inner(*args, **kwargs):
return f(*args, **kwargs)

return inner
Loading

0 comments on commit 47037a3

Please sign in to comment.