mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-10-15 15:30:26 +00:00
feat: support claude and gemini in vertex ai (#1621)
* feat: support claude and gemini in vertex ai * fix: do not show api key field in channel page when the type is VertexAI * fix: update getToken function to include channelId in cache key
This commit is contained in:
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/palm"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/tencent"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/vertexai"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/xunfei"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/zhipu"
|
||||
"github.com/songquanpeng/one-api/relay/apitype"
|
||||
@@ -55,6 +56,8 @@ func GetAdaptor(apiType int) adaptor.Adaptor {
|
||||
return &cloudflare.Adaptor{}
|
||||
case apitype.DeepL:
|
||||
return &deepl.Adaptor{}
|
||||
case apitype.VertexAI:
|
||||
return &vertexai.Adaptor{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
101
relay/adaptor/vertexai/adaptor.go
Normal file
101
relay/adaptor/vertexai/adaptor.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package vertexai
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
channelhelper "github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
var _ adaptor.Adaptor = new(Adaptor)
|
||||
|
||||
const channelName = "vertexai"
|
||||
|
||||
type Adaptor struct {}
|
||||
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
adaptor := GetAdaptor(request.Model)
|
||||
if adaptor == nil {
|
||||
return nil, errors.New("adaptor not found")
|
||||
}
|
||||
|
||||
return adaptor.ConvertRequest(c, relayMode, request)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
adaptor := GetAdaptor(meta.OriginModelName)
|
||||
if adaptor == nil {
|
||||
return nil, &relaymodel.ErrorWithStatusCode{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Error: relaymodel.Error{
|
||||
Message: "adaptor not found",
|
||||
},
|
||||
}
|
||||
}
|
||||
return adaptor.DoResponse(c, resp, meta)
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() (models []string) {
|
||||
models = modelList
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return channelName
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
suffix := ""
|
||||
if strings.HasPrefix(meta.ActualModelName, "gemini") {
|
||||
if meta.IsStream {
|
||||
suffix = "streamGenerateContent"
|
||||
} else {
|
||||
suffix = "generateContent"
|
||||
}
|
||||
} else {
|
||||
if meta.IsStream {
|
||||
suffix = "streamRawPredict"
|
||||
} else {
|
||||
suffix = "rawPredict"
|
||||
}
|
||||
}
|
||||
|
||||
baseUrl := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s", meta.Config.Region, meta.Config.VertexAIProjectID, meta.Config.Region, meta.ActualModelName, suffix)
|
||||
return baseUrl, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||
token, err := getToken(c, meta.ChannelId, meta.Config.VertexAIADC)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return channelhelper.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
54
relay/adaptor/vertexai/claude/adapter.go
Normal file
54
relay/adaptor/vertexai/claude/adapter.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package vertexai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
|
||||
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
var ModelList = []string{
|
||||
"claude-3-haiku@20240307", "claude-3-opus@20240229", "claude-3-5-sonnet@20240620", "claude-3-sonnet@20240229",
|
||||
}
|
||||
|
||||
const anthropicVersion = "vertex-2023-10-16"
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
claudeReq := anthropic.ConvertRequest(*request)
|
||||
req := Request{
|
||||
AnthropicVersion: anthropicVersion,
|
||||
// Model: claudeReq.Model,
|
||||
Messages: claudeReq.Messages,
|
||||
MaxTokens: claudeReq.MaxTokens,
|
||||
Temperature: claudeReq.Temperature,
|
||||
TopP: claudeReq.TopP,
|
||||
TopK: claudeReq.TopK,
|
||||
Stream: claudeReq.Stream,
|
||||
Tools: claudeReq.Tools,
|
||||
}
|
||||
|
||||
c.Set(ctxkey.RequestModel, request.Model)
|
||||
c.Set(ctxkey.ConvertedRequest, req)
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
err, usage = anthropic.StreamHandler(c, resp)
|
||||
} else {
|
||||
err, usage = anthropic.Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||
}
|
||||
return
|
||||
}
|
19
relay/adaptor/vertexai/claude/model.go
Normal file
19
relay/adaptor/vertexai/claude/model.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package vertexai
|
||||
|
||||
import "github.com/songquanpeng/one-api/relay/adaptor/anthropic"
|
||||
|
||||
type Request struct {
|
||||
// AnthropicVersion must be "vertex-2023-10-16"
|
||||
AnthropicVersion string `json:"anthropic_version"`
|
||||
// Model string `json:"model"`
|
||||
Messages []anthropic.Message `json:"messages"`
|
||||
System string `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Tools []anthropic.Tool `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
}
|
49
relay/adaptor/vertexai/gemini/adapter.go
Normal file
49
relay/adaptor/vertexai/gemini/adapter.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package vertexai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/gemini"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
var ModelList = []string{
|
||||
"gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision",
|
||||
}
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
geminiRequest := gemini.ConvertRequest(*request)
|
||||
c.Set(ctxkey.RequestModel, request.Model)
|
||||
c.Set(ctxkey.ConvertedRequest, geminiRequest)
|
||||
return geminiRequest, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
var responseText string
|
||||
err, responseText = gemini.StreamHandler(c, resp)
|
||||
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
|
||||
} else {
|
||||
switch meta.Mode {
|
||||
case relaymode.Embeddings:
|
||||
err, usage = gemini.EmbeddingHandler(c, resp)
|
||||
default:
|
||||
err, usage = gemini.Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
51
relay/adaptor/vertexai/registry.go
Normal file
51
relay/adaptor/vertexai/registry.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package vertexai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
claude "github.com/songquanpeng/one-api/relay/adaptor/vertexai/claude"
|
||||
gemini "github.com/songquanpeng/one-api/relay/adaptor/vertexai/gemini"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
type VertexAIModelType int
|
||||
|
||||
const (
|
||||
VerterAIClaude VertexAIModelType = iota + 1
|
||||
VerterAIGemini
|
||||
)
|
||||
|
||||
var modelMapping = map[string]VertexAIModelType{}
|
||||
var modelList = []string{}
|
||||
|
||||
func init() {
|
||||
modelList = append(modelList, claude.ModelList...)
|
||||
for _, model := range claude.ModelList {
|
||||
modelMapping[model] = VerterAIClaude
|
||||
}
|
||||
|
||||
modelList = append(modelList, gemini.ModelList...)
|
||||
for _, model := range gemini.ModelList {
|
||||
modelMapping[model] = VerterAIGemini
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
type innerAIAdapter interface {
|
||||
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
|
||||
DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode)
|
||||
}
|
||||
|
||||
func GetAdaptor(model string) innerAIAdapter {
|
||||
adaptorType := modelMapping[model]
|
||||
switch adaptorType {
|
||||
case VerterAIClaude:
|
||||
return &claude.Adaptor{}
|
||||
case VerterAIGemini:
|
||||
return &gemini.Adaptor{}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
63
relay/adaptor/vertexai/token.go
Normal file
63
relay/adaptor/vertexai/token.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package vertexai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
credentials "cloud.google.com/go/iam/credentials/apiv1"
|
||||
"cloud.google.com/go/iam/credentials/apiv1/credentialspb"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
|
||||
type ApplicationDefaultCredentials struct {
|
||||
Type string `json:"type"`
|
||||
ProjectID string `json:"project_id"`
|
||||
PrivateKeyID string `json:"private_key_id"`
|
||||
PrivateKey string `json:"private_key"`
|
||||
ClientEmail string `json:"client_email"`
|
||||
ClientID string `json:"client_id"`
|
||||
AuthURI string `json:"auth_uri"`
|
||||
TokenURI string `json:"token_uri"`
|
||||
AuthProviderX509CertURL string `json:"auth_provider_x509_cert_url"`
|
||||
ClientX509CertURL string `json:"client_x509_cert_url"`
|
||||
UniverseDomain string `json:"universe_domain"`
|
||||
}
|
||||
|
||||
|
||||
var Cache = cache.New(50*time.Minute, 55*time.Minute)
|
||||
|
||||
const defaultScope = "https://www.googleapis.com/auth/cloud-platform"
|
||||
|
||||
func getToken(ctx context.Context, channelId int, adcJson string) (string, error) {
|
||||
cacheKey := fmt.Sprintf("vertexai-token-%d", channelId)
|
||||
if token, found := Cache.Get(cacheKey); found {
|
||||
return token.(string), nil
|
||||
}
|
||||
adc := &ApplicationDefaultCredentials{}
|
||||
if err := json.Unmarshal([]byte(adcJson), adc); err != nil {
|
||||
return "", fmt.Errorf("Failed to decode credentials file: %w", err)
|
||||
}
|
||||
|
||||
c, err := credentials.NewIamCredentialsClient(ctx, option.WithCredentialsJSON([]byte(adcJson)))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Failed to create client: %w", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
req := &credentialspb.GenerateAccessTokenRequest{
|
||||
// See https://pkg.go.dev/cloud.google.com/go/iam/credentials/apiv1/credentialspb#GenerateAccessTokenRequest.
|
||||
Name: fmt.Sprintf("projects/-/serviceAccounts/%s", adc.ClientEmail),
|
||||
Scope: []string{defaultScope},
|
||||
}
|
||||
resp, err := c.GenerateAccessToken(ctx, req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Failed to generate access token: %w", err)
|
||||
}
|
||||
_ = resp
|
||||
|
||||
Cache.Set(cacheKey, resp.AccessToken, cache.DefaultExpiration)
|
||||
return resp.AccessToken, nil
|
||||
}
|
@@ -17,6 +17,7 @@ const (
|
||||
Cohere
|
||||
Cloudflare
|
||||
DeepL
|
||||
VertexAI
|
||||
|
||||
Dummy // this one is only for count, do not add any channel after this
|
||||
)
|
||||
|
@@ -43,5 +43,6 @@ const (
|
||||
TogetherAI
|
||||
Doubao
|
||||
Novita
|
||||
VertextAI
|
||||
Dummy
|
||||
)
|
||||
|
@@ -35,6 +35,8 @@ func ToAPIType(channelType int) int {
|
||||
apiType = apitype.Cloudflare
|
||||
case DeepL:
|
||||
apiType = apitype.DeepL
|
||||
case VertextAI:
|
||||
apiType = apitype.VertexAI
|
||||
}
|
||||
|
||||
return apiType
|
||||
|
@@ -43,6 +43,7 @@ var ChannelBaseURLs = []string{
|
||||
"https://api.together.xyz", // 39
|
||||
"https://ark.cn-beijing.volces.com", // 40
|
||||
"https://api.novita.ai/v3/openai", // 41
|
||||
"", // 42
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
Reference in New Issue
Block a user