Skip to content

Commit

Permalink
Merge pull request #415 from zyclove/support-presto
Browse files Browse the repository at this point in the history
feat:add support presto
  • Loading branch information
zainhoda authored May 6, 2024
2 parents 7b962d5 + 8a7242c commit 7225854
Showing 1 changed file with 110 additions and 0 deletions.
110 changes: 110 additions & 0 deletions src/vanna/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,6 +1287,116 @@ def run_sql_mssql(sql: str):
self.dialect = "T-SQL / Microsoft SQL Server"
self.run_sql = run_sql_mssql
self.run_sql_is_set = True
def connect_to_presto(
self,
host: str,
catalog: str = 'hive',
schema: str = 'default',
user: str = None,
password: str = None,
port: int = None,
combined_pem_path: str = None,
protocol: str = 'https',
requests_kwargs: dict = None
):
"""
Connect to a Presto database using the specified parameters.
Args:
host (str): The host address of the Presto database.
catalog (str): The catalog to use in the Presto environment.
schema (str): The schema to use in the Presto environment.
user (str): The username for authentication.
password (str): The password for authentication.
port (int): The port number for the Presto connection.
combined_pem_path (str): The path to the combined pem file for SSL connection.
protocol (str): The protocol to use for the connection (default is 'https').
requests_kwargs (dict): Additional keyword arguments for requests.
Raises:
DependencyError: If required dependencies are not installed.
ImproperlyConfigured: If essential configuration settings are missing.
Returns:
None
"""
try:
from pyhive import presto
except ImportError:
raise DependencyError(
"You need to install required dependencies to execute this method,"
" run command: \npip install pyhive"
)

if not host:
host = os.getenv("PRESTO_HOST")

if not host:
raise ImproperlyConfigured("Please set your presto host")

if not catalog:
catalog = os.getenv("PRESTO_CATALOG")

if not catalog:
raise ImproperlyConfigured("Please set your presto catalog")

if not user:
user = os.getenv("PRESTO_USER")

if not user:
raise ImproperlyConfigured("Please set your presto user")

if not password:
password = os.getenv("PRESTO_PASSWORD")

if not port:
port = os.getenv("PRESTO_PORT")

if not port:
raise ImproperlyConfigured("Please set your presto port")

conn = None

try:
if requests_kwargs is None and combined_pem_path is not None:
# use the combined pem file to verify the SSL connection
requests_kwargs = {
'verify': combined_pem_path, # 使用转换后得到的 PEM 文件进行 SSL 验证
}
conn = presto.Connection(host=host,
username=user,
password=password,
catalog=catalog,
schema=schema,
port=port,
protocol=protocol,
requests_kwargs=requests_kwargs)
except presto.Error as e:
raise ValidationError(e)

def run_sql_presto(sql: str) -> Union[pd.DataFrame, None]:
if conn:
try:
cs = conn.cursor()
cs.execute(sql)
results = cs.fetchall()

# Create a pandas dataframe from the results
df = pd.DataFrame(
results, columns=[desc[0] for desc in cs.description]
)
return df

except presto.Error as e:
print(e)
raise ValidationError(e)

except Exception as e:
print(e)
raise e

self.run_sql_is_set = True
self.run_sql = run_sql_presto

def run_sql(self, sql: str, **kwargs) -> pd.DataFrame:
"""
Expand Down

0 comments on commit 7225854

Please sign in to comment.