From 7d0dfb3115e58ccbe78ef0b5c6044d7689eb04c9 Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Tue, 20 Aug 2024 00:59:50 +0200 Subject: [PATCH] feat: Task and TaskFeedback actions --- src/rai_hmi/rai_hmi/action_handler_mixin.py | 81 +++++++++++++++++++ src/rai_hmi/rai_hmi/base.py | 33 +------- src/rai_hmi/rai_hmi/voice_hmi.py | 24 +----- src/rai_interfaces/CMakeLists.txt | 2 + src/rai_interfaces/action/Task.action | 12 +++ src/rai_interfaces/action/TaskFeedback.action | 13 +++ 6 files changed, 114 insertions(+), 51 deletions(-) create mode 100644 src/rai_hmi/rai_hmi/action_handler_mixin.py create mode 100644 src/rai_interfaces/action/Task.action create mode 100644 src/rai_interfaces/action/TaskFeedback.action diff --git a/src/rai_hmi/rai_hmi/action_handler_mixin.py b/src/rai_hmi/rai_hmi/action_handler_mixin.py new file mode 100644 index 00000000..cced2814 --- /dev/null +++ b/src/rai_hmi/rai_hmi/action_handler_mixin.py @@ -0,0 +1,81 @@ +from abc import abstractmethod + +from rclpy.action import ActionClient, ActionServer +from rclpy.node import Node + +from rai_hmi.task import Task +from rai_interfaces.action import Task as TaskAction +from rai_interfaces.action import TaskFeedback + + +class TaskActionMixin(Node): + """ + Mixin class to handle Task action client and TaskFeedback action server. + + Provides methods to: + - Send a task to the action server. + - Handle feedback from the action server. + - Handle task result responses. + - Implement an action server for providing task feedback. + + Abstract Methods: + handle_task_feedback_request: Must be implemented by subclasses to process task feedback requests. + """ + + def initialize_task_action_client_and_server(self): + """Initialize the action client and server for task handling.""" + self.task_action_client = ActionClient(self, TaskAction, "perform_task") + self.task_feedback_action_server = ActionServer( + self, TaskFeedback, "provide_task_feedback", self.handle_task_feedback + ) + + def add_task_to_queue(self, task: Task): + """Sends a task to the action server to be handled by the rai node.""" + + if not self.task_action_client.wait_for_server(timeout_sec=10.0): + self.get_logger().error("Task action server not available!") + return + + goal_msg = TaskAction.Goal() + goal_msg.task = task.name + goal_msg.description = task.description + goal_msg.priority = task.priority + + self.get_logger().info(f"Sending task to action server: {goal_msg.task}") + self._send_goal_future = self.task_action_client.send_goal_async( + goal_msg, feedback_callback=self.task_feedback_callback + ) + self._send_goal_future.add_done_callback(self.task_goal_response_callback) + + def task_goal_response_callback(self, future): + """Callback for handling the response from the action server when the goal is sent.""" + goal_handle = future.result() + if not goal_handle.accepted: + self.get_logger().error("Task goal rejected by action server.") + return + + self.get_logger().info("Task goal accepted by action server.") + self._get_result_future = goal_handle.get_result_async() + self._get_result_future.add_done_callback(self.task_result_callback) + + def task_feedback_callback(self, feedback_msg): + """Callback for receiving feedback from the action server.""" + self.get_logger().info(f"Task feedback received: {feedback_msg.feedback}") + + def task_result_callback(self, future): + """Callback for handling the result from the action server.""" + result = future.result().result + if result.success: + self.get_logger().info( + f"Task completed successfully: {result.result_message}" + ) + else: + self.get_logger().error(f"Task failed: {result.result_message}") + + @abstractmethod + def handle_task_feedback_request(self, goal_handle): + """Abstract method to handle TaskFeedback action request.""" + + def handle_task_feedback(self, goal_handle): + """Handles the TaskFeedback action request.""" + return self.handle_task_feedback_request(goal_handle) diff --git a/src/rai_hmi/rai_hmi/base.py b/src/rai_hmi/rai_hmi/base.py index 789072b0..4f545257 100644 --- a/src/rai_hmi/rai_hmi/base.py +++ b/src/rai_hmi/rai_hmi/base.py @@ -13,7 +13,6 @@ # limitations under the License. # -from abc import abstractmethod from enum import Enum from typing import List, Optional, Tuple, cast @@ -23,13 +22,11 @@ from langchain_core.documents import Document from langchain_core.tools import BaseTool from langchain_openai import OpenAIEmbeddings -from rclpy.node import Node from std_msgs.msg import String from std_srvs.srv import Trigger -from rai_hmi.task import Task +from rai_hmi.action_handler_mixin import TaskActionMixin from rai_hmi.tools import QueryDatabaseTool, QueueTaskTool -from rai_interfaces.srv import Feedback class HMIStatus(Enum): @@ -37,7 +34,7 @@ class HMIStatus(Enum): PROCESSING = "processing" -class BaseHMINode(Node): +class BaseHMINode(TaskActionMixin): """ Base class for Human-Machine Interface (HMI) nodes in a robotic system. @@ -90,10 +87,6 @@ def __init__(self, node_name: str): Trigger, "rai_whoami_identity_service" ) - self.feedback_service = self.create_service( - Feedback, "feedback_request", self.feedback_request_callback - ) - self.agent = None # order of the initialization is important self.system_prompt = self._initialize_system_prompt() @@ -122,22 +115,6 @@ def query_faiss_index_with_scores( output = self.faiss_index.similarity_search_with_score(query, k) return output - @abstractmethod - def handle_feedback_request(self, feedback_query: str) -> str: - """Abstract method to handle feedback requests.""" - - def feedback_request_callback( - self, request: Feedback.Request, response: Feedback.Response - ): - """Callback method for the feedback service.""" - feedback_query = request.query - self.get_logger().info(f"Received feedback request: {feedback_query}") - - feedback_response = self.handle_feedback_request(feedback_query) - - response.response = feedback_response - return response - def _initialize_system_prompt(self): while not self.constitution_service.wait_for_service(timeout_sec=1.0): self.get_logger().info( @@ -194,9 +171,3 @@ def _load_documentation(self) -> Optional[FAISS]: ) return None return faiss_index - - def add_task_to_queue(self, task: Task): - """Publishes a task to be handled by the rai node.""" - msg = String() - msg.data = task.json() - self.task_addition_request_publisher.publish(msg) diff --git a/src/rai_hmi/rai_hmi/voice_hmi.py b/src/rai_hmi/rai_hmi/voice_hmi.py index c5a6ba5c..9dd3df33 100644 --- a/src/rai_hmi/rai_hmi/voice_hmi.py +++ b/src/rai_hmi/rai_hmi/voice_hmi.py @@ -18,7 +18,7 @@ import rclpy from langchain.agents import AgentExecutor, create_tool_calling_agent from langchain.prompts import ChatPromptTemplate -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from langchain_openai import ChatOpenAI from rclpy.callback_groups import ReentrantCallbackGroup from std_msgs.msg import String @@ -76,25 +76,9 @@ def handle_human_message(self, msg: String): self.hmi_publisher.publish(String(data=output)) self.processing = False - def handle_feedback_request(self, feedback_query: str) -> str: - self.processing = True - - # handle feedback request - feedback_prompt = ( - "The task executioner is asking for feedback on the following:" - f"```\n{feedback_query}\n```" - "Please provide needed information based on the following chat history:" - ) - local_history: List[BaseMessage] = [ - SystemMessage(content=self.system_prompt), - HumanMessage(content=feedback_prompt), - ] - local_history.extend(self.history) - response = self.agent.invoke({"user_input": "", "chat_history": local_history}) - output = response["output"] - - self.processing = False - return output + # TODO: Implement + def handle_task_feedback_request(self, goal_handle): + pass def main(args=None): diff --git a/src/rai_interfaces/CMakeLists.txt b/src/rai_interfaces/CMakeLists.txt index f3565825..29438c14 100644 --- a/src/rai_interfaces/CMakeLists.txt +++ b/src/rai_interfaces/CMakeLists.txt @@ -22,6 +22,8 @@ rosidl_generate_interfaces(${PROJECT_NAME} "msg/RAIDetectionArray.msg" "srv/RAIGroundingDino.srv" "srv/Feedback.srv" + "action/Task.action" + "action/TaskFeedback.action" DEPENDENCIES std_msgs vision_msgs sensor_msgs ) diff --git a/src/rai_interfaces/action/Task.action b/src/rai_interfaces/action/Task.action new file mode 100644 index 00000000..26e4b1cc --- /dev/null +++ b/src/rai_interfaces/action/Task.action @@ -0,0 +1,12 @@ +# Goal +string task +string description +string priority + +--- +# Result +bool success + +--- +# Feedback +string current_status diff --git a/src/rai_interfaces/action/TaskFeedback.action b/src/rai_interfaces/action/TaskFeedback.action new file mode 100644 index 00000000..055cc2fa --- /dev/null +++ b/src/rai_interfaces/action/TaskFeedback.action @@ -0,0 +1,13 @@ +# Goal +string issue_description +string current_task +string additional_info + +--- +# Result +string informations +bool success + +--- +# Feedback +string feedback