refactor: use adaptor to do relay & test

This commit is contained in:
JustSong
2024-02-18 00:15:31 +08:00
parent d548a01c59
commit 1aa374ccfb
63 changed files with 1452 additions and 1332 deletions

View File

@@ -8,7 +8,7 @@ import (
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strconv"
@@ -17,7 +17,7 @@ import (
"github.com/gin-gonic/gin"
)
func ShouldDisableChannel(err *openai.Error, statusCode int) bool {
func ShouldDisableChannel(err *relaymodel.Error, statusCode int) bool {
if !config.AutomaticDisableChannelEnabled {
return false
}
@@ -33,7 +33,7 @@ func ShouldDisableChannel(err *openai.Error, statusCode int) bool {
return false
}
func ShouldEnableChannel(err error, openAIErr *openai.Error) bool {
func ShouldEnableChannel(err error, openAIErr *relaymodel.Error) bool {
if !config.AutomaticEnableChannelEnabled {
return false
}
@@ -47,11 +47,11 @@ func ShouldEnableChannel(err error, openAIErr *openai.Error) bool {
}
type GeneralErrorResponse struct {
Error openai.Error `json:"error"`
Message string `json:"message"`
Msg string `json:"msg"`
Err string `json:"err"`
ErrorMsg string `json:"error_msg"`
Error relaymodel.Error `json:"error"`
Message string `json:"message"`
Msg string `json:"msg"`
Err string `json:"err"`
ErrorMsg string `json:"error_msg"`
Header struct {
Message string `json:"message"`
} `json:"header"`
@@ -87,10 +87,10 @@ func (e GeneralErrorResponse) ToMessage() string {
return ""
}
func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *openai.ErrorWithStatusCode) {
ErrorWithStatusCode = &openai.ErrorWithStatusCode{
func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *relaymodel.ErrorWithStatusCode) {
ErrorWithStatusCode = &relaymodel.ErrorWithStatusCode{
StatusCode: resp.StatusCode,
Error: openai.Error{
Error: relaymodel.Error{
Message: "",
Type: "upstream_error",
Code: "bad_response_status_code",

View File

@@ -8,35 +8,41 @@ import (
)
type RelayMeta struct {
Mode int
ChannelType int
ChannelId int
TokenId int
TokenName string
UserId int
Group string
ModelMapping map[string]string
BaseURL string
APIVersion string
APIKey string
APIType int
Config map[string]string
Mode int
ChannelType int
ChannelId int
TokenId int
TokenName string
UserId int
Group string
ModelMapping map[string]string
BaseURL string
APIVersion string
APIKey string
APIType int
Config map[string]string
IsStream bool
OriginModelName string
ActualModelName string
RequestURLPath string
PromptTokens int // only for DoResponse
}
func GetRelayMeta(c *gin.Context) *RelayMeta {
meta := RelayMeta{
Mode: constant.Path2RelayMode(c.Request.URL.Path),
ChannelType: c.GetInt("channel"),
ChannelId: c.GetInt("channel_id"),
TokenId: c.GetInt("token_id"),
TokenName: c.GetString("token_name"),
UserId: c.GetInt("id"),
Group: c.GetString("group"),
ModelMapping: c.GetStringMapString("model_mapping"),
BaseURL: c.GetString("base_url"),
APIVersion: c.GetString("api_version"),
APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
Config: nil,
Mode: constant.Path2RelayMode(c.Request.URL.Path),
ChannelType: c.GetInt("channel"),
ChannelId: c.GetInt("channel_id"),
TokenId: c.GetInt("token_id"),
TokenName: c.GetString("token_name"),
UserId: c.GetInt("id"),
Group: c.GetString("group"),
ModelMapping: c.GetStringMapString("model_mapping"),
BaseURL: c.GetString("base_url"),
APIVersion: c.GetString("api_version"),
APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
Config: nil,
RequestURLPath: c.Request.URL.String(),
}
if meta.ChannelType == common.ChannelTypeAzure {
meta.APIVersion = GetAzureAPIVersion(c)

View File

@@ -2,12 +2,12 @@ package util
import (
"errors"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"math"
)
func ValidateTextRequest(textRequest *openai.GeneralOpenAIRequest, relayMode int) error {
func ValidateTextRequest(textRequest *model.GeneralOpenAIRequest, relayMode int) error {
if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 {
return errors.New("max_tokens is invalid")
}