Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: HMI Node #143

Merged
merged 9 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,4 @@ logs/
!examples/imgs/*.md

src/examples/*-demo
artifact_database.pkl
25 changes: 19 additions & 6 deletions examples/rosbot-xl-generic-node-demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,19 @@
import rclpy.qos
import rclpy.subscription
import rclpy.task
from langchain.tools.render import render_text_description_and_args
from langchain_openai import ChatOpenAI

from rai.agents.state_based import create_state_based_agent
from rai.node import RaiNode, describe_ros_image, wait_for_2s
from rai.node import RaiNode, describe_ros_image
from rai.tools.ros.native import (
GetCameraImage,
GetMsgFromTopic,
Ros2ShowMsgInterfaceTool,
)
from rai.tools.ros.native_actions import Ros2RunActionSync
from rai.tools.ros.tools import GetOccupancyGridTool
from rai.tools.time import WaitForSecondsTool


def main():
Expand Down Expand Up @@ -68,10 +70,9 @@ def main():
"/wait",
]

SYSTEM_PROMPT = "You are an autonomous robot connected to ros2 environment. Your main goal is to fulfill the user's requests. "
"Do not make assumptions about the environment you are currently in. "
"Use the tooling provided to gather information about the environment."
"You can use ros2 topics, services and actions to operate."
# TODO(boczekbartek): refactor system prompt

SYSTEM_PROMPT = ""

node = RaiNode(
llm=ChatOpenAI(
Expand All @@ -84,7 +85,7 @@ def main():
)

tools = [
wait_for_2s,
WaitForSecondsTool(),
GetMsgFromTopic(node=node),
Ros2RunActionSync(node=node),
GetCameraImage(node=node),
Expand All @@ -94,6 +95,18 @@ def main():

state_retriever = node.get_robot_state

SYSTEM_PROMPT = f"""You are an autonomous robot connected to ros2 environment. Your main goal is to fulfill the user's requests.
Do not make assumptions about the environment you are currently in.
Use the tooling provided to gather information about the environment:

{render_text_description_and_args(tools)}

You can use ros2 topics, services and actions to operate. """

node.get_logger().info(f"{SYSTEM_PROMPT=}")

node.system_prompt = node.initialize_system_prompt(SYSTEM_PROMPT)

app = create_state_based_agent(
llm=llm,
tools=tools,
Expand Down
12 changes: 6 additions & 6 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ langchain-aws = "^0.1.7"
langchain-openai = "^0.1.8"
langchain-community = "^0.2.4"
transforms3d = "^0.4.1"
langgraph = "^0.0.66"
langgraph = "^0.1.0"
tabulate = "^0.9.0"
lark = "^1.1.9"
langfuse = "^2.36.1"
Expand Down
49 changes: 47 additions & 2 deletions src/rai/rai/agents/state_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
#

import logging
import pickle
import time
from functools import partial
from pathlib import Path
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -75,11 +77,49 @@ class Report(BaseModel):
steps: List[str] = Field(
..., title="Steps", description="The steps taken to solve the problem"
)
success: bool = Field(
..., title="Success", description="Whether the problem was solved"
)
response_to_user: str = Field(
..., title="Response", description="The response to the user"
)


def get_stored_artifacts(
tool_call_id: str, db_path="artifact_database.pkl"
) -> List[Any]:
# TODO(boczekbartek): refactor
db_path = Path(db_path)
if not db_path.is_file():
return []

with db_path.open("rb") as db:
artifact_database = pickle.load(db)
if tool_call_id in artifact_database:
return artifact_database[tool_call_id]

return []


def store_artifacts(
tool_call_id: str, artifacts: List[Any], db_path="artifact_database.pkl"
):
# TODO(boczekbartek): refactor
db_path = Path(db_path)
if not db_path.is_file():
artifact_database = {}
with open("artifact_database.pkl", "wb") as file:
pickle.dump(artifact_database, file)
with open("artifact_database.pkl", "rb") as file:
artifact_database = pickle.load(file)
if tool_call_id not in artifact_database:
artifact_database[tool_call_id] = artifacts
else:
artifact_database[tool_call_id].extend(artifacts)
with open("artifact_database.pkl", "wb") as file:
pickle.dump(artifact_database, file)


class ToolRunner(RunnableCallable):
def __init__(
self,
Expand Down Expand Up @@ -126,13 +166,15 @@ def run_one(call: ToolCall):
content=f"Failed to run tool. Error: {e}",
name=call["name"],
tool_call_id=call["id"],
status="error",
)
except Exception as e:
self.logger.info(f'Error in "{call["name"]}", error: {e}')
output = ToolMessage(
content=f"Failed to run tool. Error: {e}",
name=call["name"],
tool_call_id=call["id"],
status="error",
)

if output.artifact is not None:
Expand All @@ -143,6 +185,7 @@ def run_one(call: ToolCall):
)

artifact = cast(MultimodalArtifact, artifact)
store_artifacts(output.tool_call_id, [artifact])

if artifact is not None: # multimodal case
return ToolMultimodalMessage(
Expand All @@ -160,7 +203,9 @@ def run_one(call: ToolCall):
outputs: List[Any] = []
for raw_output in raw_outputs:
if isinstance(raw_output, ToolMultimodalMessage):
outputs.extend(raw_output.postprocess())
outputs.extend(
raw_output.postprocess()
) # openai please allow tool messages with images!
else:
outputs.append(raw_output)

Expand Down Expand Up @@ -258,7 +303,7 @@ def retriever_wrapper(
info = str_output(retrieved_info)
state["messages"].append(
HumanMultimodalMessage(
content="Retrieved state: {}".format(info), images=images, audios=audios
content=f"Retrieved state: {info}", images=images, audios=audios
)
)
return state
Expand Down
5 changes: 5 additions & 0 deletions src/rai/rai/messages/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def __init__(
_content.extend(_image_content)
self.content = _content

@property
def text(self) -> str:
return self.content[0]["text"]


class HumanMultimodalMessage(HumanMessage, MultimodalMessage):
def __repr_args__(self) -> Any:
Expand Down Expand Up @@ -104,6 +108,7 @@ def _postprocess_openai(self):
human_message = HumanMultimodalMessage(
content=f"Image returned by a tool call {self.tool_call_id}",
images=self.images,
tool_call_id=self.tool_call_id,
)
# at this point self.content is a list of dicts
# we need to extract the text from each dict
Expand Down
Loading