Skip to content

Commit

Permalink
feat: add authentication support (#501)
Browse files Browse the repository at this point in the history
  • Loading branch information
AntoineRR committed Jun 28, 2023
1 parent 4d7c6c2 commit 625da80
Show file tree
Hide file tree
Showing 12 changed files with 311 additions and 23 deletions.
29 changes: 29 additions & 0 deletions docs/features.md
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,35 @@ async def hello_after_request(response: Response):
print("This won't be executed if user isn't logged in")
```

## Authentication

Robyn provides an easy way to add an authentication middleware to your application. You can then specify `auth_required=True` in your routes to make them accessible only to authenticated users.

```python
@app.get("/auth", auth_required=True)
async def auth(request: Request):
# This route method will only be executed if the user is authenticated
# Otherwise, a 401 response will be returned
return "Hello, world"
```

To add an authentication middleware, you can use the `configure_authentication` method. This method requires an `AuthenticationHandler` object as an argument. This object specifies how to authenticate a user, and uses a `TokenGetter` object to retrieve the token from the request. Robyn does currently provide a `BearerGetter` class that gets the token from the `Authorization` header, using the `Bearer` scheme. Here is an example of a basic authentication handler:

```python
class BasicAuthHandler(AuthenticationHandler):
def authenticate(self, request: Request) -> Optional[Identity]:
token = self.token_getter.get_token(request)
if token == "valid":
return Identity(claims={})
return None

app.configure_authentication(BasicAuthHandler(token_getter=BearerGetter()))
```

Your `authenticate` method should return an `Identity` object if the user is authenticated, or `None` otherwise. The `Identity` object can contain any data you want, and will be accessible in your route methods using the `request.identity` attribute.

Note that this authentication system is basically only using a "before request" middleware under the hood. This means you can overlook it and create your own authentication system using middlewares if you want to. However, Robyn still provide this easy to implement solution that should suit most use cases.

## MultiCore Scaling

To run Robyn across multiple cores, you can use the following command:
Expand Down
28 changes: 28 additions & 0 deletions integration_tests/base_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import pathlib
from collections import defaultdict
from typing import Optional

from robyn import WS, Robyn, Request, Response, jsonify, serve_file, serve_html
from robyn.authentication import AuthenticationHandler, BearerGetter, Identity
from robyn.templating import JinjaTemplate

from views import SyncView, AsyncView
Expand Down Expand Up @@ -692,6 +694,23 @@ async def async_exception_post(_: Request):
raise ValueError("value error")


# ===== Authentication =====


@app.get("/sync/auth", auth_required=True)
def sync_auth(request: Request):
assert request.identity is not None
assert request.identity.claims == {"key": "value"}
return "authenticated"


@app.get("/async/auth", auth_required=True)
async def async_auth(request: Request):
assert request.identity is not None
assert request.identity.claims == {"key": "value"}
return "authenticated"


# ===== Main =====


Expand All @@ -706,4 +725,13 @@ async def async_exception_post(_: Request):
app.add_view("/sync/view", SyncView)
app.add_view("/async/view", AsyncView)
app.include_router(sub_router)

class BasicAuthHandler(AuthenticationHandler):
def authenticate(self, request: Request) -> Optional[Identity]:
token = self.token_getter.get_token(request)
if token == "valid":
return Identity(claims={"key": "value"})
return None

app.configure_authentication(BasicAuthHandler(token_getter=BearerGetter()))
app.start(port=8080)
42 changes: 42 additions & 0 deletions integration_tests/test_authentication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest

from helpers.http_methods_helpers import get


@pytest.mark.benchmark
@pytest.mark.parametrize("function_type", ["sync", "async"])
def test_valid_authentication(session, function_type: str):
r = get(f"/{function_type}/auth", headers={"Authorization": "Bearer valid"})
assert r.text == "authenticated"


@pytest.mark.benchmark
@pytest.mark.parametrize("function_type", ["sync", "async"])
def test_invalid_authentication_token(session, function_type: str):
r = get(
f"/{function_type}/auth",
headers={"Authorization": "Bearer invalid"},
should_check_response=False,
)
assert r.status_code == 401
assert r.headers["WWW-Authenticate"] == "BearerGetter"


@pytest.mark.benchmark
@pytest.mark.parametrize("function_type", ["sync", "async"])
def test_invalid_authentication_header(session, function_type: str):
r = get(
f"/{function_type}/auth",
headers={"Authorization": "Bear valid"},
should_check_response=False,
)
assert r.status_code == 401
assert r.headers["WWW-Authenticate"] == "BearerGetter"


@pytest.mark.benchmark
@pytest.mark.parametrize("function_type", ["sync", "async"])
def test_invalid_authentication_no_token(session, function_type: str):
r = get(f"/{function_type}/auth", should_check_response=False)
assert r.status_code == 401
assert r.headers["WWW-Authenticate"] == "BearerGetter"
81 changes: 60 additions & 21 deletions robyn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from nestd import get_all_nested

from robyn.argument_parser import Config
from robyn.authentication import AuthenticationHandler
from robyn.logger import Colors
from robyn.reloader import setup_reloader
from robyn.env_populator import load_vars
Expand Down Expand Up @@ -54,20 +55,31 @@ def __init__(self, file_object: str, config: Config = Config()) -> None:
self.directories: List[Directory] = []
self.event_handlers = {}
self.exception_handler: Optional[Callable] = None
self.authentication_handler: Optional[AuthenticationHandler] = None

def _add_route(
self, route_type: HttpMethod, endpoint: str, handler: Callable, is_const=False
self,
route_type: HttpMethod,
endpoint: str,
handler: Callable,
is_const: bool = False,
auth_required: bool = False,
):
"""
This is base handler for all the decorators
This is base handler for all the route decorators
:param route_type str: route type between GET/POST/PUT/DELETE/PATCH
:param route_type str: route type between GET/POST/PUT/DELETE/PATCH/HEAD/OPTIONS/TRACE
:param endpoint str: endpoint for the route added
:param handler function: represents the sync or async function passed as a handler for the route
:param is_const bool: represents if the handler is a const function or not
:param auth_required bool: represents if the route needs authentication or not
"""

""" We will add the status code here only
"""
if auth_required:
self.middleware_router.add_auth_middleware(endpoint)(handler)

return self.router.add_route(
route_type, endpoint, handler, is_const, self.exception_handler
)
Expand Down Expand Up @@ -206,111 +218,129 @@ def inner(handler):

return inner

def get(self, endpoint: str, const: bool = False):
def get(self, endpoint: str, const: bool = False, auth_required: bool = False):
"""
The @app.get decorator to add a route with the GET method
:param endpoint str: endpoint to server the route
"""

def inner(handler):
return self._add_route(HttpMethod.GET, endpoint, handler, const)
return self._add_route(
HttpMethod.GET, endpoint, handler, const, auth_required
)

return inner

def post(self, endpoint: str):
def post(self, endpoint: str, auth_required: bool = False):
"""
The @app.post decorator to add a route with POST method
:param endpoint str: endpoint to server the route
"""

def inner(handler):
return self._add_route(HttpMethod.POST, endpoint, handler)
return self._add_route(
HttpMethod.POST, endpoint, handler, auth_required=auth_required
)

return inner

def put(self, endpoint: str):
def put(self, endpoint: str, auth_required: bool = False):
"""
The @app.put decorator to add a get route with PUT method
:param endpoint str: endpoint to server the route
"""

def inner(handler):
return self._add_route(HttpMethod.PUT, endpoint, handler)
return self._add_route(
HttpMethod.PUT, endpoint, handler, auth_required=auth_required
)

return inner

def delete(self, endpoint: str):
def delete(self, endpoint: str, auth_required: bool = False):
"""
The @app.delete decorator to add a route with DELETE method
:param endpoint str: endpoint to server the route
"""

def inner(handler):
return self._add_route(HttpMethod.DELETE, endpoint, handler)
return self._add_route(
HttpMethod.DELETE, endpoint, handler, auth_required=auth_required
)

return inner

def patch(self, endpoint: str):
def patch(self, endpoint: str, auth_required: bool = False):
"""
The @app.patch decorator to add a route with PATCH method
:param endpoint [str]: [endpoint to server the route]
"""

def inner(handler):
return self._add_route(HttpMethod.PATCH, endpoint, handler)
return self._add_route(
HttpMethod.PATCH, endpoint, handler, auth_required=auth_required
)

return inner

def head(self, endpoint: str):
def head(self, endpoint: str, auth_required: bool = False):
"""
The @app.head decorator to add a route with HEAD method
:param endpoint str: endpoint to server the route
"""

def inner(handler):
return self._add_route(HttpMethod.HEAD, endpoint, handler)
return self._add_route(
HttpMethod.HEAD, endpoint, handler, auth_required=auth_required
)

return inner

def options(self, endpoint: str):
def options(self, endpoint: str, auth_required: bool = False):
"""
The @app.options decorator to add a route with OPTIONS method
:param endpoint str: endpoint to server the route
"""

def inner(handler):
return self._add_route(HttpMethod.OPTIONS, endpoint, handler)
return self._add_route(
HttpMethod.OPTIONS, endpoint, handler, auth_required=auth_required
)

return inner

def connect(self, endpoint: str):
def connect(self, endpoint: str, auth_required: bool = False):
"""
The @app.connect decorator to add a route with CONNECT method
:param endpoint str: endpoint to server the route
"""

def inner(handler):
return self._add_route(HttpMethod.CONNECT, endpoint, handler)
return self._add_route(
HttpMethod.CONNECT, endpoint, handler, auth_required=auth_required
)

return inner

def trace(self, endpoint: str):
def trace(self, endpoint: str, auth_required: bool = False):
"""
The @app.trace decorator to add a route with TRACE method
:param endpoint str: endpoint to server the route
"""

def inner(handler):
return self._add_route(HttpMethod.TRACE, endpoint, handler)
return self._add_route(
HttpMethod.TRACE, endpoint, handler, auth_required=auth_required
)

return inner

Expand All @@ -336,6 +366,15 @@ def include_router(self, router):
new_endpoint
] = router.web_socket_router.routes[route]

def configure_authentication(self, authentication_handler: AuthenticationHandler):
"""
Configures the authentication handler for the application.
:param authentication_handler: the instance of a class inheriting the AuthenticationHandler base class
"""
self.authentication_handler = authentication_handler
self.middleware_router.set_authentication_handler(authentication_handler)


class SubRouter(Robyn):
def __init__(
Expand Down
Loading

0 comments on commit 625da80

Please sign in to comment.