Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serialization #10

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Contributing to vectara-agemt
# Contributing to vectara-agent

Thank you for your interest in `vectara-agentic` and considering contributing to our project!
Whether it's a bug, a new feature updates to the documentation or anything else - we truly appreciate your time and effort.
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ The `Agent` class defines a few helpful methods to help you understand the inter
* The `report()` method prints out the agent object’s type, the tools, and the LLMs used for the main agent and tool calling.
* The `token_counts()` method tells you how many tokens you have used in the current session for both the main agent and tool calling LLMs. This can be helpful if you want to track spend by token.

## Serialization

The `Agent` class supports serialization. Use the `dumps()` to serialize and `loads()` to read back from a serialized stream.

## Observability

vectara-agentic supports observability via the existing integration of LlamaIndex and Arize Phoenix.
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ llama-index-tools-tavily_research==0.2.0
llama-index-callbacks-arize-phoenix==0.2.1
pydantic==2.8.2
retrying==1.3.4
pymongo==4.6.1
pymongo==4.6.3
python-dotenv==1.0.1
tiktoken==0.7.0
dill==0.3.8
14 changes: 14 additions & 0 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,20 @@ def test_from_corpus(self):
self.assertIsInstance(agent, Agent)
self.assertEqual(agent._topic, "question answering")

def test_serialization(self):
agent = Agent.from_corpus(
tool_name="RAG Tool",
vectara_customer_id="4584783",
vectara_corpus_id="4",
vectara_api_key="api_key",
data_description="information",
assistant_specialty="question answering",
)

agent_reloaded = agent.loads(agent.dumps())
self.assertIsInstance(agent_reloaded, Agent)
self.assertEqual(agent, agent_reloaded)


if __name__ == "__main__":
unittest.main()
188 changes: 180 additions & 8 deletions vectara_agentic/agent.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""
This module contains the Agent class for handling different types of agents and their interactions.
"""
from typing import List, Callable, Optional
from typing import List, Callable, Optional, Dict, Any
import os
from datetime import date
import time
import json
import dill

import logging
logger = logging.getLogger('opentelemetry.exporter.otlp.proto.http.trace_exporter')
Expand All @@ -22,16 +24,18 @@
from llama_index.agent.openai import OpenAIAgent
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core import set_global_handler
from llama_index.core.tools.types import ToolMetadata

import phoenix as px

from dotenv import load_dotenv

from .types import AgentType, AgentStatusType, LLMRole, ObserverType
from .types import AgentType, AgentStatusType, LLMRole, ObserverType, ToolType
from .utils import get_llm, get_tokenizer_for_model
from ._prompts import REACT_PROMPT_TEMPLATE, GENERAL_PROMPT_TEMPLATE
from ._callback import AgentCallbackHandler
from .tools import VectaraToolFactory
from .tools import VectaraToolFactory, VectaraTool


load_dotenv(override=True)

Expand Down Expand Up @@ -67,14 +71,14 @@ class Agent:
"""
Agent class for handling different types of agents and their interactions.
"""

def __init__(
self,
tools: list[FunctionTool],
topic: str = "general",
custom_instructions: str = "",
verbose: bool = True,
update_func: Optional[Callable[[AgentStatusType, str], None]] = None,
agent_type: AgentType = AgentType(os.getenv("VECTARA_AGENTIC_AGENT_TYPE", "OPENAI")),
) -> None:
"""
Initialize the agent with the specified type, tools, topic, and system message.
Expand All @@ -87,7 +91,7 @@ def __init__(
verbose (bool, optional): Whether the agent should print its steps. Defaults to True.
update_func (Callable): A callback function the code calls on any agent updates.
"""
self.agent_type = AgentType(os.getenv("VECTARA_AGENTIC_AGENT_TYPE", "OPENAI"))
self.agent_type = agent_type
self.tools = tools
self.llm = get_llm(LLMRole.MAIN)
self._custom_instructions = custom_instructions
Expand Down Expand Up @@ -152,6 +156,55 @@ def __init__(
else:
print("No observer set.")

def __eq__(self, other):
if not isinstance(other, Agent):
print(f"Comparison failed: other is not an instance of Agent. (self: {type(self)}, other: {type(other)})")
return False

# Compare agent_type
if self.agent_type != other.agent_type:
print(f"Comparison failed: agent_type differs. (self.agent_type: {self.agent_type}, other.agent_type: {other.agent_type})")
return False

# Compare tools
if self.tools != other.tools:
print(f"Comparison failed: tools differ. (self.tools: {self.tools}, other.tools: {other.tools})")
return False

# Compare topic
if self._topic != other._topic:
print(f"Comparison failed: topic differs. (self.topic: {self._topic}, other.topic: {other._topic})")
return False

# Compare custom_instructions
if self._custom_instructions != other._custom_instructions:
print(f"Comparison failed: custom_instructions differ. (self.custom_instructions: {self._custom_instructions}, other.custom_instructions: {other._custom_instructions})")
return False

# Compare verbose
if self.verbose != other.verbose:
print(f"Comparison failed: verbose differs. (self.verbose: {self.verbose}, other.verbose: {other.verbose})")
return False

# Compare agent
if self.agent.memory.chat_store != other.agent.memory.chat_store:
print(f"Comparison failed: agent memory differs. (self.agent: {repr(self.agent.memory.chat_store)}, other.agent: {repr(other.agent.memory.chat_store)})")
return False

# If all comparisons pass
print("All comparisons passed. Objects are equal.")
return True


# def __eq__(self, other):
# print(f"self: {self.__dict__}, other: {other.__dict__}")
# return (isinstance(other, Agent) and
# self.agent_type == other.agent_type and self.tools == other.tools and
# self.agent.memory.chat_store == other.agent.memory.chat_store and
# self._topic == other._topic and
# self._custom_instructions == other._custom_instructions and
# self.verbose == other.verbose)

@classmethod
def from_tools(
cls,
Expand Down Expand Up @@ -208,7 +261,7 @@ def from_corpus(
data_description (str): The description of the data.
assistant_specialty (str): The specialty of the assistant.
verbose (bool, optional): Whether to print verbose output.
vectara_filter_fields (List[dict], optional): The filterable attributes (each dict includes name, type, and description).
vectara_filter_fields (List[dict], optional): The filterable attributes (each dict maps field name to Tuple[type, description]).
vectara_lambda_val (float, optional): The lambda value for Vectara hybrid search.
vectara_reranker (str, optional): The Vectara reranker name (default "mmr")
vectara_rerank_k (int, optional): The number of results to use with reranking.
Expand All @@ -224,9 +277,10 @@ def from_corpus(
vectara_customer_id=vectara_customer_id,
vectara_corpus_id=vectara_corpus_id)
field_definitions = {}
field_definitions['query'] = (str, Field(description="The user query"))
field_definitions['query'] = (str, Field(description="The user query")) # type: ignore
for field in vectara_filter_fields:
field_definitions[field['name']] = (eval(field['type']), Field(description=field['description'], default=None)) # type: ignore
field_definitions[field['name']] = (eval(field['type']),
Field(description=field['description'])) # type: ignore
QueryArgs = create_model( # type: ignore
"QueryArgs",
**field_definitions
Expand Down Expand Up @@ -315,3 +369,121 @@ def chat(self, prompt: str) -> str:
except Exception as e:
import traceback
return f"Vectara Agentic: encountered an exception ({e}) at ({traceback.format_exc()}), and can't respond."

# Serialization methods

def to_dict(self) -> Dict[str, Any]:
"""Serialize the Agent instance to a dictionary."""
tool_info = []

for tool in self.tools:
tool_dict = {
"tool_type": tool.tool_type.value,
"name": tool._metadata.name,
"description": tool._metadata.description,
"fn": dill.dumps(tool.fn).decode('latin-1') if tool.fn else None, # Serialize fn
"async_fn": dill.dumps(tool.async_fn).decode('latin-1') if tool.async_fn else None, # Serialize async_fn
}
tool_info.append(tool_dict)

return {
"agent_type": self.agent_type.value,
"tools": tool_info,
"topic": self._topic,
"custom_instructions": self._custom_instructions,
"verbose": self.verbose,
}


def dumps(self) -> str:
"""Serialize the Agent instance to a JSON string."""
return json.dumps(self.to_dict())

@classmethod
def loads(cls, data: str) -> "Agent":
"""Create an Agent instance from a JSON string."""
return cls.from_dict(json.loads(data))


def to_dict(self) -> Dict[str, Any]:
"""Serialize the Agent instance to a dictionary."""
tool_info = []

for tool in self.tools:
# Serialize each tool's metadata, function, and dynamic model schema (QueryArgs)
tool_dict = {
"tool_type": tool.tool_type.value,
"name": tool._metadata.name,
"description": tool._metadata.description,
"fn": dill.dumps(tool.fn).decode('latin-1') if tool.fn else None, # Serialize fn
"async_fn": dill.dumps(tool.async_fn).decode('latin-1') if tool.async_fn else None, # Serialize async_fn
"fn_schema": tool._metadata.fn_schema.model_json_schema() if hasattr(tool._metadata, 'fn_schema') else None, # Serialize schema if available
}
tool_info.append(tool_dict)

return {
"agent_type": self.agent_type.value,
"memory": dill.dumps(self.agent.memory).decode('latin-1'),
"tools": tool_info,
"topic": self._topic,
"custom_instructions": self._custom_instructions,
"verbose": self.verbose,
}

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Agent":
"""Create an Agent instance from a dictionary."""
agent_type = AgentType(data["agent_type"])
tools = []

JSON_TYPE_TO_PYTHON = {
"string": "str",
"integer": "int",
"boolean": "bool",
"array": "list",
"object": "dict",
"number": "float",
}

for tool_data in data["tools"]:
# Recreate the dynamic model using the schema info
if tool_data.get("fn_schema"):
field_definitions = {}
for field,values in tool_data["fn_schema"]["properties"].items():
if 'default' in values:
field_definitions[field] = (eval(JSON_TYPE_TO_PYTHON.get(values['type'], values['type'])),
Field(description=values['description'], default=values['default'])) # type: ignore
else:
field_definitions[field] = (eval(JSON_TYPE_TO_PYTHON.get(values['type'], values['type'])),
Field(description=values['description'])) # type: ignore
query_args_model = create_model( # type: ignore
"QueryArgs",
**field_definitions
)
else:
query_args_model = create_model("QueryArgs")

fn = dill.loads(tool_data["fn"].encode('latin-1')) if tool_data["fn"] else None
async_fn = dill.loads(tool_data["async_fn"].encode('latin-1')) if tool_data["async_fn"] else None

tool = VectaraTool.from_defaults(
tool_type=ToolType(tool_data["tool_type"]),
name=tool_data["name"],
description=tool_data["description"],
fn=fn,
async_fn=async_fn,
fn_schema=query_args_model # Re-assign the recreated dynamic model
)
tools.append(tool)

agent = cls(
tools=tools,
agent_type=agent_type,
topic=data["topic"],
custom_instructions=data["custom_instructions"],
verbose=data["verbose"],
)
memory = dill.loads(data["memory"].encode('latin-1')) if data.get("memory") else None
if memory:
agent.agent.memory = memory
return agent
29 changes: 25 additions & 4 deletions vectara_agentic/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,33 @@ def from_defaults(
tool = FunctionTool.from_defaults(fn, name, description, return_direct, fn_schema, async_fn, tool_metadata)
vectara_tool = cls(
tool_type=tool_type,
fn=tool.fn,
metadata=tool.metadata,
async_fn=tool.async_fn
fn=tool._fn,
metadata=tool._metadata,
async_fn=tool._async_fn
)
return vectara_tool

def __eq__(self, other):
if self.tool_type != other.tool_type:
return False

# Check if fn_schema is an instance of a BaseModel or a class itself (metaclass)
self_schema_dict = self.metadata.fn_schema.model_fields
other_schema_dict = other.metadata.fn_schema.model_fields
is_equal = True
for key in self_schema_dict.keys():
if key not in other_schema_dict:
is_equal = False
print("Not Equal 1")
break
if (self_schema_dict[key].annotation != other_schema_dict[key].annotation or
self_schema_dict[key].description != other_schema_dict[key].description or
self_schema_dict[key].is_required() != other_schema_dict[key].is_required()):
is_equal = False
print("Not Equal 2")
break
return is_equal


class VectaraToolFactory:
"""
Expand Down Expand Up @@ -255,7 +276,7 @@ def rag_function(*args, **kwargs) -> ToolOutput:
)
return out

fields = tool_args_schema.__fields__
fields = tool_args_schema.model_fields
params = [
inspect.Parameter(
name=field_name,
Expand Down