diff --git a/docs/api-reference/endpoint/ingest_chunks.mdx b/docs/api-reference/endpoint/ingest_chunks.mdx new file mode 100644 index 000000000..0264de4f7 --- /dev/null +++ b/docs/api-reference/endpoint/ingest_chunks.mdx @@ -0,0 +1,4 @@ +--- +title: 'Ingest Chunks' +openapi: 'POST /v2/ingest_chunks' +--- diff --git a/docs/mint.json b/docs/mint.json index dd03e1058..344d22004 100644 --- a/docs/mint.json +++ b/docs/mint.json @@ -294,6 +294,7 @@ "group": "Document Ingestion", "pages": [ "api-reference/endpoint/ingest_files", + "api-reference/endpoint/ingest_chunks", "api-reference/endpoint/update_files" ] }, diff --git a/py/cli/commands/ingestion.py b/py/cli/commands/ingestion.py index 043e4dff6..6f6a1920f 100644 --- a/py/cli/commands/ingestion.py +++ b/py/cli/commands/ingestion.py @@ -123,13 +123,13 @@ def update_files(ctx, file_paths, document_ids, metadatas): @cli.command() -@click.option("--v2", is_flag=True, help="use aristotle_v2.txt (a smaller file)") +@click.option( + "--v2", is_flag=True, help="use aristotle_v2.txt (a smaller file)" +) @pass_context def ingest_sample_file(ctx, v2=False): """Ingest the first sample file into R2R.""" - sample_file_url = ( - f"https://raw.githubusercontent.com/SciPhi-AI/R2R/main/py/core/examples/data/aristotle{'_v2' if v2 else ''}.txt" - ) + sample_file_url = f"https://raw.githubusercontent.com/SciPhi-AI/R2R/main/py/core/examples/data/aristotle{'_v2' if v2 else ''}.txt" client = ctx.obj with timer(): diff --git a/py/core/base/abstractions/__init__.py b/py/core/base/abstractions/__init__.py index 19d99f647..bcd63fa31 100644 --- a/py/core/base/abstractions/__init__.py +++ b/py/core/base/abstractions/__init__.py @@ -7,8 +7,9 @@ DocumentInfo, DocumentType, IngestionStatus, - KGExtractionStatus, KGEnrichmentStatus, + KGExtractionStatus, + RawChunk, ) from shared.abstractions.embedding import ( EmbeddingPurpose, @@ -77,10 +78,11 @@ "Document", "DocumentExtraction", "DocumentInfo", + "DocumentType", "IngestionStatus", "KGExtractionStatus", "KGEnrichmentStatus", - "DocumentType", + "RawChunk", # Embedding abstractions "EmbeddingPurpose", "default_embedding_prefixes", diff --git a/py/core/main/api/auth_router.py b/py/core/main/api/auth_router.py index 81bcd4c4e..a06db2590 100644 --- a/py/core/main/api/auth_router.py +++ b/py/core/main/api/auth_router.py @@ -1,5 +1,5 @@ -from uuid import UUID from typing import Optional +from uuid import UUID from fastapi import Body, Depends, Path from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm diff --git a/py/core/main/api/data/ingestion_router_openapi.yml b/py/core/main/api/data/ingestion_router_openapi.yml index 48c836b18..cb241066b 100644 --- a/py/core/main/api/data/ingestion_router_openapi.yml +++ b/py/core/main/api/data/ingestion_router_openapi.yml @@ -59,3 +59,51 @@ update_files: document_ids: "An optional list of document ids for each file. If not provided, the system will attempt to generate the corresponding unique from the `generate_document_id` method." metadatas: "An optional list of JSON metadata to affix to each file" ingestion_config: "JSON string for chunking configuration override" + +ingest_chunks: + openapi_extra: + x-codeSamples: + - lang: Python + source: | + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.ingest_chunks( + chunks=[ + { + "text": "Another chunk of text", + }, + { + "text": "Yet another chunk of text", + }, + { + "text": "A chunk of text", + }, + ], ) + - lang: Shell + source: | + curl -X POST "https://api.example.com/ingest_chunks" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer YOUR_API_KEY" \ + -d '{ + "chunks": [ + { + "text": "Another chunk of text" + }, + { + "text": "Yet another chunk of text" + }, + { + "text": "A chunk of text" + } + ], + "document_id": "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa", + "metadata": {} + }' + + input_descriptions: + chunks: "A list of text chunks to ingest into the system." + document_id: "An optional document id to associate the chunks with. If not provided, a unique document id will be generated." + metadata: "Optional JSON metadata to associate with the ingested chunks." \ No newline at end of file diff --git a/py/core/main/api/ingestion_router.py b/py/core/main/api/ingestion_router.py index 45297df39..fff763948 100644 --- a/py/core/main/api/ingestion_router.py +++ b/py/core/main/api/ingestion_router.py @@ -6,10 +6,10 @@ from uuid import UUID import yaml -from fastapi import Depends, File, Form, UploadFile +from fastapi import Body, Depends, File, Form, UploadFile from pydantic import Json -from core.base import R2RException, generate_document_id +from core.base import R2RException, RawChunk, generate_document_id from core.base.api.models import ( WrappedIngestionResponse, WrappedUpdateResponse, @@ -38,12 +38,17 @@ def _register_workflows(self): self.service, { "ingest-files": ( - "Ingestion task queued successfully." + "Ingest files task queued successfully." + if self.orchestration_provider.config.provider != "simple" + else "Ingestion task completed successfully." + ), + "ingest-chunks": ( + "Ingest chunks task queued successfully." if self.orchestration_provider.config.provider != "simple" else "Ingestion task completed successfully." ), "update-files": ( - "Update task queued successfully." + "Update file task queued successfully." if self.orchestration_provider.config.provider != "simple" else "Update task queued successfully." ), @@ -96,6 +101,7 @@ async def ingest_files_app( A valid user authentication token is required to access this endpoint, as regular users can only ingest files for their own access. More expansive collection permissioning is under development. """ + self._validate_ingestion_config(ingestion_config) # Check if the user is a superuser if not auth_user.is_superuser: @@ -253,6 +259,60 @@ async def update_files_app( raw_message["document_ids"] = workflow_input["document_ids"] return raw_message + ingest_chunks_extras = self.openapi_extras.get("ingest_chunks", {}) + ingest_chunks_descriptions = ingest_chunks_extras.get( + "input_descriptions", {} + ) + + @self.router.post( + "/ingest_chunks", + openapi_extra=ingest_chunks_extras.get("openapi_extra"), + ) + @self.base_endpoint + async def ingest_chunks_app( + chunks: Json[list[RawChunk]] = Body( + {}, description=ingest_chunks_descriptions.get("chunks") + ), + document_id: Optional[UUID] = Body( + None, description=ingest_chunks_descriptions.get("document_id") + ), + metadata: Optional[Json[dict]] = Body( + None, description=ingest_files_descriptions.get("metadata") + ), + auth_user=Depends(self.service.providers.auth.auth_wrapper), + response_model=WrappedIngestionResponse, + ): + """ + Ingest text chunks into the system. + + This endpoint supports multipart/form-data requests, enabling you to ingest pre-parsed text chunks into R2R. + + A valid user authentication token is required to access this endpoint, as regular users can only ingest chunks for their own access. More expansive collection permissioning is under development. + """ + if not document_id: + document_id = generate_document_id( + chunks[0].text[:20], auth_user.id + ) + + workflow_input = { + "document_id": str(document_id), + "chunks": [chunk.model_dump() for chunk in chunks], + "metadata": metadata or {}, + "user": auth_user.model_dump_json(), + } + + raw_message = await self.orchestration_provider.run_workflow( + "ingest-chunks", + {"request": workflow_input}, + options={ + "additional_metadata": { + "document_id": str(document_id), + } + }, + ) + raw_message["document_id"] = str(document_id) + return raw_message + @staticmethod def _validate_ingestion_config(ingestion_config): from ..assembly.factory import R2RProviderFactory diff --git a/py/core/main/api/kg_router.py b/py/core/main/api/kg_router.py index 081ad3a90..e99f665e3 100644 --- a/py/core/main/api/kg_router.py +++ b/py/core/main/api/kg_router.py @@ -149,7 +149,9 @@ async def enrich_graph( if run_type is KGRunType.ESTIMATE: - return await self.service.get_enrichment_estimate(collection_id, server_kg_enrichment_settings) + return await self.service.get_enrichment_estimate( + collection_id, server_kg_enrichment_settings + ) if kg_enrichment_settings: for key, value in kg_enrichment_settings.items(): diff --git a/py/core/main/orchestration/hatchet/ingestion_workflow.py b/py/core/main/orchestration/hatchet/ingestion_workflow.py index 2e5cd27e1..2a8739ae6 100644 --- a/py/core/main/orchestration/hatchet/ingestion_workflow.py +++ b/py/core/main/orchestration/hatchet/ingestion_workflow.py @@ -4,7 +4,13 @@ from hatchet_sdk import Context -from core.base import IngestionStatus, OrchestrationProvider, increment_version +from core.base import ( + DocumentExtraction, + IngestionStatus, + OrchestrationProvider, + generate_extraction_id, + increment_version, +) from core.base.abstractions import DocumentInfo, R2RException from core.utils import generate_default_user_collection_id @@ -130,7 +136,7 @@ async def embed(self, context: Context) -> dict: return { "status": "Successfully finalized ingestion", - "document_info": document_info.to_dict(), + "document_info": document_info.model_dump(), } @orchestration_provider.failure() @@ -272,9 +278,154 @@ async def update_files(self, context: Context) -> None: return None + @orchestration_provider.workflow( + name="ingest-chunks", + timeout="60m", + ) + class HatchetIngestChunksWorkflow: + def __init__(self, ingestion_service: IngestionService): + self.ingestion_service = ingestion_service + + @orchestration_provider.step(timeout="60m") + async def ingest(self, context: Context) -> dict: + input_data = context.workflow_input()["request"] + parsed_data = IngestionServiceAdapter.parse_ingest_chunks_input( + input_data + ) + + document_info = await self.ingestion_service.ingest_chunks_ingress( + **parsed_data + ) + + await self.ingestion_service.update_document_status( + document_info, status=IngestionStatus.EMBEDDING + ) + document_id = document_info.id + + extractions = [ + DocumentExtraction( + id=generate_extraction_id(document_id, i), + document_id=document_id, + collection_ids=[], + user_id=document_info.user_id, + data=chunk.text, + metadata=parsed_data["metadata"], + ).to_dict() + for i, chunk in enumerate(parsed_data["chunks"]) + ] + return { + "status": "Successfully ingested chunks", + "extractions": extractions, + "document_info": document_info.to_dict(), + } + + @orchestration_provider.step(parents=["ingest"], timeout="60m") + async def embed(self, context: Context) -> dict: + document_info_dict = context.step_output("ingest")["document_info"] + document_info = DocumentInfo(**document_info_dict) + + extractions = context.step_output("ingest")["extractions"] + + embedding_generator = await self.ingestion_service.embed_document( + extractions + ) + embeddings = [ + embedding.model_dump() + async for embedding in embedding_generator + ] + + await self.ingestion_service.update_document_status( + document_info, status=IngestionStatus.STORING + ) + + storage_generator = await self.ingestion_service.store_embeddings( + embeddings + ) + async for _ in storage_generator: + pass + + return { + "status": "Successfully embedded and stored chunks", + "document_info": document_info.to_dict(), + } + + @orchestration_provider.step(parents=["embed"], timeout="60m") + async def finalize(self, context: Context) -> dict: + document_info_dict = context.step_output("embed")["document_info"] + document_info = DocumentInfo(**document_info_dict) + + await self.ingestion_service.finalize_ingestion( + document_info, is_update=False + ) + + await self.ingestion_service.update_document_status( + document_info, status=IngestionStatus.SUCCESS + ) + + try: + collection_id = await self.ingestion_service.providers.database.relational.assign_document_to_collection( + document_id=document_info.id, + collection_id=generate_default_user_collection_id( + str(document_info.user_id) + ), + ) + self.ingestion_service.providers.database.vector.assign_document_to_collection( + document_id=document_info.id, collection_id=collection_id + ) + except Exception as e: + logger.error( + f"Error during assigning document to collection: {str(e)}" + ) + + return { + "status": "Successfully finalized ingestion", + "document_info": document_info.to_dict(), + } + + @orchestration_provider.failure() + async def on_failure(self, context: Context) -> None: + request = context.workflow_input().get("request", {}) + document_id = request.get("document_id") + + if not document_id: + logger.error( + "No document id was found in workflow input to mark a failure." + ) + return + + try: + documents_overview = ( + await self.ingestion_service.providers.database.relational.get_documents_overview( + filter_document_ids=[document_id] + ) + )["results"] + + if not documents_overview: + logger.error( + f"Document with id {document_id} not found in database to mark failure." + ) + return + + document_info = documents_overview[0] + + if ( + not document_info.ingestion_status + == IngestionStatus.SUCCESS + ): + await self.ingestion_service.update_document_status( + document_info, status=IngestionStatus.FAILED + ) + + except Exception as e: + logger.error( + f"Failed to update document status for {document_id}: {e}" + ) + ingest_files_workflow = HatchetIngestFilesWorkflow(service) update_files_workflow = HatchetUpdateFilesWorkflow(service) + ingest_chunks_workflow = HatchetIngestChunksWorkflow(service) return { "ingest_files": ingest_files_workflow, "update_files": update_files_workflow, + "ingest_chunks": ingest_chunks_workflow, } diff --git a/py/core/main/orchestration/simple/ingestion_workflow.py b/py/core/main/orchestration/simple/ingestion_workflow.py index 4e4e06c11..29a2c4d85 100644 --- a/py/core/main/orchestration/simple/ingestion_workflow.py +++ b/py/core/main/orchestration/simple/ingestion_workflow.py @@ -1,8 +1,11 @@ import asyncio import logging -from core.base import R2RException, increment_version -from core.utils import generate_default_user_collection_id +from core.base import DocumentExtraction, R2RException, increment_version +from core.utils import ( + generate_default_user_collection_id, + generate_extraction_id, +) from ...services import IngestionService @@ -168,4 +171,80 @@ async def update_files(input_data): await asyncio.gather(*results) - return {"ingest-files": ingest_files, "update-files": update_files} + async def ingest_chunks(input_data): + try: + from core.base import IngestionStatus + from core.main import IngestionServiceAdapter + + parsed_data = IngestionServiceAdapter.parse_ingest_chunks_input( + input_data + ) + + document_info = await service.ingest_chunks_ingress(**parsed_data) + + await service.update_document_status( + document_info, status=IngestionStatus.EMBEDDING + ) + document_id = document_info.id + + extractions = [ + DocumentExtraction( + id=generate_extraction_id(document_id, i), + document_id=document_id, + collection_ids=[], + user_id=document_info.user_id, + data=chunk.text, + metadata=parsed_data["metadata"], + ).model_dump() + for i, chunk in enumerate(parsed_data["chunks"]) + ] + + embedding_generator = await service.embed_document(extractions) + embeddings = [ + embedding.model_dump() + async for embedding in embedding_generator + ] + + await service.update_document_status( + document_info, status=IngestionStatus.STORING + ) + storage_generator = await service.store_embeddings(embeddings) + async for _ in storage_generator: + pass + + await service.finalize_ingestion(document_info, is_update=False) + + await service.update_document_status( + document_info, status=IngestionStatus.SUCCESS + ) + + try: + collection_id = await service.providers.database.relational.assign_document_to_collection( + document_id=document_info.id, + collection_id=generate_default_user_collection_id( + str(document_info.user_id) + ), + ) + service.providers.database.vector.assign_document_to_collection( + document_id=document_info.id, collection_id=collection_id + ) + except Exception as e: + logger.error( + f"Error during assigning document to collection: {str(e)}" + ) + + except Exception as e: + if document_info is not None: + await service.update_document_status( + document_info, status=IngestionStatus.FAILED + ) + raise R2RException( + status_code=500, + message=f"Error during chunk ingestion: {str(e)}", + ) + + return { + "ingest-files": ingest_files, + "update-files": update_files, + "ingest-chunks": ingest_chunks, + } diff --git a/py/core/main/services/ingestion_service.py b/py/core/main/services/ingestion_service.py index 48bfa9d79..ffe45080d 100644 --- a/py/core/main/services/ingestion_service.py +++ b/py/core/main/services/ingestion_service.py @@ -9,9 +9,9 @@ DocumentExtraction, DocumentInfo, DocumentType, - IngestionConfig, IngestionStatus, R2RException, + RawChunk, RunLoggingSingleton, RunManager, VectorEntry, @@ -78,7 +78,7 @@ async def ingest_file_ingress( metadata = metadata or {} version = version or STARTING_VERSION - document_info = self._create_document_info( + document_info = self._create_document_info_from_file( document_id, user, file_data["filename"], @@ -120,7 +120,7 @@ async def ingest_file_ingress( "info": document_info, } - def _create_document_info( + def _create_document_info_from_file( self, document_id: UUID, user: UserResponse, @@ -153,6 +153,33 @@ def _create_document_info( updated_at=datetime.now(), ) + def _create_document_info_from_chunks( + self, + document_id: UUID, + user: UserResponse, + chunks: list[RawChunk], + metadata: dict, + version: str, + ) -> DocumentInfo: + metadata = metadata or {} + metadata["version"] = version + + return DocumentInfo( + id=document_id, + user_id=user.id, + collection_ids=metadata.get("collection_ids", []), + type=DocumentType.TXT, + title=metadata.get("title", f"Ingested Chunks - {document_id}"), + metadata=metadata, + version=version, + size_in_bytes=sum( + len(chunk.text.encode("utf-8")) for chunk in chunks + ), + ingestion_status=IngestionStatus.PENDING, + created_at=datetime.now(), + updated_at=datetime.now(), + ) + async def parse_file( self, document_info: DocumentInfo, ingestion_config: dict ) -> AsyncGenerator[DocumentExtraction, None]: @@ -256,6 +283,53 @@ async def _collect_results(self, result_gen: Any) -> list[dict]: results.append(res.model_dump_json()) return results + @telemetry_event("IngestChunks") + async def ingest_chunks_ingress( + self, + document_id: UUID, + metadata: Optional[dict], + chunks: list[dict], + user: UserResponse, + *args: Any, + **kwargs: Any, + ) -> DocumentInfo: + if not chunks: + raise R2RException( + status_code=400, message="No chunks provided for ingestion." + ) + + metadata = metadata or {} + version = STARTING_VERSION + + document_info = self._create_document_info_from_chunks( + document_id, + user, + chunks, + metadata, + version, + ) + + existing_document_info = ( + await self.providers.database.relational.get_documents_overview( + filter_user_ids=[user.id], + filter_document_ids=[document_id], + ) + )["results"] + + if len(existing_document_info) > 0: + existing_doc = existing_document_info[0] + if existing_doc.ingestion_status != IngestionStatus.FAILED: + raise R2RException( + status_code=409, + message=f"Document {document_id} was already ingested and is not in a failed state.", + ) + + await self.providers.database.relational.upsert_documents_overview( + document_info + ) + + return document_info + class IngestionServiceAdapter: @staticmethod @@ -285,12 +359,10 @@ def parse_ingest_file_input(data: dict) -> dict: } @staticmethod - def parse_update_files_input(data: dict) -> dict: + def parse_ingest_chunks_input(data: dict) -> dict: return { "user": IngestionServiceAdapter._parse_user_data(data["user"]), - "document_ids": [UUID(doc_id) for doc_id in data["document_ids"]], - "metadatas": data["metadatas"], - "ingestion_config": data["ingestion_config"], - "file_sizes_in_bytes": data["file_sizes_in_bytes"], - "file_datas": data["file_datas"], + "metadata": data["metadata"], + "document_id": data["document_id"], + "chunks": [RawChunk.from_dict(chunk) for chunk in data["chunks"]], } diff --git a/py/core/main/services/kg_service.py b/py/core/main/services/kg_service.py index d0024ed65..858b85d2f 100644 --- a/py/core/main/services/kg_service.py +++ b/py/core/main/services/kg_service.py @@ -3,11 +3,12 @@ from typing import Any, AsyncGenerator, Optional from uuid import UUID - from core.base import KGExtractionStatus, RunLoggingSingleton, RunManager -from core.base.abstractions import KGCreationSettings, KGEnrichmentSettings - -from core.base.abstractions import GenerationConfig +from core.base.abstractions import ( + GenerationConfig, + KGCreationSettings, + KGEnrichmentSettings, +) from core.telemetry.telemetry_decorator import telemetry_event from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders diff --git a/py/core/pipes/kg/entity_description.py b/py/core/pipes/kg/entity_description.py index 2966e5564..adc2b48e3 100644 --- a/py/core/pipes/kg/entity_description.py +++ b/py/core/pipes/kg/entity_description.py @@ -14,8 +14,8 @@ PipeType, RunLoggingSingleton, ) -from core.base.pipes.base_pipe import AsyncPipe from core.base.abstractions import Entity +from core.base.pipes.base_pipe import AsyncPipe logger = logging.getLogger(__name__) @@ -193,4 +193,6 @@ async def process_entity( for result in asyncio.as_completed(workflows): yield await result - logger.info(f"Processed {total_entities} entities for document {document_id}") \ No newline at end of file + logger.info( + f"Processed {total_entities} entities for document {document_id}" + ) diff --git a/py/core/providers/database/document.py b/py/core/providers/database/document.py index 332b751d3..5d0522efb 100644 --- a/py/core/providers/database/document.py +++ b/py/core/providers/database/document.py @@ -21,8 +21,8 @@ DocumentInfo, DocumentType, IngestionStatus, - KGExtractionStatus, KGEnrichmentStatus, + KGExtractionStatus, R2RException, ) diff --git a/py/core/providers/kg/postgres.py b/py/core/providers/kg/postgres.py index d86b8ab40..07338027d 100644 --- a/py/core/providers/kg/postgres.py +++ b/py/core/providers/kg/postgres.py @@ -12,21 +12,19 @@ EmbeddingProvider, Entity, KGConfig, - KGExtractionStatus, KGExtraction, + KGExtractionStatus, KGProvider, Triple, ) from shared.abstractions import ( KGCreationEstimationResponse, + KGCreationSettings, KGEnrichmentEstimationResponse, KGEnrichmentSettings, - KGCreationSettings, ) - from shared.utils import llm_cost_per_million_tokens - logger = logging.getLogger(__name__) @@ -333,8 +331,8 @@ async def get_entity_map( ORDER BY name ASC LIMIT {limit} OFFSET {offset} ) - SELECT e.name, e.description, e.category, - (SELECT array_agg(DISTINCT x) FROM unnest(e.extraction_ids) x) AS extraction_ids, + SELECT e.name, e.description, e.category, + (SELECT array_agg(DISTINCT x) FROM unnest(e.extraction_ids) x) AS extraction_ids, e.document_id FROM {self._get_table_name("entity_raw")} e JOIN entities_list el ON e.name = el.name @@ -363,7 +361,7 @@ async def get_entity_map( LIMIT {limit} OFFSET {offset} ) - SELECT DISTINCT t.subject, t.predicate, t.object, t.weight, t.description, + SELECT DISTINCT t.subject, t.predicate, t.object, t.weight, t.description, (SELECT array_agg(DISTINCT x) FROM unnest(t.extraction_ids) x) AS extraction_ids, t.document_id FROM {self._get_table_name("triple_raw")} t JOIN entities_list el ON t.subject = el.name @@ -459,17 +457,23 @@ async def vector_query(self, query: str, **kwargs: Any) -> Any: filter_query = "" if collection_ids_dict: filter_query = "WHERE collection_id = ANY($3)" - filter_ids = collection_ids_dict['$overlap'] + filter_ids = collection_ids_dict["$overlap"] if search_type == "__Community__": logger.info(f"Searching in collection ids: {filter_ids}") - if search_type == "__Entity__" or search_type == "__Relationship__": + if ( + search_type == "__Entity__" + or search_type == "__Relationship__" + ): filter_query = "WHERE document_id = ANY($3)" query = f""" SELECT distinct document_id FROM {self._get_table_name('document_info')} WHERE $1 = ANY(collection_ids) """ - filter_ids = [doc_id['document_id'] for doc_id in await self.fetch_query(query, filter_ids)] + filter_ids = [ + doc_id["document_id"] + for doc_id in await self.fetch_query(query, filter_ids) + ] logger.info(f"Searching in document ids: {filter_ids}") QUERY = f""" @@ -477,9 +481,13 @@ async def vector_query(self, query: str, **kwargs: Any) -> Any: """ if filter_query != "": - results = await self.fetch_query(QUERY, (str(query_embedding), limit, filter_ids)) + results = await self.fetch_query( + QUERY, (str(query_embedding), limit, filter_ids) + ) else: - results = await self.fetch_query(QUERY, (str(query_embedding), limit)) + results = await self.fetch_query( + QUERY, (str(query_embedding), limit) + ) for result in results: yield { @@ -731,7 +739,6 @@ async def delete_graph_for_collection( QUERY, [KGExtractionStatus.PENDING, collection_id] ) - def _get_str_estimation_output(self, x: tuple[Any, Any]) -> str: if isinstance(x[0], int) and isinstance(x[1], int): return " - ".join(map(str, x)) @@ -767,7 +774,8 @@ async def get_creation_estimate( // kg_creation_settings.extraction_merge_count ) # 4 chunks per llm estimated_entities = ( - (total_chunks * 10, total_chunks * 20) + total_chunks * 10, + total_chunks * 20, ) # 25 entities per 4 chunks estimated_triples = ( int(estimated_entities[0] * 1.25), @@ -785,8 +793,14 @@ async def get_creation_estimate( ) # in millions estimated_cost = ( - total_in_out_tokens[0] * llm_cost_per_million_tokens(kg_creation_settings.generation_config.model), - total_in_out_tokens[1] * llm_cost_per_million_tokens(kg_creation_settings.generation_config.model), + total_in_out_tokens[0] + * llm_cost_per_million_tokens( + kg_creation_settings.generation_config.model + ), + total_in_out_tokens[1] + * llm_cost_per_million_tokens( + kg_creation_settings.generation_config.model + ), ) total_time_in_minutes = ( @@ -794,28 +808,40 @@ async def get_creation_estimate( total_in_out_tokens[1] * 10 / 60, ) # 10 minutes per million tokens - return KGCreationEstimationResponse( - message="These are estimated ranges, actual values may vary. To run the KG creation process, run `create-graph` with `--run` in the cli, or `run_mode=\"run\"` in the client.", + message='These are estimated ranges, actual values may vary. To run the KG creation process, run `create-graph` with `--run` in the cli, or `run_mode="run"` in the client.', document_count=len(document_ids), number_of_jobs_created=len(document_ids) + 1, total_chunks=total_chunks, - estimated_entities=self._get_str_estimation_output(estimated_entities), - estimated_triples=self._get_str_estimation_output(estimated_triples), - estimated_llm_calls=self._get_str_estimation_output(estimated_llm_calls), - estimated_total_in_out_tokens_in_millions=self._get_str_estimation_output(total_in_out_tokens), - estimated_cost_in_usd=self._get_str_estimation_output(estimated_cost), - estimated_total_time_in_minutes="Depends on your API key tier. Accurate estimate coming soon. Rough estimate: " + self._get_str_estimation_output(total_time_in_minutes), + estimated_entities=self._get_str_estimation_output( + estimated_entities + ), + estimated_triples=self._get_str_estimation_output( + estimated_triples + ), + estimated_llm_calls=self._get_str_estimation_output( + estimated_llm_calls + ), + estimated_total_in_out_tokens_in_millions=self._get_str_estimation_output( + total_in_out_tokens + ), + estimated_cost_in_usd=self._get_str_estimation_output( + estimated_cost + ), + estimated_total_time_in_minutes="Depends on your API key tier. Accurate estimate coming soon. Rough estimate: " + + self._get_str_estimation_output(total_time_in_minutes), ) async def get_enrichment_estimate( - self, collection_id: UUID, - kg_enrichment_settings: KGEnrichmentSettings + self, collection_id: UUID, kg_enrichment_settings: KGEnrichmentSettings ) -> KGEnrichmentEstimationResponse: - document_ids = [doc.id for doc in (await self.db_provider.documents_in_collection( - collection_id - ))["results"]] + document_ids = [ + doc.id + for doc in ( + await self.db_provider.documents_in_collection(collection_id) + )["results"] + ] QUERY = f""" SELECT COUNT(*) FROM {self._get_table_name("entity_embedding")} WHERE document_id = ANY($1); @@ -825,7 +851,9 @@ async def get_enrichment_estimate( ] if not entity_count: - raise ValueError("No entities found in the graph. Please run `create-graph` first.") + raise ValueError( + "No entities found in the graph. Please run `create-graph` first." + ) QUERY = f""" SELECT COUNT(*) FROM {self._get_table_name("triple_raw")} WHERE document_id = ANY($1); @@ -839,25 +867,36 @@ async def get_enrichment_estimate( 2000 * estimated_llm_calls[0] / 1000000, 2000 * estimated_llm_calls[1] / 1000000, ) - cost_per_million_tokens = llm_cost_per_million_tokens(kg_enrichment_settings.generation_config.model) + cost_per_million_tokens = llm_cost_per_million_tokens( + kg_enrichment_settings.generation_config.model + ) estimated_cost = ( - estimated_total_in_out_tokens_in_millions[0] * cost_per_million_tokens, - estimated_total_in_out_tokens_in_millions[1] * cost_per_million_tokens, + estimated_total_in_out_tokens_in_millions[0] + * cost_per_million_tokens, + estimated_total_in_out_tokens_in_millions[1] + * cost_per_million_tokens, ) - + estimated_total_time = ( estimated_total_in_out_tokens_in_millions[0] * 10 / 60, estimated_total_in_out_tokens_in_millions[1] * 10 / 60, ) return KGEnrichmentEstimationResponse( - message="These are estimated ranges, actual values may vary. To run the KG enrichment process, run `enrich-graph` with `--run` in the cli, or `run_mode=\"run\"` in the client.", + message='These are estimated ranges, actual values may vary. To run the KG enrichment process, run `enrich-graph` with `--run` in the cli, or `run_mode="run"` in the client.', total_entities=entity_count, total_triples=triple_count, - estimated_llm_calls=self._get_str_estimation_output(estimated_llm_calls), - estimated_total_in_out_tokens_in_millions=self._get_str_estimation_output(estimated_total_in_out_tokens_in_millions), - estimated_cost_in_usd=self._get_str_estimation_output(estimated_cost), - estimated_total_time_in_minutes="Depends on your API key tier. Accurate estimate coming soon. Rough estimate: " + self._get_str_estimation_output(estimated_total_time), + estimated_llm_calls=self._get_str_estimation_output( + estimated_llm_calls + ), + estimated_total_in_out_tokens_in_millions=self._get_str_estimation_output( + estimated_total_in_out_tokens_in_millions + ), + estimated_cost_in_usd=self._get_str_estimation_output( + estimated_cost + ), + estimated_total_time_in_minutes="Depends on your API key tier. Accurate estimate coming soon. Rough estimate: " + + self._get_str_estimation_output(estimated_total_time), ) async def create_vector_index(self): diff --git a/py/sdk/ingestion.py b/py/sdk/ingestion.py index 26ab9f59d..f15383e7d 100644 --- a/py/sdk/ingestion.py +++ b/py/sdk/ingestion.py @@ -128,3 +128,30 @@ async def update_files( return await client._make_request( "POST", "update_files", data=data, files=files ) + + @staticmethod + async def ingest_chunks( + client, + chunks: list[dict], + document_id: Optional[UUID] = None, + metadata: Optional[dict] = None, + ) -> dict: + """ + Ingest files into your R2R deployment + + Args: + file_paths (List[str]): List of file paths to ingest. + document_ids (Optional[List[str]]): List of document IDs. + metadatas (Optional[List[dict]]): List of metadata dictionaries for each file. + ingestion_config (Optional[Union[dict]]): Custom chunking configuration. + + Returns: + dict: Ingestion results containing processed, failed, and skipped documents. + """ + + data = { + "chunks": chunks, + "document_id": document_id, + "metadata": metadata, + } + return await client._make_request("POST", "ingest_chunks", json=data) diff --git a/py/shared/abstractions/__init__.py b/py/shared/abstractions/__init__.py index 6cb7b57ca..89658ccea 100644 --- a/py/shared/abstractions/__init__.py +++ b/py/shared/abstractions/__init__.py @@ -7,8 +7,9 @@ DocumentInfo, DocumentType, IngestionStatus, - KGExtractionStatus, KGEnrichmentStatus, + KGExtractionStatus, + RawChunk, ) from .embedding import EmbeddingPurpose, default_embedding_prefixes from .exception import R2RDocumentProcessingError, R2RException @@ -70,6 +71,7 @@ "KGExtractionStatus", "KGEnrichmentStatus", "DocumentType", + "RawChunk", # Embedding abstractions "EmbeddingPurpose", "default_embedding_prefixes", diff --git a/py/shared/abstractions/document.py b/py/shared/abstractions/document.py index 84fae5554..eb8143f70 100644 --- a/py/shared/abstractions/document.py +++ b/py/shared/abstractions/document.py @@ -185,3 +185,7 @@ class DocumentExtraction(R2RSerializable): user_id: UUID data: DataType metadata: dict + + +class RawChunk(R2RSerializable): + text: str diff --git a/py/shared/abstractions/kg.py b/py/shared/abstractions/kg.py index 2561b51a9..43e14e26b 100644 --- a/py/shared/abstractions/kg.py +++ b/py/shared/abstractions/kg.py @@ -105,6 +105,7 @@ class KGEnrichmentEstimationResponse(R2RSerializable): description="The estimated total time to run the graph enrichment process.", ) + class KGCreationSettings(R2RSerializable): """Settings for knowledge graph creation.""" diff --git a/py/shared/abstractions/search.py b/py/shared/abstractions/search.py index b0862e8a6..d2d3a34d9 100644 --- a/py/shared/abstractions/search.py +++ b/py/shared/abstractions/search.py @@ -320,7 +320,6 @@ class KGSearchSettings(R2RSerializable): description="Configuration for text generation during graph search.", ) - # TODO: add these back in # entity_types: list = [] # relationships: list = [] diff --git a/py/shared/utils/__init__.py b/py/shared/utils/__init__.py index 7bf87e669..9ba1517c8 100644 --- a/py/shared/utils/__init__.py +++ b/py/shared/utils/__init__.py @@ -13,9 +13,9 @@ generate_run_id, generate_user_id, increment_version, + llm_cost_per_million_tokens, run_pipeline, to_async_generator, - llm_cost_per_million_tokens, ) from .splitter.text import RecursiveCharacterTextSplitter, TextSplitter diff --git a/py/shared/utils/base_utils.py b/py/shared/utils/base_utils.py index aaabfe9c2..666eeca42 100644 --- a/py/shared/utils/base_utils.py +++ b/py/shared/utils/base_utils.py @@ -208,25 +208,34 @@ def format_relations(predicates: list[RelationshipType]) -> str: lines = [predicate.name for predicate in predicates] return "\n".join(lines) -def llm_cost_per_million_tokens(model: str, input_output_ratio: float = 2) -> float: + +def llm_cost_per_million_tokens( + model: str, input_output_ratio: float = 2 +) -> float: """ Returns the cost per million tokens for a given model and input/output ratio. - + Input/Output ratio is the ratio of input tokens to output tokens. """ # improving this to use provider in the future - model = model.split("/")[-1] # simplifying assumption + model = model.split("/")[-1] # simplifying assumption cost_dict = { "gpt-4o-mini": (0.15, 0.6), "gpt-4o": (2.5, 10), } if model in cost_dict: - return (cost_dict[model][0] * input_output_ratio * cost_dict[model][1])/(1 + input_output_ratio) + return ( + cost_dict[model][0] * input_output_ratio * cost_dict[model][1] + ) / (1 + input_output_ratio) else: # use gpt-4o as default logger.warning(f"Unknown model: {model}. Using gpt-4o as default.") - return (cost_dict["gpt-4o"][0] * input_output_ratio * cost_dict["gpt-4o"][1])/(1 + input_output_ratio) \ No newline at end of file + return ( + cost_dict["gpt-4o"][0] + * input_output_ratio + * cost_dict["gpt-4o"][1] + ) / (1 + input_output_ratio) diff --git a/py/tests/core/providers/database/relational/test_document_db.py b/py/tests/core/providers/database/relational/test_document_db.py index 459d1f842..93f87e636 100644 --- a/py/tests/core/providers/database/relational/test_document_db.py +++ b/py/tests/core/providers/database/relational/test_document_db.py @@ -7,8 +7,8 @@ DocumentInfo, DocumentType, IngestionStatus, - KGExtractionStatus, KGEnrichmentStatus, + KGExtractionStatus, )