Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add rate limit throttling #572

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions clippy.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
too-many-arguments-threshold = 8
7 changes: 7 additions & 0 deletions docs/env-variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,10 @@ To configure the max payload size, you can set the `ROBYN_MAX_PAYLOAD_SIZE` envi
ROBYN_MAX_PAYLOAD_SIZE=1000000
```

To configure the cache retention period, you can set the `ROBYN_CACHE_RETENTION` environment variable. The default value is `60` seconds.

```bash
#robyn.env
ROBYN_MAX_PAYLOAD_SIZE=60
```

22 changes: 22 additions & 0 deletions docs/features.md
Original file line number Diff line number Diff line change
Expand Up @@ -603,3 +603,25 @@ def hello():

app.include_router(sub_router)
```

## Rate limiting

Robyn provides built-in rate limiting functionality, allowing you to control the frequency of requests made to the API.

By implementing rate limiting, you can manage resource consumption, prevent abuse, and maintain system stability, ensuring a smooth and reliable user experience.

In order to the rate limiter to properly identify users, a `Request` object must be defined as an input parameter for the route.

```python
from robyn import Robyn, RateLimiter

app = Robyn(__file__)

rate_limiter = RateLimiter(calls_limit=3, limit_ttl=60)

@app.get("/throttled_route", rate_limiter=rate_limiter)
def throttled_route(request: Request):
return "OK"
```

In this example the limit for `/throttled_route` is 3 calls per 1 minutes (60 seconds)
36 changes: 36 additions & 0 deletions integration_tests/base_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from robyn import WS, Robyn, Request, Response, jsonify, serve_file, serve_html
from robyn.authentication import AuthenticationHandler, BearerGetter, Identity
from robyn.templating import JinjaTemplate
from robyn.throttling import RateLimiter

from integration_tests.views import SyncView, AsyncView
from integration_tests.subroutes import sub_router
Expand Down Expand Up @@ -711,6 +712,41 @@ async def async_auth(request: Request):
return "authenticated"


# ===== Rate Limiting ====

rate_limiter = RateLimiter(calls_limit=3, limit_ttl=60)


@app.get("/sync/rate/get", rate_limiter=rate_limiter)
def sync_rate_get(request: Request):
return "OK"


@app.get("/async/rate/get", rate_limiter=rate_limiter)
async def async_rate_get(request: Request):
return "OK"


@app.put("/sync/rate/put", rate_limiter=rate_limiter)
def sync_rate_put(request: Request):
return "OK"


@app.put("/async/rate/put", rate_limiter=rate_limiter)
async def async_rate_put(request: Request):
return "OK"


@app.post("/sync/rate/post", rate_limiter=rate_limiter)
def sync_rate_post(request: Request):
return "OK"


@app.post("/async/rate/post", rate_limiter=rate_limiter)
async def async_rate_post(request: Request):
return "OK"


# ===== Main =====


Expand Down
11 changes: 11 additions & 0 deletions integration_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,17 @@ def test_session():
kill_process(process)


@pytest.fixture
def test_rate_limiting_session():
domain = "127.0.0.1"
port = 8082
os.environ["ROBYN_URL"] = domain
os.environ["ROBYN_PORT"] = str(port)
process = start_server(domain, port, is_dev=True)
yield
kill_process(process)


# create robyn.env before test and delete it after test
@pytest.fixture
def env_file():
Expand Down
30 changes: 22 additions & 8 deletions integration_tests/helpers/http_methods_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def get(
expected_status_code: int = 200,
headers: dict = {},
should_check_response: bool = True,
*,
base_url: str = BASE_URL,
) -> requests.Response:
"""
Makes a GET request to the given endpoint and checks the response.
Expand All @@ -32,7 +34,7 @@ def get(
should_check_response bool: A boolean to indicate if the status code and headers should be checked.
"""
endpoint = endpoint.strip("/")
response = requests.get(f"{BASE_URL}/{endpoint}", headers=headers)
response = requests.get(f"{base_url}/{endpoint}", headers=headers)
if should_check_response:
check_response(response, expected_status_code)
return response
Expand All @@ -44,6 +46,8 @@ def post(
expected_status_code: int = 200,
headers: dict = {},
should_check_response: bool = True,
*,
base_url: str = BASE_URL,
) -> requests.Response:
"""
Makes a POST request to the given endpoint and checks the response.
Expand All @@ -55,7 +59,7 @@ def post(
"""

endpoint = endpoint.strip("/")
response = requests.post(f"{BASE_URL}/{endpoint}", data=data, headers=headers)
response = requests.post(f"{base_url}/{endpoint}", data=data, headers=headers)
if should_check_response:
check_response(response, expected_status_code)
return response
Expand All @@ -67,6 +71,8 @@ def put(
expected_status_code: int = 200,
headers: dict = {},
should_check_response: bool = True,
*,
base_url: str = BASE_URL,
) -> requests.Response:
"""
Makes a PUT request to the given endpoint and checks the response.
Expand All @@ -78,7 +84,7 @@ def put(
"""

endpoint = endpoint.strip("/")
response = requests.put(f"{BASE_URL}/{endpoint}", data=data, headers=headers)
response = requests.put(f"{base_url}/{endpoint}", data=data, headers=headers)
if should_check_response:
check_response(response, expected_status_code)
return response
Expand All @@ -90,6 +96,8 @@ def patch(
expected_status_code: int = 200,
headers: dict = {},
should_check_response: bool = True,
*,
base_url: str = BASE_URL,
) -> requests.Response:
"""
Makes a PATCH request to the given endpoint and checks the response.
Expand All @@ -101,7 +109,7 @@ def patch(
"""

endpoint = endpoint.strip("/")
response = requests.patch(f"{BASE_URL}/{endpoint}", data=data, headers=headers)
response = requests.patch(f"{base_url}/{endpoint}", data=data, headers=headers)
if should_check_response:
check_response(response, expected_status_code)
return response
Expand All @@ -113,6 +121,8 @@ def delete(
expected_status_code: int = 200,
headers: dict = {},
should_check_response: bool = True,
*,
base_url: str = BASE_URL,
) -> requests.Response:
"""
Makes a DELETE request to the given endpoint and checks the response.
Expand All @@ -124,7 +134,7 @@ def delete(
"""

endpoint = endpoint.strip("/")
response = requests.delete(f"{BASE_URL}/{endpoint}", data=data, headers=headers)
response = requests.delete(f"{base_url}/{endpoint}", data=data, headers=headers)
if should_check_response:
check_response(response, expected_status_code)
return response
Expand All @@ -136,6 +146,8 @@ def head(
expected_status_code: int = 200,
headers: dict = {},
should_check_response: bool = True,
*,
base_url: str = BASE_URL,
) -> requests.Response:
"""
Makes a HEAD request to the given endpoint and checks the response.
Expand All @@ -147,7 +159,7 @@ def head(
"""

endpoint = endpoint.strip("/")
response = requests.head(f"{BASE_URL}/{endpoint}", data=data, headers=headers)
response = requests.head(f"{base_url}/{endpoint}", data=data, headers=headers)
if should_check_response:
check_response(response, expected_status_code)
return response
Expand All @@ -162,6 +174,8 @@ def generic_http_helper(
expected_status_code: int = 200,
headers: dict = {},
should_check_response: bool = True,
*,
base_url: str = BASE_URL,
) -> requests.Response:
"""
Makes a request to the given endpoint and checks the response.
Expand All @@ -178,10 +192,10 @@ def generic_http_helper(
f"{method} method must be one of get, post, put, patch, delete"
)
if method == "get":
response = requests.get(f"{BASE_URL}/{endpoint}", headers=headers)
response = requests.get(f"{base_url}/{endpoint}", headers=headers)
else:
response = requests.request(
method, f"{BASE_URL}/{endpoint}", data=data, headers=headers
method, f"{base_url}/{endpoint}", data=data, headers=headers
)
if should_check_response:
check_response(response, expected_status_code)
Expand Down
31 changes: 31 additions & 0 deletions integration_tests/test_rate_limiting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from collections.abc import Callable
import pytest
from helpers.http_methods_helpers import get, post, put


@pytest.mark.benchmark
@pytest.mark.parametrize(
"route,method",
[
("/sync/rate/get", get),
("/async/rate/get", get),
("/sync/rate/put", put),
("/async/rate/put", put),
("/sync/rate/post", post),
("/async/rate/post", post),
],
)
def test_throttling(
route: str,
method: Callable,
test_rate_limiting_session,
):
BASE_URL = "http://127.0.0.1:8082"
r = method(route, expected_status_code=200, base_url=BASE_URL)
assert r.text == "OK"
r = method(route, expected_status_code=200, base_url=BASE_URL)
assert r.text == "OK"
r = method(route, expected_status_code=200, base_url=BASE_URL)
assert r.text == "OK"
r = method(route, expected_status_code=429, base_url=BASE_URL)
assert r.text == "Rate limit exceeded"
Loading
Loading