Skip to content

Commit

Permalink
merge upstream
Browse files Browse the repository at this point in the history
Signed-off-by: wozulong <>
  • Loading branch information
wozulong committed Aug 6, 2024
2 parents 9186907 + 6787873 commit 1c37130
Show file tree
Hide file tree
Showing 24 changed files with 242 additions and 112 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
- `GET_MEDIA_TOKEN`:是统计图片token,默认为 `true`,关闭后将不再在本地计算图片token,可能会导致和上游计费不同,此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。
- `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`)情况下统计图片token,默认为 `true`
- `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认为 `true`,关闭后将不会更新任务进度。
- `GEMINI_MODEL_MAP`:Gemini模型指定版本(v1/v1beta),使用“模型:版本”指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置

## 部署
### 部署要求
Expand Down
1 change: 1 addition & 0 deletions common/model-ratio.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ var defaultModelPrice = map[string]float64{
"mj_describe": 0.05,
"mj_upscale": 0.05,
"swap_face": 0.05,
"mj_upload": 0.05,
}

var (
Expand Down
29 changes: 29 additions & 0 deletions constant/env.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package constant

import (
"fmt"
"one-api/common"
"os"
"strings"
)

var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30)
Expand All @@ -15,3 +18,29 @@ var GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)

var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)

var GeminiModelMap = map[string]string{
"gemini-1.5-pro-latest": "v1beta",
"gemini-1.5-pro-001": "v1beta",
"gemini-1.5-pro": "v1beta",
"gemini-1.5-pro-exp-0801": "v1beta",
"gemini-1.5-flash-latest": "v1beta",
"gemini-1.5-flash-001": "v1beta",
"gemini-1.5-flash": "v1beta",
"gemini-ultra": "v1beta",
}

func InitEnv() {
modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
if modelVersionMapStr == "" {
return
}
for _, pair := range strings.Split(modelVersionMapStr, ",") {
parts := strings.Split(pair, ":")
if len(parts) == 2 {
GeminiModelMap[parts[0]] = parts[1]
} else {
common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
}
}
}
2 changes: 2 additions & 0 deletions constant/midjourney.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ const (
MjActionLowVariation = "LOW_VARIATION"
MjActionPan = "PAN"
MjActionSwapFace = "SWAP_FACE"
MjActionUpload = "UPLOAD"
)

var MidjourneyModel2Action = map[string]string{
Expand All @@ -45,4 +46,5 @@ var MidjourneyModel2Action = map[string]string{
"mj_low_variation": MjActionLowVariation,
"mj_pan": MjActionPan,
"swap_face": MjActionSwapFace,
"mj_upload": MjActionUpload,
}
2 changes: 1 addition & 1 deletion controller/channel-test.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ func testAllChannels(notify bool) error {
}

// parse *int to bool
if channel.AutoBan != nil && *channel.AutoBan == 0 {
if !channel.GetAutoBan() {
ban = false
}

Expand Down
98 changes: 65 additions & 33 deletions controller/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package controller

import (
"bytes"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
Expand Down Expand Up @@ -39,44 +40,35 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode

func Relay(c *gin.Context) {
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
retryTimes := common.RetryTimes
requestId := c.GetString(common.RequestIdKey)
channelId := c.GetInt("channel_id")
channelType := c.GetInt("channel_type")
channelName := c.GetString("channel_name")
group := c.GetString("group")
originalModel := c.GetString("original_model")
openaiErr := relayHandler(c, relayMode)
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
if openaiErr != nil {
go processChannelError(c, channelId, channelType, channelName, openaiErr)
} else {
retryTimes = 0
}
for i := 0; shouldRetry(c, channelId, openaiErr, retryTimes) && i < retryTimes; i++ {
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
var openaiErr *dto.OpenAIErrorWithStatusCode

for i := 0; i <= common.RetryTimes; i++ {
channel, err := getChannel(c, group, originalModel, i)
if err != nil {
common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
common.LogError(c, err.Error())
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
break
}
channelId = channel.Id
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channel.Id))
c.Set("use_channel", useChannel)
common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
middleware.SetupContextForSelectedChannel(c, channel, originalModel)

requestBody, err := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
openaiErr = relayHandler(c, relayMode)
if openaiErr != nil {
go processChannelError(c, channel.Id, channel.Type, channel.Name, openaiErr)
openaiErr = relayRequest(c, relayMode, channel)

if openaiErr == nil {
return // 成功处理请求,直接返回
}

go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)

if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
break
}
}
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c.Request.Context(), retryLogStr)
common.LogInfo(c, retryLogStr)
}

if openaiErr != nil {
Expand All @@ -90,7 +82,42 @@ func Relay(c *gin.Context) {
}
}

func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return relayHandler(c, relayMode)
}

func addUsedChannel(c *gin.Context, channelId int) {
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel)
}

func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) {
if retryCount == 0 {
autoBan := c.GetBool("auto_ban")
autoBanInt := 1
if !autoBan {
autoBanInt = 0
}
return &model.Channel{
Id: c.GetInt("channel_id"),
Type: c.GetInt("channel_type"),
Name: c.GetString("channel_name"),
AutoBan: &autoBanInt,
}, nil
}
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount)
if err != nil {
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
}
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
return channel, nil
}

func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
if openaiErr == nil {
return false
}
Expand All @@ -114,6 +141,10 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
return true
}
if openaiErr.StatusCode == http.StatusBadRequest {
channelType := c.GetInt("channel_type")
if channelType == common.ChannelTypeAnthropic {
return true
}
return false
}
if openaiErr.StatusCode == 408 {
Expand All @@ -129,9 +160,10 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
return true
}

func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, err *dto.OpenAIErrorWithStatusCode) {
autoBan := c.GetBool("auto_ban")
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *dto.OpenAIErrorWithStatusCode) {
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
if service.ShouldDisableChannel(channelType, err) && autoBan {
service.DisableChannel(channelId, channelName, err.Error.Message)
}
Expand Down Expand Up @@ -208,14 +240,14 @@ func RelayTask(c *gin.Context) {
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
if err != nil {
common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
break
}
channelId = channel.Id
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel)
common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
middleware.SetupContextForSelectedChannel(c, channel, originalModel)

requestBody, err := common.GetRequestBody(c)
Expand All @@ -225,7 +257,7 @@ func RelayTask(c *gin.Context) {
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c.Request.Context(), retryLogStr)
common.LogInfo(c, retryLogStr)
}
if taskErr != nil {
if taskErr.StatusCode == http.StatusTooManyRequests {
Expand Down
6 changes: 3 additions & 3 deletions controller/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -806,11 +806,11 @@ type topUpRequest struct {
Key string `json:"key"`
}

var lock = sync.Mutex{}
var topUpLock = sync.Mutex{}

func TopUp(c *gin.Context) {
lock.Lock()
defer lock.Unlock()
topUpLock.Lock()
defer topUpLock.Unlock()
req := topUpRequest{}
err := c.ShouldBindJSON(&req)
if err != nil {
Expand Down
6 changes: 6 additions & 0 deletions dto/midjourney.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ type MidjourneyResponse struct {
Result string `json:"result"`
}

type MidjourneyUploadResponse struct {
Code int `json:"code"`
Description string `json:"description"`
Result []string `json:"result"`
}

type MidjourneyResponseWithStatusCode struct {
StatusCode int `json:"statusCode"`
Response MidjourneyResponse
Expand Down
2 changes: 2 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ func main() {
common.FatalLog("failed to initialize Redis: " + err.Error())
}

// Initialize constants
constant.InitEnv()
// Initialize options
model.InitOptionMap()
if common.RedisEnabled {
Expand Down
6 changes: 6 additions & 0 deletions middleware/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,12 @@ func TokenAuth() func(c *gin.Context) {
key = parts[0]
}
token, err := model.ValidateUserToken(key)
if token != nil {
id := c.GetInt("id")
if id == 0 {
c.Set("id", token.Id)
}
}
if err != nil {
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
return
Expand Down
8 changes: 1 addition & 7 deletions middleware/distributor.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,19 +184,13 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
if channel == nil {
return
}
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
c.Set("channel_type", channel.Type)
ban := true
// parse *int to bool
if channel.AutoBan != nil && *channel.AutoBan == 0 {
ban = false
}
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
c.Set("channel_organization", *channel.OpenAIOrganization)
}
c.Set("auto_ban", ban)
c.Set("auto_ban", channel.GetAutoBan())
c.Set("model_mapping", channel.GetModelMapping())
c.Set("status_code_mapping", channel.GetStatusCodeMapping())
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
Expand Down
7 changes: 7 additions & 0 deletions model/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) {
channel.OtherInfo = string(otherInfoBytes)
}

func (channel *Channel) GetAutoBan() bool {
if channel.AutoBan == nil {
return false
}
return *channel.AutoBan == 1
}

func (channel *Channel) Save() error {
return DB.Save(channel).Error
}
Expand Down
Loading

0 comments on commit 1c37130

Please sign in to comment.