refactor: use adaptor to do relay & test

This commit is contained in:
JustSong
2024-02-18 00:15:31 +08:00
parent d548a01c59
commit 1aa374ccfb
63 changed files with 1452 additions and 1332 deletions

View File

@@ -12,19 +12,21 @@ import (
"github.com/songquanpeng/one-api/relay/channel/ali"
"github.com/songquanpeng/one-api/relay/channel/anthropic"
"github.com/songquanpeng/one-api/relay/channel/baidu"
"github.com/songquanpeng/one-api/relay/channel/google"
"github.com/songquanpeng/one-api/relay/channel/gemini"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/channel/palm"
"github.com/songquanpeng/one-api/relay/channel/tencent"
"github.com/songquanpeng/one-api/relay/channel/xunfei"
"github.com/songquanpeng/one-api/relay/channel/zhipu"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
"strings"
)
func GetRequestURL(requestURL string, meta *util.RelayMeta, textRequest *openai.GeneralOpenAIRequest) (string, error) {
func GetRequestURL(requestURL string, meta *util.RelayMeta, textRequest *model.GeneralOpenAIRequest) (string, error) {
fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType)
switch meta.APIType {
case constant.APITypeOpenAI:
@@ -43,7 +45,7 @@ func GetRequestURL(requestURL string, meta *util.RelayMeta, textRequest *openai.
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
fullRequestURL = util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType)
}
case constant.APITypeClaude:
case constant.APITypeAnthropic:
fullRequestURL = fmt.Sprintf("%s/v1/complete", meta.BaseURL)
case constant.APITypeBaidu:
switch textRequest.Model {
@@ -92,19 +94,10 @@ func GetRequestURL(requestURL string, meta *util.RelayMeta, textRequest *openai.
return fullRequestURL, nil
}
func GetRequestBody(c *gin.Context, textRequest openai.GeneralOpenAIRequest, isModelMapped bool, apiType int, relayMode int) (io.Reader, error) {
func GetRequestBody(c *gin.Context, textRequest model.GeneralOpenAIRequest, isModelMapped bool, apiType int, relayMode int) (io.Reader, error) {
var requestBody io.Reader
if isModelMapped {
jsonStr, err := json.Marshal(textRequest)
if err != nil {
return nil, err
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
switch apiType {
case constant.APITypeClaude:
case constant.APITypeAnthropic:
claudeRequest := anthropic.ConvertRequest(textRequest)
jsonStr, err := json.Marshal(claudeRequest)
if err != nil {
@@ -127,14 +120,14 @@ func GetRequestBody(c *gin.Context, textRequest openai.GeneralOpenAIRequest, isM
}
requestBody = bytes.NewBuffer(jsonData)
case constant.APITypePaLM:
palmRequest := google.ConvertPaLMRequest(textRequest)
palmRequest := palm.ConvertRequest(textRequest)
jsonStr, err := json.Marshal(palmRequest)
if err != nil {
return nil, err
}
requestBody = bytes.NewBuffer(jsonStr)
case constant.APITypeGemini:
geminiChatRequest := google.ConvertGeminiRequest(textRequest)
geminiChatRequest := gemini.ConvertRequest(textRequest)
jsonStr, err := json.Marshal(geminiChatRequest)
if err != nil {
return nil, err
@@ -187,19 +180,20 @@ func GetRequestBody(c *gin.Context, textRequest openai.GeneralOpenAIRequest, isM
return nil, err
}
requestBody = bytes.NewBuffer(jsonStr)
default:
if isModelMapped {
jsonStr, err := json.Marshal(textRequest)
if err != nil {
return nil, err
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
}
return requestBody, nil
}
func SetupRequestHeaders(c *gin.Context, req *http.Request, meta *util.RelayMeta, isStream bool) {
SetupAuthHeaders(c, req, meta, isStream)
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if isStream && c.Request.Header.Get("Accept") == "" {
req.Header.Set("Accept", "text/event-stream")
}
}
func SetupAuthHeaders(c *gin.Context, req *http.Request, meta *util.RelayMeta, isStream bool) {
apiKey := meta.APIKey
switch meta.APIType {
@@ -213,7 +207,7 @@ func SetupAuthHeaders(c *gin.Context, req *http.Request, meta *util.RelayMeta, i
req.Header.Set("X-Title", "One API")
}
}
case constant.APITypeClaude:
case constant.APITypeAnthropic:
req.Header.Set("x-api-key", apiKey)
anthropicVersion := c.Request.Header.Get("anthropic-version")
if anthropicVersion == "" {
@@ -242,7 +236,7 @@ func SetupAuthHeaders(c *gin.Context, req *http.Request, meta *util.RelayMeta, i
}
}
func DoResponse(c *gin.Context, textRequest *openai.GeneralOpenAIRequest, resp *http.Response, relayMode int, apiType int, isStream bool, promptTokens int) (usage *openai.Usage, err *openai.ErrorWithStatusCode) {
func DoResponse(c *gin.Context, textRequest *model.GeneralOpenAIRequest, resp *http.Response, relayMode int, apiType int, isStream bool, promptTokens int) (usage *model.Usage, err *model.ErrorWithStatusCode) {
var responseText string
switch apiType {
case constant.APITypeOpenAI:
@@ -251,7 +245,7 @@ func DoResponse(c *gin.Context, textRequest *openai.GeneralOpenAIRequest, resp *
} else {
err, usage = openai.Handler(c, resp, promptTokens, textRequest.Model)
}
case constant.APITypeClaude:
case constant.APITypeAnthropic:
if isStream {
err, responseText = anthropic.StreamHandler(c, resp)
} else {
@@ -270,15 +264,15 @@ func DoResponse(c *gin.Context, textRequest *openai.GeneralOpenAIRequest, resp *
}
case constant.APITypePaLM:
if isStream { // PaLM2 API does not support stream
err, responseText = google.PaLMStreamHandler(c, resp)
err, responseText = palm.StreamHandler(c, resp)
} else {
err, usage = google.PaLMHandler(c, resp, promptTokens, textRequest.Model)
err, usage = palm.Handler(c, resp, promptTokens, textRequest.Model)
}
case constant.APITypeGemini:
if isStream {
err, responseText = google.StreamHandler(c, resp)
err, responseText = gemini.StreamHandler(c, resp)
} else {
err, usage = google.GeminiHandler(c, resp, promptTokens, textRequest.Model)
err, usage = gemini.Handler(c, resp, promptTokens, textRequest.Model)
}
case constant.APITypeZhipu:
if isStream {
@@ -328,7 +322,7 @@ func DoResponse(c *gin.Context, textRequest *openai.GeneralOpenAIRequest, resp *
return nil, err
}
if usage == nil && responseText != "" {
usage = &openai.Usage{}
usage = &model.Usage{}
usage.PromptTokens = promptTokens
usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens