From af509a3958a89d6abd50a321aa7dbb1d0ce91710 Mon Sep 17 00:00:00 2001 From: caufieldjh Date: Tue, 20 Aug 2024 12:28:52 -0400 Subject: [PATCH 1/3] Add config to each pydantic BaseModel --- src/curate_gpt/agents/bootstrap_agent.py | 3 ++- src/curate_gpt/agents/chat_agent.py | 4 +++- src/curate_gpt/agents/concept_recognition_agent.py | 8 +++++++- src/curate_gpt/agents/dase_agent.py | 4 +++- src/curate_gpt/agents/dragon_agent.py | 4 +++- src/curate_gpt/agents/mapping_agent.py | 5 ++++- src/curate_gpt/app/cart.py | 4 +++- src/curate_gpt/evaluation/evaluation_datamodel.py | 6 +++++- src/curate_gpt/extract/extractor.py | 3 +++ src/curate_gpt/store/db_metadata.py | 5 ++++- src/curate_gpt/store/duckdb_result.py | 5 ++++- src/curate_gpt/store/in_memory_adapter.py | 8 +++++++- src/curate_gpt/store/metadata.py | 4 +++- src/curate_gpt/store/schema_proxy.py | 4 +++- src/curate_gpt/utils/eval_utils.py | 5 ++++- src/curate_gpt/wrappers/general/github_wrapper.py | 4 +++- src/curate_gpt/wrappers/ontology/ontology.py | 5 ++++- tests/extract/test_extractor.py | 4 +++- 18 files changed, 68 insertions(+), 17 deletions(-) diff --git a/src/curate_gpt/agents/bootstrap_agent.py b/src/curate_gpt/agents/bootstrap_agent.py index bc990ad..40c5189 100644 --- a/src/curate_gpt/agents/bootstrap_agent.py +++ b/src/curate_gpt/agents/bootstrap_agent.py @@ -3,7 +3,7 @@ import yaml from jinja2 import Template -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from curate_gpt.agents.base_agent import BaseAgent from curate_gpt.conf.prompts import PROMPTS_DIR @@ -11,6 +11,7 @@ class KnowledgeBaseSpecification(BaseModel): + model_config = ConfigDict(protected_namespaces=()) kb_name: str description: str attributes: str diff --git a/src/curate_gpt/agents/chat_agent.py b/src/curate_gpt/agents/chat_agent.py index 6724cce..e0e6bd1 100644 --- a/src/curate_gpt/agents/chat_agent.py +++ b/src/curate_gpt/agents/chat_agent.py @@ -7,7 +7,7 @@ import yaml from llm import Conversation -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from curate_gpt.agents.base_agent import BaseAgent from curate_gpt.utils.tokens import estimate_num_tokens, max_tokens_by_model @@ -23,6 +23,8 @@ class ChatResponse(BaseModel): TODO: Rename class to indicate that it is provenance-enabled chat """ + model_config = ConfigDict(protected_namespaces=()) + body: str """Text of response.""" diff --git a/src/curate_gpt/agents/concept_recognition_agent.py b/src/curate_gpt/agents/concept_recognition_agent.py index d39d763..7556f33 100644 --- a/src/curate_gpt/agents/concept_recognition_agent.py +++ b/src/curate_gpt/agents/concept_recognition_agent.py @@ -5,7 +5,7 @@ from enum import Enum from typing import Dict, List, Optional, Tuple -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from curate_gpt.agents.base_agent import BaseAgent @@ -18,6 +18,8 @@ class Span(BaseModel): """An individual span of text containing a single concept.""" + model_config = ConfigDict(protected_namespaces=()) + text: str start: Optional[int] = None @@ -36,6 +38,8 @@ class Span(BaseModel): class GroundingResult(BaseModel): """Result of grounding text.""" + model_config = ConfigDict(protected_namespaces=()) + input_text: str """Text that is supplied for grounding, assumed to contain a single context.""" @@ -62,6 +66,8 @@ class AnnotationMethod(str, Enum): class AnnotatedText(BaseModel): """In input text annotated with concept instances.""" + model_config = ConfigDict(protected_namespaces=()) + input_text: str """Text that is supplied for annotation.""" diff --git a/src/curate_gpt/agents/dase_agent.py b/src/curate_gpt/agents/dase_agent.py index c2ab1cd..b4bf1f2 100644 --- a/src/curate_gpt/agents/dase_agent.py +++ b/src/curate_gpt/agents/dase_agent.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, field from typing import Any, ClassVar, Dict, List, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from curate_gpt.agents.base_agent import BaseAgent from curate_gpt.agents.chat_agent import ChatResponse @@ -17,6 +17,8 @@ class PredictedFieldValue(BaseModel): + + model_config = ConfigDict(protected_namespaces=()) id: str original_id: Optional[str] = None predicted_value: Optional[str] = None diff --git a/src/curate_gpt/agents/dragon_agent.py b/src/curate_gpt/agents/dragon_agent.py index 4b5d9b3..afb28a6 100644 --- a/src/curate_gpt/agents/dragon_agent.py +++ b/src/curate_gpt/agents/dragon_agent.py @@ -5,7 +5,7 @@ from typing import Any, ClassVar, Dict, Iterable, List, Optional, Tuple, Union import yaml -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from curate_gpt.agents.base_agent import BaseAgent from curate_gpt.extract import AnnotatedObject @@ -17,6 +17,8 @@ class PredictedFieldValue(BaseModel): + + model_config = ConfigDict(protected_namespaces=()) id: str original_id: Optional[str] = None predicted_value: Optional[str] = None diff --git a/src/curate_gpt/agents/mapping_agent.py b/src/curate_gpt/agents/mapping_agent.py index 51db165..c48e943 100644 --- a/src/curate_gpt/agents/mapping_agent.py +++ b/src/curate_gpt/agents/mapping_agent.py @@ -10,7 +10,7 @@ import inflection import yaml -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from curate_gpt.agents.base_agent import BaseAgent from curate_gpt.store.db_adapter import SEARCH_RESULT @@ -46,12 +46,15 @@ class MappingPredicate(str, Enum): class Mapping(BaseModel): """Response from chat engine.""" + model_config = ConfigDict(protected_namespaces=()) subject_id: str object_id: str predicate_id: Optional[MappingPredicate] = None class MappingSet(BaseModel): + + model_config = ConfigDict(protected_namespaces=()) mappings: List[Mapping] prompt: str = None response_text: str = None diff --git a/src/curate_gpt/app/cart.py b/src/curate_gpt/app/cart.py index f66d8e6..4de7fff 100644 --- a/src/curate_gpt/app/cart.py +++ b/src/curate_gpt/app/cart.py @@ -1,6 +1,6 @@ from typing import Dict, List, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict class CartItem(BaseModel): @@ -8,6 +8,7 @@ class CartItem(BaseModel): A cart item is a single item in a cart """ + model_config = ConfigDict(protected_namespaces=()) object: Union[Dict, BaseModel] object_type: Optional[str] = None source: Optional[str] = None @@ -19,6 +20,7 @@ class Cart(BaseModel): A cart is a list of items that can be added to or removed from """ + model_config = ConfigDict(protected_namespaces=()) items: List[CartItem] = [] @property diff --git a/src/curate_gpt/evaluation/evaluation_datamodel.py b/src/curate_gpt/evaluation/evaluation_datamodel.py index c87df13..d864f62 100644 --- a/src/curate_gpt/evaluation/evaluation_datamodel.py +++ b/src/curate_gpt/evaluation/evaluation_datamodel.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import Dict, List, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict class StratifiedCollection(BaseModel): @@ -10,6 +10,7 @@ class StratifiedCollection(BaseModel): A collection of objects that have been split into training, test, and validation sets. """ + model_config = ConfigDict(protected_namespaces=()) source: str = None training_set: List[Dict] = None testing_set: List[Dict] = None @@ -33,6 +34,8 @@ class AggregationMethod(str, Enum): class ClassificationMetrics(BaseModel): + + model_config = ConfigDict(protected_namespaces=()) precision: float recall: float f1_score: float @@ -49,6 +52,7 @@ class Task(BaseModel): A task to be run by the evaluation runner. """ + model_config = ConfigDict(protected_namespaces=()) model_name: str = "gpt-3.5-turbo" embedding_model_name: str = "openai:" generate_background: bool = False diff --git a/src/curate_gpt/extract/extractor.py b/src/curate_gpt/extract/extractor.py index f644104..7e5dc65 100644 --- a/src/curate_gpt/extract/extractor.py +++ b/src/curate_gpt/extract/extractor.py @@ -9,6 +9,7 @@ import llm from linkml_runtime import SchemaView from pydantic import BaseModel as BaseModel +from pydantic import ConfigDict from curate_gpt.store.schema_proxy import SchemaProxy @@ -20,6 +21,8 @@ class AnnotatedObject(BaseModel): Annotated object shadows a basic dictionary object """ + model_config = ConfigDict(protected_namespaces=()) + # object: Union[Dict[str, Any], List[Dict[str, Any]]] = {} - TODO: pydantic bug? object: Any = {} annotations: Dict[str, Any] = {} diff --git a/src/curate_gpt/store/db_metadata.py b/src/curate_gpt/store/db_metadata.py index 4f39de4..d68d19f 100644 --- a/src/curate_gpt/store/db_metadata.py +++ b/src/curate_gpt/store/db_metadata.py @@ -1,10 +1,13 @@ from pathlib import Path -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict import yaml class DBSettings(BaseModel): + + model_config = ConfigDict(protected_namespaces=()) + name: str = "duckdb" """Name of the database.""" diff --git a/src/curate_gpt/store/duckdb_result.py b/src/curate_gpt/store/duckdb_result.py index bd18466..e49b702 100644 --- a/src/curate_gpt/store/duckdb_result.py +++ b/src/curate_gpt/store/duckdb_result.py @@ -3,12 +3,15 @@ import jsonlines import yaml -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict SEARCH_RESULT = Tuple[Dict[str, Any], Dict, float, Optional[Dict]] class DuckDBSearchResult(BaseModel): + + model_config = ConfigDict(protected_namespaces=()) + ids: Optional[str] = None metadatas: Optional[Dict[str, Any]] = None embeddings: Optional[List[float]] = None diff --git a/src/curate_gpt/store/in_memory_adapter.py b/src/curate_gpt/store/in_memory_adapter.py index f6e1bed..0c0cd70 100644 --- a/src/curate_gpt/store/in_memory_adapter.py +++ b/src/curate_gpt/store/in_memory_adapter.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, field from typing import ClassVar, Dict, Iterable, Iterator, List, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from curate_gpt import DBAdapter from curate_gpt.store.db_adapter import OBJECT, PROJECTION, QUERY, SEARCH_RESULT @@ -14,6 +14,9 @@ class Collection(BaseModel): + + model_config = ConfigDict(protected_namespaces=()) + objects: List[Dict] = [] metadata: Dict = {} @@ -25,6 +28,9 @@ def delete(self, key_value: str, key: str) -> None: class CollectionIndex(BaseModel): + + model_config = ConfigDict(protected_namespaces=()) + collections: Dict[str, Collection] = {} def get_collection(self, name: str) -> Collection: diff --git a/src/curate_gpt/store/metadata.py b/src/curate_gpt/store/metadata.py index 97909e6..df3433b 100644 --- a/src/curate_gpt/store/metadata.py +++ b/src/curate_gpt/store/metadata.py @@ -1,6 +1,6 @@ from typing import Dict, Optional -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict class CollectionMetadata(BaseModel): @@ -10,6 +10,8 @@ class CollectionMetadata(BaseModel): This is an open class, so additional metadata can be added. """ + model_config = ConfigDict(protected_namespaces=()) + name: Optional[str] = None """Name of the collection""" diff --git a/src/curate_gpt/store/schema_proxy.py b/src/curate_gpt/store/schema_proxy.py index b18815d..5b65a31 100644 --- a/src/curate_gpt/store/schema_proxy.py +++ b/src/curate_gpt/store/schema_proxy.py @@ -4,7 +4,7 @@ from linkml_runtime import SchemaView from linkml_runtime.linkml_model import SchemaDefinition -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict @dataclass @@ -13,6 +13,8 @@ class SchemaProxy: Manage connection to a schema """ + model_config = ConfigDict(protected_namespaces=()) + schema_source: Union[str, Path, SchemaDefinition] = None _pydantic_root_model: BaseModel = None _schemaview: SchemaView = None diff --git a/src/curate_gpt/utils/eval_utils.py b/src/curate_gpt/utils/eval_utils.py index 416281c..c1d3ffa 100644 --- a/src/curate_gpt/utils/eval_utils.py +++ b/src/curate_gpt/utils/eval_utils.py @@ -3,10 +3,13 @@ from copy import copy from typing import Any, Dict, List, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict class Outcome(BaseModel): + + model_config = ConfigDict(protected_namespaces=()) + prediction: Union[Dict[str, Any], List[Dict[str, Any]]] = {} expected: Union[Dict[str, Any], List[Dict[str, Any]]] = {} parameters: Dict[str, Any] = {} diff --git a/src/curate_gpt/wrappers/general/github_wrapper.py b/src/curate_gpt/wrappers/general/github_wrapper.py index 4b6375f..7f8191c 100644 --- a/src/curate_gpt/wrappers/general/github_wrapper.py +++ b/src/curate_gpt/wrappers/general/github_wrapper.py @@ -8,7 +8,7 @@ import requests import requests_cache -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from curate_gpt.wrappers.base_wrapper import BaseWrapper @@ -16,12 +16,14 @@ class Comment(BaseModel): + model_config = ConfigDict(protected_namespaces=()) id: str user: str = None body: str = None class Issue(BaseModel): + model_config = ConfigDict(protected_namespaces=()) id: str number: int = None type: str = None diff --git a/src/curate_gpt/wrappers/ontology/ontology.py b/src/curate_gpt/wrappers/ontology/ontology.py index 1472cbe..008b428 100644 --- a/src/curate_gpt/wrappers/ontology/ontology.py +++ b/src/curate_gpt/wrappers/ontology/ontology.py @@ -1,6 +1,6 @@ from typing import List -from pydantic import BaseModel, Extra +from pydantic import BaseModel, ConfigDict, Extra class Relationship(BaseModel): @@ -10,6 +10,7 @@ class Relationship(BaseModel): Corresponds to an edge in an OBO graph. """ + model_config = ConfigDict(protected_namespaces=()) predicate: str target: str @@ -21,6 +22,7 @@ class OntologyClass(BaseModel, extra=Extra.allow): Corresponds to a node in an OBO graph. """ + model_config = ConfigDict(protected_namespaces=()) id: str label: str = None definition: str = None @@ -37,4 +39,5 @@ class Ontology(BaseModel): Corresponds to an OBO graph. """ + model_config = ConfigDict(protected_namespaces=()) elements: List[OntologyClass] = None diff --git a/tests/extract/test_extractor.py b/tests/extract/test_extractor.py index 17c93ee..452741d 100644 --- a/tests/extract/test_extractor.py +++ b/tests/extract/test_extractor.py @@ -7,15 +7,17 @@ from curate_gpt.extract.recursive_extractor import RecursiveExtractor from curate_gpt.store.schema_proxy import SchemaProxy from linkml_runtime.utils.schema_builder import SchemaBuilder -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict class Occupation(BaseModel): + model_config = ConfigDict(protected_namespaces=()) category: str current: bool class Person(BaseModel): + model_config = ConfigDict(protected_namespaces=()) name: str age: int occupations: List[Occupation] From e5f6faa83bc70715445ed429d704f76574d8ec74 Mon Sep 17 00:00:00 2001 From: caufieldjh Date: Tue, 20 Aug 2024 12:43:19 -0400 Subject: [PATCH 2/3] Linting 1 --- src/curate_gpt/cli.py | 43 +++++-- src/curate_gpt/store/db_adapter.py | 15 ++- src/curate_gpt/store/duckdb_adapter.py | 159 ++++++++++++++----------- src/curate_gpt/store/duckdb_result.py | 16 ++- src/curate_gpt/store/vocab.py | 2 +- tests/store/test_chromadb_adapter.py | 18 ++- tests/store/test_duckdb_adapter.py | 22 ++-- 7 files changed, 156 insertions(+), 119 deletions(-) diff --git a/src/curate_gpt/cli.py b/src/curate_gpt/cli.py index 5d96bbf..9b36a78 100644 --- a/src/curate_gpt/cli.py +++ b/src/curate_gpt/cli.py @@ -107,7 +107,8 @@ def dump( "--docstore_database_type", default="chromadb", show_default=True, - help="Docstore database type.") + help="Docstore database type.", +) model_option = click.option( "-m", "--model", help="Model to use for generation or embedding, e.g. gpt-4." @@ -485,6 +486,7 @@ def all_by_all( if other_path is None: other_path = path results = match_collections(db, collection, other_collection, other_db) + def _obj(obj: Dict, is_left=False) -> Any: if ids_only: obj = {"id": obj["id"]} @@ -525,7 +527,7 @@ def _obj(obj: Dict, is_left=False) -> Any: def matches(id, path, collection, database_type): """Find matches for an ID. - curategpt matches "Continuant" -p duckdb/objects.duckdb -c objects_a -D duckdb + curategpt matches "Continuant" -p duckdb/objects.duckdb -c objects_a -D duckdb """ db = get_store(database_type, path) @@ -972,7 +974,7 @@ def complete( extractor.schema_proxy = schema_manager dac = DragonAgent(knowledge_source=db, extractor=extractor) if docstore_path or docstore_collection: - dac.document_adapter = get_store(docstore_database_type,docstore_path) + dac.document_adapter = get_store(docstore_database_type, docstore_path) dac.document_adapter_collection = docstore_collection if ":" in query: query = yaml.safe_load(query) @@ -1064,7 +1066,7 @@ def update( extractor.schema_proxy = schema_manager dac = DragonAgent(knowledge_source=db, extractor=extractor) if docstore_path or docstore_collection: - dac.document_adapter = get_store(docstore_database_type,docstore_path) + dac.document_adapter = get_store(docstore_database_type, docstore_path) dac.document_adapter_collection = docstore_collection for obj, _s, _meta in db.find(where_q, collection=collection): # click.echo(f"{obj}") @@ -1262,7 +1264,7 @@ def complete_multiple( extractor.schema_proxy = schema_manager dac = DragonAgent(knowledge_source=db, extractor=extractor) if docstore_path or docstore_collection: - dac.document_adapter = get_store(docstore_database_type,docstore_path) + dac.document_adapter = get_store(docstore_database_type, docstore_path) dac.document_adapter_collection = docstore_collection with open(input_file) as f: queries = [l.strip() for l in f.readlines()] @@ -1454,7 +1456,7 @@ def complete_all( extractor.schema_proxy = schema_manager dae = DragonAgent(knowledge_source=db, extractor=extractor) if docstore_path or docstore_collection: - dae.document_adapter = get_store(docstore_database_type,docstore_path) + dae.document_adapter = get_store(docstore_database_type, docstore_path) dae.document_adapter_collection = docstore_collection object_ids = None if id_file: @@ -1541,7 +1543,7 @@ def generate_evaluate( ------- curategpt -v generate-evaluate -c cdr_training -T cdr_test -F statements -m gpt-4 """ - db = get_store(database_type,path) + db = get_store(database_type, path) if schema: schema_manager = SchemaProxy(schema) else: @@ -1556,7 +1558,7 @@ def generate_evaluate( extractor.schema_proxy = schema_manager rage = DragonAgent(knowledge_source=db, extractor=extractor) if docstore_path or docstore_collection: - rage.document_adapter = get_store(docstore_database_type,docstore_path) + rage.document_adapter = get_store(docstore_database_type, docstore_path) rage.document_adapter_collection = docstore_collection hold_back_fields = hold_back_fields.split(",") mask_fields = mask_fields.split(",") if mask_fields else [] @@ -1884,7 +1886,17 @@ def apply_patch(input_file, patch, primary_key): help="jsonpath expression to select objects from the input file.", ) @click.argument("query") -def citeseek(query, path, collection, model, show_references, _continue, select, conversation_id, database_type): +def citeseek( + query, + path, + collection, + model, + show_references, + _continue, + select, + conversation_id, + database_type, +): """Find citations for an object or statement. You can pass in a statement directly as an argument @@ -2063,7 +2075,7 @@ def list_collections(database_type, path, peek: bool, minimal: bool, derived: bo # making sure if o[id] finds nothing we get the full obj r = list(db.peek(cn)) for o, _, _ in r: - if 'id' in o: + if "id" in o: print(f" - {o['id']}") else: print(f" - {o}") @@ -2202,7 +2214,14 @@ def copy_collection(path, collection, target_path, database_type, **kwargs): ) @path_option def split_collection( - path, collection, derived_collection_base, output_path, model, test_id_file, database_type, **kwargs + path, + collection, + derived_collection_base, + output_path, + model, + test_id_file, + database_type, + **kwargs, ): """ Split a collection into test/train/validation. @@ -2224,7 +2243,7 @@ def split_collection( ) logging.info(f"First 10: {kwargs['testing_identifiers'][:10]}") sc = stratify_collection(db, collection, **kwargs) - output_db = get_store(database_type ,output_path) + output_db = get_store(database_type, output_path) if not derived_collection_base: derived_collection_base = collection for sn in ["training", "testing", "validation"]: diff --git a/src/curate_gpt/store/db_adapter.py b/src/curate_gpt/store/db_adapter.py index 6ea7f99..d06c527 100644 --- a/src/curate_gpt/store/db_adapter.py +++ b/src/curate_gpt/store/db_adapter.py @@ -14,8 +14,17 @@ from curate_gpt.store.metadata import CollectionMetadata from curate_gpt.store.schema_proxy import SchemaProxy -from curate_gpt.store.vocab import OBJECT, SEARCH_RESULT, QUERY, PROJECTION, FILE_LIKE, EMBEDDINGS, DOCUMENTS, \ - METADATAS, DEFAULT_COLLECTION +from curate_gpt.store.vocab import ( + OBJECT, + SEARCH_RESULT, + QUERY, + PROJECTION, + FILE_LIKE, + EMBEDDINGS, + DOCUMENTS, + METADATAS, + DEFAULT_COLLECTION, +) logger = logging.getLogger(__name__) @@ -468,4 +477,4 @@ def dump_then_load(self, collection: str = None, target: "DBAdapter" = None): :param target: :return: """ - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/src/curate_gpt/store/duckdb_adapter.py b/src/curate_gpt/store/duckdb_adapter.py index ef98151..a0b96f7 100644 --- a/src/curate_gpt/store/duckdb_adapter.py +++ b/src/curate_gpt/store/duckdb_adapter.py @@ -28,8 +28,20 @@ from curate_gpt.store.duckdb_result import DuckDBSearchResult from curate_gpt.store.metadata import CollectionMetadata from curate_gpt.utils.vector_algorithms import mmr_diversified_search -from curate_gpt.store.vocab import OBJECT, QUERY, PROJECTION, EMBEDDINGS, DOCUMENTS, \ - METADATAS, MODEL_DIMENSIONS, MODELS, OPENAI_MODEL_DIMENSIONS, IDS, SEARCH_RESULT, DISTANCES +from curate_gpt.store.vocab import ( + OBJECT, + QUERY, + PROJECTION, + EMBEDDINGS, + DOCUMENTS, + METADATAS, + MODEL_DIMENSIONS, + MODELS, + OPENAI_MODEL_DIMENSIONS, + IDS, + SEARCH_RESULT, + DISTANCES, +) logger = logging.getLogger(__name__) @@ -99,14 +111,18 @@ def _get_collection_name(self, collection: Optional[str] = None) -> str: """ return self._get_collection(collection) - def _create_table_if_not_exists(self, collection: str, vec_dimension: int, distance: str, model: str = None): + def _create_table_if_not_exists( + self, collection: str, vec_dimension: int, distance: str, model: str = None + ): """ Create a table for the given collection if it does not exist :param collection: :return: """ - logger.info(f"Table {collection} does not exist, creating ...: PARAMS: model: {model}, distance: {distance},\ - vec_dimension: {vec_dimension}") + logger.info( + f"Table {collection} does not exist, creating ...: PARAMS: model: {model}, distance: {distance},\ + vec_dimension: {vec_dimension}" + ) if model is None: model = self.default_model logger.info(f"Model in create_table_if_not_exists: {model}") @@ -178,8 +194,10 @@ def _embedding_function(self, texts: Union[str, List[str]], model: str = None) - self._initialize_openai_client() openai_model = model.split(":", 1)[1] if openai_model == "" or openai_model not in MODELS: - logger.info(f"The model {openai_model} is not " - f"one of {MODELS}. Defaulting to {MODELS[1]}") + logger.info( + f"The model {openai_model} is not " + f"one of {MODELS}. Defaulting to {MODELS[1]}" + ) openai_model = MODELS[1] responses = [ @@ -233,7 +251,9 @@ def upsert(self, objs: Union[OBJECT, Iterable[OBJECT]], **kwargs): logger.info(f"model in upsert: {kwargs.get('model')}, distance: {self.distance_metric}") if collection not in self.list_collection_names(): vec_dimension = self._get_embedding_dimension(kwargs.get("model")) - self._create_table_if_not_exists(collection, vec_dimension, model=kwargs.get("model"), distance=self.distance_metric) + self._create_table_if_not_exists( + collection, vec_dimension, model=kwargs.get("model"), distance=self.distance_metric + ) ids = [self._id(o, self.id_field) for o in objs] existing_ids = set() for id_ in ids: @@ -254,16 +274,16 @@ def upsert(self, objs: Union[OBJECT, Iterable[OBJECT]], **kwargs): self.insert(objs_to_insert, **kwargs) def _process_objects( - self, - objs: Union[OBJECT, Iterable[OBJECT]], - collection: str = None, - batch_size: int = None, - object_type: str = None, - model: str = None, - distance: str = None, - text_field: Union[str, Callable] = None, - method: str = "insert", - **kwargs, + self, + objs: Union[OBJECT, Iterable[OBJECT]], + collection: str = None, + batch_size: int = None, + object_type: str = None, + model: str = None, + distance: str = None, + text_field: Union[str, Callable] = None, + method: str = "insert", + **kwargs, ): """ Process objects by inserting, updating or upserting them into the collection @@ -283,7 +303,9 @@ def _process_objects( logger.info(f"(process_objects: Model: {model}, vec_dimension: {self.vec_dimension}") if collection not in self.list_collection_names(): logger.info(f"(process)Creating table for collection {collection}") - self._create_table_if_not_exists(collection, self.vec_dimension, model=model, distance=distance) + self._create_table_if_not_exists( + collection, self.vec_dimension, model=model, distance=distance + ) if isinstance(objs, Iterable) and not isinstance(objs, str): objs = list(objs) else: @@ -320,7 +342,9 @@ def _process_objects( self.conn.execute("COMMIT;") except Exception as e: self.conn.execute("ROLLBACK;") - logger.error(f"Transaction failed: {e}, default model: {self.default_model}, model used: {model}, len(embeddings): {len(embeddings[0])}") + logger.error( + f"Transaction failed: {e}, default model: {self.default_model}, model used: {model}, len(embeddings): {len(embeddings[0])}" + ) raise finally: self.create_index(collection) @@ -342,14 +366,14 @@ def remove_collection(self, collection: str = None, exists_ok=False, **kwargs): self.conn.execute(f"DROP TABLE IF EXISTS {safe_collection_name}") def search( - self, - text: str, - where: QUERY = None, - collection: str = None, - limit: int = 10, - relevance_factor: float = None, - include=None, - **kwargs, + self, + text: str, + where: QUERY = None, + collection: str = None, + limit: int = 10, + relevance_factor: float = None, + include=None, + **kwargs, ) -> Iterator[SEARCH_RESULT]: """ Search for objects in the collection that match the given text @@ -369,18 +393,19 @@ def search( limit=limit, relevance_factor=relevance_factor, include=include, - **kwargs) + **kwargs, + ) def _search( - self, - text: str, - where: QUERY = None, - collection: str = None, - limit: int = 10, - relevance_factor: float = None, - model: str = None, - include=None, - **kwargs, + self, + text: str, + where: QUERY = None, + collection: str = None, + limit: int = 10, + relevance_factor: float = None, + model: str = None, + include=None, + **kwargs, ) -> Iterator[SEARCH_RESULT]: if relevance_factor is not None and relevance_factor < 1.0: yield from self._diversified_search( @@ -442,14 +467,14 @@ def _search( yield from self.parse_duckdb_result(results, include) def _diversified_search( - self, - text: str, - where: QUERY = None, - collection: str = None, - limit: int = 10, - relevance_factor: float = 0.5, - include=None, - **kwargs, + self, + text: str, + where: QUERY = None, + collection: str = None, + limit: int = 10, + relevance_factor: float = 0.5, + include=None, + **kwargs, ) -> Iterator[SEARCH_RESULT]: if limit is None: limit = 10 @@ -489,7 +514,6 @@ def _diversified_search( for i in reranked_indices: yield results[i] - def list_collection_names(self): """ List the names of all collections in the database @@ -499,7 +523,7 @@ def list_collection_names(self): return [row[0] for row in result] def collection_metadata( - self, collection_name: Optional[str] = None, include_derived=False, **kwargs + self, collection_name: Optional[str] = None, include_derived=False, **kwargs ) -> Optional[CollectionMetadata]: """ Get the metadata for the collection @@ -552,12 +576,12 @@ def update_collection_metadata(self, collection: str, **kwargs): UPDATE {safe_collection_name} SET metadata = ? WHERE id = '__metadata__' """, - [metadata_json] + [metadata_json], ) return current_metadata def set_collection_metadata( - self, collection_name: Optional[str], metadata: CollectionMetadata, **kwargs + self, collection_name: Optional[str], metadata: CollectionMetadata, **kwargs ): """ Set the metadata for the collection @@ -581,13 +605,13 @@ def set_collection_metadata( ) def find( - self, - where: QUERY = None, - projection: PROJECTION = None, - collection: str = None, - include=None, - limit: int = 10, - **kwargs, + self, + where: QUERY = None, + projection: PROJECTION = None, + collection: str = None, + include=None, + limit: int = 10, + **kwargs, ) -> Iterator[SEARCH_RESULT]: """ Find objects in the collection that match the given query and projection @@ -669,7 +693,9 @@ def lookup(self, id: str, collection: str = None, include=None, **kwargs) -> OBJ ) return search_result.to_dict().get(METADATAS) - def peek(self, collection: str = None, limit=5, include=None, **kwargs) -> Iterator[SEARCH_RESULT]: + def peek( + self, collection: str = None, limit=5, include=None, **kwargs + ) -> Iterator[SEARCH_RESULT]: """ Peek at the first N objects in the collection :param collection: @@ -711,9 +737,9 @@ def get_raw_objects(self, collection) -> Iterator[Dict]: yield json.loads(result[0]) def dump_then_load( - self, - collection: str = None, - target: DBAdapter = None, + self, + collection: str = None, + target: DBAdapter = None, ): """ Dump the collection to a file and then load it into the target adapter @@ -737,20 +763,13 @@ def dump_then_load( # in case it exists already, remove target.remove_collection(collection, exists_ok=True) # using same collection name in target database - target._create_table_if_not_exists( - collection, - vec_dimension, - distance, - model - ) + target._create_table_if_not_exists(collection, vec_dimension, distance, model) target.set_collection_metadata(collection, metadata) batch_size = 5000 for i in range(0, len(list(result)), batch_size): - batch = result[i: i + batch_size] + batch = result[i : i + batch_size] target.insert(batch, collection=collection) - - @staticmethod def kill_process(pid): """ diff --git a/src/curate_gpt/store/duckdb_result.py b/src/curate_gpt/store/duckdb_result.py index e49b702..91750e2 100644 --- a/src/curate_gpt/store/duckdb_result.py +++ b/src/curate_gpt/store/duckdb_result.py @@ -16,9 +16,7 @@ class DuckDBSearchResult(BaseModel): metadatas: Optional[Dict[str, Any]] = None embeddings: Optional[List[float]] = None documents: Optional[str] = None - distances: Optional[float] = ( - 0 # TODO: technically this is simple cosim similarity - ) + distances: Optional[float] = 0 # TODO: technically this is simple cosim similarity include: Optional[Set[str]] = None def to_json(self, indent: int = 2): @@ -34,15 +32,15 @@ def __repr__(self, indent: int = 2): def __iter__(self) -> Iterator[SEARCH_RESULT]: # TODO vocab.py for 'VARIABLES', but circular import - metadata_include = 'metadatas' in self.include if self.include else True - embeddings_include = 'embeddings' in self.include if self.include else True - documents_include = 'documents' in self.include if self.include else True - similarity_include = 'distances' in self.include if self.include else True + metadata_include = "metadatas" in self.include if self.include else True + embeddings_include = "embeddings" in self.include if self.include else True + documents_include = "documents" in self.include if self.include else True + similarity_include = "distances" in self.include if self.include else True obj = self.metadatas if metadata_include else {} meta = { - '_embeddings': self.embeddings if embeddings_include else None, - 'documents': self.documents if documents_include else None + "_embeddings": self.embeddings if embeddings_include else None, + "documents": self.documents if documents_include else None, } distance = self.distances if similarity_include else None diff --git a/src/curate_gpt/store/vocab.py b/src/curate_gpt/store/vocab.py index cd3f308..e835a92 100644 --- a/src/curate_gpt/store/vocab.py +++ b/src/curate_gpt/store/vocab.py @@ -25,4 +25,4 @@ "text-embedding-3-small": 1536, "text-embedding-3-large": 3072, } -MODELS = ["text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large"] \ No newline at end of file +MODELS = ["text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large"] diff --git a/tests/store/test_chromadb_adapter.py b/tests/store/test_chromadb_adapter.py index 8fc9f93..20232a7 100644 --- a/tests/store/test_chromadb_adapter.py +++ b/tests/store/test_chromadb_adapter.py @@ -183,7 +183,7 @@ def test_ontology_matches(ontology_db): first_obj = results[0][0] new_definition = "A beach with palm trees" updated_obj = { - "id": first_obj['id'], + "id": first_obj["id"], "label": first_obj["label"], "definition": new_definition, "aliases": first_obj["aliases"], @@ -197,20 +197,18 @@ def test_ontology_matches(ontology_db): # if you wish to update an ID you must insert a new one ontology_db.update([updated_obj], collection=collection) # verify update - updated_res = ontology_db.lookup(first_obj['id'], collection=collection) - assert updated_res['id'] == first_obj['id'] - assert updated_res['definition'] == new_definition - assert updated_res['label'] == first_obj['label'] + updated_res = ontology_db.lookup(first_obj["id"], collection=collection) + assert updated_res["id"] == first_obj["id"] + assert updated_res["definition"] == new_definition + assert updated_res["label"] == first_obj["label"] # test upsert - new_obj_insert = { - "id": "Palm Beach", - "key": "value" - } + new_obj_insert = {"id": "Palm Beach", "key": "value"} ontology_db.upsert([new_obj_insert], collection="test_collection") # verify upsert new_results = ontology_db.lookup("Palm Beach", collection="test_collection") - assert new_results['key'] == "value" + assert new_results["key"] == "value" + @pytest.mark.parametrize( "where,num_expected,limit", diff --git a/tests/store/test_duckdb_adapter.py b/tests/store/test_duckdb_adapter.py index 0f23527..bbc6443 100644 --- a/tests/store/test_duckdb_adapter.py +++ b/tests/store/test_duckdb_adapter.py @@ -93,7 +93,6 @@ def _id(obj, dist, meta): assert len(results2) == 1 - def test_the_embedding_function(simple_schema_manager, example_texts): db = DuckDBAdapter(OUTPUT_DUCKDB_PATH) db.conn.execute("DROP TABLE IF EXISTS test_collection") @@ -165,7 +164,7 @@ def test_ontology_matches(ontology_db): assert len(results) == 10 first_obj = results[0][0] - print("the id", first_obj['id']) + print("the id", first_obj["id"]) first_meta = results[0][2] new_id, new_definition = "Palm Beach", "A beach with palm trees" updated_obj = { @@ -184,20 +183,17 @@ def test_ontology_matches(ontology_db): ontology_db.update([updated_obj], collection=collection) # verify update updated_res = ontology_db.lookup(new_id, collection) - assert updated_res['id'] == new_id - assert updated_res['definition'] == new_definition - assert updated_res['label'] == first_obj['label'] + assert updated_res["id"] == new_id + assert updated_res["definition"] == new_definition + assert updated_res["label"] == first_obj["label"] # test upsert - new_obj_insert = { - "id": "Palm Beach", - "key": "value" - } + new_obj_insert = {"id": "Palm Beach", "key": "value"} ontology_db.upsert([new_obj_insert], collection="test_collection") # verify upsert new_results = ontology_db.lookup("Palm Beach", collection="test_collection") assert new_results["id"] == "Palm Beach" - assert new_results['key'] == "value" + assert new_results["key"] == "value" @pytest.mark.parametrize( @@ -215,6 +211,7 @@ def test_where_queries(loaded_ontology_db, where, num_expected, limit, include): ) assert len(results) == num_expected + @pytest.mark.parametrize( "batch_size", [ @@ -237,10 +234,7 @@ def test_load_in_batches(ontology_db, batch_size): # end = time.time() # print(f"Time to insert {len(list(view.objects()))} objects with batch of {batch_size}: {end - start}") - objs = list( - ontology_db.find(collection="other_collection", limit=2000 - ) - ) + objs = list(ontology_db.find(collection="other_collection", limit=2000)) assert len(objs) > 100 From 31e1b2b392ae3aceda0cc586d671d6735d43ed1e Mon Sep 17 00:00:00 2001 From: caufieldjh Date: Tue, 20 Aug 2024 16:14:19 -0400 Subject: [PATCH 3/3] A bit more linting --- src/curate_gpt/store/db_metadata.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/curate_gpt/store/db_metadata.py b/src/curate_gpt/store/db_metadata.py index 992b266..a75d39f 100644 --- a/src/curate_gpt/store/db_metadata.py +++ b/src/curate_gpt/store/db_metadata.py @@ -1,8 +1,7 @@ from pathlib import Path -from pydantic import BaseModel, ConfigDict import yaml -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict class DBSettings(BaseModel):