diff --git a/pyproject.toml b/pyproject.toml index 2383a7c2..a0cdd14f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,3 +42,4 @@ gemini = ["google-generativeai"] marqo = ["marqo"] zhipuai = ["zhipuai"] qdrant = ["qdrant-client"] +vllm = ["vllm"] diff --git a/src/vanna/vllm/__init__.py b/src/vanna/vllm/__init__.py new file mode 100644 index 00000000..171086e8 --- /dev/null +++ b/src/vanna/vllm/__init__.py @@ -0,0 +1 @@ +from .vllm import Vllm diff --git a/src/vanna/vllm/vllm.py b/src/vanna/vllm/vllm.py new file mode 100644 index 00000000..0dd67e4f --- /dev/null +++ b/src/vanna/vllm/vllm.py @@ -0,0 +1,76 @@ +import re + +import requests + +from ..base import VannaBase + + +class Vllm(VannaBase): + def __init__(self, config=None): + if config is None or "vllm_host" not in config: + self.host = "http://localhost:8000" + else: + self.host = config["vllm_host"] + + if config is None or "model" not in config: + raise ValueError("check the config for vllm") + else: + self.model = config["model"] + + def system_message(self, message: str) -> any: + return {"role": "system", "content": message} + + def user_message(self, message: str) -> any: + return {"role": "user", "content": message} + + def assistant_message(self, message: str) -> any: + return {"role": "assistant", "content": message} + + def extract_sql_query(self, text): + """ + Extracts the first SQL statement after the word 'select', ignoring case, + matches until the first semicolon, three backticks, or the end of the string, + and removes three backticks if they exist in the extracted string. + + Args: + - text (str): The string to search within for an SQL statement. + + Returns: + - str: The first SQL statement found, with three backticks removed, or an empty string if no match is found. + """ + # Regular expression to find 'select' (ignoring case) and capture until ';', '```', or end of string + pattern = re.compile(r"select.*?(?:;|```|$)", re.IGNORECASE | re.DOTALL) + + match = pattern.search(text) + if match: + # Remove three backticks from the matched string if they exist + return match.group(0).replace("```", "") + else: + return text + + def generate_sql(self, question: str, **kwargs) -> str: + # Use the super generate_sql + sql = super().generate_sql(question, **kwargs) + + # Replace "\_" with "_" + sql = sql.replace("\\_", "_") + + sql = sql.replace("\\", "") + + return self.extract_sql_query(sql) + + def submit_prompt(self, prompt, **kwargs) -> str: + url = f"{self.host}/v1/chat/completions" + data = { + "model": self.model, + "stream": False, + "messages": prompt, + } + + response = requests.post(url, json=data) + + response_dict = response.json() + + self.log(response.text) + + return response_dict['choices'][0]['message']['content'] diff --git a/tests/test_imports.py b/tests/test_imports.py index c55df02d..53153e2d 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -15,7 +15,6 @@ def test_regular_imports(): from vanna.ZhipuAI.ZhipuAI_Chat import ZhipuAI_Chat from vanna.ZhipuAI.ZhipuAI_embeddings import ZhipuAI_Embeddings - def test_shortcut_imports(): from vanna.anthropic import Anthropic_Chat from vanna.base import VannaBase @@ -25,4 +24,5 @@ def test_shortcut_imports(): from vanna.ollama import Ollama from vanna.openai import OpenAI_Chat, OpenAI_Embeddings from vanna.vannadb import VannaDB_VectorStore + from vanna.vllm import Vllm from vanna.ZhipuAI import ZhipuAI_Chat, ZhipuAI_Embeddings