mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-10-17 08:38:24 +00:00
feat: support Google PaLM2 (close #105)
This commit is contained in:
@@ -82,6 +82,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
apiType = APITypeClaude
|
||||
} else if strings.HasPrefix(textRequest.Model, "ERNIE") {
|
||||
apiType = APITypeBaidu
|
||||
} else if strings.HasPrefix(textRequest.Model, "PaLM") {
|
||||
apiType = APITypePaLM
|
||||
}
|
||||
baseURL := common.ChannelBaseURLs[channelType]
|
||||
requestURL := c.Request.URL.String()
|
||||
@@ -127,6 +129,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
fullRequestURL += "?access_token=" + apiKey // TODO: access token expire in 30 days
|
||||
case APITypePaLM:
|
||||
fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
fullRequestURL += "?key=" + apiKey
|
||||
}
|
||||
var promptTokens int
|
||||
var completionTokens int
|
||||
@@ -186,6 +193,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case APITypePaLM:
|
||||
palmRequest := requestOpenAI2PaLM(textRequest)
|
||||
jsonStr, err := json.Marshal(palmRequest)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
}
|
||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
if err != nil {
|
||||
@@ -323,6 +337,22 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
textResponse.Usage = *usage
|
||||
return nil
|
||||
}
|
||||
case APITypePaLM:
|
||||
if textRequest.Stream { // PaLM2 API does not support stream
|
||||
err, responseText := palmStreamHandler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
streamResponseText = responseText
|
||||
return nil
|
||||
} else {
|
||||
err, usage := palmHandler(c, resp, promptTokens, textRequest.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
textResponse.Usage = *usage
|
||||
return nil
|
||||
}
|
||||
default:
|
||||
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
|
||||
}
|
||||
|
Reference in New Issue
Block a user