feat: update ali stream implementation & enable internet search (#856)

* Update relay-ali.go: 改进stream模式,添加联网搜索能力

通义千问支持stream的增量模式,不需要每次去掉上次的前缀;实测qwen-max联网模式效果不错,添加了联网模式。如果别的模型有问题可以改为单独给qwen-max开放

* 删除"stream参数"

刚发现原来阿里api没有这个参数,上次误加了。

* refactor: only enable search when specified

* fix: remove custom suffix when get model ratio

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
This commit is contained in:
moondie
2023-12-24 16:17:21 +08:00
committed by GitHub
parent a763681c2e
commit ee9e746520
3 changed files with 32 additions and 15 deletions

View File

@@ -23,10 +23,11 @@ type AliInput struct {
}
type AliParameters struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"`
}
type AliChatRequest struct {
@@ -81,6 +82,8 @@ type AliChatResponse struct {
AliError
}
const AliEnableSearchModelSuffix = "-internet"
func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
messages := make([]AliMessage, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
@@ -90,17 +93,21 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
Role: strings.ToLower(message.Role),
})
}
enableSearch := false
aliModel := request.Model
if strings.HasSuffix(aliModel, AliEnableSearchModelSuffix) {
enableSearch = true
aliModel = strings.TrimSuffix(aliModel, AliEnableSearchModelSuffix)
}
return &AliChatRequest{
Model: request.Model,
Model: aliModel,
Input: AliInput{
Messages: messages,
},
//Parameters: AliParameters{ // ChatGPT's parameters are not compatible with Ali's
// TopP: request.TopP,
// TopK: 50,
// //Seed: 0,
// //EnableSearch: false,
//},
Parameters: AliParameters{
EnableSearch: enableSearch,
IncrementalOutput: request.Stream,
},
}
}
@@ -202,7 +209,7 @@ func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStre
Id: aliResponse.RequestId,
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "ernie-bot",
Model: "qwen",
Choices: []ChatCompletionsStreamResponseChoice{choice},
}
return &response
@@ -240,7 +247,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat
stopChan <- true
}()
setEventStreamHeaders(c)
lastResponseText := ""
//lastResponseText := ""
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
@@ -256,8 +263,8 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
}
response := streamResponseAli2OpenAI(&aliResponse)
response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
lastResponseText = aliResponse.Output.Text
//response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
//lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())