Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement apiToken failover mechanism #1256

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions plugins/wasm-go/extensions/ai-proxy/config/config.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package config

import (
"github.com/tidwall/gjson"

"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider"
"github.com/tidwall/gjson"
)

// @Name ai-proxy
Expand Down
32 changes: 29 additions & 3 deletions plugins/wasm-go/extensions/ai-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,17 @@ func main() {
}

func parseConfig(json gjson.Result, pluginConfig *config.PluginConfig, log wrapper.Log) error {
// log.Debugf("loading config: %s", json.String())

pluginConfig.FromJson(json)
if err := pluginConfig.Validate(); err != nil {
return err
}
if err := pluginConfig.Complete(); err != nil {
return err
}

providerConfig := pluginConfig.GetProviderConfig()
providerConfig.SetApiTokensFailover(log)

return nil
}

Expand All @@ -72,8 +74,23 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
ctx.SetContext(ctxKeyApiName, apiName)

if handler, ok := activeProvider.(provider.RequestHeadersHandler); ok {
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
ctx.DisableReroute()

providerConfig := pluginConfig.GetProviderConfig()
apiTokenInUse := providerConfig.GetRandomToken()
if providerConfig.IsFailoverEnabled() {
// Use the health check token if it is a health check request.
if apiTokenHealthCheck, _ := proxywasm.GetHttpRequestHeader("ApiToken-Health-Check"); apiTokenHealthCheck != "" {
apiTokenInUse = apiTokenHealthCheck
} else {
// if enable apiToken failover, only use available apiToken
apiTokenInUse = providerConfig.GetGlobalRandomToken(log)
}
}
log.Debugf("[onHttpRequestHeader] use apiToken %s to send request", apiTokenInUse)
ctx.SetContext(provider.ApiTokenInUse, apiTokenInUse)

hasRequestBody := wrapper.HasRequestBody()
action, err := handler.OnRequestHeaders(ctx, apiName, log)
if err == nil {
Expand All @@ -85,6 +102,7 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
}
return action
}

_ = util.SendResponse(500, "ai-proxy.proc_req_headers_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to process request headers: %v", err))
return types.ActionContinue
}
Expand Down Expand Up @@ -145,6 +163,14 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo
log.Errorf("unable to load :status header from response: %v", err)
}
ctx.DontReadResponseBody()

providerConfig := pluginConfig.GetProviderConfig()
// If apiToken failover is enabled and the request is not a health check request, handle unavailable apiToken.
if providerConfig.IsFailoverEnabled() && ctx.GetContext(provider.ApiTokenHealthCheck) == nil {
unavailableApiToken := ctx.GetContext(provider.ApiTokenInUse).(string)
providerConfig.HandleUnavailableApiToken(unavailableApiToken, log)
}

return types.ActionContinue
}

Expand Down
2 changes: 1 addition & 1 deletion plugins/wasm-go/extensions/ai-proxy/provider/ai360.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
_ = util.OverwriteRequestHost(ai360Domain)
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
_ = proxywasm.ReplaceHttpRequestHeader("Authorization", m.config.GetRandomToken())
_ = proxywasm.ReplaceHttpRequestHeader("Authorization", ctx.GetContext(ApiTokenInUse).(string))
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
}
Expand Down
2 changes: 1 addition & 1 deletion plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api
}
_ = util.OverwriteRequestPath(baichuanChatCompletionPath)
_ = util.OverwriteRequestHost(baichuanDomain)
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
_ = util.OverwriteRequestAuthorization("Bearer " + ctx.GetContext(ApiTokenInUse).(string))
CH3CHO marked this conversation as resolved.
Show resolved Hide resolved
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
return types.ActionContinue, nil
}
Expand Down
8 changes: 4 additions & 4 deletions plugins/wasm-go/extensions/ai-proxy/provider/baidu.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func (b *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
return types.ActionContinue, errors.New("request model is empty")
}
// 根据模型重写requestPath
path := b.getRequestPath(request.Model)
path := b.getRequestPath(ctx, request.Model)
_ = util.OverwriteRequestPath(path)

if b.config.context == nil {
Expand Down Expand Up @@ -126,7 +126,7 @@ func (b *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
}
request.Model = mappedModel
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
path := b.getRequestPath(mappedModel)
path := b.getRequestPath(ctx, mappedModel)
_ = util.OverwriteRequestPath(path)

if b.config.context == nil {
Expand Down Expand Up @@ -226,13 +226,13 @@ type baiduTextGenRequest struct {
UserId string `json:"user_id,omitempty"`
}

func (b *baiduProvider) getRequestPath(baiduModel string) string {
func (b *baiduProvider) getRequestPath(ctx wrapper.HttpContext, baiduModel string) string {
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
suffix, ok := baiduModelToPathSuffixMap[baiduModel]
if !ok {
suffix = baiduModel
}
return fmt.Sprintf("/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/%s?access_token=%s", suffix, b.config.GetRandomToken())
return fmt.Sprintf("/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/%s?access_token=%s", suffix, ctx.GetContext(ApiTokenInUse).(string))
}

func (b *baiduProvider) setSystemContent(request *baiduTextGenRequest, content string) {
Expand Down
2 changes: 1 addition & 1 deletion plugins/wasm-go/extensions/ai-proxy/provider/claude.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa

_ = util.OverwriteRequestPath(claudeChatCompletionPath)
_ = util.OverwriteRequestHost(claudeDomain)
_ = proxywasm.ReplaceHttpRequestHeader("x-api-key", c.config.GetRandomToken())
_ = proxywasm.ReplaceHttpRequestHeader("x-api-key", ctx.GetContext(ApiTokenInUse).(string))

if c.config.claudeVersion == "" {
c.config.claudeVersion = defaultVersion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName A
}
_ = util.OverwriteRequestPath(strings.Replace(cloudflareChatCompletionPath, "{account_id}", c.config.cloudflareAccountId, 1))
_ = util.OverwriteRequestHost(cloudflareDomain)
_ = util.OverwriteRequestAuthorization("Bearer " + c.config.GetRandomToken())
_ = util.OverwriteRequestAuthorization("Bearer " + ctx.GetContext(ApiTokenInUse).(string))

_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
Expand Down
2 changes: 1 addition & 1 deletion plugins/wasm-go/extensions/ai-proxy/provider/deepl.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestPath(deeplChatCompletionPath)
_ = util.OverwriteRequestAuthorization("DeepL-Auth-Key " + d.config.GetRandomToken())
_ = util.OverwriteRequestAuthorization("DeepL-Auth-Key " + ctx.GetContext(ApiTokenInUse).(string))
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
return types.HeaderStopIteration, nil
Expand Down
2 changes: 1 addition & 1 deletion plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api
}
_ = util.OverwriteRequestPath(deepseekChatCompletionPath)
_ = util.OverwriteRequestHost(deepseekDomain)
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
_ = util.OverwriteRequestAuthorization("Bearer " + ctx.GetContext(ApiTokenInUse).(string))
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
return types.ActionContinue, nil
}
Expand Down
Loading
Loading