mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-10-18 17:51:28 +00:00
refactor: use adaptor to do relay & test
This commit is contained in:
@@ -1,22 +1,55 @@
|
||||
package aiproxy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Auth(c *gin.Context) error {
|
||||
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
|
||||
return fmt.Sprintf("%s/api/library/ask", meta.BaseURL), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
|
||||
channel.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
|
||||
return nil, nil
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
aiProxyLibraryRequest := ConvertRequest(*request)
|
||||
aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
|
||||
return aiProxyLibraryRequest, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
|
||||
return nil, nil, nil
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
err, usage = StreamHandler(c, resp)
|
||||
} else {
|
||||
err, usage = Handler(c, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "aiproxy"
|
||||
}
|
||||
|
9
relay/channel/aiproxy/constants.go
Normal file
9
relay/channel/aiproxy/constants.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package aiproxy
|
||||
|
||||
import "github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
|
||||
var ModelList = []string{""}
|
||||
|
||||
func init() {
|
||||
ModelList = openai.ModelList
|
||||
}
|
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@@ -18,7 +19,7 @@ import (
|
||||
|
||||
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
|
||||
|
||||
func ConvertRequest(request openai.GeneralOpenAIRequest) *LibraryRequest {
|
||||
func ConvertRequest(request model.GeneralOpenAIRequest) *LibraryRequest {
|
||||
query := ""
|
||||
if len(request.Messages) != 0 {
|
||||
query = request.Messages[len(request.Messages)-1].StringContent()
|
||||
@@ -45,7 +46,7 @@ func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextRespon
|
||||
content := response.Answer + aiProxyDocuments2Markdown(response.Documents)
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: 0,
|
||||
Message: openai.Message{
|
||||
Message: model.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
},
|
||||
@@ -85,8 +86,8 @@ func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *opena
|
||||
}
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
|
||||
var usage openai.Usage
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var usage model.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
@@ -157,7 +158,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
|
||||
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var AIProxyLibraryResponse LibraryResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
@@ -172,8 +173,8 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode,
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if AIProxyLibraryResponse.ErrCode != 0 {
|
||||
return &openai.ErrorWithStatusCode{
|
||||
Error: openai.Error{
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: AIProxyLibraryResponse.Message,
|
||||
Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode),
|
||||
Code: AIProxyLibraryResponse.ErrCode,
|
||||
|
@@ -1,22 +1,76 @@
|
||||
package ali
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Auth(c *gin.Context) error {
|
||||
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
|
||||
fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL)
|
||||
if meta.Mode == constant.RelayModeEmbeddings {
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL)
|
||||
}
|
||||
return fullRequestURL, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
|
||||
channel.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||
if meta.IsStream {
|
||||
req.Header.Set("X-DashScope-SSE", "enable")
|
||||
}
|
||||
if c.GetString("plugin") != "" {
|
||||
req.Header.Set("X-DashScope-Plugin", c.GetString("plugin"))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
|
||||
return nil, nil
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
switch relayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
|
||||
return baiduEmbeddingRequest, nil
|
||||
default:
|
||||
baiduRequest := ConvertRequest(*request)
|
||||
return baiduRequest, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
|
||||
return nil, nil, nil
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
err, usage = StreamHandler(c, resp)
|
||||
} else {
|
||||
switch meta.Mode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
err, usage = EmbeddingHandler(c, resp)
|
||||
default:
|
||||
err, usage = Handler(c, resp)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "ali"
|
||||
}
|
||||
|
6
relay/channel/ali/constants.go
Normal file
6
relay/channel/ali/constants.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package ali
|
||||
|
||||
var ModelList = []string{
|
||||
"qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext",
|
||||
"text-embedding-v1",
|
||||
}
|
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -17,7 +18,7 @@ import (
|
||||
|
||||
const EnableSearchModelSuffix = "-internet"
|
||||
|
||||
func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
|
||||
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
||||
messages := make([]Message, 0, len(request.Messages))
|
||||
for i := 0; i < len(request.Messages); i++ {
|
||||
message := request.Messages[i]
|
||||
@@ -44,7 +45,7 @@ func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
|
||||
}
|
||||
}
|
||||
|
||||
func ConvertEmbeddingRequest(request openai.GeneralOpenAIRequest) *EmbeddingRequest {
|
||||
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
|
||||
return &EmbeddingRequest{
|
||||
Model: "text-embedding-v1",
|
||||
Input: struct {
|
||||
@@ -55,7 +56,7 @@ func ConvertEmbeddingRequest(request openai.GeneralOpenAIRequest) *EmbeddingRequ
|
||||
}
|
||||
}
|
||||
|
||||
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
|
||||
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var aliResponse EmbeddingResponse
|
||||
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
|
||||
if err != nil {
|
||||
@@ -68,8 +69,8 @@ func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithSta
|
||||
}
|
||||
|
||||
if aliResponse.Code != "" {
|
||||
return &openai.ErrorWithStatusCode{
|
||||
Error: openai.Error{
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
@@ -95,7 +96,7 @@ func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingR
|
||||
Object: "list",
|
||||
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Output.Embeddings)),
|
||||
Model: "text-embedding-v1",
|
||||
Usage: openai.Usage{TotalTokens: response.Usage.TotalTokens},
|
||||
Usage: model.Usage{TotalTokens: response.Usage.TotalTokens},
|
||||
}
|
||||
|
||||
for _, item := range response.Output.Embeddings {
|
||||
@@ -111,7 +112,7 @@ func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingR
|
||||
func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: 0,
|
||||
Message: openai.Message{
|
||||
Message: model.Message{
|
||||
Role: "assistant",
|
||||
Content: response.Output.Text,
|
||||
},
|
||||
@@ -122,7 +123,7 @@ func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
Object: "chat.completion",
|
||||
Created: helper.GetTimestamp(),
|
||||
Choices: []openai.TextResponseChoice{choice},
|
||||
Usage: openai.Usage{
|
||||
Usage: model.Usage{
|
||||
PromptTokens: response.Usage.InputTokens,
|
||||
CompletionTokens: response.Usage.OutputTokens,
|
||||
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
|
||||
@@ -148,8 +149,8 @@ func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletions
|
||||
return &response
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
|
||||
var usage openai.Usage
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var usage model.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
@@ -217,7 +218,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
|
||||
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var aliResponse ChatResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
@@ -232,8 +233,8 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode,
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if aliResponse.Code != "" {
|
||||
return &openai.ErrorWithStatusCode{
|
||||
Error: openai.Error{
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
|
@@ -1,22 +1,61 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Auth(c *gin.Context) error {
|
||||
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
|
||||
return fmt.Sprintf("%s/v1/complete", meta.BaseURL), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
|
||||
channel.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("x-api-key", meta.APIKey)
|
||||
anthropicVersion := c.Request.Header.Get("anthropic-version")
|
||||
if anthropicVersion == "" {
|
||||
anthropicVersion = "2023-06-01"
|
||||
}
|
||||
req.Header.Set("anthropic-version", anthropicVersion)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
|
||||
return nil, nil
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return ConvertRequest(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
|
||||
return nil, nil, nil
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
var responseText string
|
||||
err, responseText = StreamHandler(c, resp)
|
||||
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
|
||||
} else {
|
||||
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "authropic"
|
||||
}
|
||||
|
5
relay/channel/anthropic/constants.go
Normal file
5
relay/channel/anthropic/constants.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package anthropic
|
||||
|
||||
var ModelList = []string{
|
||||
"claude-instant-1", "claude-2", "claude-2.0", "claude-2.1",
|
||||
}
|
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -25,7 +26,7 @@ func stopReasonClaude2OpenAI(reason string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func ConvertRequest(textRequest openai.GeneralOpenAIRequest) *Request {
|
||||
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
|
||||
claudeRequest := Request{
|
||||
Model: textRequest.Model,
|
||||
Prompt: "",
|
||||
@@ -72,7 +73,7 @@ func streamResponseClaude2OpenAI(claudeResponse *Response) *openai.ChatCompletio
|
||||
func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: 0,
|
||||
Message: openai.Message{
|
||||
Message: model.Message{
|
||||
Role: "assistant",
|
||||
Content: strings.TrimPrefix(claudeResponse.Completion, " "),
|
||||
Name: nil,
|
||||
@@ -88,7 +89,7 @@ func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
|
||||
responseText := ""
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID())
|
||||
createdTime := helper.GetTimestamp()
|
||||
@@ -153,7 +154,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) {
|
||||
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
@@ -168,8 +169,8 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if claudeResponse.Error.Type != "" {
|
||||
return &openai.ErrorWithStatusCode{
|
||||
Error: openai.Error{
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: claudeResponse.Error.Message,
|
||||
Type: claudeResponse.Error.Type,
|
||||
Param: "",
|
||||
@@ -179,9 +180,9 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseClaude2OpenAI(&claudeResponse)
|
||||
fullTextResponse.Model = model
|
||||
completionTokens := openai.CountTokenText(claudeResponse.Completion, model)
|
||||
usage := openai.Usage{
|
||||
fullTextResponse.Model = modelName
|
||||
completionTokens := openai.CountTokenText(claudeResponse.Completion, modelName)
|
||||
usage := model.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
|
@@ -1,22 +1,89 @@
|
||||
package baidu
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Auth(c *gin.Context) error {
|
||||
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
|
||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
|
||||
var fullRequestURL string
|
||||
switch meta.ActualModelName {
|
||||
case "ERNIE-Bot-4":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
|
||||
case "ERNIE-Bot-8K":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k"
|
||||
case "ERNIE-Bot":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
|
||||
case "ERNIE-Speed":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed"
|
||||
case "ERNIE-Bot-turbo":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
|
||||
case "BLOOMZ-7B":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
|
||||
case "Embedding-V1":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
|
||||
}
|
||||
var accessToken string
|
||||
var err error
|
||||
if accessToken, err = GetAccessToken(meta.APIKey); err != nil {
|
||||
return "", err
|
||||
}
|
||||
fullRequestURL += "?access_token=" + accessToken
|
||||
return fullRequestURL, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
|
||||
channel.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
|
||||
return nil, nil
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
switch relayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
|
||||
return baiduEmbeddingRequest, nil
|
||||
default:
|
||||
baiduRequest := ConvertRequest(*request)
|
||||
return baiduRequest, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
|
||||
return nil, nil, nil
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
err, usage = StreamHandler(c, resp)
|
||||
} else {
|
||||
switch meta.Mode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
err, usage = EmbeddingHandler(c, resp)
|
||||
default:
|
||||
err, usage = Handler(c, resp)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "baidu"
|
||||
}
|
||||
|
10
relay/channel/baidu/constants.go
Normal file
10
relay/channel/baidu/constants.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package baidu
|
||||
|
||||
var ModelList = []string{
|
||||
"ERNIE-Bot-4",
|
||||
"ERNIE-Bot-8K",
|
||||
"ERNIE-Bot",
|
||||
"ERNIE-Speed",
|
||||
"ERNIE-Bot-turbo",
|
||||
"Embedding-V1",
|
||||
}
|
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -43,7 +44,7 @@ type Error struct {
|
||||
|
||||
var baiduTokenStore sync.Map
|
||||
|
||||
func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
|
||||
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
||||
messages := make([]Message, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == "system" {
|
||||
@@ -71,7 +72,7 @@ func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
|
||||
func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: 0,
|
||||
Message: openai.Message{
|
||||
Message: model.Message{
|
||||
Role: "assistant",
|
||||
Content: response.Result,
|
||||
},
|
||||
@@ -103,7 +104,7 @@ func streamResponseBaidu2OpenAI(baiduResponse *ChatStreamResponse) *openai.ChatC
|
||||
return &response
|
||||
}
|
||||
|
||||
func ConvertEmbeddingRequest(request openai.GeneralOpenAIRequest) *EmbeddingRequest {
|
||||
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
|
||||
return &EmbeddingRequest{
|
||||
Input: request.ParseInput(),
|
||||
}
|
||||
@@ -126,8 +127,8 @@ func embeddingResponseBaidu2OpenAI(response *EmbeddingResponse) *openai.Embeddin
|
||||
return &openAIEmbeddingResponse
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
|
||||
var usage openai.Usage
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var usage model.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
@@ -189,7 +190,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
|
||||
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var baiduResponse ChatResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
@@ -204,8 +205,8 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode,
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return &openai.ErrorWithStatusCode{
|
||||
Error: openai.Error{
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: baiduResponse.ErrorMsg,
|
||||
Type: "baidu_error",
|
||||
Param: "",
|
||||
@@ -226,7 +227,7 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode,
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
|
||||
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var baiduResponse EmbeddingResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
@@ -241,8 +242,8 @@ func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithSta
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return &openai.ErrorWithStatusCode{
|
||||
Error: openai.Error{
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: baiduResponse.ErrorMsg,
|
||||
Type: "baidu_error",
|
||||
Param: "",
|
||||
|
@@ -1,18 +1,18 @@
|
||||
package baidu
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ChatResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Result string `json:"result"`
|
||||
IsTruncated bool `json:"is_truncated"`
|
||||
NeedClearHistory bool `json:"need_clear_history"`
|
||||
Usage openai.Usage `json:"usage"`
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Result string `json:"result"`
|
||||
IsTruncated bool `json:"is_truncated"`
|
||||
NeedClearHistory bool `json:"need_clear_history"`
|
||||
Usage model.Usage `json:"usage"`
|
||||
Error
|
||||
}
|
||||
|
||||
@@ -37,7 +37,7 @@ type EmbeddingResponse struct {
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Data []EmbeddingData `json:"data"`
|
||||
Usage openai.Usage `json:"usage"`
|
||||
Usage model.Usage `json:"usage"`
|
||||
Error
|
||||
}
|
||||
|
||||
|
51
relay/channel/common.go
Normal file
51
relay/channel/common.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package channel
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) {
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
if meta.IsStream && c.Request.Header.Get("Accept") == "" {
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
}
|
||||
}
|
||||
|
||||
func DoRequestHelper(a Adaptor, c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
|
||||
fullRequestURL, err := a.GetRequestURL(meta)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get request url failed: %w", err)
|
||||
}
|
||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new request failed: %w", err)
|
||||
}
|
||||
err = a.SetupRequestHeader(c, req, meta)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||
}
|
||||
resp, err := DoRequest(c, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("do request failed: %w", err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func DoRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
|
||||
resp, err := util.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp == nil {
|
||||
return nil, errors.New("resp is nil")
|
||||
}
|
||||
_ = req.Body.Close()
|
||||
_ = c.Request.Body.Close()
|
||||
return resp, nil
|
||||
}
|
62
relay/channel/gemini/adaptor.go
Normal file
62
relay/channel/gemini/adaptor.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
channelhelper "github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
|
||||
version := helper.AssignOrDefault(meta.APIVersion, "v1")
|
||||
action := "generateContent"
|
||||
if meta.IsStream {
|
||||
action = "streamGenerateContent"
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
|
||||
channelhelper.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("x-goog-api-key", meta.APIKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return ConvertRequest(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
|
||||
return channelhelper.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
var responseText string
|
||||
err, responseText = StreamHandler(c, resp)
|
||||
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
|
||||
} else {
|
||||
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "google gemini"
|
||||
}
|
6
relay/channel/gemini/constants.go
Normal file
6
relay/channel/gemini/constants.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package gemini
|
||||
|
||||
var ModelList = []string{
|
||||
"gemini-pro",
|
||||
"gemini-pro-vision",
|
||||
}
|
@@ -1,4 +1,4 @@
|
||||
package google
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -21,14 +22,14 @@ import (
|
||||
// https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn
|
||||
|
||||
const (
|
||||
GeminiVisionMaxImageNum = 16
|
||||
VisionMaxImageNum = 16
|
||||
)
|
||||
|
||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||
func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRequest {
|
||||
geminiRequest := GeminiChatRequest{
|
||||
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
|
||||
SafetySettings: []GeminiChatSafetySettings{
|
||||
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
||||
geminiRequest := ChatRequest{
|
||||
Contents: make([]ChatContent, 0, len(textRequest.Messages)),
|
||||
SafetySettings: []ChatSafetySettings{
|
||||
{
|
||||
Category: "HARM_CATEGORY_HARASSMENT",
|
||||
Threshold: config.GeminiSafetySetting,
|
||||
@@ -46,14 +47,14 @@ func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRe
|
||||
Threshold: config.GeminiSafetySetting,
|
||||
},
|
||||
},
|
||||
GenerationConfig: GeminiChatGenerationConfig{
|
||||
GenerationConfig: ChatGenerationConfig{
|
||||
Temperature: textRequest.Temperature,
|
||||
TopP: textRequest.TopP,
|
||||
MaxOutputTokens: textRequest.MaxTokens,
|
||||
},
|
||||
}
|
||||
if textRequest.Functions != nil {
|
||||
geminiRequest.Tools = []GeminiChatTools{
|
||||
geminiRequest.Tools = []ChatTools{
|
||||
{
|
||||
FunctionDeclarations: textRequest.Functions,
|
||||
},
|
||||
@@ -61,30 +62,30 @@ func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRe
|
||||
}
|
||||
shouldAddDummyModelMessage := false
|
||||
for _, message := range textRequest.Messages {
|
||||
content := GeminiChatContent{
|
||||
content := ChatContent{
|
||||
Role: message.Role,
|
||||
Parts: []GeminiPart{
|
||||
Parts: []Part{
|
||||
{
|
||||
Text: message.StringContent(),
|
||||
},
|
||||
},
|
||||
}
|
||||
openaiContent := message.ParseContent()
|
||||
var parts []GeminiPart
|
||||
var parts []Part
|
||||
imageNum := 0
|
||||
for _, part := range openaiContent {
|
||||
if part.Type == openai.ContentTypeText {
|
||||
parts = append(parts, GeminiPart{
|
||||
if part.Type == model.ContentTypeText {
|
||||
parts = append(parts, Part{
|
||||
Text: part.Text,
|
||||
})
|
||||
} else if part.Type == openai.ContentTypeImageURL {
|
||||
} else if part.Type == model.ContentTypeImageURL {
|
||||
imageNum += 1
|
||||
if imageNum > GeminiVisionMaxImageNum {
|
||||
if imageNum > VisionMaxImageNum {
|
||||
continue
|
||||
}
|
||||
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
parts = append(parts, Part{
|
||||
InlineData: &InlineData{
|
||||
MimeType: mimeType,
|
||||
Data: data,
|
||||
},
|
||||
@@ -106,9 +107,9 @@ func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRe
|
||||
|
||||
// If a system message is the last message, we need to add a dummy model message to make gemini happy
|
||||
if shouldAddDummyModelMessage {
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, ChatContent{
|
||||
Role: "model",
|
||||
Parts: []GeminiPart{
|
||||
Parts: []Part{
|
||||
{
|
||||
Text: "Okay",
|
||||
},
|
||||
@@ -121,12 +122,12 @@ func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRe
|
||||
return &geminiRequest
|
||||
}
|
||||
|
||||
type GeminiChatResponse struct {
|
||||
Candidates []GeminiChatCandidate `json:"candidates"`
|
||||
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
|
||||
type ChatResponse struct {
|
||||
Candidates []ChatCandidate `json:"candidates"`
|
||||
PromptFeedback ChatPromptFeedback `json:"promptFeedback"`
|
||||
}
|
||||
|
||||
func (g *GeminiChatResponse) GetResponseText() string {
|
||||
func (g *ChatResponse) GetResponseText() string {
|
||||
if g == nil {
|
||||
return ""
|
||||
}
|
||||
@@ -136,23 +137,23 @@ func (g *GeminiChatResponse) GetResponseText() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
type GeminiChatCandidate struct {
|
||||
Content GeminiChatContent `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
Index int64 `json:"index"`
|
||||
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
|
||||
type ChatCandidate struct {
|
||||
Content ChatContent `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
Index int64 `json:"index"`
|
||||
SafetyRatings []ChatSafetyRating `json:"safetyRatings"`
|
||||
}
|
||||
|
||||
type GeminiChatSafetyRating struct {
|
||||
type ChatSafetyRating struct {
|
||||
Category string `json:"category"`
|
||||
Probability string `json:"probability"`
|
||||
}
|
||||
|
||||
type GeminiChatPromptFeedback struct {
|
||||
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
|
||||
type ChatPromptFeedback struct {
|
||||
SafetyRatings []ChatSafetyRating `json:"safetyRatings"`
|
||||
}
|
||||
|
||||
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *openai.TextResponse {
|
||||
func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
|
||||
Object: "chat.completion",
|
||||
@@ -162,7 +163,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *openai.TextRespons
|
||||
for i, candidate := range response.Candidates {
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: i,
|
||||
Message: openai.Message{
|
||||
Message: model.Message{
|
||||
Role: "assistant",
|
||||
Content: "",
|
||||
},
|
||||
@@ -176,7 +177,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *openai.TextRespons
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *openai.ChatCompletionsStreamResponse {
|
||||
func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = geminiResponse.GetResponseText()
|
||||
choice.FinishReason = &constant.StopFinishReason
|
||||
@@ -187,7 +188,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *openai
|
||||
return &response
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
|
||||
responseText := ""
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
@@ -257,7 +258,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
func GeminiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) {
|
||||
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
@@ -266,14 +267,14 @@ func GeminiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var geminiResponse GeminiChatResponse
|
||||
var geminiResponse ChatResponse
|
||||
err = json.Unmarshal(responseBody, &geminiResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if len(geminiResponse.Candidates) == 0 {
|
||||
return &openai.ErrorWithStatusCode{
|
||||
Error: openai.Error{
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: "No candidates returned",
|
||||
Type: "server_error",
|
||||
Param: "",
|
||||
@@ -283,9 +284,9 @@ func GeminiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
||||
fullTextResponse.Model = model
|
||||
completionTokens := openai.CountTokenText(geminiResponse.GetResponseText(), model)
|
||||
usage := openai.Usage{
|
||||
fullTextResponse.Model = modelName
|
||||
completionTokens := openai.CountTokenText(geminiResponse.GetResponseText(), modelName)
|
||||
usage := model.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
41
relay/channel/gemini/model.go
Normal file
41
relay/channel/gemini/model.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package gemini
|
||||
|
||||
type ChatRequest struct {
|
||||
Contents []ChatContent `json:"contents"`
|
||||
SafetySettings []ChatSafetySettings `json:"safety_settings,omitempty"`
|
||||
GenerationConfig ChatGenerationConfig `json:"generation_config,omitempty"`
|
||||
Tools []ChatTools `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
type InlineData struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type Part struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
InlineData *InlineData `json:"inlineData,omitempty"`
|
||||
}
|
||||
|
||||
type ChatContent struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Parts []Part `json:"parts"`
|
||||
}
|
||||
|
||||
type ChatSafetySettings struct {
|
||||
Category string `json:"category"`
|
||||
Threshold string `json:"threshold"`
|
||||
}
|
||||
|
||||
type ChatTools struct {
|
||||
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
|
||||
}
|
||||
|
||||
type ChatGenerationConfig struct {
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK float64 `json:"topK,omitempty"`
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
}
|
@@ -1,22 +0,0 @@
|
||||
package google
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Auth(c *gin.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
|
||||
return nil, nil, nil
|
||||
}
|
@@ -1,80 +0,0 @@
|
||||
package google
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
)
|
||||
|
||||
type GeminiChatRequest struct {
|
||||
Contents []GeminiChatContent `json:"contents"`
|
||||
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
|
||||
GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
|
||||
Tools []GeminiChatTools `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiInlineData struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type GeminiPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiChatContent struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Parts []GeminiPart `json:"parts"`
|
||||
}
|
||||
|
||||
type GeminiChatSafetySettings struct {
|
||||
Category string `json:"category"`
|
||||
Threshold string `json:"threshold"`
|
||||
}
|
||||
|
||||
type GeminiChatTools struct {
|
||||
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiChatGenerationConfig struct {
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK float64 `json:"topK,omitempty"`
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
}
|
||||
|
||||
type PaLMChatMessage struct {
|
||||
Author string `json:"author"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type PaLMFilter struct {
|
||||
Reason string `json:"reason"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type PaLMPrompt struct {
|
||||
Messages []PaLMChatMessage `json:"messages"`
|
||||
}
|
||||
|
||||
type PaLMChatRequest struct {
|
||||
Prompt PaLMPrompt `json:"prompt"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK int `json:"topK,omitempty"`
|
||||
}
|
||||
|
||||
type PaLMError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type PaLMChatResponse struct {
|
||||
Candidates []PaLMChatMessage `json:"candidates"`
|
||||
Messages []openai.Message `json:"messages"`
|
||||
Filters []PaLMFilter `json:"filters"`
|
||||
Error PaLMError `json:"error"`
|
||||
}
|
@@ -2,14 +2,18 @@ package channel
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Adaptor interface {
|
||||
GetRequestURL() string
|
||||
Auth(c *gin.Context) error
|
||||
ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error)
|
||||
DoRequest(request *openai.GeneralOpenAIRequest) error
|
||||
DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error)
|
||||
GetRequestURL(meta *util.RelayMeta) (string, error)
|
||||
SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error
|
||||
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
|
||||
DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error)
|
||||
DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode)
|
||||
GetModelList() []string
|
||||
GetChannelName() string
|
||||
}
|
||||
|
@@ -1,21 +1,80 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Auth(c *gin.Context) error {
|
||||
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
|
||||
if meta.ChannelType == common.ChannelTypeAzure {
|
||||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
||||
requestURL := strings.Split(meta.RequestURLPath, "?")[0]
|
||||
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion)
|
||||
task := strings.TrimPrefix(requestURL, "/v1/")
|
||||
model_ := meta.ActualModelName
|
||||
model_ = strings.Replace(model_, ".", "", -1)
|
||||
// https://github.com/songquanpeng/one-api/issues/67
|
||||
model_ = strings.TrimSuffix(model_, "-0301")
|
||||
model_ = strings.TrimSuffix(model_, "-0314")
|
||||
model_ = strings.TrimSuffix(model_, "-0613")
|
||||
|
||||
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
|
||||
return util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil
|
||||
}
|
||||
return util.GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
|
||||
channel.SetupCommonRequestHeader(c, req, meta)
|
||||
if meta.ChannelType == common.ChannelTypeAzure {
|
||||
req.Header.Set("api-key", meta.APIKey)
|
||||
return nil
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||
if meta.ChannelType == common.ChannelTypeOpenRouter {
|
||||
req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
|
||||
req.Header.Set("X-Title", "One API")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(request *GeneralOpenAIRequest) (any, error) {
|
||||
return nil, nil
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*ErrorWithStatusCode, *Usage, error) {
|
||||
return nil, nil, nil
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
var responseText string
|
||||
err, responseText = StreamHandler(c, resp, meta.Mode)
|
||||
usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
|
||||
} else {
|
||||
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "openai"
|
||||
}
|
||||
|
@@ -1,6 +0,0 @@
|
||||
package openai
|
||||
|
||||
const (
|
||||
ContentTypeText = "text"
|
||||
ContentTypeImageURL = "image_url"
|
||||
)
|
19
relay/channel/openai/constants.go
Normal file
19
relay/channel/openai/constants.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package openai
|
||||
|
||||
var ModelList = []string{
|
||||
"gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125",
|
||||
"gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613",
|
||||
"gpt-3.5-turbo-instruct",
|
||||
"gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
|
||||
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-4-vision-preview",
|
||||
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
|
||||
"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",
|
||||
"text-moderation-latest", "text-moderation-stable",
|
||||
"text-davinci-edit-001",
|
||||
"davinci-002", "babbage-002",
|
||||
"dall-e-2", "dall-e-3",
|
||||
"whisper-1",
|
||||
"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106",
|
||||
}
|
11
relay/channel/openai/helper.go
Normal file
11
relay/channel/openai/helper.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package openai
|
||||
|
||||
import "github.com/songquanpeng/one-api/relay/model"
|
||||
|
||||
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage {
|
||||
usage := &model.Usage{}
|
||||
usage.PromptTokens = promptTokens
|
||||
usage.CompletionTokens = CountTokenText(responseText, modeName)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
return usage
|
||||
}
|
@@ -8,12 +8,13 @@ import (
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*ErrorWithStatusCode, string) {
|
||||
func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string) {
|
||||
responseText := ""
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
@@ -90,7 +91,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*ErrorWi
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*ErrorWithStatusCode, *Usage) {
|
||||
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var textResponse SlimTextResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
@@ -105,7 +106,7 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string
|
||||
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if textResponse.Error.Type != "" {
|
||||
return &ErrorWithStatusCode{
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: textResponse.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
@@ -133,9 +134,9 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string
|
||||
if textResponse.Usage.TotalTokens == 0 {
|
||||
completionTokens := 0
|
||||
for _, choice := range textResponse.Choices {
|
||||
completionTokens += CountTokenText(choice.Message.StringContent(), model)
|
||||
completionTokens += CountTokenText(choice.Message.StringContent(), modelName)
|
||||
}
|
||||
textResponse.Usage = Usage{
|
||||
textResponse.Usage = model.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
|
@@ -1,15 +1,6 @@
|
||||
package openai
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
type ImageURL struct {
|
||||
Url string `json:"url,omitempty"`
|
||||
Detail string `json:"detail,omitempty"`
|
||||
}
|
||||
import "github.com/songquanpeng/one-api/relay/model"
|
||||
|
||||
type TextContent struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
@@ -17,142 +8,21 @@ type TextContent struct {
|
||||
}
|
||||
|
||||
type ImageContent struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
ImageURL *ImageURL `json:"image_url,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAIMessageContent struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Text string `json:"text"`
|
||||
ImageURL *ImageURL `json:"image_url,omitempty"`
|
||||
}
|
||||
|
||||
func (m Message) IsStringContent() bool {
|
||||
_, ok := m.Content.(string)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (m Message) StringContent() string {
|
||||
content, ok := m.Content.(string)
|
||||
if ok {
|
||||
return content
|
||||
}
|
||||
contentList, ok := m.Content.([]any)
|
||||
if ok {
|
||||
var contentStr string
|
||||
for _, contentItem := range contentList {
|
||||
contentMap, ok := contentItem.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if contentMap["type"] == ContentTypeText {
|
||||
if subStr, ok := contentMap["text"].(string); ok {
|
||||
contentStr += subStr
|
||||
}
|
||||
}
|
||||
}
|
||||
return contentStr
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m Message) ParseContent() []OpenAIMessageContent {
|
||||
var contentList []OpenAIMessageContent
|
||||
content, ok := m.Content.(string)
|
||||
if ok {
|
||||
contentList = append(contentList, OpenAIMessageContent{
|
||||
Type: ContentTypeText,
|
||||
Text: content,
|
||||
})
|
||||
return contentList
|
||||
}
|
||||
anyList, ok := m.Content.([]any)
|
||||
if ok {
|
||||
for _, contentItem := range anyList {
|
||||
contentMap, ok := contentItem.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch contentMap["type"] {
|
||||
case ContentTypeText:
|
||||
if subStr, ok := contentMap["text"].(string); ok {
|
||||
contentList = append(contentList, OpenAIMessageContent{
|
||||
Type: ContentTypeText,
|
||||
Text: subStr,
|
||||
})
|
||||
}
|
||||
case ContentTypeImageURL:
|
||||
if subObj, ok := contentMap["image_url"].(map[string]any); ok {
|
||||
contentList = append(contentList, OpenAIMessageContent{
|
||||
Type: ContentTypeImageURL,
|
||||
ImageURL: &ImageURL{
|
||||
Url: subObj["url"].(string),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
return contentList
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type ResponseFormat struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
}
|
||||
|
||||
type GeneralOpenAIRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
Instruction string `json:"instruction,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Functions any `json:"functions,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||
Seed float64 `json:"seed,omitempty"`
|
||||
Tools any `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
func (r GeneralOpenAIRequest) ParseInput() []string {
|
||||
if r.Input == nil {
|
||||
return nil
|
||||
}
|
||||
var input []string
|
||||
switch r.Input.(type) {
|
||||
case string:
|
||||
input = []string{r.Input.(string)}
|
||||
case []any:
|
||||
input = make([]string, 0, len(r.Input.([]any)))
|
||||
for _, item := range r.Input.([]any) {
|
||||
if str, ok := item.(string); ok {
|
||||
input = append(input, str)
|
||||
}
|
||||
}
|
||||
}
|
||||
return input
|
||||
Type string `json:"type,omitempty"`
|
||||
ImageURL *model.ImageURL `json:"image_url,omitempty"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Model string `json:"model"`
|
||||
Messages []model.Message `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
}
|
||||
|
||||
type TextRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Prompt string `json:"prompt"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Model string `json:"model"`
|
||||
Messages []model.Message `json:"messages"`
|
||||
Prompt string `json:"prompt"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
//Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
@@ -201,48 +71,30 @@ type TextToSpeechRequest struct {
|
||||
ResponseFormat string `json:"response_format"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type UsageOrResponseText struct {
|
||||
*Usage
|
||||
*model.Usage
|
||||
ResponseText string
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Param string `json:"param"`
|
||||
Code any `json:"code"`
|
||||
}
|
||||
|
||||
type ErrorWithStatusCode struct {
|
||||
Error
|
||||
StatusCode int `json:"status_code"`
|
||||
}
|
||||
|
||||
type SlimTextResponse struct {
|
||||
Choices []TextResponseChoice `json:"choices"`
|
||||
Usage `json:"usage"`
|
||||
Error Error `json:"error"`
|
||||
Choices []TextResponseChoice `json:"choices"`
|
||||
model.Usage `json:"usage"`
|
||||
Error model.Error `json:"error"`
|
||||
}
|
||||
|
||||
type TextResponseChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Index int `json:"index"`
|
||||
model.Message `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type TextResponse struct {
|
||||
Id string `json:"id"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Choices []TextResponseChoice `json:"choices"`
|
||||
Usage `json:"usage"`
|
||||
Id string `json:"id"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Choices []TextResponseChoice `json:"choices"`
|
||||
model.Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type EmbeddingResponseItem struct {
|
||||
@@ -252,10 +104,10 @@ type EmbeddingResponseItem struct {
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []EmbeddingResponseItem `json:"data"`
|
||||
Model string `json:"model"`
|
||||
Usage `json:"usage"`
|
||||
Object string `json:"object"`
|
||||
Data []EmbeddingResponseItem `json:"data"`
|
||||
Model string `json:"model"`
|
||||
model.Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type ImageResponse struct {
|
||||
|
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/image"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"math"
|
||||
"strings"
|
||||
)
|
||||
@@ -63,7 +64,7 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
||||
return len(tokenEncoder.Encode(text, nil, nil))
|
||||
}
|
||||
|
||||
func CountTokenMessages(messages []Message, model string) int {
|
||||
func CountTokenMessages(messages []model.Message, model string) int {
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
// Reference:
|
||||
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
|
@@ -1,12 +1,14 @@
|
||||
package openai
|
||||
|
||||
func ErrorWrapper(err error, code string, statusCode int) *ErrorWithStatusCode {
|
||||
Error := Error{
|
||||
import "github.com/songquanpeng/one-api/relay/model"
|
||||
|
||||
func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode {
|
||||
Error := model.Error{
|
||||
Message: err.Error(),
|
||||
Type: "one_api_error",
|
||||
Code: code,
|
||||
}
|
||||
return &ErrorWithStatusCode{
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: Error,
|
||||
StatusCode: statusCode,
|
||||
}
|
||||
|
56
relay/channel/palm/adaptor.go
Normal file
56
relay/channel/palm/adaptor.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package palm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
|
||||
return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", meta.BaseURL), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
|
||||
channel.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("x-goog-api-key", meta.APIKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return ConvertRequest(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
var responseText string
|
||||
err, responseText = StreamHandler(c, resp)
|
||||
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
|
||||
} else {
|
||||
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "google palm"
|
||||
}
|
5
relay/channel/palm/constants.go
Normal file
5
relay/channel/palm/constants.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package palm
|
||||
|
||||
var ModelList = []string{
|
||||
"PaLM-2",
|
||||
}
|
40
relay/channel/palm/model.go
Normal file
40
relay/channel/palm/model.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package palm
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
type ChatMessage struct {
|
||||
Author string `json:"author"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type Filter struct {
|
||||
Reason string `json:"reason"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type Prompt struct {
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Prompt Prompt `json:"prompt"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK int `json:"topK,omitempty"`
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type ChatResponse struct {
|
||||
Candidates []ChatMessage `json:"candidates"`
|
||||
Messages []model.Message `json:"messages"`
|
||||
Filters []Filter `json:"filters"`
|
||||
Error Error `json:"error"`
|
||||
}
|
@@ -1,4 +1,4 @@
|
||||
package google
|
||||
package palm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
@@ -16,10 +17,10 @@ import (
|
||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
|
||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
|
||||
|
||||
func ConvertPaLMRequest(textRequest openai.GeneralOpenAIRequest) *PaLMChatRequest {
|
||||
palmRequest := PaLMChatRequest{
|
||||
Prompt: PaLMPrompt{
|
||||
Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
|
||||
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
||||
palmRequest := ChatRequest{
|
||||
Prompt: Prompt{
|
||||
Messages: make([]ChatMessage, 0, len(textRequest.Messages)),
|
||||
},
|
||||
Temperature: textRequest.Temperature,
|
||||
CandidateCount: textRequest.N,
|
||||
@@ -27,7 +28,7 @@ func ConvertPaLMRequest(textRequest openai.GeneralOpenAIRequest) *PaLMChatReques
|
||||
TopK: textRequest.MaxTokens,
|
||||
}
|
||||
for _, message := range textRequest.Messages {
|
||||
palmMessage := PaLMChatMessage{
|
||||
palmMessage := ChatMessage{
|
||||
Content: message.StringContent(),
|
||||
}
|
||||
if message.Role == "user" {
|
||||
@@ -40,14 +41,14 @@ func ConvertPaLMRequest(textRequest openai.GeneralOpenAIRequest) *PaLMChatReques
|
||||
return &palmRequest
|
||||
}
|
||||
|
||||
func responsePaLM2OpenAI(response *PaLMChatResponse) *openai.TextResponse {
|
||||
func responsePaLM2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)),
|
||||
}
|
||||
for i, candidate := range response.Candidates {
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: i,
|
||||
Message: openai.Message{
|
||||
Message: model.Message{
|
||||
Role: "assistant",
|
||||
Content: candidate.Content,
|
||||
},
|
||||
@@ -58,7 +59,7 @@ func responsePaLM2OpenAI(response *PaLMChatResponse) *openai.TextResponse {
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *openai.ChatCompletionsStreamResponse {
|
||||
func streamResponsePaLM2OpenAI(palmResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
if len(palmResponse.Candidates) > 0 {
|
||||
choice.Delta.Content = palmResponse.Candidates[0].Content
|
||||
@@ -71,7 +72,7 @@ func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *openai.ChatCompl
|
||||
return &response
|
||||
}
|
||||
|
||||
func PaLMStreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
|
||||
responseText := ""
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID())
|
||||
createdTime := helper.GetTimestamp()
|
||||
@@ -90,7 +91,7 @@ func PaLMStreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithSt
|
||||
stopChan <- true
|
||||
return
|
||||
}
|
||||
var palmResponse PaLMChatResponse
|
||||
var palmResponse ChatResponse
|
||||
err = json.Unmarshal(responseBody, &palmResponse)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
@@ -130,7 +131,7 @@ func PaLMStreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithSt
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
func PaLMHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) {
|
||||
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
@@ -139,14 +140,14 @@ func PaLMHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var palmResponse PaLMChatResponse
|
||||
var palmResponse ChatResponse
|
||||
err = json.Unmarshal(responseBody, &palmResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
|
||||
return &openai.ErrorWithStatusCode{
|
||||
Error: openai.Error{
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: palmResponse.Error.Message,
|
||||
Type: palmResponse.Error.Status,
|
||||
Param: "",
|
||||
@@ -156,9 +157,9 @@ func PaLMHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
||||
fullTextResponse.Model = model
|
||||
completionTokens := openai.CountTokenText(palmResponse.Candidates[0].Content, model)
|
||||
usage := openai.Usage{
|
||||
fullTextResponse.Model = modelName
|
||||
completionTokens := openai.CountTokenText(palmResponse.Candidates[0].Content, modelName)
|
||||
usage := model.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
@@ -1,22 +1,69 @@
|
||||
package tencent
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
Sign string
|
||||
}
|
||||
|
||||
func (a *Adaptor) Auth(c *gin.Context) error {
|
||||
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
|
||||
return fmt.Sprintf("%s/hyllm/v1/chat/completions", meta.BaseURL), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
|
||||
channel.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("Authorization", a.Sign)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
|
||||
return nil, nil
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
appId, secretId, secretKey, err := ParseConfig(apiKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tencentRequest := ConvertRequest(*request)
|
||||
tencentRequest.AppId = appId
|
||||
tencentRequest.SecretId = secretId
|
||||
// we have to calculate the sign here
|
||||
a.Sign = GetSign(*tencentRequest, secretKey)
|
||||
return tencentRequest, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
|
||||
return nil, nil, nil
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
var responseText string
|
||||
err, responseText = StreamHandler(c, resp)
|
||||
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
|
||||
} else {
|
||||
err, usage = Handler(c, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "tencent"
|
||||
}
|
||||
|
5
relay/channel/tencent/constants.go
Normal file
5
relay/channel/tencent/constants.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package tencent
|
||||
|
||||
var ModelList = []string{
|
||||
"hunyuan",
|
||||
}
|
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
"sort"
|
||||
@@ -23,7 +24,7 @@ import (
|
||||
|
||||
// https://cloud.tencent.com/document/product/1729/97732
|
||||
|
||||
func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
|
||||
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
||||
messages := make([]Message, 0, len(request.Messages))
|
||||
for i := 0; i < len(request.Messages); i++ {
|
||||
message := request.Messages[i]
|
||||
@@ -67,7 +68,7 @@ func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
if len(response.Choices) > 0 {
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: 0,
|
||||
Message: openai.Message{
|
||||
Message: model.Message{
|
||||
Role: "assistant",
|
||||
Content: response.Choices[0].Messages.Content,
|
||||
},
|
||||
@@ -95,7 +96,7 @@ func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCom
|
||||
return &response
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
|
||||
var responseText string
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
@@ -159,7 +160,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
|
||||
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var TencentResponse ChatResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
@@ -174,8 +175,8 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode,
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if TencentResponse.Error.Code != 0 {
|
||||
return &openai.ErrorWithStatusCode{
|
||||
Error: openai.Error{
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: TencentResponse.Error.Message,
|
||||
Code: TencentResponse.Error.Code,
|
||||
},
|
||||
|
@@ -1,7 +1,7 @@
|
||||
package tencent
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
@@ -56,7 +56,7 @@ type ChatResponse struct {
|
||||
Choices []ResponseChoices `json:"choices,omitempty"` // 结果
|
||||
Created string `json:"created,omitempty"` // unix 时间戳的字符串
|
||||
Id string `json:"id,omitempty"` // 会话 id
|
||||
Usage openai.Usage `json:"usage,omitempty"` // token 数量
|
||||
Usage model.Usage `json:"usage,omitempty"` // token 数量
|
||||
Error Error `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
||||
Note string `json:"note,omitempty"` // 注释
|
||||
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
||||
|
@@ -1,22 +1,66 @@
|
||||
package xunfei
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
request *model.GeneralOpenAIRequest
|
||||
}
|
||||
|
||||
func (a *Adaptor) Auth(c *gin.Context) error {
|
||||
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
|
||||
channel.SetupCommonRequestHeader(c, req, meta)
|
||||
// check DoResponse for auth part
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
a.request = request
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
|
||||
return nil, nil, nil
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
|
||||
// xunfei's request is not http request, so we don't need to do anything here
|
||||
dummyResp := &http.Response{}
|
||||
dummyResp.StatusCode = http.StatusOK
|
||||
return dummyResp, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
splits := strings.Split(meta.APIKey, "|")
|
||||
if len(splits) != 3 {
|
||||
return nil, openai.ErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
|
||||
}
|
||||
if a.request == nil {
|
||||
return nil, openai.ErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest)
|
||||
}
|
||||
if meta.IsStream {
|
||||
err, usage = StreamHandler(c, *a.request, splits[0], splits[1], splits[2])
|
||||
} else {
|
||||
err, usage = Handler(c, *a.request, splits[0], splits[1], splits[2])
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "xunfei"
|
||||
}
|
||||
|
5
relay/channel/xunfei/constants.go
Normal file
5
relay/channel/xunfei/constants.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package xunfei
|
||||
|
||||
var ModelList = []string{
|
||||
"SparkDesk",
|
||||
}
|
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -23,7 +24,7 @@ import (
|
||||
// https://console.xfyun.cn/services/cbm
|
||||
// https://www.xfyun.cn/doc/spark/Web.html
|
||||
|
||||
func requestOpenAI2Xunfei(request openai.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest {
|
||||
func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest {
|
||||
messages := make([]Message, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == "system" {
|
||||
@@ -62,7 +63,7 @@ func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
}
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: 0,
|
||||
Message: openai.Message{
|
||||
Message: model.Message{
|
||||
Role: "assistant",
|
||||
Content: response.Payload.Choices.Text[0].Content,
|
||||
},
|
||||
@@ -125,14 +126,14 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
|
||||
return callUrl
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*openai.ErrorWithStatusCode, *openai.Usage) {
|
||||
func StreamHandler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
|
||||
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.SetEventStreamHeaders(c)
|
||||
var usage openai.Usage
|
||||
var usage model.Usage
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case xunfeiResponse := <-dataChan:
|
||||
@@ -155,13 +156,13 @@ func StreamHandler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appI
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*openai.ErrorWithStatusCode, *openai.Usage) {
|
||||
func Handler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
|
||||
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
||||
}
|
||||
var usage openai.Usage
|
||||
var usage model.Usage
|
||||
var content string
|
||||
var xunfeiResponse ChatResponse
|
||||
stop := false
|
||||
@@ -197,7 +198,7 @@ func Handler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId stri
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func xunfeiMakeRequest(textRequest openai.GeneralOpenAIRequest, domain, authUrl, appId string) (chan ChatResponse, chan bool, error) {
|
||||
func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl, appId string) (chan ChatResponse, chan bool, error) {
|
||||
d := websocket.Dialer{
|
||||
HandshakeTimeout: 5 * time.Second,
|
||||
}
|
||||
|
@@ -1,7 +1,7 @@
|
||||
package xunfei
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
@@ -55,7 +55,7 @@ type ChatResponse struct {
|
||||
// CompletionTokens string `json:"completion_tokens"`
|
||||
// TotalTokens string `json:"total_tokens"`
|
||||
//} `json:"text"`
|
||||
Text openai.Usage `json:"text"`
|
||||
Text model.Usage `json:"text"`
|
||||
} `json:"usage"`
|
||||
} `json:"payload"`
|
||||
}
|
||||
|
@@ -1,22 +1,58 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/channel"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Auth(c *gin.Context) error {
|
||||
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
|
||||
method := "invoke"
|
||||
if meta.IsStream {
|
||||
method = "sse-invoke"
|
||||
}
|
||||
return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", meta.BaseURL, meta.ActualModelName, method), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
|
||||
channel.SetupCommonRequestHeader(c, req, meta)
|
||||
token := GetToken(meta.APIKey)
|
||||
req.Header.Set("Authorization", token)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
|
||||
return nil, nil
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return ConvertRequest(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
|
||||
return nil, nil, nil
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
err, usage = StreamHandler(c, resp)
|
||||
} else {
|
||||
err, usage = Handler(c, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "zhipu"
|
||||
}
|
||||
|
5
relay/channel/zhipu/constants.go
Normal file
5
relay/channel/zhipu/constants.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package zhipu
|
||||
|
||||
var ModelList = []string{
|
||||
"chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite",
|
||||
}
|
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -72,7 +73,7 @@ func GetToken(apikey string) string {
|
||||
return tokenString
|
||||
}
|
||||
|
||||
func ConvertRequest(request openai.GeneralOpenAIRequest) *Request {
|
||||
func ConvertRequest(request model.GeneralOpenAIRequest) *Request {
|
||||
messages := make([]Message, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == "system" {
|
||||
@@ -110,7 +111,7 @@ func responseZhipu2OpenAI(response *Response) *openai.TextResponse {
|
||||
for i, choice := range response.Data.Choices {
|
||||
openaiChoice := openai.TextResponseChoice{
|
||||
Index: i,
|
||||
Message: openai.Message{
|
||||
Message: model.Message{
|
||||
Role: choice.Role,
|
||||
Content: strings.Trim(choice.Content, "\""),
|
||||
},
|
||||
@@ -136,7 +137,7 @@ func streamResponseZhipu2OpenAI(zhipuResponse string) *openai.ChatCompletionsStr
|
||||
return &response
|
||||
}
|
||||
|
||||
func streamMetaResponseZhipu2OpenAI(zhipuResponse *StreamMetaResponse) (*openai.ChatCompletionsStreamResponse, *openai.Usage) {
|
||||
func streamMetaResponseZhipu2OpenAI(zhipuResponse *StreamMetaResponse) (*openai.ChatCompletionsStreamResponse, *model.Usage) {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = ""
|
||||
choice.FinishReason = &constant.StopFinishReason
|
||||
@@ -150,8 +151,8 @@ func streamMetaResponseZhipu2OpenAI(zhipuResponse *StreamMetaResponse) (*openai.
|
||||
return &response, &zhipuResponse.Usage
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
|
||||
var usage *openai.Usage
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var usage *model.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
@@ -228,7 +229,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
|
||||
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var zhipuResponse Response
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
@@ -243,8 +244,8 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode,
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if !zhipuResponse.Success {
|
||||
return &openai.ErrorWithStatusCode{
|
||||
Error: openai.Error{
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: zhipuResponse.Msg,
|
||||
Type: "zhipu_error",
|
||||
Param: "",
|
||||
|
@@ -1,7 +1,7 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -19,11 +19,11 @@ type Request struct {
|
||||
}
|
||||
|
||||
type ResponseData struct {
|
||||
TaskId string `json:"task_id"`
|
||||
RequestId string `json:"request_id"`
|
||||
TaskStatus string `json:"task_status"`
|
||||
Choices []Message `json:"choices"`
|
||||
openai.Usage `json:"usage"`
|
||||
TaskId string `json:"task_id"`
|
||||
RequestId string `json:"request_id"`
|
||||
TaskStatus string `json:"task_status"`
|
||||
Choices []Message `json:"choices"`
|
||||
model.Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
@@ -34,10 +34,10 @@ type Response struct {
|
||||
}
|
||||
|
||||
type StreamMetaResponse struct {
|
||||
RequestId string `json:"request_id"`
|
||||
TaskId string `json:"task_id"`
|
||||
TaskStatus string `json:"task_status"`
|
||||
openai.Usage `json:"usage"`
|
||||
RequestId string `json:"request_id"`
|
||||
TaskId string `json:"task_id"`
|
||||
TaskStatus string `json:"task_status"`
|
||||
model.Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type tokenData struct {
|
||||
|
Reference in New Issue
Block a user