diff --git a/common/constants.go b/common/constants.go index ac901139..63c627bc 100644 --- a/common/constants.go +++ b/common/constants.go @@ -38,35 +38,37 @@ const ( ) const ( - ChannelTypeUnknown = 0 - ChannelTypeOpenAI = 1 - ChannelTypeAPI2D = 2 - ChannelTypeAzure = 3 - ChannelTypeCloseAI = 4 - ChannelTypeOpenAISB = 5 - ChannelTypeOpenAIMax = 6 - ChannelTypeOhMyGPT = 7 - ChannelTypeCustom = 8 - ChannelTypeAILS = 9 - ChannelTypeAIProxy = 10 - ChannelTypePaLM = 11 - ChannelTypeAPI2GPT = 12 - ChannelTypeAIGC2D = 13 - ChannelTypeAnthropic = 14 - ChannelTypeBaidu = 15 - ChannelTypeZhipu = 16 - ChannelTypeAli = 17 - ChannelTypeXunfei = 18 - ChannelType360 = 19 - ChannelTypeOpenRouter = 20 - ChannelTypeAIProxyLibrary = 21 - ChannelTypeFastGPT = 22 - ChannelTypeTencent = 23 - ChannelTypeGemini = 24 - ChannelTypeMoonshot = 25 - ChannelTypeBaichuan = 26 - ChannelTypeMinimax = 27 - ChannelTypeMistral = 28 + ChannelTypeUnknown = iota + ChannelTypeOpenAI + ChannelTypeAPI2D + ChannelTypeAzure + ChannelTypeCloseAI + ChannelTypeOpenAISB + ChannelTypeOpenAIMax + ChannelTypeOhMyGPT + ChannelTypeCustom + ChannelTypeAILS + ChannelTypeAIProxy + ChannelTypePaLM + ChannelTypeAPI2GPT + ChannelTypeAIGC2D + ChannelTypeAnthropic + ChannelTypeBaidu + ChannelTypeZhipu + ChannelTypeAli + ChannelTypeXunfei + ChannelType360 + ChannelTypeOpenRouter + ChannelTypeAIProxyLibrary + ChannelTypeFastGPT + ChannelTypeTencent + ChannelTypeGemini + ChannelTypeMoonshot + ChannelTypeBaichuan + ChannelTypeMinimax + ChannelTypeMistral + + ChannelTypeDummy ) var ChannelBaseURLs = []string{ diff --git a/controller/model.go b/controller/model.go index 0d0d2658..0486634c 100644 --- a/controller/model.go +++ b/controller/model.go @@ -3,6 +3,7 @@ package controller import ( "fmt" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/relay/channel/ai360" "github.com/songquanpeng/one-api/relay/channel/baichuan" "github.com/songquanpeng/one-api/relay/channel/minimax" @@ -11,6 +12,8 @@ import ( "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/helper" relaymodel "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/util" + "net/http" ) // https://platform.openai.com/docs/api-reference/models/list @@ -42,6 +45,7 @@ type OpenAIModels struct { var openAIModels []OpenAIModels var openAIModelsMap map[string]OpenAIModels +var channelId2Models map[int][]string func init() { var permission []OpenAIModelPermission @@ -138,6 +142,23 @@ func init() { for _, model := range openAIModels { openAIModelsMap[model.Id] = model } + channelId2Models = make(map[int][]string) + for i := 1; i < common.ChannelTypeDummy; i++ { + adaptor := helper.GetAdaptor(constant.ChannelType2APIType(i)) + meta := &util.RelayMeta{ + ChannelType: i, + } + adaptor.Init(meta) + channelId2Models[i] = adaptor.GetModelList() + } +} + +func DashboardListModels(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": channelId2Models, + }) } func ListModels(c *gin.Context) { diff --git a/router/api-router.go b/router/api-router.go index 6d143da7..dc1fdc6b 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -14,6 +14,7 @@ func SetApiRouter(router *gin.Engine) { apiRouter.Use(middleware.GlobalAPIRateLimit()) { apiRouter.GET("/status", controller.GetStatus) + apiRouter.GET("/models", middleware.UserAuth(), controller.DashboardListModels) apiRouter.GET("/notice", controller.GetNotice) apiRouter.GET("/about", controller.GetAbout) apiRouter.GET("/home_page_content", controller.GetHomePageContent) diff --git a/web/default/src/components/ChannelsTable.js b/web/default/src/components/ChannelsTable.js index 7117fe53..358b9262 100644 --- a/web/default/src/components/ChannelsTable.js +++ b/web/default/src/components/ChannelsTable.js @@ -1,7 +1,16 @@ import React, { useEffect, useState } from 'react'; import { Button, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react'; import { Link } from 'react-router-dom'; -import { API, setPromptShown, shouldShowPrompt, showError, showInfo, showSuccess, timestamp2string } from '../helpers'; +import { + API, + loadChannelModels, + setPromptShown, + shouldShowPrompt, + showError, + showInfo, + showSuccess, + timestamp2string +} from '../helpers'; import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants'; import { renderGroup, renderNumber } from '../helpers/render'; @@ -95,6 +104,7 @@ const ChannelsTable = () => { .catch((reason) => { showError(reason); }); + loadChannelModels().then(); }, []); const manageChannel = async (id, action, idx, value) => { diff --git a/web/default/src/helpers/utils.js b/web/default/src/helpers/utils.js index 28ae4992..eb935843 100644 --- a/web/default/src/helpers/utils.js +++ b/web/default/src/helpers/utils.js @@ -1,11 +1,13 @@ import { toast } from 'react-toastify'; import { toastConstants } from '../constants'; import React from 'react'; +import { API } from './api'; const HTMLToastContent = ({ htmlContent }) => { return
; }; export default HTMLToastContent; + export function isAdmin() { let user = localStorage.getItem('user'); if (!user) return false; @@ -29,7 +31,7 @@ export function getSystemName() { export function getLogo() { let logo = localStorage.getItem('logo'); if (!logo) return '/logo.png'; - return logo + return logo; } export function getFooterHTML() { @@ -196,4 +198,30 @@ export function shouldShowPrompt(id) { export function setPromptShown(id) { localStorage.setItem(`prompt-${id}`, 'true'); +} + +let channelModels = undefined; +export async function loadChannelModels() { + const res = await API.get('/api/models'); + const { success, data } = res.data; + if (!success) { + return; + } + channelModels = data; + localStorage.setItem('channel_models', JSON.stringify(data)); +} + +export function getChannelModels(type) { + if (channelModels !== undefined && type in channelModels) { + return channelModels[type]; + } + let models = localStorage.getItem('channel_models'); + if (!models) { + return []; + } + channelModels = JSON.parse(models); + if (type in channelModels) { + return channelModels[type]; + } + return []; } \ No newline at end of file diff --git a/web/default/src/pages/Channel/EditChannel.js b/web/default/src/pages/Channel/EditChannel.js index 693242f9..59cce0d4 100644 --- a/web/default/src/pages/Channel/EditChannel.js +++ b/web/default/src/pages/Channel/EditChannel.js @@ -1,7 +1,7 @@ import React, { useEffect, useState } from 'react'; import { Button, Form, Header, Input, Message, Segment } from 'semantic-ui-react'; import { useNavigate, useParams } from 'react-router-dom'; -import { API, showError, showInfo, showSuccess, verifyJSON } from '../../helpers'; +import { API, getChannelModels, showError, showInfo, showSuccess, verifyJSON } from '../../helpers'; import { CHANNEL_OPTIONS } from '../../constants'; const MODEL_MAPPING_EXAMPLE = { @@ -56,60 +56,12 @@ const EditChannel = () => { const [customModel, setCustomModel] = useState(''); const handleInputChange = (e, { name, value }) => { setInputs((inputs) => ({ ...inputs, [name]: value })); - if (name === 'type' && inputs.models.length === 0) { - let localModels = []; - switch (value) { - case 14: - localModels = ['claude-instant-1', 'claude-2', 'claude-2.0', 'claude-2.1']; - break; - case 11: - localModels = ['PaLM-2']; - break; - case 15: - localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'ERNIE-Bot-4', 'Embedding-V1']; - break; - case 17: - localModels = ['qwen-turbo', 'qwen-plus', 'qwen-max', 'qwen-max-longcontext', 'text-embedding-v1']; - let withInternetVersion = []; - for (let i = 0; i < localModels.length; i++) { - if (localModels[i].startsWith('qwen-')) { - withInternetVersion.push(localModels[i] + '-internet'); - } - } - localModels = [...localModels, ...withInternetVersion]; - break; - case 16: - localModels = ["glm-4", "glm-4v", "glm-3-turbo",'chatglm_turbo', 'chatglm_pro', 'chatglm_std', 'chatglm_lite']; - break; - case 18: - localModels = [ - 'SparkDesk', - 'SparkDesk-v1.1', - 'SparkDesk-v2.1', - 'SparkDesk-v3.1', - 'SparkDesk-v3.5' - ]; - break; - case 19: - localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1']; - break; - case 23: - localModels = ['hunyuan']; - break; - case 24: - localModels = ['gemini-pro', 'gemini-pro-vision']; - break; - case 25: - localModels = ['moonshot-v1-8k', 'moonshot-v1-32k', 'moonshot-v1-128k']; - break; - case 26: - localModels = ['Baichuan2-Turbo', 'Baichuan2-Turbo-192k', 'Baichuan-Text-Embedding']; - break; - case 27: - localModels = ['abab5.5s-chat', 'abab5.5-chat', 'abab6-chat']; - break; + if (name === 'type') { + let localModels = getChannelModels(value); + if (inputs.models.length === 0) { + setInputs((inputs) => ({ ...inputs, models: localModels })); } - setInputs((inputs) => ({ ...inputs, models: localModels })); + setBasicModels(localModels); } };