mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-20 19:01:02 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			61 lines
		
	
	
		
			1.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			61 lines
		
	
	
		
			1.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package middleware
 | |
| 
 | |
| import (
 | |
| 	"fmt"
 | |
| 	"github.com/gin-gonic/gin"
 | |
| 	"github.com/songquanpeng/one-api/common"
 | |
| 	"github.com/songquanpeng/one-api/common/helper"
 | |
| 	"github.com/songquanpeng/one-api/common/logger"
 | |
| 	"strings"
 | |
| )
 | |
| 
 | |
| func abortWithMessage(c *gin.Context, statusCode int, message string) {
 | |
| 	c.JSON(statusCode, gin.H{
 | |
| 		"error": gin.H{
 | |
| 			"message": helper.MessageWithRequestId(message, c.GetString(logger.RequestIdKey)),
 | |
| 			"type":    "one_api_error",
 | |
| 		},
 | |
| 	})
 | |
| 	c.Abort()
 | |
| 	logger.Error(c.Request.Context(), message)
 | |
| }
 | |
| 
 | |
| func getRequestModel(c *gin.Context) (string, error) {
 | |
| 	var modelRequest ModelRequest
 | |
| 	err := common.UnmarshalBodyReusable(c, &modelRequest)
 | |
| 	if err != nil {
 | |
| 		return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err)
 | |
| 	}
 | |
| 	if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
 | |
| 		if modelRequest.Model == "" {
 | |
| 			modelRequest.Model = "text-moderation-stable"
 | |
| 		}
 | |
| 	}
 | |
| 	if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
 | |
| 		if modelRequest.Model == "" {
 | |
| 			modelRequest.Model = c.Param("model")
 | |
| 		}
 | |
| 	}
 | |
| 	if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
 | |
| 		if modelRequest.Model == "" {
 | |
| 			modelRequest.Model = "dall-e-2"
 | |
| 		}
 | |
| 	}
 | |
| 	if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
 | |
| 		if modelRequest.Model == "" {
 | |
| 			modelRequest.Model = "whisper-1"
 | |
| 		}
 | |
| 	}
 | |
| 	return modelRequest.Model, nil
 | |
| }
 | |
| 
 | |
| func isModelInList(modelName string, models string) bool {
 | |
| 	modelList := strings.Split(models, ",")
 | |
| 	for _, model := range modelList {
 | |
| 		if modelName == model {
 | |
| 			return true
 | |
| 		}
 | |
| 	}
 | |
| 	return false
 | |
| }
 | 
