Skip to content

Commit

Permalink
Merge pull request #3476 from SailorJoe6/Issue_#1358_Human_Input_soln_2
Browse files Browse the repository at this point in the history
Enable human interaction in AutoGenStudio - Solution 2
  • Loading branch information
victordibia committed Sep 5, 2024
2 parents e1bc1e0 + 330262b commit bb11979
Show file tree
Hide file tree
Showing 9 changed files with 761 additions and 144 deletions.
182 changes: 91 additions & 91 deletions samples/apps/autogen-studio/autogenstudio/chatmanager.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,33 @@
import asyncio
import os
from datetime import datetime
from queue import Queue
from typing import Any, Dict, List, Optional, Tuple, Union

import websockets
from fastapi import WebSocket, WebSocketDisconnect
from loguru import logger

from .datamodel import Message
from .workflowmanager import WorkflowManager

from .websocket_connection_manager import WebSocketConnectionManager

class AutoGenChatManager:
"""
This class handles the automated generation and management of chat interactions
using an automated workflow configuration and message queue.
"""

def __init__(self, message_queue: Queue) -> None:
def __init__(self,
message_queue: Queue,
websocket_manager: WebSocketConnectionManager = None,
human_input_timeout: int = 180) -> None:
"""
Initializes the AutoGenChatManager with a message queue.
:param message_queue: A queue to use for sending messages asynchronously.
"""
self.message_queue = message_queue
self.websocket_manager = websocket_manager
self.a_human_input_timeout = human_input_timeout

def send(self, message: str) -> None:
def send(self, message: dict) -> None:
"""
Sends a message by putting it into the message queue.
Expand All @@ -34,6 +36,46 @@ def send(self, message: str) -> None:
if self.message_queue is not None:
self.message_queue.put_nowait(message)

async def a_send(self, message: dict) -> None:
"""
Asynchronously sends a message via the WebSocketManager class
:param message: The message string to be sent.
"""
for connection, socket_client_id in self.websocket_manager.active_connections:
if message["connection_id"] == socket_client_id:
logger.info(
f"Sending message to connection_id: {message['connection_id']}. Connection ID: {socket_client_id}"
)
await self.websocket_manager.send_message(message, connection)
else:
logger.info(
f"Skipping message for connection_id: {message['connection_id']}. Connection ID: {socket_client_id}"
)

async def a_prompt_for_input(self, prompt: dict, timeout: int = 60) -> str:
"""
Sends the user a prompt and waits for a response asynchronously via the WebSocketManager class
:param message: The message string to be sent.
"""

for connection, socket_client_id in self.websocket_manager.active_connections:
if prompt["connection_id"] == socket_client_id:
logger.info(
f"Sending message to connection_id: {prompt['connection_id']}. Connection ID: {socket_client_id}"
)
try:
result = await self.websocket_manager.get_input(prompt, connection, timeout)
return result
except Exception as e:
traceback.print_exc()
return f"Error: {e}\nTERMINATE"
else:
logger.info(
f"Skipping message for connection_id: {prompt['connection_id']}. Connection ID: {socket_client_id}"
)

def chat(
self,
message: Message,
Expand Down Expand Up @@ -72,6 +114,7 @@ def chat(
history=history,
work_dir=work_dir,
send_message_function=self.send,
a_send_message_function=self.a_send,
connection_id=connection_id,
)

Expand All @@ -82,96 +125,53 @@ def chat(
result_message.session_id = message.session_id
return result_message


class WebSocketConnectionManager:
"""
Manages WebSocket connections including sending, broadcasting, and managing the lifecycle of connections.
"""

def __init__(
async def a_chat(
self,
active_connections: List[Tuple[WebSocket, str]] = None,
active_connections_lock: asyncio.Lock = None,
) -> None:
"""
Initializes WebSocketConnectionManager with an optional list of active WebSocket connections.
:param active_connections: A list of tuples, each containing a WebSocket object and its corresponding client_id.
"""
if active_connections is None:
active_connections = []
self.active_connections_lock = active_connections_lock
self.active_connections: List[Tuple[WebSocket, str]] = active_connections

async def connect(self, websocket: WebSocket, client_id: str) -> None:
message: Message,
history: List[Dict[str, Any]],
workflow: Any = None,
connection_id: Optional[str] = None,
user_dir: Optional[str] = None,
**kwargs,
) -> Message:
"""
Accepts a new WebSocket connection and appends it to the active connections list.
Processes an incoming message according to the agent's workflow configuration
and generates a response.
:param websocket: The WebSocket instance representing a client connection.
:param client_id: A string representing the unique identifier of the client.
:param message: An instance of `Message` representing an incoming message.
:param history: A list of dictionaries, each representing a past interaction.
:param flow_config: An instance of `AgentWorkFlowConfig`. If None, defaults to a standard configuration.
:param connection_id: An optional connection identifier.
:param kwargs: Additional keyword arguments.
:return: An instance of `Message` representing a response.
"""
await websocket.accept()
async with self.active_connections_lock:
self.active_connections.append((websocket, client_id))
print(f"New Connection: {client_id}, Total: {len(self.active_connections)}")

async def disconnect(self, websocket: WebSocket) -> None:
"""
Disconnects and removes a WebSocket connection from the active connections list.
# create a working director for workflow based on user_dir/session_id/time_hash
work_dir = os.path.join(
user_dir,
str(message.session_id),
datetime.now().strftime("%Y%m%d_%H-%M-%S"),
)
os.makedirs(work_dir, exist_ok=True)

:param websocket: The WebSocket instance to remove.
"""
async with self.active_connections_lock:
try:
self.active_connections = [conn for conn in self.active_connections if conn[0] != websocket]
print(f"Connection Closed. Total: {len(self.active_connections)}")
except ValueError:
print("Error: WebSocket connection not found")

async def disconnect_all(self) -> None:
"""
Disconnects all active WebSocket connections.
"""
for connection, _ in self.active_connections[:]:
await self.disconnect(connection)
# if no flow config is provided, use the default
if workflow is None:
raise ValueError("Workflow must be specified")

async def send_message(self, message: Union[Dict, str], websocket: WebSocket) -> None:
"""
Sends a JSON message to a single WebSocket connection.
workflow_manager = WorkflowManager(
workflow=workflow,
history=history,
work_dir=work_dir,
send_message_function=self.send,
a_send_message_function=self.a_send,
a_human_input_function=self.a_prompt_for_input,
a_human_input_timeout=self.a_human_input_timeout,
connection_id=connection_id,
)

:param message: A JSON serializable dictionary containing the message to send.
:param websocket: The WebSocket instance through which to send the message.
"""
try:
async with self.active_connections_lock:
await websocket.send_json(message)
except WebSocketDisconnect:
print("Error: Tried to send a message to a closed WebSocket")
await self.disconnect(websocket)
except websockets.exceptions.ConnectionClosedOK:
print("Error: WebSocket connection closed normally")
await self.disconnect(websocket)
except Exception as e:
print(f"Error in sending message: {str(e)}", message)
await self.disconnect(websocket)

async def broadcast(self, message: Dict) -> None:
"""
Broadcasts a JSON message to all active WebSocket connections.
message_text = message.content.strip()
result_message: Message = await workflow_manager.a_run(message=f"{message_text}", clear_history=False, history=history)

:param message: A JSON serializable dictionary containing the message to broadcast.
"""
# Create a message dictionary with the desired format
message_dict = {"message": message}

for connection, _ in self.active_connections[:]:
try:
if connection.client_state == websockets.protocol.State.OPEN:
# Call send_message method with the message dictionary and current WebSocket connection
await self.send_message(message_dict, connection)
else:
print("Error: WebSocket connection is closed")
await self.disconnect(connection)
except (WebSocketDisconnect, websockets.exceptions.ConnectionClosedOK) as e:
print(f"Error: WebSocket disconnected or closed({str(e)})")
await self.disconnect(connection)
result_message.user_id = message.user_id
result_message.session_id = message.session_id
return result_message
10 changes: 7 additions & 3 deletions samples/apps/autogen-studio/autogenstudio/web/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from loguru import logger
from openai import OpenAIError

from ..chatmanager import AutoGenChatManager, WebSocketConnectionManager
from ..chatmanager import AutoGenChatManager
from ..websocket_connection_manager import WebSocketConnectionManager
from ..database import workflow_from_id
from ..database.dbmanager import DBManager
from ..datamodel import Agent, Message, Model, Response, Session, Skill, Workflow
Expand Down Expand Up @@ -64,11 +65,14 @@ def message_handler():
database_engine_uri = folders["database_engine_uri"]
dbmanager = DBManager(engine_uri=database_engine_uri)

HUMAN_INPUT_TIMEOUT_SECONDS = 180

@asynccontextmanager
async def lifespan(app: FastAPI):
print("***** App started *****")
managers["chat"] = AutoGenChatManager(message_queue=message_queue)
managers["chat"] = AutoGenChatManager(message_queue=message_queue,
websocket_manager=websocket_manager,
human_input_timeout=HUMAN_INPUT_TIMEOUT_SECONDS)
dbmanager.create_db_and_tables()

yield
Expand Down Expand Up @@ -449,7 +453,7 @@ async def run_session_workflow(message: Message, session_id: int, workflow_id: i
user_dir = os.path.join(folders["files_static_root"], "user", md5_hash(message.user_id))
os.makedirs(user_dir, exist_ok=True)
workflow = workflow_from_id(workflow_id, dbmanager=dbmanager)
agent_response: Message = managers["chat"].chat(
agent_response: Message = await managers["chat"].a_chat(
message=message,
history=user_message_history,
user_dir=user_dir,
Expand Down
Loading

0 comments on commit bb11979

Please sign in to comment.