Skip to content

Commit

Permalink
proxy & middleware update
Browse files Browse the repository at this point in the history
  • Loading branch information
Canicula98 committed Sep 13, 2024
1 parent aab5e5a commit 784409d
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 76 deletions.
15 changes: 0 additions & 15 deletions examples/keycloak-auth/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,7 @@ def CallActiveMethod(self, request, context):

if request.method_name in self._method_names:
# Method in whitelist
logger.info("EEEEEEEEEEEEEEEEEEEEEEPAAAAAAAAAAAAA")
logger.info(metadata.get("username"))
logger.info(metadata.get("password"))
logger.info(metadata.get("dataset-name"))
#logger.info(self._role_dataset())
if metadata.get("dataset-name") in self._role_dataset:
#logger.info(self._role_dataset[metadata.get("dataset-name")])
jwt_validation(metadata.get("username"), metadata.get("password"), self._role_dataset[metadata.get("dataset-name")])
#User has the necessary role to access the dataset
return
Expand All @@ -45,10 +39,7 @@ def GetObjectAttribute(self, request, context):
for method in gets:
if method in self._method_names:
# Method in whitelist
logger.info(metadata.get("dataset-name"))
#logger.info(self._role_dataset())
if metadata.get("dataset-name") in self._role_dataset:
#logger.info(self._role_dataset[metadata.get("dataset-name")])
jwt_validation(metadata.get("username"), metadata.get("password"), self._role_dataset[metadata.get("dataset-name")])
#User has the necessary role to access the dataset
return
Expand All @@ -65,10 +56,7 @@ def SetObjectAttribute(self, request, context):
for method in sets:
if method in self._method_names:
# Method in whitelist
logger.info(metadata.get("dataset-name"))
#logger.info(self._role_dataset())
if metadata.get("dataset-name") in self._role_dataset:
#logger.info(self._role_dataset[metadata.get("dataset-name")])
jwt_validation(metadata.get("username"), metadata.get("password"), self._role_dataset[metadata.get("dataset-name")])
#User has the necessary role to access the dataset
return
Expand All @@ -85,10 +73,7 @@ def DelObjectAttribute(self, request, context):
for method in dels:
if method in self._method_names:
# Method in whitelist
logger.info(metadata.get("dataset-name"))
#logger.info(self._role_dataset())
if metadata.get("dataset-name") in self._role_dataset:
#logger.info(self._role_dataset[metadata.get("dataset-name")])
jwt_validation(metadata.get("username"), metadata.get("password"), self._role_dataset[metadata.get("dataset-name")])
#User has the necessary role to access the dataset
return
Expand Down
17 changes: 7 additions & 10 deletions src/dataclay/backend/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ async def call_active_method(
metadata = self.metadata_call + [
("dataset-name", current_context["dataset_name"]),
("username", current_context["username"]),
("authorization", current_context["token"]),
("password", current_context["password"]),
]

response = await self.stub.CallActiveMethod(request, metadata=metadata)
Expand All @@ -195,14 +195,9 @@ async def get_object_attribute(self, object_id: UUID, attribute: str) -> tuple[b
)
current_context = session_var.get()
metadata = self.metadata_call + [
("dataset-name", current_context["dataset_name"]),
("username", current_context["username"]),
("authorization", current_context["token"]),
]
response = await self.stub.GetObjectAttribute(request, metadata=metadata)
current_context = session_var.get()
metadata = self.metadata_call + [
("username", current_context["username"]),
("authorization", current_context["token"]),
("password", current_context["password"]),
]
response = await self.stub.GetObjectAttribute(request, metadata=metadata)
return response.value, response.is_exception
Expand All @@ -218,8 +213,9 @@ async def set_object_attribute(
)
current_context = session_var.get()
metadata = self.metadata_call + [
("dataset-name", current_context["dataset_name"]),
("username", current_context["username"]),
("authorization", current_context["token"]),
("password", current_context["password"]),
]
response = await self.stub.SetObjectAttribute(request, metadata=metadata)
return response.value, response.is_exception
Expand All @@ -232,8 +228,9 @@ async def del_object_attribute(self, object_id: UUID, attribute: str) -> tuple[b
)
current_context = session_var.get()
metadata = self.metadata_call + [
("dataset-name", current_context["dataset_name"]),
("username", current_context["username"]),
("authorization", current_context["token"]),
("password", current_context["password"]),
]
response = await self.stub.DelObjectAttribute(request, metadata=metadata)
return response.value, response.is_exception
Expand Down
11 changes: 1 addition & 10 deletions src/dataclay/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
settings,
)
from dataclay.event_loop import EventLoopThread, get_dc_event_loop, set_dc_event_loop
from dataclay.proxy import generate_jwt
from dataclay.runtime import ClientRuntime
from dataclay.utils.telemetry import trace

Expand Down Expand Up @@ -115,9 +114,6 @@ class Client:

is_active: bool = False

_token: bytes
_TOKEN_EXPIRATION = 24 * 30

def __init__(
self,
host: Optional[str] = None,
Expand Down Expand Up @@ -150,7 +146,6 @@ def __init__(
settings_kwargs["proxy_port"] = proxy_port
settings_kwargs["proxy_enabled"] = True

self._token = b""
self.settings = ClientSettings(**settings_kwargs)

start_telemetry()
Expand Down Expand Up @@ -201,10 +196,6 @@ def start(self):
settings.client.proxy_port,
)
self.runtime = ClientRuntime(settings.client.proxy_host, settings.client.proxy_port)
# Generate the JWT(JSON web token)
self._token = generate_jwt(
settings.client.password, settings.client.username, self._TOKEN_EXPIRATION
)
else:
self.runtime = ClientRuntime(
settings.client.dataclay_host, settings.client.dataclay_port
Expand All @@ -219,7 +210,7 @@ def start(self):
{
"dataset_name": settings.client.dataset,
"username": settings.client.username,
"token": self._token,
"password": settings.client.password,
}
)

Expand Down
75 changes: 34 additions & 41 deletions src/dataclay/proxy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,49 @@
import datetime
import logging
from uuid import UUID

import requests
import jwt
from dataclay.exceptions import DataClayException
from dataclay.proxy.middleware import MiddlewareBase, MiddlewareException

from . import servicer

logger = logging.getLogger(__name__)


def get_session(request):
raise Exception("This method should be replaced by jwt token validation in the future.")
"""Retrieve Session information from the request.session_id field of the gRPC method."""
if servicer.global_metadata_api is None:
raise SystemError("get_session is only available from within a Proxy running environment")
def jwt_validation(username, password, roles):

from base64 import b64decode
from cryptography.hazmat.primitives import serialization

try:
session_id = UUID(request.session_id)
except AttributeError:
raise ValueError("This method did not have a syntactically valid SessionID")
USER_AUTH = {
"client_id": "direct-access-demo",
"username":username,
"password":password,
"grant_type": "password",
}

return servicer.global_metadata_api.get_session(session_id)
try:
r = requests.post(
"http://keycloak:8080/realms/dataclay/protocol/openid-connect/token", data=USER_AUTH
)
r.raise_for_status()
except requests.exceptions.RequestException as e:
raise e
logger.info(r.json())
token = r.json()["access_token"]

r = requests.get("http://keycloak:8080/realms/dataclay/")
r.raise_for_status()
key_der_base64 = r.json()["public_key"]

def generate_jwt(secret_key: str = "", user: str = "dataclay", TOKEN_EXPIRATION: int = 24 * 30):
# TODO: Store the username & password in a database in order to check it later
payload = {
"username": user,
"exp": datetime.datetime.now() + datetime.timedelta(hours=TOKEN_EXPIRATION),
}
token = jwt.encode(payload, secret_key, algorithm="HS256")
return token
key_der = b64decode(key_der_base64.encode())

public_key = serialization.load_der_public_key(key_der)

def jwt_validation(username, token):
# TODO: Username & password should br required to validate the token
password = "s3cret"
try:
decoded_payload = jwt.decode(token, password, algorithms=["HS256"])
if decoded_payload.get("username") != username:
raise Exception("Wrong username")
except jwt.ExpiredSignatureError as e:
raise e
except jwt.InvalidTokenError as e:
raise e
decoded_payload = jwt.decode(token, public_key, algorithms=["RS256"])

if "realm_access" in decoded_payload:
for role in roles:
if role in decoded_payload["realm_access"]["roles"]:
return
raise MiddlewareException(f"The user '{username}' does not have the required role to access the database")


def generate_jwt(secret_key: str = "", user: str = "dataclay", TOKEN_EXPIRATION: int = 24 * 30):
# TODO: Store the username & password in a database in order to check it later
payload = {
"username": user,
"exp": datetime.datetime.now() + datetime.timedelta(hours=TOKEN_EXPIRATION),
}
token = jwt.encode(payload, secret_key, algorithm="HS256")
return token

0 comments on commit 784409d

Please sign in to comment.