diff --git a/docs/features.md b/docs/features.md index 74abb3c6..91b6ff6a 100644 --- a/docs/features.md +++ b/docs/features.md @@ -397,6 +397,35 @@ async def hello_after_request(response: Response): print("This won't be executed if user isn't logged in") ``` +## Authentication + +Robyn provides an easy way to add an authentication middleware to your application. You can then specify `auth_required=True` in your routes to make them accessible only to authenticated users. + +```python +@app.get("/auth", auth_required=True) +async def auth(request: Request): + # This route method will only be executed if the user is authenticated + # Otherwise, a 401 response will be returned + return "Hello, world" +``` + +To add an authentication middleware, you can use the `configure_authentication` method. This method requires an `AuthenticationHandler` object as an argument. This object specifies how to authenticate a user, and uses a `TokenGetter` object to retrieve the token from the request. Robyn does currently provide a `BearerGetter` class that gets the token from the `Authorization` header, using the `Bearer` scheme. Here is an example of a basic authentication handler: + +```python +class BasicAuthHandler(AuthenticationHandler): + def authenticate(self, request: Request) -> Optional[Identity]: + token = self.token_getter.get_token(request) + if token == "valid": + return Identity(claims={}) + return None + +app.configure_authentication(BasicAuthHandler(token_getter=BearerGetter())) +``` + +Your `authenticate` method should return an `Identity` object if the user is authenticated, or `None` otherwise. The `Identity` object can contain any data you want, and will be accessible in your route methods using the `request.identity` attribute. + +Note that this authentication system is basically only using a "before request" middleware under the hood. This means you can overlook it and create your own authentication system using middlewares if you want to. However, Robyn still provide this easy to implement solution that should suit most use cases. + ## MultiCore Scaling To run Robyn across multiple cores, you can use the following command: diff --git a/integration_tests/base_routes.py b/integration_tests/base_routes.py index 690f2290..235f2d90 100644 --- a/integration_tests/base_routes.py +++ b/integration_tests/base_routes.py @@ -2,8 +2,10 @@ import pathlib from collections import defaultdict +from typing import Optional from robyn import WS, Robyn, Request, Response, jsonify, serve_file, serve_html +from robyn.authentication import AuthenticationHandler, BearerGetter, Identity from robyn.templating import JinjaTemplate from views import SyncView, AsyncView @@ -692,6 +694,23 @@ async def async_exception_post(_: Request): raise ValueError("value error") +# ===== Authentication ===== + + +@app.get("/sync/auth", auth_required=True) +def sync_auth(request: Request): + assert request.identity is not None + assert request.identity.claims == {"key": "value"} + return "authenticated" + + +@app.get("/async/auth", auth_required=True) +async def async_auth(request: Request): + assert request.identity is not None + assert request.identity.claims == {"key": "value"} + return "authenticated" + + # ===== Main ===== @@ -706,4 +725,13 @@ async def async_exception_post(_: Request): app.add_view("/sync/view", SyncView) app.add_view("/async/view", AsyncView) app.include_router(sub_router) + + class BasicAuthHandler(AuthenticationHandler): + def authenticate(self, request: Request) -> Optional[Identity]: + token = self.token_getter.get_token(request) + if token == "valid": + return Identity(claims={"key": "value"}) + return None + + app.configure_authentication(BasicAuthHandler(token_getter=BearerGetter())) app.start(port=8080) diff --git a/integration_tests/test_authentication.py b/integration_tests/test_authentication.py new file mode 100644 index 00000000..55b77c53 --- /dev/null +++ b/integration_tests/test_authentication.py @@ -0,0 +1,42 @@ +import pytest + +from helpers.http_methods_helpers import get + + +@pytest.mark.benchmark +@pytest.mark.parametrize("function_type", ["sync", "async"]) +def test_valid_authentication(session, function_type: str): + r = get(f"/{function_type}/auth", headers={"Authorization": "Bearer valid"}) + assert r.text == "authenticated" + + +@pytest.mark.benchmark +@pytest.mark.parametrize("function_type", ["sync", "async"]) +def test_invalid_authentication_token(session, function_type: str): + r = get( + f"/{function_type}/auth", + headers={"Authorization": "Bearer invalid"}, + should_check_response=False, + ) + assert r.status_code == 401 + assert r.headers["WWW-Authenticate"] == "BearerGetter" + + +@pytest.mark.benchmark +@pytest.mark.parametrize("function_type", ["sync", "async"]) +def test_invalid_authentication_header(session, function_type: str): + r = get( + f"/{function_type}/auth", + headers={"Authorization": "Bear valid"}, + should_check_response=False, + ) + assert r.status_code == 401 + assert r.headers["WWW-Authenticate"] == "BearerGetter" + + +@pytest.mark.benchmark +@pytest.mark.parametrize("function_type", ["sync", "async"]) +def test_invalid_authentication_no_token(session, function_type: str): + r = get(f"/{function_type}/auth", should_check_response=False) + assert r.status_code == 401 + assert r.headers["WWW-Authenticate"] == "BearerGetter" diff --git a/robyn/__init__.py b/robyn/__init__.py index c3158c41..b5f6b458 100644 --- a/robyn/__init__.py +++ b/robyn/__init__.py @@ -6,6 +6,7 @@ from nestd import get_all_nested from robyn.argument_parser import Config +from robyn.authentication import AuthenticationHandler from robyn.logger import Colors from robyn.reloader import setup_reloader from robyn.env_populator import load_vars @@ -54,20 +55,31 @@ def __init__(self, file_object: str, config: Config = Config()) -> None: self.directories: List[Directory] = [] self.event_handlers = {} self.exception_handler: Optional[Callable] = None + self.authentication_handler: Optional[AuthenticationHandler] = None def _add_route( - self, route_type: HttpMethod, endpoint: str, handler: Callable, is_const=False + self, + route_type: HttpMethod, + endpoint: str, + handler: Callable, + is_const: bool = False, + auth_required: bool = False, ): """ - This is base handler for all the decorators + This is base handler for all the route decorators - :param route_type str: route type between GET/POST/PUT/DELETE/PATCH + :param route_type str: route type between GET/POST/PUT/DELETE/PATCH/HEAD/OPTIONS/TRACE :param endpoint str: endpoint for the route added :param handler function: represents the sync or async function passed as a handler for the route + :param is_const bool: represents if the handler is a const function or not + :param auth_required bool: represents if the route needs authentication or not """ """ We will add the status code here only """ + if auth_required: + self.middleware_router.add_auth_middleware(endpoint)(handler) + return self.router.add_route( route_type, endpoint, handler, is_const, self.exception_handler ) @@ -206,7 +218,7 @@ def inner(handler): return inner - def get(self, endpoint: str, const: bool = False): + def get(self, endpoint: str, const: bool = False, auth_required: bool = False): """ The @app.get decorator to add a route with the GET method @@ -214,11 +226,13 @@ def get(self, endpoint: str, const: bool = False): """ def inner(handler): - return self._add_route(HttpMethod.GET, endpoint, handler, const) + return self._add_route( + HttpMethod.GET, endpoint, handler, const, auth_required + ) return inner - def post(self, endpoint: str): + def post(self, endpoint: str, auth_required: bool = False): """ The @app.post decorator to add a route with POST method @@ -226,11 +240,13 @@ def post(self, endpoint: str): """ def inner(handler): - return self._add_route(HttpMethod.POST, endpoint, handler) + return self._add_route( + HttpMethod.POST, endpoint, handler, auth_required=auth_required + ) return inner - def put(self, endpoint: str): + def put(self, endpoint: str, auth_required: bool = False): """ The @app.put decorator to add a get route with PUT method @@ -238,11 +254,13 @@ def put(self, endpoint: str): """ def inner(handler): - return self._add_route(HttpMethod.PUT, endpoint, handler) + return self._add_route( + HttpMethod.PUT, endpoint, handler, auth_required=auth_required + ) return inner - def delete(self, endpoint: str): + def delete(self, endpoint: str, auth_required: bool = False): """ The @app.delete decorator to add a route with DELETE method @@ -250,11 +268,13 @@ def delete(self, endpoint: str): """ def inner(handler): - return self._add_route(HttpMethod.DELETE, endpoint, handler) + return self._add_route( + HttpMethod.DELETE, endpoint, handler, auth_required=auth_required + ) return inner - def patch(self, endpoint: str): + def patch(self, endpoint: str, auth_required: bool = False): """ The @app.patch decorator to add a route with PATCH method @@ -262,11 +282,13 @@ def patch(self, endpoint: str): """ def inner(handler): - return self._add_route(HttpMethod.PATCH, endpoint, handler) + return self._add_route( + HttpMethod.PATCH, endpoint, handler, auth_required=auth_required + ) return inner - def head(self, endpoint: str): + def head(self, endpoint: str, auth_required: bool = False): """ The @app.head decorator to add a route with HEAD method @@ -274,11 +296,13 @@ def head(self, endpoint: str): """ def inner(handler): - return self._add_route(HttpMethod.HEAD, endpoint, handler) + return self._add_route( + HttpMethod.HEAD, endpoint, handler, auth_required=auth_required + ) return inner - def options(self, endpoint: str): + def options(self, endpoint: str, auth_required: bool = False): """ The @app.options decorator to add a route with OPTIONS method @@ -286,11 +310,13 @@ def options(self, endpoint: str): """ def inner(handler): - return self._add_route(HttpMethod.OPTIONS, endpoint, handler) + return self._add_route( + HttpMethod.OPTIONS, endpoint, handler, auth_required=auth_required + ) return inner - def connect(self, endpoint: str): + def connect(self, endpoint: str, auth_required: bool = False): """ The @app.connect decorator to add a route with CONNECT method @@ -298,11 +324,13 @@ def connect(self, endpoint: str): """ def inner(handler): - return self._add_route(HttpMethod.CONNECT, endpoint, handler) + return self._add_route( + HttpMethod.CONNECT, endpoint, handler, auth_required=auth_required + ) return inner - def trace(self, endpoint: str): + def trace(self, endpoint: str, auth_required: bool = False): """ The @app.trace decorator to add a route with TRACE method @@ -310,7 +338,9 @@ def trace(self, endpoint: str): """ def inner(handler): - return self._add_route(HttpMethod.TRACE, endpoint, handler) + return self._add_route( + HttpMethod.TRACE, endpoint, handler, auth_required=auth_required + ) return inner @@ -336,6 +366,15 @@ def include_router(self, router): new_endpoint ] = router.web_socket_router.routes[route] + def configure_authentication(self, authentication_handler: AuthenticationHandler): + """ + Configures the authentication handler for the application. + + :param authentication_handler: the instance of a class inheriting the AuthenticationHandler base class + """ + self.authentication_handler = authentication_handler + self.middleware_router.set_authentication_handler(authentication_handler) + class SubRouter(Robyn): def __init__( diff --git a/robyn/authentication.py b/robyn/authentication.py new file mode 100644 index 00000000..0e110d0b --- /dev/null +++ b/robyn/authentication.py @@ -0,0 +1,91 @@ +from abc import ABC, abstractclassmethod, abstractmethod +from typing import Optional + +from robyn.robyn import Identity, Request, Response +from robyn.status_codes import HTTP_401_UNAUTHORIZED + + +class AuthenticationNotConfiguredError(Exception): + """ + This exception is raised when the authentication is not configured. + """ + + def __str__(self): + return "Authentication is not configured. Use app.configure_authentication() to configure it." + + +class TokenGetter(ABC): + @property + def scheme(self) -> str: + """ + Gets the scheme of the token. + :return: The scheme of the token. + """ + return self.__class__.__name__ + + @abstractclassmethod + def get_token(cls, request: Request) -> Optional[str]: + """ + Gets the token from the request. + This method should not decode the token. Decoding is the role of the authentication handler. + :param request: The request object. + :return: The encoded token. + """ + raise NotImplementedError() + + @abstractclassmethod + def set_token(cls, request: Request, token: str): + """ + Sets the token in the request. + This method should not encode the token. Encoding is the role of the authentication handler. + :param request: The request object. + :param token: The encoded token. + """ + raise NotImplementedError() + + +class AuthenticationHandler(ABC): + def __init__(self, token_getter: TokenGetter): + """ + Creates a new instance of the AuthenticationHandler class. + This class is an abstract class used to authenticate a user. + :param token_getter: The token getter used to get the token from the request. + """ + self.token_getter = token_getter + + @property + def unauthorized_response(self) -> Response: + return Response( + headers={"WWW-Authenticate": self.token_getter.scheme}, + body="Unauthorized", + status_code=HTTP_401_UNAUTHORIZED, + ) + + @abstractmethod + def authenticate(self, request: Request) -> Optional[Identity]: + """ + Authenticates the user. + :param request: The request object. + :return: The identity of the user. + """ + raise NotImplementedError() + + +class BearerGetter(TokenGetter): + """ + This class is used to get the token from the Authorization header. + The scheme of the header must be Bearer. + """ + + @classmethod + def get_token(cls, request: Request) -> Optional[str]: + authorization_header = request.headers.get("authorization") + + if not authorization_header or not authorization_header.startswith("Bearer "): + return None + + return authorization_header[7:] + + @classmethod + def set_token(cls, request: Request, token: str): + request.headers["Authorization"] = f"Bearer {token}" diff --git a/robyn/robyn.pyi b/robyn/robyn.pyi index 0de9d689..d332917f 100644 --- a/robyn/robyn.pyi +++ b/robyn/robyn.pyi @@ -49,6 +49,10 @@ class Url: host: str path: str +@dataclass +class Identity: + claims: dict[str, str] + @dataclass class Request: """ @@ -70,6 +74,7 @@ class Request: method: str url: Url ip_addr: Optional[str] + identity: Optional[Identity] @dataclass class Response: diff --git a/robyn/router.py b/robyn/router.py index e4d71861..cfa0fd78 100644 --- a/robyn/router.py +++ b/robyn/router.py @@ -4,8 +4,9 @@ from inspect import signature from types import CoroutineType from typing import Callable, Dict, List, NamedTuple, Union, Optional +from robyn.authentication import AuthenticationHandler, AuthenticationNotConfiguredError -from robyn.robyn import FunctionInfo, HttpMethod, MiddlewareType, Response +from robyn.robyn import FunctionInfo, HttpMethod, MiddlewareType, Request, Response from robyn import status_codes from robyn.ws import WS @@ -116,6 +117,10 @@ def __init__(self) -> None: super().__init__() self.global_middlewares: List[GlobalMiddleware] = [] self.route_middlewares: List[RouteMiddleware] = [] + self.authentication_handler: Optional[AuthenticationHandler] = None + + def set_authentication_handler(self, authentication_handler: AuthenticationHandler): + self.authentication_handler = authentication_handler def add_route( self, middleware_type: MiddlewareType, endpoint: str, handler: Callable @@ -127,6 +132,26 @@ def add_route( ) return handler + def add_auth_middleware(self, endpoint: str): + """ + This method adds an authentication middleware to the specified endpoint. + """ + + def inner(handler): + def inner_handler(request: Request, *args): + if not self.authentication_handler: + raise AuthenticationNotConfiguredError() + identity = self.authentication_handler.authenticate(request) + if identity is None: + return self.authentication_handler.unauthorized_response + request.identity = identity + return request + + self.add_route(MiddlewareType.BEFORE_REQUEST, endpoint, inner_handler) + return inner_handler + + return inner + # These inner functions are basically a wrapper around the closure(decorator) being returned. # They take a handler, convert it into a closure and return the arguments. # Arguments are returned as they could be modified by the middlewares. diff --git a/src/lib.rs b/src/lib.rs index fcf7e00a..03fad164 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,7 @@ use shared_socket::SocketHeld; use pyo3::prelude::*; use types::{ function_info::{FunctionInfo, MiddlewareType}, + identity::Identity, request::PyRequest, response::PyResponse, HttpMethod, @@ -30,6 +31,7 @@ pub fn robyn(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/src/server.rs b/src/server.rs index d5a7f198..49a4af2d 100644 --- a/src/server.rs +++ b/src/server.rs @@ -401,7 +401,7 @@ async fn index( } Err(e) => { error!( - "Error while executing after middleware function for endpoint `{}`: {}", + "Error while executing before middleware function for endpoint `{}`: {}", req.uri().path(), get_traceback(e.downcast_ref::().unwrap()) ); diff --git a/src/types/identity.rs b/src/types/identity.rs new file mode 100644 index 00000000..68babde6 --- /dev/null +++ b/src/types/identity.rs @@ -0,0 +1,18 @@ +use std::collections::HashMap; + +use pyo3::{pyclass, pymethods}; + +#[pyclass] +#[derive(Debug, Clone)] +pub struct Identity { + #[pyo3(get, set)] + claims: HashMap, +} + +#[pymethods] +impl Identity { + #[new] + pub fn new(claims: HashMap) -> Self { + Self { claims } + } +} diff --git a/src/types/mod.rs b/src/types/mod.rs index b1e27e15..74bbec7b 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -5,9 +5,11 @@ use pyo3::{ }; pub mod function_info; +pub mod identity; pub mod request; pub mod response; +#[allow(clippy::large_enum_variant)] pub enum MiddlewareReturn { Request(request::Request), Response(response::Response), diff --git a/src/types/request.rs b/src/types/request.rs index 335faf1f..70f41cbf 100644 --- a/src/types/request.rs +++ b/src/types/request.rs @@ -5,6 +5,8 @@ use std::collections::HashMap; use crate::types::{check_body_type, get_body_from_pyobject, Url}; +use super::identity::Identity; + #[derive(Default, Debug, Clone, FromPyObject)] pub struct Request { pub queries: HashMap, @@ -15,6 +17,7 @@ pub struct Request { pub body: Vec, pub url: Url, pub ip_addr: Option, + pub identity: Option, } impl ToPyObject for Request { @@ -35,6 +38,7 @@ impl ToPyObject for Request { method: self.method.clone(), url: self.url.clone(), ip_addr: self.ip_addr.clone(), + identity: self.identity.clone(), }; Py::new(py, request).unwrap().as_ref(py).into() } @@ -80,6 +84,7 @@ impl Request { body: body.to_vec(), url, ip_addr, + identity: None, } } } @@ -93,6 +98,8 @@ pub struct PyRequest { pub headers: Py, #[pyo3(get, set)] pub path_params: Py, + #[pyo3(get, set)] + pub identity: Option, #[pyo3(get)] pub body: Py, #[pyo3(get)]