diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index 51f2dac7..35e49d4c 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -902,13 +902,13 @@ def run_sql_mysql(sql: str) -> Union[pd.DataFrame, None]: self.run_sql_is_set = True self.run_sql = run_sql_mysql - def connect_to_oracle( + def connect_to_oracle( self, user: str = None, password: str = None, dsn: str = None, ): - + """ Connect to an Oracle db using oracledb package. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql] **Example:** @@ -953,12 +953,12 @@ def connect_to_oracle( raise ImproperlyConfigured("Please set your Oracle db password") conn = None - + try: conn = oracledb.connect( user=user, password=password, - dsn=dsn, + dsn=dsn, ) except oracledb.Error as e: raise ValidationError(e) @@ -966,10 +966,10 @@ def connect_to_oracle( def run_sql_oracle(sql: str) -> Union[pd.DataFrame, None]: if conn: try: - sql = sql.rstrip() + sql = sql.rstrip() if sql.endswith(';'): #fix for a known problem with Oracle db where an extra ; will cause an error. - sql = sql[:-1] - + sql = sql[:-1] + cs = conn.cursor() cs.execute(sql) results = cs.fetchall() @@ -981,7 +981,7 @@ def run_sql_oracle(sql: str) -> Union[pd.DataFrame, None]: return df except oracledb.Error as e: - conn.rollback() + conn.rollback() raise ValidationError(e) except Exception as e: