refactor: abusing goroutines and channel (#1561)

* refactor: abusing goroutines

* fix: trim data prefix

* refactor: move functions to render package

* refactor: add back trim & flush

---------

Co-authored-by: JustSong <quanpengsong@gmail.com>
This commit is contained in:
zijiren
2024-06-30 18:36:33 +08:00
committed by GitHub
parent ae1cd29f94
commit b21b3b5b46
14 changed files with 614 additions and 728 deletions

View File

@@ -4,6 +4,7 @@ import (
"bufio"
"bytes"
"encoding/json"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
@@ -25,88 +26,68 @@ const (
func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) {
responseText := ""
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
scanner.Split(bufio.ScanLines)
var usage *model.Usage
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < dataPrefixLength { // ignore blank line or wrong format
continue
}
if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done {
continue
}
if strings.HasPrefix(data[dataPrefixLength:], done) {
dataChan <- data
continue
}
switch relayMode {
case relaymode.ChatCompletions:
var streamResponse ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
dataChan <- data // if error happened, pass the data to client
continue // just ignore the error
}
if len(streamResponse.Choices) == 0 {
// but for empty choice, we should not pass it to client, this is for azure
continue // just ignore empty choice
}
dataChan <- data
for _, choice := range streamResponse.Choices {
responseText += conv.AsString(choice.Delta.Content)
}
if streamResponse.Usage != nil {
usage = streamResponse.Usage
}
case relaymode.Completions:
dataChan <- data
var streamResponse CompletionsStreamResponse
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}
for _, choice := range streamResponse.Choices {
responseText += choice.Text
}
}
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
if strings.HasPrefix(data, "data: [DONE]") {
data = data[:12]
}
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
c.Render(-1, common.CustomEvent{Data: data})
return true
case <-stopChan:
return false
for scanner.Scan() {
data := scanner.Text()
if len(data) < dataPrefixLength { // ignore blank line or wrong format
continue
}
})
if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done {
continue
}
if strings.HasPrefix(data[dataPrefixLength:], done) {
render.StringData(c, data)
continue
}
switch relayMode {
case relaymode.ChatCompletions:
var streamResponse ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
render.StringData(c, data) // if error happened, pass the data to client
continue // just ignore the error
}
if len(streamResponse.Choices) == 0 {
// but for empty choice, we should not pass it to client, this is for azure
continue // just ignore empty choice
}
render.StringData(c, data)
for _, choice := range streamResponse.Choices {
responseText += conv.AsString(choice.Delta.Content)
}
if streamResponse.Usage != nil {
usage = streamResponse.Usage
}
case relaymode.Completions:
render.StringData(c, data)
var streamResponse CompletionsStreamResponse
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}
for _, choice := range streamResponse.Choices {
responseText += choice.Text
}
}
}
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
render.Done(c)
err := resp.Body.Close()
if err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil
}
return nil, responseText, usage
}