Skip to content

Commit

Permalink
merge upstream
Browse files Browse the repository at this point in the history
Signed-off-by: wozulong <>
  • Loading branch information
wozulong committed Jul 1, 2024
2 parents 61006be + 584eefe commit 895ee09
Show file tree
Hide file tree
Showing 41 changed files with 2,131 additions and 1,662 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发,感谢原作者的无私奉献。
> 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。

> [!WARNING]
> 本项目为个人学习使用,不保证稳定性,且不提供任何技术支持,使用者必须在遵循 OpenAI 的使用条款以及法律法规的情况下使用,不得用于非法用途。
> 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
Expand Down Expand Up @@ -85,8 +83,13 @@
```
可以实现400错误转为500错误,从而重试

## 比原版One API多出的配置
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒

## 部署
### 部署要求
- 本地数据库(默认):SQLite(Docker 部署默认使用 SQLite,必须挂载 `/data` 目录到宿主机)
- 远程数据库:MySQL 版本 >= 5.7.8,PgSQL 版本 >= 9.6
### 基于 Docker 进行部署
```shell
# 使用 SQLite 的部署命令:
Expand Down
12 changes: 6 additions & 6 deletions common/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,14 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second

var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 60) // unit is second
var SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60) // unit is second

var BatchUpdateEnabled = false
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
var BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)

var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second
var RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0) // unit is second

var GeminiSafetySetting = GetOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
var GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")

const (
RequestIdKey = "X-Oneapi-Request-Id"
Expand All @@ -150,10 +150,10 @@ var (
// All duration's unit is seconds
// Shouldn't larger then RateLimitKeyExpirationDuration
var (
GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitDuration int64 = 3 * 60

GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration int64 = 3 * 60

UploadRateLimitNum = 10
Expand Down
26 changes: 26 additions & 0 deletions common/env.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package common

import (
"fmt"
"os"
"strconv"
)

func GetEnvOrDefault(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.Atoi(os.Getenv(env))
if err != nil {
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
return defaultValue
}
return num
}

func GetEnvOrDefaultString(env string, defaultValue string) string {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env)
}
19 changes: 19 additions & 0 deletions common/go-channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package common
import (
"fmt"
"runtime/debug"
"time"
)

func SafeGoroutine(f func()) {
Expand Down Expand Up @@ -45,3 +46,21 @@ func SafeSendString(ch chan string, value string) (closed bool) {
// If the code reaches here, then the channel was not closed.
return false
}

// SafeSendStringTimeout send, return true, else return false
func SafeSendStringTimeout(ch chan string, value string, timeout int) (closed bool) {
defer func() {
// Recover from panic if one occured. A panic would mean the channel was closed.
if recover() != nil {
closed = false
}
}()

// This will panic if the channel is closed.
select {
case ch <- value:
return true
case <-time.After(time.Duration(timeout) * time.Second):
return false
}
}
4 changes: 3 additions & 1 deletion common/group-ratio.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package common

import "encoding/json"
import (
"encoding/json"
)

var GroupRatio = map[string]float64{
"default": 1,
Expand Down
1 change: 1 addition & 0 deletions common/model-ratio.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ var defaultModelRatio = map[string]float64{
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v4.0": 1.2858,
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"360gpt-turbo": 0.0858, // ¥0.0012 / 1k tokens
"360gpt-turbo-responsibility-8k": 0.8572, // ¥0.012 / 1k tokens
Expand Down
4 changes: 3 additions & 1 deletion common/topup-ratio.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package common

import "encoding/json"
import (
"encoding/json"
)

var TopupGroupRatio = map[string]float64{
"default": 1,
Expand Down
20 changes: 0 additions & 20 deletions common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"net"
"net/http"
"net/url"
"os"
"os/exec"
"runtime"
"strconv"
Expand Down Expand Up @@ -196,25 +195,6 @@ func Max(a int, b int) int {
}
}

func GetOrDefault(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.Atoi(os.Getenv(env))
if err != nil {
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
return defaultValue
}
return num
}

func GetOrDefaultString(env string, defaultValue string) string {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env)
}

func MessageWithRequestId(message string, id string) string {
return fmt.Sprintf("%s (request id: %s)", message, id)
}
Expand Down
7 changes: 7 additions & 0 deletions constant/env.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package constant

import (
"one-api/common"
)

var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30)
22 changes: 12 additions & 10 deletions controller/channel-test.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,16 +222,18 @@ func testAllChannels(notify bool) error {
if channel.AutoBan != nil && *channel.AutoBan == 0 {
ban = false
}
openAiErrWithStatus := dto.OpenAIErrorWithStatusCode{
StatusCode: -1,
Error: *openaiErr,
LocalError: false,
}
if isChannelEnabled && service.ShouldDisableChannel(&openAiErrWithStatus) && ban {
service.DisableChannel(channel.Id, channel.Name, err.Error())
}
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr, channel.Status) {
service.EnableChannel(channel.Id, channel.Name)
if openaiErr != nil {
openAiErrWithStatus := dto.OpenAIErrorWithStatusCode{
StatusCode: -1,
Error: *openaiErr,
LocalError: false,
}
if isChannelEnabled && service.ShouldDisableChannel(&openAiErrWithStatus) && ban {
service.DisableChannel(channel.Id, channel.Name, err.Error())
}
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr, channel.Status) {
service.EnableChannel(channel.Id, channel.Name)
}
}
channel.UpdateResponseTime(milliseconds)
time.Sleep(common.RequestInterval)
Expand Down
3 changes: 0 additions & 3 deletions controller/midjourney.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@ import (
)

func UpdateMidjourneyTaskBulk() {
if !common.IsMasterNode {
return
}
//imageModel := "midjourney"
ctx := context.TODO()
for {
Expand Down
11 changes: 0 additions & 11 deletions dto/pricing.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,3 @@ type OpenAIModels struct {
Root string `json:"root"`
Parent *string `json:"parent"`
}

type ModelPricing struct {
Available bool `json:"available"`
ModelName string `json:"model_name"`
QuotaType int `json:"quota_type"`
ModelRatio float64 `json:"model_ratio"`
ModelPrice float64 `json:"model_price"`
OwnerBy string `json:"owner_by"`
CompletionRatio float64 `json:"completion_ratio"`
EnableGroup []string `json:"enable_group,omitempty"`
}
14 changes: 8 additions & 6 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,14 @@ func main() {
}
go controller.AutomaticallyTestChannels(frequency)
}
common.SafeGoroutine(func() {
controller.UpdateMidjourneyTaskBulk()
})
common.SafeGoroutine(func() {
controller.UpdateTaskBulk()
})
if common.IsMasterNode {
common.SafeGoroutine(func() {
controller.UpdateMidjourneyTaskBulk()
})
common.SafeGoroutine(func() {
controller.UpdateTaskBulk()
})
}
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
common.BatchUpdateEnabled = true
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
Expand Down
7 changes: 6 additions & 1 deletion model/ability.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ func getPriority(group string, model string, retry int) (int, error) {
return 0, err
}

if len(priorities) == 0 {
// 如果没有查询到优先级,则返回错误
return 0, errors.New("数据库一致性被破坏")
}

// 确定要使用的优先级
var priorityToUse int
if retry >= len(priorities) {
Expand Down Expand Up @@ -199,7 +204,7 @@ func FixAbility() (int, error) {

// Use channelIds to find channel not in abilities table
var abilityChannelIds []int
err = DB.Model(&Ability{}).Pluck("channel_id", &abilityChannelIds).Error
err = DB.Table("abilities").Distinct("channel_id").Pluck("channel_id", &abilityChannelIds).Error
if err != nil {
common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error()))
return 0, err
Expand Down
6 changes: 3 additions & 3 deletions model/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ func InitDB() (err error) {
if err != nil {
return err
}
sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60)))
sqlDB.SetMaxIdleConns(common.GetEnvOrDefault("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(common.GetEnvOrDefault("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetEnvOrDefault("SQL_MAX_LIFETIME", 60)))

if !common.IsMasterNode {
return nil
Expand Down
22 changes: 16 additions & 6 deletions model/pricing.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,36 @@ package model

import (
"one-api/common"
"one-api/dto"
"sync"
"time"
)

type Pricing struct {
Available bool `json:"available"`
ModelName string `json:"model_name"`
QuotaType int `json:"quota_type"`
ModelRatio float64 `json:"model_ratio"`
ModelPrice float64 `json:"model_price"`
OwnerBy string `json:"owner_by"`
CompletionRatio float64 `json:"completion_ratio"`
EnableGroup []string `json:"enable_group,omitempty"`
}

var (
pricingMap []dto.ModelPricing
pricingMap []Pricing
lastGetPricingTime time.Time
updatePricingLock sync.Mutex
)

func GetPricing(group string) []dto.ModelPricing {
func GetPricing(group string) []Pricing {
updatePricingLock.Lock()
defer updatePricingLock.Unlock()

if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
updatePricing()
}
if group != "" {
userPricingMap := make([]dto.ModelPricing, 0)
userPricingMap := make([]Pricing, 0)
models := GetGroupModels(group)
for _, pricing := range pricingMap {
if !common.StringsContains(models, pricing.ModelName) {
Expand All @@ -42,9 +52,9 @@ func updatePricing() {
allModels[model] = i
}

pricingMap = make([]dto.ModelPricing, 0)
pricingMap = make([]Pricing, 0)
for model, _ := range allModels {
pricing := dto.ModelPricing{
pricing := Pricing{
Available: true,
ModelName: model,
}
Expand Down
6 changes: 6 additions & 0 deletions relay/channel/aws/relay-aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"one-api/relay/channel/claude"
relaycommon "one-api/relay/common"
"strings"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
Expand Down Expand Up @@ -156,6 +157,7 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i
var usage relaymodel.Usage
var id string
var model string
isFirst := true
createdTime := common.GetTimestamp()
c.Stream(func(w io.Writer) bool {
event, ok := <-stream.Events()
Expand All @@ -166,6 +168,10 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i

switch v := event.(type) {
case *types.ResponseStreamMemberChunk:
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
}
claudeResp := new(claude.ClaudeResponse)
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion relay/channel/claude/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request

func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = claudeStreamHandler(a.RequestMode, info.UpstreamModelName, info.PromptTokens, c, resp)
err, usage = claudeStreamHandler(c, resp, info, a.RequestMode)
} else {
err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName)
}
Expand Down
1 change: 1 addition & 0 deletions relay/channel/claude/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ var ModelList = []string{
"claude-3-sonnet-20240229",
"claude-3-opus-20240229",
"claude-3-haiku-20240307",
"claude-3-5-sonnet-20240620",
}

var ChannelName = "claude"
Loading

0 comments on commit 895ee09

Please sign in to comment.