mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-10-14 23:00:27 +00:00
feat: able to set model limitation for token (close #178)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
@@ -107,6 +108,18 @@ func TokenAuth() func(c *gin.Context) {
|
||||
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
|
||||
return
|
||||
}
|
||||
requestModel, err := getRequestModel(c)
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
c.Set("request_model", requestModel)
|
||||
if token.Models != nil && *token.Models != "" {
|
||||
if !isModelInList(requestModel, *token.Models) {
|
||||
abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel))
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Set("id", token.UserId)
|
||||
c.Set("token_id", token.Id)
|
||||
c.Set("token_name", token.Name)
|
||||
|
@@ -2,14 +2,12 @@ package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type ModelRequest struct {
|
||||
@@ -40,37 +38,11 @@ func Distribute() func(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// Select a channel for the user
|
||||
var modelRequest ModelRequest
|
||||
err := common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
requestModel := c.GetString("request_model")
|
||||
var err error
|
||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, requestModel, false)
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusBadRequest, "无效的请求")
|
||||
return
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "text-moderation-stable"
|
||||
}
|
||||
}
|
||||
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = c.Param("model")
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "dall-e-2"
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "whisper-1"
|
||||
}
|
||||
}
|
||||
requestModel = modelRequest.Model
|
||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, false)
|
||||
if err != nil {
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, requestModel)
|
||||
if channel != nil {
|
||||
logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||
message = "数据库一致性已被破坏,请联系管理员"
|
||||
|
@@ -1,9 +1,12 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
||||
@@ -16,3 +19,42 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
||||
c.Abort()
|
||||
logger.Error(c.Request.Context(), message)
|
||||
}
|
||||
|
||||
func getRequestModel(c *gin.Context) (string, error) {
|
||||
var modelRequest ModelRequest
|
||||
err := common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err)
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "text-moderation-stable"
|
||||
}
|
||||
}
|
||||
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = c.Param("model")
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "dall-e-2"
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "whisper-1"
|
||||
}
|
||||
}
|
||||
return modelRequest.Model, nil
|
||||
}
|
||||
|
||||
func isModelInList(modelName string, models string) bool {
|
||||
modelList := strings.Split(models, ",")
|
||||
for _, model := range modelList {
|
||||
if modelName == model {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
Reference in New Issue
Block a user