Skip to content

Commit

Permalink
feat(chat_gpt): compress chats once tasks are completed
Browse files Browse the repository at this point in the history
ensures we have the tokens to continue making requests
  • Loading branch information
stakach committed Nov 2, 2023
1 parent 1783a73 commit 0130973
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 20 deletions.
2 changes: 1 addition & 1 deletion shard.lock
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ shards:

openai:
git: https://github.com/spider-gazelle/crystal-openai.git
version: 0.9.0+git.commit.e6bfaba7758f992d7cb81cad0109180d5be2d958
version: 0.9.0+git.commit.b6c669a09a57aabcd2d156e0dcc85f0a6255408d

openapi-generator:
git: https://github.com/place-labs/openapi-generator.git
Expand Down
113 changes: 94 additions & 19 deletions src/placeos-rest-api/controllers/chat_gpt/chat_manager.cr
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ module PlaceOS::Api
ws_sockets[ws_id] = {ws, id, c, req, e}
{ws, id, c, req, e}
end
resp = openai_interaction(client, completion_req, executor, message, chat_id)
ws.send(resp.to_json)
openai_interaction(client, completion_req, executor, message, chat_id) do |resp|
ws.send(resp.to_json)
end
end
rescue error
Log.warn(exception: error) { "failure processing chat message" }
Expand All @@ -77,7 +78,7 @@ module PlaceOS::Api

private def setup(chat, chat_payload)
client = build_client
executor = build_executor(chat, chat_payload)
executor = build_executor(chat)
chat_completion = build_completion(build_prompt(chat, chat_payload), executor.functions)

{client, executor, chat_completion}
Expand All @@ -103,25 +104,87 @@ module PlaceOS::Api
)
end

private def openai_interaction(client, request, executor, message, chat_id) : NamedTuple(chat_id: String, message: String?)
@total_tokens : Int32 = 0

private def openai_interaction(client, request, executor, message, chat_id, &) : Nil
request.messages << OpenAI::ChatMessage.new(role: :user, content: message)
save_history(chat_id, :user, message)

# track token usage
discardable_tokens = 0
tracking_total = 0
calculate_discard = false

loop do
# ensure new request will fit here
# cleanup old messages, saving first system prompt and then removing messages beyond that until we're within the limit
# we could also restore messages once a task has been completed if there is space
# TODO::

# track token usage
resp = client.chat_completion(request)
@total_tokens = resp.usage.total_tokens

if calculate_discard
discardable_tokens += resp.usage.prompt_tokens - tracking_total
calculate_discard = false
end
tracking_total = @total_tokens

# save relevant history
msg = resp.choices.first.message
request.messages << msg
save_history(chat_id, msg)
save_history(chat_id, msg) unless msg.function_call || (msg.role.function? && msg.name != "task_complete")

# perform function calls until we get a response for the user
if func_call = msg.function_call
func_res = executor.execute(func_call)
discardable_tokens += resp.usage.completion_tokens

# handle the AI not providing a valid function name, we want it to retry
func_res = begin
executor.execute(func_call)
rescue ex
Log.error(exception: ex) { "executing function call" }
reply = "Encountered error: #{ex.message}"
result = DriverResponse.new(reply).as(JSON::Serializable)
request.messages << OpenAI::ChatMessage.new(:function, result.to_pretty_json, func_call.name)
next
end

# process the function result
case func_res.name
when "task_complete"
cleanup_messages(request, discardable_tokens)
discardable_tokens = 0
summary = TaskCompleted.from_json func_call.arguments.as_s
yield({chat_id: chat_id, message: "condensing progress: #{summary.details}", type: :progress, function: func_res.name, usage: resp.usage, compressed_usage: @total_tokens})
when "list_function_schemas"
calculate_discard = true
discover = FunctionDiscovery.from_json func_call.arguments.as_s
yield({chat_id: chat_id, message: "checking #{discover.id} capabilities", type: :progress, function: func_res.name, usage: resp.usage})
when "call_function"
calculate_discard = true
execute = FunctionExecutor.from_json func_call.arguments.as_s
yield({chat_id: chat_id, message: "performing action: #{execute.id}.#{execute.function}(#{execute.parameters})", type: :progress, function: func_res.name, usage: resp.usage})
end
request.messages << func_res
save_history(chat_id, msg)
next
end
break {chat_id: chat_id, message: msg.content}

cleanup_messages(request, discardable_tokens)
yield({chat_id: chat_id, message: msg.content, type: :response, usage: resp.usage, compressed_usage: @total_tokens})
break
end
end

private def cleanup_messages(request, discardable_tokens)
# keep task summaries
request.messages.reject! { |mess| mess.function_call || (mess.role.function? && mess.name != "task_complete") }

# a good estimate of the total tokens once the cleanup is complete
@total_tokens = @total_tokens - discardable_tokens
end

private def save_history(chat_id : String, role : PlaceOS::Model::ChatMessage::Role, message : String, func_name : String? = nil, func_args : JSON::Any? = nil) : Nil
PlaceOS::Model::ChatMessage.create!(role: role, chat_id: chat_id, content: message, function_name: func_name, function_args: func_args)
end
Expand All @@ -148,7 +211,8 @@ module PlaceOS::Api
str << "my phone number is: #{user.phone}\n" if user.phone.presence
str << "my swipe card number is: #{user.card_number}\n" if user.card_number.presence
str << "my user_id is: #{user.id}\n"
str << "use these details in function calls as required\n"
str << "use these details in function calls as required.\n"
str << "perform one task at a time, making as many function calls as required to complete a task. Once a task is complete call the task_complete function with details of the progress you've made.\n"
str << "the chat client prepends the date-time each message was sent at in the following format YYYY-MM-DD HH:mm:ss +ZZ:ZZ:ZZ"
}
)
Expand Down Expand Up @@ -177,19 +241,12 @@ module PlaceOS::Api
Payload.from_json grab_driver_status(chat, LLM_DRIVER, LLM_DRIVER_PROMPT)
end

private def build_executor(chat, payload : Payload?)
private def build_executor(chat)
executor = OpenAI::FunctionExecutor.new

description = if payload
"You have the following capability list, described in the following JSON:\n```json\n#{payload.capabilities.to_json}\n```\n" +
"if a request could benefit from these capabilities, obtain the list of functions by providing the id string."
else
"if a request could benefit from a capability, obtain the list of functions by providing the id string"
end

executor.add(
name: "list_function_schemas",
description: description,
description: "if a request could benefit from a capability, obtain the list of function schemas by providing the id string",
clz: FunctionDiscovery
) do |call|
request = call.as(FunctionDiscovery)
Expand All @@ -206,7 +263,8 @@ module PlaceOS::Api
executor.add(
name: "call_function",
description: "Executes functionality offered by a capability, you'll need to obtain the function schema to perform requests",
clz: FunctionExecutor) do |call|
clz: FunctionExecutor
) do |call|
request = call.as(FunctionExecutor)
reply = "No response received"
begin
Expand All @@ -219,6 +277,15 @@ module PlaceOS::Api
DriverResponse.new(reply).as(JSON::Serializable)
end

executor.add(
name: "task_complete",
description: "Once a task is complete, call this function with the details that are relevant to the conversion. Provide enough detail so you don't perform the actions again and can formulate a response to the user",
clz: TaskCompleted
) do |call|
request = call.as(TaskCompleted)
request.as(JSON::Serializable)
end

executor
end

Expand Down Expand Up @@ -279,6 +346,14 @@ module PlaceOS::Api
getter id : String
end

private struct TaskCompleted
extend OpenAI::FuncMarker
include JSON::Serializable

@[JSON::Field(description: "the details of the task that are relevant to continuing the conversion")]
getter details : String
end

private record DriverResponse, body : String do
include JSON::Serializable
end
Expand Down

0 comments on commit 0130973

Please sign in to comment.