Skip to content

Commit

Permalink
Feature/merge dev to main (#962)
Browse files Browse the repository at this point in the history
* merge dev and main

* git rm

* add back collection fix
  • Loading branch information
emrgnt-cmplxty authored Aug 23, 2024
1 parent 7b1d406 commit 8ad54be
Show file tree
Hide file tree
Showing 15 changed files with 83 additions and 42 deletions.
2 changes: 1 addition & 1 deletion py/cli/commands/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def serve(
).replace(":", "")

if docker:

run_docker_serve(
client,
host,
Expand Down
2 changes: 1 addition & 1 deletion py/cli/utils/docker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def build_docker_command(
os.environ["TRAEFIK_PORT"] = str(available_port + 1)

os.environ["CONFIG_PATH"] = (
os.path.basename(config_path) if config_path else ""
os.path.abspath(config_path) if config_path else ""
)

os.environ["R2R_IMAGE"] = image or ""
Expand Down
2 changes: 1 addition & 1 deletion py/core/base/abstractions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
from .restructure import KGEnrichmentSettings
from .search import (
AggregateSearchResult,
KGLocalSearchResult,
KGGlobalSearchResult,
KGLocalSearchResult,
KGSearchResult,
KGSearchSettings,
VectorSearchResult,
Expand Down
22 changes: 14 additions & 8 deletions py/core/base/abstractions/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ class Config:
},
}


class KGLocalSearchResult(BaseModel):
"""Result of a local knowledge graph search operation."""

query: str
entities: dict[str, Any]
relationships: dict[str, Any]
Expand All @@ -70,6 +72,7 @@ def __repr__(self) -> str:

class KGGlobalSearchResult(BaseModel):
"""Result of a global knowledge graph search operation."""

query: str
search_result: list[str]

Expand All @@ -80,26 +83,29 @@ def __repr__(self) -> str:
return self.__str__()

def dict(self) -> dict:
return {
"query": self.query,
"search_result": self.search_result
}
return {"query": self.query, "search_result": self.search_result}


class KGSearchResult(BaseModel):
"""Result of a knowledge graph search operation."""

local_result: Optional[KGLocalSearchResult] = None
global_result: Optional[KGGlobalSearchResult] = None

def __str__(self) -> str:
return f"KGSearchResult(local_result={self.local_result}, global_result={self.global_result})"

def __repr__(self) -> str:
return self.__str__()

def dict(self) -> dict:
return {
"local_result": self.local_result.dict() if self.local_result else None,
"global_result": self.global_result.dict() if self.global_result else None
"local_result": (
self.local_result.dict() if self.local_result else None
),
"global_result": (
self.global_result.dict() if self.global_result else None
),
}


Expand Down
4 changes: 3 additions & 1 deletion py/core/base/providers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ def register(self, email: str, password: str) -> Dict[str, str]:
pass

@abstractmethod
def verify_email(self, email: str, verification_code: str) -> Dict[str, str]:
def verify_email(
self, email: str, verification_code: str
) -> Dict[str, str]:
pass

@abstractmethod
Expand Down
2 changes: 1 addition & 1 deletion py/core/main/api/routes/auth/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def verify_email_app(
email: EmailStr = Body(..., description="User's email address"),
verification_code: str = Body(
..., description="Email verification code"
)
),
):
"""
Verify a user's email address.
Expand Down
2 changes: 1 addition & 1 deletion py/core/main/api/routes/retrieval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ async def search_app(
description=search_descriptions.get("kg_search_settings"),
),
auth_user=Depends(self.engine.providers.auth.auth_wrapper),
) -> WrappedSearchResponse:
) -> WrappedSearchResponse:
"""
Perform a search query on the vector database and knowledge graph.
Expand Down
25 changes: 22 additions & 3 deletions py/core/main/assembly/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging
import os
from enum import Enum
from typing import Any
from pathlib import Path
from typing import Any, Optional

import toml
from pydantic import BaseModel
Expand Down Expand Up @@ -58,7 +59,13 @@ class R2RConfig:
prompt: PromptConfig
agent: AgentConfig

def __init__(self, config_data: dict[str, Any]):
def __init__(
self, config_data: dict[str, Any], base_path: Optional[Path] = None
):
"""
:param config_data: dictionary of configuration parameters
:param base_path: base path when a relative path is specified for the prompts directory
"""
# Load the default configuration
default_config = self.load_default_config()

Expand All @@ -79,6 +86,18 @@ def __init__(self, config_data: dict[str, Any]):
and default_config[section]["provider"] != "null"
):
self._validate_config_section(default_config, section, keys)
if (
section == "prompt"
and "file_path" in default_config[section]
and not Path(
default_config[section]["file_path"]
).is_absolute()
and base_path
):
# Make file_path absolute and relative to the base path
default_config[section]["file_path"] = str(
base_path / default_config[section]["file_path"]
)
setattr(self, section, default_config[section])
self.completion = CompletionConfig.create(**self.completion)
# override GenerationConfig defaults
Expand Down Expand Up @@ -122,7 +141,7 @@ def from_toml(cls, config_path: str = None) -> "R2RConfig":
with open(config_path) as f:
config_data = toml.load(f)

return cls(config_data)
return cls(config_data, base_path=Path(config_path).parent)

def to_toml(self):
config_data = {
Expand Down
2 changes: 1 addition & 1 deletion py/core/main/services/auth_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async def verify_email(self, email: str, verification_code: str) -> bool:
raise R2RException(
status_code=400, message="Invalid or expired verification code"
)

user = self.providers.database.relational.get_user_by_id(user_id)
if not user or user.email != email:
raise R2RException(
Expand Down
4 changes: 2 additions & 2 deletions py/core/parsers/structured/csv_parser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import IO, AsyncGenerator, Union
from typing import IO, AsyncGenerator, Optional, Union

from core.base.abstractions.document import DataType
from core.base.parsers.base_parser import AsyncParser
Expand Down Expand Up @@ -36,7 +36,7 @@ def __init__(self):
self.StringIO = StringIO

def get_delimiter(
self, file_path: str | None = None, file: IO[bytes] | None = None
self, file_path: Optional[str] = None, file: Optional[IO[bytes]] = None
):

sniffer = self.csv.Sniffer()
Expand Down
18 changes: 13 additions & 5 deletions py/core/pipes/retrieval/kg_search_search_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
KGSearchSettings,
PipeType,
PromptProvider,
R2RException,
RunLoggingSingleton,
R2RException
)
from core.base.abstractions.search import (
KGGlobalSearchResult,
Expand All @@ -25,6 +25,7 @@

logger = logging.getLogger(__name__)


class KGSearchSearchPipe(GeneratorPipe):
"""
Embeds and stores documents using a specified embedding model and database.
Expand Down Expand Up @@ -132,11 +133,18 @@ async def local_search(
)
all_search_results.append(search_result)


if len(all_search_results[0])==0:
raise R2RException("No search results found. Please make sure you have run the KG enrichment step before running the search: r2r enrich-graph", 400)
if len(all_search_results[0]) == 0:
raise R2RException(
"No search results found. Please make sure you have run the KG enrichment step before running the search: r2r enrich-graph",
400,
)

yield KGLocalSearchResult(query=message, entities=all_search_results[0], relationships=all_search_results[1], communities=all_search_results[2])
yield KGLocalSearchResult(
query=message,
entities=all_search_results[0],
relationships=all_search_results[1],
communities=all_search_results[2],
)

async def global_search(
self,
Expand Down
2 changes: 1 addition & 1 deletion py/core/providers/database/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
) or os.getenv("POSTGRES_VECS_COLLECTION")
if not collection_name:
raise ValueError(
"Error, please set a valid POSTGRES_VECS_COLLECTION environment variable or set a 'collection' in the 'database' settings of your `r2r.toml`."
"Error, please set a valid POSTGRES_VECS_COLLECTION environment variable or set a 'vecs_collection' in the 'database' settings of your `r2r.toml`."
)

if not all([user, password, host, port, db_name, collection_name]):
Expand Down
6 changes: 4 additions & 2 deletions py/core/providers/database/vecs/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def _create(self):
tsvector_update_trigger(fts, 'pg_catalog.english', text);
"""
)
)
)
return self

def _drop(self):
Expand Down Expand Up @@ -1174,6 +1174,8 @@ def _build_table(name: str, meta: MetaData, dimension: int) -> Table:
table.c.fts,
table.c.text,
postgresql_using="gin",
postgresql_ops={"text": "gin_trgm_ops"}, # alternative, gin_tsvector_ops
postgresql_ops={
"text": "gin_trgm_ops"
}, # alternative, gin_tsvector_ops
)
return table
2 changes: 1 addition & 1 deletion py/core/providers/database/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _initialize_vector_db(self, dimension: int) -> None:
self.collection = self.vx.get_or_create_collection(
name=self.collection_name, dimension=dimension
)

def upsert(self, entry: VectorEntry) -> None:
if self.collection is None:
raise ValueError(
Expand Down
30 changes: 17 additions & 13 deletions py/sdk/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ class Config:
},
}


class KGLocalSearchResult(BaseModel):
query: str
entities: list[dict[str, Any]]
Expand All @@ -195,15 +196,16 @@ class KGLocalSearchResult(BaseModel):

def __str__(self) -> str:
return f"KGLocalSearchResult(query={self.query}, entities={self.entities}, relationships={self.relationships}, communities={self.communities})"

def dict(self) -> dict:
return {
"query": self.query,
"entities": self.entities,
"relationships": self.relationships,
"communities": self.communities
"communities": self.communities,
}


class KGGlobalSearchResult(BaseModel):
query: str
search_result: list[str]
Expand All @@ -213,12 +215,9 @@ def __str__(self) -> str:

def __repr__(self) -> str:
return self.__str__()

def dict(self) -> dict:
return {
"query": self.query,
"search_result": self.search_result
}
return {"query": self.query, "search_result": self.search_result}


class KGSearchResult(BaseModel):
Expand All @@ -230,11 +229,15 @@ def __str__(self) -> str:

def __repr__(self) -> str:
return self.__str__()

def dict(self) -> dict:
return {
"local_result": self.local_result.dict() if self.local_result else None,
"global_result": self.global_result.dict() if self.global_result else None
"local_result": (
self.local_result.dict() if self.local_result else None
),
"global_result": (
self.global_result.dict() if self.global_result else None
),
}

class Config:
Expand All @@ -245,7 +248,7 @@ class Config:
"entities": {
"Paris": {
"name": "Paris",
"description": "Paris is the capital of France."
"description": "Paris is the capital of France.",
}
},
"relationships": {},
Expand All @@ -255,11 +258,12 @@ class Config:
"query": "What is the capital of France?",
"search_result": [
"Paris is the capital and most populous city of France."
]
}
],
},
}
}


class R2RException(Exception):
def __init__(
self, message: str, status_code: int, detail: Optional[Any] = None
Expand Down

0 comments on commit 8ad54be

Please sign in to comment.