Skip to content

Commit

Permalink
Merge branch 'master' into bugfix/175-warning-ordered
Browse files Browse the repository at this point in the history
  • Loading branch information
hemidactylus authored Feb 8, 2024
2 parents 9321ac0 + da06717 commit 3c2a726
Show file tree
Hide file tree
Showing 3 changed files with 526 additions and 15 deletions.
210 changes: 198 additions & 12 deletions astrapy/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def __init__(
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
namespace: Optional[str] = None,
caller_name: Optional[str] = None,
caller_version: Optional[str] = None,
) -> None:
"""
Initialize an AstraDBCollection instance.
Expand All @@ -85,29 +87,75 @@ def __init__(
token (str, optional): Authentication token for Astra DB.
api_endpoint (str, optional): API endpoint URL.
namespace (str, optional): Namespace for the database.
caller_name (str, optional): identity of the caller ("my_framework")
If passing a client, its caller is used as fallback
caller_version (str, optional): version of the caller code ("1.0.3")
If passing a client, its caller is used as fallback
"""
# Check for presence of the Astra DB object
if astra_db is None:
if token is None or api_endpoint is None:
raise AssertionError("Must provide token and api_endpoint")

astra_db = AstraDB(
token=token, api_endpoint=api_endpoint, namespace=namespace
token=token,
api_endpoint=api_endpoint,
namespace=namespace,
caller_name=caller_name,
caller_version=caller_version,
)

# Set the remaining instance attributes
self.astra_db = astra_db
self.caller_name = caller_name or self.astra_db.caller_name
self.caller_version = caller_version or self.astra_db.caller_version
self.collection_name = collection_name
self.base_path = f"{self.astra_db.base_path}/{self.collection_name}"

def __repr__(self) -> str:
return f'AstraDBCollection[astra_db="{self.astra_db}", collection_name="{self.collection_name}"]'

def __eq__(self, other: Any) -> bool:
if isinstance(other, AstraDBCollection):
return all(
[
self.collection_name == other.collection_name,
self.astra_db == other.astra_db,
self.caller_name == other.caller_name,
self.caller_version == other.caller_version,
]
)
else:
return False

def copy(self) -> AstraDBCollection:
return AstraDBCollection(
collection_name=self.collection_name,
astra_db=self.astra_db.copy(),
caller_name=self.caller_name,
caller_version=self.caller_version,
)

def to_async(self) -> AsyncAstraDBCollection:
return AsyncAstraDBCollection(
astra_db=self.astra_db.to_async(), collection_name=self.collection_name
collection_name=self.collection_name,
astra_db=self.astra_db.to_async(),
caller_name=self.caller_name,
caller_version=self.caller_version,
)

def set_caller(
self,
caller_name: Optional[str] = None,
caller_version: Optional[str] = None,
) -> None:
self.astra_db.set_caller(
caller_name=caller_name,
caller_version=caller_version,
)
self.caller_name = caller_name
self.caller_version = caller_version

def _request(
self,
method: str = http_methods.POST,
Expand All @@ -126,8 +174,8 @@ def _request(
url_params=url_params,
path=path,
skip_error_check=skip_error_check,
caller_name=None,
caller_version=None,
caller_name=self.caller_name,
caller_version=self.caller_version,
)
response = restore_from_api(direct_response)
return response
Expand Down Expand Up @@ -995,6 +1043,8 @@ def __init__(
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
namespace: Optional[str] = None,
caller_name: Optional[str] = None,
caller_version: Optional[str] = None,
) -> None:
"""
Initialize an AstraDBCollection instance.
Expand All @@ -1004,28 +1054,74 @@ def __init__(
token (str, optional): Authentication token for Astra DB.
api_endpoint (str, optional): API endpoint URL.
namespace (str, optional): Namespace for the database.
caller_name (str, optional): identity of the caller ("my_framework")
If passing a client, its caller is used as fallback
caller_version (str, optional): version of the caller code ("1.0.3")
If passing a client, its caller is used as fallback
"""
# Check for presence of the Astra DB object
if astra_db is None:
if token is None or api_endpoint is None:
raise AssertionError("Must provide token and api_endpoint")

astra_db = AsyncAstraDB(
token=token, api_endpoint=api_endpoint, namespace=namespace
token=token,
api_endpoint=api_endpoint,
namespace=namespace,
caller_name=caller_name,
caller_version=caller_version,
)

# Set the remaining instance attributes
self.astra_db: AsyncAstraDB = astra_db
self.caller_name = caller_name or self.astra_db.caller_name
self.caller_version = caller_version or self.astra_db.caller_version
self.client = astra_db.client
self.collection_name = collection_name
self.base_path = f"{self.astra_db.base_path}/{self.collection_name}"

def __repr__(self) -> str:
return f'AsyncAstraDBCollection[astra_db="{self.astra_db}", collection_name="{self.collection_name}"]'

def __eq__(self, other: Any) -> bool:
if isinstance(other, AsyncAstraDBCollection):
return all(
[
self.collection_name == other.collection_name,
self.astra_db == other.astra_db,
self.caller_name == other.caller_name,
self.caller_version == other.caller_version,
]
)
else:
return False

def copy(self) -> AsyncAstraDBCollection:
return AsyncAstraDBCollection(
collection_name=self.collection_name,
astra_db=self.astra_db.copy(),
caller_name=self.caller_name,
caller_version=self.caller_version,
)

def set_caller(
self,
caller_name: Optional[str] = None,
caller_version: Optional[str] = None,
) -> None:
self.astra_db.set_caller(
caller_name=caller_name,
caller_version=caller_version,
)
self.caller_name = caller_name
self.caller_version = caller_version

def to_sync(self) -> AstraDBCollection:
return AstraDBCollection(
astra_db=self.astra_db.to_sync(), collection_name=self.collection_name
collection_name=self.collection_name,
astra_db=self.astra_db.to_sync(),
caller_name=self.caller_name,
caller_version=self.caller_version,
)

async def _request(
Expand All @@ -1047,8 +1143,8 @@ async def _request(
url_params=url_params,
path=path,
skip_error_check=skip_error_check,
caller_name=None,
caller_version=None,
caller_name=self.caller_name,
caller_version=self.caller_version,
)
response = restore_from_api(adirect_response)
return response
Expand Down Expand Up @@ -1865,14 +1961,23 @@ def __init__(
api_path: Optional[str] = None,
api_version: Optional[str] = None,
namespace: Optional[str] = None,
caller_name: Optional[str] = None,
caller_version: Optional[str] = None,
) -> None:
"""
Initialize an Astra DB instance.
Args:
token (str): Authentication token for Astra DB.
api_endpoint (str): API endpoint URL.
api_path (str, optional): used to override default URI construction
api_version (str, optional): to override default URI construction
namespace (str, optional): Namespace for the database.
caller_name (str, optional): identity of the caller ("my_framework")
caller_version (str, optional): version of the caller code ("1.0.3")
"""
self.caller_name = caller_name
self.caller_version = caller_version

if token is None or api_endpoint is None:
raise AssertionError("Must provide token and api_endpoint")

Expand Down Expand Up @@ -1901,15 +2006,51 @@ def __init__(
def __repr__(self) -> str:
return f'AstraDB[endpoint="{self.base_url}", keyspace="{self.namespace}"]'

def __eq__(self, other: Any) -> bool:
if isinstance(other, AstraDB):
# work on the "normalized" quantities (stripped, etc)
return all(
[
self.token == other.token,
self.base_url == other.base_url,
self.base_path == other.base_path,
self.caller_name == other.caller_name,
self.caller_version == other.caller_version,
]
)
else:
return False

def copy(self) -> AstraDB:
return AstraDB(
token=self.token,
api_endpoint=self.base_url,
api_path=self.api_path,
api_version=self.api_version,
namespace=self.namespace,
caller_name=self.caller_name,
caller_version=self.caller_version,
)

def to_async(self) -> AsyncAstraDB:
return AsyncAstraDB(
token=self.token,
api_endpoint=self.base_url,
api_path=self.api_path,
api_version=self.api_version,
namespace=self.namespace,
caller_name=self.caller_name,
caller_version=self.caller_version,
)

def set_caller(
self,
caller_name: Optional[str] = None,
caller_version: Optional[str] = None,
) -> None:
self.caller_name = caller_name
self.caller_version = caller_version

def _request(
self,
method: str = http_methods.POST,
Expand All @@ -1928,8 +2069,8 @@ def _request(
url_params=url_params,
path=path,
skip_error_check=skip_error_check,
caller_name=None,
caller_version=None,
caller_name=self.caller_name,
caller_version=self.caller_version,
)
response = restore_from_api(direct_response)
return response
Expand Down Expand Up @@ -2094,14 +2235,23 @@ def __init__(
api_path: Optional[str] = None,
api_version: Optional[str] = None,
namespace: Optional[str] = None,
caller_name: Optional[str] = None,
caller_version: Optional[str] = None,
) -> None:
"""
Initialize an Astra DB instance.
Args:
token (str): Authentication token for Astra DB.
api_endpoint (str): API endpoint URL.
api_path (str, optional): used to override default URI construction
api_version (str, optional): to override default URI construction
namespace (str, optional): Namespace for the database.
caller_name (str, optional): identity of the caller ("my_framework")
caller_version (str, optional): version of the caller code ("1.0.3")
"""
self.caller_name = caller_name
self.caller_version = caller_version

self.client = httpx.AsyncClient()
if token is None or api_endpoint is None:
raise AssertionError("Must provide token and api_endpoint")
Expand Down Expand Up @@ -2131,6 +2281,21 @@ def __init__(
def __repr__(self) -> str:
return f'AsyncAstraDB[endpoint="{self.base_url}", keyspace="{self.namespace}"]'

def __eq__(self, other: Any) -> bool:
if isinstance(other, AsyncAstraDB):
# work on the "normalized" quantities (stripped, etc)
return all(
[
self.token == other.token,
self.base_url == other.base_url,
self.base_path == other.base_path,
self.caller_name == other.caller_name,
self.caller_version == other.caller_version,
]
)
else:
return False

async def __aenter__(self) -> AsyncAstraDB:
return self

Expand All @@ -2142,15 +2307,36 @@ async def __aexit__(
) -> None:
await self.client.aclose()

def copy(self) -> AsyncAstraDB:
return AsyncAstraDB(
token=self.token,
api_endpoint=self.base_url,
api_path=self.api_path,
api_version=self.api_version,
namespace=self.namespace,
caller_name=self.caller_name,
caller_version=self.caller_version,
)

def to_sync(self) -> AstraDB:
return AstraDB(
token=self.token,
api_endpoint=self.base_url,
api_path=self.api_path,
api_version=self.api_version,
namespace=self.namespace,
caller_name=self.caller_name,
caller_version=self.caller_version,
)

def set_caller(
self,
caller_name: Optional[str] = None,
caller_version: Optional[str] = None,
) -> None:
self.caller_name = caller_name
self.caller_version = caller_version

async def _request(
self,
method: str = http_methods.POST,
Expand All @@ -2169,8 +2355,8 @@ async def _request(
url_params=url_params,
path=path,
skip_error_check=skip_error_check,
caller_name=None,
caller_version=None,
caller_name=self.caller_name,
caller_version=self.caller_version,
)
response = restore_from_api(adirect_response)
return response
Expand Down
Loading

0 comments on commit 3c2a726

Please sign in to comment.