Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
emrgnt-cmplxty committed Sep 16, 2024
1 parent 01decb5 commit cb817ff
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 9 deletions.
1 change: 0 additions & 1 deletion py/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .cli import cli, main
from .command_group import cli as command_group_cli

from .commands import auth, ingestion, management, retrieval, server

__all__ = [
Expand Down
4 changes: 3 additions & 1 deletion py/core/main/hatchet/restructure_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ def __init__(self, restructure_service: RestructureService):
@r2r_hatchet.step(retries=3, timeout="60m")
async def kg_node_creation(self, context: Context) -> None:
input_data = context.workflow_input()["request"]
max_description_input_length = input_data["max_description_input_length"]
max_description_input_length = input_data[
"max_description_input_length"
]
await self.restructure_service.kg_node_creation(
max_description_input_length=max_description_input_length
)
Expand Down
23 changes: 16 additions & 7 deletions py/core/pipes/kg/node_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import logging
import random
from typing import Any, AsyncGenerator, Optional
from uuid import UUID

Expand All @@ -17,8 +18,6 @@
from core.base.abstractions.graph import Entity, Triple
from core.base.pipes.base_pipe import AsyncPipe

import random

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -127,7 +126,9 @@ async def _run_logic(
Ensure the summary is coherent, informative, and captures the essence of the entity within the context of the provided information.
"""

async def process_entity(entity, triples, max_description_input_length):
async def process_entity(
entity, triples, max_description_input_length
):

# if embedding is present in the entity, just return it
# in the future disable this to override and recompute the descriptions for all entities
Expand All @@ -136,8 +137,8 @@ async def process_entity(entity, triples, max_description_input_length):

entity_info = f"{entity.name}, {entity.description}"
triples_txt = [
f"{i+1}: {triple.subject}, {triple.object}, {triple.predicate} - Summary: {triple.description}"
for i, triple in enumerate(triples)
f"{i+1}: {triple.subject}, {triple.object}, {triple.predicate} - Summary: {triple.description}"
for i, triple in enumerate(triples)
]

# truncate the descriptions to the max_description_input_length
Expand Down Expand Up @@ -196,13 +197,21 @@ async def process_entity(entity, triples, max_description_input_length):

return out_entity

max_description_input_length = input.message["max_description_input_length"]
max_description_input_length = input.message[
"max_description_input_length"
]
node_extrations = input.message["node_extrations"]

tasks = []
count = 0
async for entity, triples in node_extrations:
tasks.append(asyncio.create_task(process_entity(entity, triples, max_description_input_length)))
tasks.append(
asyncio.create_task(
process_entity(
entity, triples, max_description_input_length
)
)
)
count += 1

logger.info(f"KG Node Description pipe: Created {count} tasks")
Expand Down

0 comments on commit cb817ff

Please sign in to comment.