diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index 29c7e2dc..8f1f1502 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -1287,6 +1287,116 @@ def run_sql_mssql(sql: str): self.dialect = "T-SQL / Microsoft SQL Server" self.run_sql = run_sql_mssql self.run_sql_is_set = True + def connect_to_presto( + self, + host: str, + catalog: str = 'hive', + schema: str = 'default', + user: str = None, + password: str = None, + port: int = None, + combined_pem_path: str = None, + protocol: str = 'https', + requests_kwargs: dict = None + ): + """ + Connect to a Presto database using the specified parameters. + + Args: + host (str): The host address of the Presto database. + catalog (str): The catalog to use in the Presto environment. + schema (str): The schema to use in the Presto environment. + user (str): The username for authentication. + password (str): The password for authentication. + port (int): The port number for the Presto connection. + combined_pem_path (str): The path to the combined pem file for SSL connection. + protocol (str): The protocol to use for the connection (default is 'https'). + requests_kwargs (dict): Additional keyword arguments for requests. + + Raises: + DependencyError: If required dependencies are not installed. + ImproperlyConfigured: If essential configuration settings are missing. + + Returns: + None + """ + try: + from pyhive import presto + except ImportError: + raise DependencyError( + "You need to install required dependencies to execute this method," + " run command: \npip install pyhive" + ) + + if not host: + host = os.getenv("PRESTO_HOST") + + if not host: + raise ImproperlyConfigured("Please set your presto host") + + if not catalog: + catalog = os.getenv("PRESTO_CATALOG") + + if not catalog: + raise ImproperlyConfigured("Please set your presto catalog") + + if not user: + user = os.getenv("PRESTO_USER") + + if not user: + raise ImproperlyConfigured("Please set your presto user") + + if not password: + password = os.getenv("PRESTO_PASSWORD") + + if not port: + port = os.getenv("PRESTO_PORT") + + if not port: + raise ImproperlyConfigured("Please set your presto port") + + conn = None + + try: + if requests_kwargs is None and combined_pem_path is not None: + # use the combined pem file to verify the SSL connection + requests_kwargs = { + 'verify': combined_pem_path, # 使用转换后得到的 PEM 文件进行 SSL 验证 + } + conn = presto.Connection(host=host, + username=user, + password=password, + catalog=catalog, + schema=schema, + port=port, + protocol=protocol, + requests_kwargs=requests_kwargs) + except presto.Error as e: + raise ValidationError(e) + + def run_sql_presto(sql: str) -> Union[pd.DataFrame, None]: + if conn: + try: + cs = conn.cursor() + cs.execute(sql) + results = cs.fetchall() + + # Create a pandas dataframe from the results + df = pd.DataFrame( + results, columns=[desc[0] for desc in cs.description] + ) + return df + + except presto.Error as e: + print(e) + raise ValidationError(e) + + except Exception as e: + print(e) + raise e + + self.run_sql_is_set = True + self.run_sql = run_sql_presto def run_sql(self, sql: str, **kwargs) -> pd.DataFrame: """