Skip to content

Commit

Permalink
Merge pull request #85 from iQuxLE/tests_with_apikey_fix
Browse files Browse the repository at this point in the history
Fixing tests with provided OPENAI_API_KEY
  • Loading branch information
caufieldjh committed Sep 20, 2024
2 parents d549679 + b071835 commit e35bed8
Show file tree
Hide file tree
Showing 16 changed files with 193 additions and 111 deletions.
7 changes: 4 additions & 3 deletions src/curate_gpt/evaluation/dae_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
logger = logging.getLogger(__name__)


# TODO: missing abstract class evaluate_object, causes src/tests/evaluation/test_runner to fail
@dataclass
class DatabaseAugmentedCompletionEvaluator(BaseEvaluator):
"""
Expand Down Expand Up @@ -50,8 +49,7 @@ def evaluate(
"""
agent = self.agent
db = agent.knowledge_source
# TODO: use get()
test_objs = list(db.peek(collection=test_collection, limit=num_tests))
test_objs = list(db.find(collection=test_collection))
if any(obj for obj in test_objs if any(f not in obj for f in self.fields_to_predict)):
logger.info("Alternate strategy to get test objs; query whole collection")
test_objs = db.peek(collection=test_collection, limit=1000000)
Expand Down Expand Up @@ -133,3 +131,6 @@ def evaluate(
report_tsv_file.flush()
aggregated = aggregate_metrics(all_metrics)
return aggregated

def evaluate_object(self, obj, **kwargs) -> ClassificationMetrics:
pass
2 changes: 1 addition & 1 deletion src/curate_gpt/evaluation/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def stratify_collection_to_store(
size = len(objs)
cn = f"{collection}_{sn}_{size}"
collections[sn] = cn
logging.info(f"Writing {size} objects to {cn}")
logger.info(f"Writing {size} objects to {cn}")
if cn in existing_collections:
logger.info(f"Collection {cn} already exists")
if not force:
Expand Down
2 changes: 1 addition & 1 deletion src/curate_gpt/extract/openai_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def extract(
logger.debug(f"RESPONSE = {response}")
# print(response)
choice = response.choices[0]
message = choice["message"]
message = choice.message
if "function_call" not in message:
if self.raise_error_if_unparsable:
raise ValueError("No function call in response")
Expand Down
4 changes: 3 additions & 1 deletion src/curate_gpt/store/duckdb_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,8 +1064,10 @@ def _get_embedding_dimension(self, model_name: str) -> int:
if isinstance(model_name, str):
if model_name.startswith("openai:"):
model_key = model_name.split("openai:", 1)[1]
if model_key == "" or model_key not in MODEL_MAP.keys():
model_key = DEFAULT_OPENAI_MODEL
model_info = MODEL_MAP.get(model_key, DEFAULT_OPENAI_MODEL)
return MODEL_MAP[model_info][1]
return model_info[1]
else:
return MODEL_MAP[DEFAULT_OPENAI_MODEL][1]

Expand Down
6 changes: 3 additions & 3 deletions src/curate_gpt/wrappers/clinical/clinvar_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def objects_from_dict(self, results: Dict) -> List[Dict]:
for r in results["eSummaryResult"]["DocumentSummarySet"]["DocumentSummary"]:
obj = {}
obj["id"] = "clinvar:" + r["accession"]
obj["clinical_significance"] = r["clinical_significance"]["description"]
obj["clinical_significance_status"] = r["clinical_significance"]["review_status"]
obj["clinical_significance"] = r["germline_classification"]["description"]
obj["clinical_significance_status"] = r["germline_classification"]["review_status"]
obj["gene_sort"] = r["gene_sort"]
if "genes" in r and r["genes"]:
if "gene" in r["genes"]:
Expand All @@ -46,7 +46,7 @@ def objects_from_dict(self, results: Dict) -> List[Dict]:
obj["protein_change"] = r["protein_change"]
obj["title"] = r["title"]
obj["traits"] = [
self._trait_from_dict(t) for t in r["trait_set"]["trait"] if isinstance(t, dict)
self._trait_from_dict(t) for t in r.get("trait_set", {}).get("trait", []) if isinstance(t, dict)
]
objs.append(obj)
return objs
Expand Down
8 changes: 7 additions & 1 deletion tests/agents/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,19 @@
from tests import INPUT_DBS


# TODO: this has to be reviewed, isolate more, dont use one db for multiple tests
# - the current setup does not allow reset
# - set collection is v vulnerable, easier setting/creating new col for each test
# - using a loaded a test ontology can also be mocked for ease
# - ? use structure from tests/wrapper (vstore,wrapper fixtures)
# - or create collection in each test to use and load all collections with the whole data and reset/remove collection after
@pytest.fixture
def go_test_chroma_db() -> ChromaDBAdapter:
"""
Fixture for a ChromaDBAdapter instance with the test ontology loaded.
Note: the chromadb is not checked into github - instead,
this relies on test_chromadb_dapter.test_store to create the test db.
this relies on test_chromadb_adapter.test_store to create the test db.
"""
db = ChromaDBAdapter(str(INPUT_DBS / "go-nucleus-chroma"))
db.schema_proxy = SchemaProxy(ONTOLOGY_MODEL_PATH)
Expand Down
1 change: 1 addition & 0 deletions tests/store/test_duckdb_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def test_the_embedding_function_variations(
expected_name = "test_collection"
else:
# Specific case: Collection specified, model may or may not be specified
print("\n\n",model,"\n\n")
db.insert(objs, collection=collection, model=model)
expected_model = model if model else "all-MiniLM-L6-v2"
expected_name = collection
Expand Down
25 changes: 25 additions & 0 deletions tests/utils/helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from pathlib import Path

from src.curate_gpt.store import ChromaDBAdapter

DEBUG_MODE = False

def create_db_dir(tmp_path, out_dir) -> Path:
"""Creates a temporary directory or uses the provided debug directory."""
if DEBUG_MODE:
temp_dir = out_dir
if not temp_dir.exists():
temp_dir.mkdir(parents=True, exist_ok=True)
return temp_dir
else:
return tmp_path


def setup_db(temp_dir: Path) -> ChromaDBAdapter:
"""Sets up the DBAdapter and optionally resets it."""
# TODO: for now ChromaDB, later add DuckDB
# db = get_store("chromadb", str(temp_dir))
db = ChromaDBAdapter(str(temp_dir))
# reset only when we use the db in try block, or in test
return db

20 changes: 14 additions & 6 deletions tests/wrappers/test_bioportal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import pytest

from curate_gpt import ChromaDBAdapter
from curate_gpt.extract import BasicExtractor
from curate_gpt.wrappers.ontology.bioportal_wrapper import BioportalWrapper
from curate_gpt.wrappers.ontology.ontology_wrapper import OntologyWrapper
from tests import OUTPUT_DIR
from tests.utils.helper import DEBUG_MODE, create_db_dir, setup_db

TEMP_OAKVIEW_DB = OUTPUT_DIR / "bioportal_tmp"

Expand All @@ -18,15 +18,23 @@


@pytest.fixture
def vstore() -> OntologyWrapper:
db = ChromaDBAdapter(str(TEMP_OAKVIEW_DB))
def vstore(tmp_path) -> OntologyWrapper:
tmp_dir = create_db_dir(tmp_path=tmp_path, out_dir=TEMP_OAKVIEW_DB)
db = setup_db(tmp_dir)
db.reset()
view = BioportalWrapper(local_store=db, extractor=BasicExtractor())
assert view.fetch_definitions is False
try:
view = BioportalWrapper(local_store=db, extractor=BasicExtractor())
assert view.fetch_definitions is False
yield view
except Exception as e:
raise e
finally:
if not DEBUG_MODE:
db.reset()

# view = BioportalView(oak_adapter=adapter, local_store=db, extractor=BasicExtractor())
# view.fetch_definitions = False
# view.fetch_relationships = False
return view


@pytest.mark.skip(reason="OAK bp wrapper doesn't support definitions yets")
Expand Down
20 changes: 13 additions & 7 deletions tests/wrappers/test_clinvar.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import logging
import shutil
import time

import pytest
import requests
import yaml

from curate_gpt import ChromaDBAdapter
from curate_gpt.agents.chat_agent import ChatAgent
from curate_gpt.agents.dragon_agent import DragonAgent
from curate_gpt.extract import BasicExtractor
from curate_gpt.wrappers.clinical.clinvar_wrapper import ClinVarWrapper
from tests import INPUT_DIR, OUTPUT_DIR
from tests.store.conftest import requires_openai_api_key
from tests.utils.helper import DEBUG_MODE, create_db_dir, setup_db

TEMP_DB = OUTPUT_DIR / "obj_tmp"

Expand All @@ -29,12 +29,18 @@ def test_clinvar_transform():


@pytest.fixture
def wrapper() -> ClinVarWrapper:
shutil.rmtree(TEMP_DB, ignore_errors=True)
db = ChromaDBAdapter(str(TEMP_DB))
def wrapper(tmp_path) -> ClinVarWrapper:
temp_dir = create_db_dir(tmp_path, TEMP_DB)
db = setup_db(temp_dir)
extractor = BasicExtractor()
db.reset()
return ClinVarWrapper(local_store=db, extractor=extractor)
try:
yield ClinVarWrapper(local_store=db, extractor=extractor)
except requests.exceptions.ConnectionError as e:
logger.error(f"Connection error occurred: {e}")
raise e
finally:
if not DEBUG_MODE:
db.reset()


@requires_openai_api_key
Expand Down
20 changes: 13 additions & 7 deletions tests/wrappers/test_evidence_agent.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import logging
import shutil
from typing import Type

import pytest
import yaml

from curate_gpt import ChromaDBAdapter
from curate_gpt.agents.evidence_agent import EvidenceAgent
from curate_gpt.extract import BasicExtractor
from curate_gpt.wrappers import BaseWrapper
from curate_gpt.wrappers.literature import PubmedWrapper, WikipediaWrapper
from tests import OUTPUT_DIR
from tests.store.conftest import requires_openai_api_key
from tests.utils.helper import DEBUG_MODE, create_db_dir, setup_db

TEMP_PUBMED_DB = OUTPUT_DIR / "pmid_tmp"

Expand All @@ -30,12 +29,19 @@
WikipediaWrapper,
],
)
def test_evidence_inference(source: Type[BaseWrapper]):
shutil.rmtree(TEMP_PUBMED_DB, ignore_errors=True)
db = ChromaDBAdapter(str(TEMP_PUBMED_DB))
def test_evidence_inference(tmp_path, source: Type[BaseWrapper]):
tmp_dir = create_db_dir(tmp_path=tmp_path, out_dir=TEMP_PUBMED_DB)
db = setup_db(tmp_dir)
extractor = BasicExtractor()
db.reset()
pubmed = source(local_store=db, extractor=extractor)
try:
pubmed = source(local_store=db, extractor=extractor)
except Exception as e:
raise e
finally:
if not DEBUG_MODE:
if tmp_dir.exists():
db.reset()

ea = EvidenceAgent(chat_agent=pubmed)
obj = {
"label": "acinar cells of the salivary gland",
Expand Down
15 changes: 7 additions & 8 deletions tests/wrappers/test_ncbi_biosample.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
import logging
import shutil
import time

import yaml

from curate_gpt import ChromaDBAdapter
from curate_gpt.agents.chat_agent import ChatAgent
from curate_gpt.extract import BasicExtractor
from curate_gpt.wrappers.investigation.ncbi_biosample_wrapper import NCBIBiosampleWrapper
from tests import OUTPUT_DIR
from tests.store.conftest import requires_openai_api_key
from tests.utils.helper import create_db_dir, setup_db

TEMP_BIOSAMPLE_DB = OUTPUT_DIR / "biosample_tmp"

logger = logging.getLogger(__name__)


@requires_openai_api_key
def test_biosample_search():
shutil.rmtree(TEMP_BIOSAMPLE_DB, ignore_errors=True)
db = ChromaDBAdapter(str(TEMP_BIOSAMPLE_DB))
def test_biosample_search(tmp_path):
temp_dir = create_db_dir(tmp_path, TEMP_BIOSAMPLE_DB)
db = setup_db(temp_dir)
extractor = BasicExtractor()
db.reset()
wrapper = NCBIBiosampleWrapper(local_store=db, extractor=extractor)
Expand All @@ -33,9 +32,9 @@ def test_biosample_search():


@requires_openai_api_key
def test_biosample_chat():
shutil.rmtree(TEMP_BIOSAMPLE_DB, ignore_errors=True)
db = ChromaDBAdapter(str(TEMP_BIOSAMPLE_DB))
def test_biosample_chat(tmp_path):
temp_dir = create_db_dir(tmp_path, TEMP_BIOSAMPLE_DB)
db = setup_db(temp_dir)
extractor = BasicExtractor()
db.reset()
wrapper = NCBIBiosampleWrapper(local_store=db, extractor=extractor)
Expand Down
45 changes: 23 additions & 22 deletions tests/wrappers/test_ontology.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
import logging
import os
import shutil
import tempfile
from pprint import pprint

import pytest
from oaklib import get_adapter
from oaklib.datamodels.obograph import GraphDocument

from curate_gpt import ChromaDBAdapter
from curate_gpt.extract import BasicExtractor
from curate_gpt.wrappers.ontology.ontology_wrapper import OntologyWrapper
from tests import INPUT_DIR, OUTPUT_DIR
from tests.store.conftest import requires_openai_api_key
from tests.utils.helper import DEBUG_MODE, create_db_dir, setup_db

TEMP_OAKVIEW_DB = OUTPUT_DIR / "oaktmp"
TEMP_OAKVIEW_DB2 = OUTPUT_DIR / "oaktmp2"
TEMP_OAK_OBJ = OUTPUT_DIR / "oak_tmp_obj"
TEMP_OAK_IND = OUTPUT_DIR / "oak_tmp_ind"
TEMP_OAK_SEARCH = OUTPUT_DIR / "oak_tmp_search"

# logger = logging.getLogger(__name__)

Expand All @@ -25,20 +23,27 @@


@pytest.fixture
def vstore():
with tempfile.TemporaryDirectory() as temp_dir:
db_path = os.path.join(temp_dir, "test_db")
adapter = get_adapter(INPUT_DIR / "go-nucleus.db")
db = ChromaDBAdapter(db_path)
wrapper = OntologyWrapper(oak_adapter=adapter, local_store=db, extractor=BasicExtractor())
def vstore(request, tmp_path):
temp_db_base = request.param
temp_dir = create_db_dir(tmp_path, temp_db_base)
db = setup_db(temp_dir)
extractor = BasicExtractor()
# mock, possible connection error?
adapter = get_adapter(INPUT_DIR / "go-nucleus.db")
try:
wrapper = OntologyWrapper(oak_adapter=adapter, local_store=db, extractor=extractor)
db.insert(wrapper.objects())
yield wrapper
except Exception as e:
raise e
finally:
if not DEBUG_MODE:
db.reset()


@pytest.mark.parametrize('vstore', [TEMP_OAK_OBJ], indirect=True)
def test_oak_objects(vstore):
"""Test that the objects are extracted from the oak adapter."""
shutil.rmtree(TEMP_OAKVIEW_DB, ignore_errors=True)
# vstore.local_store.reset()
objs = list(vstore.objects())
[nucleus] = [obj for obj in objs if obj["id"] == "Nucleus"]
assert nucleus["label"] == "nucleus"
Expand All @@ -50,22 +55,17 @@ def test_oak_objects(vstore):
assert len(reversed.graphs[0].edges) == 2


@pytest.mark.parametrize('vstore', [TEMP_OAK_IND], indirect=True)
def test_oak_index(vstore):
"""Test that the objects are indexed in the local store."""
shutil.rmtree(TEMP_OAKVIEW_DB2, ignore_errors=True)
adapter = get_adapter(INPUT_DIR / "go-nucleus.db")
db = ChromaDBAdapter(str(TEMP_OAKVIEW_DB2))
db.reset()
wrapper = OntologyWrapper(oak_adapter=adapter, local_store=db, extractor=BasicExtractor())
db.insert(wrapper.objects())
g = wrapper.unwrap_object(
g = vstore.unwrap_object(
{
"id": "Nucleus",
"label": "nucleus",
"relationships": [{"predicate": "rdfs:subClassOf", "target": "Organelle"}],
"original_id": "GO:0005634",
},
store=db,
store=vstore.local_store,
)
if isinstance(g, GraphDocument):
pprint(g.__dict__, width=100, indent=2)
Expand All @@ -80,6 +80,7 @@ def test_oak_index(vstore):
print(edge.sub, edge.pred, edge.obj)


@pytest.mark.parametrize('vstore', [TEMP_OAK_SEARCH], indirect=True)
@requires_openai_api_key
def test_oak_search(vstore):
"""Test that the objects are indexed and searchable in the local store."""
Expand Down
Loading

0 comments on commit e35bed8

Please sign in to comment.