Skip to content

Commit

Permalink
feat: Task and TaskFeedback actions
Browse files Browse the repository at this point in the history
  • Loading branch information
maciejmajek committed Aug 19, 2024
1 parent e4d5f63 commit 7d0dfb3
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 51 deletions.
81 changes: 81 additions & 0 deletions src/rai_hmi/rai_hmi/action_handler_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from abc import abstractmethod

from rclpy.action import ActionClient, ActionServer
from rclpy.node import Node

from rai_hmi.task import Task
from rai_interfaces.action import Task as TaskAction
from rai_interfaces.action import TaskFeedback


class TaskActionMixin(Node):
"""
Mixin class to handle Task action client and TaskFeedback action server.
Provides methods to:
- Send a task to the action server.
- Handle feedback from the action server.
- Handle task result responses.
- Implement an action server for providing task feedback.
Abstract Methods:
handle_task_feedback_request: Must be implemented by subclasses to process task feedback requests.
"""

def initialize_task_action_client_and_server(self):
"""Initialize the action client and server for task handling."""
self.task_action_client = ActionClient(self, TaskAction, "perform_task")
self.task_feedback_action_server = ActionServer(
self, TaskFeedback, "provide_task_feedback", self.handle_task_feedback
)

def add_task_to_queue(self, task: Task):
"""Sends a task to the action server to be handled by the rai node."""

if not self.task_action_client.wait_for_server(timeout_sec=10.0):
self.get_logger().error("Task action server not available!")
return

goal_msg = TaskAction.Goal()
goal_msg.task = task.name
goal_msg.description = task.description
goal_msg.priority = task.priority

self.get_logger().info(f"Sending task to action server: {goal_msg.task}")
self._send_goal_future = self.task_action_client.send_goal_async(
goal_msg, feedback_callback=self.task_feedback_callback
)
self._send_goal_future.add_done_callback(self.task_goal_response_callback)

def task_goal_response_callback(self, future):
"""Callback for handling the response from the action server when the goal is sent."""
goal_handle = future.result()
if not goal_handle.accepted:
self.get_logger().error("Task goal rejected by action server.")
return

self.get_logger().info("Task goal accepted by action server.")
self._get_result_future = goal_handle.get_result_async()
self._get_result_future.add_done_callback(self.task_result_callback)

def task_feedback_callback(self, feedback_msg):
"""Callback for receiving feedback from the action server."""
self.get_logger().info(f"Task feedback received: {feedback_msg.feedback}")

def task_result_callback(self, future):
"""Callback for handling the result from the action server."""
result = future.result().result
if result.success:
self.get_logger().info(
f"Task completed successfully: {result.result_message}"
)
else:
self.get_logger().error(f"Task failed: {result.result_message}")

@abstractmethod
def handle_task_feedback_request(self, goal_handle):
"""Abstract method to handle TaskFeedback action request."""

def handle_task_feedback(self, goal_handle):
"""Handles the TaskFeedback action request."""
return self.handle_task_feedback_request(goal_handle)
33 changes: 2 additions & 31 deletions src/rai_hmi/rai_hmi/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
#

from abc import abstractmethod
from enum import Enum
from typing import List, Optional, Tuple, cast

Expand All @@ -23,21 +22,19 @@
from langchain_core.documents import Document
from langchain_core.tools import BaseTool
from langchain_openai import OpenAIEmbeddings
from rclpy.node import Node
from std_msgs.msg import String
from std_srvs.srv import Trigger

from rai_hmi.task import Task
from rai_hmi.action_handler_mixin import TaskActionMixin
from rai_hmi.tools import QueryDatabaseTool, QueueTaskTool
from rai_interfaces.srv import Feedback


class HMIStatus(Enum):
WAITING = "waiting"
PROCESSING = "processing"


class BaseHMINode(Node):
class BaseHMINode(TaskActionMixin):
"""
Base class for Human-Machine Interface (HMI) nodes in a robotic system.
Expand Down Expand Up @@ -90,10 +87,6 @@ def __init__(self, node_name: str):
Trigger, "rai_whoami_identity_service"
)

self.feedback_service = self.create_service(
Feedback, "feedback_request", self.feedback_request_callback
)

self.agent = None
# order of the initialization is important
self.system_prompt = self._initialize_system_prompt()
Expand Down Expand Up @@ -122,22 +115,6 @@ def query_faiss_index_with_scores(
output = self.faiss_index.similarity_search_with_score(query, k)
return output

@abstractmethod
def handle_feedback_request(self, feedback_query: str) -> str:
"""Abstract method to handle feedback requests."""

def feedback_request_callback(
self, request: Feedback.Request, response: Feedback.Response
):
"""Callback method for the feedback service."""
feedback_query = request.query
self.get_logger().info(f"Received feedback request: {feedback_query}")

feedback_response = self.handle_feedback_request(feedback_query)

response.response = feedback_response
return response

def _initialize_system_prompt(self):
while not self.constitution_service.wait_for_service(timeout_sec=1.0):
self.get_logger().info(
Expand Down Expand Up @@ -194,9 +171,3 @@ def _load_documentation(self) -> Optional[FAISS]:
)
return None
return faiss_index

def add_task_to_queue(self, task: Task):
"""Publishes a task to be handled by the rai node."""
msg = String()
msg.data = task.json()
self.task_addition_request_publisher.publish(msg)
24 changes: 4 additions & 20 deletions src/rai_hmi/rai_hmi/voice_hmi.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import rclpy
from langchain.agents import AgentExecutor, create_tool_calling_agent
from langchain.prompts import ChatPromptTemplate
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_openai import ChatOpenAI
from rclpy.callback_groups import ReentrantCallbackGroup
from std_msgs.msg import String
Expand Down Expand Up @@ -76,25 +76,9 @@ def handle_human_message(self, msg: String):
self.hmi_publisher.publish(String(data=output))
self.processing = False

def handle_feedback_request(self, feedback_query: str) -> str:
self.processing = True

# handle feedback request
feedback_prompt = (
"The task executioner is asking for feedback on the following:"
f"```\n{feedback_query}\n```"
"Please provide needed information based on the following chat history:"
)
local_history: List[BaseMessage] = [
SystemMessage(content=self.system_prompt),
HumanMessage(content=feedback_prompt),
]
local_history.extend(self.history)
response = self.agent.invoke({"user_input": "", "chat_history": local_history})
output = response["output"]

self.processing = False
return output
# TODO: Implement
def handle_task_feedback_request(self, goal_handle):
pass


def main(args=None):
Expand Down
2 changes: 2 additions & 0 deletions src/rai_interfaces/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ rosidl_generate_interfaces(${PROJECT_NAME}
"msg/RAIDetectionArray.msg"
"srv/RAIGroundingDino.srv"
"srv/Feedback.srv"
"action/Task.action"
"action/TaskFeedback.action"
DEPENDENCIES std_msgs vision_msgs sensor_msgs
)

Expand Down
12 changes: 12 additions & 0 deletions src/rai_interfaces/action/Task.action
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Goal
string task
string description
string priority

---
# Result
bool success

---
# Feedback
string current_status
13 changes: 13 additions & 0 deletions src/rai_interfaces/action/TaskFeedback.action
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Goal
string issue_description
string current_task
string additional_info

---
# Result
string informations
bool success

---
# Feedback
string feedback

0 comments on commit 7d0dfb3

Please sign in to comment.