From 2099a87908cf99ac20ee2ce6766458bbcf4314c2 Mon Sep 17 00:00:00 2001 From: archer <545436317@qq.com> Date: Wed, 29 Mar 2023 00:22:48 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=A8=A1=E5=9E=8B=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E7=AE=A1=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit feat: 模型数据导入 feat: redis 向量入库 feat: 向量索引 feat: 文件导入模型 perf: 交互 perf: prompt --- .env.template | 4 +- package.json | 1 + pnpm-lock.yaml | 95 +++++++ src/api/model.ts | 33 ++- src/components/Layout/index.tsx | 12 +- src/constants/model.ts | 26 +- src/constants/redis.ts | 1 + src/hooks/usePaging.ts | 7 +- src/pages/api/chat/chatGpt.ts | 2 +- src/pages/api/chat/vectorGpt.ts | 241 ++++++++++++++++++ src/pages/api/data/splitData.ts | 4 +- src/pages/api/model/create.ts | 17 +- ...elMoedlDataById.ts => delModelDataById.ts} | 8 +- src/pages/api/model/data/getModelData.ts | 11 +- ...pushModelData.ts => pushModelDataInput.ts} | 6 +- .../api/model/data/pushModelDataSelectData.ts | 57 +++++ src/pages/api/model/data/putModelData.ts | 12 +- src/pages/api/model/data/splitData.ts | 67 +++++ src/pages/api/model/del.ts | 41 ++- src/pages/api/model/update.ts | 2 +- src/pages/chat/index.tsx | 1 + src/pages/data/list.tsx | 4 +- .../detail/components/InputDataModal.tsx | 141 ++++++++++ .../model/detail/components/ModelDataCard.tsx | 202 +++++++++++++++ .../{ => detail}/components/ModelEditForm.tsx | 82 +++++- .../detail/components/SelectFileModal.tsx | 155 +++++++++++ .../{ => detail}/components/Training.tsx | 0 .../model/{detail.tsx => detail/index.tsx} | 163 ++++-------- .../{ => list}/components/CreateModel.tsx | 0 .../{ => list}/components/ModelPhoneList.tsx | 0 .../{ => list}/components/ModelTable.tsx | 0 src/pages/model/{list.tsx => list/index.tsx} | 0 src/service/events/generateQA.ts | 151 +++++------ src/service/events/generateVector.ts | 88 +++++++ src/service/events/pushBill.ts | 2 +- src/service/models/modelData.ts | 19 +- src/service/models/splitData.ts | 31 +++ src/service/mongo.ts | 5 +- src/service/redis.ts | 45 ++++ src/service/utils/tools.ts | 18 ++ src/types/index.d.ts | 3 + src/types/model.d.ts | 9 + src/types/mongoSchema.d.ts | 22 +- src/types/redis.d.ts | 6 + src/utils/tools.ts | 12 + 45 files changed, 1522 insertions(+), 284 deletions(-) create mode 100644 src/constants/redis.ts create mode 100644 src/pages/api/chat/vectorGpt.ts rename src/pages/api/model/data/{delMoedlDataById.ts => delModelDataById.ts} (88%) rename src/pages/api/model/data/{pushModelData.ts => pushModelDataInput.ts} (84%) create mode 100644 src/pages/api/model/data/pushModelDataSelectData.ts create mode 100644 src/pages/api/model/data/splitData.ts create mode 100644 src/pages/model/detail/components/InputDataModal.tsx create mode 100644 src/pages/model/detail/components/ModelDataCard.tsx rename src/pages/model/{ => detail}/components/ModelEditForm.tsx (72%) create mode 100644 src/pages/model/detail/components/SelectFileModal.tsx rename src/pages/model/{ => detail}/components/Training.tsx (100%) rename src/pages/model/{detail.tsx => detail/index.tsx} (64%) rename src/pages/model/{ => list}/components/CreateModel.tsx (100%) rename src/pages/model/{ => list}/components/ModelPhoneList.tsx (100%) rename src/pages/model/{ => list}/components/ModelTable.tsx (100%) rename src/pages/model/{list.tsx => list/index.tsx} (100%) create mode 100644 src/service/events/generateVector.ts create mode 100644 src/service/models/splitData.ts create mode 100644 src/service/redis.ts create mode 100644 src/types/redis.d.ts diff --git a/.env.template b/.env.template index 33c0a8203..0d1b43381 100644 --- a/.env.template +++ b/.env.template @@ -3,4 +3,6 @@ AXIOS_PROXY_PORT=33210 MONGODB_URI= MY_MAIL= MAILE_CODE= -TOKEN_KEY= \ No newline at end of file +TOKEN_KEY= +OPENAIKEY= +REDIS_URL= \ No newline at end of file diff --git a/package.json b/package.json index b7d61f55b..a041f7726 100644 --- a/package.json +++ b/package.json @@ -41,6 +41,7 @@ "react-hook-form": "^7.43.1", "react-markdown": "^8.0.5", "react-syntax-highlighter": "^15.5.0", + "redis": "^4.6.5", "rehype-katex": "^6.0.2", "remark-gfm": "^3.0.1", "remark-math": "^5.1.1", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index a9109580f..affab94bf 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -47,6 +47,7 @@ specifiers: react-hook-form: ^7.43.1 react-markdown: ^8.0.5 react-syntax-highlighter: ^15.5.0 + redis: ^4.6.5 rehype-katex: ^6.0.2 remark-gfm: ^3.0.1 remark-math: ^5.1.1 @@ -87,6 +88,7 @@ dependencies: react-hook-form: registry.npmmirror.com/react-hook-form/7.43.1_react@18.2.0 react-markdown: registry.npmmirror.com/react-markdown/8.0.5_pmekkgnqduwlme35zpnqhenc34 react-syntax-highlighter: registry.npmmirror.com/react-syntax-highlighter/15.5.0_react@18.2.0 + redis: registry.npmmirror.com/redis/4.6.5 rehype-katex: registry.npmmirror.com/rehype-katex/6.0.2 remark-gfm: registry.npmmirror.com/remark-gfm/3.0.1 remark-math: registry.npmmirror.com/remark-math/5.1.1 @@ -4504,6 +4506,72 @@ packages: version: 2.11.6 dev: false + registry.npmmirror.com/@redis/bloom/1.2.0_@redis+client@1.5.6: + resolution: {integrity: sha512-HG2DFjYKbpNmVXsa0keLHp/3leGJz1mjh09f2RLGGLQZzSHpkmZWuwJbAvo3QcRY8p80m5+ZdXZdYOSBLlp7Cg==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/@redis/bloom/-/bloom-1.2.0.tgz} + id: registry.npmmirror.com/@redis/bloom/1.2.0 + name: '@redis/bloom' + version: 1.2.0 + peerDependencies: + '@redis/client': ^1.0.0 + dependencies: + '@redis/client': registry.npmmirror.com/@redis/client/1.5.6 + dev: false + + registry.npmmirror.com/@redis/client/1.5.6: + resolution: {integrity: sha512-dFD1S6je+A47Lj22jN/upVU2fj4huR7S9APd7/ziUXsIXDL+11GPYti4Suv5y8FuXaN+0ZG4JF+y1houEJ7ToA==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/@redis/client/-/client-1.5.6.tgz} + name: '@redis/client' + version: 1.5.6 + engines: {node: '>=14'} + dependencies: + cluster-key-slot: registry.npmmirror.com/cluster-key-slot/1.1.2 + generic-pool: registry.npmmirror.com/generic-pool/3.9.0 + yallist: registry.npmmirror.com/yallist/4.0.0 + dev: false + + registry.npmmirror.com/@redis/graph/1.1.0_@redis+client@1.5.6: + resolution: {integrity: sha512-16yZWngxyXPd+MJxeSr0dqh2AIOi8j9yXKcKCwVaKDbH3HTuETpDVPcLujhFYVPtYrngSco31BUcSa9TH31Gqg==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/@redis/graph/-/graph-1.1.0.tgz} + id: registry.npmmirror.com/@redis/graph/1.1.0 + name: '@redis/graph' + version: 1.1.0 + peerDependencies: + '@redis/client': ^1.0.0 + dependencies: + '@redis/client': registry.npmmirror.com/@redis/client/1.5.6 + dev: false + + registry.npmmirror.com/@redis/json/1.0.4_@redis+client@1.5.6: + resolution: {integrity: sha512-LUZE2Gdrhg0Rx7AN+cZkb1e6HjoSKaeeW8rYnt89Tly13GBI5eP4CwDVr+MY8BAYfCg4/N15OUrtLoona9uSgw==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/@redis/json/-/json-1.0.4.tgz} + id: registry.npmmirror.com/@redis/json/1.0.4 + name: '@redis/json' + version: 1.0.4 + peerDependencies: + '@redis/client': ^1.0.0 + dependencies: + '@redis/client': registry.npmmirror.com/@redis/client/1.5.6 + dev: false + + registry.npmmirror.com/@redis/search/1.1.2_@redis+client@1.5.6: + resolution: {integrity: sha512-/cMfstG/fOh/SsE+4/BQGeuH/JJloeWuH+qJzM8dbxuWvdWibWAOAHHCZTMPhV3xIlH4/cUEIA8OV5QnYpaVoA==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/@redis/search/-/search-1.1.2.tgz} + id: registry.npmmirror.com/@redis/search/1.1.2 + name: '@redis/search' + version: 1.1.2 + peerDependencies: + '@redis/client': ^1.0.0 + dependencies: + '@redis/client': registry.npmmirror.com/@redis/client/1.5.6 + dev: false + + registry.npmmirror.com/@redis/time-series/1.0.4_@redis+client@1.5.6: + resolution: {integrity: sha512-ThUIgo2U/g7cCuZavucQTQzA9g9JbDDY2f64u3AbAoz/8vE2lt2U37LamDUVChhaDA3IRT9R6VvJwqnUfTJzng==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/@redis/time-series/-/time-series-1.0.4.tgz} + id: registry.npmmirror.com/@redis/time-series/1.0.4 + name: '@redis/time-series' + version: 1.0.4 + peerDependencies: + '@redis/client': ^1.0.0 + dependencies: + '@redis/client': registry.npmmirror.com/@redis/client/1.5.6 + dev: false + registry.npmmirror.com/@rushstack/eslint-patch/1.2.0: resolution: {integrity: sha512-sXo/qW2/pAcmT43VoRKOJbDOfV3cYpq3szSVfIThQXNt+E4DfKj361vaAt3c88U5tPUxzEswam7GW48PJqtKAg==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/@rushstack/eslint-patch/-/eslint-patch-1.2.0.tgz} name: '@rushstack/eslint-patch' @@ -5562,6 +5630,13 @@ packages: version: 0.0.1 dev: false + registry.npmmirror.com/cluster-key-slot/1.1.2: + resolution: {integrity: sha512-RMr0FhtfXemyinomL4hrWcYJxmX6deFdCxpJzhDttxgO1+bcCnkk+9drydLVDmAMG7NE6aN/fl4F7ucU/90gAA==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/cluster-key-slot/-/cluster-key-slot-1.1.2.tgz} + name: cluster-key-slot + version: 1.1.2 + engines: {node: '>=0.10.0'} + dev: false + registry.npmmirror.com/color-convert/1.9.3: resolution: {integrity: sha512-QfAUtd+vFdAtFQcC8CCyYt1fYWxSqAiK2cSD6zDB8N3cpsEBAvRxp9zOGg6G/SHHJYAT88/az/IuDGALsNVbGg==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/color-convert/-/color-convert-1.9.3.tgz} name: color-convert @@ -6799,6 +6874,13 @@ packages: version: 1.2.3 dev: true + registry.npmmirror.com/generic-pool/3.9.0: + resolution: {integrity: sha512-hymDOu5B53XvN4QT9dBmZxPX4CWhBPPLguTZ9MMFeFa/Kg0xWVfylOVNlJji/E7yTZWFd/q9GO5TxDLq156D7g==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/generic-pool/-/generic-pool-3.9.0.tgz} + name: generic-pool + version: 3.9.0 + engines: {node: '>= 4'} + dev: false + registry.npmmirror.com/gensync/1.0.0-beta.2: resolution: {integrity: sha512-3hN7NaskYvMDLQY55gnW3NQ+mesEAepTqlg+VEbj7zzqEMBVNhzcGYYeqFo/TlYz6eQiFcp1HcsCZO+nGgS8zg==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/gensync/-/gensync-1.0.0-beta.2.tgz} name: gensync @@ -9367,6 +9449,19 @@ packages: picomatch: registry.npmmirror.com/picomatch/2.3.1 dev: false + registry.npmmirror.com/redis/4.6.5: + resolution: {integrity: sha512-O0OWA36gDQbswOdUuAhRL6mTZpHFN525HlgZgDaVNgCJIAZR3ya06NTESb0R+TUZ+BFaDpz6NnnVvoMx9meUFg==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/redis/-/redis-4.6.5.tgz} + name: redis + version: 4.6.5 + dependencies: + '@redis/bloom': registry.npmmirror.com/@redis/bloom/1.2.0_@redis+client@1.5.6 + '@redis/client': registry.npmmirror.com/@redis/client/1.5.6 + '@redis/graph': registry.npmmirror.com/@redis/graph/1.1.0_@redis+client@1.5.6 + '@redis/json': registry.npmmirror.com/@redis/json/1.0.4_@redis+client@1.5.6 + '@redis/search': registry.npmmirror.com/@redis/search/1.1.2_@redis+client@1.5.6 + '@redis/time-series': registry.npmmirror.com/@redis/time-series/1.0.4_@redis+client@1.5.6 + dev: false + registry.npmmirror.com/refractor/3.6.0: resolution: {integrity: sha512-MY9W41IOWxxk31o+YvFCNyNzdkc9M20NoZK5vq6jkv4I/uh2zkWcfudj0Q1fovjUQJrNewS9NMzeTtqPf+n5EA==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/refractor/-/refractor-3.6.0.tgz} name: refractor diff --git a/src/api/model.ts b/src/api/model.ts index 8381667fe..0d5cdc1e1 100644 --- a/src/api/model.ts +++ b/src/api/model.ts @@ -1,7 +1,10 @@ import { GET, POST, DELETE, PUT } from './request'; -import type { ModelSchema } from '@/types/mongoSchema'; +import type { ModelSchema, ModelDataSchema } from '@/types/mongoSchema'; import { ModelUpdateParams } from '@/types/model'; import { TrainingItemType } from '../types/training'; +import { PagingData } from '@/types'; +import { RequestPaging } from '../types/index'; +import { Obj2Query } from '@/utils/tools'; export const getMyModels = () => GET('/model/list'); @@ -16,13 +19,35 @@ export const putModelById = (id: string, data: ModelUpdateParams) => PUT(`/model/update?modelId=${id}`, data); export const postTrainModel = (id: string, form: FormData) => - POST(`/model/train?modelId=${id}`, form, { + POST(`/model/train/train?modelId=${id}`, form, { headers: { 'content-type': 'multipart/form-data' } }); -export const putModelTrainingStatus = (id: string) => PUT(`/model/putTrainStatus?modelId=${id}`); +export const putModelTrainingStatus = (id: string) => + PUT(`/model/train/putTrainStatus?modelId=${id}`); export const getModelTrainings = (id: string) => - GET(`/model/getTrainings?modelId=${id}`); + GET(`/model/train/getTrainings?modelId=${id}`); + +/* 模型 data */ + +type GetModelDataListProps = RequestPaging & { + modelId: string; +}; +export const getModelDataList = (props: GetModelDataListProps) => + GET(`/model/data/getModelData?${Obj2Query(props)}`); + +export const postModelDataInput = (data: { + modelId: string; + data: { text: ModelDataSchema['text']; q: ModelDataSchema['q'] }[]; +}) => POST(`/model/data/pushModelDataInput`, data); + +export const postModelDataFileText = (modelId: string, text: string) => + POST(`/model/data/splitData`, { modelId, text }); + +export const putModelDataById = (data: { dataId: string; text: string }) => + PUT('/model/data/putModelData', data); +export const delOneModelData = (dataId: string) => + DELETE(`/model/data/delModelDataById?dataId=${dataId}`); diff --git a/src/components/Layout/index.tsx b/src/components/Layout/index.tsx index ee448021f..358ee25db 100644 --- a/src/components/Layout/index.tsx +++ b/src/components/Layout/index.tsx @@ -26,12 +26,12 @@ const navbarList = [ link: '/model/list', activeLink: ['/model/list', '/model/detail'] }, - { - label: '数据', - icon: 'icon-datafull', - link: '/data/list', - activeLink: ['/data/list', '/data/detail'] - }, + // { + // label: '数据', + // icon: 'icon-datafull', + // link: '/data/list', + // activeLink: ['/data/list', '/data/detail'] + // }, { label: '账号', icon: 'icon-yonghu-yuan', diff --git a/src/constants/model.ts b/src/constants/model.ts index ef1a4800d..784908bcd 100644 --- a/src/constants/model.ts +++ b/src/constants/model.ts @@ -1,11 +1,17 @@ -import type { ServiceName } from '@/types/mongoSchema'; -import { ModelSchema } from '../types/mongoSchema'; +import type { ServiceName, ModelDataType, ModelSchema } from '@/types/mongoSchema'; export enum ChatModelNameEnum { GPT35 = 'gpt-3.5-turbo', + VECTOR_GPT = 'VECTOR_GPT', GPT3 = 'text-davinci-003' } +export const ChatModelNameMap = { + [ChatModelNameEnum.GPT35]: 'gpt-3.5-turbo', + [ChatModelNameEnum.VECTOR_GPT]: 'gpt-3.5-turbo', + [ChatModelNameEnum.GPT3]: 'text-davinci-003' +}; + export type ModelConstantsData = { serviceCompany: `${ServiceName}`; name: string; @@ -29,6 +35,17 @@ export const modelList: ModelConstantsData[] = [ trainedMaxToken: 2000, maxTemperature: 2, price: 3 + }, + { + serviceCompany: 'openai', + name: '知识库', + model: ChatModelNameEnum.VECTOR_GPT, + trainName: 'vector', + maxToken: 4000, + contextMaxToken: 7500, + trainedMaxToken: 2000, + maxTemperature: 1, + price: 3 } // { // serviceCompany: 'openai', @@ -76,6 +93,11 @@ export const formatModelStatus = { } }; +export const ModelDataStatusMap = { + 0: '训练完成', + 1: '训练中' +}; + export const defaultModel: ModelSchema = { _id: '', userId: '', diff --git a/src/constants/redis.ts b/src/constants/redis.ts new file mode 100644 index 000000000..9b0edc618 --- /dev/null +++ b/src/constants/redis.ts @@ -0,0 +1 @@ +export const VecModelDataIndex = 'model:data'; diff --git a/src/hooks/usePaging.ts b/src/hooks/usePaging.ts index 2f8b88c85..1bf8cdfe3 100644 --- a/src/hooks/usePaging.ts +++ b/src/hooks/usePaging.ts @@ -8,7 +8,7 @@ export const usePaging = ({ pageSize = 10, params = {} }: { - api: (data: any) => Promise>; + api: (data: any) => any; pageSize?: number; params?: Record; }) => { @@ -30,7 +30,7 @@ export const usePaging = ({ setRequesting(true); try { - const res = await api({ + const res: PagingData = await api({ pageNum: num, pageSize, ...params @@ -75,6 +75,7 @@ export const usePaging = ({ requesting, isLoadAll, nextPage, - initRequesting + initRequesting, + setData }; }; diff --git a/src/pages/api/chat/chatGpt.ts b/src/pages/api/chat/chatGpt.ts index 8a877bbaf..c42f95411 100644 --- a/src/pages/api/chat/chatGpt.ts +++ b/src/pages/api/chat/chatGpt.ts @@ -46,7 +46,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) const model: ModelSchema = chat.modelId; const modelConstantsData = modelList.find((item) => item.model === model.service.modelName); if (!modelConstantsData) { - throw new Error('模型异常,请用 chatgpt 模型'); + throw new Error('模型加载异常'); } // 读取对话内容 diff --git a/src/pages/api/chat/vectorGpt.ts b/src/pages/api/chat/vectorGpt.ts new file mode 100644 index 000000000..9f0a14d0f --- /dev/null +++ b/src/pages/api/chat/vectorGpt.ts @@ -0,0 +1,241 @@ +import type { NextApiRequest, NextApiResponse } from 'next'; +import { createParser, ParsedEvent, ReconnectInterval } from 'eventsource-parser'; +import { connectToDatabase, ModelData } from '@/service/mongo'; +import { getOpenAIApi, authChat } from '@/service/utils/chat'; +import { httpsAgent, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools'; +import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai'; +import { ChatItemType } from '@/types/chat'; +import { jsonRes } from '@/service/response'; +import type { ModelSchema } from '@/types/mongoSchema'; +import { PassThrough } from 'stream'; +import { modelList } from '@/constants/model'; +import { pushChatBill } from '@/service/events/pushBill'; +import { connectRedis } from '@/service/redis'; +import { VecModelDataIndex } from '@/constants/redis'; +import { vectorToBuffer } from '@/utils/tools'; + +let vectorData = [ + -0.025028639, -0.010407282, 0.026523087, -0.0107438695, -0.006967359, 0.010043768, -0.012043097, + 0.008724345, -0.028919589, -0.0117738275, 0.0050690062, 0.02961969 +].concat(new Array(1524).fill(0)); + +/* 发送提示词 */ +export default async function handler(req: NextApiRequest, res: NextApiResponse) { + let step = 0; // step=1时,表示开始了流响应 + const stream = new PassThrough(); + stream.on('error', () => { + console.log('error: ', 'stream error'); + stream.destroy(); + }); + res.on('close', () => { + stream.destroy(); + }); + res.on('error', () => { + console.log('error: ', 'request error'); + stream.destroy(); + }); + + try { + const { chatId, prompt } = req.body as { + prompt: ChatItemType; + chatId: string; + }; + + const { authorization } = req.headers; + if (!chatId || !prompt) { + throw new Error('缺少参数'); + } + + await connectToDatabase(); + const redis = await connectRedis(); + + const { chat, userApiKey, systemKey, userId } = await authChat(chatId, authorization); + + const model: ModelSchema = chat.modelId; + const modelConstantsData = modelList.find((item) => item.model === model.service.modelName); + if (!modelConstantsData) { + throw new Error('模型加载异常'); + } + + // 读取对话内容 + const prompts = [...chat.content, prompt]; + + // 获取 chatAPI + const chatAPI = getOpenAIApi(userApiKey || systemKey); + + // 把输入的内容转成向量 + const promptVector = await chatAPI + .createEmbedding( + { + model: 'text-embedding-ada-002', + input: prompt.value + }, + { + timeout: 120000, + httpsAgent + } + ) + .then((res) => res?.data?.data?.[0]?.embedding || []); + + const binary = vectorToBuffer(promptVector); + + // 搜索系统提示词, 按相似度从 redis 中搜出前3条不同 dataId 的数据 + const redisData: any[] = await redis.sendCommand([ + 'FT.SEARCH', + `idx:${VecModelDataIndex}`, + `@modelId:{${String(chat.modelId._id)}} @vector:[VECTOR_RANGE 0.2 $blob]`, + // `@modelId:{${String(chat.modelId._id)}}=>[KNN 10 @vector $blob AS score]`, + 'RETURN', + '1', + 'dataId', + // 'SORTBY', + // 'score', + 'PARAMS', + '2', + 'blob', + binary, + 'DIALECT', + '2' + ]); + + // 格式化响应值,获取去重后的id + let formatIds = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20] + .map((i) => { + if (!redisData[i] || !redisData[i][1]) return ''; + return redisData[i][1]; + }) + .filter((item) => item); + formatIds = Array.from(new Set(formatIds)); + + if (formatIds.length === 0) { + throw new Error('对不起,我没有找到你的问题'); + } + + // 从 mongo 中取出原文作为提示词 + const textArr = ( + await Promise.all( + [2, 4, 6, 8, 10, 12, 14, 16, 18, 20].map((i) => { + if (!redisData[i] || !redisData[i][1]) return ''; + return ModelData.findById(redisData[i][1]) + .select('text') + .then((res) => res?.text || ''); + }) + ) + ).filter((item) => item); + + // textArr 筛选,最多 3000 tokens + const systemPrompt = systemPromptFilter(textArr, 2800); + + prompts.unshift({ + obj: 'SYSTEM', + value: `请根据下面的知识回答问题: ${systemPrompt}` + }); + + // 控制在 tokens 数量,防止超出 + const filterPrompts = openaiChatFilter(prompts, modelConstantsData.contextMaxToken); + + // 格式化文本内容成 chatgpt 格式 + const map = { + Human: ChatCompletionRequestMessageRoleEnum.User, + AI: ChatCompletionRequestMessageRoleEnum.Assistant, + SYSTEM: ChatCompletionRequestMessageRoleEnum.System + }; + const formatPrompts: ChatCompletionRequestMessage[] = filterPrompts.map( + (item: ChatItemType) => ({ + role: map[item.obj], + content: item.value + }) + ); + // console.log(formatPrompts); + // 计算温度 + const temperature = modelConstantsData.maxTemperature * (model.temperature / 10); + + let startTime = Date.now(); + // 发出请求 + const chatResponse = await chatAPI.createChatCompletion( + { + model: model.service.chatModel, + temperature: temperature, + // max_tokens: modelConstantsData.maxToken, + messages: formatPrompts, + frequency_penalty: 0.5, // 越大,重复内容越少 + presence_penalty: -0.5, // 越大,越容易出现新内容 + stream: true + }, + { + timeout: 40000, + responseType: 'stream', + httpsAgent + } + ); + + console.log('api response time:', `${(Date.now() - startTime) / 1000}s`); + + // 创建响应流 + res.setHeader('Content-Type', 'text/event-stream;charset-utf-8'); + res.setHeader('Access-Control-Allow-Origin', '*'); + res.setHeader('X-Accel-Buffering', 'no'); + res.setHeader('Cache-Control', 'no-cache, no-transform'); + step = 1; + + let responseContent = ''; + stream.pipe(res); + + const onParse = async (event: ParsedEvent | ReconnectInterval) => { + if (event.type !== 'event') return; + const data = event.data; + if (data === '[DONE]') return; + try { + const json = JSON.parse(data); + const content: string = json?.choices?.[0].delta.content || ''; + if (!content || (responseContent === '' && content === '\n')) return; + + responseContent += content; + // console.log('content:', content) + !stream.destroyed && stream.push(content.replace(/\n/g, '
')); + } catch (error) { + error; + } + }; + + const decoder = new TextDecoder(); + try { + for await (const chunk of chatResponse.data as any) { + if (stream.destroyed) { + // 流被中断了,直接忽略后面的内容 + break; + } + const parser = createParser(onParse); + parser.feed(decoder.decode(chunk)); + } + } catch (error) { + console.log('pipe error', error); + } + // close stream + !stream.destroyed && stream.push(null); + stream.destroy(); + + const promptsContent = formatPrompts.map((item) => item.content).join(''); + // 只有使用平台的 key 才计费 + pushChatBill({ + isPay: !userApiKey, + modelName: model.service.modelName, + userId, + chatId, + text: promptsContent + responseContent + }); + // jsonRes(res); + } catch (err: any) { + if (step === 1) { + // 直接结束流 + console.log('error,结束'); + stream.destroy(); + } else { + res.status(500); + jsonRes(res, { + code: 500, + error: err + }); + } + } +} diff --git a/src/pages/api/data/splitData.ts b/src/pages/api/data/splitData.ts index 16bded265..13143ae6e 100644 --- a/src/pages/api/data/splitData.ts +++ b/src/pages/api/data/splitData.ts @@ -24,7 +24,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) if (!DataRecord) { throw new Error('找不到数据集'); } - const replaceText = text.replace(/[\r\n\\n]+/g, ' '); + const replaceText = text.replace(/[\\n]+/g, ' '); // 文本拆分成 chunk let chunks = replaceText.match(/[^!?.。]+[!?.。]/g) || []; @@ -35,7 +35,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) chunks.forEach((chunk) => { splitText += chunk; const tokens = encode(splitText).length; - if (tokens >= 980) { + if (tokens >= 780) { dataItems.push({ userId, dataId, diff --git a/src/pages/api/model/create.ts b/src/pages/api/model/create.ts index 259c46bea..b7def6ae5 100644 --- a/src/pages/api/model/create.ts +++ b/src/pages/api/model/create.ts @@ -3,7 +3,7 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; import { connectToDatabase } from '@/service/mongo'; import { authToken } from '@/service/utils/tools'; -import { ModelStatusEnum, modelList, ChatModelNameEnum } from '@/constants/model'; +import { ModelStatusEnum, modelList, ChatModelNameEnum, ChatModelNameMap } from '@/constants/model'; import { Model } from '@/service/models/model'; export default async function handler(req: NextApiRequest, res: NextApiResponse) { @@ -33,15 +33,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< await connectToDatabase(); - // 重名校验 - const authRepeatName = await Model.findOne({ - name, - userId - }); - if (authRepeatName) { - throw new Error('模型名重复'); - } - // 上限校验 const authCount = await Model.countDocuments({ userId @@ -57,9 +48,9 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< status: ModelStatusEnum.running, service: { company: modelItem.serviceCompany, - trainId: modelItem.trainName, - chatModel: modelItem.model, - modelName: modelItem.model + trainId: '', + chatModel: ChatModelNameMap[modelItem.model], // 聊天时用的模型 + modelName: modelItem.model // 最底层的模型,不会变,用于计费等核心操作 } }); diff --git a/src/pages/api/model/data/delMoedlDataById.ts b/src/pages/api/model/data/delModelDataById.ts similarity index 88% rename from src/pages/api/model/data/delMoedlDataById.ts rename to src/pages/api/model/data/delModelDataById.ts index 13c8ea07a..959a4baf9 100644 --- a/src/pages/api/model/data/delMoedlDataById.ts +++ b/src/pages/api/model/data/delModelDataById.ts @@ -5,8 +5,8 @@ import { authToken } from '@/service/utils/tools'; export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { - let { modelId } = req.query as { - modelId: string; + let { dataId } = req.query as { + dataId: string; }; const { authorization } = req.headers; @@ -14,7 +14,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< throw new Error('无权操作'); } - if (!modelId) { + if (!dataId) { throw new Error('缺少参数'); } @@ -24,7 +24,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< await connectToDatabase(); await ModelData.deleteOne({ - modelId, + _id: dataId, userId }); diff --git a/src/pages/api/model/data/getModelData.ts b/src/pages/api/model/data/getModelData.ts index 3c41ec458..ec30dd9ed 100644 --- a/src/pages/api/model/data/getModelData.ts +++ b/src/pages/api/model/data/getModelData.ts @@ -14,6 +14,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< pageNum: string; pageSize: string; }; + const { authorization } = req.headers; pageNum = +pageNum; @@ -41,7 +42,15 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< .limit(pageSize); jsonRes(res, { - data + data: { + pageNum, + pageSize, + data, + total: await ModelData.countDocuments({ + modelId, + userId + }) + } }); } catch (err) { jsonRes(res, { diff --git a/src/pages/api/model/data/pushModelData.ts b/src/pages/api/model/data/pushModelDataInput.ts similarity index 84% rename from src/pages/api/model/data/pushModelData.ts rename to src/pages/api/model/data/pushModelDataInput.ts index 426dfab69..9b0b5614e 100644 --- a/src/pages/api/model/data/pushModelData.ts +++ b/src/pages/api/model/data/pushModelDataInput.ts @@ -2,12 +2,14 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; import { connectToDatabase, ModelData, Model } from '@/service/mongo'; import { authToken } from '@/service/utils/tools'; +import { ModelDataSchema } from '@/types/mongoSchema'; +import { generateVector } from '@/service/events/generateVector'; export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { const { modelId, data } = req.body as { modelId: string; - data: { q: string; a: string }[]; + data: { text: ModelDataSchema['text']; q: ModelDataSchema['q'] }[]; }; const { authorization } = req.headers; @@ -43,6 +45,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< })) ); + generateVector(true); + jsonRes(res, { data: model }); diff --git a/src/pages/api/model/data/pushModelDataSelectData.ts b/src/pages/api/model/data/pushModelDataSelectData.ts new file mode 100644 index 000000000..f7e01606e --- /dev/null +++ b/src/pages/api/model/data/pushModelDataSelectData.ts @@ -0,0 +1,57 @@ +import type { NextApiRequest, NextApiResponse } from 'next'; +import { jsonRes } from '@/service/response'; +import { connectToDatabase, DataItem, ModelData } from '@/service/mongo'; +import { authToken } from '@/service/utils/tools'; +import { customAlphabet } from 'nanoid'; +const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12); + +export default async function handler(req: NextApiRequest, res: NextApiResponse) { + try { + let { dataIds, modelId } = req.body as { dataIds: string[]; modelId: string }; + + if (!dataIds) { + throw new Error('参数错误'); + } + await connectToDatabase(); + + const { authorization } = req.headers; + + const userId = await authToken(authorization); + + const dataItems = ( + await Promise.all( + dataIds.map((dataId) => + DataItem.find<{ _id: string; result: { q: string }[]; text: string }>( + { + userId, + dataId + }, + 'result text' + ) + ) + ) + ).flat(); + + // push data + await ModelData.insertMany( + dataItems.map((item) => ({ + modelId: modelId, + userId, + text: item.text, + q: item.result.map((item) => ({ + id: nanoid(), + text: item.q + })) + })) + ); + + jsonRes(res, { + data: dataItems + }); + } catch (err) { + jsonRes(res, { + code: 500, + error: err + }); + } +} diff --git a/src/pages/api/model/data/putModelData.ts b/src/pages/api/model/data/putModelData.ts index 776865f40..2c13b8526 100644 --- a/src/pages/api/model/data/putModelData.ts +++ b/src/pages/api/model/data/putModelData.ts @@ -5,9 +5,9 @@ import { authToken } from '@/service/utils/tools'; export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { - let { modelId, answer } = req.body as { - modelId: string; - answer: string; + let { dataId, text } = req.body as { + dataId: string; + text: string; }; const { authorization } = req.headers; @@ -15,7 +15,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< throw new Error('无权操作'); } - if (!modelId) { + if (!dataId) { throw new Error('缺少参数'); } @@ -26,11 +26,11 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< await ModelData.updateOne( { - modelId, + _id: dataId, userId }, { - a: answer + text } ); diff --git a/src/pages/api/model/data/splitData.ts b/src/pages/api/model/data/splitData.ts new file mode 100644 index 000000000..379d952e1 --- /dev/null +++ b/src/pages/api/model/data/splitData.ts @@ -0,0 +1,67 @@ +import type { NextApiRequest, NextApiResponse } from 'next'; +import { jsonRes } from '@/service/response'; +import { connectToDatabase, SplitData, Model } from '@/service/mongo'; +import { authToken } from '@/service/utils/tools'; +import { generateQA } from '@/service/events/generateQA'; +import { encode } from 'gpt-token-utils'; + +/* 拆分数据成QA */ +export default async function handler(req: NextApiRequest, res: NextApiResponse) { + try { + const { text, modelId } = req.body as { text: string; modelId: string }; + if (!text || !modelId) { + throw new Error('参数错误'); + } + await connectToDatabase(); + + const { authorization } = req.headers; + + const userId = await authToken(authorization); + + // 验证是否是该用户的 model + const model = await Model.findOne({ + _id: modelId, + userId + }); + + if (!model) { + throw new Error('无权操作该模型'); + } + + const replaceText = text.replace(/(\\n|\n)+/g, ' '); + + // 文本拆分成 chunk + let chunks = replaceText.match(/[^!?.。]+[!?.。]/g) || []; + + const textList: string[] = []; + let splitText = ''; + + chunks.forEach((chunk) => { + splitText += chunk; + const tokens = encode(splitText).length; + if (tokens >= 980) { + textList.push(splitText); + splitText = ''; + } + }); + + // 批量插入数据 + await SplitData.create({ + userId, + modelId, + rawText: text, + textList + }); + + // generateQA(); + + jsonRes(res, { + data: { chunks, replaceText } + }); + } catch (err) { + jsonRes(res, { + code: 500, + error: err + }); + } +} diff --git a/src/pages/api/model/del.ts b/src/pages/api/model/del.ts index 18dc5b8fa..976ca96f6 100644 --- a/src/pages/api/model/del.ts +++ b/src/pages/api/model/del.ts @@ -1,6 +1,6 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; -import { Chat, Model, Training, connectToDatabase } from '@/service/mongo'; +import { Chat, Model, Training, connectToDatabase, ModelData } from '@/service/mongo'; import { authToken, getUserOpenaiKey } from '@/service/utils/tools'; import { TrainingStatusEnum } from '@/constants/model'; import { getOpenAIApi } from '@/service/utils/chat'; @@ -26,16 +26,20 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< await connectToDatabase(); - // 删除模型 - await Model.deleteOne({ - _id: modelId, - userId - }); - + let requestQueue: any[] = []; // 删除对应的聊天 - await Chat.deleteMany({ - modelId - }); + requestQueue.push( + Chat.deleteMany({ + modelId + }) + ); + + // 删除数据集 + requestQueue.push( + ModelData.deleteMany({ + modelId + }) + ); // 查看是否正在训练 const training: TrainingItemType | null = await Training.findOne({ @@ -56,9 +60,20 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< } // 删除对应训练记录 - await Training.deleteMany({ - modelId - }); + requestQueue.push( + Training.deleteMany({ + modelId + }) + ); + + // 删除模型 + requestQueue.push( + Model.deleteOne({ + _id: modelId, + userId + }) + ); + await requestQueue; jsonRes(res); } catch (err) { diff --git a/src/pages/api/model/update.ts b/src/pages/api/model/update.ts index 1de24e856..af9d013c2 100644 --- a/src/pages/api/model/update.ts +++ b/src/pages/api/model/update.ts @@ -37,7 +37,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< systemPrompt, intro, temperature, - service, + // service, security } ); diff --git a/src/pages/chat/index.tsx b/src/pages/chat/index.tsx index 7f95376b3..6eff04351 100644 --- a/src/pages/chat/index.tsx +++ b/src/pages/chat/index.tsx @@ -119,6 +119,7 @@ const Chat = ({ chatId }: { chatId: string }) => { async (prompts: ChatSiteItemType) => { const urlMap: Record = { [ChatModelNameEnum.GPT35]: '/api/chat/chatGpt', + [ChatModelNameEnum.VECTOR_GPT]: '/api/chat/vectorGpt', [ChatModelNameEnum.GPT3]: '/api/chat/gpt3' }; diff --git a/src/pages/data/list.tsx b/src/pages/data/list.tsx index e1dd0d07a..378d21634 100644 --- a/src/pages/data/list.tsx +++ b/src/pages/data/list.tsx @@ -184,7 +184,7 @@ const DataList = () => { > 导入 - + {/* 导出 @@ -200,7 +200,7 @@ const DataList = () => { )} - + */} + + {/* 介绍: