From ea772763f9e9f241f9d8efb7a02e65f9e53a177c Mon Sep 17 00:00:00 2001 From: Stefano Lottini Date: Wed, 28 Feb 2024 02:37:50 +0100 Subject: [PATCH] Split classes, modules, tests to keep the "idiomatic" layer well separate and not touch the "astrapy" layer (#222) * astrapy/idiomatic split structure: modules, imports, tests * completed refactor into new structure * complete refactoring into split classes, incl. tests. Not many methods implemented yet * astrapy classes allow parameter override in their copy methods * collection constructors does apply parameter overrides even when astra_db passed * wip for the optional namespace when creating Collections * create_collection, get_collection, drop_collection, Collection constructor --- README.md | 11 + astrapy/__init__.py | 24 +- astrapy/db.py | 203 ++++----- astrapy/idiomatic/__init__.py | 23 + astrapy/idiomatic/collection.py | 249 +++++++++++ astrapy/idiomatic/database.py | 329 +++++++++++++++ astrapy/idiomatic/utils.py | 26 ++ astrapy/ops.py | 19 +- astrapy/utils.py | 4 +- tests/astrapy/conftest.py | 385 +++++++++++++++++ tests/astrapy/test_async_db_ddl.py | 4 +- tests/astrapy/test_conversions.py | 167 ++++++++ tests/astrapy/test_db_ddl.py | 4 +- tests/astrapy/test_db_dml.py | 7 +- tests/conftest.py | 392 +----------------- tests/idiomatic/__init__.py | 0 tests/idiomatic/conftest.py | 54 +++ tests/idiomatic/integration/__init__.py | 0 .../integration/test_collections_async.py | 115 +++++ .../integration/test_collections_sync.py | 115 +++++ .../integration/test_databases_async.py | 105 +++++ .../integration/test_databases_sync.py | 106 +++++ tests/idiomatic/integration/test_ddl_async.py | 35 ++ tests/idiomatic/integration/test_ddl_sync.py | 35 ++ tests/idiomatic/unit/__init__.py | 0 tests/idiomatic/unit/test_collections.py | 21 + tests/idiomatic/unit/test_databases.py | 21 + 27 files changed, 1952 insertions(+), 502 deletions(-) create mode 100644 astrapy/idiomatic/__init__.py create mode 100644 astrapy/idiomatic/collection.py create mode 100644 astrapy/idiomatic/database.py create mode 100644 astrapy/idiomatic/utils.py create mode 100644 tests/astrapy/conftest.py create mode 100644 tests/idiomatic/__init__.py create mode 100644 tests/idiomatic/conftest.py create mode 100644 tests/idiomatic/integration/__init__.py create mode 100644 tests/idiomatic/integration/test_collections_async.py create mode 100644 tests/idiomatic/integration/test_collections_sync.py create mode 100644 tests/idiomatic/integration/test_databases_async.py create mode 100644 tests/idiomatic/integration/test_databases_sync.py create mode 100644 tests/idiomatic/integration/test_ddl_async.py create mode 100644 tests/idiomatic/integration/test_ddl_sync.py create mode 100644 tests/idiomatic/unit/__init__.py create mode 100644 tests/idiomatic/unit/test_collections.py create mode 100644 tests/idiomatic/unit/test_databases.py diff --git a/README.md b/README.md index 93e91fe9..338abd0b 100644 --- a/README.md +++ b/README.md @@ -414,3 +414,14 @@ To enable the `AstraDBOps` testing (off by default): ```bash TEST_ASTRADBOPS=1 poetry run pytest [...] ``` + +To separately test the "astrapy proper" vs. the "idiomatic" part, and/or only the unit/integration part of the latter: + +``` +poetry run pytest tests/astrapy +poetry run pytest tests/idiomatic +poetry run pytest tests/idiomatic/unit +poetry run pytest tests/idiomatic/integration +``` + +(the above can be combined with the options seen earlier, where it makes sense). diff --git a/astrapy/__init__.py b/astrapy/__init__.py index 055922b0..7b6b29ea 100644 --- a/astrapy/__init__.py +++ b/astrapy/__init__.py @@ -16,10 +16,8 @@ import os import importlib.metadata -from typing import Any - -def get_version() -> Any: +def get_version() -> str: try: # Poetry will create a __version__ attribute in the package's __init__.py file return importlib.metadata.version(__package__) @@ -38,11 +36,27 @@ def get_version() -> Any: pyproject_data = toml.loads(file_contents) # Return the version from the poetry section - return pyproject_data["tool"]["poetry"]["version"] + return str(pyproject_data["tool"]["poetry"]["version"]) # If the pyproject.toml file does not exist or the version is not found, return unknown except (FileNotFoundError, KeyError): return "unknown" -__version__ = get_version() +__version__: str = get_version() + +# There's a circular-import issue to heal here to bring this to top +from astrapy.idiomatic import ( # noqa: E402 + AsyncCollection, + AsyncDatabase, + Collection, + Database, +) + +__all__ = [ + "AsyncCollection", + "AsyncDatabase", + "Collection", + "Database", + "__version__", +] diff --git a/astrapy/db.py b/astrapy/db.py index 6523c0d0..e1461451 100644 --- a/astrapy/db.py +++ b/astrapy/db.py @@ -53,7 +53,6 @@ http_methods, normalize_for_api, restore_from_api, - return_unsupported_error ) from astrapy.types import ( API_DOC, @@ -105,11 +104,20 @@ def __init__( caller_name=caller_name, caller_version=caller_version, ) + else: + # if astra_db passed, copy and apply possible overrides + astra_db = astra_db.copy( + token=token, + api_endpoint=api_endpoint, + namespace=namespace, + caller_name=caller_name, + caller_version=caller_version, + ) # Set the remaining instance attributes self.astra_db = astra_db - self.caller_name = caller_name or self.astra_db.caller_name - self.caller_version = caller_version or self.astra_db.caller_version + self.caller_name = self.astra_db.caller_name + self.caller_version = self.astra_db.caller_version self.collection_name = collection_name self.base_path = f"{self.astra_db.base_path}/{self.collection_name}" @@ -129,12 +137,31 @@ def __eq__(self, other: Any) -> bool: else: return False - def copy(self) -> AstraDBCollection: + def copy( + self, + *, + collection_name: Optional[str] = None, + token: Optional[str] = None, + api_endpoint: Optional[str] = None, + api_path: Optional[str] = None, + api_version: Optional[str] = None, + namespace: Optional[str] = None, + caller_name: Optional[str] = None, + caller_version: Optional[str] = None, + ) -> AstraDBCollection: return AstraDBCollection( - collection_name=self.collection_name, - astra_db=self.astra_db.copy(), - caller_name=self.caller_name, - caller_version=self.caller_version, + collection_name=collection_name or self.collection_name, + astra_db=self.astra_db.copy( + token=token, + api_endpoint=api_endpoint, + api_path=api_path, + api_version=api_version, + namespace=namespace, + caller_name=caller_name, + caller_version=caller_version, + ), + caller_name=caller_name or self.caller_name, + caller_version=caller_version or self.caller_version, ) def to_async(self) -> AsyncAstraDBCollection: @@ -1093,11 +1120,20 @@ def __init__( caller_name=caller_name, caller_version=caller_version, ) + else: + # if astra_db passed, copy and apply possible overrides + astra_db = astra_db.copy( + token=token, + api_endpoint=api_endpoint, + namespace=namespace, + caller_name=caller_name, + caller_version=caller_version, + ) # Set the remaining instance attributes self.astra_db: AsyncAstraDB = astra_db - self.caller_name = caller_name or self.astra_db.caller_name - self.caller_version = caller_version or self.astra_db.caller_version + self.caller_name = self.astra_db.caller_name + self.caller_version = self.astra_db.caller_version self.client = astra_db.client self.collection_name = collection_name self.base_path = f"{self.astra_db.base_path}/{self.collection_name}" @@ -1118,12 +1154,31 @@ def __eq__(self, other: Any) -> bool: else: return False - def copy(self) -> AsyncAstraDBCollection: + def copy( + self, + *, + collection_name: Optional[str] = None, + token: Optional[str] = None, + api_endpoint: Optional[str] = None, + api_path: Optional[str] = None, + api_version: Optional[str] = None, + namespace: Optional[str] = None, + caller_name: Optional[str] = None, + caller_version: Optional[str] = None, + ) -> AsyncAstraDBCollection: return AsyncAstraDBCollection( - collection_name=self.collection_name, - astra_db=self.astra_db.copy(), - caller_name=self.caller_name, - caller_version=self.caller_version, + collection_name=collection_name or self.collection_name, + astra_db=self.astra_db.copy( + token=token, + api_endpoint=api_endpoint, + api_path=api_path, + api_version=api_version, + namespace=namespace, + caller_name=caller_name, + caller_version=caller_version, + ), + caller_name=caller_name or self.caller_name, + caller_version=caller_version or self.caller_version, ) def set_caller( @@ -1991,58 +2046,6 @@ async def concurrent_upsert(doc: API_DOC) -> str: if isinstance(result, BaseException) and not isinstance(result, Exception): raise result return results # type: ignore - - # Mongodb calls not supported by the API - async def find_raw_batches(): - return_unsupported_error() - - async def aggregate(): - return_unsupported_error() - - async def aggregate_raw_batches(): - return_unsupported_error() - - async def watch(): - return_unsupported_error() - - async def rename(): - return_unsupported_error() - - async def create_index(): - return_unsupported_error() - - async def create_indexes(): - return_unsupported_error() - - async def drop_index(): - return_unsupported_error() - - async def drop_indexes(): - return_unsupported_error() - - async def list_indexes(): - return_unsupported_error() - - async def index_information(): - return_unsupported_error() - - async def create_search_index(): - return_unsupported_error() - - async def create_search_indexes(): - return_unsupported_error() - - async def drop_search_index(): - return_unsupported_error() - - async def list_search_indexes(): - return_unsupported_error() - - async def update_search_index(): - return_unsupported_error() - - async def distinct(): - return_unsupported_error() class AstraDB: @@ -2116,15 +2119,25 @@ def __eq__(self, other: Any) -> bool: else: return False - def copy(self) -> AstraDB: + def copy( + self, + *, + token: Optional[str] = None, + api_endpoint: Optional[str] = None, + api_path: Optional[str] = None, + api_version: Optional[str] = None, + namespace: Optional[str] = None, + caller_name: Optional[str] = None, + caller_version: Optional[str] = None, + ) -> AstraDB: return AstraDB( - token=self.token, - api_endpoint=self.base_url, - api_path=self.api_path, - api_version=self.api_version, - namespace=self.namespace, - caller_name=self.caller_name, - caller_version=self.caller_version, + token=token or self.token, + api_endpoint=api_endpoint or self.base_url, + api_path=api_path or self.api_path, + api_version=api_version or self.api_version, + namespace=namespace or self.namespace, + caller_name=caller_name or self.caller_name, + caller_version=caller_version or self.caller_version, ) def to_async(self) -> AsyncAstraDB: @@ -2321,20 +2334,6 @@ def truncate_collection(self, collection_name: str) -> AstraDBCollection: # return the collection itself return collection - def aggregate(): - return_unsupported_error() - - def cursor_command(): - return_unsupported_error() - - def dereference(): - return_unsupported_error() - - def watch(): - return_unsupported_error() - - def validate_collection(): - return_unsupported_error() class AsyncAstraDB: def __init__( @@ -2416,15 +2415,25 @@ async def __aexit__( ) -> None: await self.client.aclose() - def copy(self) -> AsyncAstraDB: + def copy( + self, + *, + token: Optional[str] = None, + api_endpoint: Optional[str] = None, + api_path: Optional[str] = None, + api_version: Optional[str] = None, + namespace: Optional[str] = None, + caller_name: Optional[str] = None, + caller_version: Optional[str] = None, + ) -> AsyncAstraDB: return AsyncAstraDB( - token=self.token, - api_endpoint=self.base_url, - api_path=self.api_path, - api_version=self.api_version, - namespace=self.namespace, - caller_name=self.caller_name, - caller_version=self.caller_version, + token=token or self.token, + api_endpoint=api_endpoint or self.base_url, + api_path=api_path or self.api_path, + api_version=api_version or self.api_version, + namespace=namespace or self.namespace, + caller_name=caller_name or self.caller_name, + caller_version=caller_version or self.caller_version, ) def to_sync(self) -> AstraDB: diff --git a/astrapy/idiomatic/__init__.py b/astrapy/idiomatic/__init__.py new file mode 100644 index 00000000..595e60c6 --- /dev/null +++ b/astrapy/idiomatic/__init__.py @@ -0,0 +1,23 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from astrapy.idiomatic.collection import AsyncCollection, Collection +from astrapy.idiomatic.database import AsyncDatabase, Database + +__all__ = [ + "AsyncCollection", + "Collection", + "AsyncDatabase", + "Database", +] diff --git a/astrapy/idiomatic/collection.py b/astrapy/idiomatic/collection.py new file mode 100644 index 00000000..862db9ed --- /dev/null +++ b/astrapy/idiomatic/collection.py @@ -0,0 +1,249 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Any, Optional, TypedDict +from astrapy.db import AstraDBCollection, AsyncAstraDBCollection +from astrapy.idiomatic.utils import unsupported +from astrapy.idiomatic.database import AsyncDatabase, Database + + +class CollectionConstructorParams(TypedDict): + database: Database + name: str + namespace: Optional[str] + caller_name: Optional[str] + caller_version: Optional[str] + + +class AsyncCollectionConstructorParams(TypedDict): + database: AsyncDatabase + name: str + namespace: Optional[str] + caller_name: Optional[str] + caller_version: Optional[str] + + +class Collection: + def __init__( + self, + database: Database, + name: str, + *, + namespace: Optional[str] = None, + caller_name: Optional[str] = None, + caller_version: Optional[str] = None, + ) -> None: + self._constructor_params: CollectionConstructorParams = { + "database": database, + "name": name, + "namespace": namespace, + "caller_name": caller_name, + "caller_version": caller_version, + } + self._astra_db_collection = AstraDBCollection( + collection_name=name, + astra_db=database._astra_db, + namespace=namespace, + caller_name=caller_name, + caller_version=caller_version, + ) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}[_astra_db_collection="{self._astra_db_collection}"]' + + def __eq__(self, other: Any) -> bool: + if isinstance(other, Collection): + return self._astra_db_collection == other._astra_db_collection + else: + return False + + def copy(self) -> Collection: + return Collection(**self._constructor_params) + + def to_async(self) -> AsyncCollection: + return AsyncCollection( + **{ # type: ignore[arg-type] + **self._constructor_params, + **{"database": self._constructor_params["database"].to_async()}, + } + ) + + def set_caller( + self, + caller_name: Optional[str] = None, + caller_version: Optional[str] = None, + ) -> None: + self._astra_db_collection.set_caller( + caller_name=caller_name, + caller_version=caller_version, + ) + + @unsupported + def find_raw_batches(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + def aggregate(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + def aggregate_raw_batches(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + def watch(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + def rename(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + def create_index(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + def create_indexes(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + def drop_index(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + def drop_indexes(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + def list_indexes(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + def index_information(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + def create_search_index(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + def create_search_indexes(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + def drop_search_index(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + def list_search_indexes(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + def update_search_index(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + def distinct(*pargs: Any, **kwargs: Any) -> Any: ... + + +class AsyncCollection: + def __init__( + self, + database: AsyncDatabase, + name: str, + *, + namespace: Optional[str] = None, + caller_name: Optional[str] = None, + caller_version: Optional[str] = None, + ) -> None: + self._constructor_params: AsyncCollectionConstructorParams = { + "database": database, + "name": name, + "namespace": namespace, + "caller_name": caller_name, + "caller_version": caller_version, + } + self._astra_db_collection = AsyncAstraDBCollection( + collection_name=name, + astra_db=database._astra_db, + namespace=namespace, + caller_name=caller_name, + caller_version=caller_version, + ) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}[_astra_db_collection="{self._astra_db_collection}"]' + + def __eq__(self, other: Any) -> bool: + if isinstance(other, AsyncCollection): + return self._astra_db_collection == other._astra_db_collection + else: + return False + + def copy(self) -> AsyncCollection: + return AsyncCollection(**self._constructor_params) + + def to_sync(self) -> Collection: + return Collection( + **{ # type: ignore[arg-type] + **self._constructor_params, + **{"database": self._constructor_params["database"].to_sync()}, + } + ) + + def set_caller( + self, + caller_name: Optional[str] = None, + caller_version: Optional[str] = None, + ) -> None: + self._astra_db_collection.set_caller( + caller_name=caller_name, + caller_version=caller_version, + ) + + @unsupported + async def find_raw_batches(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + async def aggregate(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + async def aggregate_raw_batches(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + async def watch(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + async def rename(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + async def create_index(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + async def create_indexes(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + async def drop_index(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + async def drop_indexes(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + async def list_indexes(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + async def index_information(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + async def create_search_index(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + async def create_search_indexes(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + async def drop_search_index(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + async def list_search_indexes(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + async def update_search_index(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + async def distinct(*pargs: Any, **kwargs: Any) -> Any: ... diff --git a/astrapy/idiomatic/database.py b/astrapy/idiomatic/database.py new file mode 100644 index 00000000..aa16e516 --- /dev/null +++ b/astrapy/idiomatic/database.py @@ -0,0 +1,329 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from types import TracebackType +from typing import Any, Dict, Optional, Type, TypedDict, Union, TYPE_CHECKING +from astrapy.db import AstraDB, AsyncAstraDB +from astrapy.idiomatic.utils import unsupported + +if TYPE_CHECKING: + from astrapy.idiomatic.collection import AsyncCollection, Collection + + +def _validate_create_collection_options( + dimension: Optional[int] = None, + metric: Optional[str] = None, + indexing: Optional[Dict[str, Any]] = None, + additional_options: Optional[Dict[str, Any]] = None, +) -> None: + if additional_options: + if "vector" in additional_options: + raise ValueError( + "`additional_options` dict parameter to create_collection " + "cannot have a `vector` key. Please use the specific " + "method parameter." + ) + if "indexing" in additional_options: + raise ValueError( + "`additional_options` dict parameter to create_collection " + "cannot have a `indexing` key. Please use the specific " + "method parameter." + ) + if dimension is None and metric is not None: + raise ValueError( + "Cannot specify `metric` and not `vector_dimension` in the " + "create_collection method." + ) + + +class DatabaseConstructorParams(TypedDict): + api_endpoint: str + token: str + namespace: Optional[str] + caller_name: Optional[str] + caller_version: Optional[str] + api_path: Optional[str] + api_version: Optional[str] + + +class Database: + def __init__( + self, + api_endpoint: str, + token: str, + *, + namespace: Optional[str] = None, + caller_name: Optional[str] = None, + caller_version: Optional[str] = None, + api_path: Optional[str] = None, + api_version: Optional[str] = None, + ) -> None: + self._constructor_params: DatabaseConstructorParams = { + "api_endpoint": api_endpoint, + "token": token, + "namespace": namespace, + "caller_name": caller_name, + "caller_version": caller_version, + "api_path": api_path, + "api_version": api_version, + } + self._astra_db = AstraDB( + token=token, + api_endpoint=api_endpoint, + api_path=api_path, + api_version=api_version, + namespace=namespace, + caller_name=caller_name, + caller_version=caller_version, + ) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}[_astra_db={self._astra_db}"]' + + def __eq__(self, other: Any) -> bool: + if isinstance(other, Database): + return self._astra_db == other._astra_db + else: + return False + + def copy(self) -> Database: + return Database(**self._constructor_params) + + def to_async(self) -> AsyncDatabase: + return AsyncDatabase(**self._constructor_params) + + def set_caller( + self, + caller_name: Optional[str] = None, + caller_version: Optional[str] = None, + ) -> None: + self._astra_db.caller_name = caller_name + self._astra_db.caller_version = caller_version + + def get_collection( + self, name: str, *, namespace: Optional[str] = None + ) -> Collection: + # lazy importing here against circular-import error + from astrapy.idiomatic.collection import Collection + + _namespace = namespace or self._constructor_params["namespace"] + return Collection(self, name, namespace=_namespace) + + def create_collection( + self, + name: str, + *, + namespace: Optional[str] = None, + dimension: Optional[int] = None, + metric: Optional[str] = None, + indexing: Optional[Dict[str, Any]] = None, + additional_options: Optional[Dict[str, Any]] = None, + ) -> Collection: + _validate_create_collection_options( + dimension=dimension, + metric=metric, + indexing=indexing, + additional_options=additional_options, + ) + _options = { + **(additional_options or {}), + **({"indexing": indexing} if indexing else {}), + } + if namespace is not None: + self._astra_db.copy(namespace=namespace).create_collection( + name, + options=_options, + dimension=dimension, + metric=metric, + ) + else: + self._astra_db.create_collection( + name, + options=_options, + dimension=dimension, + metric=metric, + ) + return self.get_collection(name, namespace=namespace) + + # TODO, the return type should be a Dict[str, Any] (investigate what) + def drop_collection(self, name_or_collection: Union[str, Collection]) -> None: + # lazy importing here against circular-import error + from astrapy.idiomatic.collection import Collection + + _name: str + if isinstance(name_or_collection, Collection): + _name = name_or_collection._astra_db_collection.collection_name + else: + _name = name_or_collection + self._astra_db.delete_collection(_name) + + @unsupported + def aggregate(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + def cursor_command(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + def dereference(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + def watch(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + def validate_collection(*pargs: Any, **kwargs: Any) -> Any: ... + + +class AsyncDatabase: + def __init__( + self, + api_endpoint: str, + token: str, + *, + namespace: Optional[str] = None, + caller_name: Optional[str] = None, + caller_version: Optional[str] = None, + api_path: Optional[str] = None, + api_version: Optional[str] = None, + ) -> None: + self._constructor_params: DatabaseConstructorParams = { + "api_endpoint": api_endpoint, + "token": token, + "namespace": namespace, + "caller_name": caller_name, + "caller_version": caller_version, + "api_path": api_path, + "api_version": api_version, + } + self._astra_db = AsyncAstraDB( + token=token, + api_endpoint=api_endpoint, + api_path=api_path, + api_version=api_version, + namespace=namespace, + caller_name=caller_name, + caller_version=caller_version, + ) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}[_astra_db={self._astra_db}"]' + + def __eq__(self, other: Any) -> bool: + if isinstance(other, AsyncDatabase): + return self._astra_db == other._astra_db + else: + return False + + async def __aenter__(self) -> AsyncDatabase: + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + await self._astra_db.__aexit__( + exc_type=exc_type, + exc_value=exc_value, + traceback=traceback, + ) + + def copy(self) -> AsyncDatabase: + return AsyncDatabase(**self._constructor_params) + + def to_sync(self) -> Database: + return Database(**self._constructor_params) + + def set_caller( + self, + caller_name: Optional[str] = None, + caller_version: Optional[str] = None, + ) -> None: + self._astra_db.caller_name = caller_name + self._astra_db.caller_version = caller_version + + async def get_collection( + self, name: str, *, namespace: Optional[str] = None + ) -> AsyncCollection: + # lazy importing here against circular-import error + from astrapy.idiomatic.collection import AsyncCollection + + _namespace = namespace or self._constructor_params["namespace"] + return AsyncCollection(self, name, namespace=_namespace) + + async def create_collection( + self, + name: str, + *, + namespace: Optional[str] = None, + dimension: Optional[int] = None, + metric: Optional[str] = None, + indexing: Optional[Dict[str, Any]] = None, + additional_options: Optional[Dict[str, Any]] = None, + ) -> AsyncCollection: + _validate_create_collection_options( + dimension=dimension, + metric=metric, + indexing=indexing, + additional_options=additional_options, + ) + _options = { + **(additional_options or {}), + **({"indexing": indexing} if indexing else {}), + } + if namespace is not None: + await self._astra_db.copy(namespace=namespace).create_collection( + name, + options=_options, + dimension=dimension, + metric=metric, + ) + else: + await self._astra_db.create_collection( + name, + options=_options, + dimension=dimension, + metric=metric, + ) + return await self.get_collection(name, namespace=namespace) + + # TODO, the return type should be a Dict[str, Any] (investigate what) + async def drop_collection( + self, name_or_collection: Union[str, AsyncCollection] + ) -> None: + # lazy importing here against circular-import error + from astrapy.idiomatic.collection import AsyncCollection + + _name: str + if isinstance(name_or_collection, AsyncCollection): + _name = name_or_collection._astra_db_collection.collection_name + else: + _name = name_or_collection + await self._astra_db.delete_collection(_name) + + @unsupported + async def aggregate(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + async def cursor_command(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + async def dereference(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + async def watch(*pargs: Any, **kwargs: Any) -> Any: ... + + @unsupported + async def validate_collection(*pargs: Any, **kwargs: Any) -> Any: ... diff --git a/astrapy/idiomatic/utils.py b/astrapy/idiomatic/utils.py new file mode 100644 index 00000000..994997ac --- /dev/null +++ b/astrapy/idiomatic/utils.py @@ -0,0 +1,26 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations +from functools import wraps +from typing import Any, Callable + +DEFAULT_NOT_SUPPORTED_MESSAGE = "Operation not supported." + + +def unsupported(func: Callable[..., Any]) -> Callable[..., Any]: + @wraps(func) + def unsupported_func(*args: Any, **kwargs: Any) -> Any: + raise TypeError(DEFAULT_NOT_SUPPORTED_MESSAGE) + + return unsupported_func diff --git a/astrapy/ops.py b/astrapy/ops.py index ef23dd95..934cecee 100644 --- a/astrapy/ops.py +++ b/astrapy/ops.py @@ -84,8 +84,23 @@ def __eq__(self, other: Any) -> bool: else: return False - def copy(self) -> AstraDBOps: - return AstraDBOps(**self.constructor_params) + def copy( + self, + *, + token: Optional[str] = None, + dev_ops_url: Optional[str] = None, + dev_ops_api_version: Optional[str] = None, + caller_name: Optional[str] = None, + caller_version: Optional[str] = None, + ) -> AstraDBOps: + return AstraDBOps( + token=token or self.constructor_params["token"], + dev_ops_url=dev_ops_url or self.constructor_params["dev_ops_url"], + dev_ops_api_version=dev_ops_api_version + or self.constructor_params["dev_ops_api_version"], + caller_name=caller_name or self.constructor_params["caller_name"], + caller_version=caller_version or self.constructor_params["caller_version"], + ) def set_caller( self, diff --git a/astrapy/utils.py b/astrapy/utils.py index ca0c5d54..23c2d062 100644 --- a/astrapy/utils.py +++ b/astrapy/utils.py @@ -221,9 +221,7 @@ def make_payload(top_level: str, **kwargs: Any) -> Dict[str, Any]: return json_query -def return_unsupported_error(): - raise Exception("Unsupported operation") - + def convert_vector_to_floats(vector: Iterable[Any]) -> List[float]: """ Convert a vector of strings to a vector of floats. diff --git a/tests/astrapy/conftest.py b/tests/astrapy/conftest.py new file mode 100644 index 00000000..fc897992 --- /dev/null +++ b/tests/astrapy/conftest.py @@ -0,0 +1,385 @@ +""" +Test fixtures +""" + +import os +import math + +import pytest +from typing import ( + AsyncIterable, + Dict, + Iterable, + List, + Optional, + Set, + TypeVar, +) + +import pytest_asyncio + +from ..conftest import AstraDBCredentials +from astrapy.db import AstraDB, AstraDBCollection, AsyncAstraDB, AsyncAstraDBCollection + + +T = TypeVar("T") + + +# fixed +TEST_WRITABLE_VECTOR_COLLECTION = "writable_v_col" +TEST_READONLY_VECTOR_COLLECTION = "readonly_v_col" +TEST_WRITABLE_NONVECTOR_COLLECTION = "writable_nonv_col" +TEST_WRITABLE_ALLOWINDEX_NONVECTOR_COLLECTION = "writable_allowindex_nonv_col" +TEST_WRITABLE_DENYINDEX_NONVECTOR_COLLECTION = "writable_denyindex_nonv_col" + +VECTOR_DOCUMENTS = [ + { + "_id": "1", + "text": "Sample entry number <1>", + "otherfield": {"subfield": "x1y"}, + "anotherfield": "alpha", + "$vector": [0.1, 0.9], + }, + { + "_id": "2", + "text": "Sample entry number <2>", + "otherfield": {"subfield": "x2y"}, + "anotherfield": "alpha", + "$vector": [0.5, 0.5], + }, + { + "_id": "3", + "text": "Sample entry number <3>", + "otherfield": {"subfield": "x3y"}, + "anotherfield": "omega", + "$vector": [0.9, 0.1], + }, +] + +INDEXING_SAMPLE_DOCUMENT = { + "_id": "0", + "A": { + "a": "A.a", + "b": "A.b", + }, + "B": { + "a": "B.a", + "b": "B.b", + }, + "C": { + "a": "C.a", + "b": "C.b", + }, +} + + +def _batch_iterable(iterable: Iterable[T], batch_size: int) -> Iterable[Iterable[T]]: + this_batch = [] + for entry in iterable: + this_batch.append(entry) + if len(this_batch) == batch_size: + yield this_batch + this_batch = [] + if this_batch: + yield this_batch + + +@pytest.fixture(scope="session") +def db(astra_db_credentials_kwargs: AstraDBCredentials) -> AstraDB: + token = astra_db_credentials_kwargs["token"] + api_endpoint = astra_db_credentials_kwargs["api_endpoint"] + namespace = astra_db_credentials_kwargs.get("namespace") + + if token is None or api_endpoint is None: + raise ValueError("Required ASTRA DB configuration is missing") + + return AstraDB(token=token, api_endpoint=api_endpoint, namespace=namespace) + + +@pytest_asyncio.fixture(scope="function") +async def async_db( + astra_db_credentials_kwargs: AstraDBCredentials, +) -> AsyncIterable[AsyncAstraDB]: + token = astra_db_credentials_kwargs["token"] + api_endpoint = astra_db_credentials_kwargs["api_endpoint"] + namespace = astra_db_credentials_kwargs.get("namespace") + + if token is None or api_endpoint is None: + raise ValueError("Required ASTRA DB configuration is missing") + + async with AsyncAstraDB( + token=token, api_endpoint=api_endpoint, namespace=namespace + ) as db: + yield db + + +@pytest.fixture(scope="module") +def invalid_db( + astra_invalid_db_credentials_kwargs: Dict[str, Optional[str]] +) -> AstraDB: + token = astra_invalid_db_credentials_kwargs["token"] + api_endpoint = astra_invalid_db_credentials_kwargs["api_endpoint"] + namespace = astra_invalid_db_credentials_kwargs.get("namespace") + + if token is None or api_endpoint is None: + raise ValueError("Required ASTRA DB configuration is missing") + + return AstraDB(token=token, api_endpoint=api_endpoint, namespace=namespace) + + +@pytest.fixture(scope="session") +def readonly_v_collection(db: AstraDB) -> Iterable[AstraDBCollection]: + collection = db.create_collection( + TEST_READONLY_VECTOR_COLLECTION, + dimension=2, + ) + + collection.clear() + collection.insert_many(VECTOR_DOCUMENTS) + + yield collection + + if int(os.getenv("TEST_SKIP_COLLECTION_DELETE", "0")) == 0: + db.delete_collection(TEST_READONLY_VECTOR_COLLECTION) + + +@pytest.fixture(scope="session") +def writable_v_collection(db: AstraDB) -> Iterable[AstraDBCollection]: + """ + This is lasting for the whole test. Functions can write to it, + no guarantee (i.e. each test should use a different ID... + """ + collection = db.create_collection( + TEST_WRITABLE_VECTOR_COLLECTION, + dimension=2, + ) + + yield collection + + if int(os.getenv("TEST_SKIP_COLLECTION_DELETE", "0")) == 0: + db.delete_collection(TEST_WRITABLE_VECTOR_COLLECTION) + + +@pytest.fixture(scope="function") +def empty_v_collection( + writable_v_collection: AstraDBCollection, +) -> Iterable[AstraDBCollection]: + """available empty to each test function.""" + writable_v_collection.clear() + yield writable_v_collection + + +@pytest.fixture(scope="function") +def disposable_v_collection( + writable_v_collection: AstraDBCollection, +) -> Iterable[AstraDBCollection]: + """available prepopulated to each test function.""" + writable_v_collection.clear() + writable_v_collection.insert_many(VECTOR_DOCUMENTS) + yield writable_v_collection + + +@pytest.fixture(scope="session") +def writable_nonv_collection(db: AstraDB) -> Iterable[AstraDBCollection]: + """ + This is lasting for the whole test. Functions can write to it, + no guarantee (i.e. each test should use a different ID... + """ + collection = db.create_collection(TEST_WRITABLE_NONVECTOR_COLLECTION) + + yield collection + + if int(os.getenv("TEST_SKIP_COLLECTION_DELETE", "0")) == 0: + db.delete_collection(TEST_WRITABLE_NONVECTOR_COLLECTION) + + +@pytest.fixture(scope="function") +def allowindex_nonv_collection(db: AstraDB) -> Iterable[AstraDBCollection]: + """ + This is lasting for the whole test. Functions can write to it, + no guarantee (i.e. each test should use a different ID... + """ + collection = db.create_collection( + TEST_WRITABLE_ALLOWINDEX_NONVECTOR_COLLECTION, + options={ + "indexing": { + "allow": [ + "A", + "C.a", + ], + }, + }, + ) + collection.upsert(INDEXING_SAMPLE_DOCUMENT) + + yield collection + + if int(os.getenv("TEST_SKIP_COLLECTION_DELETE", "0")) == 0: + db.delete_collection(TEST_WRITABLE_ALLOWINDEX_NONVECTOR_COLLECTION) + + +@pytest.fixture(scope="function") +def denyindex_nonv_collection(db: AstraDB) -> Iterable[AstraDBCollection]: + """ + This is lasting for the whole test. Functions can write to it, + no guarantee (i.e. each test should use a different ID... + + Note in light of the sample document this almost results in the same + filtering paths being available ... if one remembers to deny _id here. + """ + collection = db.create_collection( + TEST_WRITABLE_DENYINDEX_NONVECTOR_COLLECTION, + options={ + "indexing": { + "deny": [ + "B", + "C.b", + "_id", + ], + }, + }, + ) + collection.upsert(INDEXING_SAMPLE_DOCUMENT) + + yield collection + + if int(os.getenv("TEST_SKIP_COLLECTION_DELETE", "0")) == 0: + db.delete_collection(TEST_WRITABLE_DENYINDEX_NONVECTOR_COLLECTION) + + +@pytest.fixture(scope="function") +def empty_nonv_collection( + writable_nonv_collection: AstraDBCollection, +) -> Iterable[AstraDBCollection]: + """available empty to each test function.""" + writable_nonv_collection.clear() + yield writable_nonv_collection + + +@pytest.fixture(scope="module") +def invalid_writable_v_collection( + invalid_db: AstraDB, +) -> Iterable[AstraDBCollection]: + collection = invalid_db.collection( + TEST_WRITABLE_VECTOR_COLLECTION, + ) + + yield collection + + +@pytest.fixture(scope="function") +def pagination_v_collection( + empty_v_collection: AstraDBCollection, +) -> Iterable[AstraDBCollection]: + INSERT_BATCH_SIZE = 20 # max 20, fixed by API constraints + N = 200 # must be EVEN + + def _mk_vector(index: int, n_total_steps: int) -> List[float]: + angle = 2 * math.pi * index / n_total_steps + return [math.cos(angle), math.sin(angle)] + + inserted_ids: Set[str] = set() + for i_batch in _batch_iterable(range(N), INSERT_BATCH_SIZE): + batch_ids = empty_v_collection.insert_many( + documents=[{"_id": str(i), "$vector": _mk_vector(i, N)} for i in i_batch] + )["status"]["insertedIds"] + inserted_ids = inserted_ids | set(batch_ids) + assert inserted_ids == {str(i) for i in range(N)} + + yield empty_v_collection + + +@pytest_asyncio.fixture(scope="function") +async def async_readonly_v_collection( + async_db: AsyncAstraDB, + readonly_v_collection: AstraDBCollection, +) -> AsyncIterable[AsyncAstraDBCollection]: + """ + This fixture piggybacks on its sync counterpart (and depends on it): + it must not actually do anything to the collection + """ + collection = await async_db.collection(TEST_READONLY_VECTOR_COLLECTION) + + yield collection + + +@pytest_asyncio.fixture(scope="function") +async def async_writable_v_collection( + async_db: AsyncAstraDB, + writable_v_collection: AstraDBCollection, +) -> AsyncIterable[AsyncAstraDBCollection]: + """ + This fixture piggybacks on its sync counterpart (and depends on it): + it must not actually do anything to the collection + """ + collection = await async_db.collection(TEST_WRITABLE_VECTOR_COLLECTION) + + yield collection + + +@pytest_asyncio.fixture(scope="function") +async def async_empty_v_collection( + async_writable_v_collection: AsyncAstraDBCollection, + empty_v_collection: AstraDBCollection, +) -> AsyncIterable[AsyncAstraDBCollection]: + """ + available empty to each test function. + + This fixture piggybacks on its sync counterpart (and depends on it): + it must not actually do anything to the collection + """ + yield async_writable_v_collection + + +@pytest_asyncio.fixture(scope="function") +async def async_disposable_v_collection( + async_writable_v_collection: AsyncAstraDBCollection, + disposable_v_collection: AstraDBCollection, +) -> AsyncIterable[AsyncAstraDBCollection]: + """ + available prepopulated to each test function. + + This fixture piggybacks on its sync counterpart (and depends on it): + it must not actually do anything to the collection + """ + yield async_writable_v_collection + + +@pytest_asyncio.fixture(scope="function") +async def async_writable_nonv_collection( + async_db: AsyncAstraDB, + writable_nonv_collection: AstraDBCollection, +) -> AsyncIterable[AsyncAstraDBCollection]: + """ + This fixture piggybacks on its sync counterpart (and depends on it): + it must not actually do anything to the collection + """ + collection = await async_db.collection(TEST_WRITABLE_NONVECTOR_COLLECTION) + + yield collection + + +@pytest_asyncio.fixture(scope="function") +async def async_empty_nonv_collection( + async_writable_nonv_collection: AsyncAstraDBCollection, + empty_nonv_collection: AstraDBCollection, +) -> AsyncIterable[AsyncAstraDBCollection]: + """ + available empty to each test function. + + This fixture piggybacks on its sync counterpart (and depends on it): + it must not actually do anything to the collection + """ + yield async_writable_nonv_collection + + +@pytest_asyncio.fixture(scope="function") +async def async_pagination_v_collection( + async_empty_v_collection: AsyncAstraDBCollection, + pagination_v_collection: AstraDBCollection, +) -> AsyncIterable[AsyncAstraDBCollection]: + """ + This fixture piggybacks on its sync counterpart (and depends on it): + it must not actually do anything to the collection + """ + yield async_empty_v_collection diff --git a/tests/astrapy/test_async_db_ddl.py b/tests/astrapy/test_async_db_ddl.py index 1bdb6f0e..6fd0f033 100644 --- a/tests/astrapy/test_async_db_ddl.py +++ b/tests/astrapy/test_async_db_ddl.py @@ -18,10 +18,10 @@ import os import logging -from typing import Dict, Optional import pytest +from ..conftest import AstraDBCredentials from astrapy.db import AsyncAstraDB, AsyncAstraDBCollection from astrapy.defaults import DEFAULT_KEYSPACE_NAME @@ -33,7 +33,7 @@ @pytest.mark.describe("should confirm path handling in constructor (async)") async def test_path_handling( - astra_db_credentials_kwargs: Dict[str, Optional[str]] + astra_db_credentials_kwargs: AstraDBCredentials, ) -> None: token = astra_db_credentials_kwargs["token"] api_endpoint = astra_db_credentials_kwargs["api_endpoint"] diff --git a/tests/astrapy/test_conversions.py b/tests/astrapy/test_conversions.py index 96203b95..bc6e0a3e 100644 --- a/tests/astrapy/test_conversions.py +++ b/tests/astrapy/test_conversions.py @@ -174,6 +174,173 @@ def test_copy_methods() -> None: assert c_adb_ops is not adb_ops +@pytest.mark.describe("test parameter override in copy methods") +def test_parameter_override_copy_methods() -> None: + sync_astradb = AstraDB( + token="token", + api_endpoint="api_endpoint", + api_path="api_path", + api_version="api_version", + namespace="namespace", + caller_name="caller_name", + caller_version="caller_version", + ) + sync_astradb2 = AstraDB( + token="token2", + api_endpoint="api_endpoint2", + api_path="api_path2", + api_version="api_version2", + namespace="namespace2", + caller_name="caller_name2", + caller_version="caller_version2", + ) + c_sync_astradb = sync_astradb.copy( + token="token2", + api_endpoint="api_endpoint2", + api_path="api_path2", + api_version="api_version2", + namespace="namespace2", + caller_name="caller_name2", + caller_version="caller_version2", + ) + assert c_sync_astradb == sync_astradb2 + + async_astradb = AsyncAstraDB( + token="token", + api_endpoint="api_endpoint", + api_path="api_path", + api_version="api_version", + namespace="namespace", + caller_name="caller_name", + caller_version="caller_version", + ) + async_astradb2 = AsyncAstraDB( + token="token2", + api_endpoint="api_endpoint2", + api_path="api_path2", + api_version="api_version2", + namespace="namespace2", + caller_name="caller_name2", + caller_version="caller_version2", + ) + c_async_astradb = async_astradb.copy( + token="token2", + api_endpoint="api_endpoint2", + api_path="api_path2", + api_version="api_version2", + namespace="namespace2", + caller_name="caller_name2", + caller_version="caller_version2", + ) + assert c_async_astradb == async_astradb2 + + sync_adbcollection = AstraDBCollection( + collection_name="collection_name", + astra_db=sync_astradb, + caller_name="caller_name", + caller_version="caller_version", + ) + sync_adbcollection2 = AstraDBCollection( + collection_name="collection_name2", + astra_db=sync_astradb2, + caller_name="caller_name2", + caller_version="caller_version2", + ) + c_sync_adbcollection = sync_adbcollection.copy( + collection_name="collection_name2", + token="token2", + api_endpoint="api_endpoint2", + api_path="api_path2", + api_version="api_version2", + namespace="namespace2", + caller_name="caller_name2", + caller_version="caller_version2", + ) + assert c_sync_adbcollection == sync_adbcollection2 + + async_adbcollection = AsyncAstraDBCollection( + collection_name="collection_name", + astra_db=async_astradb, + caller_name="caller_name", + caller_version="caller_version", + ) + async_adbcollection2 = AsyncAstraDBCollection( + collection_name="collection_name2", + astra_db=async_astradb2, + caller_name="caller_name2", + caller_version="caller_version2", + ) + c_async_adbcollection = async_adbcollection.copy( + collection_name="collection_name2", + token="token2", + api_endpoint="api_endpoint2", + api_path="api_path2", + api_version="api_version2", + namespace="namespace2", + caller_name="caller_name2", + caller_version="caller_version2", + ) + assert c_async_adbcollection == async_adbcollection2 + + adb_ops = AstraDBOps( + token="token", + dev_ops_url="dev_ops_url", + dev_ops_api_version="dev_ops_api_version", + caller_name="caller_name", + caller_version="caller_version", + ) + adb_ops2 = AstraDBOps( + token="token2", + dev_ops_url="dev_ops_url2", + dev_ops_api_version="dev_ops_api_version2", + caller_name="caller_name2", + caller_version="caller_version2", + ) + c_adb_ops = adb_ops.copy( + token="token2", + dev_ops_url="dev_ops_url2", + dev_ops_api_version="dev_ops_api_version2", + caller_name="caller_name2", + caller_version="caller_version2", + ) + assert c_adb_ops == adb_ops2 + + +@pytest.mark.describe("test parameter override when instantiating collections") +def test_parameter_override_collection_instances() -> None: + astradb0 = AstraDB(token="t0", api_endpoint="a0") + astradb1 = AstraDB(token="t1", api_endpoint="a1", namespace="n1") + col0 = AstraDBCollection( + collection_name="col0", + astra_db=astradb0, + ) + col1 = AstraDBCollection( + collection_name="col0", + astra_db=astradb0, + token="t1", + api_endpoint="a1", + namespace="n1", + ) + assert col0 != col1 + assert col1 == AstraDBCollection(collection_name="col0", astra_db=astradb1) + + a_astradb0 = AsyncAstraDB(token="t0", api_endpoint="a0") + a_astradb1 = AsyncAstraDB(token="t1", api_endpoint="a1", namespace="n1") + a_col0 = AsyncAstraDBCollection( + collection_name="col0", + astra_db=a_astradb0, + ) + a_col1 = AsyncAstraDBCollection( + collection_name="col0", + astra_db=a_astradb0, + token="t1", + api_endpoint="a1", + namespace="n1", + ) + assert a_col0 != a_col1 + assert a_col1 == AsyncAstraDBCollection(collection_name="col0", astra_db=a_astradb1) + + @pytest.mark.describe("test set_caller works in place for clients") def test_set_caller_clients() -> None: astradb0 = AstraDB(token="t1", api_endpoint="a1") diff --git a/tests/astrapy/test_db_ddl.py b/tests/astrapy/test_db_ddl.py index 0d752e0e..75b4a690 100644 --- a/tests/astrapy/test_db_ddl.py +++ b/tests/astrapy/test_db_ddl.py @@ -18,10 +18,10 @@ import os import logging -from typing import Dict, Optional import pytest +from ..conftest import AstraDBCredentials from astrapy.db import AstraDB, AstraDBCollection from astrapy.defaults import DEFAULT_KEYSPACE_NAME @@ -32,7 +32,7 @@ @pytest.mark.describe("should confirm path handling in constructor") -def test_path_handling(astra_db_credentials_kwargs: Dict[str, Optional[str]]) -> None: +def test_path_handling(astra_db_credentials_kwargs: AstraDBCredentials) -> None: token = astra_db_credentials_kwargs["token"] api_endpoint = astra_db_credentials_kwargs["api_endpoint"] namespace = astra_db_credentials_kwargs.get("namespace") diff --git a/tests/astrapy/test_db_dml.py b/tests/astrapy/test_db_dml.py index f519dc22..174a0ee2 100644 --- a/tests/astrapy/test_db_dml.py +++ b/tests/astrapy/test_db_dml.py @@ -39,6 +39,7 @@ def test_clear_collection_fail(db: AstraDB) -> None: with pytest.raises(APIRequestError): db.collection("this$does%not exist!!!").clear() + @pytest.mark.describe("should truncate a nonvector collection through AstraDB") def test_truncate_nonvector_collection_through_astradb( db: AstraDB, empty_nonv_collection: AstraDBCollection @@ -1267,12 +1268,6 @@ def test_find_find_one_non_equality_operators( ) assert resp8["data"]["documents"][0]["marker"] == "abc" -@pytest.mark.describe("test unsupported operation") -def test_unsupported_operation( - writable_v_collection: AstraDBCollection, -) -> None: - with pytest.raises(Exception): - writeable_v_collection.aggregate() @pytest.mark.describe("store and retrieve dates and datetimes correctly") def test_insert_find_with_dates( diff --git a/tests/conftest.py b/tests/conftest.py index a2948a42..812c32e3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,98 +1,21 @@ -""" -Test fixtures -""" - +# main conftest for shared fixtures (if any). import os -import math - import pytest -from typing import ( - AsyncIterable, - Dict, - Iterable, - List, - Optional, - Set, - TypeVar, - TypedDict, -) - -import pytest_asyncio +from typing import Optional, TypedDict from astrapy.defaults import DEFAULT_KEYSPACE_NAME -from astrapy.db import AstraDB, AstraDBCollection, AsyncAstraDB, AsyncAstraDBCollection - -T = TypeVar("T") - - -ASTRA_DB_APPLICATION_TOKEN = os.environ["ASTRA_DB_APPLICATION_TOKEN"] -ASTRA_DB_API_ENDPOINT = os.environ["ASTRA_DB_API_ENDPOINT"] - -ASTRA_DB_KEYSPACE = os.environ.get("ASTRA_DB_KEYSPACE", DEFAULT_KEYSPACE_NAME) - -# fixed -TEST_WRITABLE_VECTOR_COLLECTION = "writable_v_col" -TEST_READONLY_VECTOR_COLLECTION = "readonly_v_col" -TEST_WRITABLE_NONVECTOR_COLLECTION = "writable_nonv_col" -TEST_WRITABLE_ALLOWINDEX_NONVECTOR_COLLECTION = "writable_allowindex_nonv_col" -TEST_WRITABLE_DENYINDEX_NONVECTOR_COLLECTION = "writable_denyindex_nonv_col" - -VECTOR_DOCUMENTS = [ - { - "_id": "1", - "text": "Sample entry number <1>", - "otherfield": {"subfield": "x1y"}, - "anotherfield": "alpha", - "$vector": [0.1, 0.9], - }, - { - "_id": "2", - "text": "Sample entry number <2>", - "otherfield": {"subfield": "x2y"}, - "anotherfield": "alpha", - "$vector": [0.5, 0.5], - }, - { - "_id": "3", - "text": "Sample entry number <3>", - "otherfield": {"subfield": "x3y"}, - "anotherfield": "omega", - "$vector": [0.9, 0.1], - }, -] -INDEXING_SAMPLE_DOCUMENT = { - "_id": "0", - "A": { - "a": "A.a", - "b": "A.b", - }, - "B": { - "a": "B.a", - "b": "B.b", - }, - "C": { - "a": "C.a", - "b": "C.b", - }, -} - -class AstraDBCredentials(TypedDict, total=False): +class AstraDBCredentials(TypedDict): token: str api_endpoint: str namespace: Optional[str] -def _batch_iterable(iterable: Iterable[T], batch_size: int) -> Iterable[Iterable[T]]: - this_batch = [] - for entry in iterable: - this_batch.append(entry) - if len(this_batch) == batch_size: - yield this_batch - this_batch = [] - if this_batch: - yield this_batch +ASTRA_DB_APPLICATION_TOKEN = os.environ["ASTRA_DB_APPLICATION_TOKEN"] +ASTRA_DB_API_ENDPOINT = os.environ["ASTRA_DB_API_ENDPOINT"] + +ASTRA_DB_KEYSPACE = os.environ.get("ASTRA_DB_KEYSPACE", DEFAULT_KEYSPACE_NAME) @pytest.fixture(scope="session") @@ -115,304 +38,3 @@ def astra_invalid_db_credentials_kwargs() -> AstraDBCredentials: } return astra_db_creds - - -@pytest.fixture(scope="session") -def db(astra_db_credentials_kwargs: Dict[str, Optional[str]]) -> AstraDB: - token = astra_db_credentials_kwargs["token"] - api_endpoint = astra_db_credentials_kwargs["api_endpoint"] - namespace = astra_db_credentials_kwargs.get("namespace") - - if token is None or api_endpoint is None: - raise ValueError("Required ASTRA DB configuration is missing") - - return AstraDB(token=token, api_endpoint=api_endpoint, namespace=namespace) - - -@pytest_asyncio.fixture(scope="function") -async def async_db( - astra_db_credentials_kwargs: Dict[str, Optional[str]] -) -> AsyncIterable[AsyncAstraDB]: - token = astra_db_credentials_kwargs["token"] - api_endpoint = astra_db_credentials_kwargs["api_endpoint"] - namespace = astra_db_credentials_kwargs.get("namespace") - - if token is None or api_endpoint is None: - raise ValueError("Required ASTRA DB configuration is missing") - - async with AsyncAstraDB( - token=token, api_endpoint=api_endpoint, namespace=namespace - ) as db: - yield db - - -@pytest.fixture(scope="module") -def invalid_db( - astra_invalid_db_credentials_kwargs: Dict[str, Optional[str]] -) -> AstraDB: - token = astra_invalid_db_credentials_kwargs["token"] - api_endpoint = astra_invalid_db_credentials_kwargs["api_endpoint"] - namespace = astra_invalid_db_credentials_kwargs.get("namespace") - - if token is None or api_endpoint is None: - raise ValueError("Required ASTRA DB configuration is missing") - - return AstraDB(token=token, api_endpoint=api_endpoint, namespace=namespace) - - -@pytest.fixture(scope="session") -def readonly_v_collection(db: AstraDB) -> Iterable[AstraDBCollection]: - collection = db.create_collection( - TEST_READONLY_VECTOR_COLLECTION, - dimension=2, - ) - - collection.clear() - collection.insert_many(VECTOR_DOCUMENTS) - - yield collection - - if int(os.getenv("TEST_SKIP_COLLECTION_DELETE", "0")) == 0: - db.delete_collection(TEST_READONLY_VECTOR_COLLECTION) - - -@pytest.fixture(scope="session") -def writable_v_collection(db: AstraDB) -> Iterable[AstraDBCollection]: - """ - This is lasting for the whole test. Functions can write to it, - no guarantee (i.e. each test should use a different ID... - """ - collection = db.create_collection( - TEST_WRITABLE_VECTOR_COLLECTION, - dimension=2, - ) - - yield collection - - if int(os.getenv("TEST_SKIP_COLLECTION_DELETE", "0")) == 0: - db.delete_collection(TEST_WRITABLE_VECTOR_COLLECTION) - - -@pytest.fixture(scope="function") -def empty_v_collection( - writable_v_collection: AstraDBCollection, -) -> Iterable[AstraDBCollection]: - """available empty to each test function.""" - writable_v_collection.clear() - yield writable_v_collection - - -@pytest.fixture(scope="function") -def disposable_v_collection( - writable_v_collection: AstraDBCollection, -) -> Iterable[AstraDBCollection]: - """available prepopulated to each test function.""" - writable_v_collection.clear() - writable_v_collection.insert_many(VECTOR_DOCUMENTS) - yield writable_v_collection - - -@pytest.fixture(scope="session") -def writable_nonv_collection(db: AstraDB) -> Iterable[AstraDBCollection]: - """ - This is lasting for the whole test. Functions can write to it, - no guarantee (i.e. each test should use a different ID... - """ - collection = db.create_collection(TEST_WRITABLE_NONVECTOR_COLLECTION) - - yield collection - - if int(os.getenv("TEST_SKIP_COLLECTION_DELETE", "0")) == 0: - db.delete_collection(TEST_WRITABLE_NONVECTOR_COLLECTION) - - -@pytest.fixture(scope="function") -def allowindex_nonv_collection(db: AstraDB) -> Iterable[AstraDBCollection]: - """ - This is lasting for the whole test. Functions can write to it, - no guarantee (i.e. each test should use a different ID... - """ - collection = db.create_collection( - TEST_WRITABLE_ALLOWINDEX_NONVECTOR_COLLECTION, - options={ - "indexing": { - "allow": [ - "A", - "C.a", - ], - }, - }, - ) - collection.upsert(INDEXING_SAMPLE_DOCUMENT) - - yield collection - - if int(os.getenv("TEST_SKIP_COLLECTION_DELETE", "0")) == 0: - db.delete_collection(TEST_WRITABLE_ALLOWINDEX_NONVECTOR_COLLECTION) - - -@pytest.fixture(scope="function") -def denyindex_nonv_collection(db: AstraDB) -> Iterable[AstraDBCollection]: - """ - This is lasting for the whole test. Functions can write to it, - no guarantee (i.e. each test should use a different ID... - - Note in light of the sample document this almost results in the same - filtering paths being available ... if one remembers to deny _id here. - """ - collection = db.create_collection( - TEST_WRITABLE_DENYINDEX_NONVECTOR_COLLECTION, - options={ - "indexing": { - "deny": [ - "B", - "C.b", - "_id", - ], - }, - }, - ) - collection.upsert(INDEXING_SAMPLE_DOCUMENT) - - yield collection - - if int(os.getenv("TEST_SKIP_COLLECTION_DELETE", "0")) == 0: - db.delete_collection(TEST_WRITABLE_DENYINDEX_NONVECTOR_COLLECTION) - - -@pytest.fixture(scope="function") -def empty_nonv_collection( - writable_nonv_collection: AstraDBCollection, -) -> Iterable[AstraDBCollection]: - """available empty to each test function.""" - writable_nonv_collection.clear() - yield writable_nonv_collection - - -@pytest.fixture(scope="module") -def invalid_writable_v_collection( - invalid_db: AstraDB, -) -> Iterable[AstraDBCollection]: - collection = invalid_db.collection( - TEST_WRITABLE_VECTOR_COLLECTION, - ) - - yield collection - - -@pytest.fixture(scope="function") -def pagination_v_collection( - empty_v_collection: AstraDBCollection, -) -> Iterable[AstraDBCollection]: - INSERT_BATCH_SIZE = 20 # max 20, fixed by API constraints - N = 200 # must be EVEN - - def _mk_vector(index: int, n_total_steps: int) -> List[float]: - angle = 2 * math.pi * index / n_total_steps - return [math.cos(angle), math.sin(angle)] - - inserted_ids: Set[str] = set() - for i_batch in _batch_iterable(range(N), INSERT_BATCH_SIZE): - batch_ids = empty_v_collection.insert_many( - documents=[{"_id": str(i), "$vector": _mk_vector(i, N)} for i in i_batch] - )["status"]["insertedIds"] - inserted_ids = inserted_ids | set(batch_ids) - assert inserted_ids == {str(i) for i in range(N)} - - yield empty_v_collection - - -@pytest_asyncio.fixture(scope="function") -async def async_readonly_v_collection( - async_db: AsyncAstraDB, - readonly_v_collection: AstraDBCollection, -) -> AsyncIterable[AsyncAstraDBCollection]: - """ - This fixture piggybacks on its sync counterpart (and depends on it): - it must not actually do anything to the collection - """ - collection = await async_db.collection(TEST_READONLY_VECTOR_COLLECTION) - - yield collection - - -@pytest_asyncio.fixture(scope="function") -async def async_writable_v_collection( - async_db: AsyncAstraDB, - writable_v_collection: AstraDBCollection, -) -> AsyncIterable[AsyncAstraDBCollection]: - """ - This fixture piggybacks on its sync counterpart (and depends on it): - it must not actually do anything to the collection - """ - collection = await async_db.collection(TEST_WRITABLE_VECTOR_COLLECTION) - - yield collection - - -@pytest_asyncio.fixture(scope="function") -async def async_empty_v_collection( - async_writable_v_collection: AsyncAstraDBCollection, - empty_v_collection: AstraDBCollection, -) -> AsyncIterable[AsyncAstraDBCollection]: - """ - available empty to each test function. - - This fixture piggybacks on its sync counterpart (and depends on it): - it must not actually do anything to the collection - """ - yield async_writable_v_collection - - -@pytest_asyncio.fixture(scope="function") -async def async_disposable_v_collection( - async_writable_v_collection: AsyncAstraDBCollection, - disposable_v_collection: AstraDBCollection, -) -> AsyncIterable[AsyncAstraDBCollection]: - """ - available prepopulated to each test function. - - This fixture piggybacks on its sync counterpart (and depends on it): - it must not actually do anything to the collection - """ - yield async_writable_v_collection - - -@pytest_asyncio.fixture(scope="function") -async def async_writable_nonv_collection( - async_db: AsyncAstraDB, - writable_nonv_collection: AstraDBCollection, -) -> AsyncIterable[AsyncAstraDBCollection]: - """ - This fixture piggybacks on its sync counterpart (and depends on it): - it must not actually do anything to the collection - """ - collection = await async_db.collection(TEST_WRITABLE_NONVECTOR_COLLECTION) - - yield collection - - -@pytest_asyncio.fixture(scope="function") -async def async_empty_nonv_collection( - async_writable_nonv_collection: AsyncAstraDBCollection, - empty_nonv_collection: AstraDBCollection, -) -> AsyncIterable[AsyncAstraDBCollection]: - """ - available empty to each test function. - - This fixture piggybacks on its sync counterpart (and depends on it): - it must not actually do anything to the collection - """ - yield async_writable_nonv_collection - - -@pytest_asyncio.fixture(scope="function") -async def async_pagination_v_collection( - async_empty_v_collection: AsyncAstraDBCollection, - pagination_v_collection: AstraDBCollection, -) -> AsyncIterable[AsyncAstraDBCollection]: - """ - This fixture piggybacks on its sync counterpart (and depends on it): - it must not actually do anything to the collection - """ - yield async_empty_v_collection diff --git a/tests/idiomatic/__init__.py b/tests/idiomatic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/idiomatic/conftest.py b/tests/idiomatic/conftest.py new file mode 100644 index 00000000..50dec1d1 --- /dev/null +++ b/tests/idiomatic/conftest.py @@ -0,0 +1,54 @@ +"""Fixtures specific to the idiomatic-side testing, if any.""" + +from typing import Iterable +import pytest + +from ..conftest import AstraDBCredentials +from astrapy import AsyncCollection, AsyncDatabase, Collection, Database + +TEST_COLLECTION_NAME = "test_coll_sync" + + +@pytest.fixture(scope="session") +def sync_database( + astra_db_credentials_kwargs: AstraDBCredentials, +) -> Iterable[Database]: + yield Database(**astra_db_credentials_kwargs) + + +@pytest.fixture(scope="session") +def async_database( + astra_db_credentials_kwargs: AstraDBCredentials, +) -> Iterable[AsyncDatabase]: + yield AsyncDatabase(**astra_db_credentials_kwargs) + + +@pytest.fixture(scope="session") +def sync_collection( + astra_db_credentials_kwargs: AstraDBCredentials, + sync_database: Database, +) -> Iterable[Collection]: + yield Collection( + sync_database, + TEST_COLLECTION_NAME, + namespace=astra_db_credentials_kwargs["namespace"], + ) + + +@pytest.fixture(scope="session") +def async_collection( + astra_db_credentials_kwargs: AstraDBCredentials, + async_database: AsyncDatabase, +) -> Iterable[AsyncCollection]: + yield AsyncCollection( + async_database, + TEST_COLLECTION_NAME, + namespace=astra_db_credentials_kwargs["namespace"], + ) + + +__all__ = [ + "AstraDBCredentials", + "sync_database", + "async_database", +] diff --git a/tests/idiomatic/integration/__init__.py b/tests/idiomatic/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/idiomatic/integration/test_collections_async.py b/tests/idiomatic/integration/test_collections_async.py new file mode 100644 index 00000000..b547408e --- /dev/null +++ b/tests/idiomatic/integration/test_collections_async.py @@ -0,0 +1,115 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from astrapy import AsyncCollection, AsyncDatabase + + +class TestCollectionsAsync: + @pytest.mark.describe("test of instantiating Collection, async") + async def test_instantiate_collection_async( + self, + async_database: AsyncDatabase, + ) -> None: + col1 = AsyncCollection( + async_database, + "id_test_collection", + caller_name="c_n", + caller_version="c_v", + ) + col2 = AsyncCollection( + async_database, + "id_test_collection", + caller_name="c_n", + caller_version="c_v", + ) + assert col1 == col2 + + @pytest.mark.describe("test of Collection conversions, async") + async def test_convert_collection_async( + self, + async_database: AsyncDatabase, + ) -> None: + col1 = AsyncCollection( + async_database, + "id_test_collection", + caller_name="c_n", + caller_version="c_v", + ) + assert col1 == col1.copy() + assert col1 == col1.to_sync().to_async() + + @pytest.mark.describe("test of Collection set_caller, async") + async def test_collection_set_caller_async( + self, + async_database: AsyncDatabase, + ) -> None: + col1 = AsyncCollection( + async_database, + "id_test_collection", + caller_name="c_n1", + caller_version="c_v1", + ) + col2 = AsyncCollection( + async_database, + "id_test_collection", + caller_name="c_n2", + caller_version="c_v2", + ) + col2.set_caller( + caller_name="c_n1", + caller_version="c_v1", + ) + assert col1 == col2 + + @pytest.mark.describe("test errors for unsupported Collection methods, async") + async def test_collection_unsupported_methods_async( + self, + async_collection: AsyncCollection, + ) -> None: + with pytest.raises(TypeError): + await async_collection.find_raw_batches(1, "x") + with pytest.raises(TypeError): + await async_collection.aggregate(1, "x") + with pytest.raises(TypeError): + await async_collection.aggregate_raw_batches(1, "x") + with pytest.raises(TypeError): + await async_collection.watch(1, "x") + with pytest.raises(TypeError): + await async_collection.rename(1, "x") + with pytest.raises(TypeError): + await async_collection.create_index(1, "x") + with pytest.raises(TypeError): + await async_collection.create_indexes(1, "x") + with pytest.raises(TypeError): + await async_collection.drop_index(1, "x") + with pytest.raises(TypeError): + await async_collection.drop_indexes(1, "x") + with pytest.raises(TypeError): + await async_collection.list_indexes(1, "x") + with pytest.raises(TypeError): + await async_collection.index_information(1, "x") + with pytest.raises(TypeError): + await async_collection.create_search_index(1, "x") + with pytest.raises(TypeError): + await async_collection.create_search_indexes(1, "x") + with pytest.raises(TypeError): + await async_collection.drop_search_index(1, "x") + with pytest.raises(TypeError): + await async_collection.list_search_indexes(1, "x") + with pytest.raises(TypeError): + await async_collection.update_search_index(1, "x") + with pytest.raises(TypeError): + await async_collection.distinct(1, "x") diff --git a/tests/idiomatic/integration/test_collections_sync.py b/tests/idiomatic/integration/test_collections_sync.py new file mode 100644 index 00000000..7eb5266f --- /dev/null +++ b/tests/idiomatic/integration/test_collections_sync.py @@ -0,0 +1,115 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from astrapy import Collection, Database + + +class TestCollectionsSync: + @pytest.mark.describe("test of instantiating Collection, sync") + def test_instantiate_collection_sync( + self, + sync_database: Database, + ) -> None: + col1 = Collection( + sync_database, + "id_test_collection", + caller_name="c_n", + caller_version="c_v", + ) + col2 = Collection( + sync_database, + "id_test_collection", + caller_name="c_n", + caller_version="c_v", + ) + assert col1 == col2 + + @pytest.mark.describe("test of Collection conversions, sync") + def test_convert_collection_sync( + self, + sync_database: Database, + ) -> None: + col1 = Collection( + sync_database, + "id_test_collection", + caller_name="c_n", + caller_version="c_v", + ) + assert col1 == col1.copy() + assert col1 == col1.to_async().to_sync() + + @pytest.mark.describe("test of Collection set_caller, sync") + def test_collection_set_caller_sync( + self, + sync_database: Database, + ) -> None: + col1 = Collection( + sync_database, + "id_test_collection", + caller_name="c_n1", + caller_version="c_v1", + ) + col2 = Collection( + sync_database, + "id_test_collection", + caller_name="c_n2", + caller_version="c_v2", + ) + col2.set_caller( + caller_name="c_n1", + caller_version="c_v1", + ) + assert col1 == col2 + + @pytest.mark.describe("test errors for unsupported Collection methods, sync") + def test_collection_unsupported_methods_sync( + self, + sync_collection: Collection, + ) -> None: + with pytest.raises(TypeError): + sync_collection.find_raw_batches(1, "x") + with pytest.raises(TypeError): + sync_collection.aggregate(1, "x") + with pytest.raises(TypeError): + sync_collection.aggregate_raw_batches(1, "x") + with pytest.raises(TypeError): + sync_collection.watch(1, "x") + with pytest.raises(TypeError): + sync_collection.rename(1, "x") + with pytest.raises(TypeError): + sync_collection.create_index(1, "x") + with pytest.raises(TypeError): + sync_collection.create_indexes(1, "x") + with pytest.raises(TypeError): + sync_collection.drop_index(1, "x") + with pytest.raises(TypeError): + sync_collection.drop_indexes(1, "x") + with pytest.raises(TypeError): + sync_collection.list_indexes(1, "x") + with pytest.raises(TypeError): + sync_collection.index_information(1, "x") + with pytest.raises(TypeError): + sync_collection.create_search_index(1, "x") + with pytest.raises(TypeError): + sync_collection.create_search_indexes(1, "x") + with pytest.raises(TypeError): + sync_collection.drop_search_index(1, "x") + with pytest.raises(TypeError): + sync_collection.list_search_indexes(1, "x") + with pytest.raises(TypeError): + sync_collection.update_search_index(1, "x") + with pytest.raises(TypeError): + sync_collection.distinct(1, "x") diff --git a/tests/idiomatic/integration/test_databases_async.py b/tests/idiomatic/integration/test_databases_async.py new file mode 100644 index 00000000..700b799f --- /dev/null +++ b/tests/idiomatic/integration/test_databases_async.py @@ -0,0 +1,105 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from ..conftest import AstraDBCredentials, TEST_COLLECTION_NAME +from astrapy import AsyncCollection, AsyncDatabase + + +class TestDatabasesAsync: + @pytest.mark.describe("test of instantiating Database, async") + async def test_instantiate_database_async( + self, + astra_db_credentials_kwargs: AstraDBCredentials, + ) -> None: + db1 = AsyncDatabase( + caller_name="c_n", + caller_version="c_v", + **astra_db_credentials_kwargs, + ) + db2 = AsyncDatabase( + caller_name="c_n", + caller_version="c_v", + **astra_db_credentials_kwargs, + ) + assert db1 == db2 + + @pytest.mark.describe("test of Database conversions, async") + async def test_convert_database_async( + self, + astra_db_credentials_kwargs: AstraDBCredentials, + ) -> None: + db1 = AsyncDatabase( + caller_name="c_n", + caller_version="c_v", + **astra_db_credentials_kwargs, + ) + assert db1 == db1.copy() + assert db1 == db1.to_sync().to_async() + + @pytest.mark.describe("test of Database set_caller, async") + async def test_database_set_caller_async( + self, + astra_db_credentials_kwargs: AstraDBCredentials, + ) -> None: + db1 = AsyncDatabase( + caller_name="c_n1", + caller_version="c_v1", + **astra_db_credentials_kwargs, + ) + db2 = AsyncDatabase( + caller_name="c_n2", + caller_version="c_v2", + **astra_db_credentials_kwargs, + ) + db2.set_caller( + caller_name="c_n1", + caller_version="c_v1", + ) + assert db1 == db2 + + @pytest.mark.describe("test errors for unsupported Database methods, async") + async def test_database_unsupported_methods_async( + self, + async_database: AsyncDatabase, + ) -> None: + with pytest.raises(TypeError): + await async_database.aggregate(1, "x") + with pytest.raises(TypeError): + await async_database.cursor_command(1, "x") + with pytest.raises(TypeError): + await async_database.dereference(1, "x") + with pytest.raises(TypeError): + await async_database.watch(1, "x") + with pytest.raises(TypeError): + await async_database.validate_collection(1, "x") + + @pytest.mark.describe("test get_collection method, async") + async def test_database_get_collection_async( + self, + async_database: AsyncDatabase, + async_collection: AsyncCollection, + ) -> None: + collection = await async_database.get_collection(TEST_COLLECTION_NAME) + assert collection == async_collection + + NAMESPACE_2 = "other_namespace" + collection_ns2 = await async_database.get_collection( + TEST_COLLECTION_NAME, namespace=NAMESPACE_2 + ) + assert collection_ns2 == AsyncCollection( + async_database, TEST_COLLECTION_NAME, namespace=NAMESPACE_2 + ) + assert collection_ns2._astra_db_collection.astra_db.namespace == NAMESPACE_2 diff --git a/tests/idiomatic/integration/test_databases_sync.py b/tests/idiomatic/integration/test_databases_sync.py new file mode 100644 index 00000000..e300186a --- /dev/null +++ b/tests/idiomatic/integration/test_databases_sync.py @@ -0,0 +1,106 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from ..conftest import AstraDBCredentials, TEST_COLLECTION_NAME +from astrapy import Collection, Database + + +class TestDatabasesSync: + @pytest.mark.describe("test of instantiating Database, sync") + def test_instantiate_database_sync( + self, + astra_db_credentials_kwargs: AstraDBCredentials, + ) -> None: + db1 = Database( + caller_name="c_n", + caller_version="c_v", + **astra_db_credentials_kwargs, + ) + db2 = Database( + caller_name="c_n", + caller_version="c_v", + **astra_db_credentials_kwargs, + ) + assert db1 == db2 + + @pytest.mark.describe("test of Database conversions, sync") + def test_convert_database_sync( + self, + astra_db_credentials_kwargs: AstraDBCredentials, + ) -> None: + db1 = Database( + caller_name="c_n", + caller_version="c_v", + **astra_db_credentials_kwargs, + ) + assert db1 == db1.copy() + assert db1 == db1.to_async().to_sync() + + @pytest.mark.describe("test of Database set_caller, sync") + def test_database_set_caller_sync( + self, + astra_db_credentials_kwargs: AstraDBCredentials, + ) -> None: + db1 = Database( + caller_name="c_n1", + caller_version="c_v1", + **astra_db_credentials_kwargs, + ) + db2 = Database( + caller_name="c_n2", + caller_version="c_v2", + **astra_db_credentials_kwargs, + ) + db2.set_caller( + caller_name="c_n1", + caller_version="c_v1", + ) + assert db1 == db2 + + @pytest.mark.describe("test errors for unsupported Database methods, sync") + def test_database_unsupported_methods_sync( + self, + sync_database: Database, + ) -> None: + with pytest.raises(TypeError): + sync_database.aggregate(1, "x") + with pytest.raises(TypeError): + sync_database.cursor_command(1, "x") + with pytest.raises(TypeError): + sync_database.dereference(1, "x") + with pytest.raises(TypeError): + sync_database.watch(1, "x") + with pytest.raises(TypeError): + sync_database.validate_collection(1, "x") + + @pytest.mark.describe("test get_collection method, sync") + def test_database_get_collection_sync( + self, + sync_database: Database, + sync_collection: Collection, + astra_db_credentials_kwargs: AstraDBCredentials, + ) -> None: + collection = sync_database.get_collection(TEST_COLLECTION_NAME) + assert collection == sync_collection + + NAMESPACE_2 = "other_namespace" + collection_ns2 = sync_database.get_collection( + TEST_COLLECTION_NAME, namespace=NAMESPACE_2 + ) + assert collection_ns2 == Collection( + sync_database, TEST_COLLECTION_NAME, namespace=NAMESPACE_2 + ) + assert collection_ns2._astra_db_collection.astra_db.namespace == NAMESPACE_2 diff --git a/tests/idiomatic/integration/test_ddl_async.py b/tests/idiomatic/integration/test_ddl_async.py new file mode 100644 index 00000000..69f67631 --- /dev/null +++ b/tests/idiomatic/integration/test_ddl_async.py @@ -0,0 +1,35 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from astrapy import AsyncDatabase + + +class TestDDLAsync: + @pytest.mark.describe("test of collection creation, get, and then drop, async") + async def test_collection_lifecycle_async( + self, + async_database: AsyncDatabase, + ) -> None: + TEST_COLLECTION_NAME = "test_coll" + col1 = await async_database.create_collection( + TEST_COLLECTION_NAME, + dimension=123, + metric="euclidean", + indexing={"deny": ["a", "b", "c"]}, + ) + col2 = await async_database.get_collection(TEST_COLLECTION_NAME) + assert col1 == col2 + await async_database.drop_collection(TEST_COLLECTION_NAME) diff --git a/tests/idiomatic/integration/test_ddl_sync.py b/tests/idiomatic/integration/test_ddl_sync.py new file mode 100644 index 00000000..94937dcd --- /dev/null +++ b/tests/idiomatic/integration/test_ddl_sync.py @@ -0,0 +1,35 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from astrapy import Database + + +class TestDDLSync: + @pytest.mark.describe("test of collection creation, get, and then drop, sync") + def test_collection_lifecycle_sync( + self, + sync_database: Database, + ) -> None: + TEST_COLLECTION_NAME = "test_coll" + col1 = sync_database.create_collection( + TEST_COLLECTION_NAME, + dimension=123, + metric="euclidean", + indexing={"deny": ["a", "b", "c"]}, + ) + col2 = sync_database.get_collection(TEST_COLLECTION_NAME) + assert col1 == col2 + sync_database.drop_collection(TEST_COLLECTION_NAME) diff --git a/tests/idiomatic/unit/__init__.py b/tests/idiomatic/unit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/idiomatic/unit/test_collections.py b/tests/idiomatic/unit/test_collections.py new file mode 100644 index 00000000..9db7514f --- /dev/null +++ b/tests/idiomatic/unit/test_collections.py @@ -0,0 +1,21 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + + +class TestCollectionsUnit: + @pytest.mark.describe("test placeholder, unit tests on Collection") + def test_collection_unit_placeholder(self) -> None: + assert True diff --git a/tests/idiomatic/unit/test_databases.py b/tests/idiomatic/unit/test_databases.py new file mode 100644 index 00000000..8abf66b7 --- /dev/null +++ b/tests/idiomatic/unit/test_databases.py @@ -0,0 +1,21 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + + +class TestDatabasesUnit: + @pytest.mark.describe("test placeholder, unit tests on Database") + def test_database_unit_placeholder(self) -> None: + assert True