From f0e83697c05e082f185c885c2a6a0e2edce703ab Mon Sep 17 00:00:00 2001 From: emrgnt-cmplxty Date: Thu, 22 Aug 2024 15:47:36 -0700 Subject: [PATCH 1/2] fix group logic --- py/cli/commands/management.py | 2 +- py/core/base/abstractions/document.py | 2 +- .../base/api/models/management/responses.py | 1 + py/core/examples/scripts/run_auth_workflow.py | 1 - py/core/main/api/routes/management/base.py | 18 +++-- py/core/main/api/routes/retrieval/base.py | 14 +++- py/core/main/app.py | 12 +-- py/core/main/engine.py | 18 +++++ py/core/main/services/management_service.py | 37 ++++----- py/core/main/services/retrieval_service.py | 1 - py/core/parsers/media/audio_parser.py | 2 +- py/core/parsers/media/img_parser.py | 2 +- py/core/parsers/media/openai_helpers.py | 4 +- py/core/providers/database/document.py | 59 +++++++++++--- py/core/providers/database/group.py | 32 ++++++++ py/core/providers/database/vecs/collection.py | 42 +++++++--- py/core/providers/database/vector.py | 79 ++++++++++++++++--- py/core/providers/parsing/r2r_parsing.py | 1 + py/sdk/client.py | 2 +- py/sdk/management.py | 4 +- 20 files changed, 252 insertions(+), 81 deletions(-) diff --git a/py/cli/commands/management.py b/py/cli/commands/management.py index 52ffc6742..9294e56f3 100644 --- a/py/cli/commands/management.py +++ b/py/cli/commands/management.py @@ -153,7 +153,7 @@ def inspect_knowledge_graph(client, limit): ## TODO: Implement remove_document_from_group -## TODO: Implement get_document_groups +## TODO: Implement document_groups ## TODO: Implement get_documents_in_group diff --git a/py/core/base/abstractions/document.py b/py/core/base/abstractions/document.py index 3b9b936c1..12cf53b42 100644 --- a/py/core/base/abstractions/document.py +++ b/py/core/base/abstractions/document.py @@ -26,10 +26,10 @@ class DocumentStatus(str, Enum): class DocumentType(str, Enum): """Types of documents that can be stored.""" - CSV = "csv" DOCX = "docx" HTML = "html" + HTM = "htm" JSON = "json" MD = "md" PDF = "pdf" diff --git a/py/core/base/api/models/management/responses.py b/py/core/base/api/models/management/responses.py index 64636c6ee..ffc3d5dc2 100644 --- a/py/core/base/api/models/management/responses.py +++ b/py/core/base/api/models/management/responses.py @@ -61,6 +61,7 @@ class DocumentOverviewResponse(BaseModel): updated_at: datetime status: str version: str + group_ids: list[UUID] metadata: Dict[str, Any] diff --git a/py/core/examples/scripts/run_auth_workflow.py b/py/core/examples/scripts/run_auth_workflow.py index ad4837fc6..1729a7679 100644 --- a/py/core/examples/scripts/run_auth_workflow.py +++ b/py/core/examples/scripts/run_auth_workflow.py @@ -1,4 +1,3 @@ -# TODO: need to import this from the package, not from the local directory from r2r import R2RClient if __name__ == "__main__": diff --git a/py/core/main/api/routes/management/base.py b/py/core/main/api/routes/management/base.py index 62b312cd1..05ed0e26f 100644 --- a/py/core/main/api/routes/management/base.py +++ b/py/core/main/api/routes/management/base.py @@ -362,16 +362,17 @@ async def remove_user_from_group_app( user_id: str = Body(..., description="User ID"), group_id: str = Body(..., description="Group ID"), auth_user=Depends(self.engine.providers.auth.auth_wrapper), - ) -> WrappedGroupResponse: + ): if not auth_user.is_superuser: raise R2RException( "Only a superuser can remove users from groups.", 403 ) user_uuid = UUID(user_id) group_uuid = UUID(group_id) - return await self.engine.aremove_user_from_group( + await self.engine.aremove_user_from_group( user_uuid, group_uuid ) + return None # TODO - Proivde response model @self.router.get("/get_users_in_group/{group_id}") @@ -414,7 +415,7 @@ async def assign_document_to_group_app( document_id: str = Body(..., description="Document ID"), group_id: str = Body(..., description="Group ID"), auth_user=Depends(self.engine.providers.auth.auth_wrapper), - ) -> WrappedGroupResponse: + ): if not auth_user.is_superuser: raise R2RException( "Only a superuser can assign documents to groups.", 403 @@ -431,24 +432,25 @@ async def remove_document_from_group_app( document_id: str = Body(..., description="Document ID"), group_id: str = Body(..., description="Group ID"), auth_user=Depends(self.engine.providers.auth.auth_wrapper), - ) -> WrappedGroupResponse: + ) -> None: if not auth_user.is_superuser: raise R2RException( "Only a superuser can remove documents from groups.", 403 ) document_uuid = UUID(document_id) group_uuid = UUID(group_id) - return await self.engine.aremove_document_from_group( + await self.engine.aremove_document_from_group( document_uuid, group_uuid ) + return None - @self.router.get("/get_document_groups/{document_id}") + @self.router.get("/document_groups/{document_id}") @self.base_endpoint - async def get_document_groups_app( + async def document_groups_app( document_id: str = Path(..., description="Document ID"), auth_user=Depends(self.engine.providers.auth.auth_wrapper), ) -> WrappedGroupListResponse: - return await self.engine.aget_document_groups(document_id) + return await self.engine.adocument_groups(document_id) @self.router.get("/group/{group_id}/documents") @self.base_endpoint diff --git a/py/core/main/api/routes/retrieval/base.py b/py/core/main/api/routes/retrieval/base.py index 4c801a4c0..b65c3f776 100644 --- a/py/core/main/api/routes/retrieval/base.py +++ b/py/core/main/api/routes/retrieval/base.py @@ -72,6 +72,13 @@ async def search_app( Allowed operators include `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`, `ilike`, `in`, and `nin`. """ + print('auth_user = ', auth_user) + print('initial vector_search_settings = ', vector_search_settings) + + print(f"Received query: {query}") + print(f"Received vector_search_settings: {vector_search_settings}") + print(f"Received kg_search_settings: {kg_search_settings}") + print(f"Auth user: {auth_user}") user_groups = set(auth_user.group_ids) selected_groups = set(vector_search_settings.selected_group_ids) @@ -81,10 +88,12 @@ async def search_app( "User does not have access to the specified group(s): " f"{selected_groups - allowed_groups}" ) + print('initial vector_search_settings filters = ', vector_search_settings.filters) filters = { "$or": [ - {"user_id": str(auth_user.id)}, + {"user_id": {"$eq": str(auth_user.id)}}, + # {"group_ids": {"$any": list([str(ele) for ele in allowed_groups])}}, {"group_ids": {"$overlap": list(allowed_groups)}}, ] } @@ -92,7 +101,8 @@ async def search_app( filters = {"$and": [filters, vector_search_settings.filters]} vector_search_settings.filters = filters - + print('final vector_search_settings = ', vector_search_settings) + print('final vector_search_settings filters = ', vector_search_settings.filters) results = await self.engine.asearch( query=query, vector_search_settings=vector_search_settings, diff --git a/py/core/main/app.py b/py/core/main/app.py index 8276c8e7e..54d60f536 100644 --- a/py/core/main/app.py +++ b/py/core/main/app.py @@ -42,13 +42,13 @@ def _setup_routes(self): ) # Include routers in the app - self.app.include_router(ingestion_router, prefix="/v1") - self.app.include_router(management_router, prefix="/v1") - self.app.include_router(retrieval_router, prefix="/v1") - self.app.include_router(auth_router, prefix="/v1") - self.app.include_router(restructure_router, prefix="/v1") + self.app.include_router(ingestion_router, prefix="/v2") + self.app.include_router(management_router, prefix="/v2") + self.app.include_router(retrieval_router, prefix="/v2") + self.app.include_router(auth_router, prefix="/v2") + self.app.include_router(restructure_router, prefix="/v2") - @self.app.router.get("/v1/openapi_spec") + @self.app.router.get("/v2/openapi_spec") async def openapi_spec(): from fastapi.openapi.utils import get_openapi diff --git a/py/core/main/engine.py b/py/core/main/engine.py index 3fbb6f1e0..b5092e6ae 100644 --- a/py/core/main/engine.py +++ b/py/core/main/engine.py @@ -251,6 +251,12 @@ async def aget_groups_for_user(self, *args, **kwargs): *args, **kwargs ) + @syncable + async def aassign_document_to_group(self, *args, **kwargs): + return await self.management_service.aassign_document_to_group( + *args, **kwargs + ) + @syncable async def agroups_overview(self, *args, **kwargs): return await self.management_service.agroups_overview(*args, **kwargs) @@ -260,3 +266,15 @@ async def aget_documents_in_group(self, *args, **kwargs): return await self.management_service.aget_documents_in_group( *args, **kwargs ) + + @syncable + async def adocument_groups(self, *args, **kwargs): + return await self.management_service.adocument_groups( + *args, **kwargs + ) + + @syncable + async def aremove_document_from_group(self, *args, **kwargs): + return await self.management_service.aremove_document_from_group( + *args, **kwargs + ) \ No newline at end of file diff --git a/py/core/main/services/management_service.py b/py/core/main/services/management_service.py index f572a55e1..70db37000 100644 --- a/py/core/main/services/management_service.py +++ b/py/core/main/services/management_service.py @@ -408,32 +408,29 @@ async def aassign_document_to_group( self, document_id: str, group_id: UUID ): - if self.providers.database.vector.assign_document_to_group( + self.providers.database.relational.assign_document_to_group( document_id, group_id - ): - return {"message": "Document assigned to group successfully"} - else: - raise R2RException( - status_code=404, - message="Document not found or assignment failed", - ) + ) + self.providers.database.vector.assign_document_to_group( + document_id, group_id + ) + return {"message": "Document assigned to group successfully"} @telemetry_event("RemoveDocumentFromGroup") async def aremove_document_from_group( self, document_id: str, group_id: UUID ): - if self.providers.database.vector.remove_document_from_group( + self.providers.database.relational.remove_document_from_group( document_id, group_id - ): - return {"message": "Document removed from group successfully"} - else: - raise R2RException( - status_code=404, message="Document not found or removal failed" - ) + ) + self.providers.database.vector.remove_document_from_group( + document_id, group_id + ) + return {"message": "Document removed from group successfully"} - @telemetry_event("GetDocumentGroups") - async def aget_document_groups(self, document_id: str): - group_ids = self.providers.database.relational.get_document_groups( + @telemetry_event("DocumentGroups") + async def adocument_groups(self, document_id: str): + group_ids = self.providers.database.relational.document_groups( document_id ) return {"group_ids": [str(group_id) for group_id in group_ids]} @@ -615,3 +612,7 @@ async def aget_documents_in_group( return self.providers.database.relational.get_documents_in_group( group_id, offset, limit ) + + @telemetry_event("DocumentGroups") + async def adocument_groups(self, document_id: str) -> list[str]: + return self.providers.database.relational.document_groups(document_id) \ No newline at end of file diff --git a/py/core/main/services/retrieval_service.py b/py/core/main/services/retrieval_service.py index d9e72bc8a..6e570041f 100644 --- a/py/core/main/services/retrieval_service.py +++ b/py/core/main/services/retrieval_service.py @@ -54,7 +54,6 @@ async def search( query: str, vector_search_settings: VectorSearchSettings = VectorSearchSettings(), kg_search_settings: KGSearchSettings = KGSearchSettings(), - user: Optional[UserResponse] = None, *args, **kwargs, ) -> SearchResponse: diff --git a/py/core/parsers/media/audio_parser.py b/py/core/parsers/media/audio_parser.py index 000857b99..81e0ef4f3 100644 --- a/py/core/parsers/media/audio_parser.py +++ b/py/core/parsers/media/audio_parser.py @@ -9,7 +9,7 @@ class AudioParser(AsyncParser[bytes]): """A parser for audio data.""" def __init__( - self, api_base: str = "https://api.openai.com/v1/audio/transcriptions" + self, api_base: str = "https://api.openai.com/v2/audio/transcriptions" ): self.api_base = api_base self.openai_api_key = os.environ.get("OPENAI_API_KEY") diff --git a/py/core/parsers/media/img_parser.py b/py/core/parsers/media/img_parser.py index 67cc80e8e..c679b2772 100644 --- a/py/core/parsers/media/img_parser.py +++ b/py/core/parsers/media/img_parser.py @@ -13,7 +13,7 @@ def __init__( self, model: str = "gpt-4o", max_tokens: int = 2_048, - api_base: str = "https://api.openai.com/v1/chat/completions", + api_base: str = "https://api.openai.com/v2/chat/completions", ): self.model = model self.max_tokens = max_tokens diff --git a/py/core/parsers/media/openai_helpers.py b/py/core/parsers/media/openai_helpers.py index 583a685e9..910261cca 100644 --- a/py/core/parsers/media/openai_helpers.py +++ b/py/core/parsers/media/openai_helpers.py @@ -8,7 +8,7 @@ def process_frame_with_openai( api_key: str, model: str = "gpt-4o", max_tokens: int = 2_048, - api_base: str = "https://api.openai.com/v1/chat/completions", + api_base: str = "https://api.openai.com/v2/chat/completions", ) -> str: headers = { "Content-Type": "application/json", @@ -43,7 +43,7 @@ def process_frame_with_openai( def process_audio_with_openai( audio_file, api_key: str, - audio_api_base: str = "https://api.openai.com/v1/audio/transcriptions", + audio_api_base: str = "https://api.openai.com/v2/audio/transcriptions", ) -> str: headers = {"Authorization": f"Bearer {api_key}"} diff --git a/py/core/providers/database/document.py b/py/core/providers/database/document.py index 0af205ec3..5823d3d77 100644 --- a/py/core/providers/database/document.py +++ b/py/core/providers/database/document.py @@ -1,7 +1,8 @@ from typing import Optional from uuid import UUID -from core.base import DocumentInfo, DocumentStatus, DocumentType +from core.base import DocumentInfo, DocumentStatus, DocumentType, R2RException +from core.base.api.models.management.responses import GroupResponse from .base import DatabaseMixin @@ -109,16 +110,54 @@ def get_documents_overview( ) for row in results ] - - def get_document_groups(self, document_id: str) -> list[str]: + + def document_groups(self, document_id: UUID) -> list[GroupResponse]: query = f""" - SELECT group_ids - FROM {self._get_table_name('document_info')} - WHERE document_id = :document_id + SELECT g.group_id, g.name, g.description, g.created_at, g.updated_at + FROM {self._get_table_name('groups')} g + JOIN {self._get_table_name('document_info')} d ON g.group_id = ANY(d.group_ids) + WHERE d.document_id = :document_id """ params = {"document_id": document_id} - result = self.execute_query(query, params).fetchone() + results = self.execute_query(query, params).fetchall() - if result and result[0]: - return [str(group_id) for group_id in result[0]] - return [] + return [ + GroupResponse( + group_id=row[0], + name=row[1], + description=row[2], + created_at=row[3], + updated_at=row[4], + ) + for row in results + ] + + def remove_document_from_group(self, document_id: UUID, group_id: UUID) -> None: + """ + Remove a document from a group. + + Args: + document_id (UUID): The ID of the document to remove. + group_id (UUID): The ID of the group to remove the document from. + + Raises: + R2RException: If the group doesn't exist or if the document is not in the group. + """ + if not self.group_exists(group_id): + raise R2RException(status_code=404, message="Group not found") + + query = f""" + UPDATE {self._get_table_name('document_info')} + SET group_ids = array_remove(group_ids, :group_id) + WHERE document_id = :document_id AND :group_id = ANY(group_ids) + RETURNING document_id + """ + result = self.execute_query( + query, {"document_id": document_id, "group_id": group_id} + ).fetchone() + + if not result: + raise R2RException( + status_code=404, + message="Document not found in the specified group" + ) \ No newline at end of file diff --git a/py/core/providers/database/group.py b/py/core/providers/database/group.py index ab3cb811e..2c1105e68 100644 --- a/py/core/providers/database/group.py +++ b/py/core/providers/database/group.py @@ -376,3 +376,35 @@ def get_groups_for_user( ) for row in results ] + + def assign_document_to_group(self, document_id: UUID, group_id: UUID) -> None: + """ + Assign a document to a group. + + Args: + document_id (UUID): The ID of the document to assign. + group_id (UUID): The ID of the group to assign the document to. + + Raises: + R2RException: If the group doesn't exist or if the document is not found. + """ + if not self.group_exists(group_id): + raise R2RException(status_code=404, message="Group not found") + + query = f""" + UPDATE {self._get_table_name('document_info')} + SET group_ids = array_append(group_ids, :group_id) + WHERE document_id = :document_id AND NOT (:group_id = ANY(group_ids)) + RETURNING document_id + """ + result = self.execute_query( + query, {"document_id": document_id, "group_id": group_id} + ).fetchone() + + + if not result: + raise R2RException( + status_code=404, + message="Document not found or already assigned to the group" + ) + diff --git a/py/core/providers/database/vecs/collection.py b/py/core/providers/database/vecs/collection.py index ccae50c00..d9f3485a2 100644 --- a/py/core/providers/database/vecs/collection.py +++ b/py/core/providers/database/vecs/collection.py @@ -703,26 +703,31 @@ def expand_query(query: str) -> str: combined_rank, ) .where( - sa.or_( - self.table.c.fts.op("@@")(ts_query), - sa.func.similarity(self.table.c.text, query_text) > 0.1, - self.table.c.fts.op("@@")( - sa.func.phraseto_tsquery("english", query_text) - ), - self.table.c.fts.op("@@")( - sa.func.to_tsquery( - "english", - " & ".join( - f"{word}:*" for word in query_text.split() - ), - ) + sa.and_( + sa.or_( + self.table.c.fts.op("@@")(ts_query), + sa.func.similarity(self.table.c.text, query_text) > 0.1, + self.table.c.fts.op("@@")( + sa.func.phraseto_tsquery("english", query_text) + ), + self.table.c.fts.op("@@")( + sa.func.to_tsquery( + "english", + " & ".join( + f"{word}:*" for word in query_text.split() + ), + ) + ), ), + self.build_filters(search_settings.filters) ) ) .order_by(sa.desc("rank")) .limit(search_settings.hybrid_search_settings.full_text_limit) ) + + with self.client.Session() as sess: results = sess.execute(stmt).fetchall() @@ -776,6 +781,14 @@ def parse_condition(key, value): return ~column.in_(clause) elif op == "$overlap": return column.overlap(clause) + elif op == "$contains": + return column.contains(clause) + elif op == "$any": + if key == "group_ids": + # Use ANY for UUID array comparison + return func.array_to_string(column, ',').like(f"%{clause}%") + # New operator for checking if any element in the array matches + return column.any(clause) else: raise FilterError( f"Unsupported operator for column {key}: {op}" @@ -785,6 +798,9 @@ def parse_condition(key, value): else: # Handle JSON-based filters json_col = self.table.c.metadata + if not key.startswith("metadata."): + raise FilterError("metadata key must start with 'metadata.'") + key = key.split('metadata.')[1] if isinstance(value, dict): if len(value) > 1: raise FilterError("only one operator permitted") diff --git a/py/core/providers/database/vector.py b/py/core/providers/database/vector.py index a4aadcba3..03d347fb1 100644 --- a/py/core/providers/database/vector.py +++ b/py/core/providers/database/vector.py @@ -176,19 +176,6 @@ def upsert_entries(self, entries: list[VectorEntry]) -> None: ] ) - def get_document_groups(self, document_id: str) -> list[str]: - query = text( - f""" - SELECT group_ids - FROM document_info_{self.collection_name} - WHERE document_id = :document_id - """ - ) - with self.vx.Session() as sess: - result = sess.execute(query, {"document_id": document_id}) - group_ids = result.scalar() - return [str(group_id) for group_id in (group_ids or [])] - def semantic_search( self, query_vector: list[float], search_settings: VectorSearchSettings ) -> list[VectorSearchResult]: @@ -337,6 +324,72 @@ def delete( return self.collection.delete(filters=filters) + def assign_document_to_group(self, document_id: str, group_id: str) -> None: + """ + Assign a document to a group in the vector database. + + Args: + document_id (str): The ID of the document to assign. + group_id (str): The ID of the group to assign the document to. + + Raises: + ValueError: If the collection is not initialized. + """ + if self.collection is None: + raise ValueError( + "Please call `initialize_collection` before attempting to run `assign_document_to_group`." + ) + + table_name = self.collection.table.name + query = text( + f""" + UPDATE vecs."{table_name}" + SET group_ids = array_append(group_ids, :group_id) + WHERE document_id = :document_id AND NOT (:group_id = ANY(group_ids)) + RETURNING document_id + """ + ) + + with self.vx.Session() as sess: + result = sess.execute(query, {"document_id": document_id, "group_id": group_id}) + sess.commit() + + if result.rowcount == 0: + logger.warning(f"Document {document_id} not found or already assigned to group {group_id}") + + def remove_document_from_group(self, document_id: str, group_id: str) -> None: + """ + Remove a document from a group in the vector database. + + Args: + document_id (str): The ID of the document to remove. + group_id (str): The ID of the group to remove the document from. + + Raises: + ValueError: If the collection is not initialized. + """ + if self.collection is None: + raise ValueError( + "Please call `initialize_collection` before attempting to run `remove_document_from_group`." + ) + + table_name = self.collection.table.name + query = text( + f""" + UPDATE vecs."{table_name}" + SET group_ids = array_remove(group_ids, :group_id) + WHERE document_id = :document_id AND :group_id = ANY(group_ids) + RETURNING document_id + """ + ) + + with self.vx.Session() as sess: + result = sess.execute(query, {"document_id": document_id, "group_id": group_id}) + sess.commit() + + if result.rowcount == 0: + logger.warning(f"Document {document_id} not found in group {group_id} or already removed") + def get_document_chunks(self, document_id: str) -> list[dict]: if not self.collection: raise ValueError("Collection is not initialized.") diff --git a/py/core/providers/parsing/r2r_parsing.py b/py/core/providers/parsing/r2r_parsing.py index 55cbeb37c..0ad60a085 100644 --- a/py/core/providers/parsing/r2r_parsing.py +++ b/py/core/providers/parsing/r2r_parsing.py @@ -21,6 +21,7 @@ class R2RParsingProvider(ParsingProvider): DocumentType.CSV: [parsers.CSVParser, parsers.CSVParserAdvanced], DocumentType.DOCX: [parsers.DOCXParser], DocumentType.HTML: [parsers.HTMLParser], + DocumentType.HTM: [parsers.HTMLParser], DocumentType.JSON: [parsers.JSONParser], DocumentType.MD: [parsers.MDParser], DocumentType.PDF: [parsers.PDFParser, parsers.PDFParserUnstructured], diff --git a/py/sdk/client.py b/py/sdk/client.py index 8de4ac637..8c8d38d3b 100644 --- a/py/sdk/client.py +++ b/py/sdk/client.py @@ -76,7 +76,7 @@ class R2RAsyncClient: def __init__( self, base_url: str = "http://localhost:8000", - prefix: str = "/v1", + prefix: str = "/v2", custom_client=None, timeout: float = 300.0, ): diff --git a/py/sdk/management.py b/py/sdk/management.py index 3abff1b00..74830e7c7 100644 --- a/py/sdk/management.py +++ b/py/sdk/management.py @@ -471,7 +471,7 @@ async def remove_document_from_group( ) @staticmethod - async def get_document_groups( + async def document_groups( client, document_id: str, ) -> dict: @@ -485,7 +485,7 @@ async def get_document_groups( dict: The list of groups that the document is assigned to. """ return await client._make_request( - "GET", f"get_document_groups/{document_id}" + "GET", f"document_groups/{document_id}" ) @staticmethod From 29f30ec5807701578b593402cd6ccdab81dc6c0a Mon Sep 17 00:00:00 2001 From: emrgnt-cmplxty Date: Thu, 22 Aug 2024 17:39:25 -0700 Subject: [PATCH 2/2] up --- py/cli/commands/management.py | 95 +++++----- .../examples/scripts/advanced_kg_cookbook.py | 2 +- .../examples/scripts/run_group_workflow.py | 164 ++++++++++++++++++ py/core/main/api/routes/auth/base.py | 18 +- py/core/main/api/routes/management/base.py | 41 +++-- py/core/main/engine.py | 4 +- py/core/main/services/auth_service.py | 4 + py/core/main/services/management_service.py | 50 ++++-- py/core/providers/database/document.py | 68 ++------ py/core/providers/database/group.py | 124 +++++++++++-- py/core/providers/database/user.py | 48 ++++- py/core/providers/database/vector.py | 103 +++++++++-- py/sdk/auth.py | 20 +-- py/sdk/management.py | 108 ++++++++++-- py/tests/test_client.py | 6 +- py/tests/test_groups.py | 12 +- 16 files changed, 654 insertions(+), 213 deletions(-) create mode 100644 py/core/examples/scripts/run_group_workflow.py diff --git a/py/cli/commands/management.py b/py/cli/commands/management.py index 9294e56f3..f284f3a27 100644 --- a/py/cli/commands/management.py +++ b/py/cli/commands/management.py @@ -36,13 +36,23 @@ def app_settings(client): @cli.command() @click.option("--user-ids", multiple=True, help="User IDs to overview") +@click.option( + "--offset", + default=None, + help="The offset to start from. Defaults to 0.", +) +@click.option( + "--limit", + default=None, + help="The maximum number of nodes to return. Defaults to 100.", +) @click.pass_obj -def users_overview(client, user_ids): +def users_overview(client, user_ids, offset, limit): """Get an overview of users.""" user_ids = list(user_ids) if user_ids else None with timer(): - response = client.users_overview(user_ids) + response = client.users_overview(user_ids, offset, limit) for user in response: click.echo(user) @@ -73,13 +83,23 @@ def delete(client, filter): @cli.command() @click.option("--document-ids", multiple=True, help="Document IDs to overview") +@click.option( + "--offset", + default=None, + help="The offset to start from. Defaults to 0.", +) +@click.option( + "--limit", + default=None, + help="The maximum number of nodes to return. Defaults to 100.", +) @click.pass_obj -def documents_overview(client, document_ids): +def documents_overview(client, document_ids, offset, limit): """Get an overview of documents.""" document_ids = list(document_ids) if document_ids else None with timer(): - response = client.documents_overview(document_ids) + response = client.documents_overview(document_ids, offset, limit) for document in response["results"]: click.echo(document) @@ -87,11 +107,21 @@ def documents_overview(client, document_ids): @cli.command() @click.option("--document-id", help="Document ID to retrieve chunks for") +@click.option( + "--offset", + default=None, + help="The offset to start from. Defaults to 0.", +) +@click.option( + "--limit", + default=None, + help="The maximum number of nodes to return. Defaults to 100.", +) @click.pass_obj -def document_chunks(client, document_id): +def document_chunks(client, document_id, offset, limit): """Get chunks of a specific document.""" with timer(): - response = client.document_chunks(document_id) + response = client.document_chunks(document_id, offset, limit) chunks = response.get("results", []) click.echo(f"\nNumber of chunks: {len(chunks)}") @@ -103,57 +133,20 @@ def document_chunks(client, document_id): @cli.command() +@click.option( + "--offset", + default=None, + help="The offset to start from. Defaults to 0.", +) @click.option( "--limit", default=None, help="The maximum number of nodes to return. Defaults to 100.", ) @click.pass_obj -def inspect_knowledge_graph(client, limit): +def inspect_knowledge_graph(client, offset, limit): """Inspect the knowledge graph.""" with timer(): - response = client.inspect_knowledge_graph(limit) - - click.echo(response) - - -## TODO: Implement groups_overview - - -## TODO: Implement create_group - - -## TODO: Implement get_group - - -## TODO: Implement update_group - - -## TODO: Implement delete_group - - -## TODO: Implement list_groups - - -## TODO: Implement add_user_to_group - - -## TODO: Implement remove_user_from_group - - -## TODO: Implement get_users_in_group - - -## TODO: Implement get_groups_for_user - - -## TODO: Implement assign_document_to_group - - -## TODO: Implement remove_document_from_group - - -## TODO: Implement document_groups - + response = client.inspect_knowledge_graph(offset, limit) -## TODO: Implement get_documents_in_group + click.echo(response) \ No newline at end of file diff --git a/py/core/examples/scripts/advanced_kg_cookbook.py b/py/core/examples/scripts/advanced_kg_cookbook.py index 8eaa2ea73..f69d0a486 100644 --- a/py/core/examples/scripts/advanced_kg_cookbook.py +++ b/py/core/examples/scripts/advanced_kg_cookbook.py @@ -155,7 +155,7 @@ def main( print("Inspecting Knowledge Graph") print( - client.inspect_knowledge_graph(1000, print_descriptions=True)[ + client.inspect_knowledge_graph(0, 1000, print_descriptions=True)[ "results" ] ) diff --git a/py/core/examples/scripts/run_group_workflow.py b/py/core/examples/scripts/run_group_workflow.py new file mode 100644 index 000000000..38353ef20 --- /dev/null +++ b/py/core/examples/scripts/run_group_workflow.py @@ -0,0 +1,164 @@ +import os +from r2r import R2RClient + + +if __name__ == "__main__": + # Initialize the R2R client + client = R2RClient("http://localhost:8000") # Replace with your R2R deployment URL + + # Admin login + print("Logging in as admin...") + login_result = client.login("admin@example.com", "change_me_immediately") + print("Admin login result:", login_result) + + # Create two groups + print("\nCreating two groups...") + group1_result = client.create_group("TestGroup1", "A test group for document access") + group2_result = client.create_group("TestGroup2", "Another test group") + print("Group1 creation result:", group1_result) + print("Group2 creation result:", group2_result) + group1_id = group1_result['results']['group_id'] + group2_id = group2_result['results']['group_id'] + + # Get groups overview + print("\nGetting groups overview...") + groups_overview = client.groups_overview() + print("Groups overview:", groups_overview) + + # Get specific group + print("\nGetting specific group...") + group1_details = client.get_group(group1_id) + print("Group1 details:", group1_details) + + # List all groups + print("\nListing all groups...") + groups_list = client.list_groups() + print("Groups list:", groups_list) + + # Update a group + print("\nUpdating Group1...") + update_result = client.update_group(group1_id, name="UpdatedTestGroup1", description="Updated description") + print("Group update result:", update_result) + + # Ingest two documents + print("\nIngesting two documents...") + script_path = os.path.dirname(__file__) + sample_file1 = os.path.join(script_path, "core", "examples", "data", "aristotle_v2.txt") + sample_file2 = os.path.join(script_path, "core", "examples", "data", "aristotle.txt") + ingestion_result1 = client.ingest_files([sample_file1]) + ingestion_result2 = client.ingest_files([sample_file2]) + print("Document1 ingestion result:", ingestion_result1) + print("Document2 ingestion result:", ingestion_result2) + document1_id = ingestion_result1['results']['processed_documents'][0]['id'] + document2_id = ingestion_result2['results']['processed_documents'][0]['id'] + + # Assign documents to groups + print("\nAssigning documents to groups...") + assign_result1 = client.assign_document_to_group(document1_id, group1_id) + assign_result2 = client.assign_document_to_group(document2_id, group2_id) + print("Document1 assignment result:", assign_result1) + print("Document2 assignment result:", assign_result2) + + # document1_id = "c3291abf-8a4e-5d9d-80fd-232ef6fd8526" + # Get document groups + print("\nGetting groups for Document1...") + doc1_groups = client.document_groups(document1_id) + print("Document1 groups:", doc1_groups) + + # Create three test users + print("\nCreating three test users...") + user1_result = client.register("user1@test.com", "password123") + user2_result = client.register("user2@test.com", "password123") + user3_result = client.register("user3@test.com", "password123") + print("User1 creation result:", user1_result) + print("User2 creation result:", user2_result) + print("User3 creation result:", user3_result) + + # Add users to groups + print("\nAdding users to groups...") + add_user1_result = client.add_user_to_group(user1_result['results']['id'], group1_id) + add_user2_result = client.add_user_to_group(user2_result['results']['id'], group2_id) + add_user3_result1 = client.add_user_to_group(user3_result['results']['id'], group1_id) + add_user3_result2 = client.add_user_to_group(user3_result['results']['id'], group2_id) + print("Add user1 to group1 result:", add_user1_result) + print("Add user2 to group2 result:", add_user2_result) + print("Add user3 to group1 result:", add_user3_result1) + print("Add user3 to group2 result:", add_user3_result2) + + # Get users in a group + print("\nGetting users in Group1...") + users_in_group1 = client.user_groups(group1_id) + print("Users in Group1:", users_in_group1) + + # Get groups for a user + print("\nGetting groups for User3...") + user3_groups = client.user_groups(user3_result['results']['id']) + print("User3 groups:", user3_groups) + + # Get documents in a group + print("\nGetting documents in Group1...") + docs_in_group1 = client.documents_in_group(group1_id) + print("Documents in Group1:", docs_in_group1) + + # Remove user from group + print("\nRemoving User3 from Group1...") + remove_user_result = client.remove_user_from_group(user3_result['results']['id'], group1_id) + print("Remove user result:", remove_user_result) + + # Remove document from group + print("\nRemoving Document1 from Group1...") + remove_doc_result = client.remove_document_from_group(document1_id, group1_id) + print("Remove document result:", remove_doc_result) + + # Logout admin + print("\nLogging out admin...") + client.logout() + + # Login as user1 + print("\nLogging in as user1...") + client.login("user1@test.com", "password123") + + # Search for documents (should see document1 but not document2) + print("\nUser1 searching for documents...") + search_result_user1 = client.search("philosophy", {"selected_group_ids": [group1_id]}) + print("User1 search result:", search_result_user1) + + # Logout user1 + print("\nLogging out user1...") + client.logout() + + # Login as user3 + print("\nLogging in as user3...") + client.login("user3@test.com", "password123") + + # Search for documents (should see only document2 after removal from Group1) + print("\nUser3 searching for documents...") + try: + search_result_user3 = client.search("philosophy", {"selected_group_ids": [group1_id, group2_id]}) + except Exception as e: + print("User3 search result error:", e) + search_result_user3 = client.search("philosophy", {"selected_group_ids": [group2_id]}) + + print("User3 search result:", search_result_user3) + + # Logout user3 + print("\nLogging out user3...") + client.logout() + + # Clean up + print("\nCleaning up...") + # Login as admin again + client.login("admin@example.com", "change_me_immediately") + + # Delete the groups + print("Deleting the groups...") + client.delete_group(group1_id) + client.delete_group(group2_id) + + # Logout admin + print("\nLogging out admin...") + client.logout() + + print("\nWorkflow completed.") + + \ No newline at end of file diff --git a/py/core/main/api/routes/auth/base.py b/py/core/main/api/routes/auth/base.py index 49871643c..94a85dab2 100644 --- a/py/core/main/api/routes/auth/base.py +++ b/py/core/main/api/routes/auth/base.py @@ -1,5 +1,5 @@ import uuid -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from core.base.api.models.auth.responses import ( GenericMessageResponse, @@ -7,7 +7,7 @@ WrappedTokenResponse, WrappedUserResponse, ) -from fastapi import Body, Depends +from fastapi import Path, Body, Depends from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from pydantic import EmailStr @@ -200,14 +200,18 @@ async def reset_password_app( return GenericMessageResponse(message=result["message"]) @self.router.delete( - "/user", response_model=WrappedGenericMessageResponse + "/user/{user_id}", response_model=WrappedGenericMessageResponse ) @self.base_endpoint async def delete_user_app( - user_id: str = Body(..., description="ID of the user to delete"), - password: str | None = Body( + user_id: str = Path(..., description="ID of the user to delete"), + password: Optional[str] = Body( None, description="User's current password" ), + delete_vector_data: Optional[bool] = Body( + False, + description="Whether to delete the user's vector data", + ), auth_user=Depends(self.engine.providers.auth.auth_wrapper), ): """ @@ -218,6 +222,8 @@ async def delete_user_app( """ if auth_user.id != user_id and not auth_user.is_superuser: raise Exception("User ID does not match authenticated user") + if not auth_user.is_superuser and not password: + raise Exception("Password is required for non-superusers") user_uuid = uuid.UUID(user_id) - result = await self.engine.adelete_user(user_uuid, password) + result = await self.engine.adelete_user(user_uuid, password, delete_vector_data) return GenericMessageResponse(message=result["message"]) diff --git a/py/core/main/api/routes/management/base.py b/py/core/main/api/routes/management/base.py index 05ed0e26f..05623d7b9 100644 --- a/py/core/main/api/routes/management/base.py +++ b/py/core/main/api/routes/management/base.py @@ -162,6 +162,8 @@ async def score_completion( @self.base_endpoint async def users_overview_app( user_ids: Optional[list[str]] = Query([]), + offset: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=1000), auth_user=Depends(self.engine.providers.auth.auth_wrapper), ) -> WrappedUserOverviewResponse: if not auth_user.is_superuser: @@ -174,7 +176,7 @@ async def users_overview_app( [UUID(user_id) for user_id in user_ids] if user_ids else None ) - return await self.engine.ausers_overview(user_ids=user_uuids) + return await self.engine.ausers_overview(user_ids=user_uuids, offset=offset, limit=limit) @self.router.delete("/delete", status_code=204) @self.base_endpoint @@ -189,6 +191,8 @@ async def delete_app( @self.base_endpoint async def documents_overview_app( document_ids: list[str] = Query([]), + offset: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=1000), auth_user=Depends(self.engine.providers.auth.auth_wrapper), ) -> WrappedDocumentOverviewResponse: request_user_ids = ( @@ -201,6 +205,8 @@ async def documents_overview_app( user_ids=request_user_ids, group_ids=auth_user.group_ids, document_ids=document_uuids, + offset=offset, + limit=limit, ) return result @@ -208,10 +214,12 @@ async def documents_overview_app( @self.base_endpoint async def document_chunks_app( document_id: str = Path(...), + offset: Optional[int] = Query(0, ge=0), + limit: Optional[int] = Query(100, ge=0), auth_user=Depends(self.engine.providers.auth.auth_wrapper), ) -> WrappedDocumentChunkResponse: document_uuid = UUID(document_id) - chunks = await self.engine.adocument_chunks(document_uuid) + chunks = await self.engine.adocument_chunks(document_uuid, offset, limit) if not chunks: raise R2RException( @@ -232,6 +240,7 @@ async def document_chunks_app( @self.router.get("/inspect_knowledge_graph") @self.base_endpoint async def inspect_knowledge_graph( + offset: int = 0, limit: int = 100, print_descriptions: bool = False, auth_user=Depends(self.engine.providers.auth.auth_wrapper), @@ -242,15 +251,15 @@ async def inspect_knowledge_graph( 403, ) return await self.engine.ainspect_knowledge_graph( - limit=limit, print_descriptions=print_descriptions + offset=offset, limit=limit, print_descriptions=print_descriptions ) @self.router.get("/groups_overview") @self.base_endpoint async def groups_overview_app( group_ids: Optional[list[str]] = Query(None), - limit: Optional[int] = Query(100, ge=1, le=1000), offset: Optional[int] = Query(0, ge=0), + limit: Optional[int] = Query(100, ge=1, le=1000), auth_user=Depends(self.engine.providers.auth.auth_wrapper), ) -> WrappedGroupOverviewResponse: if not auth_user.is_superuser: @@ -265,7 +274,7 @@ async def groups_overview_app( else None ) return await self.engine.agroups_overview( - group_ids=group_uuids, limit=limit, offset=offset + group_ids=group_uuids, offset=offset, limit=limit ) @self.router.post("/create_group") @@ -317,7 +326,7 @@ async def update_group_app( async def delete_group_app( group_id: str = Path(..., description="Group ID"), auth_user=Depends(self.engine.providers.auth.auth_wrapper), - ) -> WrappedGroupResponse: + ): if not auth_user.is_superuser: raise R2RException("Only a superuser can delete groups.", 403) group_uuid = UUID(group_id) @@ -396,10 +405,14 @@ async def get_users_in_group_app( limit=min(max(limit, 1), 1000), ) - @self.router.get("/get_groups_for_user/{user_id}") + @self.router.get("/user_groups/{user_id}") @self.base_endpoint async def get_groups_for_user_app( user_id: str = Path(..., description="User ID"), + offset: int = Query(0, ge=0, description="Pagination offset"), + limit: int = Query( + 100, ge=1, le=1000, description="Pagination limit" + ), auth_user=Depends(self.engine.providers.auth.auth_wrapper), ): if not auth_user.is_superuser: @@ -407,7 +420,7 @@ async def get_groups_for_user_app( "Only a superuser can get groups for a user.", 403 ) user_uuid = UUID(user_id) - return await self.engine.aget_groups_for_user(user_uuid) + return await self.engine.aget_groups_for_user(user_uuid, offset, limit) @self.router.post("/assign_document_to_group") @self.base_endpoint @@ -448,13 +461,19 @@ async def remove_document_from_group_app( @self.base_endpoint async def document_groups_app( document_id: str = Path(..., description="Document ID"), + offset: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=1000), auth_user=Depends(self.engine.providers.auth.auth_wrapper), ) -> WrappedGroupListResponse: - return await self.engine.adocument_groups(document_id) + if not auth_user.is_superuser: + raise R2RException( + "Only a superuser can get the groups belonging to a document.", 403 + ) + return await self.engine.adocument_groups(document_id, offset, limit) @self.router.get("/group/{group_id}/documents") @self.base_endpoint - async def get_documents_in_group_app( + async def documents_in_group_app( group_id: str = Path(..., description="Group ID"), offset: int = Query(0, ge=0), limit: int = Query(100, ge=1, le=1000), @@ -465,6 +484,6 @@ async def get_documents_in_group_app( "Only a superuser can get documents in a group.", 403 ) group_uuid = UUID(group_id) - return await self.engine.aget_documents_in_group( + return await self.engine.adocuments_in_group( group_uuid, offset, limit ) diff --git a/py/core/main/engine.py b/py/core/main/engine.py index b5092e6ae..9a759a02b 100644 --- a/py/core/main/engine.py +++ b/py/core/main/engine.py @@ -262,8 +262,8 @@ async def agroups_overview(self, *args, **kwargs): return await self.management_service.agroups_overview(*args, **kwargs) @syncable - async def aget_documents_in_group(self, *args, **kwargs): - return await self.management_service.aget_documents_in_group( + async def adocuments_in_group(self, *args, **kwargs): + return await self.management_service.adocuments_in_group( *args, **kwargs ) diff --git a/py/core/main/services/auth_service.py b/py/core/main/services/auth_service.py index 5fd0fb9d7..84b318729 100644 --- a/py/core/main/services/auth_service.py +++ b/py/core/main/services/auth_service.py @@ -133,6 +133,7 @@ async def delete_user( self, user_id: UUID, password: Optional[str] = None, + delete_vector_data: bool = False, is_superuser: bool = False, ) -> dict[str, str]: user = self.providers.database.relational.get_user_by_id(user_id) @@ -146,6 +147,9 @@ async def delete_user( ): raise R2RException(status_code=400, message="Incorrect password") self.providers.database.relational.delete_user(user_id) + if (delete_vector_data): + self.providers.database.vector.delete_user(user_id) + return {"message": f"User account {user_id} deleted successfully."} @telemetry_event("CleanExpiredBlacklistedTokens") diff --git a/py/core/main/services/management_service.py b/py/core/main/services/management_service.py index 70db37000..8e14c5cf1 100644 --- a/py/core/main/services/management_service.py +++ b/py/core/main/services/management_service.py @@ -240,11 +240,13 @@ async def ascore_completion( async def ausers_overview( self, user_ids: Optional[list[UUID]] = None, + offset: int = 0, + limit: int = 100, *args, **kwargs, ): return self.providers.database.relational.get_users_overview( - [str(ele) for ele in user_ids] if user_ids else None + [str(ele) for ele in user_ids] if user_ids else None, offset=offset, limit=limit ) @telemetry_event("Delete") @@ -290,6 +292,8 @@ async def adocuments_overview( user_ids: Optional[list[UUID]] = None, group_ids: Optional[list[UUID]] = None, document_ids: Optional[list[UUID]] = None, + offset: Optional[int] = 0, + limit: Optional[int] = 100, *args: Any, **kwargs: Any, ): @@ -297,32 +301,39 @@ async def adocuments_overview( filter_document_ids=document_ids, filter_user_ids=user_ids, filter_group_ids=group_ids, + offset=offset, + limit=limit, ) @telemetry_event("DocumentChunks") async def document_chunks( self, document_id: UUID, + offset: int = 0, + limit: int = 100, *args, **kwargs, ): - return self.providers.database.vector.get_document_chunks(document_id) + return self.providers.database.vector.get_document_chunks(document_id, offset=offset, limit=limit) @telemetry_event("UsersOverview") async def users_overview( self, user_ids: Optional[list[UUID]], + offset: int = 0, + limit: int = 100, *args, **kwargs, ): return self.providers.database.relational.get_users_overview( - [str(ele) for ele in user_ids] + [str(ele) for ele in user_ids], offset=offset, limit=limit ) @telemetry_event("InspectKnowledgeGraph") async def inspect_knowledge_graph( self, - limit=10000, + offset: int = 0, + limit=1000, print_descriptions: bool = False, *args: Any, **kwargs: Any, @@ -335,6 +346,7 @@ async def inspect_knowledge_graph( rel_query = f""" MATCH (n1)-[r]->(n2) return n1.name AS subject, n1.description AS subject_description, n2.name AS object, n2.description AS object_description, type(r) AS relation, r.description AS relation_description + SKIP {offset} LIMIT {limit} """ @@ -429,9 +441,9 @@ async def aremove_document_from_group( return {"message": "Document removed from group successfully"} @telemetry_event("DocumentGroups") - async def adocument_groups(self, document_id: str): + async def adocument_groups(self, document_id: str, offset: int = 0, limit: int = 100): group_ids = self.providers.database.relational.document_groups( - document_id + document_id, offset=offset, limit=limit ) return {"group_ids": [str(group_id) for group_id in group_ids]} @@ -554,8 +566,10 @@ async def aupdate_group( @telemetry_event("DeleteGroup") async def adelete_group(self, group_id: UUID) -> bool: - return self.providers.database.relational.delete_group(group_id) - + self.providers.database.relational.delete_group(group_id) + self.providers.database.vector.delete_group(group_id) + return True + @telemetry_event("ListGroups") async def alist_groups( self, offset: int = 0, limit: int = 100 @@ -583,12 +597,12 @@ async def aget_users_in_group( self, group_id: UUID, offset: int = 0, limit: int = 100 ) -> list[dict]: return self.providers.database.relational.get_users_in_group( - group_id, offset, limit + group_id, offset=offset, limit=limit ) @telemetry_event("GetGroupsForUser") - async def aget_groups_for_user(self, user_id: UUID) -> list[dict]: - return self.providers.database.relational.get_groups_for_user(user_id) + async def aget_groups_for_user(self, user_id: UUID, offset: int = 0, limit: int = 100) -> list[dict]: + return self.providers.database.relational.get_groups_for_user(user_id, offset, limit) @telemetry_event("GroupsOverview") async def agroups_overview( @@ -601,18 +615,18 @@ async def agroups_overview( ): return self.providers.database.relational.get_groups_overview( [str(ele) for ele in group_ids] if group_ids else None, - offset, - limit, + offset=offset, + limit=limit, ) @telemetry_event("GetDocumentsInGroup") - async def aget_documents_in_group( + async def adocuments_in_group( self, group_id: UUID, offset: int = 0, limit: int = 100 ) -> list[dict]: - return self.providers.database.relational.get_documents_in_group( - group_id, offset, limit + return self.providers.database.relational.documents_in_group( + group_id, offset=offset, limit=limit ) @telemetry_event("DocumentGroups") - async def adocument_groups(self, document_id: str) -> list[str]: - return self.providers.database.relational.document_groups(document_id) \ No newline at end of file + async def adocument_groups(self, document_id: str, offset: int = 0, limit: int = 100) -> list[str]: + return self.providers.database.relational.document_groups(document_id, offset, limit) \ No newline at end of file diff --git a/py/core/providers/database/document.py b/py/core/providers/database/document.py index 5823d3d77..f14915c1f 100644 --- a/py/core/providers/database/document.py +++ b/py/core/providers/database/document.py @@ -70,9 +70,11 @@ def get_documents_overview( filter_user_ids: Optional[list[UUID]] = None, filter_document_ids: Optional[list[UUID]] = None, filter_group_ids: Optional[list[UUID]] = None, + offset: int = 0, + limit: int = 100, ): conditions = [] - params = {} + params = {"offset": offset, "limit": limit} if filter_document_ids: conditions.append("document_id = ANY(:document_ids)") @@ -93,8 +95,15 @@ def get_documents_overview( if conditions: query += " WHERE " + " AND ".join(conditions) + query += """ + ORDER BY created_at DESC + OFFSET :offset + LIMIT :limit + """ + results = self.execute_query(query, params).fetchall() - return [ + print('results = ', results) + documents = [ DocumentInfo( id=row[0], group_ids=row[1], @@ -110,54 +119,13 @@ def get_documents_overview( ) for row in results ] - - def document_groups(self, document_id: UUID) -> list[GroupResponse]: - query = f""" - SELECT g.group_id, g.name, g.description, g.created_at, g.updated_at - FROM {self._get_table_name('groups')} g - JOIN {self._get_table_name('document_info')} d ON g.group_id = ANY(d.group_ids) - WHERE d.document_id = :document_id - """ - params = {"document_id": document_id} - results = self.execute_query(query, params).fetchall() - - return [ - GroupResponse( - group_id=row[0], - name=row[1], - description=row[2], - created_at=row[3], - updated_at=row[4], - ) - for row in results - ] - - def remove_document_from_group(self, document_id: UUID, group_id: UUID) -> None: - """ - Remove a document from a group. - - Args: - document_id (UUID): The ID of the document to remove. - group_id (UUID): The ID of the group to remove the document from. - Raises: - R2RException: If the group doesn't exist or if the document is not in the group. + # Get total count for pagination metadata + count_query = f""" + SELECT COUNT(*) + FROM {self._get_table_name('document_info')} """ - if not self.group_exists(group_id): - raise R2RException(status_code=404, message="Group not found") + if conditions: + count_query += " WHERE " + " AND ".join(conditions) - query = f""" - UPDATE {self._get_table_name('document_info')} - SET group_ids = array_remove(group_ids, :group_id) - WHERE document_id = :document_id AND :group_id = ANY(group_ids) - RETURNING document_id - """ - result = self.execute_query( - query, {"document_id": document_id, "group_id": group_id} - ).fetchone() - - if not result: - raise R2RException( - status_code=404, - message="Document not found in the specified group" - ) \ No newline at end of file + return documents \ No newline at end of file diff --git a/py/core/providers/database/group.py b/py/core/providers/database/group.py index 2c1105e68..4d4ea9972 100644 --- a/py/core/providers/database/group.py +++ b/py/core/providers/database/group.py @@ -105,17 +105,34 @@ def update_group( created_at=result[3], updated_at=result[4], ) - + def delete_group(self, group_id: UUID) -> None: - query = f""" + # Remove group_id from users + user_update_query = f""" + UPDATE {self._get_table_name('users')} + SET group_ids = array_remove(group_ids, :group_id) + WHERE :group_id = ANY(group_ids) + """ + self.execute_query(user_update_query, {"group_id": group_id}) + + # Remove group_id from documents in the relational database + doc_update_query = f""" + UPDATE {self._get_table_name('document_info')} + SET group_ids = array_remove(group_ids, :group_id) + WHERE :group_id = ANY(group_ids) + """ + self.execute_query(doc_update_query, {"group_id": group_id}) + + # Delete the group + delete_query = f""" DELETE FROM {self._get_table_name('groups')} WHERE group_id = :group_id RETURNING group_id """ - result = self.execute_query(query, {"group_id": group_id}).fetchone() + result = self.execute_query(delete_query, {"group_id": group_id}).fetchone() + if not result: raise R2RException(status_code=404, message="Group not found") - return None def list_groups( self, offset: int = 0, limit: int = 100 @@ -258,7 +275,7 @@ def get_users_in_group( for row in results ] - def get_documents_in_group( + def documents_in_group( self, group_id: UUID, offset: int = 0, limit: int = 100 ) -> list[DocumentInfo]: """ @@ -376,7 +393,7 @@ def get_groups_for_user( ) for row in results ] - + def assign_document_to_group(self, document_id: UUID, group_id: UUID) -> None: """ Assign a document to a group. @@ -386,25 +403,106 @@ def assign_document_to_group(self, document_id: UUID, group_id: UUID) -> None: group_id (UUID): The ID of the group to assign the document to. Raises: - R2RException: If the group doesn't exist or if the document is not found. + R2RException: If the group doesn't exist, if the document is not found, + or if there's a database error. + """ + try: + if not self.group_exists(group_id): + raise R2RException(status_code=404, message="Group not found") + + # First, check if the document exists + document_check_query = f""" + SELECT 1 FROM {self._get_table_name('document_info')} + WHERE document_id = :document_id + """ + document_exists = self.execute_query( + document_check_query, {"document_id": document_id} + ).fetchone() + + if not document_exists: + raise R2RException(status_code=404, message="Document not found") + + # If document exists, proceed with the assignment + assign_query = f""" + UPDATE {self._get_table_name('document_info')} + SET group_ids = array_append(group_ids, :group_id) + WHERE document_id = :document_id AND NOT (:group_id = ANY(group_ids)) + RETURNING document_id + """ + result = self.execute_query( + assign_query, {"document_id": document_id, "group_id": group_id} + ).fetchone() + + if not result: + # Document exists but was already assigned to the group + raise R2RException( + status_code=409, + message="Document is already assigned to the group" + ) + + except R2RException: + # Re-raise R2RExceptions as they are already handled + raise + except Exception as e: + raise R2RException( + status_code=500, + message="An error occurred while assigning the document to the group" + ) + + def document_groups(self, document_id: UUID, offset: int = 0, limit: int = 100) -> list[GroupResponse]: + query = f""" + SELECT g.group_id, g.name, g.description, g.created_at, g.updated_at + FROM {self._get_table_name('groups')} g + JOIN {self._get_table_name('document_info')} d ON g.group_id = ANY(d.group_ids) + WHERE d.document_id = :document_id + ORDER BY g.name + OFFSET :offset + LIMIT :limit + """ + params = { + "document_id": document_id, + "offset": offset, + "limit": limit + } + results = self.execute_query(query, params).fetchall() + + return [ + GroupResponse( + group_id=row[0], + name=row[1], + description=row[2], + created_at=row[3], + updated_at=row[4], + ) + for row in results + ] + + def remove_document_from_group(self, document_id: UUID, group_id: UUID) -> None: + """ + Remove a document from a group. + + Args: + document_id (UUID): The ID of the document to remove. + group_id (UUID): The ID of the group to remove the document from. + + Raises: + R2RException: If the group doesn't exist or if the document is not in the group. """ if not self.group_exists(group_id): raise R2RException(status_code=404, message="Group not found") query = f""" UPDATE {self._get_table_name('document_info')} - SET group_ids = array_append(group_ids, :group_id) - WHERE document_id = :document_id AND NOT (:group_id = ANY(group_ids)) + SET group_ids = array_remove(group_ids, :group_id) + WHERE document_id = :document_id AND :group_id = ANY(group_ids) RETURNING document_id """ result = self.execute_query( query, {"document_id": document_id, "group_id": group_id} ).fetchone() - if not result: raise R2RException( status_code=404, - message="Document not found or already assigned to the group" - ) - + message="Document not found in the specified group" + ) \ No newline at end of file diff --git a/py/core/providers/database/user.py b/py/core/providers/database/user.py index 7208b1859..5ff701e26 100644 --- a/py/core/providers/database/user.py +++ b/py/core/providers/database/user.py @@ -193,13 +193,53 @@ def update_user(self, user: UserResponse) -> UserResponse: group_ids=result[10], ) - def delete_user(self, user_id: UUID): - query = f""" + + def delete_user(self, user_id: UUID) -> None: + print("A") + # Get the groups the user belongs to + group_query = f""" + SELECT group_ids FROM {self._get_table_name('users')} + WHERE user_id = :user_id + """ + group_result = self.execute_query(group_query, {"user_id": user_id}).fetchone() + + print("B") + if not group_result: + raise R2RException(status_code=404, message="User not found") + + user_groups = group_result[0] + + print("C") + # Remove user from all groups they belong to + if user_groups: + group_update_query = f""" + UPDATE {self._get_table_name('groups')} + SET user_ids = array_remove(user_ids, :user_id) + WHERE group_id = ANY(:group_ids) + """ + self.execute_query(group_update_query, {"user_id": user_id, "group_ids": user_groups}) + + print("D") + # Remove user from documents + doc_update_query = f""" + UPDATE {self._get_table_name('document_info')} + SET user_id = NULL + WHERE user_id = :user_id + """ + self.execute_query(doc_update_query, {"user_id": user_id}) + + print("E") + # Delete the user + delete_query = f""" DELETE FROM {self._get_table_name('users')} WHERE user_id = :user_id + RETURNING user_id """ - result = self.execute_query(query, {"user_id": user_id}) - if result.rowcount == 0: + + print("F") + result = self.execute_query(delete_query, {"user_id": user_id}).fetchone() + + if not result: raise R2RException(status_code=404, message="User not found") def update_user_password(self, user_id: UUID, new_hashed_password: str): diff --git a/py/core/providers/database/vector.py b/py/core/providers/database/vector.py index 03d347fb1..48b7197c6 100644 --- a/py/core/providers/database/vector.py +++ b/py/core/providers/database/vector.py @@ -1,3 +1,4 @@ +from uuid import UUID import logging import os from typing import Any, Optional @@ -390,7 +391,70 @@ def remove_document_from_group(self, document_id: str, group_id: str) -> None: if result.rowcount == 0: logger.warning(f"Document {document_id} not found in group {group_id} or already removed") - def get_document_chunks(self, document_id: str) -> list[dict]: + def remove_group_from_documents(self, group_id: str) -> None: + if self.collection is None: + raise ValueError("Collection is not initialized.") + + table_name = self.collection.table.name + query = text( + f""" + UPDATE vecs."{table_name}" + SET group_ids = array_remove(group_ids, :group_id) + WHERE :group_id = ANY(group_ids) + """ + ) + + with self.vx.Session() as sess: + sess.execute(query, {"group_id": group_id}) + sess.commit() + + def delete_user(self, user_id: str) -> None: + if self.collection is None: + raise ValueError("Collection is not initialized.") + + table_name = self.collection.table.name + query = text( + f""" + UPDATE vecs."{table_name}" + SET user_id = NULL + WHERE user_id = :user_id + """ + ) + + with self.vx.Session() as sess: + sess.execute(query, {"user_id": user_id}) + sess.commit() + + def delete_group(self, group_id: str) -> None: + """ + Remove the specified group ID from all documents in the vector database. + + Args: + group_id (str): The ID of the group to remove from all documents. + + Raises: + ValueError: If the collection is not initialized. + """ + if self.collection is None: + raise ValueError("Collection is not initialized.") + + table_name = self.collection.table.name + query = text( + f""" + UPDATE vecs."{table_name}" + SET group_ids = array_remove(group_ids, :group_id) + WHERE :group_id = ANY(group_ids) + """ + ) + + with self.vx.Session() as sess: + result = sess.execute(query, {"group_id": group_id}) + sess.commit() + + affected_rows = result.rowcount + logger.info(f"Removed group {group_id} from {affected_rows} documents.") + + def get_document_chunks(self, document_id: str, offset: int = 0, limit: int = 100) -> dict: if not self.collection: raise ValueError("Collection is not initialized.") @@ -401,22 +465,33 @@ def get_document_chunks(self, document_id: str) -> list[dict]: FROM vecs."{table_name}" WHERE document_id = :document_id ORDER BY CAST(metadata->>'chunk_order' AS INTEGER) + LIMIT :limit OFFSET :offset + """ + ) + + count_query = text( + f""" + SELECT COUNT(*) + FROM vecs."{table_name}" + WHERE document_id = :document_id """ ) - params = {"document_id": document_id} + params = {"document_id": document_id, "limit": limit, "offset": offset} with self.vx.Session() as sess: results = sess.execute(query, params).fetchall() - return [ - { - "fragment_id": result[0], - "extraction_id": result[1], - "document_id": result[2], - "user_id": result[3], - "group_ids": result[4], - "text": result[5], - "metadata": result[6], - } - for result in results - ] + total_count = sess.execute(count_query, {"document_id": document_id}).scalar() + + return [ + { + "fragment_id": result[0], + "extraction_id": result[1], + "document_id": result[2], + "user_id": result[3], + "group_ids": result[4], + "text": result[5], + "metadata": result[6], + } + for result in results + ] \ No newline at end of file diff --git a/py/sdk/auth.py b/py/sdk/auth.py index e71f7bbe2..5ef62b89e 100644 --- a/py/sdk/auth.py +++ b/py/sdk/auth.py @@ -170,22 +170,4 @@ async def confirm_password_reset( dict: The response from the server. """ data = {"reset_token": reset_token, "new_password": new_password} - return await client._make_request("POST", "reset_password", json=data) - - @staticmethod - async def delete_user(client, user_id: str, password: str = None) -> dict: - """ - Deletes the user with the given user ID. - - Args: - user_id (str): The ID of the user to delete. - password (str, optional): The password of the user to delete. - - Returns: - dict: The response from the server. - """ - data = {"user_id": user_id, "password": password} - response = await client._make_request("DELETE", "user", json=data) - client.access_token = None - client._refresh_token = None - return response + return await client._make_request("POST", "reset_password", json=data) \ No newline at end of file diff --git a/py/sdk/management.py b/py/sdk/management.py index 74830e7c7..1c95c4445 100644 --- a/py/sdk/management.py +++ b/py/sdk/management.py @@ -95,6 +95,8 @@ async def score_completion( async def users_overview( client, user_ids: Optional[list[str]] = None, + offset: Optional[int] = None, + limit: Optional[int] = None, ) -> dict: """ An overview of users in the R2R deployment. @@ -108,6 +110,10 @@ async def users_overview( params = {} if user_ids is not None: params["user_ids"] = [str(uid) for uid in user_ids] + if offset is not None: + params["offset"] = offset + if limit is not None: + params["limit"] = limit return await client._make_request( "GET", "users_overview", params=params ) @@ -136,6 +142,8 @@ async def delete( async def documents_overview( client, document_ids: Optional[list[Union[UUID, str]]] = None, + offset: Optional[int] = None, + limit: Optional[int] = None, ) -> dict: """ Get an overview of documents in the R2R deployment. @@ -152,7 +160,10 @@ async def documents_overview( ) if document_ids: params["document_ids"] = document_ids - + if offset is not None: + params["offset"] = offset + if limit is not None: + params["limit"] = limit return await client._make_request( "GET", "documents_overview", params=params ) @@ -161,6 +172,8 @@ async def documents_overview( async def document_chunks( client, document_id: str, + offset: Optional[int] = None, + limit: Optional[int] = None, ) -> dict: """ Get the chunks for a document. @@ -171,13 +184,24 @@ async def document_chunks( Returns: dict: The chunks for the document. """ - return await client._make_request( - "GET", f"document_chunks/{document_id}" - ) + params = {} + if offset is not None: + params["offset"] = offset + if limit is not None: + params["limit"] = limit + if params: + return await client._make_request( + "GET", f"document_chunks/{document_id}" + ) + else: + return await client._make_request( + "GET", f"document_chunks/{document_id}", params=params + ) @staticmethod async def inspect_knowledge_graph( client, + offset: Optional[int] = None, limit: Optional[int] = None, ) -> dict: """ @@ -190,6 +214,8 @@ async def inspect_knowledge_graph( dict: The knowledge graph inspection results. """ params = {} + if offset is not None: + params["offset"] = offset if limit is not None: params["limit"] = limit return await client._make_request( @@ -200,8 +226,8 @@ async def inspect_knowledge_graph( async def groups_overview( client, group_ids: Optional[list[str]] = None, - limit: Optional[int] = None, offset: Optional[int] = None, + limit: Optional[int] = None, ) -> dict: """ Get an overview of existing groups. @@ -217,10 +243,10 @@ async def groups_overview( params = {} if group_ids: params["group_ids"] = group_ids - if limit: - params["limit"] = limit if offset: params["offset"] = offset + if limit: + params["limit"] = limit return await client._make_request( "GET", "groups_overview", params=params ) @@ -305,6 +331,34 @@ async def delete_group( """ return await client._make_request("DELETE", f"delete_group/{group_id}") + + @staticmethod + async def delete_user( + client, + user_id: str, + password: Optional[str] = None, + delete_vector_data: bool = False, + ) -> dict: + """ + Delete a group by its ID. + + Args: + group_id (str): The ID of the group to delete. + + Returns: + dict: The response from the server. + """ + params = {} + if password is not None: + params["password"] = password + if delete_vector_data: + params["delete_vector_data"] = delete_vector_data + if params == {}: + return await client._make_request("DELETE", f"user/{user_id}") + else: + return await client._make_request("DELETE", f"user/{user_id}", json=params) + + @staticmethod async def list_groups( client, @@ -404,9 +458,11 @@ async def get_users_in_group( ) @staticmethod - async def get_groups_for_user( + async def user_groups( client, user_id: str, + offset: Optional[int] = None, + limit: Optional[int] = None, ) -> dict: """ Get all groups that a user is a member of. @@ -417,9 +473,19 @@ async def get_groups_for_user( Returns: dict: The list of groups that the user is a member of. """ - return await client._make_request( - "GET", f"get_groups_for_user/{user_id}" - ) + params = {} + if offset is not None: + params["offset"] = offset + if limit is not None: + params["limit"] = limit + if params: + return await client._make_request( + "GET", f"user_groups/{user_id}" + ) + else: + return await client._make_request( + "GET", f"user_groups/{user_id}", params=params + ) @staticmethod async def assign_document_to_group( @@ -474,6 +540,8 @@ async def remove_document_from_group( async def document_groups( client, document_id: str, + offset: Optional[int] = None, + limit: Optional[int] = None ) -> dict: """ Get all groups that a document is assigned to. @@ -484,12 +552,22 @@ async def document_groups( Returns: dict: The list of groups that the document is assigned to. """ - return await client._make_request( - "GET", f"document_groups/{document_id}" - ) + params = {} + if offset is not None: + params["offset"] = offset + if limit is not None: + params["limit"] = limit + if params: + return await client._make_request( + "GET", f"document_groups/{document_id}", params=params + ) + else: + return await client._make_request( + "GET", f"document_groups/{document_id}" + ) @staticmethod - async def get_documents_in_group( + async def documents_in_group( client, group_id: str, offset: Optional[int] = None, diff --git a/py/tests/test_client.py b/py/tests/test_client.py index ff983cc9a..312e6ebfe 100644 --- a/py/tests/test_client.py +++ b/py/tests/test_client.py @@ -100,7 +100,7 @@ def update_user(user): return updated_user db.relational.update_user.side_effect = update_user - db.relational.get_documents_in_group.return_value = [ + db.relational.documents_in_group.return_value = [ DocumentInfo( user_id=uuid.uuid4(), id=uuid.uuid4(), @@ -357,7 +357,7 @@ async def test_user_profile(r2r_client, mock_db): @pytest.mark.asyncio -async def test_get_documents_in_group(r2r_client, mock_db): +async def test_documents_in_group(r2r_client, mock_db): # Register and login as a superuser user_data = {"email": "superuser@example.com", "password": "password123"} r2r_client.register(**user_data) @@ -369,7 +369,7 @@ async def test_get_documents_in_group(r2r_client, mock_db): # Get documents in group group_id = uuid.uuid4() - response = r2r_client.get_documents_in_group(group_id) + response = r2r_client.documents_in_group(group_id) assert "results" in response assert len(response["results"]) == 100 # Default limit diff --git a/py/tests/test_groups.py b/py/tests/test_groups.py index 2562bd38a..0e13c2212 100644 --- a/py/tests/test_groups.py +++ b/py/tests/test_groups.py @@ -362,21 +362,21 @@ def test_get_groups_for_user_with_pagination(pg_db, test_user): pg_db.relational.delete_group(group.group_id) -def test_get_documents_in_group(pg_db, test_group, test_documents): +def test_documents_in_group(pg_db, test_group, test_documents): # Test getting all documents - all_docs = pg_db.relational.get_documents_in_group(test_group.group_id) + all_docs = pg_db.relational.documents_in_group(test_group.group_id) assert len(all_docs) == 5 assert all(isinstance(doc, DocumentInfo) for doc in all_docs) assert all(test_group.group_id in doc.group_ids for doc in all_docs) # Test pagination - first page - first_page = pg_db.relational.get_documents_in_group( + first_page = pg_db.relational.documents_in_group( test_group.group_id, offset=0, limit=3 ) assert len(first_page) == 3 # Test pagination - second page - second_page = pg_db.relational.get_documents_in_group( + second_page = pg_db.relational.documents_in_group( test_group.group_id, offset=3, limit=3 ) assert len(second_page) == 2 @@ -394,11 +394,11 @@ def test_get_documents_in_group(pg_db, test_group, test_documents): # Test with non-existent group non_existent_id = UUID("00000000-0000-0000-0000-000000000000") with pytest.raises(R2RException): - pg_db.relational.get_documents_in_group(non_existent_id) + pg_db.relational.documents_in_group(non_existent_id) # Test with empty group empty_group = pg_db.relational.create_group("Empty Group", "No documents") - empty_docs = pg_db.relational.get_documents_in_group(empty_group.group_id) + empty_docs = pg_db.relational.documents_in_group(empty_group.group_id) assert len(empty_docs) == 0 # Clean up