Skip to content

Commit

Permalink
Split classes, modules, tests to keep the "idiomatic" layer well sepa…
Browse files Browse the repository at this point in the history
…rate and not touch the "astrapy" layer (#222)

* astrapy/idiomatic split structure: modules, imports, tests

* completed refactor into new structure

* complete refactoring into split classes, incl. tests. Not many methods implemented yet

* astrapy classes allow parameter override in their copy methods

* collection constructors does apply parameter overrides even when astra_db passed

* wip for the optional namespace when creating Collections

* create_collection, get_collection, drop_collection, Collection constructor
  • Loading branch information
hemidactylus authored Feb 28, 2024
1 parent 84e5e03 commit ea77276
Show file tree
Hide file tree
Showing 27 changed files with 1,952 additions and 502 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -414,3 +414,14 @@ To enable the `AstraDBOps` testing (off by default):
```bash
TEST_ASTRADBOPS=1 poetry run pytest [...]
```

To separately test the "astrapy proper" vs. the "idiomatic" part, and/or only the unit/integration part of the latter:

```
poetry run pytest tests/astrapy
poetry run pytest tests/idiomatic
poetry run pytest tests/idiomatic/unit
poetry run pytest tests/idiomatic/integration
```

(the above can be combined with the options seen earlier, where it makes sense).
24 changes: 19 additions & 5 deletions astrapy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@
import os
import importlib.metadata

from typing import Any


def get_version() -> Any:
def get_version() -> str:
try:
# Poetry will create a __version__ attribute in the package's __init__.py file
return importlib.metadata.version(__package__)
Expand All @@ -38,11 +36,27 @@ def get_version() -> Any:
pyproject_data = toml.loads(file_contents)

# Return the version from the poetry section
return pyproject_data["tool"]["poetry"]["version"]
return str(pyproject_data["tool"]["poetry"]["version"])

# If the pyproject.toml file does not exist or the version is not found, return unknown
except (FileNotFoundError, KeyError):
return "unknown"


__version__ = get_version()
__version__: str = get_version()

# There's a circular-import issue to heal here to bring this to top
from astrapy.idiomatic import ( # noqa: E402
AsyncCollection,
AsyncDatabase,
Collection,
Database,
)

__all__ = [
"AsyncCollection",
"AsyncDatabase",
"Collection",
"Database",
"__version__",
]
203 changes: 106 additions & 97 deletions astrapy/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
http_methods,
normalize_for_api,
restore_from_api,
return_unsupported_error
)
from astrapy.types import (
API_DOC,
Expand Down Expand Up @@ -105,11 +104,20 @@ def __init__(
caller_name=caller_name,
caller_version=caller_version,
)
else:
# if astra_db passed, copy and apply possible overrides
astra_db = astra_db.copy(
token=token,
api_endpoint=api_endpoint,
namespace=namespace,
caller_name=caller_name,
caller_version=caller_version,
)

# Set the remaining instance attributes
self.astra_db = astra_db
self.caller_name = caller_name or self.astra_db.caller_name
self.caller_version = caller_version or self.astra_db.caller_version
self.caller_name = self.astra_db.caller_name
self.caller_version = self.astra_db.caller_version
self.collection_name = collection_name
self.base_path = f"{self.astra_db.base_path}/{self.collection_name}"

Expand All @@ -129,12 +137,31 @@ def __eq__(self, other: Any) -> bool:
else:
return False

def copy(self) -> AstraDBCollection:
def copy(
self,
*,
collection_name: Optional[str] = None,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
api_path: Optional[str] = None,
api_version: Optional[str] = None,
namespace: Optional[str] = None,
caller_name: Optional[str] = None,
caller_version: Optional[str] = None,
) -> AstraDBCollection:
return AstraDBCollection(
collection_name=self.collection_name,
astra_db=self.astra_db.copy(),
caller_name=self.caller_name,
caller_version=self.caller_version,
collection_name=collection_name or self.collection_name,
astra_db=self.astra_db.copy(
token=token,
api_endpoint=api_endpoint,
api_path=api_path,
api_version=api_version,
namespace=namespace,
caller_name=caller_name,
caller_version=caller_version,
),
caller_name=caller_name or self.caller_name,
caller_version=caller_version or self.caller_version,
)

def to_async(self) -> AsyncAstraDBCollection:
Expand Down Expand Up @@ -1093,11 +1120,20 @@ def __init__(
caller_name=caller_name,
caller_version=caller_version,
)
else:
# if astra_db passed, copy and apply possible overrides
astra_db = astra_db.copy(
token=token,
api_endpoint=api_endpoint,
namespace=namespace,
caller_name=caller_name,
caller_version=caller_version,
)

# Set the remaining instance attributes
self.astra_db: AsyncAstraDB = astra_db
self.caller_name = caller_name or self.astra_db.caller_name
self.caller_version = caller_version or self.astra_db.caller_version
self.caller_name = self.astra_db.caller_name
self.caller_version = self.astra_db.caller_version
self.client = astra_db.client
self.collection_name = collection_name
self.base_path = f"{self.astra_db.base_path}/{self.collection_name}"
Expand All @@ -1118,12 +1154,31 @@ def __eq__(self, other: Any) -> bool:
else:
return False

def copy(self) -> AsyncAstraDBCollection:
def copy(
self,
*,
collection_name: Optional[str] = None,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
api_path: Optional[str] = None,
api_version: Optional[str] = None,
namespace: Optional[str] = None,
caller_name: Optional[str] = None,
caller_version: Optional[str] = None,
) -> AsyncAstraDBCollection:
return AsyncAstraDBCollection(
collection_name=self.collection_name,
astra_db=self.astra_db.copy(),
caller_name=self.caller_name,
caller_version=self.caller_version,
collection_name=collection_name or self.collection_name,
astra_db=self.astra_db.copy(
token=token,
api_endpoint=api_endpoint,
api_path=api_path,
api_version=api_version,
namespace=namespace,
caller_name=caller_name,
caller_version=caller_version,
),
caller_name=caller_name or self.caller_name,
caller_version=caller_version or self.caller_version,
)

def set_caller(
Expand Down Expand Up @@ -1991,58 +2046,6 @@ async def concurrent_upsert(doc: API_DOC) -> str:
if isinstance(result, BaseException) and not isinstance(result, Exception):
raise result
return results # type: ignore

# Mongodb calls not supported by the API
async def find_raw_batches():
return_unsupported_error()

async def aggregate():
return_unsupported_error()

async def aggregate_raw_batches():
return_unsupported_error()

async def watch():
return_unsupported_error()

async def rename():
return_unsupported_error()

async def create_index():
return_unsupported_error()

async def create_indexes():
return_unsupported_error()

async def drop_index():
return_unsupported_error()

async def drop_indexes():
return_unsupported_error()

async def list_indexes():
return_unsupported_error()

async def index_information():
return_unsupported_error()

async def create_search_index():
return_unsupported_error()

async def create_search_indexes():
return_unsupported_error()

async def drop_search_index():
return_unsupported_error()

async def list_search_indexes():
return_unsupported_error()

async def update_search_index():
return_unsupported_error()

async def distinct():
return_unsupported_error()


class AstraDB:
Expand Down Expand Up @@ -2116,15 +2119,25 @@ def __eq__(self, other: Any) -> bool:
else:
return False

def copy(self) -> AstraDB:
def copy(
self,
*,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
api_path: Optional[str] = None,
api_version: Optional[str] = None,
namespace: Optional[str] = None,
caller_name: Optional[str] = None,
caller_version: Optional[str] = None,
) -> AstraDB:
return AstraDB(
token=self.token,
api_endpoint=self.base_url,
api_path=self.api_path,
api_version=self.api_version,
namespace=self.namespace,
caller_name=self.caller_name,
caller_version=self.caller_version,
token=token or self.token,
api_endpoint=api_endpoint or self.base_url,
api_path=api_path or self.api_path,
api_version=api_version or self.api_version,
namespace=namespace or self.namespace,
caller_name=caller_name or self.caller_name,
caller_version=caller_version or self.caller_version,
)

def to_async(self) -> AsyncAstraDB:
Expand Down Expand Up @@ -2321,20 +2334,6 @@ def truncate_collection(self, collection_name: str) -> AstraDBCollection:
# return the collection itself
return collection

def aggregate():
return_unsupported_error()

def cursor_command():
return_unsupported_error()

def dereference():
return_unsupported_error()

def watch():
return_unsupported_error()

def validate_collection():
return_unsupported_error()

class AsyncAstraDB:
def __init__(
Expand Down Expand Up @@ -2416,15 +2415,25 @@ async def __aexit__(
) -> None:
await self.client.aclose()

def copy(self) -> AsyncAstraDB:
def copy(
self,
*,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
api_path: Optional[str] = None,
api_version: Optional[str] = None,
namespace: Optional[str] = None,
caller_name: Optional[str] = None,
caller_version: Optional[str] = None,
) -> AsyncAstraDB:
return AsyncAstraDB(
token=self.token,
api_endpoint=self.base_url,
api_path=self.api_path,
api_version=self.api_version,
namespace=self.namespace,
caller_name=self.caller_name,
caller_version=self.caller_version,
token=token or self.token,
api_endpoint=api_endpoint or self.base_url,
api_path=api_path or self.api_path,
api_version=api_version or self.api_version,
namespace=namespace or self.namespace,
caller_name=caller_name or self.caller_name,
caller_version=caller_version or self.caller_version,
)

def to_sync(self) -> AstraDB:
Expand Down
23 changes: 23 additions & 0 deletions astrapy/idiomatic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from astrapy.idiomatic.collection import AsyncCollection, Collection
from astrapy.idiomatic.database import AsyncDatabase, Database

__all__ = [
"AsyncCollection",
"Collection",
"AsyncDatabase",
"Database",
]
Loading

0 comments on commit ea77276

Please sign in to comment.