From 7b1d406e6bc2fdab0f50fdfbeeef7e7032c5e434 Mon Sep 17 00:00:00 2001 From: emrgnt-cmplxty <68796651+emrgnt-cmplxty@users.noreply.github.com> Date: Fri, 23 Aug 2024 11:40:49 -0700 Subject: [PATCH] collection docs (#955) --- py/core/base/abstractions/search.py | 2 +- py/core/base/providers/auth.py | 2 +- py/core/main/api/routes/auth/base.py | 3 ++- py/core/main/api/routes/retrieval/base.py | 2 +- py/core/main/services/auth_service.py | 8 +++++++- py/core/providers/database/vecs/collection.py | 6 +++--- py/core/providers/database/vector.py | 5 +---- 7 files changed, 16 insertions(+), 12 deletions(-) diff --git a/py/core/base/abstractions/search.py b/py/core/base/abstractions/search.py index d8e3bca0d..07e923055 100644 --- a/py/core/base/abstractions/search.py +++ b/py/core/base/abstractions/search.py @@ -122,7 +122,7 @@ def dict(self) -> dict: if self.vector_search_results else [] ), - "kg_search_results": self.kg_search_results or [], + "kg_search_results": self.kg_search_results or None, } diff --git a/py/core/base/providers/auth.py b/py/core/base/providers/auth.py index 17f2a2508..cc7493b9e 100644 --- a/py/core/base/providers/auth.py +++ b/py/core/base/providers/auth.py @@ -82,7 +82,7 @@ def register(self, email: str, password: str) -> Dict[str, str]: pass @abstractmethod - def verify_email(self, verification_code: str) -> Dict[str, str]: + def verify_email(self, email: str, verification_code: str) -> Dict[str, str]: pass @abstractmethod diff --git a/py/core/main/api/routes/auth/base.py b/py/core/main/api/routes/auth/base.py index 9dac245b9..da21ddb01 100644 --- a/py/core/main/api/routes/auth/base.py +++ b/py/core/main/api/routes/auth/base.py @@ -45,6 +45,7 @@ async def register_app( ) @self.base_endpoint async def verify_email_app( + email: EmailStr = Body(..., description="User's email address"), verification_code: str = Body( ..., description="Email verification code" ) @@ -55,7 +56,7 @@ async def verify_email_app( This endpoint is used to confirm a user's email address using the verification code sent to their email after registration. """ - result = await self.engine.averify_email(verification_code) + result = await self.engine.averify_email(email, verification_code) return GenericMessageResponse(message=result["message"]) @self.router.post("/login", response_model=WrappedTokenResponse) diff --git a/py/core/main/api/routes/retrieval/base.py b/py/core/main/api/routes/retrieval/base.py index 87dbb789d..1e2601de9 100644 --- a/py/core/main/api/routes/retrieval/base.py +++ b/py/core/main/api/routes/retrieval/base.py @@ -62,7 +62,7 @@ async def search_app( description=search_descriptions.get("kg_search_settings"), ), auth_user=Depends(self.engine.providers.auth.auth_wrapper), - ) -> WrappedSearchResponse: + ) -> WrappedSearchResponse: """ Perform a search query on the vector database and knowledge graph. diff --git a/py/core/main/services/auth_service.py b/py/core/main/services/auth_service.py index d704ae91d..c18024404 100644 --- a/py/core/main/services/auth_service.py +++ b/py/core/main/services/auth_service.py @@ -35,7 +35,7 @@ async def register(self, email: str, password: str) -> UserResponse: return self.providers.auth.register(email, password) @telemetry_event("VerifyEmail") - async def verify_email(self, verification_code: str) -> bool: + async def verify_email(self, email: str, verification_code: str) -> bool: if not self.config.auth.require_email_verification: raise R2RException( @@ -49,6 +49,12 @@ async def verify_email(self, verification_code: str) -> bool: raise R2RException( status_code=400, message="Invalid or expired verification code" ) + + user = self.providers.database.relational.get_user_by_id(user_id) + if not user or user.email != email: + raise R2RException( + status_code=400, message="Invalid or expired verification code" + ) self.providers.database.relational.mark_user_as_verified(user_id) self.providers.database.relational.remove_verification_code( diff --git a/py/core/providers/database/vecs/collection.py b/py/core/providers/database/vecs/collection.py index a692ff48d..44cf7c290 100644 --- a/py/core/providers/database/vecs/collection.py +++ b/py/core/providers/database/vecs/collection.py @@ -378,7 +378,7 @@ def _create(self): tsvector_update_trigger(fts, 'pg_catalog.english', text); """ ) - ) + ) return self def _drop(self): @@ -1168,12 +1168,12 @@ def _build_table(name: str, meta: MetaData, dimension: int) -> Table: extend_existing=True, ) - # Add GIN index for full-text search and trigram similarity + # # Add GIN index for full-text search and trigram similarity Index( f"idx_{name}_fts_trgm", table.c.fts, table.c.text, postgresql_using="gin", - postgresql_ops={"text": "gin_trgm_ops"}, # Remove gin_tsvector_ops + postgresql_ops={"text": "gin_trgm_ops"}, # alternative, gin_tsvector_ops ) return table diff --git a/py/core/providers/database/vector.py b/py/core/providers/database/vector.py index 3a140b556..cc1034140 100644 --- a/py/core/providers/database/vector.py +++ b/py/core/providers/database/vector.py @@ -131,10 +131,7 @@ def _initialize_vector_db(self, dimension: int) -> None: self.collection = self.vx.get_or_create_collection( name=self.collection_name, dimension=dimension ) - self.collection.create_index(measure="cosine_distance") - self.collection.create_index(measure="l2_distance") - self.collection.create_index(measure="max_inner_product") - + def upsert(self, entry: VectorEntry) -> None: if self.collection is None: raise ValueError(