From 91decc36831def0a74f31cf3a92ecc178136f42c Mon Sep 17 00:00:00 2001 From: archer <545436317@qq.com> Date: Wed, 3 May 2023 10:57:56 +0800 Subject: [PATCH] perf: model framwork --- src/constants/model.ts | 39 ++++++++++--------- src/pages/api/chat/chat.ts | 9 ++--- src/pages/api/openapi/chat/chat.ts | 9 ++--- src/pages/api/openapi/chat/chatGpt.ts | 7 +--- src/pages/api/openapi/chat/lafGpt.ts | 11 ++---- src/pages/api/openapi/chat/vectorGpt.ts | 13 ++++--- src/pages/chat/index.tsx | 4 +- .../model/detail/components/ModelEditForm.tsx | 9 ++--- src/pages/model/detail/index.tsx | 2 +- .../model/list/components/ModelPhoneList.tsx | 2 +- .../model/list/components/ModelTable.tsx | 2 +- src/pages/number/components/PayModal.tsx | 19 --------- src/service/events/generateQA.ts | 4 +- src/service/events/pushBill.ts | 13 +++---- src/service/models/model.ts | 4 +- src/service/tools/searchKb.ts | 14 ++++--- src/service/utils/tools.ts | 4 +- src/types/mongoSchema.d.ts | 4 +- src/utils/tools.ts | 6 +-- 19 files changed, 71 insertions(+), 104 deletions(-) diff --git a/src/constants/model.ts b/src/constants/model.ts index bf14d404a..a9bc019af 100644 --- a/src/constants/model.ts +++ b/src/constants/model.ts @@ -2,33 +2,34 @@ import type { ModelSchema } from '@/types/mongoSchema'; export const embeddingModel = 'text-embedding-ada-002'; -export enum ChatModelEnum { +export enum OpenAiChatEnum { 'GPT35' = 'gpt-3.5-turbo', 'GPT4' = 'gpt-4', 'GPT432k' = 'gpt-4-32k' } + +export type ChatModelType = `${OpenAiChatEnum}`; + export const ChatModelMap = { - // ui name - [ChatModelEnum.GPT35]: 'ChatGpt', - [ChatModelEnum.GPT4]: 'Gpt4', - [ChatModelEnum.GPT432k]: 'Gpt4-32k' -}; - -export type ChatModelConstantType = { - chatModel: `${ChatModelEnum}`; - contextMaxToken: number; - maxTemperature: number; - price: number; // 多少钱 / 1token,单位: 0.00001元 -}; - -export const modelList: ChatModelConstantType[] = [ - { - chatModel: ChatModelEnum.GPT35, + [OpenAiChatEnum.GPT35]: { + name: 'ChatGpt', contextMaxToken: 4096, maxTemperature: 1.5, price: 3 + }, + [OpenAiChatEnum.GPT4]: { + name: 'Gpt4', + contextMaxToken: 8000, + maxTemperature: 1.5, + price: 30 + }, + [OpenAiChatEnum.GPT432k]: { + name: 'Gpt4-32k', + contextMaxToken: 8000, + maxTemperature: 1.5, + price: 30 } -]; +}; export enum ModelStatusEnum { running = 'running', @@ -106,7 +107,7 @@ export const defaultModel: ModelSchema = { searchMode: ModelVectorSearchModeEnum.hightSimilarity, systemPrompt: '', temperature: 0, - chatModel: ChatModelEnum.GPT35 + chatModel: OpenAiChatEnum.GPT35 }, share: { isShare: false, diff --git a/src/pages/api/chat/chat.ts b/src/pages/api/chat/chat.ts index 5cfb8e44d..f70cfe2e3 100644 --- a/src/pages/api/chat/chat.ts +++ b/src/pages/api/chat/chat.ts @@ -5,7 +5,7 @@ import { axiosConfig, openaiChatFilter } from '@/service/utils/tools'; import { ChatItemSimpleType } from '@/types/chat'; import { jsonRes } from '@/service/response'; import { PassThrough } from 'stream'; -import { modelList, ModelVectorSearchModeMap } from '@/constants/model'; +import { ChatModelMap, ModelVectorSearchModeMap } from '@/constants/model'; import { pushChatBill } from '@/service/events/pushBill'; import { gpt35StreamResponse } from '@/service/utils/openai'; import { searchKb_openai } from '@/service/tools/searchKb'; @@ -47,10 +47,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) authorization }); - const modelConstantsData = modelList.find((item) => item.chatModel === model.chat.chatModel); - if (!modelConstantsData) { - throw new Error('模型加载异常'); - } + const modelConstantsData = ChatModelMap[model.chat.chatModel]; // 读取对话内容 const prompts = [...content, prompt]; @@ -61,7 +58,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) apiKey: userApiKey || systemKey, isPay: !userApiKey, text: prompt.value, - similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22, + similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity, model, userId }); diff --git a/src/pages/api/openapi/chat/chat.ts b/src/pages/api/openapi/chat/chat.ts index 4c17c8434..21f68c3f2 100644 --- a/src/pages/api/openapi/chat/chat.ts +++ b/src/pages/api/openapi/chat/chat.ts @@ -1,11 +1,11 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { connectToDatabase } from '@/service/mongo'; import { getOpenAIApi, authOpenApiKey, authModel } from '@/service/utils/auth'; -import { axiosConfig, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools'; +import { axiosConfig, openaiChatFilter } from '@/service/utils/tools'; import { ChatItemSimpleType } from '@/types/chat'; import { jsonRes } from '@/service/response'; import { PassThrough } from 'stream'; -import { modelList, ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model'; +import { ChatModelMap, ModelVectorSearchModeMap } from '@/constants/model'; import { pushChatBill } from '@/service/events/pushBill'; import { gpt35StreamResponse } from '@/service/utils/openai'; import { searchKb_openai } from '@/service/tools/searchKb'; @@ -58,10 +58,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) modelId }); - const modelConstantsData = modelList.find((item) => item.chatModel === model.chat.chatModel); - if (!modelConstantsData) { - throw new Error('模型加载异常'); - } + const modelConstantsData = ChatModelMap[model.chat.chatModel]; // 使用了知识库搜索 if (model.chat.useKb) { diff --git a/src/pages/api/openapi/chat/chatGpt.ts b/src/pages/api/openapi/chat/chatGpt.ts index dc338957c..bef5e9e92 100644 --- a/src/pages/api/openapi/chat/chatGpt.ts +++ b/src/pages/api/openapi/chat/chatGpt.ts @@ -5,7 +5,7 @@ import { axiosConfig, openaiChatFilter } from '@/service/utils/tools'; import { ChatItemSimpleType } from '@/types/chat'; import { jsonRes } from '@/service/response'; import { PassThrough } from 'stream'; -import { modelList } from '@/constants/model'; +import { ChatModelMap } from '@/constants/model'; import { pushChatBill } from '@/service/events/pushBill'; import { gpt35StreamResponse } from '@/service/utils/openai'; @@ -60,10 +60,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) throw new Error('无权使用该模型'); } - const modelConstantsData = modelList.find((item) => item.chatModel === model.chat.chatModel); - if (!modelConstantsData) { - throw new Error('模型加载异常'); - } + const modelConstantsData = ChatModelMap[model.chat.chatModel]; // 如果有系统提示词,自动插入 if (model.chat.systemPrompt) { diff --git a/src/pages/api/openapi/chat/lafGpt.ts b/src/pages/api/openapi/chat/lafGpt.ts index 711b2e7fc..8d6cfd5a6 100644 --- a/src/pages/api/openapi/chat/lafGpt.ts +++ b/src/pages/api/openapi/chat/lafGpt.ts @@ -1,11 +1,11 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { connectToDatabase, Model } from '@/service/mongo'; import { getOpenAIApi, authOpenApiKey } from '@/service/utils/auth'; -import { axiosConfig, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools'; +import { axiosConfig, openaiChatFilter } from '@/service/utils/tools'; import { ChatItemSimpleType } from '@/types/chat'; import { jsonRes } from '@/service/response'; import { PassThrough } from 'stream'; -import { modelList, ModelVectorSearchModeMap, ChatModelEnum } from '@/constants/model'; +import { ChatModelMap, ModelVectorSearchModeMap, OpenAiChatEnum } from '@/constants/model'; import { pushChatBill } from '@/service/events/pushBill'; import { gpt35StreamResponse } from '@/service/utils/openai'; import { searchKb_openai } from '@/service/tools/searchKb'; @@ -53,10 +53,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) throw new Error('找不到模型'); } - const modelConstantsData = modelList.find((item) => item.chatModel === model.chat.chatModel); - if (!modelConstantsData) { - throw new Error('model is undefined'); - } + const modelConstantsData = ChatModelMap[model.chat.chatModel]; console.log('laf gpt start'); @@ -66,7 +63,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) // 请求一次 chatgpt 拆解需求 const promptResponse = await chatAPI.createChatCompletion( { - model: ChatModelEnum.GPT35, + model: OpenAiChatEnum.GPT35, temperature: 0, frequency_penalty: 0.5, // 越大,重复内容越少 presence_penalty: -0.5, // 越大,越容易出现新内容 diff --git a/src/pages/api/openapi/chat/vectorGpt.ts b/src/pages/api/openapi/chat/vectorGpt.ts index 20cdd4983..d6f40696d 100644 --- a/src/pages/api/openapi/chat/vectorGpt.ts +++ b/src/pages/api/openapi/chat/vectorGpt.ts @@ -1,11 +1,15 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { connectToDatabase, Model } from '@/service/mongo'; -import { axiosConfig, systemPromptFilter, openaiChatFilter } from '@/service/utils/tools'; +import { axiosConfig, openaiChatFilter } from '@/service/utils/tools'; import { getOpenAIApi, authOpenApiKey } from '@/service/utils/auth'; import { ChatItemSimpleType } from '@/types/chat'; import { jsonRes } from '@/service/response'; import { PassThrough } from 'stream'; -import { modelList, ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model'; +import { + ChatModelMap, + ModelVectorSearchModeMap, + ModelVectorSearchModeEnum +} from '@/constants/model'; import { pushChatBill } from '@/service/events/pushBill'; import { gpt35StreamResponse } from '@/service/utils/openai'; import { searchKb_openai } from '@/service/tools/searchKb'; @@ -62,10 +66,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) throw new Error('无权使用该模型'); } - const modelConstantsData = modelList.find((item) => item.chatModel === model.chat.chatModel); - if (!modelConstantsData) { - throw new Error('模型初始化异常'); - } + const modelConstantsData = ChatModelMap[model.chat.chatModel]; // 获取向量匹配到的提示词 const { code, searchPrompt } = await searchKb_openai({ diff --git a/src/pages/chat/index.tsx b/src/pages/chat/index.tsx index 53be9cf17..cdf7b14e1 100644 --- a/src/pages/chat/index.tsx +++ b/src/pages/chat/index.tsx @@ -27,7 +27,7 @@ import { import { useToast } from '@/hooks/useToast'; import { useScreen } from '@/hooks/useScreen'; import { useQuery } from '@tanstack/react-query'; -import { ChatModelEnum } from '@/constants/model'; +import { OpenAiChatEnum } from '@/constants/model'; import dynamic from 'next/dynamic'; import { useGlobalStore } from '@/store/global'; import { useCopyData } from '@/utils/tools'; @@ -69,7 +69,7 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => { name: '', avatar: '/icon/logo.png', intro: '', - chatModel: ChatModelEnum.GPT35, + chatModel: OpenAiChatEnum.GPT35, history: [] }); // 聊天框整体数据 diff --git a/src/pages/model/detail/components/ModelEditForm.tsx b/src/pages/model/detail/components/ModelEditForm.tsx index 4c17f8f57..72a397cd2 100644 --- a/src/pages/model/detail/components/ModelEditForm.tsx +++ b/src/pages/model/detail/components/ModelEditForm.tsx @@ -21,7 +21,7 @@ import { import { QuestionOutlineIcon } from '@chakra-ui/icons'; import type { ModelSchema } from '@/types/mongoSchema'; import { UseFormReturn } from 'react-hook-form'; -import { ChatModelMap, modelList, ModelVectorSearchModeMap } from '@/constants/model'; +import { ChatModelMap, ModelVectorSearchModeMap } from '@/constants/model'; import { formatPrice } from '@/utils/user'; import { useConfirm } from '@/hooks/useConfirm'; import { useSelectFile } from '@/hooks/useSelectFile'; @@ -110,17 +110,14 @@ const ModelEditForm = ({ 对话模型: - {ChatModelMap[getValues('chat.chatModel')]} + {ChatModelMap[getValues('chat.chatModel')].name} 价格: - {formatPrice( - modelList.find((item) => item.chatModel === getValues('chat.chatModel'))?.price || 0, - 1000 - )} + {formatPrice(ChatModelMap[getValues('chat.chatModel')].price, 1000)} 元/1K tokens(包括上下文和回答) diff --git a/src/pages/model/detail/index.tsx b/src/pages/model/detail/index.tsx index a4a65e1a2..437571676 100644 --- a/src/pages/model/detail/index.tsx +++ b/src/pages/model/detail/index.tsx @@ -5,7 +5,7 @@ import type { ModelSchema } from '@/types/mongoSchema'; import { Card, Box, Flex, Button, Tag, Grid } from '@chakra-ui/react'; import { useToast } from '@/hooks/useToast'; import { useForm } from 'react-hook-form'; -import { formatModelStatus, modelList, defaultModel } from '@/constants/model'; +import { formatModelStatus, defaultModel } from '@/constants/model'; import { useGlobalStore } from '@/store/global'; import { useScreen } from '@/hooks/useScreen'; import { useQuery } from '@tanstack/react-query'; diff --git a/src/pages/model/list/components/ModelPhoneList.tsx b/src/pages/model/list/components/ModelPhoneList.tsx index b6d4f1fa0..0190db08a 100644 --- a/src/pages/model/list/components/ModelPhoneList.tsx +++ b/src/pages/model/list/components/ModelPhoneList.tsx @@ -43,7 +43,7 @@ const ModelPhoneList = ({ 对话模型: - {ChatModelMap[model.chat.chatModel]} + {ChatModelMap[model.chat.chatModel].name} 模型温度: diff --git a/src/pages/model/list/components/ModelTable.tsx b/src/pages/model/list/components/ModelTable.tsx index de3e5a6a9..c73909fd6 100644 --- a/src/pages/model/list/components/ModelTable.tsx +++ b/src/pages/model/list/components/ModelTable.tsx @@ -36,7 +36,7 @@ const ModelTable = ({ key: 'service', render: (model: ModelSchema) => ( - {ChatModelMap[model.chat.chatModel]} + {ChatModelMap[model.chat.chatModel].name} ) }, diff --git a/src/pages/number/components/PayModal.tsx b/src/pages/number/components/PayModal.tsx index af1ce457b..e8e0c33ca 100644 --- a/src/pages/number/components/PayModal.tsx +++ b/src/pages/number/components/PayModal.tsx @@ -85,25 +85,6 @@ const PayModal = ({ onClose }: { onClose: () => void }) => { {!payId && ( <> - {/* 价格表 */} - {/* - - - - 模型类型 - 价格(元/1K tokens,包含所有上下文) - - - - {modelList.map((item, i) => ( - - {item.name} - {formatPrice(item.price, 1000)} - - ))} - - - */} {[5, 10, 20, 50].map((item) => ( item.chatModel === chatModel); // 计算价格 - const unitPrice = modelItem?.price || 5; + const unitPrice = ChatModelMap[chatModel]?.price || 5; const price = unitPrice * tokens; try { @@ -88,8 +86,7 @@ export const pushSplitDataBill = async ({ if (isPay) { try { // 获取模型单价格, 都是用 gpt35 拆分 - const modelItem = modelList.find((item) => item.chatModel === ChatModelEnum.GPT35); - const unitPrice = modelItem?.price || 3; + const unitPrice = ChatModelMap[OpenAiChatEnum.GPT35]?.price || 3; // 计算价格 const price = unitPrice * tokenLen; @@ -97,7 +94,7 @@ export const pushSplitDataBill = async ({ const res = await Bill.create({ userId, type, - modelName: ChatModelEnum.GPT35, + modelName: OpenAiChatEnum.GPT35, textLen: text.length, tokenLen, price diff --git a/src/service/models/model.ts b/src/service/models/model.ts index ec6b56868..dee8c384a 100644 --- a/src/service/models/model.ts +++ b/src/service/models/model.ts @@ -4,7 +4,7 @@ import { ModelVectorSearchModeMap, ModelVectorSearchModeEnum, ChatModelMap, - ChatModelEnum + OpenAiChatEnum } from '@/constants/model'; const ModelSchema = new Schema({ @@ -57,7 +57,7 @@ const ModelSchema = new Schema({ // 聊天时使用的模型 type: String, enum: Object.keys(ChatModelMap), - default: ChatModelEnum.GPT35 + default: OpenAiChatEnum.GPT35 } }, share: { diff --git a/src/service/tools/searchKb.ts b/src/service/tools/searchKb.ts index 464fb1642..b283dd78d 100644 --- a/src/service/tools/searchKb.ts +++ b/src/service/tools/searchKb.ts @@ -1,6 +1,6 @@ import { openaiCreateEmbedding } from '../utils/openai'; import { PgClient } from '@/service/pg'; -import { ModelDataStatusEnum, ModelVectorSearchModeEnum } from '@/constants/model'; +import { ModelDataStatusEnum, ModelVectorSearchModeEnum, ChatModelMap } from '@/constants/model'; import { ModelSchema } from '@/types/mongoSchema'; import { systemPromptFilter } from '../utils/tools'; @@ -9,9 +9,9 @@ import { systemPromptFilter } from '../utils/tools'; */ export const searchKb_openai = async ({ apiKey, - isPay, + isPay = true, text, - similarity, + similarity = 0.2, model, userId }: { @@ -20,7 +20,7 @@ export const searchKb_openai = async ({ text: string; model: ModelSchema; userId: string; - similarity: number; + similarity?: number; }): Promise<{ code: 200 | 201; searchPrompt?: { @@ -28,6 +28,8 @@ export const searchKb_openai = async ({ value: string; }; }> => { + const modelConstantsData = ChatModelMap[model.chat.chatModel]; + // 获取提示词的向量 const { vector: promptVector } = await openaiCreateEmbedding({ isPay, @@ -78,11 +80,11 @@ export const searchKb_openai = async ({ } // 有匹配情况下,system 添加知识库内容。 - // 系统提示词过滤,最多 2500 tokens + // 系统提示词过滤,最多 65% tokens const filterSystemPrompt = systemPromptFilter({ model: model.chat.chatModel, prompts: systemPrompts, - maxTokens: 2500 + maxTokens: Math.floor(modelConstantsData.contextMaxToken * 0.65) }); return { diff --git a/src/service/utils/tools.ts b/src/service/utils/tools.ts index e6b74dc71..c6c00254f 100644 --- a/src/service/utils/tools.ts +++ b/src/service/utils/tools.ts @@ -3,7 +3,7 @@ import jwt from 'jsonwebtoken'; import { ChatItemSimpleType } from '@/types/chat'; import { countChatTokens, sliceTextByToken } from '@/utils/tools'; import { ChatCompletionRequestMessageRoleEnum, ChatCompletionRequestMessage } from 'openai'; -import { ChatModelEnum } from '@/constants/model'; +import type { ChatModelType } from '@/constants/model'; /* 密码加密 */ export const hashPassword = (psw: string) => { @@ -44,7 +44,7 @@ export const openaiChatFilter = ({ prompts, maxTokens }: { - model: `${ChatModelEnum}`; + model: ChatModelType; prompts: ChatItemSimpleType[]; maxTokens: number; }) => { diff --git a/src/types/mongoSchema.d.ts b/src/types/mongoSchema.d.ts index e8e649be9..ed6a2ca99 100644 --- a/src/types/mongoSchema.d.ts +++ b/src/types/mongoSchema.d.ts @@ -3,7 +3,7 @@ import { ModelStatusEnum, ModelNameEnum, ModelVectorSearchModeEnum, - ChatModelEnum + ChatModelType } from '@/constants/model'; import type { DataType } from './data'; @@ -41,7 +41,7 @@ export interface ModelSchema { searchMode: `${ModelVectorSearchModeEnum}`; systemPrompt: string; temperature: number; - chatModel: `${ChatModelEnum}`; // 聊天时用的模型,训练后就是训练的模型 + chatModel: ChatModelType; // 聊天时用的模型,训练后就是训练的模型 }; share: { isShare: boolean; diff --git a/src/utils/tools.ts b/src/utils/tools.ts index 4b833745f..e14933914 100644 --- a/src/utils/tools.ts +++ b/src/utils/tools.ts @@ -2,7 +2,7 @@ import crypto from 'crypto'; import { useToast } from '@/hooks/useToast'; import { encoding_for_model, type Tiktoken } from '@dqbd/tiktoken'; import Graphemer from 'graphemer'; -import { ChatModelEnum } from '@/constants/model'; +import type { ChatModelType } from '@/constants/model'; const textDecoder = new TextDecoder(); const graphemer = new Graphemer(); @@ -130,7 +130,7 @@ export const countChatTokens = ({ model = 'gpt-3.5-turbo', messages }: { - model?: `${ChatModelEnum}`; + model?: ChatModelType; messages: { role: 'system' | 'user' | 'assistant'; content: string }[]; }) => { const text = getChatGPTEncodingText(messages, model); @@ -142,7 +142,7 @@ export const sliceTextByToken = ({ text, length }: { - model?: `${ChatModelEnum}`; + model?: ChatModelType; text: string; length: number; }) => {