diff --git a/chat.go b/chat.go index 0f56216fd..c09861c8c 100644 --- a/chat.go +++ b/chat.go @@ -14,8 +14,8 @@ const ( ) var ( - ErrChatCompletionInvalidModel = errors.New("currently, only gpt-3.5-turbo and gpt-3.5-turbo-0301 are supported") //nolint:lll - ErrChatCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateChatCompletionStream") //nolint:lll + ErrChatCompletionInvalidModel = errors.New("this model is not supported with this method, please use CreateCompletion client method instead") //nolint:lll + ErrChatCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateChatCompletionStream") //nolint:lll ) type ChatCompletionMessage struct { @@ -71,14 +71,12 @@ func (c *Client) CreateChatCompletion( return } - switch request.Model { - case GPT3Dot5Turbo0301, GPT3Dot5Turbo, GPT4, GPT40314, GPT432K0314, GPT432K: - default: + urlSuffix := "/chat/completions" + if !checkEndpointSupportsModel(urlSuffix, request.Model) { err = ErrChatCompletionInvalidModel return } - urlSuffix := "/chat/completions" req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) if err != nil { return diff --git a/chat_stream.go b/chat_stream.go index 26c4bfc15..821129295 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -37,8 +37,14 @@ func (c *Client) CreateChatCompletionStream( ctx context.Context, request ChatCompletionRequest, ) (stream *ChatCompletionStream, err error) { + urlSuffix := "/chat/completions" + if !checkEndpointSupportsModel(urlSuffix, request.Model) { + err = ErrChatCompletionInvalidModel + return + } + request.Stream = true - req, err := c.newStreamRequest(ctx, "POST", "/chat/completions", request) + req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request) if err != nil { return } diff --git a/chat_stream_test.go b/chat_stream_test.go index de604fa8b..a21ceee38 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -13,6 +13,28 @@ import ( "testing" ) +func TestChatCompletionsStreamWrongModel(t *testing.T) { + config := DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + + req := ChatCompletionRequest{ + MaxTokens: 5, + Model: "ada", + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + } + _, err := client.CreateChatCompletionStream(ctx, req) + if !errors.Is(err, ErrChatCompletionInvalidModel) { + t.Fatalf("CreateChatCompletion should return ErrChatCompletionInvalidModel, but returned: %v", err) + } +} + func TestCreateChatCompletionStream(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") diff --git a/chat_test.go b/chat_test.go index 8866ff2ae..2d569a423 100644 --- a/chat_test.go +++ b/chat_test.go @@ -34,7 +34,7 @@ func TestChatCompletionsWrongModel(t *testing.T) { } _, err := client.CreateChatCompletion(ctx, req) if !errors.Is(err, ErrChatCompletionInvalidModel) { - t.Fatalf("CreateChatCompletion should return wrong model error, but returned: %v", err) + t.Fatalf("CreateChatCompletion should return ErrChatCompletionInvalidModel, but returned: %v", err) } } diff --git a/completion.go b/completion.go index 22211d39f..6617e5a7f 100644 --- a/completion.go +++ b/completion.go @@ -45,6 +45,38 @@ const ( CodexCodeDavinci001 = "code-davinci-001" ) +var disabledModelsForEndpoints = map[string]map[string]bool{ + "/completions": { + GPT3Dot5Turbo: true, + GPT3Dot5Turbo0301: true, + GPT4: true, + GPT40314: true, + GPT432K: true, + GPT432K0314: true, + }, + "/chat/completions": { + CodexCodeDavinci002: true, + CodexCodeCushman001: true, + CodexCodeDavinci001: true, + GPT3TextDavinci003: true, + GPT3TextDavinci002: true, + GPT3TextCurie001: true, + GPT3TextBabbage001: true, + GPT3TextAda001: true, + GPT3TextDavinci001: true, + GPT3DavinciInstructBeta: true, + GPT3Davinci: true, + GPT3CurieInstructBeta: true, + GPT3Curie: true, + GPT3Ada: true, + GPT3Babbage: true, + }, +} + +func checkEndpointSupportsModel(endpoint, model string) bool { + return !disabledModelsForEndpoints[endpoint][model] +} + // CompletionRequest represents a request structure for completion API. type CompletionRequest struct { Model string `json:"model"` @@ -105,12 +137,12 @@ func (c *Client) CreateCompletion( return } - if request.Model == GPT3Dot5Turbo0301 || request.Model == GPT3Dot5Turbo { + urlSuffix := "/completions" + if !checkEndpointSupportsModel(urlSuffix, request.Model) { err = ErrCompletionUnsupportedModel return } - urlSuffix := "/completions" req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) if err != nil { return diff --git a/request_builder_test.go b/request_builder_test.go index f0f99ee5b..0f14f93fa 100644 --- a/request_builder_test.go +++ b/request_builder_test.go @@ -61,7 +61,7 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { t.Fatalf("Did not return error when request builder failed: %v", err) } - _, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{}) + _, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo}) if !errors.Is(err, errTestRequestBuilderFailed) { t.Fatalf("Did not return error when request builder failed: %v", err) } diff --git a/stream.go b/stream.go index 322d27fb9..944546a60 100644 --- a/stream.go +++ b/stream.go @@ -22,8 +22,14 @@ func (c *Client) CreateCompletionStream( ctx context.Context, request CompletionRequest, ) (stream *CompletionStream, err error) { + urlSuffix := "/completions" + if !checkEndpointSupportsModel(urlSuffix, request.Model) { + err = ErrCompletionUnsupportedModel + return + } + request.Stream = true - req, err := c.newStreamRequest(ctx, "POST", "/completions", request) + req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request) if err != nil { return } diff --git a/stream_test.go b/stream_test.go index ce560c644..7d01ebdda 100644 --- a/stream_test.go +++ b/stream_test.go @@ -12,6 +12,23 @@ import ( "testing" ) +func TestCompletionsStreamWrongModel(t *testing.T) { + config := DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := NewClientWithConfig(config) + + _, err := client.CreateCompletionStream( + context.Background(), + CompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + }, + ) + if !errors.Is(err, ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", err) + } +} + func TestCreateCompletionStream(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") @@ -140,7 +157,7 @@ func TestCreateCompletionStreamError(t *testing.T) { request := CompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, + Model: GPT3TextDavinci003, Prompt: "Hello!", Stream: true, }