Skip to content

Commit

Permalink
Emb.providers dataclasses know about parameters' hint and displayName (
Browse files Browse the repository at this point in the history
…#295)

emb.providers dataclasses know about parameters' hint and displayName
  • Loading branch information
hemidactylus authored Jul 22, 2024
1 parent a4885a4 commit 079fd0d
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 25 deletions.
6 changes: 6 additions & 0 deletions CHANGES
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
master
======
FindEmbeddingProvidersResult and descendant dataclasses:
- knowedge of optional-as-null vs optional-as-possibly-absent ancillary fields
- add handling of optional 'hint' and 'displayName' fields for parameters

v. 1.4.0
========
DatabaseAdmin classes retain a reference to the Async/Database instance that spawned it, if any
Expand Down
26 changes: 19 additions & 7 deletions astrapy/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,9 @@ class EmbeddingProviderParameter:
"""

default_value: Any
display_name: Optional[str]
help: Optional[str]
hint: Optional[str]
name: str
required: bool
parameter_type: str
Expand All @@ -467,12 +469,18 @@ def as_dict(self) -> Dict[str, Any]:
"""Recast this object into a dictionary."""

return {
"defaultValue": self.default_value,
"help": self.help,
"name": self.name,
"required": self.required,
"type": self.parameter_type,
"validation": self.validation,
k: v
for k, v in {
"defaultValue": self.default_value,
"displayName": self.display_name,
"help": self.help,
"hint": self.hint,
"name": self.name,
"required": self.required,
"type": self.parameter_type,
"validation": self.validation,
}.items()
if v is not None
}

@staticmethod
Expand All @@ -484,7 +492,9 @@ def from_dict(raw_dict: Dict[str, Any]) -> EmbeddingProviderParameter:

residual_keys = raw_dict.keys() - {
"defaultValue",
"displayName",
"help",
"hint",
"name",
"required",
"type",
Expand All @@ -497,7 +507,9 @@ def from_dict(raw_dict: Dict[str, Any]) -> EmbeddingProviderParameter:
)
return EmbeddingProviderParameter(
default_value=raw_dict["defaultValue"],
help=raw_dict["help"],
display_name=raw_dict.get("displayName"),
help=raw_dict.get("help"),
hint=raw_dict.get("hint"),
name=raw_dict["name"],
required=raw_dict["required"],
parameter_type=raw_dict["type"],
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "astrapy"
version = "1.4.0"
version = "1.4.1"
description = "AstraPy is a Pythonic SDK for DataStax Astra and its Data API"
authors = [
"Stefano Lottini <[email protected]>",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import os
import sys
from typing import Any, Dict, List, Union

import pytest
Expand All @@ -24,6 +25,8 @@
from astrapy.exceptions import DataAPIResponseException, InsertManyException
from astrapy.info import CollectionVectorServiceOptions

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from ..conftest import IS_ASTRA_DB
from ..vectorize_models import live_test_models

Expand Down
8 changes: 3 additions & 5 deletions tests/vectorize_idiomatic/live_provider_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

from __future__ import annotations

from typing import Dict

from preprocess_env import (
ASTRA_DB_API_ENDPOINT,
ASTRA_DB_KEYSPACE,
Expand All @@ -29,10 +27,10 @@
from astrapy import DataAPIClient, Database
from astrapy.admin import parse_api_endpoint
from astrapy.constants import Environment
from astrapy.info import EmbeddingProvider
from astrapy.info import FindEmbeddingProvidersResult


def live_provider_info() -> Dict[str, EmbeddingProvider]:
def live_provider_info() -> FindEmbeddingProvidersResult:
"""
Query the API endpoint `findEmbeddingProviders` endpoint
for the latest information.
Expand Down Expand Up @@ -64,4 +62,4 @@ def live_provider_info() -> Dict[str, EmbeddingProvider]:

database_admin = database.get_database_admin()
response = database_admin.find_embedding_providers()
return response.embedding_providers
return response
37 changes: 30 additions & 7 deletions tests/vectorize_idiomatic/query_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
import json
import os
import sys
from typing import List

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from astrapy.info import EmbeddingProviderParameter, FindEmbeddingProvidersResult

from typing import Dict
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from live_provider_info import live_provider_info

from astrapy.info import EmbeddingProvider, EmbeddingProviderParameter
from vectorize_models import live_test_models


def desc_param(param_data: EmbeddingProviderParameter) -> str:
Expand Down Expand Up @@ -56,11 +56,15 @@ def desc_param(param_data: EmbeddingProviderParameter) -> str:


if __name__ == "__main__":
providers: Dict[str, EmbeddingProvider] = live_provider_info()
providers_json = {ep_name: ep.as_dict() for ep_name, ep in providers.items()}
provider_info: FindEmbeddingProvidersResult = live_provider_info()
providers_json = (provider_info.raw_info or {}).get("embeddingProviders")
if not providers_json:
raise ValueError(
"raw info from embedding providers lacks `embeddingProviders` content."
)
json.dump(providers_json, open("_providers.json", "w"), indent=2, sort_keys=True)

for provider, provider_data in sorted(providers.items()):
for provider, provider_data in sorted(provider_info.embedding_providers.items()):
print(f"{provider} ({len(provider_data.models)} models)")
print(" auth:")
for auth_type, auth_data in sorted(
Expand Down Expand Up @@ -98,3 +102,22 @@ def desc_param(param_data: EmbeddingProviderParameter) -> str:
param_display_name = f"({param_name})"
param_desc = desc_param(param_data)
print(f" - {param_display_name}: {param_desc}")

print("\n" * 2)
all_test_models = list(live_test_models())
for auth_type in ["HEADER", "NONE", "SHARED_SECRET"]:
print(f"Tags for auth type {auth_type}:", end="")
#
at_test_models = [
test_model
for test_model in all_test_models
if test_model["auth_type_name"] == auth_type
]
at_model_ids: List[str] = sorted(
[str(model_desc["model_tag"]) for model_desc in at_test_models]
)
if at_model_ids:
print("")
print("\n".join(f" {ami}" for ami in at_model_ids))
else:
print(" (no tags)")
12 changes: 7 additions & 5 deletions tests/vectorize_idiomatic/vectorize_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@
import sys
from typing import Any, Dict, Iterable, List, Tuple

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from astrapy.authentication import (
EMBEDDING_HEADER_API_KEY,
EMBEDDING_HEADER_AWS_ACCESS_ID,
EMBEDDING_HEADER_AWS_SECRET_ID,
)
from astrapy.info import CollectionVectorServiceOptions, EmbeddingProviderParameter

from .live_provider_info import live_provider_info
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from live_provider_info import live_provider_info

alphanum = set("qwertyuiopasdfghjklzxcvbnmQWERTYUIOPASDFGHJKLZXCVBNM1234567890")

Expand Down Expand Up @@ -196,8 +196,10 @@ def _collapse(longt: str) -> str:
return f"{longt[:30]}_{longt[-5:]}"

# generate the full list of models based on the live provider endpoint
providers = live_provider_info()
for provider_name, provider_desc in sorted(providers.items()):
provider_info = live_provider_info()
for provider_name, provider_desc in sorted(
provider_info.embedding_providers.items()
):
for model in provider_desc.models:
for auth_type_name, auth_type_desc in sorted(
provider_desc.supported_authentication.items()
Expand Down

0 comments on commit 079fd0d

Please sign in to comment.