From da067177a4d25595dc9d5447778428972a64bed2 Mon Sep 17 00:00:00 2001 From: Stefano Lottini Date: Thu, 8 Feb 2024 10:59:19 +0100 Subject: [PATCH] Full API to manage the caller identity (#198) * caller management throughout * copy and __eq__ methods, caller inheritance in all directions, and comprehensive test thereof --- astrapy/db.py | 210 ++++++++++++++++++++-- astrapy/ops.py | 53 +++++- tests/astrapy/test_conversions.py | 278 ++++++++++++++++++++++++++++++ 3 files changed, 526 insertions(+), 15 deletions(-) create mode 100644 tests/astrapy/test_conversions.py diff --git a/astrapy/db.py b/astrapy/db.py index e4ea0eb5..3c28b147 100644 --- a/astrapy/db.py +++ b/astrapy/db.py @@ -76,6 +76,8 @@ def __init__( token: Optional[str] = None, api_endpoint: Optional[str] = None, namespace: Optional[str] = None, + caller_name: Optional[str] = None, + caller_version: Optional[str] = None, ) -> None: """ Initialize an AstraDBCollection instance. @@ -85,6 +87,10 @@ def __init__( token (str, optional): Authentication token for Astra DB. api_endpoint (str, optional): API endpoint URL. namespace (str, optional): Namespace for the database. + caller_name (str, optional): identity of the caller ("my_framework") + If passing a client, its caller is used as fallback + caller_version (str, optional): version of the caller code ("1.0.3") + If passing a client, its caller is used as fallback """ # Check for presence of the Astra DB object if astra_db is None: @@ -92,22 +98,64 @@ def __init__( raise AssertionError("Must provide token and api_endpoint") astra_db = AstraDB( - token=token, api_endpoint=api_endpoint, namespace=namespace + token=token, + api_endpoint=api_endpoint, + namespace=namespace, + caller_name=caller_name, + caller_version=caller_version, ) # Set the remaining instance attributes self.astra_db = astra_db + self.caller_name = caller_name or self.astra_db.caller_name + self.caller_version = caller_version or self.astra_db.caller_version self.collection_name = collection_name self.base_path = f"{self.astra_db.base_path}/{self.collection_name}" def __repr__(self) -> str: return f'AstraDBCollection[astra_db="{self.astra_db}", collection_name="{self.collection_name}"]' + def __eq__(self, other: Any) -> bool: + if isinstance(other, AstraDBCollection): + return all( + [ + self.collection_name == other.collection_name, + self.astra_db == other.astra_db, + self.caller_name == other.caller_name, + self.caller_version == other.caller_version, + ] + ) + else: + return False + + def copy(self) -> AstraDBCollection: + return AstraDBCollection( + collection_name=self.collection_name, + astra_db=self.astra_db.copy(), + caller_name=self.caller_name, + caller_version=self.caller_version, + ) + def to_async(self) -> AsyncAstraDBCollection: return AsyncAstraDBCollection( - astra_db=self.astra_db.to_async(), collection_name=self.collection_name + collection_name=self.collection_name, + astra_db=self.astra_db.to_async(), + caller_name=self.caller_name, + caller_version=self.caller_version, ) + def set_caller( + self, + caller_name: Optional[str] = None, + caller_version: Optional[str] = None, + ) -> None: + self.astra_db.set_caller( + caller_name=caller_name, + caller_version=caller_version, + ) + self.caller_name = caller_name + self.caller_version = caller_version + def _request( self, method: str = http_methods.POST, @@ -126,8 +174,8 @@ def _request( url_params=url_params, path=path, skip_error_check=skip_error_check, - caller_name=None, - caller_version=None, + caller_name=self.caller_name, + caller_version=self.caller_version, ) response = restore_from_api(direct_response) return response @@ -989,6 +1037,8 @@ def __init__( token: Optional[str] = None, api_endpoint: Optional[str] = None, namespace: Optional[str] = None, + caller_name: Optional[str] = None, + caller_version: Optional[str] = None, ) -> None: """ Initialize an AstraDBCollection instance. @@ -998,6 +1048,10 @@ def __init__( token (str, optional): Authentication token for Astra DB. api_endpoint (str, optional): API endpoint URL. namespace (str, optional): Namespace for the database. + caller_name (str, optional): identity of the caller ("my_framework") + If passing a client, its caller is used as fallback + caller_version (str, optional): version of the caller code ("1.0.3") + If passing a client, its caller is used as fallback """ # Check for presence of the Astra DB object if astra_db is None: @@ -1005,11 +1059,17 @@ def __init__( raise AssertionError("Must provide token and api_endpoint") astra_db = AsyncAstraDB( - token=token, api_endpoint=api_endpoint, namespace=namespace + token=token, + api_endpoint=api_endpoint, + namespace=namespace, + caller_name=caller_name, + caller_version=caller_version, ) # Set the remaining instance attributes self.astra_db: AsyncAstraDB = astra_db + self.caller_name = caller_name or self.astra_db.caller_name + self.caller_version = caller_version or self.astra_db.caller_version self.client = astra_db.client self.collection_name = collection_name self.base_path = f"{self.astra_db.base_path}/{self.collection_name}" @@ -1017,9 +1077,45 @@ def __init__( def __repr__(self) -> str: return f'AsyncAstraDBCollection[astra_db="{self.astra_db}", collection_name="{self.collection_name}"]' + def __eq__(self, other: Any) -> bool: + if isinstance(other, AsyncAstraDBCollection): + return all( + [ + self.collection_name == other.collection_name, + self.astra_db == other.astra_db, + self.caller_name == other.caller_name, + self.caller_version == other.caller_version, + ] + ) + else: + return False + + def copy(self) -> AsyncAstraDBCollection: + return AsyncAstraDBCollection( + collection_name=self.collection_name, + astra_db=self.astra_db.copy(), + caller_name=self.caller_name, + caller_version=self.caller_version, + ) + + def set_caller( + self, + caller_name: Optional[str] = None, + caller_version: Optional[str] = None, + ) -> None: + self.astra_db.set_caller( + caller_name=caller_name, + caller_version=caller_version, + ) + self.caller_name = caller_name + self.caller_version = caller_version + def to_sync(self) -> AstraDBCollection: return AstraDBCollection( - astra_db=self.astra_db.to_sync(), collection_name=self.collection_name + collection_name=self.collection_name, + astra_db=self.astra_db.to_sync(), + caller_name=self.caller_name, + caller_version=self.caller_version, ) async def _request( @@ -1041,8 +1137,8 @@ async def _request( url_params=url_params, path=path, skip_error_check=skip_error_check, - caller_name=None, - caller_version=None, + caller_name=self.caller_name, + caller_version=self.caller_version, ) response = restore_from_api(adirect_response) return response @@ -1859,14 +1955,23 @@ def __init__( api_path: Optional[str] = None, api_version: Optional[str] = None, namespace: Optional[str] = None, + caller_name: Optional[str] = None, + caller_version: Optional[str] = None, ) -> None: """ Initialize an Astra DB instance. Args: token (str): Authentication token for Astra DB. api_endpoint (str): API endpoint URL. + api_path (str, optional): used to override default URI construction + api_version (str, optional): to override default URI construction namespace (str, optional): Namespace for the database. + caller_name (str, optional): identity of the caller ("my_framework") + caller_version (str, optional): version of the caller code ("1.0.3") """ + self.caller_name = caller_name + self.caller_version = caller_version + if token is None or api_endpoint is None: raise AssertionError("Must provide token and api_endpoint") @@ -1895,6 +2000,32 @@ def __init__( def __repr__(self) -> str: return f'AstraDB[endpoint="{self.base_url}", keyspace="{self.namespace}"]' + def __eq__(self, other: Any) -> bool: + if isinstance(other, AstraDB): + # work on the "normalized" quantities (stripped, etc) + return all( + [ + self.token == other.token, + self.base_url == other.base_url, + self.base_path == other.base_path, + self.caller_name == other.caller_name, + self.caller_version == other.caller_version, + ] + ) + else: + return False + + def copy(self) -> AstraDB: + return AstraDB( + token=self.token, + api_endpoint=self.base_url, + api_path=self.api_path, + api_version=self.api_version, + namespace=self.namespace, + caller_name=self.caller_name, + caller_version=self.caller_version, + ) + def to_async(self) -> AsyncAstraDB: return AsyncAstraDB( token=self.token, @@ -1902,8 +2033,18 @@ def to_async(self) -> AsyncAstraDB: api_path=self.api_path, api_version=self.api_version, namespace=self.namespace, + caller_name=self.caller_name, + caller_version=self.caller_version, ) + def set_caller( + self, + caller_name: Optional[str] = None, + caller_version: Optional[str] = None, + ) -> None: + self.caller_name = caller_name + self.caller_version = caller_version + def _request( self, method: str = http_methods.POST, @@ -1922,8 +2063,8 @@ def _request( url_params=url_params, path=path, skip_error_check=skip_error_check, - caller_name=None, - caller_version=None, + caller_name=self.caller_name, + caller_version=self.caller_version, ) response = restore_from_api(direct_response) return response @@ -2088,14 +2229,23 @@ def __init__( api_path: Optional[str] = None, api_version: Optional[str] = None, namespace: Optional[str] = None, + caller_name: Optional[str] = None, + caller_version: Optional[str] = None, ) -> None: """ Initialize an Astra DB instance. Args: token (str): Authentication token for Astra DB. api_endpoint (str): API endpoint URL. + api_path (str, optional): used to override default URI construction + api_version (str, optional): to override default URI construction namespace (str, optional): Namespace for the database. + caller_name (str, optional): identity of the caller ("my_framework") + caller_version (str, optional): version of the caller code ("1.0.3") """ + self.caller_name = caller_name + self.caller_version = caller_version + self.client = httpx.AsyncClient() if token is None or api_endpoint is None: raise AssertionError("Must provide token and api_endpoint") @@ -2125,6 +2275,21 @@ def __init__( def __repr__(self) -> str: return f'AsyncAstraDB[endpoint="{self.base_url}", keyspace="{self.namespace}"]' + def __eq__(self, other: Any) -> bool: + if isinstance(other, AsyncAstraDB): + # work on the "normalized" quantities (stripped, etc) + return all( + [ + self.token == other.token, + self.base_url == other.base_url, + self.base_path == other.base_path, + self.caller_name == other.caller_name, + self.caller_version == other.caller_version, + ] + ) + else: + return False + async def __aenter__(self) -> AsyncAstraDB: return self @@ -2136,6 +2301,17 @@ async def __aexit__( ) -> None: await self.client.aclose() + def copy(self) -> AsyncAstraDB: + return AsyncAstraDB( + token=self.token, + api_endpoint=self.base_url, + api_path=self.api_path, + api_version=self.api_version, + namespace=self.namespace, + caller_name=self.caller_name, + caller_version=self.caller_version, + ) + def to_sync(self) -> AstraDB: return AstraDB( token=self.token, @@ -2143,8 +2319,18 @@ def to_sync(self) -> AstraDB: api_path=self.api_path, api_version=self.api_version, namespace=self.namespace, + caller_name=self.caller_name, + caller_version=self.caller_version, ) + def set_caller( + self, + caller_name: Optional[str] = None, + caller_version: Optional[str] = None, + ) -> None: + self.caller_name = caller_name + self.caller_version = caller_version + async def _request( self, method: str = http_methods.POST, @@ -2163,8 +2349,8 @@ async def _request( url_params=url_params, path=path, skip_error_check=skip_error_check, - caller_name=None, - caller_version=None, + caller_name=self.caller_name, + caller_version=self.caller_version, ) response = restore_from_api(adirect_response) return response diff --git a/astrapy/ops.py b/astrapy/ops.py index 4c9ea5ff..33295952 100644 --- a/astrapy/ops.py +++ b/astrapy/ops.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import logging -from typing import Any, cast, Dict, Optional +from typing import Any, cast, Dict, Optional, TypedDict import httpx from astrapy.api import api_request, raw_api_request @@ -30,6 +31,14 @@ logger = logging.getLogger(__name__) +class AstraDBOpsConstructorParams(TypedDict): + token: str + dev_ops_url: Optional[str] + dev_ops_api_version: Optional[str] + caller_name: Optional[str] + caller_version: Optional[str] + + class AstraDBOps: # Initialize the shared httpx client as a class attribute client = httpx.Client() @@ -39,7 +48,20 @@ def __init__( token: str, dev_ops_url: Optional[str] = None, dev_ops_api_version: Optional[str] = None, + caller_name: Optional[str] = None, + caller_version: Optional[str] = None, ) -> None: + self.caller_name = caller_name + self.caller_version = caller_version + # constructor params (for the copy() method): + self.constructor_params: AstraDBOpsConstructorParams = { + "token": token, + "dev_ops_url": dev_ops_url, + "dev_ops_api_version": dev_ops_api_version, + "caller_name": caller_name, + "caller_version": caller_version, + } + # dev_ops_url = (dev_ops_url or DEFAULT_DEV_OPS_URL).strip("/") dev_ops_api_version = ( dev_ops_api_version or DEFAULT_DEV_OPS_API_VERSION @@ -48,6 +70,31 @@ def __init__( self.token = "Bearer " + token self.base_url = f"https://{dev_ops_url}/{dev_ops_api_version}" + def __eq__(self, other: Any) -> bool: + if isinstance(other, AstraDBOps): + # work on the "normalized" quantities (stripped, etc) + return all( + [ + self.token == other.token, + self.base_url == other.base_url, + self.caller_name == other.caller_name, + self.caller_version == other.caller_version, + ] + ) + else: + return False + + def copy(self) -> AstraDBOps: + return AstraDBOps(**self.constructor_params) + + def set_caller( + self, + caller_name: Optional[str] = None, + caller_version: Optional[str] = None, + ) -> None: + self.caller_name = caller_name + self.caller_version = caller_version + def _ops_request( self, method: str, @@ -66,8 +113,8 @@ def _ops_request( json_data=json_data, url_params=_options, path=path, - caller_name=None, - caller_version=None, + caller_name=self.caller_name, + caller_version=self.caller_version, ) return raw_response diff --git a/tests/astrapy/test_conversions.py b/tests/astrapy/test_conversions.py new file mode 100644 index 00000000..96203b95 --- /dev/null +++ b/tests/astrapy/test_conversions.py @@ -0,0 +1,278 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for the User-Agent customization logic +""" + +import logging +import pytest + +from astrapy.db import AstraDB, AstraDBCollection, AsyncAstraDB, AsyncAstraDBCollection +from astrapy.ops import AstraDBOps + + +logger = logging.getLogger(__name__) + + +@pytest.mark.describe("test basic equality between instances") +def test_instance_equality() -> None: + astradb_a = AstraDB(token="t1", api_endpoint="a1") + astradb_b = AstraDB(token="t1", api_endpoint="a1") + astradb_c = AstraDB(token="t2", api_endpoint="a2") + + assert astradb_a == astradb_b + assert astradb_a != astradb_c + + astradb_coll_a = AstraDBCollection("c1", token="t1", api_endpoint="a1") + + assert astradb_a != astradb_coll_a + + astradb_coll_b = AstraDBCollection("c1", astra_db=astradb_a) + astradb_coll_c = AstraDBCollection("c3", token="t3", api_endpoint="a3") + + assert astradb_coll_a == astradb_coll_b + assert astradb_coll_a != astradb_coll_c + + astradbops_o1 = AstraDBOps(token="t1") + astradbops_o2 = AstraDBOps(token="t1") + astradbops_o3 = AstraDBOps(token="t3") + + assert astradbops_o1 == astradbops_o2 + assert astradbops_o1 != astradbops_o3 + + +@pytest.mark.describe("test basic equality between async instances") +def test_instance_equality_async() -> None: + astradb_a = AsyncAstraDB(token="t1", api_endpoint="a1") + astradb_b = AsyncAstraDB(token="t1", api_endpoint="a1") + astradb_c = AsyncAstraDB(token="t2", api_endpoint="a2") + + assert astradb_a == astradb_b + assert astradb_a != astradb_c + + astradb_coll_a = AsyncAstraDBCollection("c1", token="t1", api_endpoint="a1") + + assert astradb_a != astradb_coll_a + + astradb_coll_b = AsyncAstraDBCollection("c1", astra_db=astradb_a) + astradb_coll_c = AsyncAstraDBCollection("c3", token="t3", api_endpoint="a3") + + assert astradb_coll_a == astradb_coll_b + assert astradb_coll_a != astradb_coll_c + + +@pytest.mark.describe("test to_sync and to_async methods combine to identity") +def test_round_conversion_is_noop() -> None: + sync_astradb = AstraDB( + token="token", + api_endpoint="api_endpoint", + api_path="api_path", + api_version="api_version", + namespace="namespace", + caller_name="caller_name", + caller_version="caller_version", + ) + assert sync_astradb.to_async().to_sync() == sync_astradb + + async_astradb = AsyncAstraDB( + token="token", + api_endpoint="api_endpoint", + api_path="api_path", + api_version="api_version", + namespace="namespace", + caller_name="caller_name", + caller_version="caller_version", + ) + assert async_astradb.to_sync().to_async() == async_astradb + + sync_adbcollection = AstraDBCollection( + collection_name="collection_name", + astra_db=sync_astradb, + caller_name="caller_name", + caller_version="caller_version", + ) + assert sync_adbcollection.to_async().to_sync() == sync_adbcollection + + async_adbcollection = AsyncAstraDBCollection( + collection_name="collection_name", + astra_db=async_astradb, + caller_name="caller_name", + caller_version="caller_version", + ) + assert async_adbcollection.to_sync().to_async() == async_adbcollection + + +@pytest.mark.describe("test copy methods create identical objects") +def test_copy_methods() -> None: + sync_astradb = AstraDB( + token="token", + api_endpoint="api_endpoint", + api_path="api_path", + api_version="api_version", + namespace="namespace", + caller_name="caller_name", + caller_version="caller_version", + ) + c_sync_astradb = sync_astradb.copy() + assert c_sync_astradb == sync_astradb + assert c_sync_astradb is not sync_astradb + + async_astradb = AsyncAstraDB( + token="token", + api_endpoint="api_endpoint", + api_path="api_path", + api_version="api_version", + namespace="namespace", + caller_name="caller_name", + caller_version="caller_version", + ) + c_async_astradb = async_astradb.copy() + assert c_async_astradb == async_astradb + assert c_async_astradb is not async_astradb + + sync_adbcollection = AstraDBCollection( + collection_name="collection_name", + astra_db=sync_astradb, + caller_name="caller_name", + caller_version="caller_version", + ) + c_sync_adbcollection = sync_adbcollection.copy() + assert c_sync_adbcollection == sync_adbcollection + assert c_sync_adbcollection is not sync_adbcollection + + async_adbcollection = AsyncAstraDBCollection( + collection_name="collection_name", + astra_db=async_astradb, + caller_name="caller_name", + caller_version="caller_version", + ) + c_async_adbcollection = async_adbcollection.copy() + assert c_async_adbcollection == async_adbcollection + assert c_async_adbcollection is not async_adbcollection + + adb_ops = AstraDBOps( + token="token", + dev_ops_url="dev_ops_url", + dev_ops_api_version="dev_ops_api_version", + caller_name="caller_name", + caller_version="caller_version", + ) + c_adb_ops = adb_ops.copy() + assert c_adb_ops == adb_ops + assert c_adb_ops is not adb_ops + + +@pytest.mark.describe("test set_caller works in place for clients") +def test_set_caller_clients() -> None: + astradb0 = AstraDB(token="t1", api_endpoint="a1") + astradbops0 = AstraDBOps(token="t1") + async_astradb0 = AsyncAstraDB(token="t1", api_endpoint="a1") + # + astradb0.set_caller(caller_name="CN", caller_version="CV") + astradbops0.set_caller(caller_name="CN", caller_version="CV") + async_astradb0.set_caller(caller_name="CN", caller_version="CV") + # + astradb = AstraDB( + token="t1", api_endpoint="a1", caller_name="CN", caller_version="CV" + ) + astradbops = AstraDBOps(token="t1", caller_name="CN", caller_version="CV") + async_astradb = AsyncAstraDB( + token="t1", api_endpoint="a1", caller_name="CN", caller_version="CV" + ) + # + assert astradb0 == astradb + assert astradbops0 == astradbops + assert async_astradb0 == async_astradb + + +@pytest.mark.describe("test set_caller works in place for collections") +def test_set_caller_collections() -> None: + """ + This tests (1) the collection set_caller, (2) the fact that it is propagated + to the client, and (3) the propagation of the caller info to the astra_db + being created if not passed to the collection constructor. + """ + adb_collection0 = AstraDBCollection("c1", token="t1", api_endpoint="a1") + async_adb_collection0 = AsyncAstraDBCollection("c1", token="t1", api_endpoint="a1") + # + adb_collection0.set_caller(caller_name="CN", caller_version="CV") + async_adb_collection0.set_caller(caller_name="CN", caller_version="CV") + # + adb_collection = AstraDBCollection( + "c1", token="t1", api_endpoint="a1", caller_name="CN", caller_version="CV" + ) + async_adb_collection = AsyncAstraDBCollection( + "c1", token="t1", api_endpoint="a1", caller_name="CN", caller_version="CV" + ) + # + assert adb_collection0 == adb_collection + assert async_adb_collection0 == async_adb_collection + + +@pytest.mark.describe("test caller inheritance from client to collection") +def test_caller_inheritance_from_clients() -> None: + """ + This tests the fact that when passing a client in collection creation + the caller is acquired by default from said client. + """ + astradb = AstraDB( + token="t1", api_endpoint="a1", caller_name="CN", caller_version="CV" + ) + async_astradb = AsyncAstraDB( + token="t1", api_endpoint="a1", caller_name="CN", caller_version="CV" + ) + + adb_collection = AstraDBCollection("c1", astra_db=astradb) + async_adb_collection = AsyncAstraDBCollection("c1", astra_db=async_astradb) + + ref_adb_collection = AstraDBCollection( + "c1", token="t1", api_endpoint="a1", caller_name="CN", caller_version="CV" + ) + async_ref_adb_collection = AsyncAstraDBCollection( + "c1", token="t1", api_endpoint="a1", caller_name="CN", caller_version="CV" + ) + + assert ref_adb_collection == adb_collection + assert async_ref_adb_collection == async_adb_collection + + +@pytest.mark.describe("test caller inheritance when spawning a collection") +async def test_caller_inheritance_spawning() -> None: + """ + This tests that the caller is retained with the clients' .collection() + method. + As this module is for lightweight tests, no actual API operations involved, + this single test will be enough: create_collection and truncate_collection + are not covered (they work identically to this one though). + """ + astradb = AstraDB( + token="t1", api_endpoint="a1", caller_name="CN", caller_version="CV" + ) + async_astradb = AsyncAstraDB( + token="t1", api_endpoint="a1", caller_name="CN", caller_version="CV" + ) + + spawned_collection = astradb.collection("c1") + async_spawned_collection = await async_astradb.collection("c1") + + ref_spawned_collection = AstraDBCollection( + "c1", token="t1", api_endpoint="a1", caller_name="CN", caller_version="CV" + ) + async_ref_spawned_collection = AsyncAstraDBCollection( + "c1", token="t1", api_endpoint="a1", caller_name="CN", caller_version="CV" + ) + + assert spawned_collection == ref_spawned_collection + assert async_spawned_collection == async_ref_spawned_collection