Skip to content

Commit

Permalink
Ingestion refactor (#991)
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyaspimpalgaonkar authored Aug 27, 2024
1 parent e0d8109 commit 2807a2b
Show file tree
Hide file tree
Showing 22 changed files with 155 additions and 253 deletions.
6 changes: 4 additions & 2 deletions py/core/base/abstractions/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ class DocumentInfo(BaseModel):
title: Optional[str] = None
version: str
size_in_bytes: int
status: DocumentStatus = DocumentStatus.PROCESSING
ingestion_status: DocumentStatus = DocumentStatus.PROCESSING
restructuring_status: DocumentStatus = DocumentStatus.PROCESSING
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None

Expand All @@ -108,7 +109,8 @@ def convert_to_db_entry(self):
"title": self.title or "N/A",
"version": self.version,
"size_in_bytes": self.size_in_bytes,
"status": self.status,
"ingestion_status": self.ingestion_status,
"restructuring_status": self.restructuring_status,
"created_at": self.created_at or now,
"updated_at": self.updated_at or now,
}
Expand Down
7 changes: 5 additions & 2 deletions py/core/base/abstractions/graph.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json
import logging
import uuid
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Optional, Union
from typing import Any, Optional

from pydantic import BaseModel

Expand Down Expand Up @@ -397,5 +398,7 @@ def from_dict(
class KGExtraction(BaseModel):
"""An extraction from a document that is part of a knowledge graph."""

entities: Union[list[Entity], dict[str, Entity]]
fragment_id: uuid.UUID
document_id: uuid.UUID
entities: dict[str, Entity]
triples: list[Triple]
8 changes: 7 additions & 1 deletion py/core/base/abstractions/restructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,16 @@ class KGEnrichmentSettings(BaseModel):
description="The maximum number of knowledge triples to extract from each chunk.",
)

generation_config: GenerationConfig = Field(
generation_config_triplet: GenerationConfig = Field(
default_factory=GenerationConfig,
description="Configuration for text generation during graph enrichment.",
)

generation_config_enrichment: GenerationConfig = Field(
default_factory=GenerationConfig,
description="Configuration for text generation during graph enrichment.",
)

leiden_params: dict = Field(
default_factory=dict,
description="Parameters for the Leiden algorithm.",
Expand Down
3 changes: 2 additions & 1 deletion py/core/base/api/models/management/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ class DocumentOverviewResponse(BaseModel):
type: str
created_at: datetime
updated_at: datetime
status: str
ingestion_status: str
restructuring_status: str
version: str
group_ids: list[UUID]
metadata: Dict[str, Any]
Expand Down
1 change: 0 additions & 1 deletion py/core/base/providers/kg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ class KGConfig(ProviderConfig):
batch_size: Optional[int] = 1
kg_extraction_prompt: Optional[str] = "few_shot_ner_kg_extraction"
kg_search_prompt: Optional[str] = "kg_search"
kg_extraction_config: Optional[GenerationConfig] = None
kg_search_config: Optional[GenerationConfig] = None
kg_store_path: Optional[str] = None
kg_enrichment_settings: Optional[KGEnrichmentSettings] = (
Expand Down
19 changes: 13 additions & 6 deletions py/core/configs/local_llm_neo4j_kg.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,26 @@ excluded_parsers = [ "gif", "jpeg", "jpg", "png", "svg", "mp3", "mp4" ]

[kg]
provider = "neo4j"
batch_size = 1
max_entities = 10
max_relations = 20
kg_extraction_prompt = "zero_shot_ner_kg_extraction"
kg_extraction_prompt = "graphrag_triplet_extraction_zero_shot"

[kg.kg_extraction_config]
model = "ollama/sciphi/triplex"
model = "ollama/llama3.1"
temperature = 1
top_p = 1
max_tokens_to_sample = 1_024
stream = false
add_generation_kwargs = { }


[kg.kg_enrichment_settings]
max_knowledge_triples = 100
generation_config_triplet = { model = "ollama/llama3.1" } # and other params, model used for triplet extraction
generation_config_enrichment = { model = "ollama/llama3.1" } # and other params, model used for node description and graph clustering
leiden_params = { max_cluster_size = 1000 } # more params in graspologic/partition/leiden.py

[kg.kg_search_config]
model = "ollama/llama3.1"

[database]
provider = "postgres"

Expand All @@ -43,4 +50,4 @@ system_instruction_name = "rag_agent"
tool_names = ["search"]

[agent.generation_config]
model = "ollama/llama3.1"
model = "ollama/llama3.1"
15 changes: 10 additions & 5 deletions py/core/configs/neo4j_kg.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
[chunking] # use larger chunk sizes for kg extraction
provider = "r2r"
method = "recursive"
chunk_size = 4096
chunk_overlap = 200


[completion]
provider = "litellm"
concurrent_request_limit = 256
Expand All @@ -15,13 +22,11 @@ provider = "neo4j"
batch_size = 256
kg_extraction_prompt = "graphrag_triplet_extraction_zero_shot"

[kg.kg_extraction_config]
model = "gpt-4o-mini"

[kg.kg_enrichment_settings]
max_knowledge_triples = 100
generation_config = { model = "gpt-4o-mini" } # and other params
generation_config_triplet = { model = "gpt-4o-mini" } # and other params, model used for triplet extraction
generation_config_enrichment = { model = "gpt-4o-mini" } # and other params, model used for node description and graph clustering
leiden_params = { max_cluster_size = 1000 } # more params in graspologic/partition/leiden.py

[kg.kg_search_config]
model = "gpt-4o-mini"
model = "gpt-4o-mini"
3 changes: 2 additions & 1 deletion py/core/main/api/routes/base_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ async def wrapper(*args, **kwargs):
value=str(e),
)
logger.error(
f"Error in base endpoint {func.__name__}() - \n\n{str(e)})"
f"Error in base endpoint {func.__name__}() - \n\n{str(e)}",
exc_info=True,
)
raise HTTPException(
status_code=500,
Expand Down
2 changes: 1 addition & 1 deletion py/core/main/assembly/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class R2RConfig:
"kg": [
"provider",
"batch_size",
"kg_extraction_config",
"kg_enrichment_settings",
],
"parsing": ["provider", "excluded_parsers"],
"chunking": ["provider", "method"],
Expand Down
10 changes: 6 additions & 4 deletions py/core/main/services/ingestion_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ async def ingest_documents(
*args: Any,
**kwargs: Any,
):

if len(documents) == 0:
raise R2RException(
status_code=400, message="No documents provided for ingestion."
Expand Down Expand Up @@ -277,7 +278,8 @@ async def ingest_documents(
document.id in existing_document_info
# apply `geq` check to prevent re-ingestion of updated documents
and (existing_document_info[document.id].version >= version)
and existing_document_info[document.id].status == "success"
and existing_document_info[document.id].ingestion_status
== "success"
):
logger.error(
f"Document with ID {document.id} was already successfully processed."
Expand Down Expand Up @@ -309,7 +311,7 @@ async def ingest_documents(
title=title,
version=version,
size_in_bytes=len(document.data),
status="processing",
ingestion_status="processing",
created_at=now,
updated_at=now,
)
Expand Down Expand Up @@ -417,9 +419,9 @@ async def _process_ingestion_results(
for document_info in document_infos:
if document_info.id not in skipped_ids:
if document_info.id in failed_ids:
document_info.status = "failure"
document_info.ingestion_status = "failure"
elif document_info.id in successful_ids:
document_info.status = "success"
document_info.ingestion_status = "success"
documents_to_upsert.append(document_info)

if documents_to_upsert:
Expand Down
10 changes: 8 additions & 2 deletions py/core/main/services/restructure_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any, Dict
from typing import Any, Dict, Optional, Union

from core.base import R2RException, RunLoggingSingleton, RunManager
from core.base.abstractions import KGEnrichmentSettings
Expand Down Expand Up @@ -33,7 +33,9 @@ def __init__(

async def enrich_graph(
self,
enrich_graph_settings: KGEnrichmentSettings = KGEnrichmentSettings(),
kg_enrichment_settings: Optional[
Union[dict, KGEnrichmentSettings]
] = None,
) -> Dict[str, Any]:
"""
Perform graph enrichment.
Expand All @@ -49,8 +51,12 @@ async def input_generator():
for doc in input:
yield doc

if not kg_enrichment_settings or kg_enrichment_settings == {}:
kg_enrichment_settings = self.config.kg.kg_enrichment_settings

return await self.pipelines.kg_enrichment_pipeline.run(
input=input_generator(),
kg_enrichment_settings=kg_enrichment_settings,
run_manager=self.run_manager,
)

Expand Down
Loading

0 comments on commit 2807a2b

Please sign in to comment.