Skip to content

Commit

Permalink
548: Added support for additional db connect options.
Browse files Browse the repository at this point in the history
fix
  • Loading branch information
pygeek committed Jul 11, 2024
1 parent 8cc20fb commit 61ad271
Showing 1 changed file with 74 additions and 50 deletions.
124 changes: 74 additions & 50 deletions src/vanna/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def get_sql_prompt(
question_sql_list: list,
ddl_list: list,
doc_list: list,
**kwargs,
**kwargs
):
"""
Example:
Expand Down Expand Up @@ -718,6 +718,7 @@ def connect_to_snowflake(
database: str,
role: Union[str, None] = None,
warehouse: Union[str, None] = None,
**kwargs
):
try:
snowflake = __import__("snowflake.connector")
Expand Down Expand Up @@ -764,7 +765,8 @@ def connect_to_snowflake(
password=password,
account=account,
database=database,
client_session_keep_alive=True
client_session_keep_alive=True,
**kwargs
)

def run_sql_snowflake(sql: str) -> pd.DataFrame:
Expand Down Expand Up @@ -831,6 +833,7 @@ def connect_to_postgres(
user: str = None,
password: str = None,
port: int = None,
**kwargs
):
"""
Connect to postgres using the psycopg2 connector. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
Expand Down Expand Up @@ -900,6 +903,7 @@ def connect_to_postgres(
user=user,
password=password,
port=port,
**kwargs
)
except psycopg2.Error as e:
raise ValidationError(e)
Expand Down Expand Up @@ -931,12 +935,13 @@ def run_sql_postgres(sql: str) -> Union[pd.DataFrame, None]:


def connect_to_mysql(
self,
host: str = None,
dbname: str = None,
user: str = None,
password: str = None,
port: int = None,
self,
host: str = None,
dbname: str = None,
user: str = None,
password: str = None,
port: int = None,
**kwargs
):

try:
Expand Down Expand Up @@ -980,12 +985,15 @@ def connect_to_mysql(
conn = None

try:
conn = pymysql.connect(host=host,
user=user,
password=password,
database=dbname,
port=port,
cursorclass=pymysql.cursors.DictCursor)
conn = pymysql.connect(
host=host,
user=user,
password=password,
database=dbname,
port=port,
cursorclass=pymysql.cursors.DictCursor,
**kwargs
)
except pymysql.Error as e:
raise ValidationError(e)

Expand Down Expand Up @@ -1015,12 +1023,13 @@ def run_sql_mysql(sql: str) -> Union[pd.DataFrame, None]:
self.run_sql = run_sql_mysql

def connect_to_clickhouse(
self,
host: str = None,
dbname: str = None,
user: str = None,
password: str = None,
port: int = None,
self,
host: str = None,
dbname: str = None,
user: str = None,
password: str = None,
port: int = None,
**kwargs
):

try:
Expand Down Expand Up @@ -1070,6 +1079,7 @@ def connect_to_clickhouse(
username=user,
password=password,
database=dbname,
**kwargs
)
print(conn)
except Exception as e:
Expand All @@ -1087,15 +1097,16 @@ def run_sql_clickhouse(sql: str) -> Union[pd.DataFrame, None]:

except Exception as e:
raise e

self.run_sql_is_set = True
self.run_sql = run_sql_clickhouse

def connect_to_oracle(
self,
user: str = None,
password: str = None,
dsn: str = None,
self,
user: str = None,
password: str = None,
dsn: str = None,
**kwargs
):

"""
Expand Down Expand Up @@ -1148,7 +1159,8 @@ def connect_to_oracle(
user=user,
password=password,
dsn=dsn,
)
**kwargs
)
except oracledb.Error as e:
raise ValidationError(e)

Expand Down Expand Up @@ -1180,7 +1192,12 @@ def run_sql_oracle(sql: str) -> Union[pd.DataFrame, None]:
self.run_sql_is_set = True
self.run_sql = run_sql_oracle

def connect_to_bigquery(self, cred_file_path: str = None, project_id: str = None):
def connect_to_bigquery(
self,
cred_file_path: str = None,
project_id: str = None,
**kwargs
):
"""
Connect to gcs using the bigquery connector. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
**Example:**
Expand Down Expand Up @@ -1242,7 +1259,11 @@ def connect_to_bigquery(self, cred_file_path: str = None, project_id: str = None
)

try:
conn = bigquery.Client(project=project_id, credentials=credentials)
conn = bigquery.Client(
project=project_id,
credentials=credentials,
**kwargs
)
except:
raise ImproperlyConfigured(
"Could not connect to bigquery please correct credentials"
Expand All @@ -1265,7 +1286,7 @@ def run_sql_bigquery(sql: str) -> Union[pd.DataFrame, None]:
self.run_sql_is_set = True
self.run_sql = run_sql_bigquery

def connect_to_duckdb(self, url: str, init_sql: str = None):
def connect_to_duckdb(self, url: str, init_sql: str = None, **kwargs):
"""
Connect to a DuckDB database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
Expand Down Expand Up @@ -1303,7 +1324,7 @@ def connect_to_duckdb(self, url: str, init_sql: str = None):
f.write(response.content)

# Connect to the database
conn = duckdb.connect(path)
conn = duckdb.connect(path, **kwargs)
if init_sql:
conn.query(init_sql)

Expand All @@ -1314,7 +1335,7 @@ def run_sql_duckdb(sql: str):
self.run_sql = run_sql_duckdb
self.run_sql_is_set = True

def connect_to_mssql(self, odbc_conn_str: str):
def connect_to_mssql(self, odbc_conn_str: str, **kwargs):
"""
Connect to a Microsoft SQL Server database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
Expand Down Expand Up @@ -1347,7 +1368,7 @@ def connect_to_mssql(self, odbc_conn_str: str):

from sqlalchemy import create_engine

engine = create_engine(connection_url)
engine = create_engine(connection_url, **kwargs)

def run_sql_mssql(sql: str):
# Execute the SQL statement and return the result as a pandas DataFrame
Expand All @@ -1362,16 +1383,17 @@ def run_sql_mssql(sql: str):
self.run_sql = run_sql_mssql
self.run_sql_is_set = True
def connect_to_presto(
self,
host: str,
catalog: str = 'hive',
schema: str = 'default',
user: str = None,
password: str = None,
port: int = None,
combined_pem_path: str = None,
protocol: str = 'https',
requests_kwargs: dict = None
self,
host: str,
catalog: str = 'hive',
schema: str = 'default',
user: str = None,
password: str = None,
port: int = None,
combined_pem_path: str = None,
protocol: str = 'https',
requests_kwargs: dict = None,
**kwargs
):
"""
Connect to a Presto database using the specified parameters.
Expand Down Expand Up @@ -1444,7 +1466,8 @@ def connect_to_presto(
schema=schema,
port=port,
protocol=protocol,
requests_kwargs=requests_kwargs)
requests_kwargs=requests_kwargs,
**kwargs)
except presto.Error as e:
raise ValidationError(e)

Expand Down Expand Up @@ -1477,13 +1500,14 @@ def run_sql_presto(sql: str) -> Union[pd.DataFrame, None]:
self.run_sql = run_sql_presto

def connect_to_hive(
self,
host: str = None,
dbname: str = 'default',
user: str = None,
password: str = None,
port: int = None,
auth: str = 'CUSTOM'
self,
host: str = None,
dbname: str = 'default',
user: str = None,
password: str = None,
port: int = None,
auth: str = 'CUSTOM',
**kwargs
):
"""
Connect to a Hive database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
Expand Down

0 comments on commit 61ad271

Please sign in to comment.