feat: support Google PaLM2 (close #105)

This commit is contained in:
JustSong
2023-07-23 00:32:47 +08:00
parent fcc1e2d568
commit 8f721d67a5
6 changed files with 222 additions and 33 deletions

View File

@@ -82,6 +82,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
apiType = APITypeClaude
} else if strings.HasPrefix(textRequest.Model, "ERNIE") {
apiType = APITypeBaidu
} else if strings.HasPrefix(textRequest.Model, "PaLM") {
apiType = APITypePaLM
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
@@ -127,6 +129,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?access_token=" + apiKey // TODO: access token expire in 30 days
case APITypePaLM:
fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?key=" + apiKey
}
var promptTokens int
var completionTokens int
@@ -186,6 +193,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypePaLM:
palmRequest := requestOpenAI2PaLM(textRequest)
jsonStr, err := json.Marshal(palmRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
@@ -323,6 +337,22 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
textResponse.Usage = *usage
return nil
}
case APITypePaLM:
if textRequest.Stream { // PaLM2 API does not support stream
err, responseText := palmStreamHandler(c, resp)
if err != nil {
return err
}
streamResponseText = responseText
return nil
} else {
err, usage := palmHandler(c, resp, promptTokens, textRequest.Model)
if err != nil {
return err
}
textResponse.Usage = *usage
return nil
}
default:
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
}