mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-10-22 04:13:53 +00:00
refactor: use adaptor to do relay & test
This commit is contained in:
@@ -12,19 +12,21 @@ import (
|
||||
"github.com/songquanpeng/one-api/relay/channel/ali"
|
||||
"github.com/songquanpeng/one-api/relay/channel/anthropic"
|
||||
"github.com/songquanpeng/one-api/relay/channel/baidu"
|
||||
"github.com/songquanpeng/one-api/relay/channel/google"
|
||||
"github.com/songquanpeng/one-api/relay/channel/gemini"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/channel/palm"
|
||||
"github.com/songquanpeng/one-api/relay/channel/tencent"
|
||||
"github.com/songquanpeng/one-api/relay/channel/xunfei"
|
||||
"github.com/songquanpeng/one-api/relay/channel/zhipu"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func GetRequestURL(requestURL string, meta *util.RelayMeta, textRequest *openai.GeneralOpenAIRequest) (string, error) {
|
||||
func GetRequestURL(requestURL string, meta *util.RelayMeta, textRequest *model.GeneralOpenAIRequest) (string, error) {
|
||||
fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType)
|
||||
switch meta.APIType {
|
||||
case constant.APITypeOpenAI:
|
||||
@@ -43,7 +45,7 @@ func GetRequestURL(requestURL string, meta *util.RelayMeta, textRequest *openai.
|
||||
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
|
||||
fullRequestURL = util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType)
|
||||
}
|
||||
case constant.APITypeClaude:
|
||||
case constant.APITypeAnthropic:
|
||||
fullRequestURL = fmt.Sprintf("%s/v1/complete", meta.BaseURL)
|
||||
case constant.APITypeBaidu:
|
||||
switch textRequest.Model {
|
||||
@@ -92,19 +94,10 @@ func GetRequestURL(requestURL string, meta *util.RelayMeta, textRequest *openai.
|
||||
return fullRequestURL, nil
|
||||
}
|
||||
|
||||
func GetRequestBody(c *gin.Context, textRequest openai.GeneralOpenAIRequest, isModelMapped bool, apiType int, relayMode int) (io.Reader, error) {
|
||||
func GetRequestBody(c *gin.Context, textRequest model.GeneralOpenAIRequest, isModelMapped bool, apiType int, relayMode int) (io.Reader, error) {
|
||||
var requestBody io.Reader
|
||||
if isModelMapped {
|
||||
jsonStr, err := json.Marshal(textRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
} else {
|
||||
requestBody = c.Request.Body
|
||||
}
|
||||
switch apiType {
|
||||
case constant.APITypeClaude:
|
||||
case constant.APITypeAnthropic:
|
||||
claudeRequest := anthropic.ConvertRequest(textRequest)
|
||||
jsonStr, err := json.Marshal(claudeRequest)
|
||||
if err != nil {
|
||||
@@ -127,14 +120,14 @@ func GetRequestBody(c *gin.Context, textRequest openai.GeneralOpenAIRequest, isM
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
case constant.APITypePaLM:
|
||||
palmRequest := google.ConvertPaLMRequest(textRequest)
|
||||
palmRequest := palm.ConvertRequest(textRequest)
|
||||
jsonStr, err := json.Marshal(palmRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case constant.APITypeGemini:
|
||||
geminiChatRequest := google.ConvertGeminiRequest(textRequest)
|
||||
geminiChatRequest := gemini.ConvertRequest(textRequest)
|
||||
jsonStr, err := json.Marshal(geminiChatRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -187,19 +180,20 @@ func GetRequestBody(c *gin.Context, textRequest openai.GeneralOpenAIRequest, isM
|
||||
return nil, err
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
default:
|
||||
if isModelMapped {
|
||||
jsonStr, err := json.Marshal(textRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
} else {
|
||||
requestBody = c.Request.Body
|
||||
}
|
||||
}
|
||||
return requestBody, nil
|
||||
}
|
||||
|
||||
func SetupRequestHeaders(c *gin.Context, req *http.Request, meta *util.RelayMeta, isStream bool) {
|
||||
SetupAuthHeaders(c, req, meta, isStream)
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
if isStream && c.Request.Header.Get("Accept") == "" {
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
}
|
||||
}
|
||||
|
||||
func SetupAuthHeaders(c *gin.Context, req *http.Request, meta *util.RelayMeta, isStream bool) {
|
||||
apiKey := meta.APIKey
|
||||
switch meta.APIType {
|
||||
@@ -213,7 +207,7 @@ func SetupAuthHeaders(c *gin.Context, req *http.Request, meta *util.RelayMeta, i
|
||||
req.Header.Set("X-Title", "One API")
|
||||
}
|
||||
}
|
||||
case constant.APITypeClaude:
|
||||
case constant.APITypeAnthropic:
|
||||
req.Header.Set("x-api-key", apiKey)
|
||||
anthropicVersion := c.Request.Header.Get("anthropic-version")
|
||||
if anthropicVersion == "" {
|
||||
@@ -242,7 +236,7 @@ func SetupAuthHeaders(c *gin.Context, req *http.Request, meta *util.RelayMeta, i
|
||||
}
|
||||
}
|
||||
|
||||
func DoResponse(c *gin.Context, textRequest *openai.GeneralOpenAIRequest, resp *http.Response, relayMode int, apiType int, isStream bool, promptTokens int) (usage *openai.Usage, err *openai.ErrorWithStatusCode) {
|
||||
func DoResponse(c *gin.Context, textRequest *model.GeneralOpenAIRequest, resp *http.Response, relayMode int, apiType int, isStream bool, promptTokens int) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
var responseText string
|
||||
switch apiType {
|
||||
case constant.APITypeOpenAI:
|
||||
@@ -251,7 +245,7 @@ func DoResponse(c *gin.Context, textRequest *openai.GeneralOpenAIRequest, resp *
|
||||
} else {
|
||||
err, usage = openai.Handler(c, resp, promptTokens, textRequest.Model)
|
||||
}
|
||||
case constant.APITypeClaude:
|
||||
case constant.APITypeAnthropic:
|
||||
if isStream {
|
||||
err, responseText = anthropic.StreamHandler(c, resp)
|
||||
} else {
|
||||
@@ -270,15 +264,15 @@ func DoResponse(c *gin.Context, textRequest *openai.GeneralOpenAIRequest, resp *
|
||||
}
|
||||
case constant.APITypePaLM:
|
||||
if isStream { // PaLM2 API does not support stream
|
||||
err, responseText = google.PaLMStreamHandler(c, resp)
|
||||
err, responseText = palm.StreamHandler(c, resp)
|
||||
} else {
|
||||
err, usage = google.PaLMHandler(c, resp, promptTokens, textRequest.Model)
|
||||
err, usage = palm.Handler(c, resp, promptTokens, textRequest.Model)
|
||||
}
|
||||
case constant.APITypeGemini:
|
||||
if isStream {
|
||||
err, responseText = google.StreamHandler(c, resp)
|
||||
err, responseText = gemini.StreamHandler(c, resp)
|
||||
} else {
|
||||
err, usage = google.GeminiHandler(c, resp, promptTokens, textRequest.Model)
|
||||
err, usage = gemini.Handler(c, resp, promptTokens, textRequest.Model)
|
||||
}
|
||||
case constant.APITypeZhipu:
|
||||
if isStream {
|
||||
@@ -328,7 +322,7 @@ func DoResponse(c *gin.Context, textRequest *openai.GeneralOpenAIRequest, resp *
|
||||
return nil, err
|
||||
}
|
||||
if usage == nil && responseText != "" {
|
||||
usage = &openai.Usage{}
|
||||
usage = &model.Usage{}
|
||||
usage.PromptTokens = promptTokens
|
||||
usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
|
Reference in New Issue
Block a user