refactor: use adaptor to do relay & test

This commit is contained in:
JustSong
2024-02-18 00:15:31 +08:00
parent d548a01c59
commit 1aa374ccfb
63 changed files with 1452 additions and 1332 deletions

View File

@@ -1,22 +1,55 @@
package aiproxy
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
type Adaptor struct {
}
func (a *Adaptor) Auth(c *gin.Context) error {
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
return fmt.Sprintf("%s/api/library/ask", meta.BaseURL), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil
}
func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
return nil, nil
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
aiProxyLibraryRequest := ConvertRequest(*request)
aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
return aiProxyLibraryRequest, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
return nil, nil, nil
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
err, usage = Handler(c, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "aiproxy"
}

View File

@@ -0,0 +1,9 @@
package aiproxy
import "github.com/songquanpeng/one-api/relay/channel/openai"
var ModelList = []string{""}
func init() {
ModelList = openai.ModelList
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strconv"
@@ -18,7 +19,7 @@ import (
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
func ConvertRequest(request openai.GeneralOpenAIRequest) *LibraryRequest {
func ConvertRequest(request model.GeneralOpenAIRequest) *LibraryRequest {
query := ""
if len(request.Messages) != 0 {
query = request.Messages[len(request.Messages)-1].StringContent()
@@ -45,7 +46,7 @@ func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextRespon
content := response.Answer + aiProxyDocuments2Markdown(response.Documents)
choice := openai.TextResponseChoice{
Index: 0,
Message: openai.Message{
Message: model.Message{
Role: "assistant",
Content: content,
},
@@ -85,8 +86,8 @@ func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *opena
}
}
func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
var usage openai.Usage
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var usage model.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
@@ -157,7 +158,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
return nil, &usage
}
func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var AIProxyLibraryResponse LibraryResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -172,8 +173,8 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode,
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if AIProxyLibraryResponse.ErrCode != 0 {
return &openai.ErrorWithStatusCode{
Error: openai.Error{
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: AIProxyLibraryResponse.Message,
Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode),
Code: AIProxyLibraryResponse.ErrCode,

View File

@@ -1,22 +1,76 @@
package ali
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
type Adaptor struct {
}
func (a *Adaptor) Auth(c *gin.Context) error {
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL)
if meta.Mode == constant.RelayModeEmbeddings {
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL)
}
return fullRequestURL, nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
if meta.IsStream {
req.Header.Set("X-DashScope-SSE", "enable")
}
if c.GetString("plugin") != "" {
req.Header.Set("X-DashScope-Plugin", c.GetString("plugin"))
}
return nil
}
func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
return nil, nil
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch relayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
return baiduEmbeddingRequest, nil
default:
baiduRequest := ConvertRequest(*request)
return baiduRequest, nil
}
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
return nil, nil, nil
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
switch meta.Mode {
case constant.RelayModeEmbeddings:
err, usage = EmbeddingHandler(c, resp)
default:
err, usage = Handler(c, resp)
}
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "ali"
}

View File

@@ -0,0 +1,6 @@
package ali
var ModelList = []string{
"qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext",
"text-embedding-v1",
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
@@ -17,7 +18,7 @@ import (
const EnableSearchModelSuffix = "-internet"
func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
messages := make([]Message, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
@@ -44,7 +45,7 @@ func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
}
}
func ConvertEmbeddingRequest(request openai.GeneralOpenAIRequest) *EmbeddingRequest {
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
return &EmbeddingRequest{
Model: "text-embedding-v1",
Input: struct {
@@ -55,7 +56,7 @@ func ConvertEmbeddingRequest(request openai.GeneralOpenAIRequest) *EmbeddingRequ
}
}
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var aliResponse EmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
if err != nil {
@@ -68,8 +69,8 @@ func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithSta
}
if aliResponse.Code != "" {
return &openai.ErrorWithStatusCode{
Error: openai.Error{
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
@@ -95,7 +96,7 @@ func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingR
Object: "list",
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Output.Embeddings)),
Model: "text-embedding-v1",
Usage: openai.Usage{TotalTokens: response.Usage.TotalTokens},
Usage: model.Usage{TotalTokens: response.Usage.TotalTokens},
}
for _, item := range response.Output.Embeddings {
@@ -111,7 +112,7 @@ func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingR
func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,
Message: openai.Message{
Message: model.Message{
Role: "assistant",
Content: response.Output.Text,
},
@@ -122,7 +123,7 @@ func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
Usage: openai.Usage{
Usage: model.Usage{
PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
@@ -148,8 +149,8 @@ func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletions
return &response
}
func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
var usage openai.Usage
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var usage model.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
@@ -217,7 +218,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
return nil, &usage
}
func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var aliResponse ChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -232,8 +233,8 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode,
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
return &openai.ErrorWithStatusCode{
Error: openai.Error{
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,

View File

@@ -1,22 +1,61 @@
package anthropic
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
type Adaptor struct {
}
func (a *Adaptor) Auth(c *gin.Context) error {
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
return fmt.Sprintf("%s/v1/complete", meta.BaseURL), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("x-api-key", meta.APIKey)
anthropicVersion := c.Request.Header.Get("anthropic-version")
if anthropicVersion == "" {
anthropicVersion = "2023-06-01"
}
req.Header.Set("anthropic-version", anthropicVersion)
return nil
}
func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
return nil, nil
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return ConvertRequest(*request), nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
return nil, nil, nil
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
var responseText string
err, responseText = StreamHandler(c, resp)
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
} else {
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "authropic"
}

View File

@@ -0,0 +1,5 @@
package anthropic
var ModelList = []string{
"claude-instant-1", "claude-2", "claude-2.0", "claude-2.1",
}

View File

@@ -9,6 +9,7 @@ import (
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
@@ -25,7 +26,7 @@ func stopReasonClaude2OpenAI(reason string) string {
}
}
func ConvertRequest(textRequest openai.GeneralOpenAIRequest) *Request {
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
claudeRequest := Request{
Model: textRequest.Model,
Prompt: "",
@@ -72,7 +73,7 @@ func streamResponseClaude2OpenAI(claudeResponse *Response) *openai.ChatCompletio
func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,
Message: openai.Message{
Message: model.Message{
Role: "assistant",
Content: strings.TrimPrefix(claudeResponse.Completion, " "),
Name: nil,
@@ -88,7 +89,7 @@ func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
return &fullTextResponse
}
func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID())
createdTime := helper.GetTimestamp()
@@ -153,7 +154,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
return nil, responseText
}
func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) {
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -168,8 +169,8 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if claudeResponse.Error.Type != "" {
return &openai.ErrorWithStatusCode{
Error: openai.Error{
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: claudeResponse.Error.Message,
Type: claudeResponse.Error.Type,
Param: "",
@@ -179,9 +180,9 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string
}, nil
}
fullTextResponse := responseClaude2OpenAI(&claudeResponse)
fullTextResponse.Model = model
completionTokens := openai.CountTokenText(claudeResponse.Completion, model)
usage := openai.Usage{
fullTextResponse.Model = modelName
completionTokens := openai.CountTokenText(claudeResponse.Completion, modelName)
usage := model.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,

View File

@@ -1,22 +1,89 @@
package baidu
import (
"errors"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
type Adaptor struct {
}
func (a *Adaptor) Auth(c *gin.Context) error {
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
var fullRequestURL string
switch meta.ActualModelName {
case "ERNIE-Bot-4":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
case "ERNIE-Bot-8K":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k"
case "ERNIE-Bot":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
case "ERNIE-Speed":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed"
case "ERNIE-Bot-turbo":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
case "BLOOMZ-7B":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
case "Embedding-V1":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
}
var accessToken string
var err error
if accessToken, err = GetAccessToken(meta.APIKey); err != nil {
return "", err
}
fullRequestURL += "?access_token=" + accessToken
return fullRequestURL, nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil
}
func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
return nil, nil
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch relayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
return baiduEmbeddingRequest, nil
default:
baiduRequest := ConvertRequest(*request)
return baiduRequest, nil
}
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
return nil, nil, nil
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
switch meta.Mode {
case constant.RelayModeEmbeddings:
err, usage = EmbeddingHandler(c, resp)
default:
err, usage = Handler(c, resp)
}
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "baidu"
}

View File

@@ -0,0 +1,10 @@
package baidu
var ModelList = []string{
"ERNIE-Bot-4",
"ERNIE-Bot-8K",
"ERNIE-Bot",
"ERNIE-Speed",
"ERNIE-Bot-turbo",
"Embedding-V1",
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
@@ -43,7 +44,7 @@ type Error struct {
var baiduTokenStore sync.Map
func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
messages := make([]Message, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
@@ -71,7 +72,7 @@ func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,
Message: openai.Message{
Message: model.Message{
Role: "assistant",
Content: response.Result,
},
@@ -103,7 +104,7 @@ func streamResponseBaidu2OpenAI(baiduResponse *ChatStreamResponse) *openai.ChatC
return &response
}
func ConvertEmbeddingRequest(request openai.GeneralOpenAIRequest) *EmbeddingRequest {
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
return &EmbeddingRequest{
Input: request.ParseInput(),
}
@@ -126,8 +127,8 @@ func embeddingResponseBaidu2OpenAI(response *EmbeddingResponse) *openai.Embeddin
return &openAIEmbeddingResponse
}
func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
var usage openai.Usage
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var usage model.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
@@ -189,7 +190,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
return nil, &usage
}
func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var baiduResponse ChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -204,8 +205,8 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode,
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if baiduResponse.ErrorMsg != "" {
return &openai.ErrorWithStatusCode{
Error: openai.Error{
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: baiduResponse.ErrorMsg,
Type: "baidu_error",
Param: "",
@@ -226,7 +227,7 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode,
return nil, &fullTextResponse.Usage
}
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var baiduResponse EmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -241,8 +242,8 @@ func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithSta
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if baiduResponse.ErrorMsg != "" {
return &openai.ErrorWithStatusCode{
Error: openai.Error{
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: baiduResponse.ErrorMsg,
Type: "baidu_error",
Param: "",

View File

@@ -1,18 +1,18 @@
package baidu
import (
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
"time"
)
type ChatResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Result string `json:"result"`
IsTruncated bool `json:"is_truncated"`
NeedClearHistory bool `json:"need_clear_history"`
Usage openai.Usage `json:"usage"`
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Result string `json:"result"`
IsTruncated bool `json:"is_truncated"`
NeedClearHistory bool `json:"need_clear_history"`
Usage model.Usage `json:"usage"`
Error
}
@@ -37,7 +37,7 @@ type EmbeddingResponse struct {
Object string `json:"object"`
Created int64 `json:"created"`
Data []EmbeddingData `json:"data"`
Usage openai.Usage `json:"usage"`
Usage model.Usage `json:"usage"`
Error
}

51
relay/channel/common.go Normal file
View File

@@ -0,0 +1,51 @@
package channel
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) {
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if meta.IsStream && c.Request.Header.Get("Accept") == "" {
req.Header.Set("Accept", "text/event-stream")
}
}
func DoRequestHelper(a Adaptor, c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
fullRequestURL, err := a.GetRequestURL(meta)
if err != nil {
return nil, fmt.Errorf("get request url failed: %w", err)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}
err = a.SetupRequestHeader(c, req, meta)
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
resp, err := DoRequest(c, req)
if err != nil {
return nil, fmt.Errorf("do request failed: %w", err)
}
return resp, nil
}
func DoRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
resp, err := util.HTTPClient.Do(req)
if err != nil {
return nil, err
}
if resp == nil {
return nil, errors.New("resp is nil")
}
_ = req.Body.Close()
_ = c.Request.Body.Close()
return resp, nil
}

View File

@@ -0,0 +1,62 @@
package gemini
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/helper"
channelhelper "github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
type Adaptor struct {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
version := helper.AssignOrDefault(meta.APIVersion, "v1")
action := "generateContent"
if meta.IsStream {
action = "streamGenerateContent"
}
return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channelhelper.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("x-goog-api-key", meta.APIKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return ConvertRequest(*request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channelhelper.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
var responseText string
err, responseText = StreamHandler(c, resp)
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
} else {
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "google gemini"
}

View File

@@ -0,0 +1,6 @@
package gemini
var ModelList = []string{
"gemini-pro",
"gemini-pro-vision",
}

View File

@@ -1,4 +1,4 @@
package google
package gemini
import (
"bufio"
@@ -11,6 +11,7 @@ import (
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
@@ -21,14 +22,14 @@ import (
// https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn
const (
GeminiVisionMaxImageNum = 16
VisionMaxImageNum = 16
)
// Setting safety to the lowest possible values since Gemini is already powerless enough
func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRequest {
geminiRequest := GeminiChatRequest{
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
SafetySettings: []GeminiChatSafetySettings{
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
geminiRequest := ChatRequest{
Contents: make([]ChatContent, 0, len(textRequest.Messages)),
SafetySettings: []ChatSafetySettings{
{
Category: "HARM_CATEGORY_HARASSMENT",
Threshold: config.GeminiSafetySetting,
@@ -46,14 +47,14 @@ func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRe
Threshold: config.GeminiSafetySetting,
},
},
GenerationConfig: GeminiChatGenerationConfig{
GenerationConfig: ChatGenerationConfig{
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
MaxOutputTokens: textRequest.MaxTokens,
},
}
if textRequest.Functions != nil {
geminiRequest.Tools = []GeminiChatTools{
geminiRequest.Tools = []ChatTools{
{
FunctionDeclarations: textRequest.Functions,
},
@@ -61,30 +62,30 @@ func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRe
}
shouldAddDummyModelMessage := false
for _, message := range textRequest.Messages {
content := GeminiChatContent{
content := ChatContent{
Role: message.Role,
Parts: []GeminiPart{
Parts: []Part{
{
Text: message.StringContent(),
},
},
}
openaiContent := message.ParseContent()
var parts []GeminiPart
var parts []Part
imageNum := 0
for _, part := range openaiContent {
if part.Type == openai.ContentTypeText {
parts = append(parts, GeminiPart{
if part.Type == model.ContentTypeText {
parts = append(parts, Part{
Text: part.Text,
})
} else if part.Type == openai.ContentTypeImageURL {
} else if part.Type == model.ContentTypeImageURL {
imageNum += 1
if imageNum > GeminiVisionMaxImageNum {
if imageNum > VisionMaxImageNum {
continue
}
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
parts = append(parts, Part{
InlineData: &InlineData{
MimeType: mimeType,
Data: data,
},
@@ -106,9 +107,9 @@ func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRe
// If a system message is the last message, we need to add a dummy model message to make gemini happy
if shouldAddDummyModelMessage {
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
geminiRequest.Contents = append(geminiRequest.Contents, ChatContent{
Role: "model",
Parts: []GeminiPart{
Parts: []Part{
{
Text: "Okay",
},
@@ -121,12 +122,12 @@ func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRe
return &geminiRequest
}
type GeminiChatResponse struct {
Candidates []GeminiChatCandidate `json:"candidates"`
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
type ChatResponse struct {
Candidates []ChatCandidate `json:"candidates"`
PromptFeedback ChatPromptFeedback `json:"promptFeedback"`
}
func (g *GeminiChatResponse) GetResponseText() string {
func (g *ChatResponse) GetResponseText() string {
if g == nil {
return ""
}
@@ -136,23 +137,23 @@ func (g *GeminiChatResponse) GetResponseText() string {
return ""
}
type GeminiChatCandidate struct {
Content GeminiChatContent `json:"content"`
FinishReason string `json:"finishReason"`
Index int64 `json:"index"`
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
type ChatCandidate struct {
Content ChatContent `json:"content"`
FinishReason string `json:"finishReason"`
Index int64 `json:"index"`
SafetyRatings []ChatSafetyRating `json:"safetyRatings"`
}
type GeminiChatSafetyRating struct {
type ChatSafetyRating struct {
Category string `json:"category"`
Probability string `json:"probability"`
}
type GeminiChatPromptFeedback struct {
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
type ChatPromptFeedback struct {
SafetyRatings []ChatSafetyRating `json:"safetyRatings"`
}
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *openai.TextResponse {
func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion",
@@ -162,7 +163,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *openai.TextRespons
for i, candidate := range response.Candidates {
choice := openai.TextResponseChoice{
Index: i,
Message: openai.Message{
Message: model.Message{
Role: "assistant",
Content: "",
},
@@ -176,7 +177,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *openai.TextRespons
return &fullTextResponse
}
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *openai.ChatCompletionsStreamResponse {
func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = geminiResponse.GetResponseText()
choice.FinishReason = &constant.StopFinishReason
@@ -187,7 +188,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *openai
return &response
}
func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
responseText := ""
dataChan := make(chan string)
stopChan := make(chan bool)
@@ -257,7 +258,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
return nil, responseText
}
func GeminiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) {
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -266,14 +267,14 @@ func GeminiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var geminiResponse GeminiChatResponse
var geminiResponse ChatResponse
err = json.Unmarshal(responseBody, &geminiResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if len(geminiResponse.Candidates) == 0 {
return &openai.ErrorWithStatusCode{
Error: openai.Error{
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: "No candidates returned",
Type: "server_error",
Param: "",
@@ -283,9 +284,9 @@ func GeminiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
}, nil
}
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
fullTextResponse.Model = model
completionTokens := openai.CountTokenText(geminiResponse.GetResponseText(), model)
usage := openai.Usage{
fullTextResponse.Model = modelName
completionTokens := openai.CountTokenText(geminiResponse.GetResponseText(), modelName)
usage := model.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,

View File

@@ -0,0 +1,41 @@
package gemini
type ChatRequest struct {
Contents []ChatContent `json:"contents"`
SafetySettings []ChatSafetySettings `json:"safety_settings,omitempty"`
GenerationConfig ChatGenerationConfig `json:"generation_config,omitempty"`
Tools []ChatTools `json:"tools,omitempty"`
}
type InlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"`
}
type Part struct {
Text string `json:"text,omitempty"`
InlineData *InlineData `json:"inlineData,omitempty"`
}
type ChatContent struct {
Role string `json:"role,omitempty"`
Parts []Part `json:"parts"`
}
type ChatSafetySettings struct {
Category string `json:"category"`
Threshold string `json:"threshold"`
}
type ChatTools struct {
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
}
type ChatGenerationConfig struct {
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"`
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
}

View File

@@ -1,22 +0,0 @@
package google
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel/openai"
"net/http"
)
type Adaptor struct {
}
func (a *Adaptor) Auth(c *gin.Context) error {
return nil
}
func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
return nil, nil, nil
}

View File

@@ -1,80 +0,0 @@
package google
import (
"github.com/songquanpeng/one-api/relay/channel/openai"
)
type GeminiChatRequest struct {
Contents []GeminiChatContent `json:"contents"`
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
Tools []GeminiChatTools `json:"tools,omitempty"`
}
type GeminiInlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"`
}
type GeminiPart struct {
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
}
type GeminiChatContent struct {
Role string `json:"role,omitempty"`
Parts []GeminiPart `json:"parts"`
}
type GeminiChatSafetySettings struct {
Category string `json:"category"`
Threshold string `json:"threshold"`
}
type GeminiChatTools struct {
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
}
type GeminiChatGenerationConfig struct {
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"`
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
}
type PaLMChatMessage struct {
Author string `json:"author"`
Content string `json:"content"`
}
type PaLMFilter struct {
Reason string `json:"reason"`
Message string `json:"message"`
}
type PaLMPrompt struct {
Messages []PaLMChatMessage `json:"messages"`
}
type PaLMChatRequest struct {
Prompt PaLMPrompt `json:"prompt"`
Temperature float64 `json:"temperature,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK int `json:"topK,omitempty"`
}
type PaLMError struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
}
type PaLMChatResponse struct {
Candidates []PaLMChatMessage `json:"candidates"`
Messages []openai.Message `json:"messages"`
Filters []PaLMFilter `json:"filters"`
Error PaLMError `json:"error"`
}

View File

@@ -2,14 +2,18 @@ package channel
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
type Adaptor interface {
GetRequestURL() string
Auth(c *gin.Context) error
ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error)
DoRequest(request *openai.GeneralOpenAIRequest) error
DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error)
GetRequestURL(meta *util.RelayMeta) (string, error)
SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error)
DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode)
GetModelList() []string
GetChannelName() string
}

View File

@@ -1,21 +1,80 @@
package openai
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
"strings"
)
type Adaptor struct {
}
func (a *Adaptor) Auth(c *gin.Context) error {
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
if meta.ChannelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
requestURL := strings.Split(meta.RequestURLPath, "?")[0]
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion)
task := strings.TrimPrefix(requestURL, "/v1/")
model_ := meta.ActualModelName
model_ = strings.Replace(model_, ".", "", -1)
// https://github.com/songquanpeng/one-api/issues/67
model_ = strings.TrimSuffix(model_, "-0301")
model_ = strings.TrimSuffix(model_, "-0314")
model_ = strings.TrimSuffix(model_, "-0613")
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
return util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil
}
return util.GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
if meta.ChannelType == common.ChannelTypeAzure {
req.Header.Set("api-key", meta.APIKey)
return nil
}
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
if meta.ChannelType == common.ChannelTypeOpenRouter {
req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
req.Header.Set("X-Title", "One API")
}
return nil
}
func (a *Adaptor) ConvertRequest(request *GeneralOpenAIRequest) (any, error) {
return nil, nil
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*ErrorWithStatusCode, *Usage, error) {
return nil, nil, nil
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
var responseText string
err, responseText = StreamHandler(c, resp, meta.Mode)
usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
} else {
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "openai"
}

View File

@@ -1,6 +0,0 @@
package openai
const (
ContentTypeText = "text"
ContentTypeImageURL = "image_url"
)

View File

@@ -0,0 +1,19 @@
package openai
var ModelList = []string{
"gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125",
"gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-instruct",
"gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
"gpt-4-turbo-preview",
"gpt-4-vision-preview",
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",
"text-moderation-latest", "text-moderation-stable",
"text-davinci-edit-001",
"davinci-002", "babbage-002",
"dall-e-2", "dall-e-3",
"whisper-1",
"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106",
}

View File

@@ -0,0 +1,11 @@
package openai
import "github.com/songquanpeng/one-api/relay/model"
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage {
usage := &model.Usage{}
usage.PromptTokens = promptTokens
usage.CompletionTokens = CountTokenText(responseText, modeName)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return usage
}

View File

@@ -8,12 +8,13 @@ import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
)
func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*ErrorWithStatusCode, string) {
func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string) {
responseText := ""
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
@@ -90,7 +91,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*ErrorWi
return nil, responseText
}
func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*ErrorWithStatusCode, *Usage) {
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
var textResponse SlimTextResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -105,7 +106,7 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if textResponse.Error.Type != "" {
return &ErrorWithStatusCode{
return &model.ErrorWithStatusCode{
Error: textResponse.Error,
StatusCode: resp.StatusCode,
}, nil
@@ -133,9 +134,9 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string
if textResponse.Usage.TotalTokens == 0 {
completionTokens := 0
for _, choice := range textResponse.Choices {
completionTokens += CountTokenText(choice.Message.StringContent(), model)
completionTokens += CountTokenText(choice.Message.StringContent(), modelName)
}
textResponse.Usage = Usage{
textResponse.Usage = model.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,

View File

@@ -1,15 +1,6 @@
package openai
type Message struct {
Role string `json:"role"`
Content any `json:"content"`
Name *string `json:"name,omitempty"`
}
type ImageURL struct {
Url string `json:"url,omitempty"`
Detail string `json:"detail,omitempty"`
}
import "github.com/songquanpeng/one-api/relay/model"
type TextContent struct {
Type string `json:"type,omitempty"`
@@ -17,142 +8,21 @@ type TextContent struct {
}
type ImageContent struct {
Type string `json:"type,omitempty"`
ImageURL *ImageURL `json:"image_url,omitempty"`
}
type OpenAIMessageContent struct {
Type string `json:"type,omitempty"`
Text string `json:"text"`
ImageURL *ImageURL `json:"image_url,omitempty"`
}
func (m Message) IsStringContent() bool {
_, ok := m.Content.(string)
return ok
}
func (m Message) StringContent() string {
content, ok := m.Content.(string)
if ok {
return content
}
contentList, ok := m.Content.([]any)
if ok {
var contentStr string
for _, contentItem := range contentList {
contentMap, ok := contentItem.(map[string]any)
if !ok {
continue
}
if contentMap["type"] == ContentTypeText {
if subStr, ok := contentMap["text"].(string); ok {
contentStr += subStr
}
}
}
return contentStr
}
return ""
}
func (m Message) ParseContent() []OpenAIMessageContent {
var contentList []OpenAIMessageContent
content, ok := m.Content.(string)
if ok {
contentList = append(contentList, OpenAIMessageContent{
Type: ContentTypeText,
Text: content,
})
return contentList
}
anyList, ok := m.Content.([]any)
if ok {
for _, contentItem := range anyList {
contentMap, ok := contentItem.(map[string]any)
if !ok {
continue
}
switch contentMap["type"] {
case ContentTypeText:
if subStr, ok := contentMap["text"].(string); ok {
contentList = append(contentList, OpenAIMessageContent{
Type: ContentTypeText,
Text: subStr,
})
}
case ContentTypeImageURL:
if subObj, ok := contentMap["image_url"].(map[string]any); ok {
contentList = append(contentList, OpenAIMessageContent{
Type: ContentTypeImageURL,
ImageURL: &ImageURL{
Url: subObj["url"].(string),
},
})
}
}
}
return contentList
}
return nil
}
type ResponseFormat struct {
Type string `json:"type,omitempty"`
}
type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
}
func (r GeneralOpenAIRequest) ParseInput() []string {
if r.Input == nil {
return nil
}
var input []string
switch r.Input.(type) {
case string:
input = []string{r.Input.(string)}
case []any:
input = make([]string, 0, len(r.Input.([]any)))
for _, item := range r.Input.([]any) {
if str, ok := item.(string); ok {
input = append(input, str)
}
}
}
return input
Type string `json:"type,omitempty"`
ImageURL *model.ImageURL `json:"image_url,omitempty"`
}
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
MaxTokens int `json:"max_tokens"`
Model string `json:"model"`
Messages []model.Message `json:"messages"`
MaxTokens int `json:"max_tokens"`
}
type TextRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Prompt string `json:"prompt"`
MaxTokens int `json:"max_tokens"`
Model string `json:"model"`
Messages []model.Message `json:"messages"`
Prompt string `json:"prompt"`
MaxTokens int `json:"max_tokens"`
//Stream bool `json:"stream"`
}
@@ -201,48 +71,30 @@ type TextToSpeechRequest struct {
ResponseFormat string `json:"response_format"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type UsageOrResponseText struct {
*Usage
*model.Usage
ResponseText string
}
type Error struct {
Message string `json:"message"`
Type string `json:"type"`
Param string `json:"param"`
Code any `json:"code"`
}
type ErrorWithStatusCode struct {
Error
StatusCode int `json:"status_code"`
}
type SlimTextResponse struct {
Choices []TextResponseChoice `json:"choices"`
Usage `json:"usage"`
Error Error `json:"error"`
Choices []TextResponseChoice `json:"choices"`
model.Usage `json:"usage"`
Error model.Error `json:"error"`
}
type TextResponseChoice struct {
Index int `json:"index"`
Message `json:"message"`
FinishReason string `json:"finish_reason"`
Index int `json:"index"`
model.Message `json:"message"`
FinishReason string `json:"finish_reason"`
}
type TextResponse struct {
Id string `json:"id"`
Model string `json:"model,omitempty"`
Object string `json:"object"`
Created int64 `json:"created"`
Choices []TextResponseChoice `json:"choices"`
Usage `json:"usage"`
Id string `json:"id"`
Model string `json:"model,omitempty"`
Object string `json:"object"`
Created int64 `json:"created"`
Choices []TextResponseChoice `json:"choices"`
model.Usage `json:"usage"`
}
type EmbeddingResponseItem struct {
@@ -252,10 +104,10 @@ type EmbeddingResponseItem struct {
}
type EmbeddingResponse struct {
Object string `json:"object"`
Data []EmbeddingResponseItem `json:"data"`
Model string `json:"model"`
Usage `json:"usage"`
Object string `json:"object"`
Data []EmbeddingResponseItem `json:"data"`
Model string `json:"model"`
model.Usage `json:"usage"`
}
type ImageResponse struct {

View File

@@ -8,6 +8,7 @@ import (
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/model"
"math"
"strings"
)
@@ -63,7 +64,7 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
return len(tokenEncoder.Encode(text, nil, nil))
}
func CountTokenMessages(messages []Message, model string) int {
func CountTokenMessages(messages []model.Message, model string) int {
tokenEncoder := getTokenEncoder(model)
// Reference:
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb

View File

@@ -1,12 +1,14 @@
package openai
func ErrorWrapper(err error, code string, statusCode int) *ErrorWithStatusCode {
Error := Error{
import "github.com/songquanpeng/one-api/relay/model"
func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode {
Error := model.Error{
Message: err.Error(),
Type: "one_api_error",
Code: code,
}
return &ErrorWithStatusCode{
return &model.ErrorWithStatusCode{
Error: Error,
StatusCode: statusCode,
}

View File

@@ -0,0 +1,56 @@
package palm
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
type Adaptor struct {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", meta.BaseURL), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("x-goog-api-key", meta.APIKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return ConvertRequest(*request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
var responseText string
err, responseText = StreamHandler(c, resp)
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
} else {
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "google palm"
}

View File

@@ -0,0 +1,5 @@
package palm
var ModelList = []string{
"PaLM-2",
}

View File

@@ -0,0 +1,40 @@
package palm
import (
"github.com/songquanpeng/one-api/relay/model"
)
type ChatMessage struct {
Author string `json:"author"`
Content string `json:"content"`
}
type Filter struct {
Reason string `json:"reason"`
Message string `json:"message"`
}
type Prompt struct {
Messages []ChatMessage `json:"messages"`
}
type ChatRequest struct {
Prompt Prompt `json:"prompt"`
Temperature float64 `json:"temperature,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK int `json:"topK,omitempty"`
}
type Error struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
}
type ChatResponse struct {
Candidates []ChatMessage `json:"candidates"`
Messages []model.Message `json:"messages"`
Filters []Filter `json:"filters"`
Error Error `json:"error"`
}

View File

@@ -1,4 +1,4 @@
package google
package palm
import (
"encoding/json"
@@ -9,6 +9,7 @@ import (
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
)
@@ -16,10 +17,10 @@ import (
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
func ConvertPaLMRequest(textRequest openai.GeneralOpenAIRequest) *PaLMChatRequest {
palmRequest := PaLMChatRequest{
Prompt: PaLMPrompt{
Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
palmRequest := ChatRequest{
Prompt: Prompt{
Messages: make([]ChatMessage, 0, len(textRequest.Messages)),
},
Temperature: textRequest.Temperature,
CandidateCount: textRequest.N,
@@ -27,7 +28,7 @@ func ConvertPaLMRequest(textRequest openai.GeneralOpenAIRequest) *PaLMChatReques
TopK: textRequest.MaxTokens,
}
for _, message := range textRequest.Messages {
palmMessage := PaLMChatMessage{
palmMessage := ChatMessage{
Content: message.StringContent(),
}
if message.Role == "user" {
@@ -40,14 +41,14 @@ func ConvertPaLMRequest(textRequest openai.GeneralOpenAIRequest) *PaLMChatReques
return &palmRequest
}
func responsePaLM2OpenAI(response *PaLMChatResponse) *openai.TextResponse {
func responsePaLM2OpenAI(response *ChatResponse) *openai.TextResponse {
fullTextResponse := openai.TextResponse{
Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)),
}
for i, candidate := range response.Candidates {
choice := openai.TextResponseChoice{
Index: i,
Message: openai.Message{
Message: model.Message{
Role: "assistant",
Content: candidate.Content,
},
@@ -58,7 +59,7 @@ func responsePaLM2OpenAI(response *PaLMChatResponse) *openai.TextResponse {
return &fullTextResponse
}
func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *openai.ChatCompletionsStreamResponse {
func streamResponsePaLM2OpenAI(palmResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
if len(palmResponse.Candidates) > 0 {
choice.Delta.Content = palmResponse.Candidates[0].Content
@@ -71,7 +72,7 @@ func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *openai.ChatCompl
return &response
}
func PaLMStreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID())
createdTime := helper.GetTimestamp()
@@ -90,7 +91,7 @@ func PaLMStreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithSt
stopChan <- true
return
}
var palmResponse PaLMChatResponse
var palmResponse ChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
@@ -130,7 +131,7 @@ func PaLMStreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithSt
return nil, responseText
}
func PaLMHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) {
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -139,14 +140,14 @@ func PaLMHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var palmResponse PaLMChatResponse
var palmResponse ChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
return &openai.ErrorWithStatusCode{
Error: openai.Error{
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: palmResponse.Error.Message,
Type: palmResponse.Error.Status,
Param: "",
@@ -156,9 +157,9 @@ func PaLMHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
}, nil
}
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
fullTextResponse.Model = model
completionTokens := openai.CountTokenText(palmResponse.Candidates[0].Content, model)
usage := openai.Usage{
fullTextResponse.Model = modelName
completionTokens := openai.CountTokenText(palmResponse.Candidates[0].Content, modelName)
usage := model.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,

View File

@@ -1,22 +1,69 @@
package tencent
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
"strings"
)
type Adaptor struct {
Sign string
}
func (a *Adaptor) Auth(c *gin.Context) error {
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
return fmt.Sprintf("%s/hyllm/v1/chat/completions", meta.BaseURL), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", a.Sign)
return nil
}
func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
return nil, nil
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
appId, secretId, secretKey, err := ParseConfig(apiKey)
if err != nil {
return nil, err
}
tencentRequest := ConvertRequest(*request)
tencentRequest.AppId = appId
tencentRequest.SecretId = secretId
// we have to calculate the sign here
a.Sign = GetSign(*tencentRequest, secretKey)
return tencentRequest, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
return nil, nil, nil
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
var responseText string
err, responseText = StreamHandler(c, resp)
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
} else {
err, usage = Handler(c, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "tencent"
}

View File

@@ -0,0 +1,5 @@
package tencent
var ModelList = []string{
"hunyuan",
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"sort"
@@ -23,7 +24,7 @@ import (
// https://cloud.tencent.com/document/product/1729/97732
func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
messages := make([]Message, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
@@ -67,7 +68,7 @@ func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse {
if len(response.Choices) > 0 {
choice := openai.TextResponseChoice{
Index: 0,
Message: openai.Message{
Message: model.Message{
Role: "assistant",
Content: response.Choices[0].Messages.Content,
},
@@ -95,7 +96,7 @@ func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCom
return &response
}
func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
var responseText string
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
@@ -159,7 +160,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
return nil, responseText
}
func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var TencentResponse ChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -174,8 +175,8 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode,
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if TencentResponse.Error.Code != 0 {
return &openai.ErrorWithStatusCode{
Error: openai.Error{
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: TencentResponse.Error.Message,
Code: TencentResponse.Error.Code,
},

View File

@@ -1,7 +1,7 @@
package tencent
import (
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
)
type Message struct {
@@ -56,7 +56,7 @@ type ChatResponse struct {
Choices []ResponseChoices `json:"choices,omitempty"` // 结果
Created string `json:"created,omitempty"` // unix 时间戳的字符串
Id string `json:"id,omitempty"` // 会话 id
Usage openai.Usage `json:"usage,omitempty"` // token 数量
Usage model.Usage `json:"usage,omitempty"` // token 数量
Error Error `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null表示取不到有效值
Note string `json:"note,omitempty"` // 注释
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id每次请求都会返回。用于反馈接口入参

View File

@@ -1,22 +1,66 @@
package xunfei
import (
"errors"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
"strings"
)
type Adaptor struct {
request *model.GeneralOpenAIRequest
}
func (a *Adaptor) Auth(c *gin.Context) error {
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
return "", nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
// check DoResponse for auth part
return nil
}
func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
a.request = request
return nil, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
return nil, nil, nil
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
// xunfei's request is not http request, so we don't need to do anything here
dummyResp := &http.Response{}
dummyResp.StatusCode = http.StatusOK
return dummyResp, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
splits := strings.Split(meta.APIKey, "|")
if len(splits) != 3 {
return nil, openai.ErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
}
if a.request == nil {
return nil, openai.ErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest)
}
if meta.IsStream {
err, usage = StreamHandler(c, *a.request, splits[0], splits[1], splits[2])
} else {
err, usage = Handler(c, *a.request, splits[0], splits[1], splits[2])
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "xunfei"
}

View File

@@ -0,0 +1,5 @@
package xunfei
var ModelList = []string{
"SparkDesk",
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"net/url"
@@ -23,7 +24,7 @@ import (
// https://console.xfyun.cn/services/cbm
// https://www.xfyun.cn/doc/spark/Web.html
func requestOpenAI2Xunfei(request openai.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest {
func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest {
messages := make([]Message, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
@@ -62,7 +63,7 @@ func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse {
}
choice := openai.TextResponseChoice{
Index: 0,
Message: openai.Message{
Message: model.Message{
Role: "assistant",
Content: response.Payload.Choices.Text[0].Content,
},
@@ -125,14 +126,14 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
return callUrl
}
func StreamHandler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*openai.ErrorWithStatusCode, *openai.Usage) {
func StreamHandler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) {
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil {
return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
}
common.SetEventStreamHeaders(c)
var usage openai.Usage
var usage model.Usage
c.Stream(func(w io.Writer) bool {
select {
case xunfeiResponse := <-dataChan:
@@ -155,13 +156,13 @@ func StreamHandler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appI
return nil, &usage
}
func Handler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*openai.ErrorWithStatusCode, *openai.Usage) {
func Handler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) {
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil {
return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
}
var usage openai.Usage
var usage model.Usage
var content string
var xunfeiResponse ChatResponse
stop := false
@@ -197,7 +198,7 @@ func Handler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId stri
return nil, &usage
}
func xunfeiMakeRequest(textRequest openai.GeneralOpenAIRequest, domain, authUrl, appId string) (chan ChatResponse, chan bool, error) {
func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl, appId string) (chan ChatResponse, chan bool, error) {
d := websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
}

View File

@@ -1,7 +1,7 @@
package xunfei
import (
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
)
type Message struct {
@@ -55,7 +55,7 @@ type ChatResponse struct {
// CompletionTokens string `json:"completion_tokens"`
// TotalTokens string `json:"total_tokens"`
//} `json:"text"`
Text openai.Usage `json:"text"`
Text model.Usage `json:"text"`
} `json:"usage"`
} `json:"payload"`
}

View File

@@ -1,22 +1,58 @@
package zhipu
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
type Adaptor struct {
}
func (a *Adaptor) Auth(c *gin.Context) error {
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
method := "invoke"
if meta.IsStream {
method = "sse-invoke"
}
return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", meta.BaseURL, meta.ActualModelName, method), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
token := GetToken(meta.APIKey)
req.Header.Set("Authorization", token)
return nil
}
func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
return nil, nil
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return ConvertRequest(*request), nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
return nil, nil, nil
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
err, usage = Handler(c, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "zhipu"
}

View File

@@ -0,0 +1,5 @@
package zhipu
var ModelList = []string{
"chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite",
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
@@ -72,7 +73,7 @@ func GetToken(apikey string) string {
return tokenString
}
func ConvertRequest(request openai.GeneralOpenAIRequest) *Request {
func ConvertRequest(request model.GeneralOpenAIRequest) *Request {
messages := make([]Message, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
@@ -110,7 +111,7 @@ func responseZhipu2OpenAI(response *Response) *openai.TextResponse {
for i, choice := range response.Data.Choices {
openaiChoice := openai.TextResponseChoice{
Index: i,
Message: openai.Message{
Message: model.Message{
Role: choice.Role,
Content: strings.Trim(choice.Content, "\""),
},
@@ -136,7 +137,7 @@ func streamResponseZhipu2OpenAI(zhipuResponse string) *openai.ChatCompletionsStr
return &response
}
func streamMetaResponseZhipu2OpenAI(zhipuResponse *StreamMetaResponse) (*openai.ChatCompletionsStreamResponse, *openai.Usage) {
func streamMetaResponseZhipu2OpenAI(zhipuResponse *StreamMetaResponse) (*openai.ChatCompletionsStreamResponse, *model.Usage) {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = ""
choice.FinishReason = &constant.StopFinishReason
@@ -150,8 +151,8 @@ func streamMetaResponseZhipu2OpenAI(zhipuResponse *StreamMetaResponse) (*openai.
return &response, &zhipuResponse.Usage
}
func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
var usage *openai.Usage
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var usage *model.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
@@ -228,7 +229,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
return nil, usage
}
func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var zhipuResponse Response
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -243,8 +244,8 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode,
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if !zhipuResponse.Success {
return &openai.ErrorWithStatusCode{
Error: openai.Error{
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: zhipuResponse.Msg,
Type: "zhipu_error",
Param: "",

View File

@@ -1,7 +1,7 @@
package zhipu
import (
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
"time"
)
@@ -19,11 +19,11 @@ type Request struct {
}
type ResponseData struct {
TaskId string `json:"task_id"`
RequestId string `json:"request_id"`
TaskStatus string `json:"task_status"`
Choices []Message `json:"choices"`
openai.Usage `json:"usage"`
TaskId string `json:"task_id"`
RequestId string `json:"request_id"`
TaskStatus string `json:"task_status"`
Choices []Message `json:"choices"`
model.Usage `json:"usage"`
}
type Response struct {
@@ -34,10 +34,10 @@ type Response struct {
}
type StreamMetaResponse struct {
RequestId string `json:"request_id"`
TaskId string `json:"task_id"`
TaskStatus string `json:"task_status"`
openai.Usage `json:"usage"`
RequestId string `json:"request_id"`
TaskId string `json:"task_id"`
TaskStatus string `json:"task_status"`
model.Usage `json:"usage"`
}
type tokenData struct {