diff --git a/py/core/base/logging/run_logger.py b/py/core/base/logging/run_logger.py index 6f562908c..67b7575e3 100644 --- a/py/core/base/logging/run_logger.py +++ b/py/core/base/logging/run_logger.py @@ -344,7 +344,9 @@ def __init__(self, config: PostgresLoggingConfig): self.log_table = config.log_table self.log_info_table = config.log_info_table self.config = config - self.project_name = os.getenv("R2R_PROJECT_NAME", "default") + self.project_name = config.app.project_name or os.getenv( + "R2R_PROJECT_NAME", "default" + ) self.pool = None if not os.getenv("POSTGRES_DBNAME"): raise ValueError( @@ -581,7 +583,7 @@ def get_instance(cls): return cls.SUPPORTED_PROVIDERS[cls._config.provider](cls._config) @classmethod - def configure(cls, logging_config: LoggingConfig = LoggingConfig()): + def configure(cls, logging_config: LoggingConfig): if not cls._is_configured: cls._config = logging_config cls._is_configured = True diff --git a/py/core/base/providers/__init__.py b/py/core/base/providers/__init__.py index 7fa7341e5..306eb0ad5 100644 --- a/py/core/base/providers/__init__.py +++ b/py/core/base/providers/__init__.py @@ -1,5 +1,5 @@ from .auth import AuthConfig, AuthProvider -from .base import Provider, ProviderConfig +from .base import AppConfig, Provider, ProviderConfig from .crypto import CryptoConfig, CryptoProvider from .database import ( DatabaseConfig, @@ -21,6 +21,7 @@ "AuthConfig", "AuthProvider", # Base provider classes + "AppConfig", "Provider", "ProviderConfig", # Ingestion provider diff --git a/py/core/base/providers/base.py b/py/core/base/providers/base.py index 30f6211cd..052ae93d6 100644 --- a/py/core/base/providers/base.py +++ b/py/core/base/providers/base.py @@ -3,10 +3,22 @@ from pydantic import BaseModel +from ..abstractions import R2RSerializable + + +class AppConfig(R2RSerializable): + project_name: Optional[str] = None + + @classmethod + def create(cls, *args, **kwargs): + project_name = kwargs.get("project_name") + return AppConfig(project_name=project_name) + class ProviderConfig(BaseModel, ABC): """A base provider configuration class""" + app: AppConfig # Add an app_config field extra_fields: dict[str, Any] = {} provider: Optional[str] = None diff --git a/py/core/main/config.py b/py/core/main/config.py index 37798af1c..e5de2ef2f 100644 --- a/py/core/main/config.py +++ b/py/core/main/config.py @@ -10,6 +10,7 @@ from ..base.abstractions import GenerationConfig from ..base.agent.agent import AgentConfig from ..base.logging.run_logger import LoggingConfig +from ..base.providers import AppConfig from ..base.providers.auth import AuthConfig from ..base.providers.crypto import CryptoConfig from ..base.providers.database import DatabaseConfig @@ -40,6 +41,7 @@ class R2RConfig: CONFIG_OPTIONS["default"] = None REQUIRED_KEYS: dict[str, list] = { + "app": [], "completion": ["provider"], "crypto": ["provider"], "auth": ["provider"], @@ -64,6 +66,7 @@ class R2RConfig: "orchestration": ["provider"], } + app: AppConfig auth: AuthConfig completion: CompletionConfig crypto: CryptoConfig @@ -118,24 +121,24 @@ def __init__( ) setattr(self, section, default_config[section]) - self.completion = CompletionConfig.create(**self.completion) # type: ignore + self.app = AppConfig.create(**self.app) # type: ignore + self.auth = AuthConfig.create(**self.auth, app=self.app) # type: ignore + self.completion = CompletionConfig.create(**self.completion, app=self.app) # type: ignore + self.crypto = CryptoConfig.create(**self.crypto, app=self.app) # type: ignore + self.database = DatabaseConfig.create(**self.database, app=self.app) # type: ignore + self.embedding = EmbeddingConfig.create(**self.embedding, app=self.app) # type: ignore + self.ingestion = IngestionConfig.create(**self.ingestion, app=self.app) # type: ignore + self.kg = KGConfig.create(**self.kg, app=self.app) # type: ignore + self.logging = LoggingConfig.create(**self.logging, app=self.app) # type: ignore + self.prompt = PromptConfig.create(**self.prompt, app=self.app) # type: ignore + self.agent = AgentConfig.create(**self.agent, app=self.app) # type: ignore + self.file = FileConfig.create(**self.file, app=self.app) # type: ignore + self.orchestration = OrchestrationConfig.create(**self.orchestration, app=self.app) # type: ignore # override GenerationConfig defaults GenerationConfig.set_default( **self.completion.generation_config.dict() ) - self.auth = AuthConfig.create(**self.auth) # type: ignore - self.crypto = CryptoConfig.create(**self.crypto) # type: ignore - self.database = DatabaseConfig.create(**self.database) # type: ignore - self.embedding = EmbeddingConfig.create(**self.embedding) # type: ignore - self.ingestion = IngestionConfig.create(**self.ingestion) # type: ignore - self.kg = KGConfig.create(**self.kg) # type: ignore - self.logging = LoggingConfig.create(**self.logging) # type: ignore - self.prompt = PromptConfig.create(**self.prompt) # type: ignore - self.agent = AgentConfig.create(**self.agent) # type: ignore - self.file = FileConfig.create(**self.file) # type: ignore - self.orchestration = OrchestrationConfig.create(**self.orchestration) # type: ignore - def _validate_config_section( self, config_data: dict[str, Any], section: str, keys: list ): diff --git a/py/core/main/orchestration/hatchet/kg_workflow.py b/py/core/main/orchestration/hatchet/kg_workflow.py index 68abbb39d..644ae1a85 100644 --- a/py/core/main/orchestration/hatchet/kg_workflow.py +++ b/py/core/main/orchestration/hatchet/kg_workflow.py @@ -40,10 +40,10 @@ def get_input_data_dict(input_data): class KGExtractDescribeEmbedWorkflow: def __init__(self, kg_service: KgService): self.kg_service = kg_service - + @orchestration_provider.concurrency( - max_runs=orchestration_provider.config.kg_creation_concurrency_limit, - limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN + max_runs=orchestration_provider.config.kg_creation_concurrency_limit, + limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, ) def concurrency(self, context) -> str: return str(context.workflow_input()["request"]["collection_id"]) diff --git a/py/core/providers/database/postgres.py b/py/core/providers/database/postgres.py index 5f8d58a05..fe945ba36 100644 --- a/py/core/providers/database/postgres.py +++ b/py/core/providers/database/postgres.py @@ -72,7 +72,7 @@ def __init__( self.db_name = db_name project_name = ( - config.project_name + config.app.project_name or os.getenv("R2R_PROJECT_NAME") # Remove the following line after deprecation or os.getenv("POSTGRES_PROJECT_NAME") diff --git a/py/core/providers/orchestration/hatchet.py b/py/core/providers/orchestration/hatchet.py index 7b430e2ec..d641d4705 100644 --- a/py/core/providers/orchestration/hatchet.py +++ b/py/core/providers/orchestration/hatchet.py @@ -34,7 +34,7 @@ def get_worker(self, name: str, max_threads: Optional[int] = None) -> Any: max_threads = self.config.max_threads self.worker = self.orchestrator.worker(name, max_threads) return self.worker - + def concurrency(self, *args, **kwargs) -> Callable: return self.orchestrator.concurrency(*args, **kwargs) diff --git a/py/r2r.toml b/py/r2r.toml index 8867a1b32..10e98fc4b 100644 --- a/py/r2r.toml +++ b/py/r2r.toml @@ -1,3 +1,7 @@ +[app] +# app settings are global available like `r2r_config.agent.app` +# project_name = "my_project" # optional, can also set with `R2R_PROJECT_NAME` env var + [agent] system_instruction_name = "rag_agent" tool_names = ["search"]