Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
naqvis committed Oct 17, 2023
1 parent 08b0748 commit 473fb3f
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 63 deletions.
2 changes: 1 addition & 1 deletion spec/openai/chatgpt_spec.cr
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
require "../helper"

module PlaceOS::Api
describe ChatGPT, focus: true do
describe ChatGPT do
::Spec.before_each do
Model::ChatMessage.clear
Model::Chat.clear
Expand Down
6 changes: 2 additions & 4 deletions src/constants.cr
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@ module PlaceOS::Api
PROD = ENV["SG_ENV"]?.try(&.downcase) == "production"

# Open AI
OPENAI_API_KEY = ENV["OPENAI_API_KEY"]?
OPENAI_API_KEY_PATH = ENV["OPENAI_API_KEY_PATH"]?
ORGANIZATION = ENV["OPENAI_ORGANIZATION"]? || ""
OPENAI_API_BASE = ENV["OPENAI_API_BASE"]? # Set this to Azure URL only if Azure OpenAI is used
OPENAI_API_KEY = ENV["OPENAI_API_KEY"]?
OPENAI_API_BASE = ENV["OPENAI_API_BASE"]? # Set this to Azure URL only if Azure OpenAI is used

# CHANGELOG
#################################################################################################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,16 @@ module PlaceOS::Api
before_action :can_read, only: [:index, :show]
before_action :can_write, only: [:chat, :delete]

@[AC::Route::Filter(:before_action)]
def check_authority
unless @authority = current_authority
Log.warn { {message: "authority not found", action: "authorize!", host: request.hostname} }
raise Error::Unauthorized.new "authority not found"
end
end

getter chat_manager : ChatGPT::ChatManager { ChatGPT::ChatManager.new(self) }
getter! authority : Model::Authority?

# list user chats
@[AC::Route::GET("/")]
Expand Down Expand Up @@ -38,17 +47,11 @@ module PlaceOS::Api
def chat(socket, system_id : String,
@[AC::Param::Info(name: "resume", description: "To resume previous chat session. Provide session chat id", example: "chats-xxxx")]
resume : String? = nil) : Nil
chat = (resume && PlaceOS::Model::Chat.find!(resume.not_nil!)) || begin
PlaceOS::Model::Chat.create!(user_id: current_user.id.as(String), system_id: system_id, summary: "")
end

begin
chat_manager.start_chat(socket, chat, !!resume)
rescue e : RemoteDriver::Error
handle_execute_error(e)
rescue e
render_error(HTTP::Status::INTERNAL_SERVER_ERROR, e.message, backtrace: e.backtrace)
end
chat_manager.start_session(socket, (resume && PlaceOS::Model::Chat.find!(resume.not_nil!)) || nil, system_id)
rescue e : RemoteDriver::Error
handle_execute_error(e)
rescue e
render_error(HTTP::Status::INTERNAL_SERVER_ERROR, e.message, backtrace: e.backtrace)
end

# remove chat and associated history
Expand All @@ -60,6 +63,18 @@ module PlaceOS::Api
end
chat.destroy
end

record Config, api_key : String, api_base : String?

protected def config
if internals = authority.internals["openai"]
key = internals["api_key"]?.try &.as_s || Api::OPENAI_API_KEY || raise Error::NotFound.new("missing openai api_key configuration")
Config.new(key, internals["api_base"]?.try &.as_s || Api::OPENAI_API_BASE)
else
key = Api::OPENAI_API_KEY || raise Error::NotFound.new("missing openai api_key configuration")
Config.new(key, Api::OPENAI_API_BASE)
end
end
end
end

Expand Down
95 changes: 48 additions & 47 deletions src/placeos-rest-api/controllers/openai/chat_manager.cr
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@ module PlaceOS::Api
Log = ::Log.for(self)
alias RemoteDriver = ::PlaceOS::Driver::Proxy::RemoteDriver

private getter chat_sockets = {} of String => {HTTP::WebSocket, OpenAI::Client, OpenAI::ChatCompletionRequest, OpenAI::FunctionExecutor}
private getter ping_tasks : Hash(String, Tasker::Repeat(Nil)) = {} of String => Tasker::Repeat(Nil)
private getter ws_sockets = {} of UInt64 => {HTTP::WebSocket, String, OpenAI::Client, OpenAI::ChatCompletionRequest, OpenAI::FunctionExecutor}
private getter ws_ping_tasks : Hash(UInt64, Tasker::Repeat(Nil)) = {} of UInt64 => Tasker::Repeat(Nil)

private getter ws_lock = Mutex.new(protection: :reentrant)
private getter session_ch = Channel(Nil).new
private getter app : ChatGPT

LLM_DRIVER = "LLM"
Expand All @@ -20,63 +19,69 @@ module PlaceOS::Api
def initialize(@app)
end

def start_chat(ws : HTTP::WebSocket, chat : PlaceOS::Model::Chat, resume : Bool = false)
chat_id = chat.id.as(String)
update_summary = !resume
chat_prompt =
if resume
Log.debug { {chat_id: chat_id, message: "resuming chat session"} }
nil
else
Log.debug { {chat_id: chat_id, message: "starting new chat session"} }
driver_prompt(chat)
end

def start_session(ws : HTTP::WebSocket, existing_chat : PlaceOS::Model::Chat?, system_id : String)
ws_lock.synchronize do
if existing_socket = chat_sockets[chat_id]?
ws_id = ws.object_id
if existing_socket = ws_sockets[ws_id]?
existing_socket[0].close rescue nil
end

client, executor, chat_completion = setup(chat, chat_prompt)
chat_sockets[chat_id] = {ws, client, chat_completion, executor}
if chat = existing_chat
Log.debug { {chat_id: chat.id, message: "resuming chat session"} }
client, executor, chat_completion = setup(chat, nil)
ws_sockets[ws_id] = {ws, chat.id.as(String), client, chat_completion, executor}
else
Log.debug { {message: "starting new chat session"} }
end

ping_tasks[chat_id] = Tasker.every(10.seconds) do
ws_ping_tasks[ws_id] = Tasker.every(10.seconds) do
ws.ping rescue nil
nil
end

ws.on_message do |message|
if (update_summary)
PlaceOS::Model::Chat.update(chat_id, {summary: message})
update_summary = false
end
resp = openai_interaction(client, chat_completion, executor, message, chat_id)
ws.send(resp.to_json)
end
ws.on_message { |message| manage_chat(ws, message, system_id) }

ws.on_close do
if task = ping_tasks.delete(chat_id)
if task = ws_ping_tasks.delete(ws_id)
task.cancel
end
chat_sockets.delete(chat_id)
ws_sockets.delete(ws_id)
end
end
end

private def setup(chat, chat_prompt)
private def manage_chat(ws : HTTP::WebSocket, message : String, system_id : String)
ws_lock.synchronize do
ws_id = ws.object_id
_, chat_id, client, completion_req, executor = ws_sockets[ws_id]? || begin
chat = PlaceOS::Model::Chat.create!(user_id: app.current_user.id.as(String), system_id: system_id, summary: message)
id = chat.id.as(String)
c, e, req = setup(chat, driver_prompt(chat))
ws_sockets[ws_id] = {ws, id, c, req, e}
{ws, id, c, req, e}
end
resp = openai_interaction(client, completion_req, executor, message, chat_id)
ws.send(resp.to_json)
end
end

private def setup(chat, chat_payload)
client = build_client
executor = build_executor(chat)
chat_completion = build_completion(build_prompt(chat, chat_prompt), executor.functions)
chat_completion = build_completion(build_prompt(chat, chat_payload), executor.functions)

{client, executor, chat_completion}
end

private def build_client
if azure = PlaceOS::Api::OPENAI_API_BASE
OpenAI::Client.azure(api_key: nil, api_endpoint: azure)
else
OpenAI::Client.new
end
app_config = app.config
config = if base = app_config.api_base
OpenAI::Client::Config.azure(api_key: app_config.api_key, api_base: base)
else
OpenAI::Client::Config.default(api_key: app_config.api_key)
end

OpenAI::Client.new(config)
end

private def build_completion(messages, functions)
Expand Down Expand Up @@ -115,13 +120,13 @@ module PlaceOS::Api
save_history(chat_id, PlaceOS::Model::ChatMessage::Role.parse(msg.role.to_s), msg.content || "", msg.name, msg.function_call.try &.arguments)
end

private def build_prompt(chat : PlaceOS::Model::Chat, chat_prompt : ChatPrompt?)
private def build_prompt(chat : PlaceOS::Model::Chat, chat_payload : Payload?)
messages = [] of OpenAI::ChatMessage

if prompt = chat_prompt
messages << OpenAI::ChatMessage.new(role: :assistant, content: prompt.payload.prompt)
messages << OpenAI::ChatMessage.new(role: :assistant, content: "You have the following capabilities: #{prompt.payload.capabilities.to_json}")
messages << OpenAI::ChatMessage.new(role: :assistant, content: "You have access to the following API: #{function_schemas(chat, prompt.payload.capabilities).to_json}")
if payload = chat_payload
messages << OpenAI::ChatMessage.new(role: :assistant, content: payload.prompt)
messages << OpenAI::ChatMessage.new(role: :assistant, content: "You have the following capabilities: #{payload.capabilities.to_json}")
messages << OpenAI::ChatMessage.new(role: :assistant, content: "You have access to the following API: #{function_schemas(chat, payload.capabilities).to_json}")
messages << OpenAI::ChatMessage.new(role: :assistant, content: "If you were asked to perform any function of given capabilities, perform the action and reply with a confirmation telling what you have done.")

messages.each { |m| save_history(chat.id.as(String), m) }
Expand All @@ -144,10 +149,10 @@ module PlaceOS::Api
messages
end

private def driver_prompt(chat : PlaceOS::Model::Chat) : ChatPrompt?
private def driver_prompt(chat : PlaceOS::Model::Chat) : Payload?
resp, code = exec_driver_func(chat, LLM_DRIVER, LLM_DRIVER_CHAT, nil)
if code > 200 && code < 299
ChatPrompt.new(message: "", payload: Payload.from_json(resp))
Payload.from_json(resp)
end
end

Expand Down Expand Up @@ -218,10 +223,6 @@ module PlaceOS::Api
include JSON::Serializable
end

record ChatPrompt, message : String, payload : Payload do
include JSON::Serializable
end

struct Payload
include JSON::Serializable

Expand Down

0 comments on commit 473fb3f

Please sign in to comment.