feat: now use token as the unit of quota (close #33)

This commit is contained in:
JustSong
2023-04-28 16:58:55 +08:00
parent 601fa5cea8
commit 053bb85a1c
5 changed files with 185 additions and 22 deletions

View File

@@ -2,6 +2,8 @@ package controller
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
@@ -11,14 +13,78 @@ import (
"strings"
)
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type TextRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Prompt string `json:"prompt"`
//Stream bool `json:"stream"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type TextResponse struct {
Usage `json:"usage"`
}
type StreamResponse struct {
Choices []struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
}
func Relay(c *gin.Context) {
channelType := c.GetInt("channel")
tokenId := c.GetInt("token_id")
isUnlimitedQuota := c.GetBool("unlimited_quota")
consumeQuota := c.GetBool("consume_quota")
baseURL := common.ChannelBaseURLs[channelType]
if channelType == common.ChannelTypeCustom {
baseURL = c.GetString("base_url")
}
requestBody, err := io.ReadAll(c.Request.Body)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "one_api_error",
},
})
return
}
err = c.Request.Body.Close()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "one_api_error",
},
})
return
}
var textRequest TextRequest
err = json.Unmarshal(requestBody, &textRequest)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "one_api_error",
},
})
return
}
// Reset request body
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
requestURL := c.Request.URL.String()
req, err := http.NewRequest(c.Request.Method, fmt.Sprintf("%s%s", baseURL, requestURL), c.Request.Body)
if err != nil {
@@ -30,16 +96,11 @@ func Relay(c *gin.Context) {
})
return
}
//req.Header = c.Request.Header.Clone()
// Fix HTTP Decompression failed
// https://github.com/stoplightio/prism/issues/1064#issuecomment-824682360
//req.Header.Del("Accept-Encoding")
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
req.Header.Set("Connection", c.Request.Header.Get("Connection"))
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
c.JSON(http.StatusOK, gin.H{
@@ -50,20 +111,36 @@ func Relay(c *gin.Context) {
})
return
}
err = req.Body.Close()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "one_api_error",
},
})
return
}
var textResponse TextResponse
isStream := resp.Header.Get("Content-Type") == "text/event-stream"
var streamResponseText string
defer func() {
err := req.Body.Close()
if err != nil {
common.SysError("Error closing request body: " + err.Error())
}
if !isUnlimitedQuota && requestURL == "/v1/chat/completions" {
err := model.DecreaseTokenRemainQuotaById(tokenId)
if consumeQuota {
quota := 0
if isStream {
quota = int(float64(len(streamResponseText)) * 0.8)
} else {
quota = textResponse.Usage.TotalTokens
}
err := model.ConsumeTokenQuota(tokenId, quota)
if err != nil {
common.SysError("Error decreasing token remain times: " + err.Error())
common.SysError("Error consuming token remain quota: " + err.Error())
}
}
}()
isStream := resp.Header.Get("Content-Type") == "text/event-stream"
if isStream {
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
@@ -87,6 +164,18 @@ func Relay(c *gin.Context) {
for scanner.Scan() {
data := scanner.Text()
dataChan <- data
data = data[6:]
if data != "[DONE]" {
var streamResponse StreamResponse
err = json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
common.SysError("Error unmarshalling stream response: " + err.Error())
return
}
for _, choice := range streamResponse.Choices {
streamResponseText += choice.Delta.Content
}
}
}
stopChan <- true
}()
@@ -108,6 +197,38 @@ func Relay(c *gin.Context) {
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "one_api_error",
},
})
return
}
err = resp.Body.Close()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "one_api_error",
},
})
return
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "one_api_error",
},
})
return
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
c.JSON(http.StatusOK, gin.H{
@@ -120,3 +241,12 @@ func Relay(c *gin.Context) {
}
}
}
func RelayNotImplemented(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"error": gin.H{
"message": "Not Implemented",
"type": "one_api_error",
},
})
}