From 784409d47243b4053c0af6eebe9f944c04c79382 Mon Sep 17 00:00:00 2001 From: Canicula98 Date: Fri, 13 Sep 2024 10:26:17 +0200 Subject: [PATCH] proxy & middleware update --- examples/keycloak-auth/middleware.py | 15 ------ src/dataclay/backend/client.py | 17 +++---- src/dataclay/client/api.py | 11 +--- src/dataclay/proxy/__init__.py | 75 +++++++++++++--------------- 4 files changed, 42 insertions(+), 76 deletions(-) diff --git a/examples/keycloak-auth/middleware.py b/examples/keycloak-auth/middleware.py index 122db17..17c9b22 100644 --- a/examples/keycloak-auth/middleware.py +++ b/examples/keycloak-auth/middleware.py @@ -22,13 +22,7 @@ def CallActiveMethod(self, request, context): if request.method_name in self._method_names: # Method in whitelist - logger.info("EEEEEEEEEEEEEEEEEEEEEEPAAAAAAAAAAAAA") - logger.info(metadata.get("username")) - logger.info(metadata.get("password")) - logger.info(metadata.get("dataset-name")) - #logger.info(self._role_dataset()) if metadata.get("dataset-name") in self._role_dataset: - #logger.info(self._role_dataset[metadata.get("dataset-name")]) jwt_validation(metadata.get("username"), metadata.get("password"), self._role_dataset[metadata.get("dataset-name")]) #User has the necessary role to access the dataset return @@ -45,10 +39,7 @@ def GetObjectAttribute(self, request, context): for method in gets: if method in self._method_names: # Method in whitelist - logger.info(metadata.get("dataset-name")) - #logger.info(self._role_dataset()) if metadata.get("dataset-name") in self._role_dataset: - #logger.info(self._role_dataset[metadata.get("dataset-name")]) jwt_validation(metadata.get("username"), metadata.get("password"), self._role_dataset[metadata.get("dataset-name")]) #User has the necessary role to access the dataset return @@ -65,10 +56,7 @@ def SetObjectAttribute(self, request, context): for method in sets: if method in self._method_names: # Method in whitelist - logger.info(metadata.get("dataset-name")) - #logger.info(self._role_dataset()) if metadata.get("dataset-name") in self._role_dataset: - #logger.info(self._role_dataset[metadata.get("dataset-name")]) jwt_validation(metadata.get("username"), metadata.get("password"), self._role_dataset[metadata.get("dataset-name")]) #User has the necessary role to access the dataset return @@ -85,10 +73,7 @@ def DelObjectAttribute(self, request, context): for method in dels: if method in self._method_names: # Method in whitelist - logger.info(metadata.get("dataset-name")) - #logger.info(self._role_dataset()) if metadata.get("dataset-name") in self._role_dataset: - #logger.info(self._role_dataset[metadata.get("dataset-name")]) jwt_validation(metadata.get("username"), metadata.get("password"), self._role_dataset[metadata.get("dataset-name")]) #User has the necessary role to access the dataset return diff --git a/src/dataclay/backend/client.py b/src/dataclay/backend/client.py index 963652f..b03e7ac 100644 --- a/src/dataclay/backend/client.py +++ b/src/dataclay/backend/client.py @@ -177,7 +177,7 @@ async def call_active_method( metadata = self.metadata_call + [ ("dataset-name", current_context["dataset_name"]), ("username", current_context["username"]), - ("authorization", current_context["token"]), + ("password", current_context["password"]), ] response = await self.stub.CallActiveMethod(request, metadata=metadata) @@ -195,14 +195,9 @@ async def get_object_attribute(self, object_id: UUID, attribute: str) -> tuple[b ) current_context = session_var.get() metadata = self.metadata_call + [ + ("dataset-name", current_context["dataset_name"]), ("username", current_context["username"]), - ("authorization", current_context["token"]), - ] - response = await self.stub.GetObjectAttribute(request, metadata=metadata) - current_context = session_var.get() - metadata = self.metadata_call + [ - ("username", current_context["username"]), - ("authorization", current_context["token"]), + ("password", current_context["password"]), ] response = await self.stub.GetObjectAttribute(request, metadata=metadata) return response.value, response.is_exception @@ -218,8 +213,9 @@ async def set_object_attribute( ) current_context = session_var.get() metadata = self.metadata_call + [ + ("dataset-name", current_context["dataset_name"]), ("username", current_context["username"]), - ("authorization", current_context["token"]), + ("password", current_context["password"]), ] response = await self.stub.SetObjectAttribute(request, metadata=metadata) return response.value, response.is_exception @@ -232,8 +228,9 @@ async def del_object_attribute(self, object_id: UUID, attribute: str) -> tuple[b ) current_context = session_var.get() metadata = self.metadata_call + [ + ("dataset-name", current_context["dataset_name"]), ("username", current_context["username"]), - ("authorization", current_context["token"]), + ("password", current_context["password"]), ] response = await self.stub.DelObjectAttribute(request, metadata=metadata) return response.value, response.is_exception diff --git a/src/dataclay/client/api.py b/src/dataclay/client/api.py index ff21246..375e163 100644 --- a/src/dataclay/client/api.py +++ b/src/dataclay/client/api.py @@ -23,7 +23,6 @@ settings, ) from dataclay.event_loop import EventLoopThread, get_dc_event_loop, set_dc_event_loop -from dataclay.proxy import generate_jwt from dataclay.runtime import ClientRuntime from dataclay.utils.telemetry import trace @@ -115,9 +114,6 @@ class Client: is_active: bool = False - _token: bytes - _TOKEN_EXPIRATION = 24 * 30 - def __init__( self, host: Optional[str] = None, @@ -150,7 +146,6 @@ def __init__( settings_kwargs["proxy_port"] = proxy_port settings_kwargs["proxy_enabled"] = True - self._token = b"" self.settings = ClientSettings(**settings_kwargs) start_telemetry() @@ -201,10 +196,6 @@ def start(self): settings.client.proxy_port, ) self.runtime = ClientRuntime(settings.client.proxy_host, settings.client.proxy_port) - # Generate the JWT(JSON web token) - self._token = generate_jwt( - settings.client.password, settings.client.username, self._TOKEN_EXPIRATION - ) else: self.runtime = ClientRuntime( settings.client.dataclay_host, settings.client.dataclay_port @@ -219,7 +210,7 @@ def start(self): { "dataset_name": settings.client.dataset, "username": settings.client.username, - "token": self._token, + "password": settings.client.password, } ) diff --git a/src/dataclay/proxy/__init__.py b/src/dataclay/proxy/__init__.py index 7a1fa52..b8398f1 100644 --- a/src/dataclay/proxy/__init__.py +++ b/src/dataclay/proxy/__init__.py @@ -1,56 +1,49 @@ -import datetime import logging -from uuid import UUID - +import requests import jwt +from dataclay.exceptions import DataClayException +from dataclay.proxy.middleware import MiddlewareBase, MiddlewareException -from . import servicer logger = logging.getLogger(__name__) -def get_session(request): - raise Exception("This method should be replaced by jwt token validation in the future.") - """Retrieve Session information from the request.session_id field of the gRPC method.""" - if servicer.global_metadata_api is None: - raise SystemError("get_session is only available from within a Proxy running environment") +def jwt_validation(username, password, roles): + + from base64 import b64decode + from cryptography.hazmat.primitives import serialization - try: - session_id = UUID(request.session_id) - except AttributeError: - raise ValueError("This method did not have a syntactically valid SessionID") + USER_AUTH = { + "client_id": "direct-access-demo", + "username":username, + "password":password, + "grant_type": "password", + } - return servicer.global_metadata_api.get_session(session_id) + try: + r = requests.post( + "http://keycloak:8080/realms/dataclay/protocol/openid-connect/token", data=USER_AUTH + ) + r.raise_for_status() + except requests.exceptions.RequestException as e: + raise e + logger.info(r.json()) + token = r.json()["access_token"] + r = requests.get("http://keycloak:8080/realms/dataclay/") + r.raise_for_status() + key_der_base64 = r.json()["public_key"] -def generate_jwt(secret_key: str = "", user: str = "dataclay", TOKEN_EXPIRATION: int = 24 * 30): - # TODO: Store the username & password in a database in order to check it later - payload = { - "username": user, - "exp": datetime.datetime.now() + datetime.timedelta(hours=TOKEN_EXPIRATION), - } - token = jwt.encode(payload, secret_key, algorithm="HS256") - return token + key_der = b64decode(key_der_base64.encode()) + public_key = serialization.load_der_public_key(key_der) -def jwt_validation(username, token): - # TODO: Username & password should br required to validate the token - password = "s3cret" - try: - decoded_payload = jwt.decode(token, password, algorithms=["HS256"]) - if decoded_payload.get("username") != username: - raise Exception("Wrong username") - except jwt.ExpiredSignatureError as e: - raise e - except jwt.InvalidTokenError as e: - raise e + decoded_payload = jwt.decode(token, public_key, algorithms=["RS256"]) + if "realm_access" in decoded_payload: + for role in roles: + if role in decoded_payload["realm_access"]["roles"]: + return + raise MiddlewareException(f"The user '{username}' does not have the required role to access the database") + -def generate_jwt(secret_key: str = "", user: str = "dataclay", TOKEN_EXPIRATION: int = 24 * 30): - # TODO: Store the username & password in a database in order to check it later - payload = { - "username": user, - "exp": datetime.datetime.now() + datetime.timedelta(hours=TOKEN_EXPIRATION), - } - token = jwt.encode(payload, secret_key, algorithm="HS256") - return token