refactor: refactor relay part (#957)

* refactor: refactor relay part

* refactor: refactor config part
This commit is contained in:
JustSong
2024-01-21 23:21:42 +08:00
committed by GitHub
parent e2ed0399f0
commit 2d760d4a01
81 changed files with 1795 additions and 1459 deletions

19
relay/util/billing.go Normal file
View File

@@ -0,0 +1,19 @@
package util
import (
"context"
"one-api/common/logger"
"one-api/model"
)
func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int, tokenId int) {
if preConsumedQuota != 0 {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
if err != nil {
logger.Error(ctx, "error return pre-consumed quota: "+err.Error())
}
}(ctx)
}
}

View File

@@ -7,6 +7,8 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/common/config"
"one-api/common/logger"
"one-api/model"
"one-api/relay/channel/openai"
"strconv"
@@ -16,7 +18,7 @@ import (
)
func ShouldDisableChannel(err *openai.Error, statusCode int) bool {
if !common.AutomaticDisableChannelEnabled {
if !config.AutomaticDisableChannelEnabled {
return false
}
if err == nil {
@@ -32,7 +34,7 @@ func ShouldDisableChannel(err *openai.Error, statusCode int) bool {
}
func ShouldEnableChannel(err error, openAIErr *openai.Error) bool {
if !common.AutomaticEnableChannelEnabled {
if !config.AutomaticEnableChannelEnabled {
return false
}
if err != nil {
@@ -138,11 +140,11 @@ func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuo
// quotaDelta is remaining quota to be consumed
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
logger.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
logger.SysError("error update user quota cache: " + err.Error())
}
// totalQuota is total quota consumed
if totalQuota != 0 {
@@ -152,11 +154,11 @@ func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuo
model.UpdateChannelUsedQuota(channelId, totalQuota)
}
if totalQuota <= 0 {
common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
logger.Error(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
}
}
func GetAPIVersion(c *gin.Context) string {
func GetAzureAPIVersion(c *gin.Context) string {
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {

View File

@@ -2,7 +2,7 @@ package util
import (
"net/http"
"one-api/common"
"one-api/common/config"
"time"
)
@@ -10,11 +10,11 @@ var HTTPClient *http.Client
var ImpatientHTTPClient *http.Client
func init() {
if common.RelayTimeout == 0 {
if config.RelayTimeout == 0 {
HTTPClient = &http.Client{}
} else {
HTTPClient = &http.Client{
Timeout: time.Duration(common.RelayTimeout) * time.Second,
Timeout: time.Duration(config.RelayTimeout) * time.Second,
}
}

View File

@@ -0,0 +1,12 @@
package util
func GetMappedModelName(modelName string, mapping map[string]string) (string, bool) {
if mapping == nil {
return modelName, false
}
mappedModelName := mapping[modelName]
if mappedModelName != "" {
return mappedModelName, true
}
return modelName, false
}

44
relay/util/relay_meta.go Normal file
View File

@@ -0,0 +1,44 @@
package util
import (
"github.com/gin-gonic/gin"
"one-api/common"
"strings"
)
type RelayMeta struct {
ChannelType int
ChannelId int
TokenId int
TokenName string
UserId int
Group string
ModelMapping map[string]string
BaseURL string
APIVersion string
APIKey string
Config map[string]string
}
func GetRelayMeta(c *gin.Context) *RelayMeta {
meta := RelayMeta{
ChannelType: c.GetInt("channel"),
ChannelId: c.GetInt("channel_id"),
TokenId: c.GetInt("token_id"),
TokenName: c.GetString("token_name"),
UserId: c.GetInt("id"),
Group: c.GetString("group"),
ModelMapping: c.GetStringMapString("model_mapping"),
BaseURL: c.GetString("base_url"),
APIVersion: c.GetString("api_version"),
APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
Config: nil,
}
if meta.ChannelType == common.ChannelTypeAzure {
meta.APIVersion = GetAzureAPIVersion(c)
}
if meta.BaseURL == "" {
meta.BaseURL = common.ChannelBaseURLs[meta.ChannelType]
}
return &meta
}

37
relay/util/validation.go Normal file
View File

@@ -0,0 +1,37 @@
package util
import (
"errors"
"math"
"one-api/relay/channel/openai"
"one-api/relay/constant"
)
func ValidateTextRequest(textRequest *openai.GeneralOpenAIRequest, relayMode int) error {
if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 {
return errors.New("max_tokens is invalid")
}
if textRequest.Model == "" {
return errors.New("model is required")
}
switch relayMode {
case constant.RelayModeCompletions:
if textRequest.Prompt == "" {
return errors.New("field prompt is required")
}
case constant.RelayModeChatCompletions:
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
return errors.New("field messages is required")
}
case constant.RelayModeEmbeddings:
case constant.RelayModeModerations:
if textRequest.Input == "" {
return errors.New("field input is required")
}
case constant.RelayModeEdits:
if textRequest.Instruction == "" {
return errors.New("field instruction is required")
}
}
return nil
}