Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check for refresh token absolute expiration time #604

Merged
merged 10 commits into from
Aug 15, 2024
15 changes: 14 additions & 1 deletion cli/medperf/account_management/account_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,29 @@ def set_credentials(
id_token_payload,
token_issued_at,
token_expires_in,
login_event=False,
):
email = id_token_payload["email"]
TokenStore().set_tokens(email, access_token, refresh_token)
config_p = read_config()

if login_event:
# Set the time the user logged in, so that we can track the lifetime of
# the refresh token
logged_in_at = token_issued_at
else:
# This means this is a refresh event. Preserve the logged_in_at timestamp.
logged_in_at = config_p.active_profile[config.credentials_keyword][
"logged_in_at"
]

account_info = {
"email": email,
"token_issued_at": token_issued_at,
"token_expires_in": token_expires_in,
"logged_in_at": logged_in_at,
}
config_p = read_config()

config_p.active_profile[config.credentials_keyword] = account_info
write_config(config_p)

Expand Down
37 changes: 31 additions & 6 deletions cli/medperf/comms/auth/auth0.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sqlite3
from medperf.comms.auth.interface import Auth
from medperf.comms.auth.token_verifier import verify_token
from medperf.exceptions import CommunicationError
from medperf.exceptions import CommunicationError, AuthenticationError
import requests
import medperf.config as config
from medperf.utils import log_response_error
Expand Down Expand Up @@ -66,6 +66,7 @@ def login(self, email):
id_token_payload,
token_issued_at,
token_expires_in,
login_event=True,
)

def __request_device_code(self):
Expand Down Expand Up @@ -191,11 +192,35 @@ def _access_token(self):
refresh_token = creds["refresh_token"]
token_expires_in = creds["token_expires_in"]
token_issued_at = creds["token_issued_at"]
if (
time.time()
> token_issued_at + token_expires_in - config.token_expiration_leeway
):
access_token = self.__refresh_access_token(refresh_token)
logged_in_at = creds["logged_in_at"]

# token_issued_at and expires_in are for the access token
sliding_expiration_time = (
token_issued_at + token_expires_in - config.token_expiration_leeway
)
absolute_expiration_time = (
logged_in_at
+ config.token_absolute_expiry
- config.refresh_token_expiration_leeway
)
current_time = time.time()

if current_time < sliding_expiration_time:
# Access token not expired. No need to refresh.
return access_token

# So we need to refresh.
if current_time > absolute_expiration_time:
# Expired refresh token. Force logout and ask the user to re-authenticate
logging.debug(
f"Refresh token expired: {absolute_expiration_time=} <> {current_time=}"
)
self.logout()
raise AuthenticationError("Token expired. Please re-authenticate")

# Expired access token and not expired refresh token. Refresh.
access_token = self.__refresh_access_token(refresh_token)

return access_token

def __refresh_access_token(self, refresh_token):
Expand Down
1 change: 1 addition & 0 deletions cli/medperf/comms/auth/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def login(self, email):
id_token_payload,
token_issued_at,
token_expires_in,
login_event=True,
)

def logout(self):
Expand Down
2 changes: 2 additions & 0 deletions cli/medperf/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
auth_jwks_cache_ttl = 600 # fetch jwks every 10 mins. Default value in auth0 python SDK

token_expiration_leeway = 10 # Refresh tokens 10 seconds before expiration
refresh_token_expiration_leeway = 10 # Logout users 10 seconds before absolute token expiration.
token_absolute_expiry = 2592000 # Refresh token absolute expiration time (seconds). This value is set on auth0's configuration
access_token_storage_id = "medperf_access_token"
refresh_token_storage_id = "medperf_refresh_token"

Expand Down
4 changes: 4 additions & 0 deletions cli/medperf/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ class ExecutionError(MedperfException):
"""Raised when an execution component fails"""


class AuthenticationError(MedperfException):
"""Raised when authentication can't be processed"""


class CleanExit(MedperfException):
"""Raised when Medperf needs to stop for non erroneous reasons"""

Expand Down
14 changes: 14 additions & 0 deletions cli/medperf/storage/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import shutil
import time

from medperf import config
from medperf.config_management import read_config, write_config
Expand All @@ -25,6 +26,7 @@ def apply_configuration_migrations():

config_p = read_config()

# Migration for moving the logs folder to a new location
if "logs_folder" not in config_p.storage:
return

Expand All @@ -35,4 +37,16 @@ def apply_configuration_migrations():

del config_p.storage["logs_folder"]

# Migration for tracking the login timestamp (i.e., refresh token issuance timestamp)
if config.credentials_keyword in config_p.active_profile:
# So the user is logged in
if "logged_in_at" not in config_p.active_profile[config.credentials_keyword]:
# Apply migration. We will set it to the current time, since this
# will make sure they will not be logged out before the actual refresh
# token expiration (for a better user experience). However, currently logged
# in users will still face a confusing error when the refresh token expires.
config_p.active_profile[config.credentials_keyword][
"logged_in_at"
] = time.time()

write_config(config_p)
35 changes: 33 additions & 2 deletions cli/medperf/tests/comms/test_auth0.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
from unittest.mock import ANY
from medperf.tests.mocks import MockResponse
from medperf.comms.auth.auth0 import Auth0
from medperf import config
from medperf.exceptions import AuthenticationError
import sqlite3
import pytest


PATCH_AUTH = "medperf.comms.auth.auth0.{}"


Expand Down Expand Up @@ -35,6 +38,7 @@ def test_token_is_not_refreshed_if_not_expired(mocker, setup):
"access_token": "",
"token_expires_in": 900,
"token_issued_at": time.time(),
"logged_in_at": time.time(),
}
mocker.patch(PATCH_AUTH.format("read_credentials"), return_value=creds)
spy = mocker.patch(PATCH_AUTH.format("Auth0._Auth0__refresh_access_token"))
Expand All @@ -48,11 +52,14 @@ def test_token_is_not_refreshed_if_not_expired(mocker, setup):

def test_token_is_refreshed_if_expired(mocker, setup):
# Arrange
expiration_time = 900
mocked_issued_at = time.time() - expiration_time
creds = {
"refresh_token": "",
"access_token": "",
"token_expires_in": 900,
"token_issued_at": time.time() - 1000,
"token_expires_in": expiration_time,
"token_issued_at": mocked_issued_at,
"logged_in_at": time.time(),
}
mocker.patch(PATCH_AUTH.format("read_credentials"), return_value=creds)
spy = mocker.patch(PATCH_AUTH.format("Auth0._Auth0__refresh_access_token"))
Expand All @@ -64,6 +71,30 @@ def test_token_is_refreshed_if_expired(mocker, setup):
spy.assert_called_once()


def test_logs_out_if_session_reaches_token_absolute_expiration_time(mocker, setup):
# Arrange
expiration_time = 900
absolute_expiration_time = config.token_absolute_expiry
mocked_logged_in_at = time.time() - absolute_expiration_time
mocked_issued_at = time.time() - expiration_time
creds = {
"refresh_token": "",
"access_token": "",
"token_expires_in": expiration_time,
"token_issued_at": mocked_issued_at,
"logged_in_at": mocked_logged_in_at,
}
mocker.patch(PATCH_AUTH.format("read_credentials"), return_value=creds)
spy = mocker.patch(PATCH_AUTH.format("Auth0.logout"))

# Act
with pytest.raises(AuthenticationError):
Auth0().access_token

# Assert
spy.assert_called_once()


def test_refresh_token_sets_new_tokens(mocker, setup):
# Arrange
access_token = "access_token"
Expand Down
Loading