From 013097384db17b7a517b46446bbf138e06ad021d Mon Sep 17 00:00:00 2001 From: Stephen von Takach Date: Thu, 2 Nov 2023 16:42:23 +1100 Subject: [PATCH] feat(chat_gpt): compress chats once tasks are completed ensures we have the tokens to continue making requests --- shard.lock | 2 +- .../controllers/chat_gpt/chat_manager.cr | 113 +++++++++++++++--- 2 files changed, 95 insertions(+), 20 deletions(-) diff --git a/shard.lock b/shard.lock index e49b7325..2a73ce7b 100644 --- a/shard.lock +++ b/shard.lock @@ -191,7 +191,7 @@ shards: openai: git: https://github.com/spider-gazelle/crystal-openai.git - version: 0.9.0+git.commit.e6bfaba7758f992d7cb81cad0109180d5be2d958 + version: 0.9.0+git.commit.b6c669a09a57aabcd2d156e0dcc85f0a6255408d openapi-generator: git: https://github.com/place-labs/openapi-generator.git diff --git a/src/placeos-rest-api/controllers/chat_gpt/chat_manager.cr b/src/placeos-rest-api/controllers/chat_gpt/chat_manager.cr index a2d36698..079d26a2 100644 --- a/src/placeos-rest-api/controllers/chat_gpt/chat_manager.cr +++ b/src/placeos-rest-api/controllers/chat_gpt/chat_manager.cr @@ -66,8 +66,9 @@ module PlaceOS::Api ws_sockets[ws_id] = {ws, id, c, req, e} {ws, id, c, req, e} end - resp = openai_interaction(client, completion_req, executor, message, chat_id) - ws.send(resp.to_json) + openai_interaction(client, completion_req, executor, message, chat_id) do |resp| + ws.send(resp.to_json) + end end rescue error Log.warn(exception: error) { "failure processing chat message" } @@ -77,7 +78,7 @@ module PlaceOS::Api private def setup(chat, chat_payload) client = build_client - executor = build_executor(chat, chat_payload) + executor = build_executor(chat) chat_completion = build_completion(build_prompt(chat, chat_payload), executor.functions) {client, executor, chat_completion} @@ -103,25 +104,87 @@ module PlaceOS::Api ) end - private def openai_interaction(client, request, executor, message, chat_id) : NamedTuple(chat_id: String, message: String?) + @total_tokens : Int32 = 0 + + private def openai_interaction(client, request, executor, message, chat_id, &) : Nil request.messages << OpenAI::ChatMessage.new(role: :user, content: message) save_history(chat_id, :user, message) + + # track token usage + discardable_tokens = 0 + tracking_total = 0 + calculate_discard = false + loop do + # ensure new request will fit here + # cleanup old messages, saving first system prompt and then removing messages beyond that until we're within the limit + # we could also restore messages once a task has been completed if there is space + # TODO:: + + # track token usage resp = client.chat_completion(request) + @total_tokens = resp.usage.total_tokens + + if calculate_discard + discardable_tokens += resp.usage.prompt_tokens - tracking_total + calculate_discard = false + end + tracking_total = @total_tokens + + # save relevant history msg = resp.choices.first.message request.messages << msg - save_history(chat_id, msg) + save_history(chat_id, msg) unless msg.function_call || (msg.role.function? && msg.name != "task_complete") + # perform function calls until we get a response for the user if func_call = msg.function_call - func_res = executor.execute(func_call) + discardable_tokens += resp.usage.completion_tokens + + # handle the AI not providing a valid function name, we want it to retry + func_res = begin + executor.execute(func_call) + rescue ex + Log.error(exception: ex) { "executing function call" } + reply = "Encountered error: #{ex.message}" + result = DriverResponse.new(reply).as(JSON::Serializable) + request.messages << OpenAI::ChatMessage.new(:function, result.to_pretty_json, func_call.name) + next + end + + # process the function result + case func_res.name + when "task_complete" + cleanup_messages(request, discardable_tokens) + discardable_tokens = 0 + summary = TaskCompleted.from_json func_call.arguments.as_s + yield({chat_id: chat_id, message: "condensing progress: #{summary.details}", type: :progress, function: func_res.name, usage: resp.usage, compressed_usage: @total_tokens}) + when "list_function_schemas" + calculate_discard = true + discover = FunctionDiscovery.from_json func_call.arguments.as_s + yield({chat_id: chat_id, message: "checking #{discover.id} capabilities", type: :progress, function: func_res.name, usage: resp.usage}) + when "call_function" + calculate_discard = true + execute = FunctionExecutor.from_json func_call.arguments.as_s + yield({chat_id: chat_id, message: "performing action: #{execute.id}.#{execute.function}(#{execute.parameters})", type: :progress, function: func_res.name, usage: resp.usage}) + end request.messages << func_res - save_history(chat_id, msg) next end - break {chat_id: chat_id, message: msg.content} + + cleanup_messages(request, discardable_tokens) + yield({chat_id: chat_id, message: msg.content, type: :response, usage: resp.usage, compressed_usage: @total_tokens}) + break end end + private def cleanup_messages(request, discardable_tokens) + # keep task summaries + request.messages.reject! { |mess| mess.function_call || (mess.role.function? && mess.name != "task_complete") } + + # a good estimate of the total tokens once the cleanup is complete + @total_tokens = @total_tokens - discardable_tokens + end + private def save_history(chat_id : String, role : PlaceOS::Model::ChatMessage::Role, message : String, func_name : String? = nil, func_args : JSON::Any? = nil) : Nil PlaceOS::Model::ChatMessage.create!(role: role, chat_id: chat_id, content: message, function_name: func_name, function_args: func_args) end @@ -148,7 +211,8 @@ module PlaceOS::Api str << "my phone number is: #{user.phone}\n" if user.phone.presence str << "my swipe card number is: #{user.card_number}\n" if user.card_number.presence str << "my user_id is: #{user.id}\n" - str << "use these details in function calls as required\n" + str << "use these details in function calls as required.\n" + str << "perform one task at a time, making as many function calls as required to complete a task. Once a task is complete call the task_complete function with details of the progress you've made.\n" str << "the chat client prepends the date-time each message was sent at in the following format YYYY-MM-DD HH:mm:ss +ZZ:ZZ:ZZ" } ) @@ -177,19 +241,12 @@ module PlaceOS::Api Payload.from_json grab_driver_status(chat, LLM_DRIVER, LLM_DRIVER_PROMPT) end - private def build_executor(chat, payload : Payload?) + private def build_executor(chat) executor = OpenAI::FunctionExecutor.new - description = if payload - "You have the following capability list, described in the following JSON:\n```json\n#{payload.capabilities.to_json}\n```\n" + - "if a request could benefit from these capabilities, obtain the list of functions by providing the id string." - else - "if a request could benefit from a capability, obtain the list of functions by providing the id string" - end - executor.add( name: "list_function_schemas", - description: description, + description: "if a request could benefit from a capability, obtain the list of function schemas by providing the id string", clz: FunctionDiscovery ) do |call| request = call.as(FunctionDiscovery) @@ -206,7 +263,8 @@ module PlaceOS::Api executor.add( name: "call_function", description: "Executes functionality offered by a capability, you'll need to obtain the function schema to perform requests", - clz: FunctionExecutor) do |call| + clz: FunctionExecutor + ) do |call| request = call.as(FunctionExecutor) reply = "No response received" begin @@ -219,6 +277,15 @@ module PlaceOS::Api DriverResponse.new(reply).as(JSON::Serializable) end + executor.add( + name: "task_complete", + description: "Once a task is complete, call this function with the details that are relevant to the conversion. Provide enough detail so you don't perform the actions again and can formulate a response to the user", + clz: TaskCompleted + ) do |call| + request = call.as(TaskCompleted) + request.as(JSON::Serializable) + end + executor end @@ -279,6 +346,14 @@ module PlaceOS::Api getter id : String end + private struct TaskCompleted + extend OpenAI::FuncMarker + include JSON::Serializable + + @[JSON::Field(description: "the details of the task that are relevant to continuing the conversion")] + getter details : String + end + private record DriverResponse, body : String do include JSON::Serializable end