diff --git a/src/rai_hmi/rai_hmi/text_hmi.py b/src/rai_hmi/rai_hmi/text_hmi.py index bad973f8..0c799184 100644 --- a/src/rai_hmi/rai_hmi/text_hmi.py +++ b/src/rai_hmi/rai_hmi/text_hmi.py @@ -24,6 +24,7 @@ AIMessage, BaseMessage, HumanMessage, + SystemMessage, ToolCall, ToolMessage, ) @@ -84,9 +85,9 @@ def initialize_agent(_node: BaseHMINode): return agent -def initialize_session_memory(): +def initialize_session_memory(system_prompt: str = ""): if "memory" not in st.session_state: - st.session_state.memory = [] + st.session_state.memory = [SystemMessage(content=system_prompt)] if "tool_calls" not in st.session_state: st.session_state.tool_calls = {} @@ -134,6 +135,8 @@ def handle_history_message(message: BaseMessage): base_64_image = artifact["images"][0] image = decode_base64_into_image(base_64_image) st.image(image) + elif isinstance(message, SystemMessage): + return # we do not handle system messages else: raise ValueError("Unknown message type") @@ -142,7 +145,7 @@ def handle_history_message(message: BaseMessage): with st.spinner("Initializing ROS 2 node..."): node = initialize_ros_node(robot_description_package) agent = initialize_agent(_node=node) - initialize_session_memory() + initialize_session_memory(system_prompt=node.system_prompt) status = { "robot_database": node.faiss_index is not None,