mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-10-15 23:54:30 +00:00
feat: now use token as the unit of quota (close #33)
This commit is contained in:
@@ -2,6 +2,8 @@ package controller
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
@@ -11,14 +13,78 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type TextRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Prompt string `json:"prompt"`
|
||||
//Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type TextResponse struct {
|
||||
Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type StreamResponse struct {
|
||||
Choices []struct {
|
||||
Delta struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"delta"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
func Relay(c *gin.Context) {
|
||||
channelType := c.GetInt("channel")
|
||||
tokenId := c.GetInt("token_id")
|
||||
isUnlimitedQuota := c.GetBool("unlimited_quota")
|
||||
consumeQuota := c.GetBool("consume_quota")
|
||||
baseURL := common.ChannelBaseURLs[channelType]
|
||||
if channelType == common.ChannelTypeCustom {
|
||||
baseURL = c.GetString("base_url")
|
||||
}
|
||||
requestBody, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
"type": "one_api_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
err = c.Request.Body.Close()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
"type": "one_api_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
var textRequest TextRequest
|
||||
err = json.Unmarshal(requestBody, &textRequest)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
"type": "one_api_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
// Reset request body
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
requestURL := c.Request.URL.String()
|
||||
req, err := http.NewRequest(c.Request.Method, fmt.Sprintf("%s%s", baseURL, requestURL), c.Request.Body)
|
||||
if err != nil {
|
||||
@@ -30,16 +96,11 @@ func Relay(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
//req.Header = c.Request.Header.Clone()
|
||||
// Fix HTTP Decompression failed
|
||||
// https://github.com/stoplightio/prism/issues/1064#issuecomment-824682360
|
||||
//req.Header.Del("Accept-Encoding")
|
||||
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
req.Header.Set("Connection", c.Request.Header.Get("Connection"))
|
||||
client := &http.Client{}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -50,20 +111,36 @@ func Relay(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
err = req.Body.Close()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
"type": "one_api_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var textResponse TextResponse
|
||||
isStream := resp.Header.Get("Content-Type") == "text/event-stream"
|
||||
var streamResponseText string
|
||||
|
||||
defer func() {
|
||||
err := req.Body.Close()
|
||||
if err != nil {
|
||||
common.SysError("Error closing request body: " + err.Error())
|
||||
}
|
||||
if !isUnlimitedQuota && requestURL == "/v1/chat/completions" {
|
||||
err := model.DecreaseTokenRemainQuotaById(tokenId)
|
||||
if consumeQuota {
|
||||
quota := 0
|
||||
if isStream {
|
||||
quota = int(float64(len(streamResponseText)) * 0.8)
|
||||
} else {
|
||||
quota = textResponse.Usage.TotalTokens
|
||||
}
|
||||
err := model.ConsumeTokenQuota(tokenId, quota)
|
||||
if err != nil {
|
||||
common.SysError("Error decreasing token remain times: " + err.Error())
|
||||
common.SysError("Error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
}
|
||||
}()
|
||||
isStream := resp.Header.Get("Content-Type") == "text/event-stream"
|
||||
|
||||
if isStream {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
@@ -87,6 +164,18 @@ func Relay(c *gin.Context) {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
dataChan <- data
|
||||
data = data[6:]
|
||||
if data != "[DONE]" {
|
||||
var streamResponse StreamResponse
|
||||
err = json.Unmarshal([]byte(data), &streamResponse)
|
||||
if err != nil {
|
||||
common.SysError("Error unmarshalling stream response: " + err.Error())
|
||||
return
|
||||
}
|
||||
for _, choice := range streamResponse.Choices {
|
||||
streamResponseText += choice.Delta.Content
|
||||
}
|
||||
}
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
@@ -108,6 +197,38 @@ func Relay(c *gin.Context) {
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
"type": "one_api_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
"type": "one_api_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &textResponse)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
"type": "one_api_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
// Reset response body
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -120,3 +241,12 @@ func Relay(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func RelayNotImplemented(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Not Implemented",
|
||||
"type": "one_api_error",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
Reference in New Issue
Block a user