Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Net HTTP persistent adapter compatibility #509

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions lib/middleware/beta.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# frozen_string_literal: true

module OpenAI
class MiddlewareBeta < Faraday::Middleware
BETA_REGEX = %r{
\A/#{OpenAI.configuration.api_version}
/(assistants|batches|threads|vector_stores)
}ix.freeze

def on_request(env)
return unless env[:url].path.match?(BETA_REGEX)

env[:request_headers].merge!(
{
"OpenAI-Beta" => "assistants=v2"
}
)
end
end
end
19 changes: 19 additions & 0 deletions lib/middleware/errors.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# frozen_string_literal: true

module OpenAI
class MiddlewareErrors < Faraday::Middleware
def call(env)
@app.call(env)
rescue Faraday::Error => e
raise e unless e.response.is_a?(Hash)

logger = Logger.new($stdout)
logger.formatter = proc do |_severity, _datetime, _progname, msg|
"\033[31mOpenAI HTTP Error (spotted in ruby-openai #{VERSION}): #{msg}\n\033[0m"
end
logger.error(e.response[:body])

raise e
end
end
end
17 changes: 2 additions & 15 deletions lib/openai.rb
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,8 @@ module OpenAI
class Error < StandardError; end
class ConfigurationError < Error; end

class MiddlewareErrors < Faraday::Middleware
def call(env)
@app.call(env)
rescue Faraday::Error => e
raise e unless e.response.is_a?(Hash)

logger = Logger.new($stdout)
logger.formatter = proc do |_severity, _datetime, _progname, msg|
"\033[31mOpenAI HTTP Error (spotted in ruby-openai #{VERSION}): #{msg}\n\033[0m"
end
logger.error(e.response[:body])

raise e
end
end
autoload :MiddlewareErrors, "middleware/errors"
autoload :MiddlewareBeta, "middleware/beta"

class Configuration
attr_accessor :access_token,
Expand Down
4 changes: 1 addition & 3 deletions lib/openai/assistants.rb
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
module OpenAI
class Assistants
BETA_VERSION = "v2".freeze

def initialize(client:)
@client = client.beta(assistants: OpenAI::Assistants::BETA_VERSION)
@client = client
end

def list
Expand Down
2 changes: 1 addition & 1 deletion lib/openai/batches.rb
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module OpenAI
class Batches
def initialize(client:)
@client = client.beta(assistants: OpenAI::Assistants::BETA_VERSION)
@client = client
end

def list(parameters: {})
Expand Down
24 changes: 19 additions & 5 deletions lib/openai/client.rb
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class Client
request_timeout
extra_headers
].freeze
attr_reader *CONFIG_KEYS, :faraday_middleware
attr_reader *CONFIG_KEYS

def initialize(config = {}, &faraday_middleware)
CONFIG_KEYS.each do |key|
Expand All @@ -23,7 +23,12 @@ def initialize(config = {}, &faraday_middleware)
config[key].nil? ? OpenAI.configuration.send(key) : config[key]
)
end
@faraday_middleware = faraday_middleware

@connection = build_connection
faraday_middleware&.call(@connection)

@multipart_connection = build_connection(multipart: true)
faraday_middleware&.call(@multipart_connection)
end

def chat(parameters: {})
Expand Down Expand Up @@ -102,9 +107,18 @@ def azure?
@api_type&.to_sym == :azure
end

def beta(apis)
dup.tap do |client|
client.add_headers("OpenAI-Beta": apis.map { |k, v| "#{k}=#{v}" }.join(";"))
private

attr_reader :connection, :multipart_connection

def build_connection(multipart: false)
Faraday.new do |faraday|
faraday.options[:timeout] = @request_timeout
faraday.request(:multipart) if multipart
faraday.use MiddlewareErrors if @log_errors
faraday.use MiddlewareBeta
faraday.response :raise_error
faraday.response :json
end
end
end
Expand Down
24 changes: 5 additions & 19 deletions lib/openai/http.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,32 @@ module HTTP
include HTTPHeaders

def get(path:, parameters: nil)
parse_jsonl(conn.get(uri(path: path), parameters) do |req|
parse_jsonl(connection.get(uri(path: path), parameters) do |req|
req.headers = headers
end&.body)
end

def post(path:)
parse_jsonl(conn.post(uri(path: path)) do |req|
parse_jsonl(connection.post(uri(path: path)) do |req|
req.headers = headers
end&.body)
end

def json_post(path:, parameters:)
conn.post(uri(path: path)) do |req|
connection.post(uri(path: path)) do |req|
configure_json_post_request(req, parameters)
end&.body
end

def multipart_post(path:, parameters: nil)
conn(multipart: true).post(uri(path: path)) do |req|
multipart_connection.post(uri(path: path)) do |req|
req.headers = headers.merge({ "Content-Type" => "multipart/form-data" })
req.body = multipart_parameters(parameters)
end&.body
end

def delete(path:)
conn.delete(uri(path: path)) do |req|
connection.delete(uri(path: path)) do |req|
req.headers = headers
end&.body
end
Expand Down Expand Up @@ -70,20 +70,6 @@ def to_json_stream(user_proc:)
end
end

def conn(multipart: false)
connection = Faraday.new do |f|
f.options[:timeout] = @request_timeout
f.request(:multipart) if multipart
f.use MiddlewareErrors if @log_errors
f.response :raise_error
f.response :json
end

@faraday_middleware&.call(connection)

connection
end

def uri(path:)
if azure?
base = File.join(@uri_base, path)
Expand Down
2 changes: 1 addition & 1 deletion lib/openai/images.rb
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module OpenAI
class Images
def initialize(client: nil)
def initialize(client:)
@client = client
end

Expand Down
2 changes: 1 addition & 1 deletion lib/openai/messages.rb
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module OpenAI
class Messages
def initialize(client:)
@client = client.beta(assistants: OpenAI::Assistants::BETA_VERSION)
@client = client
end

def list(thread_id:, parameters: {})
Expand Down
2 changes: 1 addition & 1 deletion lib/openai/run_steps.rb
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module OpenAI
class RunSteps
def initialize(client:)
@client = client.beta(assistants: OpenAI::Assistants::BETA_VERSION)
@client = client
end

def list(thread_id:, run_id:, parameters: {})
Expand Down
2 changes: 1 addition & 1 deletion lib/openai/runs.rb
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module OpenAI
class Runs
def initialize(client:)
@client = client.beta(assistants: OpenAI::Assistants::BETA_VERSION)
@client = client
end

def list(thread_id:, parameters: {})
Expand Down
2 changes: 1 addition & 1 deletion lib/openai/threads.rb
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module OpenAI
class Threads
def initialize(client:)
@client = client.beta(assistants: OpenAI::Assistants::BETA_VERSION)
@client = client
end

def retrieve(id:)
Expand Down
2 changes: 1 addition & 1 deletion lib/openai/vector_store_file_batches.rb
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module OpenAI
class VectorStoreFileBatches
def initialize(client:)
@client = client.beta(assistants: OpenAI::Assistants::BETA_VERSION)
@client = client
end

def list(vector_store_id:, id:, parameters: {})
Expand Down
2 changes: 1 addition & 1 deletion lib/openai/vector_store_files.rb
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module OpenAI
class VectorStoreFiles
def initialize(client:)
@client = client.beta(assistants: OpenAI::Assistants::BETA_VERSION)
@client = client
end

def list(vector_store_id:, parameters: {})
Expand Down
2 changes: 1 addition & 1 deletion lib/openai/vector_stores.rb
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module OpenAI
class VectorStores
def initialize(client:)
@client = client.beta(assistants: OpenAI::Assistants::BETA_VERSION)
@client = client
end

def list(parameters: {})
Expand Down
20 changes: 6 additions & 14 deletions spec/openai/client/client_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@
expect(c0.uri_base).to eq(OpenAI::Configuration::DEFAULT_URI_BASE)
expect(c0.send(:headers).values).to include("Bearer #{c0.access_token}")
expect(c0.send(:headers).values).to include(c0.organization_id)
expect(c0.send(:conn).options.timeout).to eq(OpenAI::Configuration::DEFAULT_REQUEST_TIMEOUT)
expect(c0.send(:connection).options.timeout).to eq(
OpenAI::Configuration::DEFAULT_REQUEST_TIMEOUT
)
expect(c0.send(:uri, path: "")).to include(OpenAI::Configuration::DEFAULT_URI_BASE)
expect(c0.send(:headers).values).to include("X-Default")
expect(c0.send(:headers).values).not_to include("X-Test")
Expand All @@ -55,7 +57,7 @@
expect(c1.request_timeout).to eq(60)
expect(c1.uri_base).to eq("https://oai.hconeai.com/")
expect(c1.send(:headers).values).to include(c1.access_token)
expect(c1.send(:conn).options.timeout).to eq(60)
expect(c1.send(:connection).options.timeout).to eq(60)
expect(c1.send(:uri, path: "")).to include("https://oai.hconeai.com/")
expect(c1.send(:headers).values).not_to include("X-Default")
expect(c1.send(:headers).values).to include("X-Test")
Expand All @@ -67,7 +69,7 @@
expect(c2.uri_base).to eq("https://example.com/")
expect(c2.send(:headers).values).to include("Bearer #{c2.access_token}")
expect(c2.send(:headers).values).to include(c2.organization_id)
expect(c2.send(:conn).options.timeout).to eq(1)
expect(c2.send(:connection).options.timeout).to eq(1)
expect(c2.send(:uri, path: "")).to include("https://example.com/")
expect(c2.send(:headers).values).to include("X-Default")
expect(c2.send(:headers).values).not_to include("X-Test")
Expand Down Expand Up @@ -112,14 +114,6 @@
end
end

context "when using beta APIs" do
let(:client) { OpenAI::Client.new.beta(assistants: "v2") }

it "sends the appropriate header value" do
expect(client.send(:headers)["OpenAI-Beta"]).to eq "assistants=v2"
end
end

context "with a block" do
let(:client) do
OpenAI::Client.new do |client|
Expand All @@ -128,9 +122,7 @@
end

it "sets the logger" do
connection = Faraday.new
client.faraday_middleware.call(connection)
expect(connection.builder.handlers).to include Faraday::Response::Logger
expect(client.send(:connection).builder.handlers).to include Faraday::Response::Logger
end
end
end