Skip to content

Commit

Permalink
feat(rai_node): execute mission as ros2 action (#208)
Browse files Browse the repository at this point in the history
  • Loading branch information
boczekbartek authored Sep 17, 2024
1 parent b737575 commit f84f3ec
Show file tree
Hide file tree
Showing 8 changed files with 248 additions and 162 deletions.
25 changes: 19 additions & 6 deletions examples/rosbot-xl-generic-node-demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,19 @@
import rclpy.qos
import rclpy.subscription
import rclpy.task
from langchain.tools.render import render_text_description_and_args
from langchain_openai import ChatOpenAI

from rai.agents.state_based import create_state_based_agent
from rai.node import RaiNode, describe_ros_image, wait_for_2s
from rai.node import RaiNode, describe_ros_image
from rai.tools.ros.native import (
GetCameraImage,
GetMsgFromTopic,
Ros2ShowMsgInterfaceTool,
)
from rai.tools.ros.native_actions import Ros2RunActionSync
from rai.tools.ros.tools import GetOccupancyGridTool
from rai.tools.time import WaitForSecondsTool


def main():
Expand Down Expand Up @@ -68,10 +70,9 @@ def main():
"/wait",
]

SYSTEM_PROMPT = "You are an autonomous robot connected to ros2 environment. Your main goal is to fulfill the user's requests. "
"Do not make assumptions about the environment you are currently in. "
"Use the tooling provided to gather information about the environment."
"You can use ros2 topics, services and actions to operate."
# TODO(boczekbartek): refactor system prompt

SYSTEM_PROMPT = ""

node = RaiNode(
llm=ChatOpenAI(
Expand All @@ -84,7 +85,7 @@ def main():
)

tools = [
wait_for_2s,
WaitForSecondsTool(),
GetMsgFromTopic(node=node),
Ros2RunActionSync(node=node),
GetCameraImage(node=node),
Expand All @@ -94,6 +95,18 @@ def main():

state_retriever = node.get_robot_state

SYSTEM_PROMPT = f"""You are an autonomous robot connected to ros2 environment. Your main goal is to fulfill the user's requests.
Do not make assumptions about the environment you are currently in.
Use the tooling provided to gather information about the environment:
{render_text_description_and_args(tools)}
You can use ros2 topics, services and actions to operate. """

node.get_logger().info(f"{SYSTEM_PROMPT=}")

node.system_prompt = node.initialize_system_prompt(SYSTEM_PROMPT)

app = create_state_based_agent(
llm=llm,
tools=tools,
Expand Down
30 changes: 25 additions & 5 deletions src/rai/rai/agents/state_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pickle
import time
from functools import partial
from pathlib import Path
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -76,20 +77,39 @@ class Report(BaseModel):
steps: List[str] = Field(
..., title="Steps", description="The steps taken to solve the problem"
)
success: bool = Field(
..., title="Success", description="Whether the problem was solved"
)
response_to_user: str = Field(
..., title="Response", description="The response to the user"
)


def get_stored_artifacts(tool_call_id: str) -> List[Any]:
with open("artifact_database.pkl", "rb") as file:
artifact_database = pickle.load(file)
def get_stored_artifacts(
tool_call_id: str, db_path="artifact_database.pkl"
) -> List[Any]:
# TODO(boczekbartek): refactor
db_path = Path(db_path)
if not db_path.is_file():
return []

with db_path.open("rb") as db:
artifact_database = pickle.load(db)
if tool_call_id in artifact_database:
return artifact_database[tool_call_id]

return []


def store_artifacts(tool_call_id: str, artifacts: List[Any]):
def store_artifacts(
tool_call_id: str, artifacts: List[Any], db_path="artifact_database.pkl"
):
# TODO(boczekbartek): refactor
db_path = Path(db_path)
if not db_path.is_file():
artifact_database = {}
with open("artifact_database.pkl", "wb") as file:
pickle.dump(artifact_database, file)
with open("artifact_database.pkl", "rb") as file:
artifact_database = pickle.load(file)
if tool_call_id not in artifact_database:
Expand Down Expand Up @@ -283,7 +303,7 @@ def retriever_wrapper(
info = str_output(retrieved_info)
state["messages"].append(
HumanMultimodalMessage(
content="Retrieved state: {}".format(info), images=images, audios=audios
content=f"Retrieved state: {info}", images=images, audios=audios
)
)
return state
Expand Down
Loading

0 comments on commit f84f3ec

Please sign in to comment.