Skip to content

Commit

Permalink
feat: implement voice_hmi agents
Browse files Browse the repository at this point in the history
  • Loading branch information
maciejmajek committed Aug 19, 2024
1 parent abdbc4b commit e4d5f63
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 5 deletions.
60 changes: 55 additions & 5 deletions src/rai_hmi/rai_hmi/voice_hmi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,22 @@
# limitations under the License.
#

from typing import List

import rclpy
from langchain.agents import AgentExecutor, create_tool_calling_agent
from langchain.prompts import ChatPromptTemplate
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_openai import ChatOpenAI
from rclpy.callback_groups import ReentrantCallbackGroup
from std_msgs.msg import String

from rai_hmi.base import BaseHMINode


class VoiceHMINode(BaseHMINode):
def __init__(self, node_name: str, robot_description_package: str):
super().__init__(node_name, robot_description_package)
def __init__(self, node_name: str):
super().__init__(node_name)

self.callback_group = ReentrantCallbackGroup()
self.hmi_subscription = self.create_subscription(
Expand All @@ -36,20 +43,63 @@ def __init__(self, node_name: str, robot_description_package: str):
String, "to_human", 10, callback_group=self.callback_group
)

self.history: List[BaseMessage] = []
self.agent = self.initialize_agent()

self.get_logger().info("Voice HMI node initialized")

def initialize_agent(self):
prompt = ChatPromptTemplate.from_messages(
[
("system", self.system_prompt),
("placeholder", "{chat_history}"),
("human", "{user_input}"),
("placeholder", "{agent_scratchpad}"),
]
)
llm = ChatOpenAI(model="gpt-4o")
agent = create_tool_calling_agent(llm=llm, tools=self.tools, prompt=prompt)
agent_executor = AgentExecutor(agent=agent, tools=self.tools)
return agent_executor

def handle_human_message(self, msg: String):
self.processing = True

# handle human message
output = "" # self.agent(msg.data, config=config)
response = self.agent.invoke(
{"user_input": msg.data, "chat_history": self.history}
)
output = response["output"]
self.history.append(HumanMessage(msg.data))
self.history.append(AIMessage(output))

self.processing = False
self.hmi_publisher.publish(String(data=output))
self.processing = False

def handle_feedback_request(self, feedback_query: str) -> str:
self.processing = True

# handle feedback request
output = "" # self.agent(feedback_query, config=config)
feedback_prompt = (
"The task executioner is asking for feedback on the following:"
f"```\n{feedback_query}\n```"
"Please provide needed information based on the following chat history:"
)
local_history: List[BaseMessage] = [
SystemMessage(content=self.system_prompt),
HumanMessage(content=feedback_prompt),
]
local_history.extend(self.history)
response = self.agent.invoke({"user_input": "", "chat_history": local_history})
output = response["output"]

self.processing = False
return output


def main(args=None):
rclpy.init(args=args)
voice_hmi_node = VoiceHMINode("voice_hmi_node")
rclpy.spin(voice_hmi_node)
voice_hmi_node.destroy_node()
rclpy.shutdown()
1 change: 1 addition & 0 deletions src/rai_hmi/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
entry_points={
"console_scripts": [
"hmi_node = rai_hmi.hmi_node:main",
"voice_hmi_node = rai_hmi.voice_hmi:main",
],
},
)

0 comments on commit e4d5f63

Please sign in to comment.