diff --git a/astrapy/db.py b/astrapy/db.py index 7360e407..f99b6540 100644 --- a/astrapy/db.py +++ b/astrapy/db.py @@ -104,6 +104,13 @@ def __init__( caller_name=caller_name, caller_version=caller_version, ) + else: + # if astra_db passed, copy and apply possible overrides + astra_db = astra_db.copy( + token=token, + api_endpoint=api_endpoint, + namespace=namespace, + ) # Set the remaining instance attributes self.astra_db = astra_db @@ -128,12 +135,31 @@ def __eq__(self, other: Any) -> bool: else: return False - def copy(self) -> AstraDBCollection: + def copy( + self, + *, + collection_name: Optional[str] = None, + token: Optional[str] = None, + api_endpoint: Optional[str] = None, + api_path: Optional[str] = None, + api_version: Optional[str] = None, + namespace: Optional[str] = None, + caller_name: Optional[str] = None, + caller_version: Optional[str] = None, + ) -> AstraDBCollection: return AstraDBCollection( - collection_name=self.collection_name, - astra_db=self.astra_db.copy(), - caller_name=self.caller_name, - caller_version=self.caller_version, + collection_name=collection_name or self.collection_name, + astra_db=self.astra_db.copy( + token=token, + api_endpoint=api_endpoint, + api_path=api_path, + api_version=api_version, + namespace=namespace, + caller_name=caller_name, + caller_version=caller_version, + ), + caller_name=caller_name or self.caller_name, + caller_version=caller_version or self.caller_version, ) def to_async(self) -> AsyncAstraDBCollection: @@ -1092,6 +1118,13 @@ def __init__( caller_name=caller_name, caller_version=caller_version, ) + else: + # if astra_db passed, copy and apply possible overrides + astra_db = astra_db.copy( + token=token, + api_endpoint=api_endpoint, + namespace=namespace, + ) # Set the remaining instance attributes self.astra_db: AsyncAstraDB = astra_db @@ -1117,12 +1150,31 @@ def __eq__(self, other: Any) -> bool: else: return False - def copy(self) -> AsyncAstraDBCollection: + def copy( + self, + *, + collection_name: Optional[str] = None, + token: Optional[str] = None, + api_endpoint: Optional[str] = None, + api_path: Optional[str] = None, + api_version: Optional[str] = None, + namespace: Optional[str] = None, + caller_name: Optional[str] = None, + caller_version: Optional[str] = None, + ) -> AsyncAstraDBCollection: return AsyncAstraDBCollection( - collection_name=self.collection_name, - astra_db=self.astra_db.copy(), - caller_name=self.caller_name, - caller_version=self.caller_version, + collection_name=collection_name or self.collection_name, + astra_db=self.astra_db.copy( + token=token, + api_endpoint=api_endpoint, + api_path=api_path, + api_version=api_version, + namespace=namespace, + caller_name=caller_name, + caller_version=caller_version, + ), + caller_name=caller_name or self.caller_name, + caller_version=caller_version or self.caller_version, ) def set_caller( @@ -2063,15 +2115,25 @@ def __eq__(self, other: Any) -> bool: else: return False - def copy(self) -> AstraDB: + def copy( + self, + *, + token: Optional[str] = None, + api_endpoint: Optional[str] = None, + api_path: Optional[str] = None, + api_version: Optional[str] = None, + namespace: Optional[str] = None, + caller_name: Optional[str] = None, + caller_version: Optional[str] = None, + ) -> AstraDB: return AstraDB( - token=self.token, - api_endpoint=self.base_url, - api_path=self.api_path, - api_version=self.api_version, - namespace=self.namespace, - caller_name=self.caller_name, - caller_version=self.caller_version, + token=token or self.token, + api_endpoint=api_endpoint or self.base_url, + api_path=api_path or self.api_path, + api_version=api_version or self.api_version, + namespace=namespace or self.namespace, + caller_name=caller_name or self.caller_name, + caller_version=caller_version or self.caller_version, ) def to_async(self) -> AsyncAstraDB: @@ -2349,15 +2411,25 @@ async def __aexit__( ) -> None: await self.client.aclose() - def copy(self) -> AsyncAstraDB: + def copy( + self, + *, + token: Optional[str] = None, + api_endpoint: Optional[str] = None, + api_path: Optional[str] = None, + api_version: Optional[str] = None, + namespace: Optional[str] = None, + caller_name: Optional[str] = None, + caller_version: Optional[str] = None, + ) -> AsyncAstraDB: return AsyncAstraDB( - token=self.token, - api_endpoint=self.base_url, - api_path=self.api_path, - api_version=self.api_version, - namespace=self.namespace, - caller_name=self.caller_name, - caller_version=self.caller_version, + token=token or self.token, + api_endpoint=api_endpoint or self.base_url, + api_path=api_path or self.api_path, + api_version=api_version or self.api_version, + namespace=namespace or self.namespace, + caller_name=caller_name or self.caller_name, + caller_version=caller_version or self.caller_version, ) def to_sync(self) -> AstraDB: diff --git a/astrapy/ops.py b/astrapy/ops.py index ef23dd95..934cecee 100644 --- a/astrapy/ops.py +++ b/astrapy/ops.py @@ -84,8 +84,23 @@ def __eq__(self, other: Any) -> bool: else: return False - def copy(self) -> AstraDBOps: - return AstraDBOps(**self.constructor_params) + def copy( + self, + *, + token: Optional[str] = None, + dev_ops_url: Optional[str] = None, + dev_ops_api_version: Optional[str] = None, + caller_name: Optional[str] = None, + caller_version: Optional[str] = None, + ) -> AstraDBOps: + return AstraDBOps( + token=token or self.constructor_params["token"], + dev_ops_url=dev_ops_url or self.constructor_params["dev_ops_url"], + dev_ops_api_version=dev_ops_api_version + or self.constructor_params["dev_ops_api_version"], + caller_name=caller_name or self.constructor_params["caller_name"], + caller_version=caller_version or self.constructor_params["caller_version"], + ) def set_caller( self, diff --git a/tests/astrapy/test_conversions.py b/tests/astrapy/test_conversions.py index 96203b95..bc6e0a3e 100644 --- a/tests/astrapy/test_conversions.py +++ b/tests/astrapy/test_conversions.py @@ -174,6 +174,173 @@ def test_copy_methods() -> None: assert c_adb_ops is not adb_ops +@pytest.mark.describe("test parameter override in copy methods") +def test_parameter_override_copy_methods() -> None: + sync_astradb = AstraDB( + token="token", + api_endpoint="api_endpoint", + api_path="api_path", + api_version="api_version", + namespace="namespace", + caller_name="caller_name", + caller_version="caller_version", + ) + sync_astradb2 = AstraDB( + token="token2", + api_endpoint="api_endpoint2", + api_path="api_path2", + api_version="api_version2", + namespace="namespace2", + caller_name="caller_name2", + caller_version="caller_version2", + ) + c_sync_astradb = sync_astradb.copy( + token="token2", + api_endpoint="api_endpoint2", + api_path="api_path2", + api_version="api_version2", + namespace="namespace2", + caller_name="caller_name2", + caller_version="caller_version2", + ) + assert c_sync_astradb == sync_astradb2 + + async_astradb = AsyncAstraDB( + token="token", + api_endpoint="api_endpoint", + api_path="api_path", + api_version="api_version", + namespace="namespace", + caller_name="caller_name", + caller_version="caller_version", + ) + async_astradb2 = AsyncAstraDB( + token="token2", + api_endpoint="api_endpoint2", + api_path="api_path2", + api_version="api_version2", + namespace="namespace2", + caller_name="caller_name2", + caller_version="caller_version2", + ) + c_async_astradb = async_astradb.copy( + token="token2", + api_endpoint="api_endpoint2", + api_path="api_path2", + api_version="api_version2", + namespace="namespace2", + caller_name="caller_name2", + caller_version="caller_version2", + ) + assert c_async_astradb == async_astradb2 + + sync_adbcollection = AstraDBCollection( + collection_name="collection_name", + astra_db=sync_astradb, + caller_name="caller_name", + caller_version="caller_version", + ) + sync_adbcollection2 = AstraDBCollection( + collection_name="collection_name2", + astra_db=sync_astradb2, + caller_name="caller_name2", + caller_version="caller_version2", + ) + c_sync_adbcollection = sync_adbcollection.copy( + collection_name="collection_name2", + token="token2", + api_endpoint="api_endpoint2", + api_path="api_path2", + api_version="api_version2", + namespace="namespace2", + caller_name="caller_name2", + caller_version="caller_version2", + ) + assert c_sync_adbcollection == sync_adbcollection2 + + async_adbcollection = AsyncAstraDBCollection( + collection_name="collection_name", + astra_db=async_astradb, + caller_name="caller_name", + caller_version="caller_version", + ) + async_adbcollection2 = AsyncAstraDBCollection( + collection_name="collection_name2", + astra_db=async_astradb2, + caller_name="caller_name2", + caller_version="caller_version2", + ) + c_async_adbcollection = async_adbcollection.copy( + collection_name="collection_name2", + token="token2", + api_endpoint="api_endpoint2", + api_path="api_path2", + api_version="api_version2", + namespace="namespace2", + caller_name="caller_name2", + caller_version="caller_version2", + ) + assert c_async_adbcollection == async_adbcollection2 + + adb_ops = AstraDBOps( + token="token", + dev_ops_url="dev_ops_url", + dev_ops_api_version="dev_ops_api_version", + caller_name="caller_name", + caller_version="caller_version", + ) + adb_ops2 = AstraDBOps( + token="token2", + dev_ops_url="dev_ops_url2", + dev_ops_api_version="dev_ops_api_version2", + caller_name="caller_name2", + caller_version="caller_version2", + ) + c_adb_ops = adb_ops.copy( + token="token2", + dev_ops_url="dev_ops_url2", + dev_ops_api_version="dev_ops_api_version2", + caller_name="caller_name2", + caller_version="caller_version2", + ) + assert c_adb_ops == adb_ops2 + + +@pytest.mark.describe("test parameter override when instantiating collections") +def test_parameter_override_collection_instances() -> None: + astradb0 = AstraDB(token="t0", api_endpoint="a0") + astradb1 = AstraDB(token="t1", api_endpoint="a1", namespace="n1") + col0 = AstraDBCollection( + collection_name="col0", + astra_db=astradb0, + ) + col1 = AstraDBCollection( + collection_name="col0", + astra_db=astradb0, + token="t1", + api_endpoint="a1", + namespace="n1", + ) + assert col0 != col1 + assert col1 == AstraDBCollection(collection_name="col0", astra_db=astradb1) + + a_astradb0 = AsyncAstraDB(token="t0", api_endpoint="a0") + a_astradb1 = AsyncAstraDB(token="t1", api_endpoint="a1", namespace="n1") + a_col0 = AsyncAstraDBCollection( + collection_name="col0", + astra_db=a_astradb0, + ) + a_col1 = AsyncAstraDBCollection( + collection_name="col0", + astra_db=a_astradb0, + token="t1", + api_endpoint="a1", + namespace="n1", + ) + assert a_col0 != a_col1 + assert a_col1 == AsyncAstraDBCollection(collection_name="col0", astra_db=a_astradb1) + + @pytest.mark.describe("test set_caller works in place for clients") def test_set_caller_clients() -> None: astradb0 = AstraDB(token="t1", api_endpoint="a1")