From 864d1423afdbbb98c10225410c84b949fe081a0d Mon Sep 17 00:00:00 2001 From: Vipul Gupta Date: Fri, 12 Apr 2024 00:50:54 +0530 Subject: [PATCH 1/2] Adding Google Gemini Chat integration to Vanna --- pyproject.toml | 3 +- src/vanna/google/__init__.py | 1 + src/vanna/google/gemini_chat.py | 52 +++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 1 deletion(-) create mode 100644 src/vanna/google/__init__.py create mode 100644 src/vanna/google/gemini_chat.py diff --git a/pyproject.toml b/pyproject.toml index b9d3e275..e41bf181 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,8 @@ mysql = ["PyMySQL"] bigquery = ["google-cloud-bigquery"] snowflake = ["snowflake-connector-python"] duckdb = ["duckdb"] -all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo"] +google = ["google-generativeai", "google-cloud-aiplatform"] +all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform"] test = ["tox"] chromadb = ["chromadb"] openai = ["openai"] diff --git a/src/vanna/google/__init__.py b/src/vanna/google/__init__.py new file mode 100644 index 00000000..b0592623 --- /dev/null +++ b/src/vanna/google/__init__.py @@ -0,0 +1 @@ +from .gemini_chat import GoogleGeminiChat \ No newline at end of file diff --git a/src/vanna/google/gemini_chat.py b/src/vanna/google/gemini_chat.py new file mode 100644 index 00000000..2a857f00 --- /dev/null +++ b/src/vanna/google/gemini_chat.py @@ -0,0 +1,52 @@ +import os +from ..base import VannaBase + + +class GoogleGeminiChat(VannaBase): + def __init__(self, config=None): + VannaBase.__init__(self, config=config) + + # default temperature - can be overrided using config + self.temperature = 0.7 + + if "temperature" in config: + self.temperature = config["temperature"] + + if "model_name" in config: + model_name = config["model_name"] + else: + model_name = "gemini-1.0-pro" + + self.google_api_key = None + + if "api_key" in config or os.getenv("GOOGLE_API_KEY"): + """ + If Google api_key is provided through config + or set as an environment variable, assign it. + """ + import google.generativeai as genai + + genai.configure(api_key=config["api_key"]) + self.chat_model = genai.GenerativeModel(model_name) + else: + # Authenticate using VertexAI + from vertexai.preview.generative_models import GenerativeModel + self.chat_model = GenerativeModel("gemini-pro") + + def system_message(self, message: str) -> any: + return message + + def user_message(self, message: str) -> any: + return message + + def assistant_message(self, message: str) -> any: + return message + + def submit_prompt(self, prompt, **kwargs) -> str: + response = self.chat_model.generate_content( + prompt, + generation_config={ + "temperature": self.temperature, + }, + ) + return response.text From 98a9ff60732cba31ccc81477e6b3fe91c6965766 Mon Sep 17 00:00:00 2001 From: Zain Hoda <7146154+zainhoda@users.noreply.github.com> Date: Tue, 16 Apr 2024 11:21:46 -0400 Subject: [PATCH 2/2] Add Gemini to integration tests --- .github/workflows/tests.yml | 1 + tests/test_vanna.py | 18 ++++++++++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c454b524..7f3acda9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -27,6 +27,7 @@ jobs: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} SNOWFLAKE_ACCOUNT: ${{ secrets.SNOWFLAKE_ACCOUNT }} SNOWFLAKE_USERNAME: ${{ secrets.SNOWFLAKE_USERNAME }} SNOWFLAKE_PASSWORD: ${{ secrets.SNOWFLAKE_PASSWORD }} diff --git a/tests/test_vanna.py b/tests/test_vanna.py index 1d11c3f3..82b20195 100644 --- a/tests/test_vanna.py +++ b/tests/test_vanna.py @@ -1,6 +1,7 @@ import os from vanna.anthropic.anthropic_chat import Anthropic_Chat +from vanna.google import GoogleGeminiChat from vanna.mistral.mistral import Mistral from vanna.openai.openai_chat import OpenAI_Chat from vanna.remote import VannaDefault @@ -92,9 +93,22 @@ def __init__(self, config=None): def test_vn_claude(): - sql = vn_claude.generate_sql("What are the top 5 customers by sales?") + sql = vn_claude.generate_sql("What are the top 8 customers by sales?") df = vn_claude.run_sql(sql) - assert len(df) == 5 + assert len(df) == 8 + +class VannaGemini(VannaDB_VectorStore, GoogleGeminiChat): + def __init__(self, config=None): + VannaDB_VectorStore.__init__(self, vanna_model=MY_VANNA_MODEL, vanna_api_key=MY_VANNA_API_KEY, config=config) + GoogleGeminiChat.__init__(self, config=config) + +vn_gemini = VannaGemini(config={'api_key': os.environ['GEMINI_API_KEY']}) +vn_gemini.connect_to_sqlite('https://vanna.ai/Chinook.sqlite') + +def test_vn_gemini(): + sql = vn_gemini.generate_sql("What are the top 9 customers by sales?") + df = vn_gemini.run_sql(sql) + assert len(df) == 9 def test_training_plan(): vn_dummy = VannaDefault(model=MY_VANNA_MODEL, api_key=MY_VANNA_API_KEY)