Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add config to each pydantic BaseModel; fix #59 #60

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/curate_gpt/agents/bootstrap_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@

import yaml
from jinja2 import Template
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict

from curate_gpt.agents.base_agent import BaseAgent
from curate_gpt.conf.prompts import PROMPTS_DIR
from curate_gpt.extract import AnnotatedObject


class KnowledgeBaseSpecification(BaseModel):
model_config = ConfigDict(protected_namespaces=())
kb_name: str
description: str
attributes: str
Expand Down
4 changes: 3 additions & 1 deletion src/curate_gpt/agents/chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import yaml
from llm import Conversation
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict

from curate_gpt.agents.base_agent import BaseAgent
from curate_gpt.utils.tokens import estimate_num_tokens, max_tokens_by_model
Expand All @@ -23,6 +23,8 @@ class ChatResponse(BaseModel):
TODO: Rename class to indicate that it is provenance-enabled chat
"""

model_config = ConfigDict(protected_namespaces=())

body: str
"""Text of response."""

Expand Down
8 changes: 7 additions & 1 deletion src/curate_gpt/agents/concept_recognition_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from enum import Enum
from typing import Dict, List, Optional, Tuple

from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict

from curate_gpt.agents.base_agent import BaseAgent

Expand All @@ -18,6 +18,8 @@
class Span(BaseModel):
"""An individual span of text containing a single concept."""

model_config = ConfigDict(protected_namespaces=())

text: str

start: Optional[int] = None
Expand All @@ -36,6 +38,8 @@ class Span(BaseModel):
class GroundingResult(BaseModel):
"""Result of grounding text."""

model_config = ConfigDict(protected_namespaces=())

input_text: str
"""Text that is supplied for grounding, assumed to contain a single context."""

Expand All @@ -62,6 +66,8 @@ class AnnotationMethod(str, Enum):
class AnnotatedText(BaseModel):
"""In input text annotated with concept instances."""

model_config = ConfigDict(protected_namespaces=())

input_text: str
"""Text that is supplied for annotation."""

Expand Down
4 changes: 3 additions & 1 deletion src/curate_gpt/agents/dase_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass, field
from typing import Any, ClassVar, Dict, List, Optional, Union

from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict

from curate_gpt.agents.base_agent import BaseAgent
from curate_gpt.agents.chat_agent import ChatResponse
Expand All @@ -17,6 +17,8 @@


class PredictedFieldValue(BaseModel):

model_config = ConfigDict(protected_namespaces=())
id: str
original_id: Optional[str] = None
predicted_value: Optional[str] = None
Expand Down
4 changes: 3 additions & 1 deletion src/curate_gpt/agents/dragon_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any, ClassVar, Dict, Iterable, List, Optional, Tuple, Union

import yaml
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict

from curate_gpt.agents.base_agent import BaseAgent
from curate_gpt.extract import AnnotatedObject
Expand All @@ -17,6 +17,8 @@


class PredictedFieldValue(BaseModel):

model_config = ConfigDict(protected_namespaces=())
id: str
original_id: Optional[str] = None
predicted_value: Optional[str] = None
Expand Down
5 changes: 4 additions & 1 deletion src/curate_gpt/agents/mapping_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import inflection
import yaml
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict

from curate_gpt.agents.base_agent import BaseAgent
from curate_gpt.store.db_adapter import SEARCH_RESULT
Expand Down Expand Up @@ -46,12 +46,15 @@ class MappingPredicate(str, Enum):
class Mapping(BaseModel):
"""Response from chat engine."""

model_config = ConfigDict(protected_namespaces=())
subject_id: str
object_id: str
predicate_id: Optional[MappingPredicate] = None


class MappingSet(BaseModel):

model_config = ConfigDict(protected_namespaces=())
mappings: List[Mapping]
prompt: str = None
response_text: str = None
Expand Down
4 changes: 3 additions & 1 deletion src/curate_gpt/app/cart.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import Dict, List, Optional, Union

from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict


class CartItem(BaseModel):
"""
A cart item is a single item in a cart
"""

model_config = ConfigDict(protected_namespaces=())
object: Union[Dict, BaseModel]
object_type: Optional[str] = None
source: Optional[str] = None
Expand All @@ -19,6 +20,7 @@ class Cart(BaseModel):
A cart is a list of items that can be added to or removed from
"""

model_config = ConfigDict(protected_namespaces=())
items: List[CartItem] = []

@property
Expand Down
6 changes: 5 additions & 1 deletion src/curate_gpt/evaluation/evaluation_datamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
from pathlib import Path
from typing import Dict, List, Optional, Union

from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict


class StratifiedCollection(BaseModel):
"""
A collection of objects that have been split into training, test, and validation sets.
"""

model_config = ConfigDict(protected_namespaces=())
source: str = None
training_set: List[Dict] = None
testing_set: List[Dict] = None
Expand All @@ -33,6 +34,8 @@ class AggregationMethod(str, Enum):


class ClassificationMetrics(BaseModel):

model_config = ConfigDict(protected_namespaces=())
precision: float
recall: float
f1_score: float
Expand All @@ -49,6 +52,7 @@ class Task(BaseModel):
A task to be run by the evaluation runner.
"""

model_config = ConfigDict(protected_namespaces=())
model_name: str = "gpt-3.5-turbo"
embedding_model_name: str = "openai:"
generate_background: bool = False
Expand Down
3 changes: 3 additions & 0 deletions src/curate_gpt/extract/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import llm
from linkml_runtime import SchemaView
from pydantic import BaseModel as BaseModel
from pydantic import ConfigDict

from curate_gpt.store.schema_proxy import SchemaProxy

Expand All @@ -20,6 +21,8 @@ class AnnotatedObject(BaseModel):
Annotated object shadows a basic dictionary object
"""

model_config = ConfigDict(protected_namespaces=())

# object: Union[Dict[str, Any], List[Dict[str, Any]]] = {} - TODO: pydantic bug?
object: Any = {}
annotations: Dict[str, Any] = {}
Expand Down
5 changes: 4 additions & 1 deletion src/curate_gpt/store/db_metadata.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from pathlib import Path

import yaml
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict


class DBSettings(BaseModel):

model_config = ConfigDict(protected_namespaces=())

name: str = "duckdb"
"""Name of the database."""

Expand Down
5 changes: 4 additions & 1 deletion src/curate_gpt/store/duckdb_result.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple

from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict

SEARCH_RESULT = Tuple[Dict[str, Any], Dict, float, Optional[Dict]]


class DuckDBSearchResult(BaseModel):

model_config = ConfigDict(protected_namespaces=())

ids: Optional[str] = None
metadatas: Optional[Dict[str, Any]] = None
embeddings: Optional[List[float]] = None
Expand Down
8 changes: 7 additions & 1 deletion src/curate_gpt/store/in_memory_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass, field
from typing import ClassVar, Dict, Iterable, Iterator, List, Optional, Union

from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict

from curate_gpt import DBAdapter
from curate_gpt.store.db_adapter import OBJECT, PROJECTION, QUERY, SEARCH_RESULT
Expand All @@ -14,6 +14,9 @@


class Collection(BaseModel):

model_config = ConfigDict(protected_namespaces=())

objects: List[Dict] = []
metadata: Dict = {}

Expand All @@ -25,6 +28,9 @@ def delete(self, key_value: str, key: str) -> None:


class CollectionIndex(BaseModel):

model_config = ConfigDict(protected_namespaces=())

collections: Dict[str, Collection] = {}

def get_collection(self, name: str) -> Collection:
Expand Down
4 changes: 3 additions & 1 deletion src/curate_gpt/store/metadata.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict, Optional

from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict


class CollectionMetadata(BaseModel):
Expand All @@ -10,6 +10,8 @@ class CollectionMetadata(BaseModel):
This is an open class, so additional metadata can be added.
"""

model_config = ConfigDict(protected_namespaces=())

name: Optional[str] = None
"""Name of the collection"""

Expand Down
4 changes: 3 additions & 1 deletion src/curate_gpt/store/schema_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from linkml_runtime import SchemaView
from linkml_runtime.linkml_model import SchemaDefinition
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict


@dataclass
Expand All @@ -13,6 +13,8 @@ class SchemaProxy:
Manage connection to a schema
"""

model_config = ConfigDict(protected_namespaces=())

schema_source: Union[str, Path, SchemaDefinition] = None
_pydantic_root_model: BaseModel = None
_schemaview: SchemaView = None
Expand Down
5 changes: 4 additions & 1 deletion src/curate_gpt/utils/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
from copy import copy
from typing import Any, Dict, List, Optional, Union

from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict


class Outcome(BaseModel):

model_config = ConfigDict(protected_namespaces=())

prediction: Union[Dict[str, Any], List[Dict[str, Any]]] = {}
expected: Union[Dict[str, Any], List[Dict[str, Any]]] = {}
parameters: Dict[str, Any] = {}
Expand Down
6 changes: 4 additions & 2 deletions src/curate_gpt/wrappers/general/github_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Chat with a Google Drive."""
"""Chat with issues from a GitHub repository."""

import logging
import os
Expand All @@ -8,20 +8,22 @@

import requests
import requests_cache
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict

from curate_gpt.wrappers.base_wrapper import BaseWrapper

logger = logging.getLogger(__name__)


class Comment(BaseModel):
model_config = ConfigDict(protected_namespaces=())
id: str
user: str = None
body: str = None


class Issue(BaseModel):
model_config = ConfigDict(protected_namespaces=())
id: str
number: int = None
type: str = None
Expand Down
5 changes: 4 additions & 1 deletion src/curate_gpt/wrappers/ontology/ontology.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List

from pydantic import BaseModel, Extra
from pydantic import BaseModel, ConfigDict, Extra


class Relationship(BaseModel):
Expand All @@ -10,6 +10,7 @@ class Relationship(BaseModel):
Corresponds to an edge in an OBO graph.
"""

model_config = ConfigDict(protected_namespaces=())
predicate: str
target: str

Expand All @@ -21,6 +22,7 @@ class OntologyClass(BaseModel, extra=Extra.allow):
Corresponds to a node in an OBO graph.
"""

model_config = ConfigDict(protected_namespaces=())
id: str
label: str = None
definition: str = None
Expand All @@ -37,4 +39,5 @@ class Ontology(BaseModel):
Corresponds to an OBO graph.
"""

model_config = ConfigDict(protected_namespaces=())
elements: List[OntologyClass] = None
4 changes: 3 additions & 1 deletion tests/extract/test_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest
from linkml_runtime.utils.schema_builder import SchemaBuilder
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict

from curate_gpt.extract.basic_extractor import BasicExtractor
from curate_gpt.extract.extractor import AnnotatedObject
Expand All @@ -12,11 +12,13 @@


class Occupation(BaseModel):
model_config = ConfigDict(protected_namespaces=())
category: str
current: bool


class Person(BaseModel):
model_config = ConfigDict(protected_namespaces=())
name: str
age: int
occupations: List[Occupation]
Expand Down
Loading