From 4bdb49eb93767543ed4901d21f32c57b83bb4a08 Mon Sep 17 00:00:00 2001 From: LemonHX Date: Fri, 1 Mar 2024 15:11:50 +0800 Subject: [PATCH] fix array "required" and "enum" in function calling --- grpcServer/relay.go | 25 ++++ .../relay_function_calling_middleware.go | 30 +++- grpcServer/relay_gin.go | 3 + relay/reqTransformer/ChatGPT.go | 25 +++- tests_grpc/baichuan_test.go | 2 +- tests_grpc/chatglm_test.go | 2 +- tests_grpc/moonshot_test.go | 133 ++++++++++++++++++ tests_grpc/openai_test.go | 4 +- tests_http/baichuan_test.go | 6 +- tests_http/chatglm_test.go | 8 +- tests_http/moonshot_test.go | 124 ++++++++++++++++ 11 files changed, 344 insertions(+), 18 deletions(-) create mode 100644 tests_grpc/moonshot_test.go create mode 100644 tests_http/moonshot_test.go diff --git a/grpcServer/relay.go b/grpcServer/relay.go index 2a0b3b3..51bd56f 100644 --- a/grpcServer/relay.go +++ b/grpcServer/relay.go @@ -27,6 +27,7 @@ const CHATGLM_LLM_API = "chatglm" const AZURE_OPENAI_LLM_API = "azure_openai" const BAICHUAN_LLM_API = "baichuan" const GEMINI_LLM_API = "gemini" +const MOONSHOT_LLM_API = "moonshot" func (uno *UnoForwardServer) BlockingRequestLLM(ctx context.Context, rs *model.LLMRequestSchema) (*model.LLMResponseSchema, error) { info := rs.GetLlmRequestInfo() @@ -35,6 +36,18 @@ func (uno *UnoForwardServer) BlockingRequestLLM(ctx context.Context, rs *model.L cli := NewOpenAIClient(info) return OpenAIChatCompletion(cli, rs) + case MOONSHOT_LLM_API: + cli := NewOpenAIClient(info) + if functionCallingRequestMake(rs) { + res, err := OpenAIChatCompletion(cli, rs) + if err != nil { + return nil, status.Errorf(codes.Internal, err.Error()) + } + functionCallingResponseHandle(res) + return res, nil + } + return OpenAIChatCompletion(cli, rs) + case CHATGLM_LLM_API: cli := NewChatGLMClient(info) return ChatGLMChatCompletion(cli, rs) @@ -69,6 +82,18 @@ func (uno *UnoForwardServer) StreamRequestLLM(rs *model.LLMRequestSchema, sv mod case OPENAI_LLM_API: cli := NewOpenAIClient(info) return OpenAIChatCompletionStreaming(cli, rs, sv) + case MOONSHOT_LLM_API: + cli := NewOpenAIClient(info) + if functionCallingRequestMake(rs) { + res, err := OpenAIChatCompletion(cli, rs) + if err != nil { + return status.Errorf(codes.Internal, err.Error()) + } + functionCallingResponseHandle(res) + functionCallingResponseToStream(res, sv) + return nil + } + return OpenAIChatCompletionStreaming(cli, rs, sv) case CHATGLM_LLM_API: cli := NewChatGLMClient(info) return ChatGLMChatCompletionStreaming(cli, rs, sv) diff --git a/grpcServer/relay_function_calling_middleware.go b/grpcServer/relay_function_calling_middleware.go index 26d1d68..872594e 100644 --- a/grpcServer/relay_function_calling_middleware.go +++ b/grpcServer/relay_function_calling_middleware.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "log" + "strings" "time" openai "github.com/sashabaranov/go-openai" @@ -14,8 +15,34 @@ import ( type Callings struct { } +func removeWhiteSpaces(str string) string { + var b strings.Builder + b.Grow(len(str)) + for _, ch := range str { + if ch != ' ' && ch != '\t' && ch != '\n' { + b.WriteRune(ch) + } + } + return b.String() +} func getParams(url string) (params []ChatGLM.GLMFunctionCall, err error) { + url = removeWhiteSpaces(url) err = json.Unmarshal([]byte(url), ¶ms) + if err != nil { + // ```json<>``` + uurl1, found1 := strings.CutPrefix(url, "```json") + if found1 { + uurl1, _ = strings.CutSuffix(uurl1, "```") + err = json.Unmarshal([]byte(uurl1), ¶ms) + return + } + uurl2, found2 := strings.CutPrefix(url, "```") + if found2 { + uurl2, _ = strings.CutSuffix(uurl2, "```") + err = json.Unmarshal([]byte(uurl2), ¶ms) + return + } + } return } @@ -53,7 +80,7 @@ func functionCallingRequestMake(req *model.LLMRequestSchema) bool { tools[i].Function.Parameters.(map[string]any)["properties"].(map[string]any)[f.Parameters[j].Name].(map[string]any)["enum"] = f.Parameters[j].Enums } } - log.Printf("%#v\n", tools[i]) + log.Printf("find function: %#v", tools[i].Function) } tools_json_byte, err := json.Marshal(tools) @@ -80,7 +107,6 @@ func functionCallingResponseHandle(resp *model.LLMResponseSchema) { if err != nil { log.Fatal(err) } - log.Printf("%#v\n", function_calling) resp.Message.Content = "" resp.ToolCalls = []*model.ToolCall{} for _, f := range function_calling { diff --git a/grpcServer/relay_gin.go b/grpcServer/relay_gin.go index 665dd55..39b3240 100644 --- a/grpcServer/relay_gin.go +++ b/grpcServer/relay_gin.go @@ -26,6 +26,9 @@ func getProvider(m string) (string, error) { if strings.Contains(m, "gpt") { return "openai", nil } + if strings.Contains(m, "moonshot") { + return "moonshot", nil + } return "", errors.New("could not get provider") } diff --git a/relay/reqTransformer/ChatGPT.go b/relay/reqTransformer/ChatGPT.go index 0c3d5cf..814e73b 100644 --- a/relay/reqTransformer/ChatGPT.go +++ b/relay/reqTransformer/ChatGPT.go @@ -1,6 +1,7 @@ package reqTransformer import ( + "encoding/json" "log" openai "github.com/sashabaranov/go-openai" @@ -21,9 +22,15 @@ func ChatGPTToGrpcRequest(api string, model_type string, token string, req opena if !ok { openai_params = make(map[string]any) } - requireds, ok := openai_params["requireds"].([]string) + var requireds []string + openai_requireds, ok := openai_params["required"].([]any) if !ok { requireds = make([]string, 0) + } else { + requireds = make([]string, len(openai_requireds)) + for i, v := range openai_requireds { + requireds[i] = v.(string) + } } openai_params_properties, ok := openai_params["properties"].(map[string]any) if !ok { @@ -45,9 +52,12 @@ func ChatGPTToGrpcRequest(api string, model_type string, token string, req opena if ok { param.Description = desc } - enum, ok := m["enum"].([]string) + json.Marshal(m["enum"]) + enum, ok := m["enum"].([]any) if ok { - param.Enums = enum + for _, e := range enum { + param.Enums = append(param.Enums, e.(string)) + } } params = append(params, ¶m) } @@ -63,6 +73,11 @@ func ChatGPTToGrpcRequest(api string, model_type string, token string, req opena if req.ToolChoice == "none" { usefc = false } + url := "TODO: URL" + switch api { + case "moonshot": + url = "https://api.moonshot.cn/v1" + } return &model.LLMRequestSchema{ Messages: messages, LlmRequestInfo: &model.LLMRequestInfo{ @@ -71,7 +86,7 @@ func ChatGPTToGrpcRequest(api string, model_type string, token string, req opena Temperature: float64(req.Temperature), TopP: float64(req.TopP), TopK: float64(0), - Url: "TODO: URL", + Url: url, Token: token, UseFunctionCalling: usefc, Functions: tools, @@ -127,7 +142,7 @@ func ChatGPTGrpcChatCompletionReq(rs *model.LLMRequestSchema) openai.ChatComplet tools[i].Function.Parameters.(map[string]any)["properties"].(map[string]any)[f.Parameters[j].Name].(map[string]any)["enum"] = f.Parameters[j].Enums } } - log.Printf("%#v\n", tools[i]) + log.Printf("converting function call to grpc: %#v", tools[i].Function) } return openai.ChatCompletionRequest{ Model: info.GetModel(), diff --git a/tests_grpc/baichuan_test.go b/tests_grpc/baichuan_test.go index 18beffc..ca76898 100644 --- a/tests_grpc/baichuan_test.go +++ b/tests_grpc/baichuan_test.go @@ -141,5 +141,5 @@ func TestBaichuanFunctionCalling(t *testing.T) { if err != nil { t.Error(err) } - log.Printf("res: %#v\n", res.ToolCalls[0]) + log.Printf("res: %#v", res.ToolCalls[0]) } diff --git a/tests_grpc/chatglm_test.go b/tests_grpc/chatglm_test.go index 1be2491..3ffc980 100644 --- a/tests_grpc/chatglm_test.go +++ b/tests_grpc/chatglm_test.go @@ -141,5 +141,5 @@ func TestChatGLMFunctionCalling(t *testing.T) { if err != nil { t.Error(err) } - log.Printf("res: %#v\n", res.ToolCalls[0]) + log.Printf("res: %#v", res.ToolCalls[0]) } diff --git a/tests_grpc/moonshot_test.go b/tests_grpc/moonshot_test.go new file mode 100644 index 0000000..31ffefa --- /dev/null +++ b/tests_grpc/moonshot_test.go @@ -0,0 +1,133 @@ +package tests_grpc_test + +import ( + "context" + "log" + "os" + "testing" + + "github.com/joho/godotenv" + "go.limit.dev/unollm/grpcServer" + "go.limit.dev/unollm/model" + "go.limit.dev/unollm/utils" +) + +func TestMoonShot(t *testing.T) { + godotenv.Load("../.env") + + messages := make([]*model.LLMChatCompletionMessage, 0) + messages = append(messages, &model.LLMChatCompletionMessage{ + Role: "user", + Content: "假如今天下大雨,我是否需要带伞?", + }) + OPENAIApiKey := os.Getenv("TEST_MOONSHOT_API") + req_info := model.LLMRequestInfo{ + LlmApiType: grpcServer.MOONSHOT_LLM_API, + Model: "moonshot-v1-8k", + Temperature: 0.9, + TopP: 0.9, + TopK: 1, + Url: "https://api.moonshot.cn/v1", + Token: OPENAIApiKey, + } + req := model.LLMRequestSchema{ + Messages: messages, + LlmRequestInfo: &req_info, + } + mockServer := grpcServer.UnoForwardServer{} + res, err := mockServer.BlockingRequestLLM(context.Background(), &req) + if err != nil { + t.Error(err) + } + log.Println("res: ", res) +} + +func TestMoonShotStreaming(t *testing.T) { + godotenv.Load("../.env") + + messages := make([]*model.LLMChatCompletionMessage, 0) + messages = append(messages, &model.LLMChatCompletionMessage{ + Role: "user", + Content: "假如今天下大雨,我是否需要带伞?", + }) + OPENAIApiKey := os.Getenv("TEST_MOONSHOT_API") + req_info := model.LLMRequestInfo{ + LlmApiType: grpcServer.MOONSHOT_LLM_API, + Model: "moonshot-v1-8k", + Temperature: 0.9, + TopP: 0.9, + TopK: 1, + Url: "https://api.moonshot.cn/v1", + Token: OPENAIApiKey, + } + req := model.LLMRequestSchema{ + Messages: messages, + LlmRequestInfo: &req_info, + } + mockServer := grpcServer.UnoForwardServer{} + mockServerPipe := utils.MockServerStream{ + Stream: make(chan *model.PartialLLMResponse, 1000), + } + err := mockServer.StreamRequestLLM(&req, &mockServerPipe) + if err != nil { + t.Fatal(err) + } + for { + res := <-mockServerPipe.Stream + log.Println(res) + if res.LlmTokenCount != nil { + log.Println(res.LlmTokenCount) + return + } + } +} + +func TestMoonShotFunctionCalling(t *testing.T) { + godotenv.Load("../.env") + + messages := make([]*model.LLMChatCompletionMessage, 0) + messages = append(messages, &model.LLMChatCompletionMessage{ + Role: "user", + Content: "whats the weather like in Poston?", + }) + OPENAIApiKey := os.Getenv("TEST_MOONSHOT_API") + req_info := model.LLMRequestInfo{ + LlmApiType: grpcServer.MOONSHOT_LLM_API, + Model: "moonshot-v1-8k", + Temperature: 0.9, + TopP: 0.9, + TopK: 1, + Url: "https://api.moonshot.cn/v1", + Token: OPENAIApiKey, + Functions: []*model.Function{ + { + Name: "get_weather", + Description: "Get the weather of a location", + Parameters: []*model.FunctionCallingParameter{ + { + Name: "location", + Type: "string", + Description: "The city and state, e.g. San Francisco, CA", + }, + { + Name: "unit", + Type: "string", + Enums: []string{"celsius", "fahrenheit"}, + }, + }, + Requireds: []string{"location", "unit"}, + }, + }, + UseFunctionCalling: true, + } + req := model.LLMRequestSchema{ + Messages: messages, + LlmRequestInfo: &req_info, + } + mockServer := grpcServer.UnoForwardServer{} + res, err := mockServer.BlockingRequestLLM(context.Background(), &req) + if err != nil { + t.Fatal(err) + } + log.Printf("res: %#v", res.ToolCalls[0]) +} diff --git a/tests_grpc/openai_test.go b/tests_grpc/openai_test.go index 5f1f5d4..c80762e 100644 --- a/tests_grpc/openai_test.go +++ b/tests_grpc/openai_test.go @@ -130,7 +130,7 @@ func TestOpenAIFunctionCalling(t *testing.T) { if err != nil { t.Fatal(err) } - log.Printf("res: %#v\n", res.ToolCalls[0]) + log.Printf("res: %#v", res.ToolCalls[0]) } func TestOpenAIEmbedding(t *testing.T) { @@ -149,5 +149,5 @@ func TestOpenAIEmbedding(t *testing.T) { if err != nil { t.Fatal(err) } - log.Printf("res: %#v\n", res) + log.Printf("res: %#v", res) } diff --git a/tests_http/baichuan_test.go b/tests_http/baichuan_test.go index 99ddf2d..e3cb1ee 100644 --- a/tests_http/baichuan_test.go +++ b/tests_http/baichuan_test.go @@ -42,7 +42,7 @@ func GinTestBaichuanStreaming(t *testing.T) { t.Error(e) break } - log.Printf("%#v\n", cv.Choices[0].Delta) + log.Printf("%#v", cv.Choices[0].Delta) } } @@ -65,7 +65,7 @@ func GinTestBaichuanBlocking(t *testing.T) { if err != nil { t.Fatal(err) } - log.Printf("%#v\n", resp.Choices[0]) + log.Printf("%#v", resp.Choices[0]) } func GinTestBaichuanFunctionCalling(t *testing.T) { @@ -120,6 +120,6 @@ func GinTestBaichuanFunctionCalling(t *testing.T) { t.Error(e) break } - log.Printf("%#v\n", cv.Choices[0].Delta) + log.Printf("%#v", cv.Choices[0].Delta) } } diff --git a/tests_http/chatglm_test.go b/tests_http/chatglm_test.go index 874b2e1..4f52ed6 100644 --- a/tests_http/chatglm_test.go +++ b/tests_http/chatglm_test.go @@ -42,7 +42,7 @@ func TestChatGLMStreaming(t *testing.T) { t.Error(e) break } - log.Printf("%#v\n", cv.Choices[0].Delta) + log.Printf("%#v", cv.Choices[0].Delta) } } @@ -65,7 +65,7 @@ func TestChatGLMBlocking(t *testing.T) { if err != nil { t.Fatal(err) } - log.Printf("%#v\n", resp.Choices[0]) + log.Printf("%#v", resp.Choices[0]) } func TestChatGLMFunctionCalling(t *testing.T) { @@ -120,7 +120,7 @@ func TestChatGLMFunctionCalling(t *testing.T) { t.Error(e) break } - log.Printf("%#v\n", cv.Choices[0].Delta.ToolCalls) + log.Printf("%#v", cv.Choices[0].Delta.ToolCalls) } } @@ -136,5 +136,5 @@ func TestChatGLMEmbedding(t *testing.T) { if err != nil { log.Panic(err) } - log.Printf("%#v\n", resp) + log.Printf("%#v", resp) } diff --git a/tests_http/moonshot_test.go b/tests_http/moonshot_test.go new file mode 100644 index 0000000..2d105b0 --- /dev/null +++ b/tests_http/moonshot_test.go @@ -0,0 +1,124 @@ +package tests_http_test + +import ( + "context" + "errors" + "io" + "log" + "os" + "testing" + + "github.com/joho/godotenv" + openai "github.com/sashabaranov/go-openai" + tests_http "go.limit.dev/unollm/tests_http" +) + +func TestMoonShotStreaming(t *testing.T) { + godotenv.Load("../.env") + + client := tests_http.GetClient(os.Getenv("TEST_MOONSHOT_API")) + + resp, err := client.CreateChatCompletionStream(context.Background(), + openai.ChatCompletionRequest{ + Model: "moonshot-v1-8k", + Messages: []openai.ChatCompletionMessage{ + { + Role: "user", + Content: "如果今天下雨,我需要打伞吗?", + }, + }, + }, + ) + if err != nil { + t.Fatal(err) + } + for { + cv, e := resp.Recv() + if e != nil { + if errors.Is(e, io.EOF) { + break + } + t.Error(e) + break + } + log.Printf("%#v", cv.Choices[0].Delta) + } +} + +func TestMoonShotBlocking(t *testing.T) { + godotenv.Load("../.env") + + client := tests_http.GetClient(os.Getenv("TEST_MOONSHOT_API")) + + resp, err := client.CreateChatCompletion(context.Background(), + openai.ChatCompletionRequest{ + Model: "moonshot-v1-8k", + Messages: []openai.ChatCompletionMessage{ + { + Role: "user", + Content: "如果今天下雨,我需要打伞吗?", + }, + }, + }, + ) + if err != nil { + t.Fatal(err) + } + log.Printf("%#v", resp.Choices[0]) +} + +func TestMoonShotFunctionCalling(t *testing.T) { + godotenv.Load("../.env") + + client := tests_http.GetClient(os.Getenv("TEST_MOONSHOT_API")) + + resp, err := client.CreateChatCompletionStream(context.Background(), + openai.ChatCompletionRequest{ + Model: "moonshot-v1-8k", + Messages: []openai.ChatCompletionMessage{ + { + Role: "user", + Content: "请返回北京和天津的天气情况。", + }, + }, + ToolChoice: "auto", + Tools: []openai.Tool{ + { + Type: openai.ToolType("function"), + Function: &openai.FunctionDefinition{ + Name: "get_weather", + Description: "Get the weather of a location", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "location": map[string]any{ + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": map[string]any{ + "type": "string", + "enum": []string{"celsius", "fahrenheit"}, + }, + }, + "required": []string{"location", "unit"}, + }, + }, + }, + }, + }, + ) + if err != nil { + t.Fatal(err) + } + for { + cv, e := resp.Recv() + if e != nil { + if errors.Is(e, io.EOF) { + break + } + log.Panic(e) + break + } + log.Printf("%#v", cv.Choices[0].Delta.ToolCalls) + } +}