Skip to content

Commit

Permalink
Check for refresh token absolute expiration time (#604)
Browse files Browse the repository at this point in the history
* Handle expired tokens

* Implement tests for token expiration

* Use absolute expiration time instead of expiration time

* Fix style issues

* refresh token expiration: bugfix and migration

* fix local auth bug

* don't logout if access token not expired

---------

Co-authored-by: hasan7n <[email protected]>
Co-authored-by: hasan7n <[email protected]>
  • Loading branch information
3 people committed Aug 15, 2024
1 parent 37073cf commit 6186316
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 9 deletions.
15 changes: 14 additions & 1 deletion cli/medperf/account_management/account_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,29 @@ def set_credentials(
id_token_payload,
token_issued_at,
token_expires_in,
login_event=False,
):
email = id_token_payload["email"]
TokenStore().set_tokens(email, access_token, refresh_token)
config_p = read_config()

if login_event:
# Set the time the user logged in, so that we can track the lifetime of
# the refresh token
logged_in_at = token_issued_at
else:
# This means this is a refresh event. Preserve the logged_in_at timestamp.
logged_in_at = config_p.active_profile[config.credentials_keyword][
"logged_in_at"
]

account_info = {
"email": email,
"token_issued_at": token_issued_at,
"token_expires_in": token_expires_in,
"logged_in_at": logged_in_at,
}
config_p = read_config()

config_p.active_profile[config.credentials_keyword] = account_info
write_config(config_p)

Expand Down
37 changes: 31 additions & 6 deletions cli/medperf/comms/auth/auth0.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sqlite3
from medperf.comms.auth.interface import Auth
from medperf.comms.auth.token_verifier import verify_token
from medperf.exceptions import CommunicationError
from medperf.exceptions import CommunicationError, AuthenticationError
import requests
import medperf.config as config
from medperf.utils import log_response_error
Expand Down Expand Up @@ -66,6 +66,7 @@ def login(self, email):
id_token_payload,
token_issued_at,
token_expires_in,
login_event=True,
)

def __request_device_code(self):
Expand Down Expand Up @@ -191,11 +192,35 @@ def _access_token(self):
refresh_token = creds["refresh_token"]
token_expires_in = creds["token_expires_in"]
token_issued_at = creds["token_issued_at"]
if (
time.time()
> token_issued_at + token_expires_in - config.token_expiration_leeway
):
access_token = self.__refresh_access_token(refresh_token)
logged_in_at = creds["logged_in_at"]

# token_issued_at and expires_in are for the access token
sliding_expiration_time = (
token_issued_at + token_expires_in - config.token_expiration_leeway
)
absolute_expiration_time = (
logged_in_at
+ config.token_absolute_expiry
- config.refresh_token_expiration_leeway
)
current_time = time.time()

if current_time < sliding_expiration_time:
# Access token not expired. No need to refresh.
return access_token

# So we need to refresh.
if current_time > absolute_expiration_time:
# Expired refresh token. Force logout and ask the user to re-authenticate
logging.debug(
f"Refresh token expired: {absolute_expiration_time=} <> {current_time=}"
)
self.logout()
raise AuthenticationError("Token expired. Please re-authenticate")

# Expired access token and not expired refresh token. Refresh.
access_token = self.__refresh_access_token(refresh_token)

return access_token

def __refresh_access_token(self, refresh_token):
Expand Down
1 change: 1 addition & 0 deletions cli/medperf/comms/auth/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def login(self, email):
id_token_payload,
token_issued_at,
token_expires_in,
login_event=True,
)

def logout(self):
Expand Down
2 changes: 2 additions & 0 deletions cli/medperf/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
auth_jwks_cache_ttl = 600 # fetch jwks every 10 mins. Default value in auth0 python SDK

token_expiration_leeway = 10 # Refresh tokens 10 seconds before expiration
refresh_token_expiration_leeway = 10 # Logout users 10 seconds before absolute token expiration.
token_absolute_expiry = 2592000 # Refresh token absolute expiration time (seconds). This value is set on auth0's configuration
access_token_storage_id = "medperf_access_token"
refresh_token_storage_id = "medperf_refresh_token"

Expand Down
4 changes: 4 additions & 0 deletions cli/medperf/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ class ExecutionError(MedperfException):
"""Raised when an execution component fails"""


class AuthenticationError(MedperfException):
"""Raised when authentication can't be processed"""


class CleanExit(MedperfException):
"""Raised when Medperf needs to stop for non erroneous reasons"""

Expand Down
14 changes: 14 additions & 0 deletions cli/medperf/storage/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import shutil
import time

from medperf import config
from medperf.config_management import read_config, write_config
Expand All @@ -25,6 +26,7 @@ def apply_configuration_migrations():

config_p = read_config()

# Migration for moving the logs folder to a new location
if "logs_folder" not in config_p.storage:
return

Expand All @@ -35,4 +37,16 @@ def apply_configuration_migrations():

del config_p.storage["logs_folder"]

# Migration for tracking the login timestamp (i.e., refresh token issuance timestamp)
if config.credentials_keyword in config_p.active_profile:
# So the user is logged in
if "logged_in_at" not in config_p.active_profile[config.credentials_keyword]:
# Apply migration. We will set it to the current time, since this
# will make sure they will not be logged out before the actual refresh
# token expiration (for a better user experience). However, currently logged
# in users will still face a confusing error when the refresh token expires.
config_p.active_profile[config.credentials_keyword][
"logged_in_at"
] = time.time()

write_config(config_p)
35 changes: 33 additions & 2 deletions cli/medperf/tests/comms/test_auth0.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
from unittest.mock import ANY
from medperf.tests.mocks import MockResponse
from medperf.comms.auth.auth0 import Auth0
from medperf import config
from medperf.exceptions import AuthenticationError
import sqlite3
import pytest


PATCH_AUTH = "medperf.comms.auth.auth0.{}"


Expand Down Expand Up @@ -35,6 +38,7 @@ def test_token_is_not_refreshed_if_not_expired(mocker, setup):
"access_token": "",
"token_expires_in": 900,
"token_issued_at": time.time(),
"logged_in_at": time.time(),
}
mocker.patch(PATCH_AUTH.format("read_credentials"), return_value=creds)
spy = mocker.patch(PATCH_AUTH.format("Auth0._Auth0__refresh_access_token"))
Expand All @@ -48,11 +52,14 @@ def test_token_is_not_refreshed_if_not_expired(mocker, setup):

def test_token_is_refreshed_if_expired(mocker, setup):
# Arrange
expiration_time = 900
mocked_issued_at = time.time() - expiration_time
creds = {
"refresh_token": "",
"access_token": "",
"token_expires_in": 900,
"token_issued_at": time.time() - 1000,
"token_expires_in": expiration_time,
"token_issued_at": mocked_issued_at,
"logged_in_at": time.time(),
}
mocker.patch(PATCH_AUTH.format("read_credentials"), return_value=creds)
spy = mocker.patch(PATCH_AUTH.format("Auth0._Auth0__refresh_access_token"))
Expand All @@ -64,6 +71,30 @@ def test_token_is_refreshed_if_expired(mocker, setup):
spy.assert_called_once()


def test_logs_out_if_session_reaches_token_absolute_expiration_time(mocker, setup):
# Arrange
expiration_time = 900
absolute_expiration_time = config.token_absolute_expiry
mocked_logged_in_at = time.time() - absolute_expiration_time
mocked_issued_at = time.time() - expiration_time
creds = {
"refresh_token": "",
"access_token": "",
"token_expires_in": expiration_time,
"token_issued_at": mocked_issued_at,
"logged_in_at": mocked_logged_in_at,
}
mocker.patch(PATCH_AUTH.format("read_credentials"), return_value=creds)
spy = mocker.patch(PATCH_AUTH.format("Auth0.logout"))

# Act
with pytest.raises(AuthenticationError):
Auth0().access_token

# Assert
spy.assert_called_once()


def test_refresh_token_sets_new_tokens(mocker, setup):
# Arrange
access_token = "access_token"
Expand Down

0 comments on commit 6186316

Please sign in to comment.