diff --git a/api/services/okta_service.py b/api/services/okta_service.py index ce807af..d5a9bb6 100644 --- a/api/services/okta_service.py +++ b/api/services/okta_service.py @@ -46,6 +46,10 @@ def initialize( ) self.use_group_owners_api = use_group_owners_api + def _get_sessioned_okta_request_executor(self) -> SessionedOktaRequestExecutor: + """Establishes an Okta client session to pool connections""" + return SessionedOktaRequestExecutor(self.okta_client.get_request_executor()) + @staticmethod async def _retry(func: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Any: """Retry Okta API requests with specific status codes using exponential backoff.""" @@ -88,10 +92,7 @@ async def _retry(func: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Any: def get_user(self, userId: str) -> User: async def _get_user(userId: str) -> User: - request_executor: OktaRequestExecutor = self.okta_client.get_request_executor() - async with aiohttp.ClientSession() as session: - request_executor.set_session(session) - + async with self._get_sessioned_okta_request_executor() as _: user, _, error = await OktaService._retry(self.okta_client.get_user, userId) if error is not None: @@ -105,10 +106,7 @@ async def _get_user(userId: str) -> User: def get_user_schema(self, userTypeId: str) -> UserSchema: async def _get_user_schema(userTypeId: str) -> UserSchema: - request_executor: OktaRequestExecutor = self.okta_client.get_request_executor() - async with aiohttp.ClientSession() as session: - request_executor.set_session(session) - + async with self._get_sessioned_okta_request_executor() as _: userType, _, error = await OktaService._retry(self.okta_client.get_user_type, userTypeId) if error is not None: @@ -129,10 +127,7 @@ async def _get_user_schema(userTypeId: str) -> UserSchema: def list_users(self) -> list[User]: async def _list_users() -> list[User]: - request_executor: OktaRequestExecutor = self.okta_client.get_request_executor() - async with aiohttp.ClientSession() as session: - request_executor.set_session(session) - + async with self._get_sessioned_okta_request_executor() as _: users, resp, error = await OktaService._retry(self.okta_client.list_users) if error is not None: @@ -150,10 +145,7 @@ async def _list_users() -> list[User]: def create_group(self, name: str, description: str) -> Group: async def _create_group(name: str, description: str) -> Group: - request_executor: OktaRequestExecutor = self.okta_client.get_request_executor() - async with aiohttp.ClientSession() as session: - request_executor.set_session(session) - + async with self._get_sessioned_okta_request_executor() as _: group, _, error = await OktaService._retry( self.okta_client.create_group, OktaGroupType({"profile": {"name": name, "description": description}}), @@ -170,10 +162,7 @@ async def _create_group(name: str, description: str) -> Group: def update_group(self, groupId: str, name: str, description: str) -> Group: async def _update_group(groupId: str, name: str, description: str) -> Group: - request_executor: OktaRequestExecutor = self.okta_client.get_request_executor() - async with aiohttp.ClientSession() as session: - request_executor.set_session(session) - + async with self._get_sessioned_okta_request_executor() as _: group, _, error = await OktaService._retry( self.okta_client.update_group, groupId, @@ -198,10 +187,7 @@ async def async_add_user_to_group(self, groupId: str, userId: str) -> None: logger.warning(f"cannot add user with userId of {userId}") return - request_executor: OktaRequestExecutor = self.okta_client.get_request_executor() - async with aiohttp.ClientSession() as session: - request_executor.set_session(session) - + async with self._get_sessioned_okta_request_executor() as _: _, error = await OktaService._retry(self.okta_client.add_user_to_group, groupId, userId) if error is not None: @@ -219,10 +205,7 @@ async def async_remove_user_from_group(self, groupId: str, userId: str) -> None: logger.warning(f"cannot remove user with userId of {userId}") return - request_executor: OktaRequestExecutor = self.okta_client.get_request_executor() - async with aiohttp.ClientSession() as session: - request_executor.set_session(session) - + async with self._get_sessioned_okta_request_executor() as _: _, error = await OktaService._retry(self.okta_client.remove_user_from_group, groupId, userId) if error is not None: @@ -235,10 +218,7 @@ def remove_user_from_group(self, groupId: str, userId: str) -> None: # GET https://{yourOktaDomain}.com/api/v1/groups/?expand=app,stats def get_group(self, groupId: str) -> Group: async def _get_group(groupId: str) -> Group: - request_executor: OktaRequestExecutor = self.okta_client.get_request_executor() - async with aiohttp.ClientSession() as session: - request_executor.set_session(session) - + async with self._get_sessioned_okta_request_executor() as _: group, _, error = await OktaService._retry(self.okta_client.get_group, groupId) if error is not None: @@ -254,10 +234,7 @@ async def _get_group(groupId: str) -> Group: def list_groups(self, *, query_params: dict[str, str] = DEFAULT_QUERY_PARAMS) -> list[Group]: async def _list_groups(query_params: dict[str, str]) -> list[Group]: - request_executor: OktaRequestExecutor = self.okta_client.get_request_executor() - async with aiohttp.ClientSession() as session: - request_executor.set_session(session) - + async with self._get_sessioned_okta_request_executor() as _: groups, resp, error = await OktaService._retry(self.okta_client.list_groups, query_params=query_params) if error is not None: @@ -283,10 +260,7 @@ def list_groups_with_active_rules(self) -> dict[str, list[OktaGroupRuleType]]: def list_group_rules(self, *, query_params: dict[str, str] = {}) -> list[OktaGroupRuleType]: async def _list_group_rules(query_params: dict[str, str]) -> list[OktaGroupRuleType]: - request_executor: OktaRequestExecutor = self.okta_client.get_request_executor() - async with aiohttp.ClientSession() as session: - request_executor.set_session(session) - + async with self._get_sessioned_okta_request_executor() as _: group_rules, resp, error = await OktaService._retry( self.okta_client.list_group_rules, query_params=query_params ) @@ -306,10 +280,7 @@ async def _list_group_rules(query_params: dict[str, str]) -> list[OktaGroupRuleT def list_users_for_group(self, groupId: str) -> list[User]: async def _list_users_for_group(groupId: str) -> list[User]: - request_executor: OktaRequestExecutor = self.okta_client.get_request_executor() - async with aiohttp.ClientSession() as session: - request_executor.set_session(session) - + async with self._get_sessioned_okta_request_executor() as _: users, resp, error = await OktaService._retry(self.okta_client.list_group_users, groupId) if error is not None: @@ -325,10 +296,7 @@ async def _list_users_for_group(groupId: str) -> list[User]: return asyncio.run(_list_users_for_group(groupId)) async def async_delete_group(self, groupId: str) -> None: - request_executor: OktaRequestExecutor = self.okta_client.get_request_executor() - async with aiohttp.ClientSession() as session: - request_executor.set_session(session) - + async with self._get_sessioned_okta_request_executor() as _: _, error = await OktaService._retry(self.okta_client.delete_group, groupId) if error is not None: @@ -354,21 +322,18 @@ async def async_add_owner_to_group(self, groupId: str, userId: str) -> None: if not self.use_group_owners_api: return - request_executor: OktaRequestExecutor = self.okta_client.get_request_executor() - - request, error = await request_executor.create_request( - method="POST", - url="/api/v1/groups/{groupId}/owners".format(groupId=groupId), - body={"id": userId, "type": "USER"}, - headers={}, - oauth=False, - ) + async with self._get_sessioned_okta_request_executor() as request_executor: + request, error = await request_executor.create_request( + method="POST", + url="/api/v1/groups/{groupId}/owners".format(groupId=groupId), + body={"id": userId, "type": "USER"}, + headers={}, + oauth=False, + ) - if error is not None: - raise Exception(error) + if error is not None: + raise Exception(error) - async with aiohttp.ClientSession() as session: - request_executor.set_session(session) _, error = await OktaService._retry(request_executor.execute, request) # Ignore error if owner is already assigned to group @@ -391,21 +356,18 @@ async def async_remove_owner_from_group(self, groupId: str, userId: str) -> None if not self.use_group_owners_api: return - request_executor: OktaRequestExecutor = self.okta_client.get_request_executor() - - request, error = await request_executor.create_request( - method="DELETE", - url="/api/v1/groups/{groupId}/owners/{userId}".format(groupId=groupId, userId=userId), - body={}, - headers={}, - oauth=False, - ) + async with self._get_sessioned_okta_request_executor() as request_executor: + request, error = await request_executor.create_request( + method="DELETE", + url="/api/v1/groups/{groupId}/owners/{userId}".format(groupId=groupId, userId=userId), + body={}, + headers={}, + oauth=False, + ) - if error is not None: - raise Exception(error) + if error is not None: + raise Exception(error) - async with aiohttp.ClientSession() as session: - request_executor.set_session(session) _, error = await OktaService._retry(request_executor.execute, request) if error is not None: @@ -422,21 +384,18 @@ def list_owners_for_group(self, groupId: str) -> list[User]: return [] async def _list_owners_for_group(groupId: str) -> list[User]: - request_executor: OktaRequestExecutor = self.okta_client.get_request_executor() - - request, error = await request_executor.create_request( - method="GET", - url="/api/v1/groups/{groupId}/owners".format(groupId=groupId), - body={}, - headers={}, - oauth=False, - ) + async with self._get_sessioned_okta_request_executor() as request_executor: + request, error = await request_executor.create_request( + method="GET", + url="/api/v1/groups/{groupId}/owners".format(groupId=groupId), + body={}, + headers={}, + oauth=False, + ) - if error is not None: - raise Exception(error) + if error is not None: + raise Exception(error) - async with aiohttp.ClientSession() as session: - request_executor.set_session(session) response, error = await OktaService._retry(request_executor.execute, request, OktaUserType) if error is not None: @@ -451,6 +410,27 @@ async def _list_owners_for_group(groupId: str) -> list[User]: return asyncio.run(_list_owners_for_group(groupId)) +class SessionedOktaRequestExecutor: + """ + Context manager for Okta's RequestExecutor that manages an aiohttp ClientSession to enable connection pooling. + """ + + def __init__(self, request_executor: OktaRequestExecutor): + self._request_executor = request_executor + self._session: Optional[aiohttp.ClientSession] = None + + async def __aenter__(self) -> OktaRequestExecutor: + self._session = aiohttp.ClientSession() + self._request_executor.set_session(self._session) + return self._request_executor + + async def __aexit__(self, *args: Any) -> None: + if self._session: + await self._session.close() + self._request_executor.set_session(None) + self._session = None + + # Wrapper class for the Okta API user model class User: def __init__(self, user: OktaUserType):