Skip to content

Commit

Permalink
add a global project config (#1333)
Browse files Browse the repository at this point in the history
  • Loading branch information
emrgnt-cmplxty authored Oct 4, 2024
1 parent 4cb83fb commit ff3d746
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 21 deletions.
6 changes: 4 additions & 2 deletions py/core/base/logging/run_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,9 @@ def __init__(self, config: PostgresLoggingConfig):
self.log_table = config.log_table
self.log_info_table = config.log_info_table
self.config = config
self.project_name = os.getenv("R2R_PROJECT_NAME", "default")
self.project_name = config.app.project_name or os.getenv(
"R2R_PROJECT_NAME", "default"
)
self.pool = None
if not os.getenv("POSTGRES_DBNAME"):
raise ValueError(
Expand Down Expand Up @@ -581,7 +583,7 @@ def get_instance(cls):
return cls.SUPPORTED_PROVIDERS[cls._config.provider](cls._config)

@classmethod
def configure(cls, logging_config: LoggingConfig = LoggingConfig()):
def configure(cls, logging_config: LoggingConfig):
if not cls._is_configured:
cls._config = logging_config
cls._is_configured = True
Expand Down
3 changes: 2 additions & 1 deletion py/core/base/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .auth import AuthConfig, AuthProvider
from .base import Provider, ProviderConfig
from .base import AppConfig, Provider, ProviderConfig
from .crypto import CryptoConfig, CryptoProvider
from .database import (
DatabaseConfig,
Expand All @@ -21,6 +21,7 @@
"AuthConfig",
"AuthProvider",
# Base provider classes
"AppConfig",
"Provider",
"ProviderConfig",
# Ingestion provider
Expand Down
12 changes: 12 additions & 0 deletions py/core/base/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,22 @@

from pydantic import BaseModel

from ..abstractions import R2RSerializable


class AppConfig(R2RSerializable):
project_name: Optional[str] = None

@classmethod
def create(cls, *args, **kwargs):
project_name = kwargs.get("project_name")
return AppConfig(project_name=project_name)


class ProviderConfig(BaseModel, ABC):
"""A base provider configuration class"""

app: AppConfig # Add an app_config field
extra_fields: dict[str, Any] = {}
provider: Optional[str] = None

Expand Down
29 changes: 16 additions & 13 deletions py/core/main/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ..base.abstractions import GenerationConfig
from ..base.agent.agent import AgentConfig
from ..base.logging.run_logger import LoggingConfig
from ..base.providers import AppConfig
from ..base.providers.auth import AuthConfig
from ..base.providers.crypto import CryptoConfig
from ..base.providers.database import DatabaseConfig
Expand Down Expand Up @@ -40,6 +41,7 @@ class R2RConfig:
CONFIG_OPTIONS["default"] = None

REQUIRED_KEYS: dict[str, list] = {
"app": [],
"completion": ["provider"],
"crypto": ["provider"],
"auth": ["provider"],
Expand All @@ -64,6 +66,7 @@ class R2RConfig:
"orchestration": ["provider"],
}

app: AppConfig
auth: AuthConfig
completion: CompletionConfig
crypto: CryptoConfig
Expand Down Expand Up @@ -118,24 +121,24 @@ def __init__(
)
setattr(self, section, default_config[section])

self.completion = CompletionConfig.create(**self.completion) # type: ignore
self.app = AppConfig.create(**self.app) # type: ignore
self.auth = AuthConfig.create(**self.auth, app=self.app) # type: ignore
self.completion = CompletionConfig.create(**self.completion, app=self.app) # type: ignore
self.crypto = CryptoConfig.create(**self.crypto, app=self.app) # type: ignore
self.database = DatabaseConfig.create(**self.database, app=self.app) # type: ignore
self.embedding = EmbeddingConfig.create(**self.embedding, app=self.app) # type: ignore
self.ingestion = IngestionConfig.create(**self.ingestion, app=self.app) # type: ignore
self.kg = KGConfig.create(**self.kg, app=self.app) # type: ignore
self.logging = LoggingConfig.create(**self.logging, app=self.app) # type: ignore
self.prompt = PromptConfig.create(**self.prompt, app=self.app) # type: ignore
self.agent = AgentConfig.create(**self.agent, app=self.app) # type: ignore
self.file = FileConfig.create(**self.file, app=self.app) # type: ignore
self.orchestration = OrchestrationConfig.create(**self.orchestration, app=self.app) # type: ignore
# override GenerationConfig defaults
GenerationConfig.set_default(
**self.completion.generation_config.dict()
)

self.auth = AuthConfig.create(**self.auth) # type: ignore
self.crypto = CryptoConfig.create(**self.crypto) # type: ignore
self.database = DatabaseConfig.create(**self.database) # type: ignore
self.embedding = EmbeddingConfig.create(**self.embedding) # type: ignore
self.ingestion = IngestionConfig.create(**self.ingestion) # type: ignore
self.kg = KGConfig.create(**self.kg) # type: ignore
self.logging = LoggingConfig.create(**self.logging) # type: ignore
self.prompt = PromptConfig.create(**self.prompt) # type: ignore
self.agent = AgentConfig.create(**self.agent) # type: ignore
self.file = FileConfig.create(**self.file) # type: ignore
self.orchestration = OrchestrationConfig.create(**self.orchestration) # type: ignore

def _validate_config_section(
self, config_data: dict[str, Any], section: str, keys: list
):
Expand Down
6 changes: 3 additions & 3 deletions py/core/main/orchestration/hatchet/kg_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ def get_input_data_dict(input_data):
class KGExtractDescribeEmbedWorkflow:
def __init__(self, kg_service: KgService):
self.kg_service = kg_service

@orchestration_provider.concurrency(
max_runs=orchestration_provider.config.kg_creation_concurrency_limit,
limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN
max_runs=orchestration_provider.config.kg_creation_concurrency_limit,
limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN,
)
def concurrency(self, context) -> str:
return str(context.workflow_input()["request"]["collection_id"])
Expand Down
2 changes: 1 addition & 1 deletion py/core/providers/database/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(
self.db_name = db_name

project_name = (
config.project_name
config.app.project_name
or os.getenv("R2R_PROJECT_NAME")
# Remove the following line after deprecation
or os.getenv("POSTGRES_PROJECT_NAME")
Expand Down
2 changes: 1 addition & 1 deletion py/core/providers/orchestration/hatchet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_worker(self, name: str, max_threads: Optional[int] = None) -> Any:
max_threads = self.config.max_threads
self.worker = self.orchestrator.worker(name, max_threads)
return self.worker

def concurrency(self, *args, **kwargs) -> Callable:
return self.orchestrator.concurrency(*args, **kwargs)

Expand Down
4 changes: 4 additions & 0 deletions py/r2r.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
[app]
# app settings are global available like `r2r_config.agent.app`
# project_name = "my_project" # optional, can also set with `R2R_PROJECT_NAME` env var

[agent]
system_instruction_name = "rag_agent"
tool_names = ["search"]
Expand Down

0 comments on commit ff3d746

Please sign in to comment.