Skip to content

Commit

Permalink
Better configuration (#79)
Browse files Browse the repository at this point in the history
* Configurable Transport (#75)

* new functions to allow HTTPClient configuration

* updated go.mod for testing from remote

* updated go.mod for remote testing

* revert go.mod replace directives

* Fixed NewOrgClientWithTransport comment

* Make client fully configurable

* make empty messages limit configurable #70 #71

* make auth token private in config

* add docs

* lint

---------

Co-authored-by: Michael Fox <[email protected]>
  • Loading branch information
sashabaranov and mwillfox committed Feb 20, 2023
1 parent 133d2c9 commit 1eb5d62
Show file tree
Hide file tree
Showing 10 changed files with 89 additions and 53 deletions.
45 changes: 18 additions & 27 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,43 +6,34 @@ import (
"net/http"
)

const apiURLv1 = "https://api.openai.com/v1"

func newTransport() *http.Client {
return &http.Client{}
}

// Client is OpenAI GPT-3 API client.
type Client struct {
BaseURL string
HTTPClient *http.Client
authToken string
idOrg string
config ClientConfig
}

// NewClient creates new OpenAI API client.
func NewClient(authToken string) *Client {
return &Client{
BaseURL: apiURLv1,
HTTPClient: newTransport(),
authToken: authToken,
idOrg: "",
}
config := DefaultConfig(authToken)
return &Client{config}
}

// NewClientWithConfig creates new OpenAI API client for specified config.
func NewClientWithConfig(config ClientConfig) *Client {
return &Client{config}
}

// NewOrgClient creates new OpenAI API client for specified Organization ID.
//
// Deprecated: Please use NewClientWithConfig.
func NewOrgClient(authToken, org string) *Client {
return &Client{
BaseURL: apiURLv1,
HTTPClient: newTransport(),
authToken: authToken,
idOrg: org,
}
config := DefaultConfig(authToken)
config.OrgID = org
return &Client{config}
}

func (c *Client) sendRequest(req *http.Request, v interface{}) error {
req.Header.Set("Accept", "application/json; charset=utf-8")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.authToken))
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))

// Check whether Content-Type is already set, Upload Files API requires
// Content-Type == multipart/form-data
Expand All @@ -51,11 +42,11 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error {
req.Header.Set("Content-Type", "application/json; charset=utf-8")
}

if len(c.idOrg) > 0 {
req.Header.Set("OpenAI-Organization", c.idOrg)
if len(c.config.OrgID) > 0 {
req.Header.Set("OpenAI-Organization", c.config.OrgID)
}

res, err := c.HTTPClient.Do(req)
res, err := c.config.HTTPClient.Do(req)
if err != nil {
return err
}
Expand Down Expand Up @@ -86,5 +77,5 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error {
}

func (c *Client) fullURL(suffix string) string {
return fmt.Sprintf("%s%s", c.BaseURL, suffix)
return fmt.Sprintf("%s%s", c.config.BaseURL, suffix)
}
6 changes: 4 additions & 2 deletions api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,10 @@ func TestAPIError(t *testing.T) {

func TestRequestError(t *testing.T) {
var err error
c := NewClient("dummy")
c.BaseURL = "https://httpbin.org/status/418?"

config := DefaultConfig("dummy")
config.BaseURL = "https://httpbin.org/status/418?"
c := NewClientWithConfig(config)
ctx := context.Background()
_, err = c.ListEngines(ctx)
if err == nil {
Expand Down
5 changes: 3 additions & 2 deletions completion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ func TestCompletions(t *testing.T) {
ts.Start()
defer ts.Close()

client := NewClient(test.GetTestToken())
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
client.BaseURL = ts.URL + "/v1"

req := CompletionRequest{
MaxTokens: 5,
Expand Down
33 changes: 33 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package gogpt

import (
"net/http"
)

const (
apiURLv1 = "https://api.openai.com/v1"
defaultEmptyMessagesLimit uint = 300
)

// ClientConfig is a configuration of a client.
type ClientConfig struct {
authToken string

HTTPClient *http.Client

BaseURL string
OrgID string

EmptyMessagesLimit uint
}

func DefaultConfig(authToken string) ClientConfig {
return ClientConfig{
HTTPClient: &http.Client{},
BaseURL: apiURLv1,
OrgID: "",
authToken: authToken,

EmptyMessagesLimit: defaultEmptyMessagesLimit,
}
}
5 changes: 3 additions & 2 deletions edits_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ func TestEdits(t *testing.T) {
ts.Start()
defer ts.Close()

client := NewClient(test.GetTestToken())
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
client.BaseURL = ts.URL + "/v1"

// create an edit request
model := "ada"
Expand Down
5 changes: 3 additions & 2 deletions files_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ func TestFileUpload(t *testing.T) {
ts.Start()
defer ts.Close()

client := NewClient(test.GetTestToken())
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
client.BaseURL = ts.URL + "/v1"

req := FileRequest{
FileName: "test.go",
Expand Down
10 changes: 6 additions & 4 deletions image_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ func TestImages(t *testing.T) {
ts.Start()
defer ts.Close()

client := NewClient(test.GetTestToken())
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
client.BaseURL = ts.URL + "/v1"

req := ImageRequest{}
req.Prompt = "Lorem ipsum"
Expand Down Expand Up @@ -94,9 +95,10 @@ func TestImageEdit(t *testing.T) {
ts.Start()
defer ts.Close()

client := NewClient(test.GetTestToken())
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
client.BaseURL = ts.URL + "/v1"

origin, err := os.Create("image.png")
if err != nil {
Expand Down
5 changes: 3 additions & 2 deletions moderation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ func TestModerations(t *testing.T) {
ts.Start()
defer ts.Close()

client := NewClient(test.GetTestToken())
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
client.BaseURL = ts.URL + "/v1"

// create an edit request
model := "text-moderation-stable"
Expand Down
13 changes: 8 additions & 5 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,18 @@ import (
)

var (
emptyMessagesLimit = 300
ErrTooManyEmptyStreamMessages = errors.New("stream has sent too many empty messages")
)

type CompletionStream struct {
emptyMessagesLimit uint

reader *bufio.Reader
response *http.Response
}

func (stream *CompletionStream) Recv() (response CompletionResponse, err error) {
emptyMessagesCount := 0
var emptyMessagesCount uint

waitForData:
line, err := stream.reader.ReadBytes('\n')
Expand All @@ -33,7 +34,7 @@ waitForData:
line = bytes.TrimSpace(line)
if !bytes.HasPrefix(line, headerData) {
emptyMessagesCount++
if emptyMessagesCount > emptyMessagesLimit {
if emptyMessagesCount > stream.emptyMessagesLimit {
err = ErrTooManyEmptyStreamMessages
return
}
Expand Down Expand Up @@ -74,18 +75,20 @@ func (c *Client) CreateCompletionStream(
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Connection", "keep-alive")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.authToken))
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
if err != nil {
return
}

req = req.WithContext(ctx)
resp, err := c.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close()
resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close()
if err != nil {
return
}

stream = &CompletionStream{
emptyMessagesLimit: c.config.EmptyMessagesLimit,

reader: bufio.NewReader(resp.Body),
response: resp,
}
Expand Down
15 changes: 8 additions & 7 deletions stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,15 @@ func TestCreateCompletionStream(t *testing.T) {
defer server.Close()

// Client portion of the test
client := NewClient(test.GetTestToken())
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &tokenRoundTripper{
test.GetTestToken(),
http.DefaultTransport,
}

client := NewClientWithConfig(config)
ctx := context.Background()
client.BaseURL = server.URL + "/v1"

request := CompletionRequest{
Prompt: "Ex falso quodlibet",
Expand All @@ -48,11 +54,6 @@ func TestCreateCompletionStream(t *testing.T) {
Stream: true,
}

client.HTTPClient.Transport = &tokenRoundTripper{
test.GetTestToken(),
http.DefaultTransport,
}

stream, err := client.CreateCompletionStream(ctx, request)
if err != nil {
t.Errorf("CreateCompletionStream returned error: %v", err)
Expand Down

0 comments on commit 1eb5d62

Please sign in to comment.