diff --git a/middleware/auth.go b/middleware/auth.go index edd15de5..d2c9b3cf 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -143,6 +143,12 @@ func TokenAuth() func(c *gin.Context) { key = parts[0] } token, err := model.ValidateUserToken(key) + if token != nil { + id := c.GetInt("id") + if id == 0 { + c.Set("id", token.Id) + } + } if err != nil { abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error()) return diff --git a/model/token.go b/model/token.go index 27907afb..272c5734 100644 --- a/model/token.go +++ b/model/token.go @@ -51,12 +51,12 @@ func ValidateUserToken(key string) (token *Token, err error) { if token.Status == common.TokenStatusExhausted { keyPrefix := key[:3] keySuffix := key[len(key)-3:] - return nil, errors.New("该令牌额度已用尽 TokenStatusExhausted[sk-" + keyPrefix + "***" + keySuffix + "]") + return token, errors.New("该令牌额度已用尽 TokenStatusExhausted[sk-" + keyPrefix + "***" + keySuffix + "]") } else if token.Status == common.TokenStatusExpired { - return nil, errors.New("该令牌已过期") + return token, errors.New("该令牌已过期") } if token.Status != common.TokenStatusEnabled { - return nil, errors.New("该令牌状态不可用") + return token, errors.New("该令牌状态不可用") } if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() { if !common.RedisEnabled { @@ -66,7 +66,7 @@ func ValidateUserToken(key string) (token *Token, err error) { common.SysError("failed to update token status" + err.Error()) } } - return nil, errors.New("该令牌已过期") + return token, errors.New("该令牌已过期") } if !token.UnlimitedQuota && token.RemainQuota <= 0 { if !common.RedisEnabled { @@ -79,7 +79,7 @@ func ValidateUserToken(key string) (token *Token, err error) { } keyPrefix := key[:3] keySuffix := key[len(key)-3:] - return nil, errors.New(fmt.Sprintf("[sk-%s***%s] 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", keyPrefix, keySuffix, token.RemainQuota)) + return token, errors.New(fmt.Sprintf("[sk-%s***%s] 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", keyPrefix, keySuffix, token.RemainQuota)) } return token, nil }