From 3063e676bf5932024d76be8e8d9e41df06d4e8cc Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Tue, 7 Nov 2023 16:20:59 +0100 Subject: [PATCH] Feat Implement assistants API (#535) * chore: implement assistants API * fix * fix * chore: add tests * fix tests * fix linting --- assistant.go | 260 ++++++++++++++++++++++++++++++++++++++++++++++ assistant_test.go | 202 +++++++++++++++++++++++++++++++++++ client_test.go | 27 +++++ 3 files changed, 489 insertions(+) create mode 100644 assistant.go create mode 100644 assistant_test.go diff --git a/assistant.go b/assistant.go new file mode 100644 index 000000000..d75eebef3 --- /dev/null +++ b/assistant.go @@ -0,0 +1,260 @@ +package openai + +import ( + "context" + "fmt" + "net/http" + "net/url" +) + +const ( + assistantsSuffix = "/assistants" + assistantsFilesSuffix = "/files" +) + +type Assistant struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Model string `json:"model"` + Instructions *string `json:"instructions,omitempty"` + Tools []any `json:"tools,omitempty"` + + httpHeader +} + +type AssistantTool struct { + Type string `json:"type"` +} + +type AssistantToolCodeInterpreter struct { + AssistantTool +} + +type AssistantToolRetrieval struct { + AssistantTool +} + +type AssistantToolFunction struct { + AssistantTool + Function FunctionDefinition `json:"function"` +} + +type AssistantRequest struct { + Model string `json:"model"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Instructions *string `json:"instructions,omitempty"` + Tools []any `json:"tools,omitempty"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +// AssistantsList is a list of assistants. +type AssistantsList struct { + Assistants []Assistant `json:"data"` + + httpHeader +} + +type AssistantFile struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + AssistantID string `json:"assistant_id"` + + httpHeader +} + +type AssistantFileRequest struct { + FileID string `json:"file_id"` +} + +type AssistantFilesList struct { + AssistantFiles []AssistantFile `json:"data"` + + httpHeader +} + +// CreateAssistant creates a new assistant. +func (c *Client) CreateAssistant(ctx context.Context, request AssistantRequest) (response Assistant, err error) { + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(assistantsSuffix), withBody(request)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// RetrieveAssistant retrieves an assistant. +func (c *Client) RetrieveAssistant( + ctx context.Context, + assistantID string, +) (response Assistant, err error) { + urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// ModifyAssistant modifies an assistant. +func (c *Client) ModifyAssistant( + ctx context.Context, + assistantID string, + request AssistantRequest, +) (response Assistant, err error) { + urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// DeleteAssistant deletes an assistant. +func (c *Client) DeleteAssistant( + ctx context.Context, + assistantID string, +) (response Assistant, err error) { + urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// ListAssistants Lists the currently available assistants. +func (c *Client) ListAssistants( + ctx context.Context, + limit *int, + order *string, + after *string, + before *string, +) (reponse AssistantsList, err error) { + urlValues := url.Values{} + if limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *limit)) + } + if order != nil { + urlValues.Add("order", *order) + } + if after != nil { + urlValues.Add("after", *after) + } + if before != nil { + urlValues.Add("before", *before) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s%s", assistantsSuffix, encodedValues) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &reponse) + return +} + +// CreateAssistantFile creates a new assistant file. +func (c *Client) CreateAssistantFile( + ctx context.Context, + assistantID string, + request AssistantFileRequest, +) (response AssistantFile, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s", assistantsSuffix, assistantID, assistantsFilesSuffix) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), + withBody(request)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// RetrieveAssistantFile retrieves an assistant file. +func (c *Client) RetrieveAssistantFile( + ctx context.Context, + assistantID string, + fileID string, +) (response AssistantFile, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", assistantsSuffix, assistantID, assistantsFilesSuffix, fileID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// DeleteAssistantFile deletes an existing file. +func (c *Client) DeleteAssistantFile( + ctx context.Context, + assistantID string, + fileID string, +) (err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", assistantsSuffix, assistantID, assistantsFilesSuffix, fileID) + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, nil) + return +} + +// ListAssistantFiles Lists the currently available files for an assistant. +func (c *Client) ListAssistantFiles( + ctx context.Context, + assistantID string, + limit *int, + order *string, + after *string, + before *string, +) (response AssistantFilesList, err error) { + urlValues := url.Values{} + if limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *limit)) + } + if order != nil { + urlValues.Add("order", *order) + } + if after != nil { + urlValues.Add("after", *after) + } + if before != nil { + urlValues.Add("before", *before) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s/%s%s%s", assistantsSuffix, assistantID, assistantsFilesSuffix, encodedValues) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/assistant_test.go b/assistant_test.go new file mode 100644 index 000000000..eb6f42458 --- /dev/null +++ b/assistant_test.go @@ -0,0 +1,202 @@ +package openai_test + +import ( + "context" + + openai "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" + + "encoding/json" + "fmt" + "net/http" + "testing" +) + +// TestAssistant Tests the assistant endpoint of the API using the mocked server. +func TestAssistant(t *testing.T) { + assistantID := "asst_abc123" + assistantName := "Ambrogio" + assistantDescription := "Ambrogio is a friendly assistant." + assitantInstructions := `You are a personal math tutor. +When asked a question, write and run Python code to answer the question.` + assistantFileID := "file-wB6RM6wHdA49HfS2DJ9fEyrH" + limit := 20 + order := "desc" + after := "asst_abc122" + before := "asst_abc124" + + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler( + "/v1/assistants/"+assistantID+"/files/"+assistantFileID, + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.AssistantFile{ + ID: assistantFileID, + Object: "assistant.file", + CreatedAt: 1234567890, + AssistantID: assistantID, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodDelete { + fmt.Fprintln(w, `{ + id: "file-wB6RM6wHdA49HfS2DJ9fEyrH", + object: "assistant.file.deleted", + deleted: true + }`) + } + }, + ) + + server.RegisterHandler( + "/v1/assistants/"+assistantID+"/files", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.AssistantFilesList{ + AssistantFiles: []openai.AssistantFile{ + { + ID: assistantFileID, + Object: "assistant.file", + CreatedAt: 1234567890, + AssistantID: assistantID, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodPost { + var request openai.AssistantFileRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.AssistantFile{ + ID: request.FileID, + Object: "assistant.file", + CreatedAt: 1234567890, + AssistantID: assistantID, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/assistants/"+assistantID, + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + resBytes, _ := json.Marshal(openai.Assistant{ + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: &assistantName, + Model: openai.GPT4TurboPreview, + Description: &assistantDescription, + Instructions: &assitantInstructions, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodPost: + var request openai.AssistantRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Assistant{ + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: request.Name, + Model: request.Model, + Description: request.Description, + Instructions: request.Instructions, + Tools: request.Tools, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodDelete: + fmt.Fprintln(w, `{ + "id": "asst_abc123", + "object": "assistant.deleted", + "deleted": true + }`) + } + }, + ) + + server.RegisterHandler( + "/v1/assistants", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.AssistantRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Assistant{ + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: request.Name, + Model: request.Model, + Description: request.Description, + Instructions: request.Instructions, + Tools: request.Tools, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.AssistantsList{ + Assistants: []openai.Assistant{ + { + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: &assistantName, + Model: openai.GPT4TurboPreview, + Description: &assistantDescription, + Instructions: &assitantInstructions, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + ctx := context.Background() + + _, err := client.CreateAssistant(ctx, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assitantInstructions, + }) + checks.NoError(t, err, "CreateAssistant error") + + _, err = client.RetrieveAssistant(ctx, assistantID) + checks.NoError(t, err, "RetrieveAssistant error") + + _, err = client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assitantInstructions, + }) + checks.NoError(t, err, "ModifyAssistant error") + + _, err = client.DeleteAssistant(ctx, assistantID) + checks.NoError(t, err, "DeleteAssistant error") + + _, err = client.ListAssistants(ctx, &limit, &order, &after, &before) + checks.NoError(t, err, "ListAssistants error") + + _, err = client.CreateAssistantFile(ctx, assistantID, openai.AssistantFileRequest{ + FileID: assistantFileID, + }) + checks.NoError(t, err, "CreateAssistantFile error") + + _, err = client.ListAssistantFiles(ctx, assistantID, &limit, &order, &after, &before) + checks.NoError(t, err, "ListAssistantFiles error") + + _, err = client.RetrieveAssistantFile(ctx, assistantID, assistantFileID) + checks.NoError(t, err, "RetrieveAssistantFile error") + + err = client.DeleteAssistantFile(ctx, assistantID, assistantFileID) + checks.NoError(t, err, "DeleteAssistantFile error") +} diff --git a/client_test.go b/client_test.go index 2c1d749ed..bff2597c5 100644 --- a/client_test.go +++ b/client_test.go @@ -274,6 +274,33 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"DeleteFineTuneModel", func() (any, error) { return client.DeleteFineTuneModel(ctx, "") }}, + {"CreateAssistant", func() (any, error) { + return client.CreateAssistant(ctx, AssistantRequest{}) + }}, + {"RetrieveAssistant", func() (any, error) { + return client.RetrieveAssistant(ctx, "") + }}, + {"ModifyAssistant", func() (any, error) { + return client.ModifyAssistant(ctx, "", AssistantRequest{}) + }}, + {"DeleteAssistant", func() (any, error) { + return client.DeleteAssistant(ctx, "") + }}, + {"ListAssistants", func() (any, error) { + return client.ListAssistants(ctx, nil, nil, nil, nil) + }}, + {"CreateAssistantFile", func() (any, error) { + return client.CreateAssistantFile(ctx, "", AssistantFileRequest{}) + }}, + {"ListAssistantFiles", func() (any, error) { + return client.ListAssistantFiles(ctx, "", nil, nil, nil, nil) + }}, + {"RetrieveAssistantFile", func() (any, error) { + return client.RetrieveAssistantFile(ctx, "", "") + }}, + {"DeleteAssistantFile", func() (any, error) { + return nil, client.DeleteAssistantFile(ctx, "", "") + }}, } for _, testCase := range testCases {