diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index 492516ead..fb09bb1be 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -529,7 +529,7 @@ def get_sql_prompt( question_sql_list: list, ddl_list: list, doc_list: list, - **kwargs, + **kwargs ): """ Example: @@ -718,6 +718,7 @@ def connect_to_snowflake( database: str, role: Union[str, None] = None, warehouse: Union[str, None] = None, + **kwargs ): try: snowflake = __import__("snowflake.connector") @@ -764,7 +765,8 @@ def connect_to_snowflake( password=password, account=account, database=database, - client_session_keep_alive=True + client_session_keep_alive=True, + **kwargs ) def run_sql_snowflake(sql: str) -> pd.DataFrame: @@ -831,6 +833,7 @@ def connect_to_postgres( user: str = None, password: str = None, port: int = None, + **kwargs ): """ Connect to postgres using the psycopg2 connector. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql] @@ -900,6 +903,7 @@ def connect_to_postgres( user=user, password=password, port=port, + **kwargs ) except psycopg2.Error as e: raise ValidationError(e) @@ -931,12 +935,13 @@ def run_sql_postgres(sql: str) -> Union[pd.DataFrame, None]: def connect_to_mysql( - self, - host: str = None, - dbname: str = None, - user: str = None, - password: str = None, - port: int = None, + self, + host: str = None, + dbname: str = None, + user: str = None, + password: str = None, + port: int = None, + **kwargs ): try: @@ -980,12 +985,15 @@ def connect_to_mysql( conn = None try: - conn = pymysql.connect(host=host, - user=user, - password=password, - database=dbname, - port=port, - cursorclass=pymysql.cursors.DictCursor) + conn = pymysql.connect( + host=host, + user=user, + password=password, + database=dbname, + port=port, + cursorclass=pymysql.cursors.DictCursor, + **kwargs + ) except pymysql.Error as e: raise ValidationError(e) @@ -1015,12 +1023,13 @@ def run_sql_mysql(sql: str) -> Union[pd.DataFrame, None]: self.run_sql = run_sql_mysql def connect_to_clickhouse( - self, - host: str = None, - dbname: str = None, - user: str = None, - password: str = None, - port: int = None, + self, + host: str = None, + dbname: str = None, + user: str = None, + password: str = None, + port: int = None, + **kwargs ): try: @@ -1070,6 +1079,7 @@ def connect_to_clickhouse( username=user, password=password, database=dbname, + **kwargs ) print(conn) except Exception as e: @@ -1087,15 +1097,16 @@ def run_sql_clickhouse(sql: str) -> Union[pd.DataFrame, None]: except Exception as e: raise e - + self.run_sql_is_set = True self.run_sql = run_sql_clickhouse def connect_to_oracle( - self, - user: str = None, - password: str = None, - dsn: str = None, + self, + user: str = None, + password: str = None, + dsn: str = None, + **kwargs ): """ @@ -1148,7 +1159,8 @@ def connect_to_oracle( user=user, password=password, dsn=dsn, - ) + **kwargs + ) except oracledb.Error as e: raise ValidationError(e) @@ -1180,7 +1192,12 @@ def run_sql_oracle(sql: str) -> Union[pd.DataFrame, None]: self.run_sql_is_set = True self.run_sql = run_sql_oracle - def connect_to_bigquery(self, cred_file_path: str = None, project_id: str = None): + def connect_to_bigquery( + self, + cred_file_path: str = None, + project_id: str = None, + **kwargs + ): """ Connect to gcs using the bigquery connector. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql] **Example:** @@ -1242,7 +1259,11 @@ def connect_to_bigquery(self, cred_file_path: str = None, project_id: str = None ) try: - conn = bigquery.Client(project=project_id, credentials=credentials) + conn = bigquery.Client( + project=project_id, + credentials=credentials, + **kwargs + ) except: raise ImproperlyConfigured( "Could not connect to bigquery please correct credentials" @@ -1265,7 +1286,7 @@ def run_sql_bigquery(sql: str) -> Union[pd.DataFrame, None]: self.run_sql_is_set = True self.run_sql = run_sql_bigquery - def connect_to_duckdb(self, url: str, init_sql: str = None): + def connect_to_duckdb(self, url: str, init_sql: str = None, **kwargs): """ Connect to a DuckDB database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql] @@ -1303,7 +1324,7 @@ def connect_to_duckdb(self, url: str, init_sql: str = None): f.write(response.content) # Connect to the database - conn = duckdb.connect(path) + conn = duckdb.connect(path, **kwargs) if init_sql: conn.query(init_sql) @@ -1314,7 +1335,7 @@ def run_sql_duckdb(sql: str): self.run_sql = run_sql_duckdb self.run_sql_is_set = True - def connect_to_mssql(self, odbc_conn_str: str): + def connect_to_mssql(self, odbc_conn_str: str, **kwargs): """ Connect to a Microsoft SQL Server database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql] @@ -1347,7 +1368,7 @@ def connect_to_mssql(self, odbc_conn_str: str): from sqlalchemy import create_engine - engine = create_engine(connection_url) + engine = create_engine(connection_url, **kwargs) def run_sql_mssql(sql: str): # Execute the SQL statement and return the result as a pandas DataFrame @@ -1362,16 +1383,17 @@ def run_sql_mssql(sql: str): self.run_sql = run_sql_mssql self.run_sql_is_set = True def connect_to_presto( - self, - host: str, - catalog: str = 'hive', - schema: str = 'default', - user: str = None, - password: str = None, - port: int = None, - combined_pem_path: str = None, - protocol: str = 'https', - requests_kwargs: dict = None + self, + host: str, + catalog: str = 'hive', + schema: str = 'default', + user: str = None, + password: str = None, + port: int = None, + combined_pem_path: str = None, + protocol: str = 'https', + requests_kwargs: dict = None, + **kwargs ): """ Connect to a Presto database using the specified parameters. @@ -1444,7 +1466,8 @@ def connect_to_presto( schema=schema, port=port, protocol=protocol, - requests_kwargs=requests_kwargs) + requests_kwargs=requests_kwargs, + **kwargs) except presto.Error as e: raise ValidationError(e) @@ -1477,13 +1500,14 @@ def run_sql_presto(sql: str) -> Union[pd.DataFrame, None]: self.run_sql = run_sql_presto def connect_to_hive( - self, - host: str = None, - dbname: str = 'default', - user: str = None, - password: str = None, - port: int = None, - auth: str = 'CUSTOM' + self, + host: str = None, + dbname: str = 'default', + user: str = None, + password: str = None, + port: int = None, + auth: str = 'CUSTOM', + **kwargs ): """ Connect to a Hive database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]