mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-10-16 08:04:10 +00:00
refactor: use tiktoken-go to calculate token number
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -44,6 +45,13 @@ type StreamResponse struct {
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
var tokenEncoder, _ = tiktoken.GetEncoding("cl100k_base")
|
||||
|
||||
func countToken(text string) int {
|
||||
token := tokenEncoder.Encode(text, nil, nil)
|
||||
return len(token)
|
||||
}
|
||||
|
||||
func Relay(c *gin.Context) {
|
||||
err := relayHelper(c)
|
||||
if err != nil {
|
||||
@@ -64,6 +72,7 @@ func relayHelper(c *gin.Context) error {
|
||||
if channelType == common.ChannelTypeCustom {
|
||||
baseURL = c.GetString("base_url")
|
||||
}
|
||||
var textRequest TextRequest
|
||||
if consumeQuota {
|
||||
requestBody, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
@@ -73,7 +82,6 @@ func relayHelper(c *gin.Context) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var textRequest TextRequest
|
||||
err = json.Unmarshal(requestBody, &textRequest)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -112,7 +120,12 @@ func relayHelper(c *gin.Context) error {
|
||||
if consumeQuota {
|
||||
quota := 0
|
||||
if isStream {
|
||||
quota = int(float64(len(streamResponseText)) * common.BytesNumber2Quota)
|
||||
var text string
|
||||
for _, message := range textRequest.Messages {
|
||||
text += fmt.Sprintf("%s: %s\n", message.Role, message.Content)
|
||||
}
|
||||
text += fmt.Sprintf("%s: %s\n", "assistant", streamResponseText)
|
||||
quota = countToken(text) + 3
|
||||
} else {
|
||||
quota = textResponse.Usage.TotalTokens
|
||||
}
|
||||
|
Reference in New Issue
Block a user