Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: pass request to AXES_COOLOFF_TIME callback #1222

Merged
11 changes: 6 additions & 5 deletions axes/attempts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,18 @@
log = getLogger(__name__)


def get_cool_off_threshold(attempt_time: Optional[datetime] = None) -> datetime:
def get_cool_off_threshold(request: Optional[HttpRequest] = None) -> datetime:
browniebroke marked this conversation as resolved.
Show resolved Hide resolved
"""
Get threshold for fetching access attempts from the database.
"""

cool_off = get_cool_off()
cool_off = get_cool_off(request)
if cool_off is None:
raise TypeError(
"Cool off threshold can not be calculated with settings.AXES_COOLOFF_TIME set to None"
)

attempt_time = request.axes_attempt_time
if attempt_time is None:
return now() - cool_off
return attempt_time - cool_off
Expand Down Expand Up @@ -62,12 +63,12 @@ def get_user_attempts(
)
return attempts_list

threshold = get_cool_off_threshold(request.axes_attempt_time)
threshold = get_cool_off_threshold(request)
log.debug("AXES: Getting access attempts that are newer than %s", threshold)
return [attempts.filter(attempt_time__gte=threshold) for attempts in attempts_list]


def clean_expired_user_attempts(attempt_time: Optional[datetime] = None) -> int:
def clean_expired_user_attempts(request: Optional[HttpRequest] = None) -> int:
browniebroke marked this conversation as resolved.
Show resolved Hide resolved
"""
Clean expired user attempts from the database.
"""
Expand All @@ -78,7 +79,7 @@ def clean_expired_user_attempts(attempt_time: Optional[datetime] = None) -> int:
)
return 0

threshold = get_cool_off_threshold(attempt_time)
threshold = get_cool_off_threshold(request)
count, _ = AccessAttempt.objects.filter(attempt_time__lt=threshold).delete()
log.info(
"AXES: Cleaned up %s expired access attempts from database that were older than %s",
Expand Down
2 changes: 1 addition & 1 deletion axes/handlers/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def user_login_failed(self, sender, credentials: dict, request=None, **kwargs):
return

cache_keys = get_client_cache_keys(request, credentials)
cache_timeout = get_cache_timeout()
cache_timeout = get_cache_timeout(request)
failures = []
for cache_key in cache_keys:
added = self.cache.add(key=cache_key, value=1, timeout=cache_timeout)
Expand Down
6 changes: 3 additions & 3 deletions axes/handlers/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def user_login_failed(self, sender, credentials: dict, request=None, **kwargs):
return

# 1. database query: Clean up expired user attempts from the database before logging new attempts
clean_expired_user_attempts(request.axes_attempt_time)
clean_expired_user_attempts(request)

username = get_client_username(request, credentials)
client_str = get_client_str(
Expand Down Expand Up @@ -262,7 +262,7 @@ def user_logged_in(self, sender, request, user, **kwargs):
"""

# 1. database query: Clean up expired user attempts from the database
clean_expired_user_attempts(request.axes_attempt_time)
clean_expired_user_attempts(request)

username = user.get_username()
credentials = get_credentials(username)
Expand Down Expand Up @@ -305,7 +305,7 @@ def user_logged_out(self, sender, request, user, **kwargs):
"""

# 1. database query: Clean up expired user attempts from the database
clean_expired_user_attempts(request.axes_attempt_time)
clean_expired_user_attempts(request)

username = user.get_username() if user else None
client_str = get_client_str(
Expand Down
20 changes: 11 additions & 9 deletions axes/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_cache() -> BaseCache:
return caches[getattr(settings, "AXES_CACHE", "default")]


def get_cache_timeout() -> Optional[int]:
def get_cache_timeout(request: Optional[HttpRequest] = None) -> Optional[int]:
"""
Return the cache timeout interpreted from settings.AXES_COOLOFF_TIME.

Expand All @@ -43,21 +43,22 @@ def get_cache_timeout() -> Optional[int]:
for use with the Django cache backends.
"""

cool_off = get_cool_off()
cool_off = get_cool_off(request)
if cool_off is None:
return None
return int(cool_off.total_seconds())


def get_cool_off() -> Optional[timedelta]:
def get_cool_off(request: Optional[HttpRequest] = None) -> Optional[timedelta]:
"""
Return the login cool off time interpreted from settings.AXES_COOLOFF_TIME.

The return value is either None or timedelta.

Notice that the settings.AXES_COOLOFF_TIME is either None, timedelta, or integer/float of hours,
and this function offers a unified _timedelta or None_ representation of that configuration
for use with the Axes internal implementations.
Notice that the settings.AXES_COOLOFF_TIME is either None, timedelta, integer/float of hours,
a path to a callable or a callable taking 1 argument (the request). This function
offers a unified _timedelta or None_ representation of that configuration for use with the
Axes internal implementations.

:exception TypeError: if settings.AXES_COOLOFF_TIME is of wrong type.
"""
Expand All @@ -69,9 +70,10 @@ def get_cool_off() -> Optional[timedelta]:
if isinstance(cool_off, float):
return timedelta(minutes=cool_off * 60)
if isinstance(cool_off, str):
return import_string(cool_off)()
cool_off_func = import_string(cool_off)
return cool_off_func(request)
if callable(cool_off):
return cool_off() # pylint: disable=not-callable
return cool_off(request) # pylint: disable=not-callable

return cool_off

Expand Down Expand Up @@ -462,7 +464,7 @@ def get_lockout_response(
"username": get_client_username(request, credentials) or "",
}

cool_off = get_cool_off()
cool_off = get_cool_off(request)
if cool_off:
context.update(
{
Expand Down
2 changes: 1 addition & 1 deletion docs/4_configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ The following ``settings.py`` options are available for customizing Axes behavio
+------------------------------------------------------+----------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| AXES_LOCK_OUT_AT_FAILURE | True | After the number of allowed login attempts are exceeded, should we lock out this IP (and optional user agent)? |
+------------------------------------------------------+----------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| AXES_COOLOFF_TIME | None | If set, defines a period of inactivity after which old failed login attempts will be cleared. Can be set to a Python timedelta object, an integer, a float, a callable, or a string path to a callable which takes no arguments. If an integer or float, will be interpreted as a number of hours: ``AXES_COOLOFF_TIME = 2`` 2 hours, ``AXES_COOLOFF_TIME = 2.0`` 2 hours, 120 minutes, ``AXES_COOLOFF_TIME = 1.7`` 1.7 hours, 102 minutes, 6120 seconds |
| AXES_COOLOFF_TIME | None | If set, defines a period of inactivity after which old failed login attempts will be cleared. Can be set to a Python timedelta object, an integer, a float, a callable, or a string path to a callable which takes the request as argument. If an integer or float, will be interpreted as a number of hours: ``AXES_COOLOFF_TIME = 2`` 2 hours, ``AXES_COOLOFF_TIME = 2.0`` 2 hours, 120 minutes, ``AXES_COOLOFF_TIME = 1.7`` 1.7 hours, 102 minutes, 6120 seconds |
+------------------------------------------------------+----------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| AXES_ONLY_ADMIN_SITE | False | If ``True``, lock is only enabled for admin site. Admin site is determined by checking request path against the path of ``"admin:index"`` view. If admin urls are not registered in current urlconf, all requests will not be locked. |
+------------------------------------------------------+----------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
Expand Down
11 changes: 6 additions & 5 deletions tests/test_attempts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from unittest.mock import patch

from django.http import HttpRequest
from django.test import override_settings
from django.test import override_settings, RequestFactory
from django.utils.timezone import now

from axes.attempts import get_cool_off_threshold
Expand All @@ -15,12 +15,13 @@ class GetCoolOffThresholdTestCase(AxesTestCase):
def test_get_cool_off_threshold(self):
timestamp = now()

request = RequestFactory().post("/")
with patch("axes.attempts.now", return_value=timestamp):
attempt_time = timestamp
threshold_now = get_cool_off_threshold(attempt_time)
request.axes_attempt_time = timestamp
threshold_now = get_cool_off_threshold(request)

attempt_time = None
threshold_none = get_cool_off_threshold(attempt_time)
request.axes_attempt_time = None
threshold_none = get_cool_off_threshold(request)

self.assertEqual(threshold_now, threshold_none)

Expand Down
36 changes: 34 additions & 2 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,38 @@ def test_get_cache_timeout_timedelta(self):
def test_get_cache_timeout_none(self):
self.assertEqual(get_cache_timeout(), None)

def test_get_increasing_cache_timeout_by_username(self):
user_durations = {
"ben": timedelta(minutes=5),
"jen": timedelta(minutes=10),
}

def _callback(request):
username = request.POST["username"] if request else object()
previous_duration = user_durations.get(username, timedelta())
user_durations[username] = previous_duration + timedelta(minutes=5)
return user_durations[username]

rf = RequestFactory()
ben_req = rf.post("/", data={"username": "ben"})
jen_req = rf.post("/", data={"username": "jen"})
james_req = rf.post("/", data={"username": "james"})

with override_settings(AXES_COOLOFF_TIME=_callback):
with self.subTest("no username"):
self.assertEqual(get_cache_timeout(), 300)

with self.subTest("ben"):
self.assertEqual(get_cache_timeout(ben_req), 600)
self.assertEqual(get_cache_timeout(ben_req), 900)
self.assertEqual(get_cache_timeout(ben_req), 1200)

with self.subTest("jen"):
self.assertEqual(get_cache_timeout(jen_req), 900)

with self.subTest("james"):
self.assertEqual(get_cache_timeout(james_req), 300)


class TimestampTestCase(AxesTestCase):
def test_iso8601(self):
Expand Down Expand Up @@ -915,7 +947,7 @@ def test_get_lockout_response_lockout_response(self):
self.assertEqual(type(response), HttpResponse)


def mock_get_cool_off_str():
def mock_get_cool_off_str(req):
return timedelta(seconds=30)


Expand All @@ -940,7 +972,7 @@ def test_get_cool_off_float_lt_0(self):
def test_get_cool_off_float_gt_0(self):
self.assertEqual(get_cool_off(), timedelta(seconds=6120))

@override_settings(AXES_COOLOFF_TIME=lambda: timedelta(seconds=30))
@override_settings(AXES_COOLOFF_TIME=lambda r: timedelta(seconds=30))
def test_get_cool_off_callable(self):
self.assertEqual(get_cool_off(), timedelta(seconds=30))

Expand Down
Loading