From fa6000e12d458587d53bf49da13941280d1182e7 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Thu, 3 Oct 2024 15:22:56 -0700 Subject: [PATCH 1/3] Corrections for KG SDK usage --- py/cli/commands/kg.py | 23 ++++++++++------------ py/core/main/api/kg_router.py | 32 ++++++++++++++++++++----------- py/sdk/kg.py | 36 +++++++++++++++++------------------ py/shared/abstractions/kg.py | 5 +++-- 4 files changed, 51 insertions(+), 45 deletions(-) diff --git a/py/cli/commands/kg.py b/py/cli/commands/kg.py index 26e613129..bdd9ff22b 100644 --- a/py/cli/commands/kg.py +++ b/py/cli/commands/kg.py @@ -30,12 +30,9 @@ help="Force the graph creation process.", ) @pass_context -def create_graph( +async def create_graph( ctx, collection_id, run, kg_creation_settings, force_kg_creation ): - """ - Create a new graph. - """ client = ctx.obj if kg_creation_settings: @@ -46,18 +43,19 @@ def create_graph( "Error: kg-creation-settings must be a valid JSON string" ) return - - if not run: - run_type = "estimate" else: - run_type = "run" + kg_creation_settings = {} + + run_type = "run" if run else "estimate" if force_kg_creation: kg_creation_settings = {"force_kg_creation": True} with timer(): response = client.create_graph( - collection_id, run_type, kg_creation_settings + collection_id=collection_id, + run_type=run_type, + kg_creation_settings=kg_creation_settings, ) click.echo(json.dumps(response, indent=2)) @@ -102,11 +100,10 @@ def enrich_graph( "Error: kg-enrichment-settings must be a valid JSON string" ) return - - if not run: - run_type = "estimate" else: - run_type = "run" + kg_enrichment_settings = {} + + run_type = "run" if run else "estimate" if force_kg_enrichment: kg_enrichment_settings = {"force_kg_enrichment": True} diff --git a/py/core/main/api/kg_router.py b/py/core/main/api/kg_router.py index 081ad3a90..227d870ce 100644 --- a/py/core/main/api/kg_router.py +++ b/py/core/main/api/kg_router.py @@ -1,11 +1,13 @@ import logging from pathlib import Path -from typing import Optional +from typing import Optional, Union from uuid import UUID import yaml from fastapi import Body, Depends, Query -from pydantic import Json +from pydantic import Json, ValidationError + +from core import R2RException from core.base import RunType from core.base.api.models import ( @@ -14,7 +16,7 @@ ) from core.base.providers import OrchestrationProvider, Workflow from core.utils import generate_default_user_collection_id -from shared.abstractions.kg import KGRunType +from shared.abstractions.kg import KGRunType, KGCreationSettings from ..services.kg_service import KgService from .base_router import BaseRouter @@ -61,10 +63,10 @@ async def create_graph( description="Collection ID to create graph for.", ), run_type: Optional[KGRunType] = Body( - default=KGRunType.ESTIMATE, + default=None, description="Run type for the graph creation process.", ), - kg_creation_settings: Optional[Json[dict]] = Body( + kg_creation_settings: Optional[dict] = Body( default=None, description="Settings for the graph creation process.", ), @@ -76,10 +78,12 @@ async def create_graph( This step extracts the relevant entities and relationships from the documents and creates a graph based on the extracted information. In order to do GraphRAG, you will need to run the enrich_graph endpoint. """ - if not auth_user.is_superuser: logger.warning("Implement permission checks here.") + if not run_type: + run_type = KGRunType.ESTIMATE + if not collection_id: collection_id = generate_default_user_collection_id( auth_user.id @@ -102,7 +106,7 @@ async def create_graph( ) workflow_input = { - "collection_id": collection_id, + "collection_id": str(collection_id), "kg_creation_settings": server_kg_creation_settings.model_dump_json(), "user": auth_user.json(), } @@ -124,7 +128,7 @@ async def enrich_graph( default=KGRunType.ESTIMATE, description="Run type for the graph enrichment process.", ), - kg_enrichment_settings: Optional[Json[dict]] = Body( + kg_enrichment_settings: Optional[dict] = Body( default=None, description="Settings for the graph enrichment process.", ), @@ -134,10 +138,12 @@ async def enrich_graph( """ This endpoint enriches the graph with additional information. It creates communities of nodes based on their similarity and adds embeddings to the graph. This step is necessary for GraphRAG to work. """ - if not auth_user.is_superuser: logger.warning("Implement permission checks here.") + if not run_type: + run_type = KGRunType.ESTIMATE + server_kg_enrichment_settings = ( self.service.providers.kg.config.kg_enrichment_settings ) @@ -147,9 +153,13 @@ async def enrich_graph( auth_user.id ) + logger.info(f"Running on collection {collection_id}") + if run_type is KGRunType.ESTIMATE: - return await self.service.get_enrichment_estimate(collection_id, server_kg_enrichment_settings) + return await self.service.get_enrichment_estimate( + collection_id, server_kg_enrichment_settings + ) if kg_enrichment_settings: for key, value in kg_enrichment_settings.items(): @@ -157,7 +167,7 @@ async def enrich_graph( setattr(server_kg_enrichment_settings, key, value) workflow_input = { - "collection_id": collection_id, + "collection_id": str(collection_id), "kg_enrichment_settings": server_kg_enrichment_settings.model_dump_json(), "user": auth_user.json(), } diff --git a/py/sdk/kg.py b/py/sdk/kg.py index 9ae1f1372..dc266f34c 100644 --- a/py/sdk/kg.py +++ b/py/sdk/kg.py @@ -16,33 +16,34 @@ class KGMethods: @staticmethod async def create_graph( client, - collection_id: Optional[UUID] = None, - run_type: KGRunType = KGRunType.ESTIMATE, + collection_id: Optional[Union[UUID, str]] = None, + run_type: Optional[Union[str, KGRunType]] = None, kg_creation_settings: Optional[Union[dict, KGCreationSettings]] = None, ) -> KGCreationResponse: """ Create a graph from the given settings. + + Args: + collection_id (Optional[Union[UUID, str]]): The ID of the collection to create the graph for. + run_type (Optional[Union[str, KGRunType]]): The type of run to perform. + kg_creation_settings (Optional[Union[dict, KGCreationSettings]]): Settings for the graph creation process. """ if isinstance(kg_creation_settings, KGCreationSettings): kg_creation_settings = kg_creation_settings.model_dump() - elif kg_creation_settings is None or kg_creation_settings == "{}": - kg_creation_settings = {} data = { - "run_type": run_type, - "kg_creation_settings": json.dumps(kg_creation_settings), + "collection_id": str(collection_id) if collection_id else None, + "run_type": str(run_type) if run_type else None, + "kg_creation_settings": kg_creation_settings or {}, } - if collection_id: - data["collection_id"] = collection_id - return await client._make_request("POST", "create_graph", json=data) @staticmethod async def enrich_graph( client, - collection_id: Optional[UUID] = None, - run_type: KGRunType = KGRunType.ESTIMATE, + collection_id: Optional[Union[UUID, str]] = None, + run_type: Optional[Union[str, KGRunType]] = None, kg_enrichment_settings: Optional[ Union[dict, KGEnrichmentSettings] ] = None, @@ -51,24 +52,21 @@ async def enrich_graph( Perform graph enrichment over the entire graph. Args: - collection_id (str): The ID of the collection to enrich. + collection_id (Optional[Union[UUID, str]]): The ID of the collection to enrich the graph for. + run_type (Optional[Union[str, KGRunType]]): The type of run to perform. kg_enrichment_settings (Optional[Union[dict, KGEnrichmentSettings]]): Settings for the graph enrichment process. Returns: KGEnrichmentResponse: Results of the graph enrichment process. """ if isinstance(kg_enrichment_settings, KGEnrichmentSettings): kg_enrichment_settings = kg_enrichment_settings.model_dump() - elif kg_enrichment_settings is None or kg_enrichment_settings == "{}": - kg_enrichment_settings = {} data = { - "kg_enrichment_settings": json.dumps(kg_enrichment_settings), - "run_type": run_type, + "collection_id": str(collection_id) if collection_id else None, + "run_type": str(run_type) if run_type else None, + "kg_enrichment_settings": kg_enrichment_settings or {}, } - if collection_id: - data["collection_id"] = collection_id - return await client._make_request("POST", "enrich_graph", json=data) @staticmethod diff --git a/py/shared/abstractions/kg.py b/py/shared/abstractions/kg.py index 2561b51a9..85a095b33 100644 --- a/py/shared/abstractions/kg.py +++ b/py/shared/abstractions/kg.py @@ -6,7 +6,7 @@ from .llm import GenerationConfig -class KGRunType(Enum): +class KGRunType(str, Enum): """Type of KG run.""" ESTIMATE = "estimate" @@ -105,6 +105,7 @@ class KGEnrichmentEstimationResponse(R2RSerializable): description="The estimated total time to run the graph enrichment process.", ) + class KGCreationSettings(R2RSerializable): """Settings for knowledge graph creation.""" @@ -167,7 +168,7 @@ class KGEnrichmentSettings(R2RSerializable): description="Whether to skip leiden clustering on the graph or not.", ) - force_enrichment: bool = Field( + force_kg_enrichment: bool = Field( default=False, description="Force run the enrichment step even if graph creation is still in progress for some documents.", ) From 830fa666d510ae49c6ebfea024324fa97aadb7a4 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Thu, 3 Oct 2024 15:28:15 -0700 Subject: [PATCH 2/3] Clean up --- py/core/main/api/kg_router.py | 7 ++----- py/sdk/kg.py | 1 - 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/py/core/main/api/kg_router.py b/py/core/main/api/kg_router.py index 227d870ce..80d7b1a5d 100644 --- a/py/core/main/api/kg_router.py +++ b/py/core/main/api/kg_router.py @@ -1,13 +1,10 @@ import logging from pathlib import Path -from typing import Optional, Union +from typing import Optional from uuid import UUID import yaml from fastapi import Body, Depends, Query -from pydantic import Json, ValidationError - -from core import R2RException from core.base import RunType from core.base.api.models import ( @@ -16,7 +13,7 @@ ) from core.base.providers import OrchestrationProvider, Workflow from core.utils import generate_default_user_collection_id -from shared.abstractions.kg import KGRunType, KGCreationSettings +from shared.abstractions.kg import KGRunType from ..services.kg_service import KgService from .base_router import BaseRouter diff --git a/py/sdk/kg.py b/py/sdk/kg.py index dc266f34c..30ed64f29 100644 --- a/py/sdk/kg.py +++ b/py/sdk/kg.py @@ -1,4 +1,3 @@ -import json from typing import Optional, Union from uuid import UUID From 0ad0827938fc027d906ca02002c330517ae39d01 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Thu, 3 Oct 2024 15:37:43 -0700 Subject: [PATCH 3/3] missed file --- py/cli/commands/kg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/cli/commands/kg.py b/py/cli/commands/kg.py index bdd9ff22b..ebfd6e556 100644 --- a/py/cli/commands/kg.py +++ b/py/cli/commands/kg.py @@ -30,7 +30,7 @@ help="Force the graph creation process.", ) @pass_context -async def create_graph( +def create_graph( ctx, collection_id, run, kg_creation_settings, force_kg_creation ): client = ctx.obj