feat: support cogview-3

This commit is contained in:
JustSong
2024-04-06 00:30:08 +08:00
parent 0cb224e62e
commit 880e12c855
9 changed files with 127 additions and 73 deletions

View File

@@ -54,9 +54,25 @@ func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, e
return imageRequest, nil
}
func isValidImageSize(model string, size string) bool {
if model == "cogview-3" {
return true
}
_, ok := constant.ImageSizeRatios[model][size]
return ok
}
func getImageSizeRatio(model string, size string) float64 {
ratio, ok := constant.ImageSizeRatios[model][size]
if !ok {
return 1
}
return ratio
}
func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *util.RelayMeta) *relaymodel.ErrorWithStatusCode {
// model validation
_, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size]
hasValidSize := isValidImageSize(imageRequest.Model, imageRequest.Size)
if !hasValidSize {
return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
}
@@ -64,7 +80,7 @@ func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *util.Rela
if imageRequest.Prompt == "" {
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
}
if len(imageRequest.Prompt) > constant.DalleImagePromptLengthLimitations[imageRequest.Model] {
if len(imageRequest.Prompt) > constant.ImagePromptLengthLimitations[imageRequest.Model] {
return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
}
// Number of generated images validation
@@ -81,10 +97,7 @@ func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) {
if imageRequest == nil {
return 0, errors.New("imageRequest is nil")
}
imageCostRatio, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size]
if !hasValidSize {
return 0, fmt.Errorf("size not supported for this image model: %s", imageRequest.Size)
}
imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size)
if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" {
if imageRequest.Size == "1024x1024" {
imageCostRatio *= 2

View File

@@ -20,12 +20,11 @@ import (
)
func isWithinRange(element string, value int) bool {
if _, ok := constant.DalleGenerationImageAmounts[element]; !ok {
if _, ok := constant.ImageGenerationAmounts[element]; !ok {
return false
}
min := constant.DalleGenerationImageAmounts[element][0]
max := constant.DalleGenerationImageAmounts[element][1]
min := constant.ImageGenerationAmounts[element][0]
max := constant.ImageGenerationAmounts[element][1]
return value >= min && value <= max
}
@@ -81,7 +80,6 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
if err != nil {
return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
}
jsonStr, err := json.Marshal(finalRequest)
if err != nil {
return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError)