Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Corrections for KG SDK usage #1330

Merged
merged 4 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions py/cli/commands/kg.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@
def create_graph(
ctx, collection_id, run, kg_creation_settings, force_kg_creation
):
"""
Create a new graph.
"""
client = ctx.obj

if kg_creation_settings:
Expand All @@ -46,18 +43,19 @@ def create_graph(
"Error: kg-creation-settings must be a valid JSON string"
)
return

if not run:
run_type = "estimate"
else:
run_type = "run"
kg_creation_settings = {}

run_type = "run" if run else "estimate"

if force_kg_creation:
kg_creation_settings = {"force_kg_creation": True}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overwriting kg_creation_settings unconditionally when force_kg_creation is true may lead to loss of other settings. Consider updating the dictionary instead of overwriting it.


with timer():
response = client.create_graph(
collection_id, run_type, kg_creation_settings
collection_id=collection_id,
run_type=run_type,
kg_creation_settings=kg_creation_settings,
)

click.echo(json.dumps(response, indent=2))
Expand Down Expand Up @@ -102,11 +100,10 @@ def enrich_graph(
"Error: kg-enrichment-settings must be a valid JSON string"
)
return

if not run:
run_type = "estimate"
else:
run_type = "run"
kg_enrichment_settings = {}

run_type = "run" if run else "estimate"

if force_kg_enrichment:
kg_enrichment_settings = {"force_kg_enrichment": True}
Expand Down
21 changes: 13 additions & 8 deletions py/core/main/api/kg_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import yaml
from fastapi import Body, Depends, Query
from pydantic import Json

from core.base import RunType
from core.base.api.models import (
Expand Down Expand Up @@ -65,10 +64,10 @@ async def create_graph(
description="Collection ID to create graph for.",
),
run_type: Optional[KGRunType] = Body(
default=KGRunType.ESTIMATE,
default=None,
description="Run type for the graph creation process.",
),
kg_creation_settings: Optional[Json[dict]] = Body(
kg_creation_settings: Optional[dict] = Body(
default=None,
description="Settings for the graph creation process.",
),
Expand All @@ -79,10 +78,12 @@ async def create_graph(
This step extracts the relevant entities and relationships from the documents and creates a graph based on the extracted information.
In order to do GraphRAG, you will need to run the enrich_graph endpoint.
"""

if not auth_user.is_superuser:
logger.warning("Implement permission checks here.")

if not run_type:
run_type = KGRunType.ESTIMATE

if not collection_id:
collection_id = generate_default_user_collection_id(
auth_user.id
Expand All @@ -105,7 +106,7 @@ async def create_graph(
)

workflow_input = {
"collection_id": collection_id,
"collection_id": str(collection_id),
"kg_creation_settings": server_kg_creation_settings.model_dump_json(),
"user": auth_user.json(),
}
Expand All @@ -127,7 +128,7 @@ async def enrich_graph(
default=KGRunType.ESTIMATE,
description="Run type for the graph enrichment process.",
),
kg_enrichment_settings: Optional[Json[dict]] = Body(
kg_enrichment_settings: Optional[dict] = Body(
default=None,
description="Settings for the graph enrichment process.",
),
Expand All @@ -136,10 +137,12 @@ async def enrich_graph(
"""
This endpoint enriches the graph with additional information. It creates communities of nodes based on their similarity and adds embeddings to the graph. This step is necessary for GraphRAG to work.
"""

if not auth_user.is_superuser:
logger.warning("Implement permission checks here.")

if not run_type:
run_type = KGRunType.ESTIMATE

server_kg_enrichment_settings = (
self.service.providers.kg.config.kg_enrichment_settings
)
Expand All @@ -149,6 +152,8 @@ async def enrich_graph(
auth_user.id
)

logger.info(f"Running on collection {collection_id}")

if run_type is KGRunType.ESTIMATE:

return await self.service.get_enrichment_estimate(
Expand All @@ -161,7 +166,7 @@ async def enrich_graph(
setattr(server_kg_enrichment_settings, key, value)

workflow_input = {
"collection_id": collection_id,
"collection_id": str(collection_id),
"kg_enrichment_settings": server_kg_enrichment_settings.model_dump_json(),
"user": auth_user.json(),
}
Expand Down
37 changes: 17 additions & 20 deletions py/sdk/kg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
from typing import Optional, Union
from uuid import UUID

Expand All @@ -16,33 +15,34 @@ class KGMethods:
@staticmethod
async def create_graph(
client,
collection_id: Optional[UUID] = None,
run_type: KGRunType = KGRunType.ESTIMATE,
collection_id: Optional[Union[UUID, str]] = None,
run_type: Optional[Union[str, KGRunType]] = None,
kg_creation_settings: Optional[Union[dict, KGCreationSettings]] = None,
) -> KGCreationResponse:
"""
Create a graph from the given settings.

Args:
collection_id (Optional[Union[UUID, str]]): The ID of the collection to create the graph for.
run_type (Optional[Union[str, KGRunType]]): The type of run to perform.
kg_creation_settings (Optional[Union[dict, KGCreationSettings]]): Settings for the graph creation process.
"""
if isinstance(kg_creation_settings, KGCreationSettings):
kg_creation_settings = kg_creation_settings.model_dump()
elif kg_creation_settings is None or kg_creation_settings == "{}":
kg_creation_settings = {}

data = {
"run_type": run_type,
"kg_creation_settings": json.dumps(kg_creation_settings),
"collection_id": str(collection_id) if collection_id else None,
"run_type": str(run_type) if run_type else None,
"kg_creation_settings": kg_creation_settings or {},
}

if collection_id:
data["collection_id"] = collection_id

return await client._make_request("POST", "create_graph", json=data)

@staticmethod
async def enrich_graph(
client,
collection_id: Optional[UUID] = None,
run_type: KGRunType = KGRunType.ESTIMATE,
collection_id: Optional[Union[UUID, str]] = None,
run_type: Optional[Union[str, KGRunType]] = None,
kg_enrichment_settings: Optional[
Union[dict, KGEnrichmentSettings]
] = None,
Expand All @@ -51,24 +51,21 @@ async def enrich_graph(
Perform graph enrichment over the entire graph.

Args:
collection_id (str): The ID of the collection to enrich.
collection_id (Optional[Union[UUID, str]]): The ID of the collection to enrich the graph for.
run_type (Optional[Union[str, KGRunType]]): The type of run to perform.
kg_enrichment_settings (Optional[Union[dict, KGEnrichmentSettings]]): Settings for the graph enrichment process.
Returns:
KGEnrichmentResponse: Results of the graph enrichment process.
"""
if isinstance(kg_enrichment_settings, KGEnrichmentSettings):
kg_enrichment_settings = kg_enrichment_settings.model_dump()
elif kg_enrichment_settings is None or kg_enrichment_settings == "{}":
kg_enrichment_settings = {}

data = {
"kg_enrichment_settings": json.dumps(kg_enrichment_settings),
"run_type": run_type,
"collection_id": str(collection_id) if collection_id else None,
"run_type": str(run_type) if run_type else None,
"kg_enrichment_settings": kg_enrichment_settings or {},
}

if collection_id:
data["collection_id"] = collection_id

return await client._make_request("POST", "enrich_graph", json=data)

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions py/shared/abstractions/kg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .llm import GenerationConfig


class KGRunType(Enum):
class KGRunType(str, Enum):
"""Type of KG run."""

ESTIMATE = "estimate"
Expand Down Expand Up @@ -75,7 +75,7 @@ class KGEnrichmentSettings(R2RSerializable):
description="Whether to skip leiden clustering on the graph or not.",
)

force_enrichment: bool = Field(
force_kg_enrichment: bool = Field(
default=False,
description="Force run the enrichment step even if graph creation is still in progress for some documents.",
)
Expand Down
Loading