Skip to content

Commit

Permalink
Feature/encapsulate orchestration (#1265)
Browse files Browse the repository at this point in the history
* fully encapsulate orchestration

* fully encapsulate orchestration

* complete encapsulation

* revert import cmt
  • Loading branch information
emrgnt-cmplxty authored Sep 25, 2024
1 parent db478ed commit fdbd50f
Show file tree
Hide file tree
Showing 21 changed files with 789 additions and 784 deletions.
3 changes: 2 additions & 1 deletion py/core/base/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .file import FileConfig, FileProvider
from .kg import KGConfig, KGProvider
from .llm import CompletionConfig, CompletionProvider
from .orchestration import OrchestrationConfig, OrchestrationProvider
from .orchestration import OrchestrationConfig, OrchestrationProvider, Workflow
from .parsing import OverrideParser, ParsingConfig, ParsingProvider
from .prompt import PromptConfig, PromptProvider

Expand Down Expand Up @@ -55,6 +55,7 @@
# Orchestration provider
"OrchestrationConfig",
"OrchestrationProvider",
"Workflow",
# Parsing provider
"ParsingConfig",
"ParsingProvider",
Expand Down
29 changes: 25 additions & 4 deletions py/core/base/providers/orchestration.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from abc import abstractmethod
from enum import Enum
from typing import Any, Callable

from .base import Provider, ProviderConfig


class Workflow(Enum):
INGESTION = "ingestion"
RESTRUCTURE = "restructure"


class OrchestrationConfig(ProviderConfig):
provider: str
max_threads: int = 256
Expand All @@ -24,21 +30,36 @@ def __init__(self, config: OrchestrationConfig):
self.worker = None

@abstractmethod
def register_workflow(self, workflow: Any) -> None:
async def start_worker(self):
pass

@abstractmethod
def get_worker(self, name: str, max_threads: int) -> Any:
pass

@abstractmethod
def workflow(self, *args, **kwargs) -> Callable:
def step(self, *args, **kwargs) -> Any:
pass

@abstractmethod
def workflow(self, *args, **kwargs) -> Any:
pass

@abstractmethod
def failure(self, *args, **kwargs) -> Any:
pass

@abstractmethod
def step(self, *args, **kwargs) -> Callable:
def register_workflows(self, workflow: Workflow, service: Any) -> None:
pass

@abstractmethod
def start_worker(self):
def run_workflow(
self,
workflow_name: str,
parameters: dict,
options: dict,
*args,
**kwargs,
) -> Any:
pass
6 changes: 1 addition & 5 deletions py/core/main/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

# from .app_entry import r2r_app
from .assembly import *
from .hatchet import *
from .orchestration import *
from .services import *

__all__ = [
Expand All @@ -22,10 +22,6 @@
"RestructureRouter",
## R2R APP
"R2RApp",
## R2R APP ENTRY
# "r2r_app",
## R2R HATCHET
"r2r_hatchet",
## R2R ASSEMBLY
# Builder
"R2RBuilder",
Expand Down
10 changes: 5 additions & 5 deletions py/core/main/api/auth_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
class AuthRouter(BaseRouter):
def __init__(
self,
auth_service: AuthService,
run_type: RunType = RunType.INGESTION,
orchestration_provider: Optional[OrchestrationProvider] = None,
service: AuthService,
orchestration_provider: OrchestrationProvider,
run_type: RunType = RunType.UNSPECIFIED,
):
super().__init__(auth_service, run_type, orchestration_provider)
self.service: AuthService = auth_service # for type hinting
super().__init__(service, orchestration_provider, run_type)
self.service: AuthService = service # for type hinting

def _register_workflows(self):
pass
Expand Down
2 changes: 1 addition & 1 deletion py/core/main/api/base_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ class BaseRouter:
def __init__(
self,
service: "Service",
orchestration_provider: OrchestrationProvider,
run_type: RunType = RunType.UNSPECIFIED,
orchestration_provider: Optional[OrchestrationProvider] = None,
):
self.service = service
self.run_type = run_type
Expand Down
69 changes: 8 additions & 61 deletions py/core/main/api/ingestion_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@
WrappedIngestionResponse,
WrappedUpdateResponse,
)
from core.base.providers import OrchestrationProvider
from core.base.providers import OrchestrationProvider, Workflow

from ...main.hatchet import r2r_hatchet
from ..hatchet import IngestFilesWorkflow, UpdateFilesWorkflow
from ..services.ingestion_service import IngestionService
from .base_router import BaseRouter, RunType

Expand All @@ -29,22 +27,15 @@ class IngestionRouter(BaseRouter):
def __init__(
self,
service: IngestionService,
orchestration_provider: OrchestrationProvider,
run_type: RunType = RunType.INGESTION,
orchestration_provider: Optional[OrchestrationProvider] = None,
):
if not orchestration_provider:
raise ValueError(
"IngestionRouter requires an orchestration provider."
)
super().__init__(service, run_type, orchestration_provider)
super().__init__(service, orchestration_provider, run_type)
self.service: IngestionService = service

def _register_workflows(self):
self.orchestration_provider.register_workflow(
IngestFilesWorkflow(self.service)
)
self.orchestration_provider.register_workflow(
UpdateFilesWorkflow(self.service)
self.orchestration_provider.register_workflows(
Workflow.INGESTION, self.service
)

def _load_openapi_extras(self):
Expand Down Expand Up @@ -146,7 +137,7 @@ async def ingest_files_app(
file_data["content_type"],
)

task_id = r2r_hatchet.admin.run_workflow(
task_id = self.orchestration_provider.run_workflow(
"ingest-file",
{"request": workflow_input},
options={
Expand All @@ -165,50 +156,6 @@ async def ingest_files_app(
)
return messages

@self.router.post(
"/retry_ingest_files",
openapi_extra=ingest_files_extras.get("openapi_extra"),
)
@self.base_endpoint
async def retry_ingest_files(
document_ids: list[UUID] = Form(
...,
description=ingest_files_descriptions.get("document_ids"),
),
auth_user=Depends(self.service.providers.auth.auth_wrapper),
response_model=WrappedIngestionResponse,
):
"""
Retry the ingestion of files into the system.
This endpoint allows you to retry the ingestion of files that have previously failed to ingest into R2R.
A valid user authentication token is required to access this endpoint, as regular users can only retry the ingestion of their own files. More expansive collection permissioning is under development.
"""
if not auth_user.is_superuser:
documents_overview = await self.service.providers.database.relational.get_documents_overview(
filter_document_ids=document_ids,
filter_user_ids=[auth_user.id],
)[
"results"
]
if len(documents_overview) != len(document_ids):
raise R2RException(
status_code=404,
message="One or more documents not found.",
)

# FIXME: This is throwing an aiohttp.client_exceptions.ClientConnectionError: Cannot connect to host localhost:8080 ssl:default… can we whitelist the host?
workflow_list = await r2r_hatchet.rest.workflow_run_list()

# TODO: we want to extract the hatchet run ids for the document ids, and then retry them

return {
"message": "Retry tasks queued successfully.",
"task_ids": [str(task_id) for task_id in workflow_list],
"document_ids": [str(doc_id) for doc_id in document_ids],
}

update_files_extras = self.openapi_extras.get("update_files", {})
update_files_descriptions = update_files_extras.get(
"input_descriptions", {}
Expand Down Expand Up @@ -301,8 +248,8 @@ async def update_files_app(
"is_update": True,
}

task_id = r2r_hatchet.admin.run_workflow(
"update-files", {"request": workflow_input}
task_id = self.orchestration_provider.run_workflow(
"update-files", {"request": workflow_input}, {}
)

return {
Expand Down
4 changes: 2 additions & 2 deletions py/core/main/api/management_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ class ManagementRouter(BaseRouter):
def __init__(
self,
service: ManagementService,
orchestration_provider: OrchestrationProvider,
run_type: RunType = RunType.MANAGEMENT,
orchestration_provider: Optional[OrchestrationProvider] = None,
):
super().__init__(service, run_type, orchestration_provider)
super().__init__(service, orchestration_provider, run_type)
self.service: ManagementService = service # for type hinting
self.start_time = datetime.now(timezone.utc)

Expand Down
40 changes: 10 additions & 30 deletions py/core/main/api/restructure_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,8 @@
WrappedKGCreationResponse,
WrappedKGEnrichmentResponse,
)
from core.base.providers import OrchestrationProvider

from ...main.hatchet import r2r_hatchet
from ..hatchet import (
CreateGraphWorkflow,
EnrichGraphWorkflow,
KGCommunitySummaryWorkflow,
KgExtractAndStoreWorkflow,
)
from core.base.providers import OrchestrationProvider, Workflow

from ..services.restructure_service import RestructureService
from .base_router import BaseRouter, RunType

Expand All @@ -30,14 +23,10 @@ class RestructureRouter(BaseRouter):
def __init__(
self,
service: RestructureService,
orchestration_provider: OrchestrationProvider,
run_type: RunType = RunType.RESTRUCTURE,
orchestration_provider: Optional[OrchestrationProvider] = None,
):
if not orchestration_provider:
raise ValueError(
"RestructureRouter requires an orchestration provider."
)
super().__init__(service, run_type, orchestration_provider)
super().__init__(service, orchestration_provider, run_type)
self.service: RestructureService = service

def _load_openapi_extras(self):
Expand All @@ -49,17 +38,8 @@ def _load_openapi_extras(self):
return yaml_content

def _register_workflows(self):
self.orchestration_provider.register_workflow(
EnrichGraphWorkflow(self.service)
)
self.orchestration_provider.register_workflow(
KgExtractAndStoreWorkflow(self.service)
)
self.orchestration_provider.register_workflow(
CreateGraphWorkflow(self.service)
)
self.orchestration_provider.register_workflow(
KGCommunitySummaryWorkflow(self.service)
self.orchestration_provider.register_workflows(
Workflow.RESTRUCTURE, self.service
)

def _setup_routes(self):
Expand Down Expand Up @@ -102,8 +82,8 @@ async def create_graph(
"user": auth_user.json(),
}

task_id = r2r_hatchet.admin.run_workflow(
"create-graph", {"request": workflow_input}
task_id = self.orchestration_provider.run_workflow(
"create-graph", {"request": workflow_input}, {}
)

return {
Expand Down Expand Up @@ -157,8 +137,8 @@ async def enrich_graph(
"user": auth_user.json(),
}

task_id = r2r_hatchet.admin.run_workflow(
"enrich-graph", {"request": workflow_input}
task_id = self.orchestration_provider.run_workflow(
"enrich-graph", {"request": workflow_input}, {}
)

return {
Expand Down
4 changes: 2 additions & 2 deletions py/core/main/api/retrieval_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ class RetrievalRouter(BaseRouter):
def __init__(
self,
service: RetrievalService,
orchestration_provider: OrchestrationProvider,
run_type: RunType = RunType.RETRIEVAL,
orchestration_provider: Optional[OrchestrationProvider] = None,
):
super().__init__(service, run_type, orchestration_provider)
super().__init__(service, orchestration_provider, run_type)
self.service: RetrievalService = service # for type hinting

def _load_openapi_extras(self):
Expand Down
10 changes: 7 additions & 3 deletions py/core/main/assembly/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,16 +233,20 @@ async def build(self, *args, **kwargs) -> R2RApp:
orchestration_provider = providers.orchestration

routers = {
"auth_router": AuthRouter(services["auth"]).get_router(),
"auth_router": AuthRouter(
services["auth"], orchestration_provider=orchestration_provider
).get_router(),
"ingestion_router": IngestionRouter(
services["ingestion"],
orchestration_provider=orchestration_provider,
).get_router(),
"management_router": ManagementRouter(
services["management"]
services["management"],
orchestration_provider=orchestration_provider,
).get_router(),
"retrieval_router": RetrievalRouter(
services["retrieval"]
services["retrieval"],
orchestration_provider=orchestration_provider,
).get_router(),
"restructure_router": RestructureRouter(
services["restructure"],
Expand Down
4 changes: 3 additions & 1 deletion py/core/main/assembly/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ async def create_auth_provider(
if auth_config.provider == "r2r":
from core.providers import R2RAuthProvider

r2r_auth = R2RAuthProvider(auth_config, crypto_provider, db_provider)
r2r_auth = R2RAuthProvider(
auth_config, crypto_provider, db_provider
)
await r2r_auth.initialize()
return r2r_auth
elif auth_config.provider == "supabase":
Expand Down
18 changes: 0 additions & 18 deletions py/core/main/hatchet/__init__.py

This file was deleted.

6 changes: 0 additions & 6 deletions py/core/main/hatchet/base.py

This file was deleted.

Loading

0 comments on commit fdbd50f

Please sign in to comment.