diff --git a/plugins/wasm-go/extensions/ai-json-resp/README.md b/plugins/wasm-go/extensions/ai-json-resp/README.md new file mode 100644 index 0000000000..9fc5a5eee5 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-json-resp/README.md @@ -0,0 +1,202 @@ +## 简介 + +**Note** + +> 需要数据面的proxy wasm版本大于等于0.2.100 +> + +> 编译时,需要带上版本的tag,例如:tinygo build -o main.wasm -scheduler=none -target=wasi -gc=custom -tags="custommalloc nottinygc_finalizer proxy_wasm_version_0_2_100" ./ + + +LLM响应结构化插件,用于根据默认或用户配置的Json Schema对AI的响应进行结构化,以便后续插件处理。注意目前只支持 `非流式响应`。 + + +### 配置说明 + +| Name | Type | Requirement | Default | **Description** | +| --- | --- | --- | --- | --- | +| serviceName | str | required | - | AI服务或支持AI-Proxy的网关服务名称 | +| serviceDomain | str | optional | - | AI服务或支持AI-Proxy的网关服务域名/IP地址 | +| servicePath | str | optional | '/v1/chat/completions' | AI服务或支持AI-Proxy的网关服务基础路径 | +| serviceUrl | str | optional | - | AI服务或支持 AI-Proxy 的网关服务URL, 插件将自动提取Domain 和 Path, 用于填充未配置的 serviceDomain 或 servicePath | +| servicePort | int | optional | 443 | 网关服务端口 | +| serviceTimeout | int | optional | 50000 | 默认请求超时时间 | +| maxRetry | int | optional | 3 | 若回答无法正确提取格式化时重试次数 | +| contentPath | str | optional | "choices.0.message.content” | 从LLM回答中提取响应结果的gpath路径 | +| jsonSchema | str (json) | optional | - | 验证请求所参照的 jsonSchema, 为空只验证并返回合法Json格式响应 | +| enableSwagger | bool | optional | false | 是否启用 Swagger 协议进行验证 | +| enableOas3 | bool | optional | true | 是否启用 Oas3 协议进行验证 | +| enableContentDisposition | bool | optional | true | 是否启用 Content-Disposition 头部, 若启用则会在响应头中添加 `Content-Disposition: attachment; filename="response.json"` | + +> 出于性能考虑,默认支持的最大 Json Schema 深度为 6。超过此深度的 Json Schema 将不用于验证响应,插件只会检查返回的响应是否为合法的 Json 格式。 + + +### 请求和返回参数说明 + +- **请求参数**: 本插件请求格式为openai请求格式,包含`model`和`messages`字段,其中`model`为AI模型名称,`messages`为对话消息列表,每个消息包含`role`和`content`字段,`role`为消息角色,`content`为消息内容。 + ```json + { + "model": "gpt-4", + "messages": [ + {"role": "user", "content": "give me a api doc for add the variable x to x+5"} + ] + } + ``` + 其他请求参数需参考配置的ai服务或网关服务的相应文档。 +- **返回参数**: + - 返回满足定义的Json Schema约束的 `Json格式响应` + - 若未定义Json Schema,则返回合法的`Json格式响应` + - 若出现内部错误,则返回 `{ "Code": 10XX, "Msg": "错误信息提示" }`。 + +## 请求示例 + +```bash +curl -X POST "http://localhost:8001/v1/chat/completions" \ +-H "Content-Type: application/json" \ +-d '{ + "model": "gpt-4", + "messages": [ + {"role": "user", "content": "give me a api doc for add the variable x to x+5"} + ] +}' + +``` + +## 返回示例 +### 正常返回 +在正常情况下,系统应返回经过 JSON Schema 验证的 JSON 数据。如果未配置 JSON Schema,系统将返回符合 JSON 标准的合法 JSON 数据。 +```json +{ + "apiVersion": "1.0", + "request": { + "endpoint": "/add_to_five", + "method": "POST", + "port": 8080, + "headers": { + "Content-Type": "application/json" + }, + "body": { + "x": 7 + } + } +} +``` + +### 异常返回 +在发生错误时,返回状态码为 `500`,返回内容为 JSON 格式的错误信息。包含错误码 `Code` 和错误信息 `Msg` 两个字段。 +```json +{ + "Code": 1006, + "Msg": "retry count exceed max retry count" +} +``` + +### 错误码说明 +| 错误码 | 说明 | +| --- | --- | +| 1001 | 配置的Json Schema不是合法Json格式| +| 1002 | 配置的Json Schema编译失败,不是合法的Json Schema 格式或深度超出 jsonSchemaMaxDepth 且 rejectOnDepthExceeded 为true| +| 1003 | 无法在响应中提取合法的Json| +| 1004 | 响应为空字符串| +| 1005 | 响应不符合Json Schema定义| +| 1006 | 重试次数超过最大限制| +| 1007 | 无法获取响应内容,可能是上游服务配置错误或获取内容的ContentPath路径错误| +| 1008 | serciveDomain为空, 请注意serviceDomian或serviceUrl不能同时为空| + +## 服务配置说明 +本插件需要配置上游服务来支持出现异常时的自动重试机制, 支持的配置主要包括`支持openai接口的AI服务`或`本地网关服务` + +### 支持openai接口的AI服务 +以qwen为例,基本配置如下: + +Yaml格式配置如下 +```yaml +serviceName: qwen +serviceDomain: dashscope.aliyuncs.com +apiKey: [Your API Key] +servicePath: /compatible-mode/v1/chat/completions +jsonSchema: + title: ReasoningSchema + type: object + properties: + reasoning_steps: + type: array + items: + type: string + description: The reasoning steps leading to the final conclusion. + answer: + type: string + description: The final answer, taking into account the reasoning steps. + required: + - reasoning_steps + - answer + additionalProperties: false +``` + +JSON 格式配置 +```json +{ + "serviceName": "qwen", + "serviceUrl": "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions", + "apiKey": "[Your API Key]", + "jsonSchema": { + "title": "ActionItemsSchema", + "type": "object", + "properties": { + "action_items": { + "type": "array", + "items": { + "type": "object", + "properties": { + "description": { + "type": "string", + "description": "Description of the action item." + }, + "due_date": { + "type": ["string", "null"], + "description": "Due date for the action item, can be null if not specified." + }, + "owner": { + "type": ["string", "null"], + "description": "Owner responsible for the action item, can be null if not specified." + } + }, + "required": ["description", "due_date", "owner"], + "additionalProperties": false + }, + "description": "List of action items from the meeting." + } + }, + "required": ["action_items"], + "additionalProperties": false + } +} +``` + +### 本地网关服务 +为了能复用已经配置好的服务,本插件也支持配置本地网关服务。例如,若网关已经配置好了[AI-proxy服务](../ai-proxy/README.md),则可以直接配置如下: +1. 创建一个固定IP为127.0.0.1的服务,例如localservice.static +```yaml +- name: outbound|10000||localservice.static + connect_timeout: 30s + type: LOGICAL_DNS + dns_lookup_family: V4_ONLY + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: outbound|8001||localservice.static + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: 127.0.0.1 + port_value: 10000 +``` +2. 配置文件中添加localservice.static的服务配置 +```yaml +serviceName: localservice +serviceDomain: 127.0.0.1 +servicePort: 10000 +``` +3. 自动提取请求的Path,Header等信息 +插件会自动提取请求的Path,Header等信息,从而避免对AI服务的重复配置。 diff --git a/plugins/wasm-go/extensions/ai-json-resp/go.mod b/plugins/wasm-go/extensions/ai-json-resp/go.mod new file mode 100644 index 0000000000..dba1b5f01e --- /dev/null +++ b/plugins/wasm-go/extensions/ai-json-resp/go.mod @@ -0,0 +1,21 @@ +module github.com/alibaba/higress/plugins/wasm-go/extensions/hello-world + +go 1.18 + +replace github.com/alibaba/higress/plugins/wasm-go => ../.. + +require ( + github.com/alibaba/higress/plugins/wasm-go v1.4.2 + github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f +) + +require ( + github.com/google/uuid v1.3.0 // indirect + github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect + github.com/magefile/mage v1.14.0 // indirect + github.com/santhosh-tekuri/jsonschema v1.2.4 // indirect + github.com/tidwall/gjson v1.14.3 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.0 // indirect + github.com/tidwall/resp v0.1.1 // indirect +) diff --git a/plugins/wasm-go/extensions/ai-json-resp/go.sum b/plugins/wasm-go/extensions/ai-json-resp/go.sum new file mode 100644 index 0000000000..4a1363d5b8 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-json-resp/go.sum @@ -0,0 +1,26 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA= +github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a h1:luYRvxLTE1xYxrXYj7nmjd1U0HHh8pUPiKfdZ0MhCGE= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240318034951-d5306e367c43/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= +github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= +github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/santhosh-tekuri/jsonschema v1.2.4 h1:hNhW8e7t+H1vgY+1QeEQpveR6D4+OwKPXCfD2aieJis= +github.com/santhosh-tekuri/jsonschema v1.2.4/go.mod h1:TEAUOeZSmIxTTuHatJzrvARHiuO9LYd+cIxzgEHCQI4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= +github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE= +github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/plugins/wasm-go/extensions/ai-json-resp/main.go b/plugins/wasm-go/extensions/ai-json-resp/main.go new file mode 100644 index 0000000000..1034396d70 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-json-resp/main.go @@ -0,0 +1,573 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "encoding/json" + "errors" + "net/http" + "strconv" + "strings" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/santhosh-tekuri/jsonschema" + "github.com/tidwall/gjson" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" +) + +const ( + DEFAULT_SCHEMA = "defaultSchema" + HTTP_STATUS_OK = uint32(200) + HTTP_STATUS_INTERNAL_SERVER_ERROR = uint32(500) + FROM_THIS_PLUGIN_KEY = "fromThisPlugin" + EXTEND_HEADER_KEY = "X-HIGRESS-AI-JSON-RESP" + + JSON_SCHEMA_INVALID_CODE = 1001 + JSON_SCHEMA_COMPILE_FAILED_CODE = 1002 + CANNOT_FIND_JSON_IN_RESPONSE_CODE = 1003 + CONTENT_IS_EMPTY_CODE = 1004 + JSON_MISMATCH_SCHEMA_CODE = 1005 + REACH_MAX_RETRY_COUNT_CODE = 1006 + SERVICE_UNAVAILABLE_CODE = 1007 + SERVICE_CONFIG_INVALID_CODE = 1008 +) + +type RejectStruct struct { + RejectCode uint32 `json:"Code"` + RejectMsg string `json:"Msg"` +} + +func (r RejectStruct) GetBytes() []byte { + jsonData, _ := json.Marshal(r) + return jsonData +} + +func (r RejectStruct) GetShortMsg() string { + return "ai-json-resp." + strings.Split(r.RejectMsg, ":")[0] +} + +type PluginConfig struct { + // @Title zh-CN 服务名称 + // @Description zh-CN 用以请求服务的名称(网关或其他AI服务) + serviceName string `required:"true" json:"serviceName" yaml:"serviceName"` + // @Title zh-CN 服务域名 + // @Description zh-CN 用以请求服务的域名 + serviceDomain string `required:"false" json:"serviceDomain" yaml:"serviceDomain"` + // @Title zh-CN 服务端口 + // @Description zh-CN 用以请求服务的端口 + servicePort int `required:"false" json:"servicePort" yaml:"servicePort"` + // @Title zh-CN 服务URL + // @Description zh-CN 用以请求服务的URL,若提供则会覆盖serviceDomain和servicePort + serviceUrl string `required:"false" json:"serviceUrl" yaml:"serviceUrl"` + // @Title zh-CN API Key + // @Description zh-CN 若使用AI服务,需要填写请求服务的API Key + apiKey string `required:"false" json: "apiKey" yaml:"apiKey"` + // @Title zh-CN 请求端点 + // @Description zh-CN 用以请求服务的端点, 默认为"/v1/chat/completions" + servicePath string `required:"false" json: "servicePath" yaml:"servicePath"` + // @Title zh-CN 服务超时时间 + // @Description zh-CN 用以请求服务的超时时间 + serviceTimeout int `required:"false" json:"serviceTimeout" yaml:"serviceTimeout"` + // @Title zh-CN 最大重试次数 + // @Description zh-CN 用以请求服务的最大重试次数 + maxRetry int `required:"false" json:"maxRetry" yaml:"maxRetry"` + // @Title zh-CN 内容路径 + // @Description zh-CN 从AI服务返回的响应中提取json的gpath路径 + contentPath string `required:"false" json:"contentPath" yaml:"contentPath"` + // @Title zh-CN Json Schema + // @Description zh-CN 用以验证响应json的Json Schema, 为空则只验证返回的响应是否为合法json + jsonSchema map[string]interface{} `required:"false" json:"jsonSchema" yaml:"jsonSchema"` + // @Title zh-CN 是否启用swagger + // @Description zh-CN 是否启用swagger进行Json Schema验证 + enableSwagger bool `required:"false" json:"enableSwagger" yaml:"enableSwagger"` + // @Title zh-CN 是否启用oas3 + // @Description zh-CN 是否启用oas3进行Json Schema验证 + enableOas3 bool `required:"false" json:"enableOas3" yaml:"enableOas3"` + // @Title zh-CN 是否启用Content-Disposition + // @Description zh-CN 是否启用Content-Disposition, 若启用则会在响应头中添加Content-Disposition: attachment; filename="response.json" + enableContentDisposition bool `required:"false" json:"enableContentDisposition" yaml:"enableContentDisposition"` + + serviceClient wrapper.HttpClient + draft *jsonschema.Draft + compiler *jsonschema.Compiler + compile *jsonschema.Schema + rejectStruct RejectStruct + jsonSchemaMaxDepth int + enableJsonSchemaValidation bool +} + +func main() { + wrapper.SetCtx( + "ai-json-resp", + wrapper.ParseConfigBy(parseConfig), + wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders), + wrapper.ProcessRequestBodyBy(onHttpRequestBody), + ) +} + +type RequestContext struct { + Path string + ReqHeaders [][2]string + ReqBody []byte + RespHeader [][2]string + RespBody []byte + HistoryMessages []chatMessage +} + +func parseUrl(url string) (string, string) { + if url == "" { + return "", "" + } + url = strings.TrimPrefix(url, "http://") + url = strings.TrimPrefix(url, "https://") + index := strings.Index(url, "/") + if index == -1 { + return url, "" + } + return url[:index], url[index:] +} + +func parseConfig(result gjson.Result, config *PluginConfig, log wrapper.Log) error { + config.serviceName = result.Get("serviceName").String() + config.serviceUrl = result.Get("serviceUrl").String() + config.serviceDomain = result.Get("serviceDomain").String() + config.servicePath = result.Get("servicePath").String() + config.servicePort = int(result.Get("servicePort").Int()) + if config.serviceUrl != "" { + domain, url := parseUrl(config.serviceUrl) + log.Debugf("serviceUrl: %s, the parsed domain: %s, the parsed url: %s", config.serviceUrl, domain, url) + if config.serviceDomain == "" { + config.serviceDomain = domain + } + if config.servicePath == "" { + config.servicePath = url + } + } + if config.servicePort == 0 { + config.servicePort = 443 + } + config.serviceTimeout = int(result.Get("serviceTimeout").Int()) + config.apiKey = result.Get("apiKey").String() + config.rejectStruct = RejectStruct{HTTP_STATUS_OK, ""} + if config.serviceTimeout == 0 { + config.serviceTimeout = 50000 + } + config.maxRetry = int(result.Get("maxRetry").Int()) + if config.maxRetry == 0 { + config.maxRetry = 3 + } + config.contentPath = result.Get("contentPath").String() + if config.contentPath == "" { + config.contentPath = "choices.0.message.content" + } + + if jsonSchemaValue := result.Get("jsonSchema"); jsonSchemaValue.Exists() { + if schemaValue, ok := jsonSchemaValue.Value().(map[string]interface{}); ok { + config.jsonSchema = schemaValue + + } else { + config.rejectStruct = RejectStruct{JSON_SCHEMA_INVALID_CODE, "Json Schema is not valid"} + } + } else { + config.jsonSchema = nil + } + + if config.serviceDomain == "" { + config.rejectStruct = RejectStruct{JSON_SCHEMA_INVALID_CODE, "service domain is empty"} + } + + config.serviceClient = wrapper.NewClusterClient(wrapper.DnsCluster{ + ServiceName: config.serviceName, + Port: int64(config.servicePort), + Domain: config.serviceDomain, + }) + + enableSwagger := result.Get("enableSwagger").Bool() + enableOas3 := result.Get("enableOas3").Bool() + + // set draft version + if enableSwagger { + config.draft = jsonschema.Draft4 + } + if enableOas3 { + config.draft = jsonschema.Draft7 + } + if !enableSwagger && !enableOas3 { + config.draft = jsonschema.Draft7 + } + + // create compiler + compiler := jsonschema.NewCompiler() + compiler.Draft = config.draft + config.compiler = compiler + + // set max depth of json schema + config.jsonSchemaMaxDepth = 6 + + enableContentDispositionValue := result.Get("enableContentDisposition") + if !enableContentDispositionValue.Exists() { + config.enableContentDisposition = true + } else { + config.enableContentDisposition = enableContentDispositionValue.Bool() + } + + config.enableJsonSchemaValidation = true + + jsonSchemaBytes, err := json.Marshal(config.jsonSchema) + if err != nil { + config.rejectStruct = RejectStruct{JSON_SCHEMA_INVALID_CODE, "Json Schema marshal failed"} + return err + } + + maxDepth := GetMaxDepth(config.jsonSchema) + log.Debugf("max depth of json schema: %d", maxDepth) + if maxDepth > config.jsonSchemaMaxDepth { + config.enableJsonSchemaValidation = false + log.Infof("Json Schema depth exceeded: %d from %d , Json Schema validation will not be used.", maxDepth, config.jsonSchemaMaxDepth) + } + + if config.enableJsonSchemaValidation { + jsonSchemaStr := string(jsonSchemaBytes) + config.compiler.AddResource(DEFAULT_SCHEMA, strings.NewReader(jsonSchemaStr)) + // Test if the Json Schema is valid + compile, err := config.compiler.Compile(DEFAULT_SCHEMA) + if err != nil { + log.Infof("Json Schema compile failed: %v", err) + config.rejectStruct = RejectStruct{JSON_SCHEMA_COMPILE_FAILED_CODE, "Json Schema compile failed: " + err.Error()} + config.compile = nil + } else { + config.compile = compile + } + } + + return nil +} + +func (r *RequestContext) assembleReqBody(config PluginConfig) []byte { + var reqBodystrut chatCompletionRequest + json.Unmarshal(r.ReqBody, &reqBodystrut) + content := gjson.ParseBytes(r.RespBody).Get(config.contentPath).String() + jsonSchemaBytes, _ := json.Marshal(config.jsonSchema) + jsonSchemaStr := string(jsonSchemaBytes) + + askQuestion := "Given the Json Schema: " + jsonSchemaStr + ", please help me convert the following content to a pure json: " + content + askQuestion += "\n Do not respond other content except the pure json!!!!" + + reqBodystrut.Messages = append(r.HistoryMessages, []chatMessage{ + { + Role: "user", + Content: askQuestion, + }, + }...) + + reqBody, _ := json.Marshal(reqBodystrut) + return reqBody +} + +func (r *RequestContext) SaveBodyToHistMsg(log wrapper.Log, reqBody []byte, respBody []byte) { + r.RespBody = respBody + lastUserMessage := "" + lastSystemMessage := "" + + var reqBodystrut chatCompletionRequest + err := json.Unmarshal(reqBody, &reqBodystrut) + if err != nil { + log.Debugf("unmarshal reqBody failed: %v", err) + } else { + if len(reqBodystrut.Messages) != 0 { + lastUserMessage = reqBodystrut.Messages[len(reqBodystrut.Messages)-1].Content + } + } + + var respBodystrut chatCompletionResponse + err = json.Unmarshal(respBody, &respBodystrut) + if err != nil { + log.Debugf("unmarshal respBody failed: %v", err) + } else { + if len(respBodystrut.Choices) != 0 { + lastSystemMessage = respBodystrut.Choices[len(respBodystrut.Choices)-1].Message.Content + } + } + + if lastUserMessage != "" { + r.HistoryMessages = append(r.HistoryMessages, chatMessage{ + Role: "user", + Content: lastUserMessage, + }) + } + + if lastSystemMessage != "" { + r.HistoryMessages = append(r.HistoryMessages, chatMessage{ + Role: "system", + Content: lastSystemMessage, + }) + } +} + +func (r *RequestContext) SaveStrToHistMsg(log wrapper.Log, errMsg string) { + r.HistoryMessages = append(r.HistoryMessages, chatMessage{ + Role: "system", + Content: errMsg, + }) +} + +func (c *PluginConfig) ValidateBody(body []byte) error { + var respJsonStrct chatCompletionResponse + err := json.Unmarshal(body, &respJsonStrct) + if err != nil { + c.rejectStruct = RejectStruct{SERVICE_UNAVAILABLE_CODE, "service unavailable: " + string(body)} + return errors.New(c.rejectStruct.RejectMsg) + } + content := gjson.ParseBytes(body).Get(c.contentPath) + if !content.Exists() { + c.rejectStruct = RejectStruct{SERVICE_UNAVAILABLE_CODE, "response body does not contain the content: " + string(body)} + return errors.New(c.rejectStruct.RejectMsg) + } + return nil +} + +func (c *PluginConfig) ValidateJson(body []byte, log wrapper.Log) (string, error) { + content := gjson.ParseBytes(body).Get(c.contentPath).String() + // first extract json from response body + if content == "" { + log.Infof("response body does not contain the content") + c.rejectStruct = RejectStruct{CONTENT_IS_EMPTY_CODE, "response body does not contain the content"} + return "", errors.New(c.rejectStruct.RejectMsg) + } + jsonStr, err := c.ExtractJson(content) + + if err != nil { + log.Infof("response body does not contain the valid json: %v", err.Error()) + c.rejectStruct = RejectStruct{CANNOT_FIND_JSON_IN_RESPONSE_CODE, "response body does not contain the valid json: " + err.Error()} + return "", errors.New(c.rejectStruct.RejectMsg) + } + + if c.jsonSchema != nil && c.enableJsonSchemaValidation { + compile, err := c.compiler.Compile(DEFAULT_SCHEMA) + if err != nil { + log.Infof("Json Schema compile failed: %v", err) + c.rejectStruct = RejectStruct{JSON_SCHEMA_COMPILE_FAILED_CODE, "Json Schema compile failed: " + err.Error()} + c.compile = nil + } else { + c.compile = compile + } + + // validate the json + err = c.compile.Validate(strings.NewReader(jsonStr)) + if err != nil { + log.Infof("response body does not match the Json Schema: %v", err) + c.rejectStruct = RejectStruct{JSON_MISMATCH_SCHEMA_CODE, "response body does not match the Json Schema: " + err.Error()} + return "", errors.New(c.rejectStruct.RejectMsg) + } + } + c.rejectStruct = RejectStruct{HTTP_STATUS_OK, ""} + return jsonStr, nil +} + +func (c *PluginConfig) ExtractJson(bodyStr string) (string, error) { + // simply extract json from response body string + startIndex := strings.Index(bodyStr, "{") + endIndex := strings.LastIndex(bodyStr, "}") + 1 + + // if not found + if startIndex == -1 || endIndex == -1 || startIndex >= endIndex { + return "", errors.New("cannot find json in the response body") + } + + jsonStr := bodyStr[startIndex:endIndex] + + // attempt to parse the JSON + var result map[string]interface{} + err := json.Unmarshal([]byte(jsonStr), &result) + if err != nil { + return "", err + } + return jsonStr, nil +} + +func sendResponse(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log, body []byte) { + log.Infof("Final send: Code %d, Message %s, Body: %s", config.rejectStruct.RejectCode, config.rejectStruct.RejectMsg, string(body)) + header := [][2]string{ + {"Content-Type", "application/json"}, + } + if body != nil && config.enableContentDisposition { + header = append(header, [2]string{"Content-Disposition", "attachment; filename=\"response.json\""}) + } + if config.rejectStruct.RejectCode != HTTP_STATUS_OK { + proxywasm.SendHttpResponseWithDetail(HTTP_STATUS_INTERNAL_SERVER_ERROR, config.rejectStruct.GetShortMsg(), nil, config.rejectStruct.GetBytes(), -1) + } else { + proxywasm.SendHttpResponse(HTTP_STATUS_OK, header, body, -1) + } +} + +func recursiveRefineJson(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log, retryCount int, requestContext *RequestContext) { + // if retry count exceeds max retry count, return the response + if retryCount >= config.maxRetry { + log.Debugf("retry count exceeds max retry count") + // report more useful error by appending the last of previous error message + config.rejectStruct = RejectStruct{REACH_MAX_RETRY_COUNT_CODE, "retry count exceeds max retry count: " + config.rejectStruct.RejectMsg} + sendResponse(ctx, config, log, nil) + return + } + + // recursively refine json + config.serviceClient.Post(requestContext.Path, requestContext.ReqHeaders, requestContext.assembleReqBody(config), + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + err := config.ValidateBody(responseBody) + if err != nil { + sendResponse(ctx, config, log, nil) + return + } + retryCount++ + requestContext.SaveBodyToHistMsg(log, requestContext.assembleReqBody(config), responseBody) + log.Debugf("[retry request %d/%d] resp code: %d", retryCount, config.maxRetry, statusCode) + validateJson, err := config.ValidateJson(responseBody, log) + if err == nil { + sendResponse(ctx, config, log, []byte(validateJson)) + } else { + requestContext.SaveStrToHistMsg(log, err.Error()) + recursiveRefineJson(ctx, config, log, retryCount, requestContext) + } + }, uint32(config.serviceTimeout)) +} + +func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action { + if config.rejectStruct.RejectCode != HTTP_STATUS_OK { + sendResponse(ctx, config, log, nil) + return types.ActionPause + } + + // verify if the request is from this plugin + extendHeaderValue, err := proxywasm.GetHttpRequestHeader(EXTEND_HEADER_KEY) + if err == nil { + fromThisPlugin, convErr := strconv.ParseBool(extendHeaderValue) + if convErr != nil { + log.Debugf("failed to parse header value as bool: %v", convErr) + ctx.SetContext(FROM_THIS_PLUGIN_KEY, false) + } + if fromThisPlugin { + ctx.SetContext(FROM_THIS_PLUGIN_KEY, true) + return types.ActionContinue + } + } else { + ctx.SetContext(FROM_THIS_PLUGIN_KEY, false) + } + + path, err := proxywasm.GetHttpRequestHeader(":path") + if err != nil { + log.Infof("get request path failed: %v", err) + path = "" + } else { + ctx.SetContext("path", path) + } + + headers, err := proxywasm.GetHttpRequestHeaders() + if err != nil { + log.Infof("get request header failed: %v", err) + } + + apiKey, err := proxywasm.GetHttpRequestHeader("Authorization") + if err != nil { + log.Infof("get request header failed: %v", err) + apiKey = "" + } + if apiKey != "" { + // remove the Authorization header + proxywasm.RemoveHttpRequestHeader("Authorization") + // remove the Authorization header from the headers + for i, header := range headers { + if header[0] == "Authorization" { + headers = append(headers[:i], headers[i+1:]...) + break + } + } + } + if config.apiKey != "" { + log.Debugf("add Authorization header %s", "Bearer "+config.apiKey) + headers = append(headers, [2]string{"Authorization", "Bearer " + config.apiKey}) + } + ctx.SetContext("headers", headers) + + return types.ActionContinue +} + +func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action { + // if the request is from this plugin, continue the request + fromThisPlugin, ok := ctx.GetContext(FROM_THIS_PLUGIN_KEY).(bool) + if ok && fromThisPlugin { + log.Debugf("detected buffer_request, sending request to AI service") + return types.ActionContinue + } + + var headers [][2]string + if h, ok := ctx.GetContext("headers").([][2]string); ok { + headers = append(h, [2]string{EXTEND_HEADER_KEY, "true"}) + } else { + log.Debugf("cannot get headers from context, use default headers") + headers = [][2]string{ + {"Content-Type", "application/json"}, + {EXTEND_HEADER_KEY, "true"}, + } + } + + // if there is any error in the config, return the response directly + if config.rejectStruct.RejectCode != HTTP_STATUS_OK { + sendResponse(ctx, config, log, nil) + return types.ActionContinue + } + + var path string + if path, ok := ctx.GetContext("path").(string); ok { + log.Debugf("use path: %s", path) + } else { + log.Debugf("cannot get path from context, use default path") + path = "/v1/chat/completions" + } + + if config.servicePath != "" { + log.Debugf("use base path: %s", config.servicePath) + path = config.servicePath + } + + requestContext := &RequestContext{ + Path: path, + ReqHeaders: headers, + ReqBody: body, + } + + config.serviceClient.Post(requestContext.Path, requestContext.ReqHeaders, requestContext.ReqBody, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + err := config.ValidateBody(responseBody) + if err != nil { + sendResponse(ctx, config, log, nil) + return + } + requestContext.SaveBodyToHistMsg(log, body, responseBody) + log.Debugf("[first request] resp code: %d", statusCode) + validateJson, err := config.ValidateJson(responseBody, log) + if err == nil { + sendResponse(ctx, config, log, []byte(validateJson)) + return + } else { + retryCount := 0 + requestContext.SaveStrToHistMsg(log, err.Error()) + recursiveRefineJson(ctx, config, log, retryCount, requestContext) + } + }, uint32(config.serviceTimeout)) + + return types.ActionPause +} diff --git a/plugins/wasm-go/extensions/ai-json-resp/model.go b/plugins/wasm-go/extensions/ai-json-resp/model.go new file mode 100644 index 0000000000..f04232caef --- /dev/null +++ b/plugins/wasm-go/extensions/ai-json-resp/model.go @@ -0,0 +1,180 @@ +// adopt from https://github.com/alibaba/higress/blob/main/plugins/wasm-go/extensions/ai-proxy/provider/model.go +package main + +import "strings" + +const ( + streamEventIdItemKey = "id:" + streamEventNameItemKey = "event:" + streamBuiltInItemKey = ":" + streamHttpStatusValuePrefix = "HTTP_STATUS/" + streamDataItemKey = "data:" + streamEndDataValue = "[DONE]" +) + +type chatCompletionRequest struct { + Model string `json:"model"` + Messages []chatMessage `json:"messages"` + MaxTokens int `json:"max_tokens,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + N int `json:"n,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + Seed int `json:"seed,omitempty"` + Stream bool `json:"stream,omitempty"` + StreamOptions *streamOptions `json:"stream_options,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + Tools []tool `json:"tools,omitempty"` + ToolChoice *toolChoice `json:"tool_choice,omitempty"` + User string `json:"user,omitempty"` + Stop []string `json:"stop,omitempty"` + ResponseFormat map[string]interface{} `json:"response_format,omitempty"` +} + +type streamOptions struct { + IncludeUsage bool `json:"include_usage,omitempty"` +} + +type tool struct { + Type string `json:"type"` + Function function `json:"function"` +} + +type function struct { + Description string `json:"description,omitempty"` + Name string `json:"name"` + Parameters map[string]interface{} `json:"parameters,omitempty"` +} + +type toolChoice struct { + Type string `json:"type"` + Function function `json:"function"` +} + +type chatCompletionResponse struct { + Id string `json:"id,omitempty"` + Choices []chatCompletionChoice `json:"choices"` + Created int64 `json:"created,omitempty"` + Model string `json:"model,omitempty"` + SystemFingerprint string `json:"system_fingerprint,omitempty"` + Object string `json:"object,omitempty"` + Usage usage `json:"usage,omitempty"` +} + +type chatCompletionChoice struct { + Index int `json:"index"` + Message *chatMessage `json:"message,omitempty"` + Delta *chatMessage `json:"delta,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` +} + +type usage struct { + PromptTokens int `json:"prompt_tokens,omitempty"` + CompletionTokens int `json:"completion_tokens,omitempty"` + TotalTokens int `json:"total_tokens,omitempty"` +} + +type chatMessage struct { + Name string `json:"name,omitempty"` + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` + ToolCalls []toolCall `json:"tool_calls,omitempty"` +} + +func (m *chatMessage) IsEmpty() bool { + if m.Content != "" { + return false + } + if len(m.ToolCalls) != 0 { + nonEmpty := false + for _, toolCall := range m.ToolCalls { + if !toolCall.Function.IsEmpty() { + nonEmpty = true + break + } + } + if nonEmpty { + return false + } + } + return true +} + +type toolCall struct { + Index int `json:"index"` + Id string `json:"id"` + Type string `json:"type"` + Function functionCall `json:"function"` +} + +type functionCall struct { + Id string `json:"id"` + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +func (m *functionCall) IsEmpty() bool { + return m.Name == "" && m.Arguments == "" +} + +type streamEvent struct { + Id string `json:"id"` + Event string `json:"event"` + Data string `json:"data"` + HttpStatus string `json:"http_status"` +} + +func (e *streamEvent) setValue(key, value string) { + switch key { + case streamEventIdItemKey: + e.Id = value + case streamEventNameItemKey: + e.Event = value + case streamDataItemKey: + e.Data = value + case streamBuiltInItemKey: + if strings.HasPrefix(value, streamHttpStatusValuePrefix) { + e.HttpStatus = value[len(streamHttpStatusValuePrefix):] + } + } +} + +type embeddingsRequest struct { + Input interface{} `json:"input"` + Model string `json:"model"` + EncodingFormat string `json:"encoding_format,omitempty"` + Dimensions int `json:"dimensions,omitempty"` + User string `json:"user,omitempty"` +} + +type embeddingsResponse struct { + Object string `json:"object"` + Data []embedding `json:"data"` + Model string `json:"model"` + Usage usage `json:"usage"` +} + +type embedding struct { + Object string `json:"object"` + Index int `json:"index"` + Embedding []float64 `json:"embedding"` +} + +func (r embeddingsRequest) ParseInput() []string { + if r.Input == nil { + return nil + } + var input []string + switch r.Input.(type) { + case string: + input = []string{r.Input.(string)} + case []any: + input = make([]string, 0, len(r.Input.([]any))) + for _, item := range r.Input.([]any) { + if str, ok := item.(string); ok { + input = append(input, str) + } + } + } + return input +} diff --git a/plugins/wasm-go/extensions/ai-json-resp/util.go b/plugins/wasm-go/extensions/ai-json-resp/util.go new file mode 100644 index 0000000000..3403c4cbfa --- /dev/null +++ b/plugins/wasm-go/extensions/ai-json-resp/util.go @@ -0,0 +1,33 @@ +package main + +func GetMaxDepth(data interface{}) int { + type item struct { + value interface{} + depth int + } + + maxDepth := 0 + stack := []item{{value: data, depth: 1}} + + for len(stack) > 0 { + currentItem := stack[len(stack)-1] + stack = stack[:len(stack)-1] + + if currentItem.depth > maxDepth { + maxDepth = currentItem.depth + } + + switch v := currentItem.value.(type) { + case map[string]interface{}: + for _, value := range v { + stack = append(stack, item{value: value, depth: currentItem.depth + 1}) + } + case []interface{}: + for _, value := range v { + stack = append(stack, item{value: value, depth: currentItem.depth + 1}) + } + } + } + + return maxDepth +}