Skip to content

Commit

Permalink
Merge pull request #397 from vanna-ai/decider
Browse files Browse the repository at this point in the history
Add debugger, intermediate SQL, updated prompt
  • Loading branch information
zainhoda authored Apr 30, 2024
2 parents f2cec1f + 298bae1 commit 6787539
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 67 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ classifiers = [
"Operating System :: OS Independent",
]
dependencies = [
"requests", "tabulate", "plotly", "pandas", "sqlparse", "kaleido", "flask", "sqlalchemy"
"requests", "tabulate", "plotly", "pandas", "sqlparse", "kaleido", "flask", "flask-sock", "sqlalchemy"
]

[project.urls]
Expand Down
191 changes: 150 additions & 41 deletions src/vanna/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
import plotly.express as px
import plotly.graph_objects as go
import requests
import sqlparse

from ..exceptions import DependencyError, ImproperlyConfigured, ValidationError
from ..types import TrainingPlan, TrainingPlanItem
Expand All @@ -70,14 +71,25 @@

class VannaBase(ABC):
def __init__(self, config=None):
if config is None:
config = {}

self.config = config
self.run_sql_is_set = False
self.static_documentation = ""
self.dialect = self.config.get("dialect", "SQL")
self.language = self.config.get("language", None)

def log(self, message: str):
def log(self, message: str, title: str = "Info"):
print(message)

def generate_sql(self, question: str, **kwargs) -> str:
def _response_language(self) -> str:
if self.language is None:
return ""

return f"Respond in the {self.language} language."

def generate_sql(self, question: str, allow_llm_to_see_data=False, **kwargs) -> str:
"""
Example:
```python
Expand All @@ -99,6 +111,7 @@ def generate_sql(self, question: str, **kwargs) -> str:
Args:
question (str): The question to generate a SQL query for.
allow_llm_to_see_data (bool): Whether to allow the LLM to see the data (for the purposes of introspecting the data to generate the final SQL).
Returns:
str: The SQL query that answers the question.
Expand All @@ -118,45 +131,129 @@ def generate_sql(self, question: str, **kwargs) -> str:
doc_list=doc_list,
**kwargs,
)
self.log(prompt)
self.log(title="SQL Prompt", message=prompt)
llm_response = self.submit_prompt(prompt, **kwargs)
self.log(llm_response)
self.log(title="LLM Response", message=llm_response)

if 'intermediate_sql' in llm_response:
if not allow_llm_to_see_data:
return "The LLM is not allowed to see the data in your database. Your question requires database introspection to generate the necessary SQL. Please set allow_llm_to_see_data=True to enable this."

if allow_llm_to_see_data:
intermediate_sql = self.extract_sql(llm_response)

try:
self.log(title="Running Intermediate SQL", message=intermediate_sql)
df = self.run_sql(intermediate_sql)

prompt = self.get_sql_prompt(
initial_prompt=initial_prompt,
question=question,
question_sql_list=question_sql_list,
ddl_list=ddl_list,
doc_list=doc_list+[f"The following is a pandas DataFrame with the results of the intermediate SQL query {intermediate_sql}: \n" + df.to_markdown()],
**kwargs,
)
self.log(title="Final SQL Prompt", message=prompt)
llm_response = self.submit_prompt(prompt, **kwargs)
self.log(title="LLM Response", message=llm_response)
except Exception as e:
return f"Error running intermediate SQL: {e}"


return self.extract_sql(llm_response)

def extract_sql(self, llm_response: str) -> str:
# If the llm_response contains a CTE (with clause), extract the sql bewteen WITH and ;
sql = re.search(r"WITH.*?;", llm_response, re.DOTALL)
if sql:
self.log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(0)}")
return sql.group(0)
# If the llm_response is not markdown formatted, extract sql by finding select and ; in the response
sql = re.search(r"SELECT.*?;", llm_response, re.DOTALL)
if sql:
self.log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(0)}"
)
return sql.group(0)
"""
Example:
```python
vn.extract_sql("Here's the SQL query in a code block: ```sql\nSELECT * FROM customers\n```")
```
# If the llm_response contains a markdown code block, with or without the sql tag, extract the sql from it
sql = re.search(r"```sql\n(.*)```", llm_response, re.DOTALL)
if sql:
self.log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(1)}")
return sql.group(1)
Extracts the SQL query from the LLM response. This is useful in case the LLM response contains other information besides the SQL query.
Override this function if your LLM responses need custom extraction logic.
sql = re.search(r"```(.*)```", llm_response, re.DOTALL)
if sql:
self.log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(1)}")
return sql.group(1)
Args:
llm_response (str): The LLM response.
Returns:
str: The extracted SQL query.
"""

# If the llm_response contains a CTE (with clause), extract the last sql between WITH and ;
sqls = re.findall(r"WITH.*?;", llm_response, re.DOTALL)
if sqls:
sql = sqls[-1]
self.log(title="Extracted SQL", message=f"{sql}")
return sql

# If the llm_response is not markdown formatted, extract last sql by finding select and ; in the response
sqls = re.findall(r"SELECT.*?;", llm_response, re.DOTALL)
if sqls:
sql = sqls[-1]
self.log(title="Extracted SQL", message=f"{sql}")
return sql

# If the llm_response contains a markdown code block, with or without the sql tag, extract the last sql from it
sqls = re.findall(r"```sql\n(.*)```", llm_response, re.DOTALL)
if sqls:
sql = sqls[-1]
self.log(title="Extracted SQL", message=f"{sql}")
return sql

sqls = re.findall(r"```(.*)```", llm_response, re.DOTALL)
if sqls:
sql = sqls[-1]
self.log(title="Extracted SQL", message=f"{sql}")
return sql

return llm_response

def is_sql_valid(self, sql: str) -> bool:
# This is a check to see the SQL is valid and should be run
# This simple function just checks if the SQL contains a SELECT statement
"""
Example:
```python
vn.is_sql_valid("SELECT * FROM customers")
```
Checks if the SQL query is valid. This is usually used to check if we should run the SQL query or not.
By default it checks if the SQL query is a SELECT statement. You can override this method to enable running other types of SQL queries.
Args:
sql (str): The SQL query to check.
Returns:
bool: True if the SQL query is valid, False otherwise.
"""

parsed = sqlparse.parse(sql)

for statement in parsed:
if statement.get_type() == 'SELECT':
return True

if "SELECT" in sql.upper():
return False

def should_generate_chart(self, df: pd.DataFrame) -> bool:
"""
Example:
```python
vn.should_generate_chart(df)
```
Checks if a chart should be generated for the given DataFrame. By default, it checks if the DataFrame has more than one row and has numerical columns.
You can override this method to customize the logic for generating charts.
Args:
df (pd.DataFrame): The DataFrame to check.
Returns:
bool: True if a chart should be generated, False otherwise.
"""

if len(df) > 1 and df.select_dtypes(include=['number']).shape[1] > 0:
return True
else:
return False

return False

def generate_followup_questions(
self, question: str, sql: str, df: pd.DataFrame, n_questions: int = 5, **kwargs
Expand Down Expand Up @@ -184,7 +281,8 @@ def generate_followup_questions(
f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe SQL query for this question was: {sql}\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n"
),
self.user_message(
f"Generate a list of {n_questions} followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions. Remember that there should be an unambiguous SQL query that can be generated from the question. Prefer questions that are answerable outside of the context of this conversation. Prefer questions that are slight modifications of the SQL query that was generated that allow digging deeper into the data. Each question will be turned into a button that the user can click to generate a new SQL query so don't use 'example' type questions. Each question must have a one-to-one correspondence with an instantiated SQL query."
f"Generate a list of {n_questions} followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions. Remember that there should be an unambiguous SQL query that can be generated from the question. Prefer questions that are answerable outside of the context of this conversation. Prefer questions that are slight modifications of the SQL query that was generated that allow digging deeper into the data. Each question will be turned into a button that the user can click to generate a new SQL query so don't use 'example' type questions. Each question must have a one-to-one correspondence with an instantiated SQL query." +
self._response_language()
),
]

Expand Down Expand Up @@ -228,7 +326,8 @@ def generate_summary(self, question: str, df: pd.DataFrame, **kwargs) -> str:
f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n"
),
self.user_message(
"Briefly summarize the data based on the question that was asked. Do not respond with any additional explanation beyond the summary."
"Briefly summarize the data based on the question that was asked. Do not respond with any additional explanation beyond the summary." +
self._response_language()
),
]

Expand Down Expand Up @@ -375,7 +474,7 @@ def add_ddl_to_prompt(
self, initial_prompt: str, ddl_list: list[str], max_tokens: int = 14000
) -> str:
if len(ddl_list) > 0:
initial_prompt += "\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
initial_prompt += "\n===Tables \n"

for ddl in ddl_list:
if (
Expand All @@ -394,7 +493,7 @@ def add_documentation_to_prompt(
max_tokens: int = 14000,
) -> str:
if len(documentation_list) > 0:
initial_prompt += "\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
initial_prompt += "\n===Additional Context \n\n"

for documentation in documentation_list:
if (
Expand All @@ -410,7 +509,7 @@ def add_sql_to_prompt(
self, initial_prompt: str, sql_list: list[str], max_tokens: int = 14000
) -> str:
if len(sql_list) > 0:
initial_prompt += "\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
initial_prompt += "\n===Question-SQL Pairs\n\n"

for question in sql_list:
if (
Expand Down Expand Up @@ -456,7 +555,8 @@ def get_sql_prompt(
"""

if initial_prompt is None:
initial_prompt = "The user provides a question and you provide SQL. You will only respond with SQL code and not with any explanations.\n\nRespond with only SQL code. Do not answer with any explanations -- just the code.\n"
initial_prompt = f"You are a {self.dialect} expert. "
"Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "

initial_prompt = self.add_ddl_to_prompt(
initial_prompt, ddl_list, max_tokens=14000
Expand All @@ -469,6 +569,15 @@ def get_sql_prompt(
initial_prompt, doc_list, max_tokens=14000
)

initial_prompt += (
"===Response Guidelines \n"
"1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
"2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
"3. If the provided context is insufficient, please explain why it can't be generated. \n"
"4. Please use the most relevant table(s). \n"
"5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
)

message_log = [self.system_message(initial_prompt)]

for example in question_sql_list:
Expand Down Expand Up @@ -676,7 +785,7 @@ def run_sql_snowflake(sql: str) -> pd.DataFrame:

return df

self.static_documentation = "This is a Snowflake database"
self.dialect = "Snowflake SQL"
self.run_sql = run_sql_snowflake
self.run_sql_is_set = True

Expand Down Expand Up @@ -710,7 +819,7 @@ def connect_to_sqlite(self, url: str):
def run_sql_sqlite(sql: str):
return pd.read_sql_query(sql, conn)

self.static_documentation = "This is a SQLite database"
self.dialect = "SQLite"
self.run_sql = run_sql_sqlite
self.run_sql_is_set = True

Expand Down Expand Up @@ -815,7 +924,7 @@ def run_sql_postgres(sql: str) -> Union[pd.DataFrame, None]:
conn.rollback()
raise e

self.static_documentation = "This is a Postgres database"
self.dialect = "PostgreSQL"
self.run_sql_is_set = True
self.run_sql = run_sql_postgres

Expand Down Expand Up @@ -1078,7 +1187,7 @@ def run_sql_bigquery(sql: str) -> Union[pd.DataFrame, None]:
raise errors
return None

self.static_documentation = "This is a BigQuery database"
self.dialect = "BigQuery SQL"
self.run_sql_is_set = True
self.run_sql = run_sql_bigquery

Expand Down Expand Up @@ -1127,7 +1236,7 @@ def connect_to_duckdb(self, url: str, init_sql: str = None):
def run_sql_duckdb(sql: str):
return conn.query(sql).to_df()

self.static_documentation = "This is a DuckDB database"
self.dialect = "DuckDB SQL"
self.run_sql = run_sql_duckdb
self.run_sql_is_set = True

Expand Down Expand Up @@ -1174,7 +1283,7 @@ def run_sql_mssql(sql: str):

raise Exception("Couldn't run sql")

self.static_documentation = "This is a Microsoft SQL Server database"
self.dialect = "T-SQL / Microsoft SQL Server"
self.run_sql = run_sql_mssql
self.run_sql_is_set = True

Expand Down
Loading

0 comments on commit 6787539

Please sign in to comment.