Skip to content

Commit

Permalink
Merge pull request #356 from vipgupta/gemini_vanna_integration
Browse files Browse the repository at this point in the history
Adding Google Gemini Chat integration to Vanna
  • Loading branch information
zainhoda authored Apr 16, 2024
2 parents bec5f66 + 98a9ff6 commit d2a5cf0
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 3 deletions.
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ jobs:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
SNOWFLAKE_ACCOUNT: ${{ secrets.SNOWFLAKE_ACCOUNT }}
SNOWFLAKE_USERNAME: ${{ secrets.SNOWFLAKE_USERNAME }}
SNOWFLAKE_PASSWORD: ${{ secrets.SNOWFLAKE_PASSWORD }}
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ mysql = ["PyMySQL"]
bigquery = ["google-cloud-bigquery"]
snowflake = ["snowflake-connector-python"]
duckdb = ["duckdb"]
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo"]
google = ["google-generativeai", "google-cloud-aiplatform"]
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform"]
test = ["tox"]
chromadb = ["chromadb"]
openai = ["openai"]
Expand Down
1 change: 1 addition & 0 deletions src/vanna/google/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .gemini_chat import GoogleGeminiChat
52 changes: 52 additions & 0 deletions src/vanna/google/gemini_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os
from ..base import VannaBase


class GoogleGeminiChat(VannaBase):
def __init__(self, config=None):
VannaBase.__init__(self, config=config)

# default temperature - can be overrided using config
self.temperature = 0.7

if "temperature" in config:
self.temperature = config["temperature"]

if "model_name" in config:
model_name = config["model_name"]
else:
model_name = "gemini-1.0-pro"

self.google_api_key = None

if "api_key" in config or os.getenv("GOOGLE_API_KEY"):
"""
If Google api_key is provided through config
or set as an environment variable, assign it.
"""
import google.generativeai as genai

genai.configure(api_key=config["api_key"])
self.chat_model = genai.GenerativeModel(model_name)
else:
# Authenticate using VertexAI
from vertexai.preview.generative_models import GenerativeModel
self.chat_model = GenerativeModel("gemini-pro")

def system_message(self, message: str) -> any:
return message

def user_message(self, message: str) -> any:
return message

def assistant_message(self, message: str) -> any:
return message

def submit_prompt(self, prompt, **kwargs) -> str:
response = self.chat_model.generate_content(
prompt,
generation_config={
"temperature": self.temperature,
},
)
return response.text
18 changes: 16 additions & 2 deletions tests/test_vanna.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

from vanna.anthropic.anthropic_chat import Anthropic_Chat
from vanna.google import GoogleGeminiChat
from vanna.mistral.mistral import Mistral
from vanna.openai.openai_chat import OpenAI_Chat
from vanna.remote import VannaDefault
Expand Down Expand Up @@ -92,9 +93,22 @@ def __init__(self, config=None):


def test_vn_claude():
sql = vn_claude.generate_sql("What are the top 5 customers by sales?")
sql = vn_claude.generate_sql("What are the top 8 customers by sales?")
df = vn_claude.run_sql(sql)
assert len(df) == 5
assert len(df) == 8

class VannaGemini(VannaDB_VectorStore, GoogleGeminiChat):
def __init__(self, config=None):
VannaDB_VectorStore.__init__(self, vanna_model=MY_VANNA_MODEL, vanna_api_key=MY_VANNA_API_KEY, config=config)
GoogleGeminiChat.__init__(self, config=config)

vn_gemini = VannaGemini(config={'api_key': os.environ['GEMINI_API_KEY']})
vn_gemini.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')

def test_vn_gemini():
sql = vn_gemini.generate_sql("What are the top 9 customers by sales?")
df = vn_gemini.run_sql(sql)
assert len(df) == 9

def test_training_plan():
vn_dummy = VannaDefault(model=MY_VANNA_MODEL, api_key=MY_VANNA_API_KEY)
Expand Down

0 comments on commit d2a5cf0

Please sign in to comment.