Skip to content

Commit

Permalink
Refactor to address PR feedback.
Browse files Browse the repository at this point in the history
  • Loading branch information
barborico committed Sep 12, 2024
1 parent b0ab002 commit ad10c14
Showing 1 changed file with 67 additions and 87 deletions.
154 changes: 67 additions & 87 deletions api/services/okta_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def initialize(
)
self.use_group_owners_api = use_group_owners_api

def _get_sessioned_okta_request_executor(self) -> SessionedOktaRequestExecutor:
"""Establishes an Okta client session to pool connections"""
return SessionedOktaRequestExecutor(self.okta_client.get_request_executor())

@staticmethod
async def _retry(func: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Any:
"""Retry Okta API requests with specific status codes using exponential backoff."""
Expand Down Expand Up @@ -88,10 +92,7 @@ async def _retry(func: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Any:

def get_user(self, userId: str) -> User:
async def _get_user(userId: str) -> User:
request_executor: OktaRequestExecutor = self.okta_client.get_request_executor()
async with aiohttp.ClientSession() as session:
request_executor.set_session(session)

async with self._get_sessioned_okta_request_executor() as _:
user, _, error = await OktaService._retry(self.okta_client.get_user, userId)

if error is not None:
Expand All @@ -105,10 +106,7 @@ async def _get_user(userId: str) -> User:

def get_user_schema(self, userTypeId: str) -> UserSchema:
async def _get_user_schema(userTypeId: str) -> UserSchema:
request_executor: OktaRequestExecutor = self.okta_client.get_request_executor()
async with aiohttp.ClientSession() as session:
request_executor.set_session(session)

async with self._get_sessioned_okta_request_executor() as _:
userType, _, error = await OktaService._retry(self.okta_client.get_user_type, userTypeId)

if error is not None:
Expand All @@ -129,10 +127,7 @@ async def _get_user_schema(userTypeId: str) -> UserSchema:

def list_users(self) -> list[User]:
async def _list_users() -> list[User]:
request_executor: OktaRequestExecutor = self.okta_client.get_request_executor()
async with aiohttp.ClientSession() as session:
request_executor.set_session(session)

async with self._get_sessioned_okta_request_executor() as _:
users, resp, error = await OktaService._retry(self.okta_client.list_users)

if error is not None:
Expand All @@ -150,10 +145,7 @@ async def _list_users() -> list[User]:

def create_group(self, name: str, description: str) -> Group:
async def _create_group(name: str, description: str) -> Group:
request_executor: OktaRequestExecutor = self.okta_client.get_request_executor()
async with aiohttp.ClientSession() as session:
request_executor.set_session(session)

async with self._get_sessioned_okta_request_executor() as _:
group, _, error = await OktaService._retry(
self.okta_client.create_group,
OktaGroupType({"profile": {"name": name, "description": description}}),
Expand All @@ -170,10 +162,7 @@ async def _create_group(name: str, description: str) -> Group:

def update_group(self, groupId: str, name: str, description: str) -> Group:
async def _update_group(groupId: str, name: str, description: str) -> Group:
request_executor: OktaRequestExecutor = self.okta_client.get_request_executor()
async with aiohttp.ClientSession() as session:
request_executor.set_session(session)

async with self._get_sessioned_okta_request_executor() as _:
group, _, error = await OktaService._retry(
self.okta_client.update_group,
groupId,
Expand All @@ -198,10 +187,7 @@ async def async_add_user_to_group(self, groupId: str, userId: str) -> None:
logger.warning(f"cannot add user with userId of {userId}")
return

request_executor: OktaRequestExecutor = self.okta_client.get_request_executor()
async with aiohttp.ClientSession() as session:
request_executor.set_session(session)

async with self._get_sessioned_okta_request_executor() as _:
_, error = await OktaService._retry(self.okta_client.add_user_to_group, groupId, userId)

if error is not None:
Expand All @@ -219,10 +205,7 @@ async def async_remove_user_from_group(self, groupId: str, userId: str) -> None:
logger.warning(f"cannot remove user with userId of {userId}")
return

request_executor: OktaRequestExecutor = self.okta_client.get_request_executor()
async with aiohttp.ClientSession() as session:
request_executor.set_session(session)

async with self._get_sessioned_okta_request_executor() as _:
_, error = await OktaService._retry(self.okta_client.remove_user_from_group, groupId, userId)

if error is not None:
Expand All @@ -235,10 +218,7 @@ def remove_user_from_group(self, groupId: str, userId: str) -> None:
# GET https://{yourOktaDomain}.com/api/v1/groups/<group_id>?expand=app,stats
def get_group(self, groupId: str) -> Group:
async def _get_group(groupId: str) -> Group:
request_executor: OktaRequestExecutor = self.okta_client.get_request_executor()
async with aiohttp.ClientSession() as session:
request_executor.set_session(session)

async with self._get_sessioned_okta_request_executor() as _:
group, _, error = await OktaService._retry(self.okta_client.get_group, groupId)

if error is not None:
Expand All @@ -254,10 +234,7 @@ async def _get_group(groupId: str) -> Group:

def list_groups(self, *, query_params: dict[str, str] = DEFAULT_QUERY_PARAMS) -> list[Group]:
async def _list_groups(query_params: dict[str, str]) -> list[Group]:
request_executor: OktaRequestExecutor = self.okta_client.get_request_executor()
async with aiohttp.ClientSession() as session:
request_executor.set_session(session)

async with self._get_sessioned_okta_request_executor() as _:
groups, resp, error = await OktaService._retry(self.okta_client.list_groups, query_params=query_params)

if error is not None:
Expand All @@ -283,10 +260,7 @@ def list_groups_with_active_rules(self) -> dict[str, list[OktaGroupRuleType]]:

def list_group_rules(self, *, query_params: dict[str, str] = {}) -> list[OktaGroupRuleType]:
async def _list_group_rules(query_params: dict[str, str]) -> list[OktaGroupRuleType]:
request_executor: OktaRequestExecutor = self.okta_client.get_request_executor()
async with aiohttp.ClientSession() as session:
request_executor.set_session(session)

async with self._get_sessioned_okta_request_executor() as _:
group_rules, resp, error = await OktaService._retry(
self.okta_client.list_group_rules, query_params=query_params
)
Expand All @@ -306,10 +280,7 @@ async def _list_group_rules(query_params: dict[str, str]) -> list[OktaGroupRuleT

def list_users_for_group(self, groupId: str) -> list[User]:
async def _list_users_for_group(groupId: str) -> list[User]:
request_executor: OktaRequestExecutor = self.okta_client.get_request_executor()
async with aiohttp.ClientSession() as session:
request_executor.set_session(session)

async with self._get_sessioned_okta_request_executor() as _:
users, resp, error = await OktaService._retry(self.okta_client.list_group_users, groupId)

if error is not None:
Expand All @@ -325,10 +296,7 @@ async def _list_users_for_group(groupId: str) -> list[User]:
return asyncio.run(_list_users_for_group(groupId))

async def async_delete_group(self, groupId: str) -> None:
request_executor: OktaRequestExecutor = self.okta_client.get_request_executor()
async with aiohttp.ClientSession() as session:
request_executor.set_session(session)

async with self._get_sessioned_okta_request_executor() as _:
_, error = await OktaService._retry(self.okta_client.delete_group, groupId)

if error is not None:
Expand All @@ -354,21 +322,18 @@ async def async_add_owner_to_group(self, groupId: str, userId: str) -> None:
if not self.use_group_owners_api:
return

request_executor: OktaRequestExecutor = self.okta_client.get_request_executor()

request, error = await request_executor.create_request(
method="POST",
url="/api/v1/groups/{groupId}/owners".format(groupId=groupId),
body={"id": userId, "type": "USER"},
headers={},
oauth=False,
)
async with self._get_sessioned_okta_request_executor() as request_executor:
request, error = await request_executor.create_request(
method="POST",
url="/api/v1/groups/{groupId}/owners".format(groupId=groupId),
body={"id": userId, "type": "USER"},
headers={},
oauth=False,
)

if error is not None:
raise Exception(error)
if error is not None:
raise Exception(error)

async with aiohttp.ClientSession() as session:
request_executor.set_session(session)
_, error = await OktaService._retry(request_executor.execute, request)

# Ignore error if owner is already assigned to group
Expand All @@ -391,21 +356,18 @@ async def async_remove_owner_from_group(self, groupId: str, userId: str) -> None
if not self.use_group_owners_api:
return

request_executor: OktaRequestExecutor = self.okta_client.get_request_executor()

request, error = await request_executor.create_request(
method="DELETE",
url="/api/v1/groups/{groupId}/owners/{userId}".format(groupId=groupId, userId=userId),
body={},
headers={},
oauth=False,
)
async with self._get_sessioned_okta_request_executor() as request_executor:
request, error = await request_executor.create_request(
method="DELETE",
url="/api/v1/groups/{groupId}/owners/{userId}".format(groupId=groupId, userId=userId),
body={},
headers={},
oauth=False,
)

if error is not None:
raise Exception(error)
if error is not None:
raise Exception(error)

async with aiohttp.ClientSession() as session:
request_executor.set_session(session)
_, error = await OktaService._retry(request_executor.execute, request)

if error is not None:
Expand All @@ -422,21 +384,18 @@ def list_owners_for_group(self, groupId: str) -> list[User]:
return []

async def _list_owners_for_group(groupId: str) -> list[User]:
request_executor: OktaRequestExecutor = self.okta_client.get_request_executor()

request, error = await request_executor.create_request(
method="GET",
url="/api/v1/groups/{groupId}/owners".format(groupId=groupId),
body={},
headers={},
oauth=False,
)
async with self._get_sessioned_okta_request_executor() as request_executor:
request, error = await request_executor.create_request(
method="GET",
url="/api/v1/groups/{groupId}/owners".format(groupId=groupId),
body={},
headers={},
oauth=False,
)

if error is not None:
raise Exception(error)
if error is not None:
raise Exception(error)

async with aiohttp.ClientSession() as session:
request_executor.set_session(session)
response, error = await OktaService._retry(request_executor.execute, request, OktaUserType)

if error is not None:
Expand All @@ -451,6 +410,27 @@ async def _list_owners_for_group(groupId: str) -> list[User]:
return asyncio.run(_list_owners_for_group(groupId))


class SessionedOktaRequestExecutor:
"""
Context manager for Okta's RequestExecutor that manages an aiohttp ClientSession to enable connection pooling.
"""

def __init__(self, request_executor: OktaRequestExecutor):
self._request_executor = request_executor
self._session: Optional[aiohttp.ClientSession] = None

async def __aenter__(self) -> OktaRequestExecutor:
self._session = aiohttp.ClientSession()
self._request_executor.set_session(self._session)
return self._request_executor

async def __aexit__(self, *args: Any) -> None:
if self._session:
await self._session.close()
self._request_executor.set_session(None)
self._session = None


# Wrapper class for the Okta API user model
class User:
def __init__(self, user: OktaUserType):
Expand Down

0 comments on commit ad10c14

Please sign in to comment.