feat: add Google Gemini Pro support (#826)

* fest: Add Google Gemini Pro, fix #810

* fest: Add tooling to Gemini; Add OpenAI-like system prompt to Gemini

* refactor: removing unused if statement

* fest: Add dummy model message for system message in gemini model

* chore: update implementation

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
This commit is contained in:
David Zhuang
2023-12-16 23:48:32 -05:00
committed by GitHub
parent 366b82128f
commit 5cf23d8698
12 changed files with 353 additions and 3 deletions

View File

@@ -27,6 +27,7 @@ const (
APITypeXunfei
APITypeAIProxyLibrary
APITypeTencent
APITypeGemini
)
var httpClient *http.Client
@@ -118,6 +119,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
apiType = APITypeAIProxyLibrary
case common.ChannelTypeTencent:
apiType = APITypeTencent
case common.ChannelTypeGemini:
apiType = APITypeGemini
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
@@ -177,6 +180,24 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?key=" + apiKey
case APITypeGemini:
requestBaseURL := "https://generativelanguage.googleapis.com"
if baseURL != "" {
requestBaseURL = baseURL
}
version := "v1"
if c.GetString("api_version") != "" {
version = c.GetString("api_version")
}
action := "generateContent"
// actually gemini does not support stream, it's fake
//if textRequest.Stream {
// action = "streamGenerateContent"
//}
fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action)
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?key=" + apiKey
case APITypeZhipu:
method := "invoke"
if textRequest.Stream {
@@ -274,6 +295,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeGemini:
geminiChatRequest := requestOpenAI2Gemini(textRequest)
jsonStr, err := json.Marshal(geminiChatRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeZhipu:
zhipuRequest := requestOpenAI2Zhipu(textRequest)
jsonStr, err := json.Marshal(zhipuRequest)
@@ -367,6 +395,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
req.Header.Set("Authorization", apiKey)
case APITypePaLM:
// do not set Authorization header
case APITypeGemini:
// do not set Authorization header
default:
req.Header.Set("Authorization", "Bearer "+apiKey)
}
@@ -527,6 +557,25 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
}
return nil
}
case APITypeGemini:
if textRequest.Stream {
err, responseText := geminiChatStreamHandler(c, resp)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := geminiChatHandler(c, resp, promptTokens, textRequest.Model)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeZhipu:
if isStream {
err, usage := zhipuStreamHandler(c, resp)