refactor: use tiktoken-go to calculate token number

This commit is contained in:
JustSong
2023-04-28 18:36:17 +08:00
parent aea6c859e7
commit b08cd7e104
3 changed files with 21 additions and 2 deletions

View File

@@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go"
"io"
"net/http"
"one-api/common"
@@ -44,6 +45,13 @@ type StreamResponse struct {
} `json:"choices"`
}
var tokenEncoder, _ = tiktoken.GetEncoding("cl100k_base")
func countToken(text string) int {
token := tokenEncoder.Encode(text, nil, nil)
return len(token)
}
func Relay(c *gin.Context) {
err := relayHelper(c)
if err != nil {
@@ -64,6 +72,7 @@ func relayHelper(c *gin.Context) error {
if channelType == common.ChannelTypeCustom {
baseURL = c.GetString("base_url")
}
var textRequest TextRequest
if consumeQuota {
requestBody, err := io.ReadAll(c.Request.Body)
if err != nil {
@@ -73,7 +82,6 @@ func relayHelper(c *gin.Context) error {
if err != nil {
return err
}
var textRequest TextRequest
err = json.Unmarshal(requestBody, &textRequest)
if err != nil {
return err
@@ -112,7 +120,12 @@ func relayHelper(c *gin.Context) error {
if consumeQuota {
quota := 0
if isStream {
quota = int(float64(len(streamResponseText)) * common.BytesNumber2Quota)
var text string
for _, message := range textRequest.Messages {
text += fmt.Sprintf("%s: %s\n", message.Role, message.Content)
}
text += fmt.Sprintf("%s: %s\n", "assistant", streamResponseText)
quota = countToken(text) + 3
} else {
quota = textResponse.Usage.TotalTokens
}