Skip to content

Commit

Permalink
save before flight (@chrisburr, if my plane crash, remember to fight …
Browse files Browse the repository at this point in the history
…for it to be recognized an occupational accident :-) )
  • Loading branch information
chaen committed Mar 29, 2024
1 parent f928602 commit 7b81cad
Show file tree
Hide file tree
Showing 10 changed files with 186 additions and 89 deletions.
2 changes: 2 additions & 0 deletions diracx-routers/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ auth = "diracx.routers.auth:router"
".well-known" = "diracx.routers.auth.well_known:router"

[project.entry-points."diracx.access_policies"]
# OpenAccessPolicy = "diracx.routers.access_policies:OpenAccessPolicy"
WMSAccessPolicy = "diracx.routers.job_manager.access_policies:WMSAccessPolicy"


[tool.setuptools.packages.find]
where = ["src"]

Expand Down
80 changes: 77 additions & 3 deletions diracx-routers/src/diracx/routers/access_policies.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,38 @@
"""
AccessPolicy
We define a set of Policy classes (WMS, DFC, etc).
They have a default implementation in diracx.
If an extension wants to change it, it can be overwriten in the entry point
diracx.access_policies
Each route should either:
* have the open_access decorator to make explicit that it does not implement policy
* have a callable and call it that will perform the access policy
"""

import functools
import os
from typing import Annotated, Self
from abc import ABCMeta, abstractmethod
from typing import Annotated, Callable, Self

from fastapi import Depends

from diracx.core.extensions import select_from_extension
from diracx.routers.auth import AuthorizedUserInfo, verify_dirac_access_token


class BaseAccessPolicy:
class BaseAccessPolicy(metaclass=ABCMeta):
"""
Base class to be used by all the other Access Policy.
Each child class should implement the policy staticmethod.
"""

@classmethod
def check(cls) -> Self:
raise NotImplementedError("This should never be called")
Expand All @@ -28,11 +52,29 @@ def available_implementations(cls, access_policy_name: str):
)
return policy_classes

@staticmethod
@abstractmethod
async def policy(user_info: AuthorizedUserInfo, /, **kwargs):
"""
This is the method to be implemented in child classes.
It should always take an AuthorizedUserInfo parameter, which
is passed by check_permissions.
The rest is whatever the policy actually needs
This method must be static and async
"""
pass


def check_permissions(
policy,
policy: Callable,
user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)],
):
"""
This wrapper just calls the actual implementation, but also makes sure
that the policy has been called.
If not, diracx will abruptly crash. It is violent, but necessary to make
sure that it gets noticed :-)
"""

has_been_called = False

Expand All @@ -55,3 +97,35 @@ async def wrapped_policy(**kwargs):
flush=True,
)
os._exit(1)


def open_access(f):
"""
Decorator to put around the route that are part of a DiracxRouter
that are expected not to do any access policy check.
The presence of a token will still be checked if the router has require_auth to True.
This is useful to allow the CI to detect routes which may have forgotten
to have an access check
"""
f.diracx_open_access = True

@functools.wraps(f)
def inner(*args, **kwargs):
return f(*args, **kwargs)

return inner


# class OpenAccessPolicy(BaseAccessPolicy):
# """Open access, always allowed"""

# @classmethod
# def check(cls):
# pass

# @staticmethod
# async def policy(user_info: AuthorizedUserInfo, /):
# pass


# OpenAccessPolicyCallable = Annotated[Callable, Depends(OpenAccessPolicy.check)]
10 changes: 9 additions & 1 deletion diracx-routers/src/diracx/routers/auth/well_known.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from ..fastapi_classes import DiracxRouter
from .utils import AuthSettings

# from ..access_policies import OpenAccessPolicyCallable

router = DiracxRouter(require_auth=False, path_root="")


Expand All @@ -16,8 +18,10 @@ async def openid_configuration(
request: Request,
config: Config,
settings: AuthSettings,
# check_permissions: OpenAccessPolicyCallable,
):
"""OpenID Connect discovery endpoint."""
# await check_permissions()
scopes_supported = []
for vo in config.Registry:
scopes_supported.append(f"vo:{vo}")
Expand Down Expand Up @@ -65,8 +69,12 @@ class Metadata(TypedDict):


@router.get("/dirac-metadata")
async def installation_metadata(config: Config) -> Metadata:
async def installation_metadata(
config: Config,
# check_permissions: OpenAccessPolicyCallable,
) -> Metadata:
"""Get metadata about the dirac installation."""
# await check_permissions()
metadata: Metadata = {
"virtual_organizations": {},
}
Expand Down
5 changes: 5 additions & 0 deletions diracx-routers/src/diracx/routers/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
status,
)

# from .access_policies import OpenAccessPolicyCallable
from .access_policies import open_access
from .dependencies import Config
from .fastapi_classes import DiracxRouter

Expand All @@ -18,10 +20,12 @@
router = DiracxRouter()


@open_access
@router.get("/")
async def serve_config(
config: Config,
response: Response,
# check_permissions: OpenAccessPolicyCallable,
if_none_match: Annotated[str | None, Header()] = None,
if_modified_since: Annotated[str | None, Header()] = None,
):
Expand All @@ -34,6 +38,7 @@ async def serve_config(
If If-Modified-Since is given and is newer than latest,
return 304: this is to avoid flip/flopping
"""
# await check_permissions()
headers = {
"ETag": config._hexsha,
"Last-Modified": config._modified.strftime(LAST_MODIFIED_FORMAT),
Expand Down
40 changes: 20 additions & 20 deletions diracx-routers/src/diracx/routers/job_manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ..auth import AuthorizedUserInfo, verify_dirac_access_token
from ..dependencies import JobDB, JobLoggingDB, SandboxMetadataDB, TaskQueueDB
from ..fastapi_classes import DiracxRouter
from .access_policies import ActionType, CheckPermissionsCallable
from .access_policies import ActionType, CheckWMSPolicyCallable
from .sandboxes import router as sandboxes_router

MAX_PARAMETRIC_JOBS = 20
Expand Down Expand Up @@ -116,7 +116,7 @@ async def submit_bulk_jobs(
job_db: JobDB,
job_logging_db: JobLoggingDB,
user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)],
check_permissions: CheckPermissionsCallable,
check_permissions: CheckWMSPolicyCallable,
) -> list[InsertedJob]:
await check_permissions(action=ActionType.CREATE, job_db=job_db)

Expand Down Expand Up @@ -253,7 +253,7 @@ async def delete_bulk_jobs(
job_logging_db: JobLoggingDB,
task_queue_db: TaskQueueDB,
background_task: BackgroundTasks,
check_permissions: CheckPermissionsCallable,
check_permissions: CheckWMSPolicyCallable,
):

await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=job_ids)
Expand Down Expand Up @@ -291,7 +291,7 @@ async def kill_bulk_jobs(
job_logging_db: JobLoggingDB,
task_queue_db: TaskQueueDB,
background_task: BackgroundTasks,
check_permissions: CheckPermissionsCallable,
check_permissions: CheckWMSPolicyCallable,
):
await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=job_ids)
# TODO: implement job policy
Expand Down Expand Up @@ -328,7 +328,7 @@ async def remove_bulk_jobs(
sandbox_metadata_db: SandboxMetadataDB,
task_queue_db: TaskQueueDB,
background_task: BackgroundTasks,
check_permissions: CheckPermissionsCallable,
check_permissions: CheckWMSPolicyCallable,
):
"""
Fully remove a list of jobs from the WMS databases.
Expand Down Expand Up @@ -362,7 +362,7 @@ async def remove_bulk_jobs(
async def get_job_status_bulk(
job_ids: Annotated[list[int], Query()],
job_db: JobDB,
check_permissions: CheckPermissionsCallable,
check_permissions: CheckWMSPolicyCallable,
) -> dict[int, LimitedJobStatusReturn]:
await check_permissions(action=ActionType.READ, job_db=job_db, job_ids=job_ids)
try:
Expand All @@ -379,7 +379,7 @@ async def set_job_status_bulk(
job_update: dict[int, dict[datetime, JobStatusUpdate]],
job_db: JobDB,
job_logging_db: JobLoggingDB,
check_permissions: CheckPermissionsCallable,
check_permissions: CheckWMSPolicyCallable,
force: bool = False,
) -> dict[int, SetJobStatusReturn]:
await check_permissions(
Expand Down Expand Up @@ -408,7 +408,7 @@ async def get_job_status_history_bulk(
job_ids: Annotated[list[int], Query()],
job_logging_db: JobLoggingDB,
job_db: JobDB,
check_permissions: CheckPermissionsCallable,
check_permissions: CheckWMSPolicyCallable,
) -> dict[int, list[JobStatusReturn]]:
await check_permissions(action=ActionType.READ, job_db=job_db, job_ids=job_ids)
result = await asyncio.gather(
Expand All @@ -422,7 +422,7 @@ async def reschedule_bulk_jobs(
job_ids: Annotated[list[int], Query()],
job_db: JobDB,
job_logging_db: JobLoggingDB,
check_permissions: CheckPermissionsCallable,
check_permissions: CheckWMSPolicyCallable,
):
await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=job_ids)
rescheduled_jobs = []
Expand Down Expand Up @@ -473,7 +473,7 @@ async def reschedule_bulk_jobs(
async def reschedule_single_job(
job_id: int,
job_db: JobDB,
check_permissions: CheckPermissionsCallable,
check_permissions: CheckWMSPolicyCallable,
):
await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=[job_id])
try:
Expand Down Expand Up @@ -545,7 +545,7 @@ async def search(
config: Annotated[Config, Depends(ConfigSource.create)],
job_db: JobDB,
user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)],
check_permissions: CheckPermissionsCallable,
check_permissions: CheckWMSPolicyCallable,
page: int = 0,
per_page: int = 100,
body: Annotated[
Expand Down Expand Up @@ -585,7 +585,7 @@ async def summary(
job_db: JobDB,
user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)],
body: JobSummaryParams,
check_permissions: CheckPermissionsCallable,
check_permissions: CheckWMSPolicyCallable,
):
"""Show information suitable for plotting"""
await check_permissions(action=ActionType.QUERY, job_db=job_db)
Expand All @@ -605,7 +605,7 @@ async def summary(
async def get_single_job(
job_id: int,
job_db: JobDB,
check_permissions: CheckPermissionsCallable,
check_permissions: CheckWMSPolicyCallable,
):
await check_permissions(action=ActionType.READ, job_db=job_db, job_ids=[job_id])
return f"This job {job_id}"
Expand All @@ -619,7 +619,7 @@ async def delete_single_job(
job_logging_db: JobLoggingDB,
task_queue_db: TaskQueueDB,
background_task: BackgroundTasks,
check_permissions: CheckPermissionsCallable,
check_permissions: CheckWMSPolicyCallable,
):
"""
Delete a job by killing and setting the job status to DELETED.
Expand Down Expand Up @@ -652,7 +652,7 @@ async def kill_single_job(
job_logging_db: JobLoggingDB,
task_queue_db: TaskQueueDB,
background_task: BackgroundTasks,
check_permissions: CheckPermissionsCallable,
check_permissions: CheckWMSPolicyCallable,
):
"""
Kill a job.
Expand Down Expand Up @@ -682,7 +682,7 @@ async def remove_single_job(
sandbox_metadata_db: SandboxMetadataDB,
task_queue_db: TaskQueueDB,
background_task: BackgroundTasks,
check_permissions: CheckPermissionsCallable,
check_permissions: CheckWMSPolicyCallable,
):
"""
Fully remove a job from the WMS databases.
Expand Down Expand Up @@ -714,7 +714,7 @@ async def remove_single_job(
async def get_single_job_status(
job_id: int,
job_db: JobDB,
check_permissions: CheckPermissionsCallable,
check_permissions: CheckWMSPolicyCallable,
) -> dict[int, LimitedJobStatusReturn]:
await check_permissions(action=ActionType.READ, job_db=job_db, job_ids=[job_id])
try:
Expand All @@ -732,7 +732,7 @@ async def set_single_job_status(
status: Annotated[dict[datetime, JobStatusUpdate], Body()],
job_db: JobDB,
job_logging_db: JobLoggingDB,
check_permissions: CheckPermissionsCallable,
check_permissions: CheckWMSPolicyCallable,
force: bool = False,
) -> dict[int, SetJobStatusReturn]:
await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=[job_id])
Expand All @@ -758,7 +758,7 @@ async def get_single_job_status_history(
job_id: int,
job_db: JobDB,
job_logging_db: JobLoggingDB,
check_permissions: CheckPermissionsCallable,
check_permissions: CheckWMSPolicyCallable,
) -> dict[int, list[JobStatusReturn]]:
await check_permissions(action=ActionType.READ, job_db=job_db, job_ids=[job_id])
try:
Expand All @@ -775,7 +775,7 @@ async def set_single_job_properties(
job_id: int,
job_properties: Annotated[dict[str, Any], Body()],
job_db: JobDB,
check_permissions: CheckPermissionsCallable,
check_permissions: CheckWMSPolicyCallable,
update_timestamp: bool = False,
):
"""
Expand Down
Loading

0 comments on commit 7b81cad

Please sign in to comment.