Skip to content

Commit

Permalink
fix array "required" and "enum" in function calling
Browse files Browse the repository at this point in the history
  • Loading branch information
LemonHX committed Mar 1, 2024
1 parent 8dd91f1 commit 4bdb49e
Show file tree
Hide file tree
Showing 11 changed files with 344 additions and 18 deletions.
25 changes: 25 additions & 0 deletions grpcServer/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ const CHATGLM_LLM_API = "chatglm"
const AZURE_OPENAI_LLM_API = "azure_openai"
const BAICHUAN_LLM_API = "baichuan"
const GEMINI_LLM_API = "gemini"
const MOONSHOT_LLM_API = "moonshot"

func (uno *UnoForwardServer) BlockingRequestLLM(ctx context.Context, rs *model.LLMRequestSchema) (*model.LLMResponseSchema, error) {
info := rs.GetLlmRequestInfo()
Expand All @@ -35,6 +36,18 @@ func (uno *UnoForwardServer) BlockingRequestLLM(ctx context.Context, rs *model.L
cli := NewOpenAIClient(info)
return OpenAIChatCompletion(cli, rs)

case MOONSHOT_LLM_API:
cli := NewOpenAIClient(info)
if functionCallingRequestMake(rs) {
res, err := OpenAIChatCompletion(cli, rs)
if err != nil {
return nil, status.Errorf(codes.Internal, err.Error())
}
functionCallingResponseHandle(res)
return res, nil
}
return OpenAIChatCompletion(cli, rs)

case CHATGLM_LLM_API:
cli := NewChatGLMClient(info)
return ChatGLMChatCompletion(cli, rs)
Expand Down Expand Up @@ -69,6 +82,18 @@ func (uno *UnoForwardServer) StreamRequestLLM(rs *model.LLMRequestSchema, sv mod
case OPENAI_LLM_API:
cli := NewOpenAIClient(info)
return OpenAIChatCompletionStreaming(cli, rs, sv)
case MOONSHOT_LLM_API:
cli := NewOpenAIClient(info)
if functionCallingRequestMake(rs) {
res, err := OpenAIChatCompletion(cli, rs)
if err != nil {
return status.Errorf(codes.Internal, err.Error())
}
functionCallingResponseHandle(res)
functionCallingResponseToStream(res, sv)
return nil
}
return OpenAIChatCompletionStreaming(cli, rs, sv)
case CHATGLM_LLM_API:
cli := NewChatGLMClient(info)
return ChatGLMChatCompletionStreaming(cli, rs, sv)
Expand Down
30 changes: 28 additions & 2 deletions grpcServer/relay_function_calling_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"log"
"strings"
"time"

openai "github.com/sashabaranov/go-openai"
Expand All @@ -14,8 +15,34 @@ import (
type Callings struct {
}

func removeWhiteSpaces(str string) string {
var b strings.Builder
b.Grow(len(str))
for _, ch := range str {
if ch != ' ' && ch != '\t' && ch != '\n' {
b.WriteRune(ch)
}
}
return b.String()
}
func getParams(url string) (params []ChatGLM.GLMFunctionCall, err error) {
url = removeWhiteSpaces(url)
err = json.Unmarshal([]byte(url), &params)
if err != nil {
// ```json<>```
uurl1, found1 := strings.CutPrefix(url, "```json")
if found1 {
uurl1, _ = strings.CutSuffix(uurl1, "```")
err = json.Unmarshal([]byte(uurl1), &params)
return
}
uurl2, found2 := strings.CutPrefix(url, "```")
if found2 {
uurl2, _ = strings.CutSuffix(uurl2, "```")
err = json.Unmarshal([]byte(uurl2), &params)
return
}
}
return
}

Expand Down Expand Up @@ -53,7 +80,7 @@ func functionCallingRequestMake(req *model.LLMRequestSchema) bool {
tools[i].Function.Parameters.(map[string]any)["properties"].(map[string]any)[f.Parameters[j].Name].(map[string]any)["enum"] = f.Parameters[j].Enums
}
}
log.Printf("%#v\n", tools[i])
log.Printf("find function: %#v", tools[i].Function)
}

tools_json_byte, err := json.Marshal(tools)
Expand All @@ -80,7 +107,6 @@ func functionCallingResponseHandle(resp *model.LLMResponseSchema) {
if err != nil {
log.Fatal(err)
}
log.Printf("%#v\n", function_calling)
resp.Message.Content = ""
resp.ToolCalls = []*model.ToolCall{}
for _, f := range function_calling {
Expand Down
3 changes: 3 additions & 0 deletions grpcServer/relay_gin.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ func getProvider(m string) (string, error) {
if strings.Contains(m, "gpt") {
return "openai", nil
}
if strings.Contains(m, "moonshot") {
return "moonshot", nil
}
return "", errors.New("could not get provider")
}

Expand Down
25 changes: 20 additions & 5 deletions relay/reqTransformer/ChatGPT.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package reqTransformer

import (
"encoding/json"
"log"

openai "github.com/sashabaranov/go-openai"
Expand All @@ -21,9 +22,15 @@ func ChatGPTToGrpcRequest(api string, model_type string, token string, req opena
if !ok {
openai_params = make(map[string]any)
}
requireds, ok := openai_params["requireds"].([]string)
var requireds []string
openai_requireds, ok := openai_params["required"].([]any)
if !ok {
requireds = make([]string, 0)
} else {
requireds = make([]string, len(openai_requireds))
for i, v := range openai_requireds {
requireds[i] = v.(string)
}
}
openai_params_properties, ok := openai_params["properties"].(map[string]any)
if !ok {
Expand All @@ -45,9 +52,12 @@ func ChatGPTToGrpcRequest(api string, model_type string, token string, req opena
if ok {
param.Description = desc
}
enum, ok := m["enum"].([]string)
json.Marshal(m["enum"])
enum, ok := m["enum"].([]any)
if ok {
param.Enums = enum
for _, e := range enum {
param.Enums = append(param.Enums, e.(string))
}
}
params = append(params, &param)
}
Expand All @@ -63,6 +73,11 @@ func ChatGPTToGrpcRequest(api string, model_type string, token string, req opena
if req.ToolChoice == "none" {
usefc = false
}
url := "TODO: URL"
switch api {
case "moonshot":
url = "https://api.moonshot.cn/v1"
}
return &model.LLMRequestSchema{
Messages: messages,
LlmRequestInfo: &model.LLMRequestInfo{
Expand All @@ -71,7 +86,7 @@ func ChatGPTToGrpcRequest(api string, model_type string, token string, req opena
Temperature: float64(req.Temperature),
TopP: float64(req.TopP),
TopK: float64(0),
Url: "TODO: URL",
Url: url,
Token: token,
UseFunctionCalling: usefc,
Functions: tools,
Expand Down Expand Up @@ -127,7 +142,7 @@ func ChatGPTGrpcChatCompletionReq(rs *model.LLMRequestSchema) openai.ChatComplet
tools[i].Function.Parameters.(map[string]any)["properties"].(map[string]any)[f.Parameters[j].Name].(map[string]any)["enum"] = f.Parameters[j].Enums
}
}
log.Printf("%#v\n", tools[i])
log.Printf("converting function call to grpc: %#v", tools[i].Function)
}
return openai.ChatCompletionRequest{
Model: info.GetModel(),
Expand Down
2 changes: 1 addition & 1 deletion tests_grpc/baichuan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,5 +141,5 @@ func TestBaichuanFunctionCalling(t *testing.T) {
if err != nil {
t.Error(err)
}
log.Printf("res: %#v\n", res.ToolCalls[0])
log.Printf("res: %#v", res.ToolCalls[0])
}
2 changes: 1 addition & 1 deletion tests_grpc/chatglm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,5 +141,5 @@ func TestChatGLMFunctionCalling(t *testing.T) {
if err != nil {
t.Error(err)
}
log.Printf("res: %#v\n", res.ToolCalls[0])
log.Printf("res: %#v", res.ToolCalls[0])
}
133 changes: 133 additions & 0 deletions tests_grpc/moonshot_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package tests_grpc_test

import (
"context"
"log"
"os"
"testing"

"github.com/joho/godotenv"
"go.limit.dev/unollm/grpcServer"
"go.limit.dev/unollm/model"
"go.limit.dev/unollm/utils"
)

func TestMoonShot(t *testing.T) {
godotenv.Load("../.env")

messages := make([]*model.LLMChatCompletionMessage, 0)
messages = append(messages, &model.LLMChatCompletionMessage{
Role: "user",
Content: "假如今天下大雨,我是否需要带伞?",
})
OPENAIApiKey := os.Getenv("TEST_MOONSHOT_API")
req_info := model.LLMRequestInfo{
LlmApiType: grpcServer.MOONSHOT_LLM_API,
Model: "moonshot-v1-8k",
Temperature: 0.9,
TopP: 0.9,
TopK: 1,
Url: "https://api.moonshot.cn/v1",
Token: OPENAIApiKey,
}
req := model.LLMRequestSchema{
Messages: messages,
LlmRequestInfo: &req_info,
}
mockServer := grpcServer.UnoForwardServer{}
res, err := mockServer.BlockingRequestLLM(context.Background(), &req)
if err != nil {
t.Error(err)
}
log.Println("res: ", res)
}

func TestMoonShotStreaming(t *testing.T) {
godotenv.Load("../.env")

messages := make([]*model.LLMChatCompletionMessage, 0)
messages = append(messages, &model.LLMChatCompletionMessage{
Role: "user",
Content: "假如今天下大雨,我是否需要带伞?",
})
OPENAIApiKey := os.Getenv("TEST_MOONSHOT_API")
req_info := model.LLMRequestInfo{
LlmApiType: grpcServer.MOONSHOT_LLM_API,
Model: "moonshot-v1-8k",
Temperature: 0.9,
TopP: 0.9,
TopK: 1,
Url: "https://api.moonshot.cn/v1",
Token: OPENAIApiKey,
}
req := model.LLMRequestSchema{
Messages: messages,
LlmRequestInfo: &req_info,
}
mockServer := grpcServer.UnoForwardServer{}
mockServerPipe := utils.MockServerStream{
Stream: make(chan *model.PartialLLMResponse, 1000),
}
err := mockServer.StreamRequestLLM(&req, &mockServerPipe)
if err != nil {
t.Fatal(err)
}
for {
res := <-mockServerPipe.Stream
log.Println(res)
if res.LlmTokenCount != nil {
log.Println(res.LlmTokenCount)
return
}
}
}

func TestMoonShotFunctionCalling(t *testing.T) {
godotenv.Load("../.env")

messages := make([]*model.LLMChatCompletionMessage, 0)
messages = append(messages, &model.LLMChatCompletionMessage{
Role: "user",
Content: "whats the weather like in Poston?",
})
OPENAIApiKey := os.Getenv("TEST_MOONSHOT_API")
req_info := model.LLMRequestInfo{
LlmApiType: grpcServer.MOONSHOT_LLM_API,
Model: "moonshot-v1-8k",
Temperature: 0.9,
TopP: 0.9,
TopK: 1,
Url: "https://api.moonshot.cn/v1",
Token: OPENAIApiKey,
Functions: []*model.Function{
{
Name: "get_weather",
Description: "Get the weather of a location",
Parameters: []*model.FunctionCallingParameter{
{
Name: "location",
Type: "string",
Description: "The city and state, e.g. San Francisco, CA",
},
{
Name: "unit",
Type: "string",
Enums: []string{"celsius", "fahrenheit"},
},
},
Requireds: []string{"location", "unit"},
},
},
UseFunctionCalling: true,
}
req := model.LLMRequestSchema{
Messages: messages,
LlmRequestInfo: &req_info,
}
mockServer := grpcServer.UnoForwardServer{}
res, err := mockServer.BlockingRequestLLM(context.Background(), &req)
if err != nil {
t.Fatal(err)
}
log.Printf("res: %#v", res.ToolCalls[0])
}
4 changes: 2 additions & 2 deletions tests_grpc/openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func TestOpenAIFunctionCalling(t *testing.T) {
if err != nil {
t.Fatal(err)
}
log.Printf("res: %#v\n", res.ToolCalls[0])
log.Printf("res: %#v", res.ToolCalls[0])
}

func TestOpenAIEmbedding(t *testing.T) {
Expand All @@ -149,5 +149,5 @@ func TestOpenAIEmbedding(t *testing.T) {
if err != nil {
t.Fatal(err)
}
log.Printf("res: %#v\n", res)
log.Printf("res: %#v", res)
}
6 changes: 3 additions & 3 deletions tests_http/baichuan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func GinTestBaichuanStreaming(t *testing.T) {
t.Error(e)
break
}
log.Printf("%#v\n", cv.Choices[0].Delta)
log.Printf("%#v", cv.Choices[0].Delta)
}
}

Expand All @@ -65,7 +65,7 @@ func GinTestBaichuanBlocking(t *testing.T) {
if err != nil {
t.Fatal(err)
}
log.Printf("%#v\n", resp.Choices[0])
log.Printf("%#v", resp.Choices[0])
}

func GinTestBaichuanFunctionCalling(t *testing.T) {
Expand Down Expand Up @@ -120,6 +120,6 @@ func GinTestBaichuanFunctionCalling(t *testing.T) {
t.Error(e)
break
}
log.Printf("%#v\n", cv.Choices[0].Delta)
log.Printf("%#v", cv.Choices[0].Delta)
}
}
Loading

0 comments on commit 4bdb49e

Please sign in to comment.