Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #990

Merged
merged 5 commits into from
Aug 27, 2024
Merged

Dev #990

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions py/cli/command_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@


@click.group()
@click.option(
"--base-url", default="http://localhost:8000", help="Base URL for the API"
)
@click.pass_context
def cli(ctx):
def cli(ctx, base_url):
"""R2R CLI for all core operations."""

ctx.obj = R2RClient()
ctx.obj = R2RClient(base_url=base_url)
2 changes: 1 addition & 1 deletion py/cli/utils/docker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def run_docker_serve(
if not no_conflict:
click.secho(f"Warning: {message}", fg="red", bold=True)
click.echo("This may cause issues when starting the Docker setup.")
if not click.confirm("Do you want to continue?", default=False):
if not click.confirm("Do you want to continue?", default=True):
click.echo("Aborting Docker setup.")
return

Expand Down
2 changes: 1 addition & 1 deletion py/core/agent/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def format_search_results_for_llm(
) -> str:
formatted_results = ""
for i, result in enumerate(results):
text = result.metadata.get("text", "N/A")
text = result.text
formatted_results += f"{i+1}. {text}\n"
return formatted_results

Expand Down
1 change: 1 addition & 0 deletions py/core/base/providers/chunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class Method(str, Enum):
BY_TITLE = "by_title"
BASIC = "basic"
RECURSIVE = "recursive"
CHARACTER = "character"


class ChunkingConfig(ProviderConfig):
Expand Down
4 changes: 3 additions & 1 deletion py/core/base/utils/splitter/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,9 +628,11 @@ def transform_documents(
class CharacterTextSplitter(TextSplitter):
"""Splitting text that looks at characters."""

DEFAULT_SEPARATOR: str = "\n\n"

def __init__(
self,
separator: str = "\n\n",
separator: str = DEFAULT_SEPARATOR,
is_separator_regex: bool = False,
**kwargs: Any,
) -> None:
Expand Down
20 changes: 0 additions & 20 deletions py/core/configs/neo4j_kg.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,6 @@ concurrent_request_limit = 256
stream = false
add_generation_kwargs = { }

[embedding]
provider = "openai"
base_model = "text-embedding-3-small"
base_dimension = 1_536
batch_size = 256
add_title_as_prefix = true

[ingestion]
excluded_parsers = [ "gif", "jpeg", "jpg", "png", "svg", "mp3", "mp4" ]

[kg]
provider = "neo4j"
batch_size = 256
Expand All @@ -35,13 +25,3 @@ kg_extraction_prompt = "graphrag_triplet_extraction_zero_shot"

[kg.kg_search_config]
model = "gpt-4o-mini"

[database]
provider = "postgres"

[agent]
system_instruction_name = "rag_agent"
tool_names = ["search"]

[agent.generation_config]
model = "gpt-4o-mini"
61 changes: 33 additions & 28 deletions py/core/examples/scripts/advanced_kg_cookbook.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import requests
from bs4 import BeautifulSoup, Comment

from r2r import EntityType, R2RClient, R2RPromptProvider, RelationshipType
from r2r import R2RClient, R2RPromptProvider


def escape_braces(text):
Expand Down Expand Up @@ -88,48 +88,53 @@ def main(

# Specify the entity types for the KG extraction prompt
entity_types = [
EntityType("COMPANY"),
EntityType("SCHOOL"),
EntityType("LOCATION"),
EntityType("PERSON"),
EntityType("DATE"),
EntityType("OTHER"),
EntityType("QUANTITY"),
EntityType("EVENT"),
EntityType("INDUSTRY"),
EntityType("MEDIA"),
"COMPANY",
"SCHOOL",
"LOCATION",
"PERSON",
"DATE",
"OTHER",
"QUANTITY",
"EVENT",
"INDUSTRY",
"MEDIA",
]

# Specify the relations for the KG construction
relations = [
# Founder Relations
RelationshipType("EDUCATED_AT"),
RelationshipType("WORKED_AT"),
RelationshipType("FOUNDED"),
"EDUCATED_AT",
"WORKED_AT",
"FOUNDED",
# Company relations
RelationshipType("RAISED"),
RelationshipType("REVENUE"),
RelationshipType("TEAM_SIZE"),
RelationshipType("LOCATION"),
RelationshipType("ACQUIRED_BY"),
RelationshipType("ANNOUNCED"),
RelationshipType("INDUSTRY"),
"RAISED",
"REVENUE",
"TEAM_SIZE",
"LOCATION",
"ACQUIRED_BY",
"ANNOUNCED",
"INDUSTRY",
# Product relations
RelationshipType("PRODUCT"),
RelationshipType("FEATURES"),
RelationshipType("TECHNOLOGY"),
"PRODUCT",
"FEATURES",
"TECHNOLOGY",
# Additional relations
RelationshipType("HAS"),
RelationshipType("AS_OF"),
RelationshipType("PARTICIPATED"),
RelationshipType("ASSOCIATED"),
"HAS",
"AS_OF",
"PARTICIPATED",
"ASSOCIATED",
]

client = R2RClient(base_url=base_url)
r2r_prompts = R2RPromptProvider()

prompt = "graphrag_triplet_extraction_few_shot"

r2r_prompts.update_prompt(
prompt,
input_types={"entity_types": entity_types, "relations": relations},
)

url_map = get_all_yc_co_directory_urls()

i = 0
Expand Down
2 changes: 1 addition & 1 deletion py/core/main/abstractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class R2RPipelines(BaseModel):
search_pipeline: SearchPipeline
rag_pipeline: RAGPipeline
streaming_rag_pipeline: RAGPipeline
kg_enrichment_pipeline: KGEnrichmentPipeline
kg_enrichment_pipeline: Optional[KGEnrichmentPipeline]

class Config:
arbitrary_types_allowed = True
Expand Down
27 changes: 14 additions & 13 deletions py/core/main/assembly/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ def create_kg_pipe(self, *args, **kwargs) -> Any:
return KGTriplesExtractionPipe(
kg_provider=self.providers.kg,
llm_provider=self.providers.llm,
database_provider=self.providers.database,
prompt_provider=self.providers.prompt,
chunking_provider=self.providers.chunking,
kg_batch_size=self.config.kg.batch_size,
Expand Down Expand Up @@ -515,12 +516,6 @@ def create_ingestion_pipeline(self, *args, **kwargs) -> IngestionPipeline:
ingestion_pipeline.add_pipe(
self.pipes.vector_storage_pipe, embedding_pipe=True
)
# Add KG pipes if provider is set
if self.config.kg.provider is not None:
ingestion_pipeline.add_pipe(self.pipes.kg_pipe, kg_pipe=True)
ingestion_pipeline.add_pipe(
self.pipes.kg_storage_pipe, kg_pipe=True
)

return ingestion_pipeline

Expand Down Expand Up @@ -563,13 +558,19 @@ def create_rag_pipeline(

def create_kg_enrichment_pipeline(
self, *args, **kwargs
) -> KGEnrichmentPipeline:
kg_enrichment_pipeline = KGEnrichmentPipeline()
kg_enrichment_pipeline.add_pipe(self.pipes.kg_node_extraction_pipe)
kg_enrichment_pipeline.add_pipe(self.pipes.kg_node_description_pipe)
kg_enrichment_pipeline.add_pipe(self.pipes.kg_clustering_pipe)

return kg_enrichment_pipeline
) -> Optional[KGEnrichmentPipeline]:
if self.config.kg.provider is not None:
kg_enrichment_pipeline = KGEnrichmentPipeline()
kg_enrichment_pipeline.add_pipe(self.pipes.kg_pipe)
kg_enrichment_pipeline.add_pipe(self.pipes.kg_storage_pipe)
kg_enrichment_pipeline.add_pipe(self.pipes.kg_node_extraction_pipe)
kg_enrichment_pipeline.add_pipe(
self.pipes.kg_node_description_pipe
)
kg_enrichment_pipeline.add_pipe(self.pipes.kg_clustering_pipe)
return kg_enrichment_pipeline
else:
return None

def create_pipelines(
self,
Expand Down
42 changes: 27 additions & 15 deletions py/core/main/services/ingestion_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
logger = logging.getLogger(__name__)
MB_CONVERSION_FACTOR = 1024 * 1024
STARTING_VERSION = "v0"
MAX_FILES_PER_INGESTION = 100
OVERVIEW_FETCH_PAGE_SIZE = 1_000


class IngestionService(Service):
Expand Down Expand Up @@ -66,7 +68,11 @@ async def ingest_files(
raise R2RException(
status_code=400, message="No files provided for ingestion."
)

if len(files) > MAX_FILES_PER_INGESTION:
raise R2RException(
status_code=400,
message=f"Exceeded maximum number of files per ingestion: {MAX_FILES_PER_INGESTION}.",
)
try:
documents = []
for iteration, file in enumerate(files):
Expand Down Expand Up @@ -131,24 +137,30 @@ async def update_files(
generate_user_document_id(file.filename, user.id)
for file in files
]
# Only superusers can modify arbitrary document ids, which this gate guarantees in conjuction with the check that follows
documents_overview = (
(
if len(files) > MAX_FILES_PER_INGESTION:
raise R2RException(
status_code=400,
message=f"Exceeded maximum number of files per ingestion: {MAX_FILES_PER_INGESTION}.",
)

documents_overview = []

offset = 0
while True:
documents_overview_page = (
self.providers.database.relational.get_documents_overview(
filter_document_ids=document_ids,
filter_user_ids=(
[user.id] if not user.is_superuser else None
),
offset=offset,
limit=OVERVIEW_FETCH_PAGE_SIZE,
)
)
if user.is_superuser
else self.providers.database.relational.get_documents_overview(
filter_document_ids=document_ids, filter_user_ids=[user.id]
)
)

if len(documents_overview) != len(files):
raise R2RException(
status_code=404,
message="One or more documents was not found.",
)
documents_overview.extend(documents_overview_page)
if len(documents_overview_page) < OVERVIEW_FETCH_PAGE_SIZE:
break
offset += 1

documents = []
new_versions = []
Expand Down
8 changes: 7 additions & 1 deletion py/core/main/services/restructure_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,14 @@ async def enrich_graph(
"""
try:
# Assuming there's a graph enrichment pipeline

async def input_generator():
input = []
for doc in input:
yield doc

return await self.pipelines.kg_enrichment_pipeline.run(
input=[],
input=input_generator(),
run_manager=self.run_manager,
)

Expand Down
2 changes: 1 addition & 1 deletion py/core/parsers/media/audio_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class AudioParser(AsyncParser[bytes]):
"""A parser for audio data."""

def __init__(
self, api_base: str = "https://api.openai.com/v2/audio/transcriptions"
self, api_base: str = "https://api.openai.com/v1/audio/transcriptions"
):
self.api_base = api_base
self.openai_api_key = os.environ.get("OPENAI_API_KEY")
Expand Down
4 changes: 2 additions & 2 deletions py/core/parsers/media/openai_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def process_frame_with_openai(
api_key: str,
model: str = "gpt-4o",
max_tokens: int = 2_048,
api_base: str = "https://api.openai.com/v2/chat/completions",
api_base: str = "https://api.openai.com/v1/chat/completions",
) -> str:
headers = {
"Content-Type": "application/json",
Expand Down Expand Up @@ -43,7 +43,7 @@ def process_frame_with_openai(
def process_audio_with_openai(
audio_file,
api_key: str,
audio_api_base: str = "https://api.openai.com/v2/audio/transcriptions",
audio_api_base: str = "https://api.openai.com/v1/audio/transcriptions",
) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Access the 'text' key in the transcription dictionary using transcription['text'] instead of transcription.text.

headers = {"Authorization": f"Bearer {api_key}"}

Expand Down
1 change: 1 addition & 0 deletions py/core/pipelines/graph_enrichment.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def add_pipe(
*args,
**kwargs,
) -> None:
print("pipe = ", pipe)
logger.debug(
f"Adding pipe {pipe.config.name} to the KGEnrichmentPipeline"
)
Expand Down
3 changes: 3 additions & 0 deletions py/core/pipes/ingestion/kg_extraction_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ async def extract_kg(
"""
Extracts NER triples from a fragment with retries.
"""

logger.info(f"Extracting triples for fragment: {fragment.id}")

messages = self.prompt_provider._get_message_payload(
task_prompt_name=self.kg_provider.config.kg_extraction_prompt,
task_inputs={"input": fragment},
Expand Down
Loading
Loading