mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-10-19 01:54:06 +00:00
feat: support replicate chat models (#1989)
* feat: add Replicate adaptor and integrate into channel and API types * feat: support llm chat on replicate
This commit is contained in:
222
relay/adaptor/replicate/image.go
Normal file
222
relay/adaptor/replicate/image.go
Normal file
@@ -0,0 +1,222 @@
|
||||
package replicate
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"image"
|
||||
"image/png"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"golang.org/x/image/webp"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// ImagesEditsHandler just copy response body to client
|
||||
//
|
||||
// https://replicate.com/black-forest-labs/flux-fill-pro
|
||||
// func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
// c.Writer.WriteHeader(resp.StatusCode)
|
||||
// for k, v := range resp.Header {
|
||||
// c.Writer.Header().Set(k, v[0])
|
||||
// }
|
||||
|
||||
// if _, err := io.Copy(c.Writer, resp.Body); err != nil {
|
||||
// return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// defer resp.Body.Close()
|
||||
|
||||
// return nil, nil
|
||||
// }
|
||||
|
||||
var errNextLoop = errors.New("next_loop")
|
||||
|
||||
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
payload, _ := io.ReadAll(resp.Body)
|
||||
return openai.ErrorWrapper(
|
||||
errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)),
|
||||
"bad_status_code", http.StatusInternalServerError),
|
||||
nil
|
||||
}
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
respData := new(ImageResponse)
|
||||
if err = json.Unmarshal(respBody, respData); err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
for {
|
||||
err = func() error {
|
||||
// get task
|
||||
taskReq, err := http.NewRequestWithContext(c.Request.Context(),
|
||||
http.MethodGet, respData.URLs.Get, nil)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "new request")
|
||||
}
|
||||
|
||||
taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
|
||||
taskResp, err := http.DefaultClient.Do(taskReq)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get task")
|
||||
}
|
||||
defer taskResp.Body.Close()
|
||||
|
||||
if taskResp.StatusCode != http.StatusOK {
|
||||
payload, _ := io.ReadAll(taskResp.Body)
|
||||
return errors.Errorf("bad status code [%d]%s",
|
||||
taskResp.StatusCode, string(payload))
|
||||
}
|
||||
|
||||
taskBody, err := io.ReadAll(taskResp.Body)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "read task response")
|
||||
}
|
||||
|
||||
taskData := new(ImageResponse)
|
||||
if err = json.Unmarshal(taskBody, taskData); err != nil {
|
||||
return errors.Wrap(err, "decode task response")
|
||||
}
|
||||
|
||||
switch taskData.Status {
|
||||
case "succeeded":
|
||||
case "failed", "canceled":
|
||||
return errors.Errorf("task failed: %s", taskData.Status)
|
||||
default:
|
||||
time.Sleep(time.Second * 3)
|
||||
return errNextLoop
|
||||
}
|
||||
|
||||
output, err := taskData.GetOutput()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get output")
|
||||
}
|
||||
if len(output) == 0 {
|
||||
return errors.New("response output is empty")
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
var pool errgroup.Group
|
||||
respBody := &openai.ImageResponse{
|
||||
Created: taskData.CompletedAt.Unix(),
|
||||
Data: []openai.ImageData{},
|
||||
}
|
||||
|
||||
for _, imgOut := range output {
|
||||
imgOut := imgOut
|
||||
pool.Go(func() error {
|
||||
// download image
|
||||
downloadReq, err := http.NewRequestWithContext(c.Request.Context(),
|
||||
http.MethodGet, imgOut, nil)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "new request")
|
||||
}
|
||||
|
||||
imgResp, err := http.DefaultClient.Do(downloadReq)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "download image")
|
||||
}
|
||||
defer imgResp.Body.Close()
|
||||
|
||||
if imgResp.StatusCode != http.StatusOK {
|
||||
payload, _ := io.ReadAll(imgResp.Body)
|
||||
return errors.Errorf("bad status code [%d]%s",
|
||||
imgResp.StatusCode, string(payload))
|
||||
}
|
||||
|
||||
imgData, err := io.ReadAll(imgResp.Body)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "read image")
|
||||
}
|
||||
|
||||
imgData, err = ConvertImageToPNG(imgData)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "convert image")
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
respBody.Data = append(respBody.Data, openai.ImageData{
|
||||
B64Json: fmt.Sprintf("data:image/png;base64,%s",
|
||||
base64.StdEncoding.EncodeToString(imgData)),
|
||||
})
|
||||
mu.Unlock()
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := pool.Wait(); err != nil {
|
||||
if len(respBody.Data) == 0 {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
logger.Error(c, fmt.Sprintf("some images failed to download: %+v", err))
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, respBody)
|
||||
return nil
|
||||
}()
|
||||
if err != nil {
|
||||
if errors.Is(err, errNextLoop) {
|
||||
continue
|
||||
}
|
||||
|
||||
return openai.ErrorWrapper(err, "image_task_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// ConvertImageToPNG converts a WebP image to PNG format
|
||||
func ConvertImageToPNG(webpData []byte) ([]byte, error) {
|
||||
// bypass if it's already a PNG image
|
||||
if bytes.HasPrefix(webpData, []byte("\x89PNG")) {
|
||||
return webpData, nil
|
||||
}
|
||||
|
||||
// check if is jpeg, convert to png
|
||||
if bytes.HasPrefix(webpData, []byte("\xff\xd8\xff")) {
|
||||
img, _, err := image.Decode(bytes.NewReader(webpData))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "decode jpeg")
|
||||
}
|
||||
|
||||
var pngBuffer bytes.Buffer
|
||||
if err := png.Encode(&pngBuffer, img); err != nil {
|
||||
return nil, errors.Wrap(err, "encode png")
|
||||
}
|
||||
|
||||
return pngBuffer.Bytes(), nil
|
||||
}
|
||||
|
||||
// Decode the WebP image
|
||||
img, err := webp.Decode(bytes.NewReader(webpData))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "decode webp")
|
||||
}
|
||||
|
||||
// Encode the image as PNG
|
||||
var pngBuffer bytes.Buffer
|
||||
if err := png.Encode(&pngBuffer, img); err != nil {
|
||||
return nil, errors.Wrap(err, "encode png")
|
||||
}
|
||||
|
||||
return pngBuffer.Bytes(), nil
|
||||
}
|
Reference in New Issue
Block a user