From e450256aa89fde3ffba040702690905c6e33dcad Mon Sep 17 00:00:00 2001 From: Adam Dangoor Date: Sat, 28 Sep 2024 11:02:14 +0100 Subject: [PATCH] Make target databases an Iterable --- src/mock_vws/_flask_server/target_manager.py | 27 +++++++++++--- src/mock_vws/_flask_server/vws.py | 2 +- .../mock_web_services_api.py | 36 +++++++++++++++---- src/mock_vws/database.py | 3 +- tests/mock_vws/test_requests_mock_usage.py | 4 +-- 5 files changed, 57 insertions(+), 15 deletions(-) diff --git a/src/mock_vws/_flask_server/target_manager.py b/src/mock_vws/_flask_server/target_manager.py index 83afda03..38c76efc 100644 --- a/src/mock_vws/_flask_server/target_manager.py +++ b/src/mock_vws/_flask_server/target_manager.py @@ -207,7 +207,10 @@ def create_target(database_name: str) -> Response: target_id=request_json["target_id"], target_tracking_rater=target_tracking_rater, ) - database.targets.add(target) + new_database_targets = {*database.targets, target} + new_database = dataclasses.replace(database, targets=new_database_targets) + TARGET_MANAGER.remove_database(database=database) + TARGET_MANAGER.add_database(database=new_database) return Response( response=json.dumps(obj=target.to_dict()), @@ -232,8 +235,15 @@ def delete_target(database_name: str, target_id: str) -> Response: target = database.get_target(target_id=target_id) now = datetime.datetime.now(tz=target.upload_date.tzinfo) new_target = dataclasses.replace(target, delete_date=now) - database.targets.remove(target) - database.targets.add(new_target) + new_database_targets = { + database_target + for database_target in database.targets + if database_target != target + } + new_database_targets = {*new_database_targets, new_target} + new_database = dataclasses.replace(database, targets=new_database_targets) + TARGET_MANAGER.remove_database(database=database) + TARGET_MANAGER.add_database(database=new_database) return Response( response=json.dumps(obj=new_target.to_dict()), status=HTTPStatus.OK, @@ -282,8 +292,15 @@ def update_target(database_name: str, target_id: str) -> Response: last_modified_date=last_modified_date, ) - database.targets.remove(target) - database.targets.add(new_target) + new_database_targets = { + database_target + for database_target in database.targets + if database_target != target + } + new_database_targets = {*new_database_targets, new_target} + new_database = dataclasses.replace(database, targets=new_database_targets) + TARGET_MANAGER.remove_database(database=database) + TARGET_MANAGER.add_database(database=new_database) return Response( response=json.dumps(obj=new_target.to_dict()), diff --git a/src/mock_vws/_flask_server/vws.py b/src/mock_vws/_flask_server/vws.py index 5a909cd4..57c0f692 100644 --- a/src/mock_vws/_flask_server/vws.py +++ b/src/mock_vws/_flask_server/vws.py @@ -467,7 +467,7 @@ def get_duplicates(target_id: str) -> Response: (target,) = ( target for target in database.targets if target.target_id == target_id ) - other_targets = database.targets - {target} + other_targets = set(database.targets) - {target} similar_targets = [ other.target_id diff --git a/src/mock_vws/_requests_mock_server/mock_web_services_api.py b/src/mock_vws/_requests_mock_server/mock_web_services_api.py index 07d6c719..2a8a82b0 100644 --- a/src/mock_vws/_requests_mock_server/mock_web_services_api.py +++ b/src/mock_vws/_requests_mock_server/mock_web_services_api.py @@ -181,7 +181,13 @@ def add_target(self, request: PreparedRequest) -> _ResponseType: application_metadata=application_metadata, target_tracking_rater=self._target_tracking_rater, ) - database.targets.add(new_target) + new_database_targets = {*database.targets, new_target} + new_database = dataclasses.replace( + database, + targets=new_database_targets, + ) + self._target_manager.remove_database(database=database) + self._target_manager.add_database(database=new_database) date = email.utils.formatdate( timeval=None, @@ -251,8 +257,19 @@ def delete_target(self, request: PreparedRequest) -> _ResponseType: now = datetime.datetime.now(tz=target.upload_date.tzinfo) new_target = dataclasses.replace(target, delete_date=now) - database.targets.remove(target) - database.targets.add(new_target) + new_database_targets = { + database_target + for database_target in database.targets + if database_target != target + } + new_database_targets = {*new_database_targets, new_target} + new_database = dataclasses.replace( + database, + targets=new_database_targets, + ) + self._target_manager.remove_database(database=database) + self._target_manager.add_database(database=new_database) + date = email.utils.formatdate( timeval=None, localtime=False, @@ -492,7 +509,7 @@ def get_duplicates(self, request: PreparedRequest) -> _ResponseType: target_id = request.path_url.split(sep="/")[-1] target = database.get_target(target_id=target_id) - other_targets = database.targets - {target} + other_targets = set(database.targets) - {target} similar_targets = [ other.target_id @@ -624,8 +641,15 @@ def update_target(self, request: PreparedRequest) -> _ResponseType: last_modified_date=last_modified_date, ) - database.targets.remove(target) - database.targets.add(new_target) + new_targets = { + database_target + for database_target in database.targets + if database_target != target + } + new_targets = {*new_targets, new_target} + new_database = dataclasses.replace(database, targets=new_targets) + self._target_manager.remove_database(database=database) + self._target_manager.add_database(database=new_database) body = { "result_code": ResultCodes.SUCCESS.value, diff --git a/src/mock_vws/database.py b/src/mock_vws/database.py index 153c8f58..cea93ef7 100644 --- a/src/mock_vws/database.py +++ b/src/mock_vws/database.py @@ -3,6 +3,7 @@ """ import uuid +from collections.abc import Iterable from dataclasses import dataclass, field from typing import Self, TypedDict @@ -67,7 +68,7 @@ class VuforiaDatabase: # ``frozen=True`` while still being able to keep the interface we want. # In particular, we might want to inspect the ``database`` object's targets # as they change via API requests. - targets: set[Target] = field(default_factory=set, hash=False) + targets: Iterable[Target] = field(default_factory=set, hash=False) state: States = States.WORKING request_quota: int = 100000 diff --git a/tests/mock_vws/test_requests_mock_usage.py b/tests/mock_vws/test_requests_mock_usage.py index 5c3fb892..4af31561 100644 --- a/tests/mock_vws/test_requests_mock_usage.py +++ b/tests/mock_vws/test_requests_mock_usage.py @@ -283,7 +283,7 @@ def test_to_dict(high_quality_image: io.BytesIO) -> None: application_metadata=None, ) - assert len(database.targets) == 1 + assert len(list(database.targets)) == 1 target = next(iter(database.targets)) target_dict = target.to_dict() @@ -318,7 +318,7 @@ def test_to_dict_deleted(high_quality_image: io.BytesIO) -> None: vws_client.wait_for_target_processed(target_id=target_id) vws_client.delete_target(target_id=target_id) - assert len(database.targets) == 1 + assert len(list(database.targets)) == 1 target = next(iter(database.targets)) target_dict = target.to_dict()