diff --git a/py/core/base/providers/__init__.py b/py/core/base/providers/__init__.py index 3318f1c8f..520d488ad 100644 --- a/py/core/base/providers/__init__.py +++ b/py/core/base/providers/__init__.py @@ -18,7 +18,7 @@ from .file import FileConfig, FileProvider from .kg import KGConfig, KGProvider from .llm import CompletionConfig, CompletionProvider -from .orchestration import OrchestrationConfig, OrchestrationProvider +from .orchestration import OrchestrationConfig, OrchestrationProvider, Workflow from .parsing import OverrideParser, ParsingConfig, ParsingProvider from .prompt import PromptConfig, PromptProvider @@ -55,6 +55,7 @@ # Orchestration provider "OrchestrationConfig", "OrchestrationProvider", + "Workflow", # Parsing provider "ParsingConfig", "ParsingProvider", diff --git a/py/core/base/providers/orchestration.py b/py/core/base/providers/orchestration.py index a676bf650..42d1d3ec9 100644 --- a/py/core/base/providers/orchestration.py +++ b/py/core/base/providers/orchestration.py @@ -1,9 +1,15 @@ from abc import abstractmethod +from enum import Enum from typing import Any, Callable from .base import Provider, ProviderConfig +class Workflow(Enum): + INGESTION = "ingestion" + RESTRUCTURE = "restructure" + + class OrchestrationConfig(ProviderConfig): provider: str max_threads: int = 256 @@ -24,7 +30,7 @@ def __init__(self, config: OrchestrationConfig): self.worker = None @abstractmethod - def register_workflow(self, workflow: Any) -> None: + async def start_worker(self): pass @abstractmethod @@ -32,13 +38,28 @@ def get_worker(self, name: str, max_threads: int) -> Any: pass @abstractmethod - def workflow(self, *args, **kwargs) -> Callable: + def step(self, *args, **kwargs) -> Any: + pass + + @abstractmethod + def workflow(self, *args, **kwargs) -> Any: + pass + + @abstractmethod + def failure(self, *args, **kwargs) -> Any: pass @abstractmethod - def step(self, *args, **kwargs) -> Callable: + def register_workflows(self, workflow: Workflow, service: Any) -> None: pass @abstractmethod - def start_worker(self): + def run_workflow( + self, + workflow_name: str, + parameters: dict, + options: dict, + *args, + **kwargs, + ) -> Any: pass diff --git a/py/core/main/__init__.py b/py/core/main/__init__.py index cec677483..9447a6f12 100644 --- a/py/core/main/__init__.py +++ b/py/core/main/__init__.py @@ -4,7 +4,7 @@ # from .app_entry import r2r_app from .assembly import * -from .hatchet import * +from .orchestration import * from .services import * __all__ = [ @@ -22,10 +22,6 @@ "RestructureRouter", ## R2R APP "R2RApp", - ## R2R APP ENTRY - # "r2r_app", - ## R2R HATCHET - "r2r_hatchet", ## R2R ASSEMBLY # Builder "R2RBuilder", diff --git a/py/core/main/api/auth_router.py b/py/core/main/api/auth_router.py index 1a3bd7851..e99460847 100644 --- a/py/core/main/api/auth_router.py +++ b/py/core/main/api/auth_router.py @@ -22,12 +22,12 @@ class AuthRouter(BaseRouter): def __init__( self, - auth_service: AuthService, - run_type: RunType = RunType.INGESTION, - orchestration_provider: Optional[OrchestrationProvider] = None, + service: AuthService, + orchestration_provider: OrchestrationProvider, + run_type: RunType = RunType.UNSPECIFIED, ): - super().__init__(auth_service, run_type, orchestration_provider) - self.service: AuthService = auth_service # for type hinting + super().__init__(service, orchestration_provider, run_type) + self.service: AuthService = service # for type hinting def _register_workflows(self): pass diff --git a/py/core/main/api/base_router.py b/py/core/main/api/base_router.py index 58265e529..3ddfcbaa9 100644 --- a/py/core/main/api/base_router.py +++ b/py/core/main/api/base_router.py @@ -19,8 +19,8 @@ class BaseRouter: def __init__( self, service: "Service", + orchestration_provider: OrchestrationProvider, run_type: RunType = RunType.UNSPECIFIED, - orchestration_provider: Optional[OrchestrationProvider] = None, ): self.service = service self.run_type = run_type diff --git a/py/core/main/api/ingestion_router.py b/py/core/main/api/ingestion_router.py index 58276d30e..385820bbc 100644 --- a/py/core/main/api/ingestion_router.py +++ b/py/core/main/api/ingestion_router.py @@ -15,10 +15,8 @@ WrappedIngestionResponse, WrappedUpdateResponse, ) -from core.base.providers import OrchestrationProvider +from core.base.providers import OrchestrationProvider, Workflow -from ...main.hatchet import r2r_hatchet -from ..hatchet import IngestFilesWorkflow, UpdateFilesWorkflow from ..services.ingestion_service import IngestionService from .base_router import BaseRouter, RunType @@ -29,22 +27,15 @@ class IngestionRouter(BaseRouter): def __init__( self, service: IngestionService, + orchestration_provider: OrchestrationProvider, run_type: RunType = RunType.INGESTION, - orchestration_provider: Optional[OrchestrationProvider] = None, ): - if not orchestration_provider: - raise ValueError( - "IngestionRouter requires an orchestration provider." - ) - super().__init__(service, run_type, orchestration_provider) + super().__init__(service, orchestration_provider, run_type) self.service: IngestionService = service def _register_workflows(self): - self.orchestration_provider.register_workflow( - IngestFilesWorkflow(self.service) - ) - self.orchestration_provider.register_workflow( - UpdateFilesWorkflow(self.service) + self.orchestration_provider.register_workflows( + Workflow.INGESTION, self.service ) def _load_openapi_extras(self): @@ -146,7 +137,7 @@ async def ingest_files_app( file_data["content_type"], ) - task_id = r2r_hatchet.admin.run_workflow( + task_id = self.orchestration_provider.run_workflow( "ingest-file", {"request": workflow_input}, options={ @@ -165,50 +156,6 @@ async def ingest_files_app( ) return messages - @self.router.post( - "/retry_ingest_files", - openapi_extra=ingest_files_extras.get("openapi_extra"), - ) - @self.base_endpoint - async def retry_ingest_files( - document_ids: list[UUID] = Form( - ..., - description=ingest_files_descriptions.get("document_ids"), - ), - auth_user=Depends(self.service.providers.auth.auth_wrapper), - response_model=WrappedIngestionResponse, - ): - """ - Retry the ingestion of files into the system. - - This endpoint allows you to retry the ingestion of files that have previously failed to ingest into R2R. - - A valid user authentication token is required to access this endpoint, as regular users can only retry the ingestion of their own files. More expansive collection permissioning is under development. - """ - if not auth_user.is_superuser: - documents_overview = await self.service.providers.database.relational.get_documents_overview( - filter_document_ids=document_ids, - filter_user_ids=[auth_user.id], - )[ - "results" - ] - if len(documents_overview) != len(document_ids): - raise R2RException( - status_code=404, - message="One or more documents not found.", - ) - - # FIXME: This is throwing an aiohttp.client_exceptions.ClientConnectionError: Cannot connect to host localhost:8080 ssl:default… can we whitelist the host? - workflow_list = await r2r_hatchet.rest.workflow_run_list() - - # TODO: we want to extract the hatchet run ids for the document ids, and then retry them - - return { - "message": "Retry tasks queued successfully.", - "task_ids": [str(task_id) for task_id in workflow_list], - "document_ids": [str(doc_id) for doc_id in document_ids], - } - update_files_extras = self.openapi_extras.get("update_files", {}) update_files_descriptions = update_files_extras.get( "input_descriptions", {} @@ -301,8 +248,8 @@ async def update_files_app( "is_update": True, } - task_id = r2r_hatchet.admin.run_workflow( - "update-files", {"request": workflow_input} + task_id = self.orchestration_provider.run_workflow( + "update-files", {"request": workflow_input}, {} ) return { diff --git a/py/core/main/api/management_router.py b/py/core/main/api/management_router.py index cb37b0315..2a959c7cb 100644 --- a/py/core/main/api/management_router.py +++ b/py/core/main/api/management_router.py @@ -41,10 +41,10 @@ class ManagementRouter(BaseRouter): def __init__( self, service: ManagementService, + orchestration_provider: OrchestrationProvider, run_type: RunType = RunType.MANAGEMENT, - orchestration_provider: Optional[OrchestrationProvider] = None, ): - super().__init__(service, run_type, orchestration_provider) + super().__init__(service, orchestration_provider, run_type) self.service: ManagementService = service # for type hinting self.start_time = datetime.now(timezone.utc) diff --git a/py/core/main/api/restructure_router.py b/py/core/main/api/restructure_router.py index d65ee80fb..c1145fca6 100644 --- a/py/core/main/api/restructure_router.py +++ b/py/core/main/api/restructure_router.py @@ -11,15 +11,8 @@ WrappedKGCreationResponse, WrappedKGEnrichmentResponse, ) -from core.base.providers import OrchestrationProvider - -from ...main.hatchet import r2r_hatchet -from ..hatchet import ( - CreateGraphWorkflow, - EnrichGraphWorkflow, - KGCommunitySummaryWorkflow, - KgExtractAndStoreWorkflow, -) +from core.base.providers import OrchestrationProvider, Workflow + from ..services.restructure_service import RestructureService from .base_router import BaseRouter, RunType @@ -30,14 +23,10 @@ class RestructureRouter(BaseRouter): def __init__( self, service: RestructureService, + orchestration_provider: OrchestrationProvider, run_type: RunType = RunType.RESTRUCTURE, - orchestration_provider: Optional[OrchestrationProvider] = None, ): - if not orchestration_provider: - raise ValueError( - "RestructureRouter requires an orchestration provider." - ) - super().__init__(service, run_type, orchestration_provider) + super().__init__(service, orchestration_provider, run_type) self.service: RestructureService = service def _load_openapi_extras(self): @@ -49,17 +38,8 @@ def _load_openapi_extras(self): return yaml_content def _register_workflows(self): - self.orchestration_provider.register_workflow( - EnrichGraphWorkflow(self.service) - ) - self.orchestration_provider.register_workflow( - KgExtractAndStoreWorkflow(self.service) - ) - self.orchestration_provider.register_workflow( - CreateGraphWorkflow(self.service) - ) - self.orchestration_provider.register_workflow( - KGCommunitySummaryWorkflow(self.service) + self.orchestration_provider.register_workflows( + Workflow.RESTRUCTURE, self.service ) def _setup_routes(self): @@ -102,8 +82,8 @@ async def create_graph( "user": auth_user.json(), } - task_id = r2r_hatchet.admin.run_workflow( - "create-graph", {"request": workflow_input} + task_id = self.orchestration_provider.run_workflow( + "create-graph", {"request": workflow_input}, {} ) return { @@ -157,8 +137,8 @@ async def enrich_graph( "user": auth_user.json(), } - task_id = r2r_hatchet.admin.run_workflow( - "enrich-graph", {"request": workflow_input} + task_id = self.orchestration_provider.run_workflow( + "enrich-graph", {"request": workflow_input}, {} ) return { diff --git a/py/core/main/api/retrieval_router.py b/py/core/main/api/retrieval_router.py index c84b4ae87..c94c9d7cb 100644 --- a/py/core/main/api/retrieval_router.py +++ b/py/core/main/api/retrieval_router.py @@ -29,10 +29,10 @@ class RetrievalRouter(BaseRouter): def __init__( self, service: RetrievalService, + orchestration_provider: OrchestrationProvider, run_type: RunType = RunType.RETRIEVAL, - orchestration_provider: Optional[OrchestrationProvider] = None, ): - super().__init__(service, run_type, orchestration_provider) + super().__init__(service, orchestration_provider, run_type) self.service: RetrievalService = service # for type hinting def _load_openapi_extras(self): diff --git a/py/core/main/assembly/builder.py b/py/core/main/assembly/builder.py index be2e0e700..8ad2b5172 100644 --- a/py/core/main/assembly/builder.py +++ b/py/core/main/assembly/builder.py @@ -233,16 +233,20 @@ async def build(self, *args, **kwargs) -> R2RApp: orchestration_provider = providers.orchestration routers = { - "auth_router": AuthRouter(services["auth"]).get_router(), + "auth_router": AuthRouter( + services["auth"], orchestration_provider=orchestration_provider + ).get_router(), "ingestion_router": IngestionRouter( services["ingestion"], orchestration_provider=orchestration_provider, ).get_router(), "management_router": ManagementRouter( - services["management"] + services["management"], + orchestration_provider=orchestration_provider, ).get_router(), "retrieval_router": RetrievalRouter( - services["retrieval"] + services["retrieval"], + orchestration_provider=orchestration_provider, ).get_router(), "restructure_router": RestructureRouter( services["restructure"], diff --git a/py/core/main/assembly/factory.py b/py/core/main/assembly/factory.py index c6c612525..450f5a035 100644 --- a/py/core/main/assembly/factory.py +++ b/py/core/main/assembly/factory.py @@ -50,7 +50,9 @@ async def create_auth_provider( if auth_config.provider == "r2r": from core.providers import R2RAuthProvider - r2r_auth = R2RAuthProvider(auth_config, crypto_provider, db_provider) + r2r_auth = R2RAuthProvider( + auth_config, crypto_provider, db_provider + ) await r2r_auth.initialize() return r2r_auth elif auth_config.provider == "supabase": diff --git a/py/core/main/hatchet/__init__.py b/py/core/main/hatchet/__init__.py deleted file mode 100644 index b63ef5de4..000000000 --- a/py/core/main/hatchet/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -from .base import r2r_hatchet -from .ingestion_workflow import IngestFilesWorkflow, UpdateFilesWorkflow -from .restructure_workflow import ( - CreateGraphWorkflow, - EnrichGraphWorkflow, - KGCommunitySummaryWorkflow, - KgExtractAndStoreWorkflow, -) - -__all__ = [ - "r2r_hatchet", - "IngestFilesWorkflow", - "UpdateFilesWorkflow", - "EnrichGraphWorkflow", - "CreateGraphWorkflow", - "KgExtractAndStoreWorkflow", - "KGCommunitySummaryWorkflow", -] diff --git a/py/core/main/hatchet/base.py b/py/core/main/hatchet/base.py deleted file mode 100644 index d38a51d52..000000000 --- a/py/core/main/hatchet/base.py +++ /dev/null @@ -1,6 +0,0 @@ -from hatchet_sdk import Hatchet - -try: - r2r_hatchet = Hatchet() -except ImportError: - r2r_hatchet = None diff --git a/py/core/main/hatchet/ingestion_workflow.py b/py/core/main/hatchet/ingestion_workflow.py deleted file mode 100644 index 107a6cea5..000000000 --- a/py/core/main/hatchet/ingestion_workflow.py +++ /dev/null @@ -1,278 +0,0 @@ -import asyncio -import logging - -from hatchet_sdk import Context - -from core.base import IngestionStatus, increment_version -from core.base.abstractions import DocumentInfo, R2RException - -from ..services import IngestionService, IngestionServiceAdapter -from .base import r2r_hatchet - -logger = logging.getLogger(__name__) - - -@r2r_hatchet.workflow( - name="ingest-file", - timeout="60m", -) -class IngestFilesWorkflow: - def __init__(self, ingestion_service: IngestionService): - self.ingestion_service = ingestion_service - - @r2r_hatchet.step(timeout="60m") - async def parse(self, context: Context) -> dict: - input_data = context.workflow_input()["request"] - print("input_data = ", input_data) - parsed_data = IngestionServiceAdapter.parse_ingest_file_input( - input_data - ) - - ingestion_result = await self.ingestion_service.ingest_file_ingress( - **parsed_data - ) - - document_info = ingestion_result["info"] - - await self.ingestion_service.update_document_status( - document_info, - status=IngestionStatus.PARSING, - ) - - extractions_generator = await self.ingestion_service.parse_file( - document_info - ) - - extractions = [] - async for extraction in extractions_generator: - extractions.append(extraction) - - serializable_extractions = [ - fragment.to_dict() for fragment in extractions - ] - - return { - "status": "Successfully extracted data", - "extractions": serializable_extractions, - "document_info": document_info.to_dict(), - } - - @r2r_hatchet.step(parents=["parse"], timeout="60m") - async def chunk(self, context: Context) -> dict: - document_info_dict = context.step_output("parse")["document_info"] - document_info = DocumentInfo(**document_info_dict) - - await self.ingestion_service.update_document_status( - document_info, - status=IngestionStatus.CHUNKING, - ) - - extractions = context.step_output("parse")["extractions"] - chunking_config = context.workflow_input()["request"].get( - "chunking_config" - ) - - chunk_generator = await self.ingestion_service.chunk_document( - extractions, - chunking_config, - ) - - chunks = [] - async for chunk in chunk_generator: - chunks.append(chunk) - - serializable_chunks = [chunk.to_dict() for chunk in chunks] - - return { - "status": "Successfully chunked data", - "chunks": serializable_chunks, - "document_info": document_info.to_dict(), - } - - @r2r_hatchet.step(parents=["chunk"], timeout="60m") - async def embed(self, context: Context) -> dict: - document_info_dict = context.step_output("chunk")["document_info"] - document_info = DocumentInfo(**document_info_dict) - - await self.ingestion_service.update_document_status( - document_info, - status=IngestionStatus.EMBEDDING, - ) - - chunks = context.step_output("chunk")["chunks"] - - embedding_generator = await self.ingestion_service.embed_document( - chunks - ) - - embeddings = [] - async for embedding in embedding_generator: - embeddings.append(embedding) - - await self.ingestion_service.update_document_status( - document_info, - status=IngestionStatus.STORING, - ) - - storage_generator = await self.ingestion_service.store_embeddings( # type: ignore - embeddings - ) - - async for _ in storage_generator: - pass - - return { - "document_info": document_info.to_dict(), - } - - @r2r_hatchet.step(parents=["embed"], timeout="60m") - async def finalize(self, context: Context) -> dict: - document_info_dict = context.step_output("embed")["document_info"] - document_info = DocumentInfo(**document_info_dict) - - is_update = context.workflow_input()["request"].get("is_update") - - await self.ingestion_service.finalize_ingestion( - document_info, is_update=is_update - ) - - await self.ingestion_service.update_document_status( - document_info, - status=IngestionStatus.SUCCESS, - ) - - return { - "status": "Successfully finalized ingestion", - "document_info": document_info.to_dict(), - } - - @r2r_hatchet.on_failure_step() - async def on_failure(self, context: Context) -> None: - request = context.workflow_input().get("request", {}) - document_id = request.get("document_id") - - if not document_id: - logger.error( - "No document id was found in workflow input to mark a failure." - ) - return - - try: - documents_overview = await self.ingestion_service.providers.database.relational.get_documents_overview( - filter_document_ids=[document_id] - ) - - if not documents_overview: - logger.error( - f"Document with id {document_id} not found in database to mark failure." - ) - return - - document_info = documents_overview[0] - - # Update the document status to FAILURE - await self.ingestion_service.update_document_status( - document_info, - status=IngestionStatus.FAILURE, - ) - - except Exception as e: - logger.error( - f"Failed to update document status for {document_id}: {e}" - ) - - -# TODO: Implement a check to see if the file is actually changed before updating -@r2r_hatchet.workflow(name="update-files", timeout="60m") -class UpdateFilesWorkflow: - def __init__(self, ingestion_service: IngestionService): - self.ingestion_service = ingestion_service - - @r2r_hatchet.step(retries=0, timeout="60m") - async def update_files(self, context: Context) -> None: - data = context.workflow_input()["request"] - parsed_data = IngestionServiceAdapter.parse_update_files_input(data) - - file_datas = parsed_data["file_datas"] - user = parsed_data["user"] - document_ids = parsed_data["document_ids"] - metadatas = parsed_data["metadatas"] - chunking_config = parsed_data["chunking_config"] - file_sizes_in_bytes = parsed_data["file_sizes_in_bytes"] - - if not file_datas: - raise R2RException( - status_code=400, message="No files provided for update." - ) - if len(document_ids) != len(file_datas): - raise R2RException( - status_code=400, - message="Number of ids does not match number of files.", - ) - - documents_overview = ( - await self.ingestion_service.providers.database.relational.get_documents_overview( - filter_document_ids=document_ids, - filter_user_ids=None if user.is_superuser else [user.id], - ) - )["results"] - if len(documents_overview) != len(document_ids): - raise R2RException( - status_code=404, - message="One or more documents not found.", - ) - - results = [] - - for idx, ( - file_data, - doc_id, - doc_info, - file_size_in_bytes, - ) in enumerate( - zip( - file_datas, - document_ids, - documents_overview, - file_sizes_in_bytes, - ) - ): - new_version = increment_version(doc_info.version) - - updated_metadata = ( - metadatas[idx] if metadatas else doc_info.metadata - ) - updated_metadata["title"] = ( - updated_metadata.get("title") - or file_data["filename"].split("/")[-1] - ) - - # Prepare input for ingest_file workflow - ingest_input = { - "file_data": file_data, - "user": data.get("user"), - "metadata": updated_metadata, - "document_id": str(doc_id), - "version": new_version, - "chunking_config": ( - chunking_config.model_dump_json() - if chunking_config - else None - ), - "size_in_bytes": file_size_in_bytes, - "is_update": True, - } - - # Spawn ingest_file workflow as a child workflow - child_result = ( - await context.aio.spawn_workflow( - "ingest-file", - {"request": ingest_input}, - key=f"ingest_file_{doc_id}", - ) - ).result() - results.append(child_result) - - await asyncio.gather(*results) - - return None diff --git a/py/core/main/hatchet/restructure_workflow.py b/py/core/main/hatchet/restructure_workflow.py deleted file mode 100644 index 539624e19..000000000 --- a/py/core/main/hatchet/restructure_workflow.py +++ /dev/null @@ -1,350 +0,0 @@ -import asyncio -import json -import logging -import uuid - -from hatchet_sdk import ConcurrencyLimitStrategy, Context - -from core import GenerationConfig, IngestionStatus, KGCreationSettings -from core.base import R2RDocumentProcessingError -from core.base.abstractions import RestructureStatus - -from ..services import RestructureService -from .base import r2r_hatchet - -logger = logging.getLogger(__name__) - - -@r2r_hatchet.workflow(name="kg-extract-and-store", timeout="60m") -class KgExtractAndStoreWorkflow: - def __init__(self, restructure_service: RestructureService): - self.restructure_service = restructure_service - - @r2r_hatchet.step(retries=3, timeout="60m") - async def kg_extract_and_store(self, context: Context) -> dict: - input_data = context.workflow_input()["request"] - document_id = uuid.UUID(input_data["document_id"]) - fragment_merge_count = input_data["fragment_merge_count"] - max_knowledge_triples = input_data["max_knowledge_triples"] - entity_types = input_data["entity_types"] - relation_types = input_data["relation_types"] - - document_overview = await self.restructure_service.providers.database.relational.get_documents_overview( - filter_document_ids=[document_id] - ) - document_overview = document_overview["results"][0] - - try: - - # Set restructure status to 'processing' - document_overview.restructuring_status = ( - RestructureStatus.PROCESSING - ) - - await self.restructure_service.providers.database.relational.upsert_documents_overview( - document_overview - ) - - errors = await self.restructure_service.kg_extract_and_store( - document_id=document_id, - generation_config=GenerationConfig( - **input_data["generation_config"] - ), - fragment_merge_count=fragment_merge_count, - max_knowledge_triples=max_knowledge_triples, - entity_types=entity_types, - relation_types=relation_types, - ) - # Set restructure status to 'success' if completed successfully - if len(errors) == 0: - document_overview.restructuring_status = ( - RestructureStatus.SUCCESS - ) - await self.restructure_service.providers.database.relational.upsert_documents_overview( - document_overview - ) - else: - - document_overview.restructuring_status = ( - RestructureStatus.FAILURE - ) - await self.restructure_service.providers.database.relational.upsert_documents_overview( - document_overview - ) - raise R2RDocumentProcessingError( - error_message=f"Error in kg_extract_and_store, list of errors: {errors}", - document_id=document_id, - ) - - except Exception as e: - # Set restructure status to 'failure' if an error occurred - document_overview.restructuring_status = RestructureStatus.FAILURE - await self.restructure_service.providers.database.relational.upsert_documents_overview( - document_overview - ) - raise R2RDocumentProcessingError( - error_message=e, - document_id=document_id, - ) - - return {"result": None} - - -@r2r_hatchet.workflow(name="create-graph", timeout="60m") -class CreateGraphWorkflow: - def __init__(self, restructure_service: RestructureService): - self.restructure_service = restructure_service - - @r2r_hatchet.step(retries=1) - async def kg_extraction_ingress(self, context: Context) -> dict: - input_data = context.workflow_input()["request"] - - kg_creation_settings = KGCreationSettings( - **json.loads(input_data["kg_creation_settings"]) - ) - - documents_overview = ( - await self.restructure_service.providers.database.relational.get_documents_overview() - ) - documents_overview = documents_overview["results"] - - document_ids = [ - doc.id - for doc in documents_overview - if doc.restructuring_status != IngestionStatus.SUCCESS - ] - - document_ids = [str(doc_id) for doc_id in document_ids] - - documents_overviews = await self.restructure_service.providers.database.relational.get_documents_overview( - filter_document_ids=document_ids - ) - documents_overviews = documents_overviews["results"] - - # Only run if restructuring_status is pending or failure - filtered_document_ids = [] - for document_overview in documents_overviews: - restructuring_status = document_overview.restructuring_status - if restructuring_status in [ - RestructureStatus.PENDING, - RestructureStatus.FAILURE, - RestructureStatus.ENRICHMENT_FAILURE, - ]: - filtered_document_ids.append(document_overview.id) - elif restructuring_status == RestructureStatus.SUCCESS: - logger.warning( - f"Graph already created for document ID: {document_overview.id}" - ) - elif restructuring_status == RestructureStatus.PROCESSING: - logger.warning( - f"Graph creation is already in progress for document ID: {document_overview.id}" - ) - elif restructuring_status == RestructureStatus.ENRICHED: - logger.warning( - f"Graph is already enriched for document ID: {document_overview.id}" - ) - else: - logger.warning( - f"Unknown restructuring status for document ID: {document_overview.id}" - ) - - results = [] - for document_id in filtered_document_ids: - logger.info( - f"Running Graph Creation Workflow for document ID: {document_id}" - ) - results.append( - ( - context.aio.spawn_workflow( - "kg-extract-and-store", - { - "request": { - "document_id": str(document_id), - "fragment_merge_count": kg_creation_settings.fragment_merge_count, - "max_knowledge_triples": kg_creation_settings.max_knowledge_triples, - "generation_config": kg_creation_settings.generation_config.to_dict(), - "entity_types": kg_creation_settings.entity_types, - "relation_types": kg_creation_settings.relation_types, - } - }, - key=f"kg-extract-and-store_{document_id}", - ) - ) - ) - - if not filtered_document_ids: - logger.info( - "No documents to process, either all graphs were created or in progress, or no documents were provided. Skipping graph creation." - ) - return {"result": "success"} - - logger.info(f"Ran {len(results)} workflows for graph creation") - results = await asyncio.gather(*results) - return {"result": "success"} - - -@r2r_hatchet.workflow(name="enrich-graph", timeout="60m") -class EnrichGraphWorkflow: - def __init__(self, restructure_service: RestructureService): - self.restructure_service = restructure_service - - @r2r_hatchet.step(retries=3, timeout="60m") - async def kg_node_creation(self, context: Context) -> dict: - input_data = context.workflow_input()["request"] - max_description_input_length = input_data[ - "max_description_input_length" - ] - await self.restructure_service.kg_node_creation( - max_description_input_length=max_description_input_length - ) - return {"result": None} - - @r2r_hatchet.step(retries=3, parents=["kg_node_creation"], timeout="60m") - async def kg_clustering(self, context: Context) -> dict: - input_data = context.workflow_input()["request"] - skip_clustering = input_data["skip_clustering"] - force_enrichment = input_data["force_enrichment"] - leiden_params = input_data["leiden_params"] - max_summary_input_length = input_data["max_summary_input_length"] - generation_config = GenerationConfig(**input_data["generation_config"]) - - # todo: check if documets are already being clustered - # check if any documents are still being restructured, need to explicitly set the force_clustering flag to true to run clustering if documents are still being restructured - - documents_overview = ( - await self.restructure_service.providers.database.relational.get_documents_overview() - ) - documents_overview = documents_overview["results"] - - if not force_enrichment: - if any( - document_overview.restructuring_status - == RestructureStatus.PROCESSING - for document_overview in documents_overview - ): - logger.error( - "Graph creation is still in progress for some documents, skipping enrichment, please set force_enrichment to true if you want to run enrichment anyway" - ) - return {"result": None} - - if any( - document_overview.restructuring_status - == RestructureStatus.ENRICHING - for document_overview in documents_overview - ): - logger.error( - "Graph enrichment is still in progress for some documents, skipping enrichment, please set force_enrichment to true if you want to run enrichment anyway" - ) - return {"result": None} - - for document_overview in documents_overview: - if document_overview.restructuring_status in [ - RestructureStatus.SUCCESS, - RestructureStatus.ENRICHMENT_FAILURE, - ]: - document_overview.restructuring_status = ( - RestructureStatus.ENRICHING - ) - - await self.restructure_service.providers.database.relational.upsert_documents_overview( - documents_overview - ) - - try: - if not skip_clustering: - results = await self.restructure_service.kg_clustering( - leiden_params, generation_config - ) - - result = results[0] - - # Run community summary workflows - workflows = [] - for level, community_id in result["intermediate_communities"]: - logger.info( - f"Running KG Community Summary Workflow for community ID: {community_id} at level {level}" - ) - workflows.append( - context.aio.spawn_workflow( - "kg-community-summary", - { - "request": { - "community_id": str(community_id), - "level": level, - "generation_config": generation_config.to_dict(), - "max_summary_input_length": max_summary_input_length, - } - }, - key=f"kg-community-summary_{community_id}_{level}", - ) - ) - - results = await asyncio.gather(*workflows) - else: - logger.info( - "Skipping Leiden clustering as skip_clustering is True, also skipping community summary workflows" - ) - return {"result": None} - - except Exception as e: - logger.error(f"Error in kg_clustering: {str(e)}", exc_info=True) - documents_overview = ( - await self.restructure_service.providers.database.relational.get_documents_overview() - ) - documents_overview = documents_overview["results"] - for document_overview in documents_overview: - if ( - document_overview.restructuring_status - == RestructureStatus.ENRICHING - ): - document_overview.restructuring_status = ( - RestructureStatus.ENRICHMENT_FAILURE - ) - await self.restructure_service.providers.database.relational.upsert_documents_overview( - document_overview - ) - logger.error( - f"Error in kg_clustering for document {document_overview.id}: {str(e)}" - ) - raise e - - finally: - - documents_overview = ( - await self.restructure_service.providers.database.relational.get_documents_overview() - ) - documents_overview = documents_overview["results"] - for document_overview in documents_overview: - if ( - document_overview.restructuring_status - == RestructureStatus.ENRICHING - ): - document_overview.restructuring_status = ( - RestructureStatus.ENRICHED - ) - - await self.restructure_service.providers.database.relational.upsert_documents_overview( - documents_overview - ) - return {"result": None} - - -@r2r_hatchet.workflow(name="kg-community-summary", timeout="60m") -class KGCommunitySummaryWorkflow: - def __init__(self, restructure_service: RestructureService): - self.restructure_service = restructure_service - - @r2r_hatchet.step(retries=1, timeout="60m") - async def kg_community_summary(self, context: Context) -> dict: - input_data = context.workflow_input()["request"] - community_id = input_data["community_id"] - level = input_data["level"] - generation_config = GenerationConfig(**input_data["generation_config"]) - max_summary_input_length = input_data["max_summary_input_length"] - await self.restructure_service.kg_community_summary( - community_id=community_id, - level=level, - max_summary_input_length=max_summary_input_length, - generation_config=generation_config, - ) - return {"result": None} diff --git a/py/core/main/orchestration/__init__.py b/py/core/main/orchestration/__init__.py new file mode 100644 index 000000000..7ae291985 --- /dev/null +++ b/py/core/main/orchestration/__init__.py @@ -0,0 +1,7 @@ +from .hatchet.ingestion_workflow import hatchet_ingestion_factory +from .hatchet.restructure_workflow import hatchet_restructure_workflow + +__all__ = [ + "hatchet_ingestion_factory", + "hatchet_restructure_workflow", +] diff --git a/py/core/main/orchestration/hatchet/__init__.py b/py/core/main/orchestration/hatchet/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/py/core/main/orchestration/hatchet/ingestion_workflow.py b/py/core/main/orchestration/hatchet/ingestion_workflow.py new file mode 100644 index 000000000..0c00925ef --- /dev/null +++ b/py/core/main/orchestration/hatchet/ingestion_workflow.py @@ -0,0 +1,288 @@ +import asyncio +import logging +from typing import TYPE_CHECKING + +from hatchet_sdk import Context + +from core.base import IngestionStatus, OrchestrationProvider, increment_version +from core.base.abstractions import DocumentInfo, R2RException + +from ...services import IngestionService, IngestionServiceAdapter + +if TYPE_CHECKING: + from hatchet_sdk import Hatchet + +logger = logging.getLogger(__name__) + + +def hatchet_ingestion_factory( + orchestration_provider: OrchestrationProvider, service: IngestionService +) -> list["Hatchet.Workflow"]: + @orchestration_provider.workflow( + name="ingest-file", + timeout="60m", + ) + class HatchetIngestFilesWorkflow: + def __init__(self, ingestion_service: IngestionService): + self.ingestion_service = ingestion_service + + @orchestration_provider.step(timeout="60m") + async def parse(self, context: Context) -> dict: + input_data = context.workflow_input()["request"] + parsed_data = IngestionServiceAdapter.parse_ingest_file_input( + input_data + ) + + ingestion_result = ( + await self.ingestion_service.ingest_file_ingress(**parsed_data) + ) + + document_info = ingestion_result["info"] + + await self.ingestion_service.update_document_status( + document_info, + status=IngestionStatus.PARSING, + ) + + extractions_generator = await self.ingestion_service.parse_file( + document_info + ) + + extractions = [] + async for extraction in extractions_generator: + extractions.append(extraction) + + serializable_extractions = [ + fragment.to_dict() for fragment in extractions + ] + + return { + "status": "Successfully extracted data", + "extractions": serializable_extractions, + "document_info": document_info.to_dict(), + } + + @orchestration_provider.step(parents=["parse"], timeout="60m") + async def chunk(self, context: Context) -> dict: + document_info_dict = context.step_output("parse")["document_info"] + document_info = DocumentInfo(**document_info_dict) + + await self.ingestion_service.update_document_status( + document_info, + status=IngestionStatus.CHUNKING, + ) + + extractions = context.step_output("parse")["extractions"] + chunking_config = context.workflow_input()["request"].get( + "chunking_config" + ) + + chunk_generator = await self.ingestion_service.chunk_document( + extractions, + chunking_config, + ) + + chunks = [] + async for chunk in chunk_generator: + chunks.append(chunk) + + serializable_chunks = [chunk.to_dict() for chunk in chunks] + + return { + "status": "Successfully chunked data", + "chunks": serializable_chunks, + "document_info": document_info.to_dict(), + } + + @orchestration_provider.step(parents=["chunk"], timeout="60m") + async def embed(self, context: Context) -> dict: + document_info_dict = context.step_output("chunk")["document_info"] + document_info = DocumentInfo(**document_info_dict) + + await self.ingestion_service.update_document_status( + document_info, + status=IngestionStatus.EMBEDDING, + ) + + chunks = context.step_output("chunk")["chunks"] + + embedding_generator = await self.ingestion_service.embed_document( + chunks + ) + + embeddings = [] + async for embedding in embedding_generator: + embeddings.append(embedding) + + await self.ingestion_service.update_document_status( + document_info, + status=IngestionStatus.STORING, + ) + + storage_generator = await self.ingestion_service.store_embeddings( # type: ignore + embeddings + ) + + async for _ in storage_generator: + pass + + return { + "document_info": document_info.to_dict(), + } + + @orchestration_provider.step(parents=["embed"], timeout="60m") + async def finalize(self, context: Context) -> dict: + document_info_dict = context.step_output("embed")["document_info"] + document_info = DocumentInfo(**document_info_dict) + + is_update = context.workflow_input()["request"].get("is_update") + + await self.ingestion_service.finalize_ingestion( + document_info, is_update=is_update + ) + + await self.ingestion_service.update_document_status( + document_info, + status=IngestionStatus.SUCCESS, + ) + + return { + "status": "Successfully finalized ingestion", + "document_info": document_info.to_dict(), + } + + @orchestration_provider.failure() + async def on_failure(self, context: Context) -> None: + request = context.workflow_input().get("request", {}) + document_id = request.get("document_id") + + if not document_id: + logger.error( + "No document id was found in workflow input to mark a failure." + ) + return + + try: + documents_overview = await self.ingestion_service.providers.database.relational.get_documents_overview( + filter_document_ids=[document_id] + ) + + if not documents_overview: + logger.error( + f"Document with id {document_id} not found in database to mark failure." + ) + return + + document_info = documents_overview[0] + + # Update the document status to FAILURE + await self.ingestion_service.update_document_status( + document_info, + status=IngestionStatus.FAILURE, + ) + + except Exception as e: + logger.error( + f"Failed to update document status for {document_id}: {e}" + ) + + # TODO: Implement a check to see if the file is actually changed before updating + @orchestration_provider.workflow(name="update-files", timeout="60m") + class HatchetUpdateFilesWorkflow: + def __init__(self, ingestion_service: IngestionService): + self.ingestion_service = ingestion_service + + @orchestration_provider.step(retries=0, timeout="60m") + async def update_files(self, context: Context) -> None: + data = context.workflow_input()["request"] + parsed_data = IngestionServiceAdapter.parse_update_files_input( + data + ) + + file_datas = parsed_data["file_datas"] + user = parsed_data["user"] + document_ids = parsed_data["document_ids"] + metadatas = parsed_data["metadatas"] + chunking_config = parsed_data["chunking_config"] + file_sizes_in_bytes = parsed_data["file_sizes_in_bytes"] + + if not file_datas: + raise R2RException( + status_code=400, message="No files provided for update." + ) + if len(document_ids) != len(file_datas): + raise R2RException( + status_code=400, + message="Number of ids does not match number of files.", + ) + + documents_overview = ( + await self.ingestion_service.providers.database.relational.get_documents_overview( + filter_document_ids=document_ids, + filter_user_ids=None if user.is_superuser else [user.id], + ) + )["results"] + if len(documents_overview) != len(document_ids): + raise R2RException( + status_code=404, + message="One or more documents not found.", + ) + + results = [] + + for idx, ( + file_data, + doc_id, + doc_info, + file_size_in_bytes, + ) in enumerate( + zip( + file_datas, + document_ids, + documents_overview, + file_sizes_in_bytes, + ) + ): + new_version = increment_version(doc_info.version) + + updated_metadata = ( + metadatas[idx] if metadatas else doc_info.metadata + ) + updated_metadata["title"] = ( + updated_metadata.get("title") + or file_data["filename"].split("/")[-1] + ) + + # Prepare input for ingest_file workflow + ingest_input = { + "file_data": file_data, + "user": data.get("user"), + "metadata": updated_metadata, + "document_id": str(doc_id), + "version": new_version, + "chunking_config": ( + chunking_config.model_dump_json() + if chunking_config + else None + ), + "size_in_bytes": file_size_in_bytes, + "is_update": True, + } + + # Spawn ingest_file workflow as a child workflow + child_result = ( + await context.aio.spawn_workflow( + "ingest-file", + {"request": ingest_input}, + key=f"ingest_file_{doc_id}", + ) + ).result() + results.append(child_result) + + await asyncio.gather(*results) + + return None + + ingest_files_workflow = HatchetIngestFilesWorkflow(service) + update_files_workflow = HatchetUpdateFilesWorkflow(service) + return [ingest_files_workflow, update_files_workflow] diff --git a/py/core/main/orchestration/hatchet/restructure_workflow.py b/py/core/main/orchestration/hatchet/restructure_workflow.py new file mode 100644 index 000000000..ccc85b85f --- /dev/null +++ b/py/core/main/orchestration/hatchet/restructure_workflow.py @@ -0,0 +1,376 @@ +import asyncio +import json +import logging +import uuid + +from hatchet_sdk import Context + +from core import GenerationConfig, IngestionStatus, KGCreationSettings +from core.base import OrchestrationProvider, R2RDocumentProcessingError +from core.base.abstractions import RestructureStatus + +from ...services import RestructureService + +logger = logging.getLogger(__name__) +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from hatchet_sdk import Hatchet + + +def hatchet_restructure_workflow( + orchestration_provider: OrchestrationProvider, service: RestructureService +) -> list["Hatchet.Workflow"]: + @orchestration_provider.workflow( + name="kg-extract-and-store", timeout="60m" + ) + class KgExtractAndStoreWorkflow: + def __init__(self, restructure_service: RestructureService): + self.restructure_service = restructure_service + + @orchestration_provider.step(retries=3, timeout="60m") + async def kg_extract_and_store(self, context: Context) -> dict: + input_data = context.workflow_input()["request"] + document_id = uuid.UUID(input_data["document_id"]) + fragment_merge_count = input_data["fragment_merge_count"] + max_knowledge_triples = input_data["max_knowledge_triples"] + entity_types = input_data["entity_types"] + relation_types = input_data["relation_types"] + + document_overview = await self.restructure_service.providers.database.relational.get_documents_overview( + filter_document_ids=[document_id] + ) + document_overview = document_overview["results"][0] + + try: + + # Set restructure status to 'processing' + document_overview.restructuring_status = ( + RestructureStatus.PROCESSING + ) + + await self.restructure_service.providers.database.relational.upsert_documents_overview( + document_overview + ) + + errors = await self.restructure_service.kg_extract_and_store( + document_id=document_id, + generation_config=GenerationConfig( + **input_data["generation_config"] + ), + fragment_merge_count=fragment_merge_count, + max_knowledge_triples=max_knowledge_triples, + entity_types=entity_types, + relation_types=relation_types, + ) + # Set restructure status to 'success' if completed successfully + if len(errors) == 0: + document_overview.restructuring_status = ( + RestructureStatus.SUCCESS + ) + await self.restructure_service.providers.database.relational.upsert_documents_overview( + document_overview + ) + else: + + document_overview.restructuring_status = ( + RestructureStatus.FAILURE + ) + await self.restructure_service.providers.database.relational.upsert_documents_overview( + document_overview + ) + raise R2RDocumentProcessingError( + error_message=f"Error in kg_extract_and_store, list of errors: {errors}", + document_id=document_id, + ) + + except Exception as e: + # Set restructure status to 'failure' if an error occurred + document_overview.restructuring_status = ( + RestructureStatus.FAILURE + ) + await self.restructure_service.providers.database.relational.upsert_documents_overview( + document_overview + ) + raise R2RDocumentProcessingError( + error_message=e, + document_id=document_id, + ) + + return {"result": None} + + @orchestration_provider.workflow(name="create-graph", timeout="60m") + class CreateGraphWorkflow: + def __init__(self, restructure_service: RestructureService): + self.restructure_service = restructure_service + + @orchestration_provider.step(retries=1) + async def kg_extraction_ingress(self, context: Context) -> dict: + input_data = context.workflow_input()["request"] + + kg_creation_settings = KGCreationSettings( + **json.loads(input_data["kg_creation_settings"]) + ) + + documents_overview = ( + await self.restructure_service.providers.database.relational.get_documents_overview() + ) + documents_overview = documents_overview["results"] + + document_ids = [ + doc.id + for doc in documents_overview + if doc.restructuring_status != IngestionStatus.SUCCESS + ] + + document_ids = [str(doc_id) for doc_id in document_ids] + + documents_overviews = await self.restructure_service.providers.database.relational.get_documents_overview( + filter_document_ids=document_ids + ) + documents_overviews = documents_overviews["results"] + + # Only run if restructuring_status is pending or failure + filtered_document_ids = [] + for document_overview in documents_overviews: + restructuring_status = document_overview.restructuring_status + if restructuring_status in [ + RestructureStatus.PENDING, + RestructureStatus.FAILURE, + RestructureStatus.ENRICHMENT_FAILURE, + ]: + filtered_document_ids.append(document_overview.id) + elif restructuring_status == RestructureStatus.SUCCESS: + logger.warning( + f"Graph already created for document ID: {document_overview.id}" + ) + elif restructuring_status == RestructureStatus.PROCESSING: + logger.warning( + f"Graph creation is already in progress for document ID: {document_overview.id}" + ) + elif restructuring_status == RestructureStatus.ENRICHED: + logger.warning( + f"Graph is already enriched for document ID: {document_overview.id}" + ) + else: + logger.warning( + f"Unknown restructuring status for document ID: {document_overview.id}" + ) + + results = [] + for document_id in filtered_document_ids: + logger.info( + f"Running Graph Creation Workflow for document ID: {document_id}" + ) + results.append( + ( + context.aio.spawn_workflow( + "kg-extract-and-store", + { + "request": { + "document_id": str(document_id), + "fragment_merge_count": kg_creation_settings.fragment_merge_count, + "max_knowledge_triples": kg_creation_settings.max_knowledge_triples, + "generation_config": kg_creation_settings.generation_config.to_dict(), + "entity_types": kg_creation_settings.entity_types, + "relation_types": kg_creation_settings.relation_types, + } + }, + key=f"kg-extract-and-store_{document_id}", + ) + ) + ) + + if not filtered_document_ids: + logger.info( + "No documents to process, either all graphs were created or in progress, or no documents were provided. Skipping graph creation." + ) + return {"result": "success"} + + logger.info(f"Ran {len(results)} workflows for graph creation") + results = await asyncio.gather(*results) + return {"result": "success"} + + @orchestration_provider.workflow(name="enrich-graph", timeout="60m") + class EnrichGraphWorkflow: + def __init__(self, restructure_service: RestructureService): + self.restructure_service = restructure_service + + @orchestration_provider.step(retries=3, timeout="60m") + async def kg_node_creation(self, context: Context) -> dict: + input_data = context.workflow_input()["request"] + max_description_input_length = input_data[ + "max_description_input_length" + ] + await self.restructure_service.kg_node_creation( + max_description_input_length=max_description_input_length + ) + return {"result": None} + + @orchestration_provider.step( + retries=3, parents=["kg_node_creation"], timeout="60m" + ) + async def kg_clustering(self, context: Context) -> dict: + input_data = context.workflow_input()["request"] + skip_clustering = input_data["skip_clustering"] + force_enrichment = input_data["force_enrichment"] + leiden_params = input_data["leiden_params"] + max_summary_input_length = input_data["max_summary_input_length"] + generation_config = GenerationConfig( + **input_data["generation_config"] + ) + + # todo: check if documets are already being clustered + # check if any documents are still being restructured, need to explicitly set the force_clustering flag to true to run clustering if documents are still being restructured + + documents_overview = ( + await self.restructure_service.providers.database.relational.get_documents_overview() + ) + documents_overview = documents_overview["results"] + + if not force_enrichment: + if any( + document_overview.restructuring_status + == RestructureStatus.PROCESSING + for document_overview in documents_overview + ): + logger.error( + "Graph creation is still in progress for some documents, skipping enrichment, please set force_enrichment to true if you want to run enrichment anyway" + ) + return {"result": None} + + if any( + document_overview.restructuring_status + == RestructureStatus.ENRICHING + for document_overview in documents_overview + ): + logger.error( + "Graph enrichment is still in progress for some documents, skipping enrichment, please set force_enrichment to true if you want to run enrichment anyway" + ) + return {"result": None} + + for document_overview in documents_overview: + if document_overview.restructuring_status in [ + RestructureStatus.SUCCESS, + RestructureStatus.ENRICHMENT_FAILURE, + ]: + document_overview.restructuring_status = ( + RestructureStatus.ENRICHING + ) + + await self.restructure_service.providers.database.relational.upsert_documents_overview( + documents_overview + ) + + try: + if not skip_clustering: + results = await self.restructure_service.kg_clustering( + leiden_params, generation_config + ) + + result = results[0] + + # Run community summary workflows + workflows = [] + for level, community_id in result[ + "intermediate_communities" + ]: + logger.info( + f"Running KG Community Summary Workflow for community ID: {community_id} at level {level}" + ) + workflows.append( + context.aio.spawn_workflow( + "kg-community-summary", + { + "request": { + "community_id": str(community_id), + "level": level, + "generation_config": generation_config.to_dict(), + "max_summary_input_length": max_summary_input_length, + } + }, + key=f"kg-community-summary_{community_id}_{level}", + ) + ) + + results = await asyncio.gather(*workflows) + else: + logger.info( + "Skipping Leiden clustering as skip_clustering is True, also skipping community summary workflows" + ) + return {"result": None} + + except Exception as e: + logger.error( + f"Error in kg_clustering: {str(e)}", exc_info=True + ) + documents_overview = ( + await self.restructure_service.providers.database.relational.get_documents_overview() + ) + documents_overview = documents_overview["results"] + for document_overview in documents_overview: + if ( + document_overview.restructuring_status + == RestructureStatus.ENRICHING + ): + document_overview.restructuring_status = ( + RestructureStatus.ENRICHMENT_FAILURE + ) + await self.restructure_service.providers.database.relational.upsert_documents_overview( + document_overview + ) + logger.error( + f"Error in kg_clustering for document {document_overview.id}: {str(e)}" + ) + raise e + + finally: + + documents_overview = ( + await self.restructure_service.providers.database.relational.get_documents_overview() + ) + documents_overview = documents_overview["results"] + for document_overview in documents_overview: + if ( + document_overview.restructuring_status + == RestructureStatus.ENRICHING + ): + document_overview.restructuring_status = ( + RestructureStatus.ENRICHED + ) + + await self.restructure_service.providers.database.relational.upsert_documents_overview( + documents_overview + ) + return {"result": None} + + @orchestration_provider.workflow( + name="kg-community-summary", timeout="60m" + ) + class KGCommunitySummaryWorkflow: + def __init__(self, restructure_service: RestructureService): + self.restructure_service = restructure_service + + @orchestration_provider.step(retries=1, timeout="60m") + async def kg_community_summary(self, context: Context) -> dict: + input_data = context.workflow_input()["request"] + community_id = input_data["community_id"] + level = input_data["level"] + generation_config = GenerationConfig( + **input_data["generation_config"] + ) + max_summary_input_length = input_data["max_summary_input_length"] + await self.restructure_service.kg_community_summary( + community_id=community_id, + level=level, + max_summary_input_length=max_summary_input_length, + generation_config=generation_config, + ) + return {"result": None} + + return [ + KgExtractAndStoreWorkflow(service), + CreateGraphWorkflow(service), + EnrichGraphWorkflow(service), + KGCommunitySummaryWorkflow(service), + ] diff --git a/py/core/main/services/ingestion_service.py b/py/core/main/services/ingestion_service.py index f19c3be0f..a0a3ac276 100644 --- a/py/core/main/services/ingestion_service.py +++ b/py/core/main/services/ingestion_service.py @@ -289,7 +289,6 @@ def _parse_user_data(user_data) -> UserResponse: @staticmethod def parse_ingest_file_input(data: dict) -> dict: - print('data["chunking_config"] = ', data["chunking_config"]) return { "user": IngestionServiceAdapter._parse_user_data(data["user"]), "metadata": data["metadata"], diff --git a/py/core/providers/orchestration/hatchet.py b/py/core/providers/orchestration/hatchet.py index 13c89b349..ae7e3ff11 100644 --- a/py/core/providers/orchestration/hatchet.py +++ b/py/core/providers/orchestration/hatchet.py @@ -1,25 +1,30 @@ import asyncio from typing import Any, Callable, Optional -from hatchet_sdk import Hatchet - -from core.base import OrchestrationConfig, OrchestrationProvider +from core.base import OrchestrationConfig, OrchestrationProvider, Workflow class HatchetOrchestrationProvider(OrchestrationProvider): def __init__(self, config: OrchestrationConfig): super().__init__(config) + try: + from hatchet_sdk import Hatchet + except ImportError: + raise ImportError( + "Hatchet SDK not installed. Please install it using `pip install hatchet-sdk`." + ) self.orchestrator = Hatchet() self.config: OrchestrationConfig = config # for type hinting self.worker - def register_workflow(self, workflow: Any) -> None: - if self.worker: - self.worker.register_workflow(workflow) - else: - raise ValueError( - "Worker not initialized. Call get_worker() first." - ) + def workflow(self, *args, **kwargs) -> Callable: + return self.orchestrator.workflow(*args, **kwargs) + + def step(self, *args, **kwargs) -> Callable: + return self.orchestrator.step(*args, **kwargs) + + def failure(self, *args, **kwargs) -> Callable: + return self.orchestrator.on_failure_step(*args, **kwargs) def get_worker(self, name: str, max_threads: Optional[int] = None) -> Any: if not max_threads: @@ -27,12 +32,6 @@ def get_worker(self, name: str, max_threads: Optional[int] = None) -> Any: self.worker = self.orchestrator.worker(name, max_threads) return self.worker - def workflow(self, *args, **kwargs) -> Callable: - return self.orchestrator.workflow(*args, **kwargs) - - def step(self, *args, **kwargs) -> Callable: - return self.orchestrator.step(*args, **kwargs) - async def start_worker(self): if not self.worker: raise ValueError( @@ -40,3 +39,40 @@ async def start_worker(self): ) asyncio.create_task(self.worker.async_start()) + + def run_workflow( + self, + workflow_name: str, + parameters: dict, + options: dict, + *args, + **kwargs, + ) -> Any: + self.orchestrator.admin.run_workflow( + workflow_name, + parameters, + options=options, + *args, + **kwargs, + ) + + def register_workflows(self, workflow: Workflow, service: Any) -> None: + if workflow == Workflow.INGESTION: + from core.main.orchestration.hatchet.ingestion_workflow import ( + hatchet_ingestion_factory, + ) + + workflows = hatchet_ingestion_factory(self, service) + if self.worker: + for workflow in workflows: + self.worker.register_workflow(workflow) + + elif workflow == Workflow.RESTRUCTURE: + from core.main.orchestration.hatchet.restructure_workflow import ( + hatchet_restructure_workflow, + ) + + workflows = hatchet_restructure_workflow(self, service) + if self.worker: + for workflow in workflows: + self.worker.register_workflow(workflow)