feat: support baidu's models now (close #286)

This commit is contained in:
JustSong
2023-07-22 23:24:09 +08:00
parent 3c940113ab
commit 9a1db61675
7 changed files with 268 additions and 2 deletions

View File

@@ -18,6 +18,7 @@ const (
APITypeOpenAI = iota
APITypeClaude
APITypePaLM
APITypeBaidu
)
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
@@ -79,6 +80,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
apiType := APITypeOpenAI
if strings.HasPrefix(textRequest.Model, "claude") {
apiType = APITypeClaude
} else if strings.HasPrefix(textRequest.Model, "ERNIE") {
apiType = APITypeBaidu
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
@@ -112,6 +115,18 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if baseURL != "" {
fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL)
}
case APITypeBaidu:
switch textRequest.Model {
case "ERNIE-Bot":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
case "ERNIE-Bot-turbo":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
case "BLOOMZ-7B":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?access_token=" + apiKey // TODO: access token expire in 30 days
}
var promptTokens int
var completionTokens int
@@ -164,6 +179,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeBaidu:
baiduRequest := requestOpenAI2Baidu(textRequest)
jsonStr, err := json.Marshal(baiduRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
@@ -216,7 +238,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if strings.HasPrefix(textRequest.Model, "gpt-4") {
completionRatio = 2
}
if isStream {
if isStream && apiType != APITypeBaidu {
completionTokens = countTokenText(streamResponseText, textRequest.Model)
} else {
promptTokens = textResponse.Usage.PromptTokens
@@ -285,6 +307,22 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
textResponse.Usage = *usage
return nil
}
case APITypeBaidu:
if isStream {
err, usage := baiduStreamHandler(c, resp)
if err != nil {
return err
}
textResponse.Usage = *usage
return nil
} else {
err, usage := baiduHandler(c, resp)
if err != nil {
return err
}
textResponse.Usage = *usage
return nil
}
default:
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
}