From ae4243b522858aa88dcdc784f7667890eb8a4bb9 Mon Sep 17 00:00:00 2001 From: archer <545436317@qq.com> Date: Sat, 1 Apr 2023 22:31:56 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E7=9F=A5=E8=AF=86=E5=BA=93=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 5 +- src/api/model.ts | 7 +- src/constants/model.ts | 7 +- src/constants/redis.ts | 7 +- src/pages/api/chat/vectorGpt.ts | 46 ++---- src/pages/api/model/data/delModelDataById.ts | 26 +-- src/pages/api/model/data/getModelData.ts | 39 +++-- .../api/model/data/pushModelDataInput.ts | 31 +++- src/pages/api/model/data/pushModelDataJson.ts | 78 +++++++++ .../api/model/data/pushModelDataSelectData.ts | 57 ------- src/pages/api/model/data/putModelData.ts | 21 ++- src/pages/api/model/del.ts | 78 ++++----- src/pages/api/timer/testVector.ts | 68 -------- src/pages/chat/index.tsx | 155 +++++++++--------- .../detail/components/InputDataModal.tsx | 34 ++-- .../model/detail/components/ModelDataCard.tsx | 68 +++++--- .../detail/components/SelectFileModal.tsx | 59 +++---- .../detail/components/SelectJsonModal.tsx | 145 ++++++++++++++++ src/service/events/generateQA.ts | 48 +++--- src/service/events/generateVector.ts | 93 +++++------ src/service/models/modelData.ts | 37 ----- src/service/mongo.ts | 1 - src/service/redis.ts | 4 +- src/types/mongoSchema.d.ts | 2 +- src/types/redis.d.ts | 7 +- src/utils/tools.ts | 6 + 26 files changed, 611 insertions(+), 518 deletions(-) create mode 100644 src/pages/api/model/data/pushModelDataJson.ts delete mode 100644 src/pages/api/model/data/pushModelDataSelectData.ts delete mode 100644 src/pages/api/timer/testVector.ts create mode 100644 src/pages/model/detail/components/SelectJsonModal.tsx delete mode 100644 src/service/models/modelData.ts diff --git a/README.md b/README.md index c2a271032..bafcb9ad8 100644 --- a/README.md +++ b/README.md @@ -107,5 +107,6 @@ echo "Restart clash" ```bash # 索引 # FT.CREATE idx:model:data ON JSON PREFIX 1 model:data: SCHEMA $.modelId AS modelId TAG $.dataId AS dataId TAG $.vector AS vector VECTOR FLAT 6 DIM 1536 DISTANCE_METRIC COSINE TYPE FLOAT32 -FT.CREATE idx:model:data:hash ON HASH PREFIX 1 model:data: SCHEMA modelId TAG dataId TAG vector VECTOR FLAT 6 DIM 1536 DISTANCE_METRIC COSINE TYPE FLOAT32 -``` \ No newline at end of file +# FT.CREATE idx:model:data:hash ON HASH PREFIX 1 model:data: SCHEMA modelId TAG dataId TAG vector VECTOR FLAT 6 DIM 1536 DISTANCE_METRIC COSINE TYPE FLOAT32 +FT.CREATE idx:model:data ON HASH PREFIX 1 model:data: SCHEMA modelId TAG userId TAG q TEXT text TEXT vector VECTOR FLAT 6 DIM 1536 DISTANCE_METRIC COSINE TYPE FLOAT32 +``` diff --git a/src/api/model.ts b/src/api/model.ts index 192bfe51c..ff11ce8f4 100644 --- a/src/api/model.ts +++ b/src/api/model.ts @@ -44,11 +44,16 @@ export const getModelSplitDataList = (modelId: string) => export const postModelDataInput = (data: { modelId: string; data: { text: ModelDataSchema['text']; q: ModelDataSchema['q'] }[]; -}) => POST(`/model/data/pushModelDataInput`, data); +}) => POST(`/model/data/pushModelDataInput`, data); export const postModelDataFileText = (modelId: string, text: string) => POST(`/model/data/splitData`, { modelId, text }); +export const postModelDataJsonData = ( + modelId: string, + jsonData: { prompt: string; completion: string; vector?: number[] }[] +) => POST(`/model/data/pushModelDataJson`, { modelId, data: jsonData }); + export const putModelDataById = (data: { dataId: string; text: string }) => PUT('/model/data/putModelData', data); export const delOneModelData = (dataId: string) => diff --git a/src/constants/model.ts b/src/constants/model.ts index 784908bcd..61fd8b9a4 100644 --- a/src/constants/model.ts +++ b/src/constants/model.ts @@ -1,4 +1,5 @@ import type { ServiceName, ModelDataType, ModelSchema } from '@/types/mongoSchema'; +import type { RedisModelDataItemType } from '@/types/redis'; export enum ChatModelNameEnum { GPT35 = 'gpt-3.5-turbo', @@ -93,9 +94,9 @@ export const formatModelStatus = { } }; -export const ModelDataStatusMap = { - 0: '训练完成', - 1: '训练中' +export const ModelDataStatusMap: Record = { + ready: '训练完成', + waiting: '训练中' }; export const defaultModel: ModelSchema = { diff --git a/src/constants/redis.ts b/src/constants/redis.ts index 9b0edc618..cb045a03c 100644 --- a/src/constants/redis.ts +++ b/src/constants/redis.ts @@ -1 +1,6 @@ -export const VecModelDataIndex = 'model:data'; +export const VecModelDataPrefix = 'model:data'; +export const VecModelDataIdx = `idx:${VecModelDataPrefix}:hash`; +export enum ModelDataStatusEnum { + ready = 'ready', + waiting = 'waiting' +} diff --git a/src/pages/api/chat/vectorGpt.ts b/src/pages/api/chat/vectorGpt.ts index 1507aa13b..dca7a3c1e 100644 --- a/src/pages/api/chat/vectorGpt.ts +++ b/src/pages/api/chat/vectorGpt.ts @@ -1,6 +1,6 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { createParser, ParsedEvent, ReconnectInterval } from 'eventsource-parser'; -import { connectToDatabase, ModelData } from '@/service/mongo'; +import { connectToDatabase } from '@/service/mongo'; import { getOpenAIApi, authChat } from '@/service/utils/chat'; import { httpsAgent, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools'; import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai'; @@ -11,7 +11,7 @@ import { PassThrough } from 'stream'; import { modelList } from '@/constants/model'; import { pushChatBill } from '@/service/events/pushBill'; import { connectRedis } from '@/service/redis'; -import { VecModelDataIndex } from '@/constants/redis'; +import { VecModelDataPrefix } from '@/constants/redis'; import { vectorToBuffer } from '@/utils/tools'; /* 发送提示词 */ @@ -73,17 +73,17 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) ) .then((res) => res?.data?.data?.[0]?.embedding || []); - // 搜索系统提示词, 按相似度从 redis 中搜出前3条不同 dataId 的数据 + // 搜索系统提示词, 按相似度从 redis 中搜出相关的 q 和 text const redisData: any[] = await redis.sendCommand([ 'FT.SEARCH', - `idx:${VecModelDataIndex}:hash`, + `idx:${VecModelDataPrefix}:hash`, `@modelId:{${String( chat.modelId._id )}} @vector:[VECTOR_RANGE 0.15 $blob]=>{$YIELD_DISTANCE_AS: score}`, // `@modelId:{${String(chat.modelId._id)}}=>[KNN 10 @vector $blob AS score]`, 'RETURN', '1', - 'dataId', + 'text', 'SORTBY', 'score', 'PARAMS', @@ -97,42 +97,28 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) '2' ]); - // 格式化响应值,获取去重后的id - let formatIds = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20] + // 格式化响应值,获取 qa + const formatRedisPrompt = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20] .map((i) => { - if (!redisData[i] || !redisData[i][1]) return ''; - return redisData[i][1]; + if (!redisData[i]) return ''; + const text = (redisData[i][1] as string) || ''; + + if (!text) return ''; + + return text; }) .filter((item) => item); - formatIds = Array.from(new Set(formatIds)); - if (formatIds.length === 0) { + if (formatRedisPrompt.length === 0) { throw new Error('对不起,我没有找到你的问题'); } - // 从 mongo 中取出原文作为提示词 - const textArr = ( - await Promise.all( - [2, 4, 6, 8, 10, 12, 14, 16, 18, 20].map((i) => { - if (!redisData[i] || !redisData[i][1]) return ''; - return ModelData.findById(redisData[i][1]) - .select('text q') - .then((res) => { - if (!res) return ''; - // const questions = res.q.map((item) => item.text).join(' '); - const answer = res.text; - return `${answer}`; - }); - }) - ) - ).filter((item) => item); - // textArr 筛选,最多 3000 tokens - const systemPrompt = systemPromptFilter(textArr, 3400); + const systemPrompt = systemPromptFilter(formatRedisPrompt, 3400); prompts.unshift({ obj: 'SYSTEM', - value: `${model.systemPrompt}。 我的知识库: "${systemPrompt}"` + value: `${model.systemPrompt} 我的知识库: "${systemPrompt}"` }); // 控制在 tokens 数量,防止超出 diff --git a/src/pages/api/model/data/delModelDataById.ts b/src/pages/api/model/data/delModelDataById.ts index 9f34d8556..aa4059c07 100644 --- a/src/pages/api/model/data/delModelDataById.ts +++ b/src/pages/api/model/data/delModelDataById.ts @@ -1,9 +1,7 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; -import { connectToDatabase, ModelData } from '@/service/mongo'; import { authToken } from '@/service/utils/tools'; import { connectRedis } from '@/service/redis'; -import { VecModelDataIndex } from '@/constants/redis'; export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { @@ -23,25 +21,15 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< // 凭证校验 const userId = await authToken(authorization); - await connectToDatabase(); const redis = await connectRedis(); - const data = await ModelData.findById(dataId); - - await ModelData.deleteOne({ - _id: dataId, - userId - }); - - // 删除 redis 数据 - data?.q.forEach(async (item) => { - try { - await redis.json.del(`${VecModelDataIndex}:${item.id}`); - } catch (error) { - console.log(error); - } - }); - + // 校验是否为该用户的数据 + const dataItemUserId = await redis.hGet(dataId, 'userId'); + if (dataItemUserId !== userId) { + throw new Error('无权操作'); + } + // 删除 + await redis.del(dataId); jsonRes(res); } catch (err) { console.log(err); diff --git a/src/pages/api/model/data/getModelData.ts b/src/pages/api/model/data/getModelData.ts index ec30dd9ed..ce93594db 100644 --- a/src/pages/api/model/data/getModelData.ts +++ b/src/pages/api/model/data/getModelData.ts @@ -1,7 +1,10 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; -import { connectToDatabase, ModelData } from '@/service/mongo'; +import { connectToDatabase } from '@/service/mongo'; import { authToken } from '@/service/utils/tools'; +import { connectRedis } from '@/service/redis'; +import { VecModelDataIdx } from '@/constants/redis'; +import { SearchOptions } from 'redis'; export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { @@ -32,24 +35,34 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< const userId = await authToken(authorization); await connectToDatabase(); + const redis = await connectRedis(); - const data = await ModelData.find({ - modelId, - userId - }) - .sort({ _id: -1 }) // 按照创建时间倒序排列 - .skip((pageNum - 1) * pageSize) - .limit(pageSize); + // 从 redis 中获取数据 + const searchRes = await redis.ft.search( + VecModelDataIdx, + `@modelId:{${modelId}} @userId:{${userId}}`, + { + RETURN: ['q', 'text', 'status'], + LIMIT: { + from: (pageNum - 1) * pageSize, + size: pageSize + }, + SORTBY: { + BY: 'modelId', + DIRECTION: 'DESC' + } + } + ); jsonRes(res, { data: { pageNum, pageSize, - data, - total: await ModelData.countDocuments({ - modelId, - userId - }) + data: searchRes.documents.map((item) => ({ + id: item.id, + ...item.value + })), + total: searchRes.total } }); } catch (err) { diff --git a/src/pages/api/model/data/pushModelDataInput.ts b/src/pages/api/model/data/pushModelDataInput.ts index 9b0b5614e..65679fcf5 100644 --- a/src/pages/api/model/data/pushModelDataInput.ts +++ b/src/pages/api/model/data/pushModelDataInput.ts @@ -1,9 +1,11 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; -import { connectToDatabase, ModelData, Model } from '@/service/mongo'; +import { connectToDatabase, Model } from '@/service/mongo'; import { authToken } from '@/service/utils/tools'; import { ModelDataSchema } from '@/types/mongoSchema'; import { generateVector } from '@/service/events/generateVector'; +import { connectRedis } from '@/service/redis'; +import { VecModelDataPrefix, ModelDataStatusEnum } from '@/constants/redis'; export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { @@ -25,6 +27,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< const userId = await authToken(authorization); await connectToDatabase(); + const redis = await connectRedis(); // 验证是否是该用户的 model const model = await Model.findOne({ @@ -36,19 +39,29 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< throw new Error('无权操作该模型'); } - // push data - await ModelData.insertMany( - data.map((item) => ({ - ...item, - modelId, - userId - })) + const insertRes = await Promise.allSettled( + data.map((item) => { + return redis.sendCommand([ + 'HMSET', + `${VecModelDataPrefix}:${item.q.id}`, + 'userId', + userId, + 'modelId', + modelId, + 'q', + item.q.text, + 'text', + item.text, + 'status', + ModelDataStatusEnum.waiting + ]); + }) ); generateVector(true); jsonRes(res, { - data: model + data: insertRes.filter((item) => item.status === 'rejected').length }); } catch (err) { jsonRes(res, { diff --git a/src/pages/api/model/data/pushModelDataJson.ts b/src/pages/api/model/data/pushModelDataJson.ts new file mode 100644 index 000000000..9b81b7dcb --- /dev/null +++ b/src/pages/api/model/data/pushModelDataJson.ts @@ -0,0 +1,78 @@ +import type { NextApiRequest, NextApiResponse } from 'next'; +import { jsonRes } from '@/service/response'; +import { connectToDatabase, Model } from '@/service/mongo'; +import { authToken } from '@/service/utils/tools'; +import { generateVector } from '@/service/events/generateVector'; +import { vectorToBuffer, formatVector } from '@/utils/tools'; +import { connectRedis } from '@/service/redis'; +import { VecModelDataPrefix, ModelDataStatusEnum } from '@/constants/redis'; +import { customAlphabet } from 'nanoid'; +const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12); + +export default async function handler(req: NextApiRequest, res: NextApiResponse) { + try { + const { modelId, data } = req.body as { + modelId: string; + data: { prompt: string; completion: string; vector?: number[] }[]; + }; + const { authorization } = req.headers; + + if (!authorization) { + throw new Error('无权操作'); + } + + if (!modelId || !Array.isArray(data)) { + throw new Error('缺少参数'); + } + + // 凭证校验 + const userId = await authToken(authorization); + + await connectToDatabase(); + const redis = await connectRedis(); + + // 验证是否是该用户的 model + const model = await Model.findOne({ + _id: modelId, + userId + }); + + if (!model) { + throw new Error('无权操作该模型'); + } + + // 插入 redis + const insertRedisRes = await Promise.allSettled( + data.map((item) => { + const vector = item.vector; + + return redis.sendCommand([ + 'HMSET', + `${VecModelDataPrefix}:${nanoid()}`, + 'userId', + userId, + 'modelId', + String(modelId), + ...(vector ? ['vector', vectorToBuffer(formatVector(vector))] : []), + 'q', + item.prompt, + 'text', + item.completion, + 'status', + vector ? ModelDataStatusEnum.ready : ModelDataStatusEnum.waiting + ]); + }) + ); + + generateVector(true); + + jsonRes(res, { + data: insertRedisRes.filter((item) => item.status === 'rejected').length + }); + } catch (err) { + jsonRes(res, { + code: 500, + error: err + }); + } +} diff --git a/src/pages/api/model/data/pushModelDataSelectData.ts b/src/pages/api/model/data/pushModelDataSelectData.ts deleted file mode 100644 index f7e01606e..000000000 --- a/src/pages/api/model/data/pushModelDataSelectData.ts +++ /dev/null @@ -1,57 +0,0 @@ -import type { NextApiRequest, NextApiResponse } from 'next'; -import { jsonRes } from '@/service/response'; -import { connectToDatabase, DataItem, ModelData } from '@/service/mongo'; -import { authToken } from '@/service/utils/tools'; -import { customAlphabet } from 'nanoid'; -const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12); - -export default async function handler(req: NextApiRequest, res: NextApiResponse) { - try { - let { dataIds, modelId } = req.body as { dataIds: string[]; modelId: string }; - - if (!dataIds) { - throw new Error('参数错误'); - } - await connectToDatabase(); - - const { authorization } = req.headers; - - const userId = await authToken(authorization); - - const dataItems = ( - await Promise.all( - dataIds.map((dataId) => - DataItem.find<{ _id: string; result: { q: string }[]; text: string }>( - { - userId, - dataId - }, - 'result text' - ) - ) - ) - ).flat(); - - // push data - await ModelData.insertMany( - dataItems.map((item) => ({ - modelId: modelId, - userId, - text: item.text, - q: item.result.map((item) => ({ - id: nanoid(), - text: item.q - })) - })) - ); - - jsonRes(res, { - data: dataItems - }); - } catch (err) { - jsonRes(res, { - code: 500, - error: err - }); - } -} diff --git a/src/pages/api/model/data/putModelData.ts b/src/pages/api/model/data/putModelData.ts index 2c13b8526..f4d24997e 100644 --- a/src/pages/api/model/data/putModelData.ts +++ b/src/pages/api/model/data/putModelData.ts @@ -1,7 +1,7 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; -import { connectToDatabase, ModelData } from '@/service/mongo'; import { authToken } from '@/service/utils/tools'; +import { connectRedis } from '@/service/redis'; export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { @@ -22,17 +22,16 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< // 凭证校验 const userId = await authToken(authorization); - await connectToDatabase(); + const redis = await connectRedis(); - await ModelData.updateOne( - { - _id: dataId, - userId - }, - { - text - } - ); + // 校验是否为该用户的数据 + const dataItemUserId = await redis.hGet(dataId, 'userId'); + if (dataItemUserId !== userId) { + throw new Error('无权操作'); + } + + // 更新 + await redis.hSet(dataId, 'text', text); jsonRes(res); } catch (err) { diff --git a/src/pages/api/model/del.ts b/src/pages/api/model/del.ts index 2d6e1729f..aeb367f44 100644 --- a/src/pages/api/model/del.ts +++ b/src/pages/api/model/del.ts @@ -1,13 +1,12 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; -import { Chat, Model, Training, connectToDatabase, ModelData } from '@/service/mongo'; +import { Chat, Model, Training, connectToDatabase } from '@/service/mongo'; import { authToken, getUserApiOpenai } from '@/service/utils/tools'; import { TrainingStatusEnum } from '@/constants/model'; -import { getOpenAIApi } from '@/service/utils/chat'; import { TrainingItemType } from '@/types/training'; import { httpsAgent } from '@/service/utils/tools'; import { connectRedis } from '@/service/redis'; -import { VecModelDataIndex } from '@/constants/redis'; +import { VecModelDataIdx } from '@/constants/redis'; /* 获取我的模型 */ export default async function handler(req: NextApiRequest, res: NextApiResponse) { @@ -26,39 +25,38 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< // 凭证校验 const userId = await authToken(authorization); + // 验证是否是该用户的 model + const model = await Model.findOne({ + _id: modelId, + userId + }); + + if (!model) { + throw new Error('无权操作该模型'); + } + await connectToDatabase(); const redis = await connectRedis(); - const modelDataList = await ModelData.find({ + // 获取 redis 中模型关联的所有数据 + const searchRes = await redis.ft.search( + VecModelDataIdx, + `@modelId:{${modelId}} @userId:{${userId}}`, + { + LIMIT: { + from: 0, + size: 10000 + } + } + ); + // 删除 redis 内容 + await Promise.all(searchRes.documents.map((item) => redis.del(item.id))); + + // 删除对应的聊天 + await Chat.deleteMany({ modelId }); - // 删除 redis - modelDataList?.forEach((modelData) => - modelData.q.forEach(async (item) => { - try { - await redis.json.del(`${VecModelDataIndex}:${item.id}`); - } catch (error) { - console.log(error); - } - }) - ); - - let requestQueue: any[] = []; - // 删除对应的聊天 - requestQueue.push( - Chat.deleteMany({ - modelId - }) - ); - - // 删除数据集 - requestQueue.push( - ModelData.deleteMany({ - modelId - }) - ); - // 查看是否正在训练 const training: TrainingItemType | null = await Training.findOne({ modelId, @@ -78,21 +76,15 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< } // 删除对应训练记录 - requestQueue.push( - Training.deleteMany({ - modelId - }) - ); + await Training.deleteMany({ + modelId + }); // 删除模型 - requestQueue.push( - Model.deleteOne({ - _id: modelId, - userId - }) - ); - - await Promise.all(requestQueue); + await Model.deleteOne({ + _id: modelId, + userId + }); jsonRes(res); } catch (err) { diff --git a/src/pages/api/timer/testVector.ts b/src/pages/api/timer/testVector.ts deleted file mode 100644 index 89f5f07f8..000000000 --- a/src/pages/api/timer/testVector.ts +++ /dev/null @@ -1,68 +0,0 @@ -// Next.js API route support: https://nextjs.org/docs/api-routes/introduction -import type { NextApiRequest, NextApiResponse } from 'next'; -import { jsonRes } from '@/service/response'; -import { connectToDatabase, Bill } from '@/service/mongo'; -import { authToken } from '@/service/utils/tools'; -import type { BillSchema } from '@/types/mongoSchema'; -import { VecModelDataIndex } from '@/constants/redis'; -import { connectRedis } from '@/service/redis'; -import { vectorToBuffer } from '@/utils/tools'; - -let vectorData = [ - -0.025028639, -0.010407282, 0.026523087, -0.0107438695, -0.006967359, 0.010043768, -0.012043097, - 0.008724345, -0.028919589, -0.0117738275, 0.0050690062, 0.02961969 -].concat(new Array(1524).fill(0)); -let vectorData2 = [ - 0.025028639, 0.010407282, 0.026523087, 0.0107438695, -0.006967359, 0.010043768, -0.012043097, - 0.008724345, 0.028919589, 0.0117738275, 0.0050690062, 0.02961969 -].concat(new Array(1524).fill(0)); - -export default async function handler(req: NextApiRequest, res: NextApiResponse) { - try { - if (process.env.NODE_ENV !== 'development') { - throw new Error('不是开发环境'); - } - await connectToDatabase(); - - const redis = await connectRedis(); - - await redis.sendCommand([ - 'HMSET', - 'model:data:333', - 'vector', - vectorToBuffer(vectorData2), - 'modelId', - '1133', - 'dataId', - 'safadfa' - ]); - - // search - const response = await redis.sendCommand([ - 'FT.SEARCH', - 'idx:model:data:hash', - '@modelId:{1133} @vector:[VECTOR_RANGE 0.15 $blob]=>{$YIELD_DISTANCE_AS: score}', - 'RETURN', - '2', - 'modelId', - 'dataId', - 'PARAMS', - '2', - 'blob', - vectorToBuffer(vectorData2), - 'SORTBY', - 'score', - 'DIALECT', - '2' - ]); - - jsonRes(res, { - data: response - }); - } catch (err) { - jsonRes(res, { - code: 500, - error: err - }); - } -} diff --git a/src/pages/chat/index.tsx b/src/pages/chat/index.tsx index b432e6978..dc93787eb 100644 --- a/src/pages/chat/index.tsx +++ b/src/pages/chat/index.tsx @@ -190,97 +190,91 @@ const Chat = ({ chatId }: { chatId: string }) => { /** * 发送一个内容 */ - const sendPrompt = useCallback( - async (e?: React.MouseEvent) => { - e?.stopPropagation(); - e?.preventDefault(); + const sendPrompt = useCallback(async () => { + const storeInput = inputVal; + // 去除空行 + const val = inputVal + .trim() + .split('\n') + .filter((val) => val) + .join('\n'); + if (!chatData?.modelId || !val || !ChatBox.current || isChatting) { + return; + } - const storeInput = inputVal; - // 去除空行 - const val = inputVal - .trim() - .split('\n') - .filter((val) => val) - .join('\n'); - if (!chatData?.modelId || !val || !ChatBox.current || isChatting) { - return; + // 长度校验 + const tokens = encode(val).length; + const model = modelList.find((item) => item.model === chatData.modelName); + + if (model && tokens >= model.maxToken) { + toast({ + title: '单次输入超出 4000 tokens', + status: 'warning' + }); + return; + } + + const newChatList: ChatSiteItemType[] = [ + ...chatData.history, + { + obj: 'Human', + value: val, + status: 'finish' + }, + { + obj: 'AI', + value: '', + status: 'loading' } + ]; - // 长度校验 - const tokens = encode(val).length; - const model = modelList.find((item) => item.model === chatData.modelName); + // 插入内容 + setChatData((state) => ({ + ...state, + history: newChatList + })); - if (model && tokens >= model.maxToken) { - toast({ - title: '单次输入超出 4000 tokens', - status: 'warning' + // 清空输入内容 + resetInputVal(''); + scrollToBottom(); + + try { + await gptChatPrompt(newChatList[newChatList.length - 2]); + + // 如果是 Human 第一次发送,插入历史记录 + const humanChat = newChatList.filter((item) => item.obj === 'Human'); + if (humanChat.length === 1) { + pushChatHistory({ + chatId, + title: humanChat[0].value }); - return; } + } catch (err: any) { + toast({ + title: typeof err === 'string' ? err : err?.message || '聊天出错了~', + status: 'warning', + duration: 5000, + isClosable: true + }); - const newChatList: ChatSiteItemType[] = [ - ...chatData.history, - { - obj: 'Human', - value: val, - status: 'finish' - }, - { - obj: 'AI', - value: '', - status: 'loading' - } - ]; + resetInputVal(storeInput); - // 插入内容 setChatData((state) => ({ ...state, - history: newChatList + history: newChatList.slice(0, newChatList.length - 2) })); - - // 清空输入内容 - resetInputVal(''); - scrollToBottom(); - - try { - await gptChatPrompt(newChatList[newChatList.length - 2]); - - // 如果是 Human 第一次发送,插入历史记录 - const humanChat = newChatList.filter((item) => item.obj === 'Human'); - if (humanChat.length === 1) { - pushChatHistory({ - chatId, - title: humanChat[0].value - }); - } - } catch (err: any) { - toast({ - title: typeof err === 'string' ? err : err?.message || '聊天出错了~', - status: 'warning', - duration: 5000, - isClosable: true - }); - - resetInputVal(storeInput); - - setChatData((state) => ({ - ...state, - history: newChatList.slice(0, newChatList.length - 2) - })); - } - }, - [ - inputVal, - chatData, - isChatting, - resetInputVal, - scrollToBottom, - toast, - gptChatPrompt, - pushChatHistory, - chatId - ] - ); + } + }, [ + inputVal, + chatData, + isChatting, + resetInputVal, + scrollToBottom, + toast, + gptChatPrompt, + pushChatHistory, + chatId + ]); // 删除一句话 const delChatRecord = useCallback( @@ -474,6 +468,7 @@ const Chat = ({ chatId }: { chatId: string }) => { flex={1} w={0} py={0} + pr={0} border={'none'} _focusVisible={{ border: 'none' diff --git a/src/pages/model/detail/components/InputDataModal.tsx b/src/pages/model/detail/components/InputDataModal.tsx index bb0757320..97f4f746f 100644 --- a/src/pages/model/detail/components/InputDataModal.tsx +++ b/src/pages/model/detail/components/InputDataModal.tsx @@ -45,24 +45,22 @@ const InputDataModal = ({ setImporting(true); try { - await postModelDataInput({ + const res = await postModelDataInput({ modelId: modelId, data: [ { text: e.text, - q: [ - { - id: nanoid(), - text: e.q - } - ] + q: { + id: nanoid(), + text: e.q + } } ] }); toast({ - title: '导入数据成功,需要一段时间训练', - status: 'success' + title: res === 0 ? '导入数据成功,需要一段时间训练' : '数据导入异常', + status: res === 0 ? 'success' : 'warning' }); onClose(); onSuccess(); @@ -88,8 +86,15 @@ const InputDataModal = ({ 手动导入 - - + + 问题 @@ -169,7 +172,7 @@ const ModelDataCard = ({ model }: { model: ModelSchema }) => { aria-label={'delete'} size={'sm'} onClick={async () => { - await delOneModelData(item._id); + await delOneModelData(item.id); refetchData(pageNum); }} /> @@ -188,8 +191,19 @@ const ModelDataCard = ({ model }: { model: ModelSchema }) => { {isOpenInputModal && ( )} - {isOpenSelectModal && ( - + {isOpenSelectFileModal && ( + + )} + {isOpenSelectJsonModal && ( + )} ); diff --git a/src/pages/model/detail/components/SelectFileModal.tsx b/src/pages/model/detail/components/SelectFileModal.tsx index daf6a1abb..823b49bd3 100644 --- a/src/pages/model/detail/components/SelectFileModal.tsx +++ b/src/pages/model/detail/components/SelectFileModal.tsx @@ -100,40 +100,43 @@ const SelectFileModal = ({ }); return ( - + - + 文件导入 - - + + + 支持 {fileExtension} 文件。模型会自动对文本进行 QA 拆分,需要较长训练时间,拆分需要消耗 + tokens,大约0.04元/1k tokens,请确保账号余额充足。 + + + 一共 {fileText.length} 个字,{encode(fileText).length} 个tokens + + - - 支持 {fileExtension} 文件. 会先对文本进行拆分,需要时间较长。 - - 一共 {fileText.length} 个字,{encode(fileText).length} 个tokens - - - {fileText} - - + {fileText} + diff --git a/src/pages/model/detail/components/SelectJsonModal.tsx b/src/pages/model/detail/components/SelectJsonModal.tsx new file mode 100644 index 000000000..508abf5e7 --- /dev/null +++ b/src/pages/model/detail/components/SelectJsonModal.tsx @@ -0,0 +1,145 @@ +import React, { useState, useCallback } from 'react'; +import { + Box, + Flex, + Button, + Modal, + ModalOverlay, + ModalContent, + ModalHeader, + ModalCloseButton, + ModalBody +} from '@chakra-ui/react'; +import { useToast } from '@/hooks/useToast'; +import { useSelectFile } from '@/hooks/useSelectFile'; +import { useConfirm } from '@/hooks/useConfirm'; +import { readTxtContent } from '@/utils/tools'; +import { useMutation } from '@tanstack/react-query'; +import { postModelDataJsonData } from '@/api/model'; +import Markdown from '@/components/Markdown'; + +const SelectJsonModal = ({ + onClose, + onSuccess, + modelId +}: { + onClose: () => void; + onSuccess: () => void; + modelId: string; +}) => { + const [selecting, setSelecting] = useState(false); + const { toast } = useToast(); + const { File, onOpen } = useSelectFile({ fileType: '.json', multiple: true }); + const [fileData, setFileData] = useState< + { prompt: string; completion: string; vector?: number[] }[] + >([]); + const { openConfirm, ConfirmChild } = useConfirm({ + content: '确认导入该数据集?' + }); + + const onSelectFile = useCallback( + async (e: File[]) => { + setSelecting(true); + try { + const jsonData = ( + await Promise.all(e.map((item) => readTxtContent(item).then((text) => JSON.parse(text)))) + ).flat(); + // check 文件类型 + for (let i = 0; i < jsonData.length; i++) { + if (!jsonData[i]?.prompt || !jsonData[i]?.completion) { + throw new Error('缺少 prompt 或 completion'); + } + } + + setFileData(jsonData); + } catch (error: any) { + console.log(error); + toast({ + title: error?.message || 'JSON文件格式有误', + status: 'error' + }); + } + setSelecting(false); + }, + [setSelecting, toast] + ); + + const { mutate, isLoading } = useMutation({ + mutationFn: async () => { + if (!fileData) return; + const res = await postModelDataJsonData(modelId, fileData); + console.log(res); + toast({ + title: '导入数据成功,需要一段拆解和训练', + status: 'success' + }); + onClose(); + onSuccess(); + }, + onError() { + toast({ + title: '导入文件失败', + status: 'error' + }); + } + }); + + return ( + + + + JSON数据集 + + + + + + + + + 一共 {fileData.length} 组数据 + + + + {JSON.stringify(fileData)} + + + + + + + + + + + + + ); +}; + +export default SelectJsonModal; diff --git a/src/service/events/generateQA.ts b/src/service/events/generateQA.ts index 15e570978..70941ee86 100644 --- a/src/service/events/generateQA.ts +++ b/src/service/events/generateQA.ts @@ -1,10 +1,12 @@ -import { SplitData, ModelData } from '@/service/mongo'; +import { SplitData } from '@/service/mongo'; import { getOpenAIApi } from '@/service/utils/chat'; import { httpsAgent, getOpenApiKey } from '@/service/utils/tools'; import type { ChatCompletionRequestMessage } from 'openai'; import { ChatModelNameEnum } from '@/constants/model'; import { pushSplitDataBill } from '@/service/events/pushBill'; import { generateVector } from './generateVector'; +import { connectRedis } from '../redis'; +import { VecModelDataPrefix } from '@/constants/redis'; import { customAlphabet } from 'nanoid'; const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12); @@ -18,6 +20,7 @@ export async function generateQA(next = false): Promise { }; try { + const redis = await connectRedis(); // 找出一个需要生成的 dataItem const dataItem = await SplitData.findOne({ textList: { $exists: true, $ne: [] } @@ -29,8 +32,10 @@ export async function generateQA(next = false): Promise { return; } + // 源文本 const text = dataItem.textList[dataItem.textList.length - 1]; if (!text) { + await SplitData.findByIdAndUpdate(dataItem._id, { $pop: { textList: 1 } }); // 弹出无效文本 throw new Error('无文本'); } @@ -63,7 +68,7 @@ export async function generateQA(next = false): Promise { .createChatCompletion( { model: ChatModelNameEnum.GPT35, - temperature: 0.2, + temperature: 0.4, n: 1, messages: [ systemPrompt, @@ -79,26 +84,29 @@ export async function generateQA(next = false): Promise { } ) .then((res) => ({ - rawContent: res?.data.choices[0].message?.content || '', - result: splitText(res?.data.choices[0].message?.content || '') - })); // 从 content 中提取 QA + rawContent: res?.data.choices[0].message?.content || '', // chatgpt原本的回复 + result: splitText(res?.data.choices[0].message?.content || '') // 格式化后的QA对 + })); await Promise.allSettled([ - SplitData.findByIdAndUpdate(dataItem._id, { $pop: { textList: 1 } }), - ModelData.insertMany( - response.result.map((item) => ({ - modelId: dataItem.modelId, - userId: dataItem.userId, - text: item.a, - q: [ - { - id: nanoid(), - text: item.q - } - ], - status: 1 - })) - ) + SplitData.findByIdAndUpdate(dataItem._id, { $pop: { textList: 1 } }), // 弹出已经拆分的文本 + ...response.result.map((item) => { + // 插入 redis + return redis.sendCommand([ + 'HMSET', + `${VecModelDataPrefix}:${nanoid()}`, + 'userId', + String(dataItem.userId), + 'modelId', + String(dataItem.modelId), + 'q', + item.q, + 'text', + item.a, + 'status', + 'waiting' + ]); + }) ]); console.log( diff --git a/src/service/events/generateVector.ts b/src/service/events/generateVector.ts index 407429251..118377458 100644 --- a/src/service/events/generateVector.ts +++ b/src/service/events/generateVector.ts @@ -1,9 +1,9 @@ import { getOpenAIApi } from '@/service/utils/chat'; import { httpsAgent } from '@/service/utils/tools'; -import { ModelData } from '../models/modelData'; import { connectRedis } from '../redis'; -import { VecModelDataIndex } from '@/constants/redis'; +import { VecModelDataIdx } from '@/constants/redis'; import { vectorToBuffer } from '@/utils/tools'; +import { ModelDataStatusEnum } from '@/constants/redis'; export async function generateVector(next = false): Promise { if (global.generatingVector && !next) return; @@ -12,74 +12,71 @@ export async function generateVector(next = false): Promise { try { const redis = await connectRedis(); - // 找出一个需要生成的 dataItem - const dataItem = await ModelData.findOne({ - status: { $ne: 0 } - }); + // 从找出一个 status = waiting 的数据 + const searchRes = await redis.ft.search( + VecModelDataIdx, + `@status:{${ModelDataStatusEnum.waiting}}`, + { + RETURN: ['q'], + LIMIT: { + from: 0, + size: 1 + } + } + ); - if (!dataItem) { + if (searchRes.total === 0) { console.log('没有需要生成 【向量】 的数据'); global.generatingVector = false; return; } + const dataItem: { id: string; q: string } = { + id: searchRes.documents[0].id, + q: String(searchRes.documents[0]?.value?.q || '') + }; + // 获取 openapi Key const openAiKey = process.env.OPENAIKEY as string; // 获取 openai 请求实例 const chatAPI = getOpenAIApi(openAiKey); - const dataId = String(dataItem._id); - // 生成词向量 - const response = await Promise.allSettled( - dataItem.q.map((item, i) => - chatAPI - .createEmbedding( - { - model: 'text-embedding-ada-002', - input: item.text - }, - { - timeout: 120000, - httpsAgent - } - ) - .then((res) => res?.data?.data?.[0]?.embedding || []) - .then((vector) => - redis.sendCommand([ - 'HMSET', - `${VecModelDataIndex}:${item.id}`, - 'vector', - vectorToBuffer(vector), - 'modelId', - String(dataItem.modelId), - 'dataId', - String(dataId) - ]) - ) + const vector = await chatAPI + .createEmbedding( + { + model: 'text-embedding-ada-002', + input: dataItem.q + }, + { + timeout: 120000, + httpsAgent + } ) - ); + .then((res) => res?.data?.data?.[0]?.embedding || []); - if (response.filter((item) => item.status === 'fulfilled').length === 0) { - throw new Error(JSON.stringify(response)); - } - // 修改该数据状态 - await ModelData.findByIdAndUpdate(dataItem._id, { - status: 0 - }); + // 更新 redis 向量和状态数据 + await redis.sendCommand([ + 'HMSET', + dataItem.id, + 'vector', + vectorToBuffer(vector), + 'status', + ModelDataStatusEnum.ready + ]); - console.log(`生成向量成功: ${dataItem._id}`); + console.log(`生成向量成功: ${dataItem.id}`); setTimeout(() => { generateVector(true); - }, 3000); + }, 2000); } catch (error: any) { - console.log(error); - console.log('error: 生成向量错误', error?.response?.data); + console.log('error: 生成向量错误', error?.response?.statusText); + !error?.response && console.log(error); if (error?.response?.statusText === 'Too Many Requests') { - console.log('次数限制,1分钟后尝试'); + console.log('生成向量次数限制,1分钟后尝试'); // 限制次数,1分钟后再试 setTimeout(() => { generateVector(true); diff --git a/src/service/models/modelData.ts b/src/service/models/modelData.ts deleted file mode 100644 index d8e0e3c30..000000000 --- a/src/service/models/modelData.ts +++ /dev/null @@ -1,37 +0,0 @@ -/* 模型的知识库 */ -import { Schema, model, models, Model as MongoModel } from 'mongoose'; -import { ModelDataSchema as ModelDataType } from '@/types/mongoSchema'; - -const ModelDataSchema = new Schema({ - modelId: { - type: Schema.Types.ObjectId, - ref: 'model', - required: true - }, - userId: { - type: Schema.Types.ObjectId, - ref: 'user', - required: true - }, - text: { - type: String, - required: true - }, - q: { - type: [ - { - id: String, // 对应redis的key - text: String - } - ], - default: [] - }, - status: { - type: Number, - enum: [0, 1], // 1 训练ing - default: 1 - } -}); - -export const ModelData: MongoModel = - models['modelData'] || model('modelData', ModelDataSchema); diff --git a/src/service/mongo.ts b/src/service/mongo.ts index 9f3aef612..fda7f46f2 100644 --- a/src/service/mongo.ts +++ b/src/service/mongo.ts @@ -35,7 +35,6 @@ export async function connectToDatabase(): Promise { export * from './models/authCode'; export * from './models/chat'; export * from './models/model'; -export * from './models/modelData'; export * from './models/user'; export * from './models/training'; export * from './models/bill'; diff --git a/src/service/redis.ts b/src/service/redis.ts index f0cd0bf9e..9a1c6c783 100644 --- a/src/service/redis.ts +++ b/src/service/redis.ts @@ -29,8 +29,8 @@ export const connectRedis = async () => { await global.redisClient.connect(); - // 0 - 测试库,1 - 正式 - await global.redisClient.select(0); + // 1 - 测试库,0 - 正式 + await global.redisClient.select(process.env.NODE_ENV === 'development' ? 0 : 0); return global.redisClient; } catch (error) { diff --git a/src/types/mongoSchema.d.ts b/src/types/mongoSchema.d.ts index 33cd5a2e3..10c6a0971 100644 --- a/src/types/mongoSchema.d.ts +++ b/src/types/mongoSchema.d.ts @@ -60,7 +60,7 @@ export interface ModelDataSchema { q: { id: string; text: string; - }[]; + }; status: ModelDataType; } diff --git a/src/types/redis.d.ts b/src/types/redis.d.ts index d4b9cd136..dba1175e1 100644 --- a/src/types/redis.d.ts +++ b/src/types/redis.d.ts @@ -1,6 +1,7 @@ +import { ModelDataStatusEnum } from '@/constants/redis'; export interface RedisModelDataItemType { id: string; - vector: number[]; - dataId: string; - modelId: string; + q: string; + text: string; + status: `${ModelDataStatusEnum}`; } diff --git a/src/utils/tools.ts b/src/utils/tools.ts index fb8239b57..254236769 100644 --- a/src/utils/tools.ts +++ b/src/utils/tools.ts @@ -127,3 +127,9 @@ export const vectorToBuffer = (vector: number[]) => { return Buffer.from(npVector.buffer); }; +export function formatVector(vector: number[]) { + let formattedVector = vector.slice(0, 1536); // 截取前1536个元素 + formattedVector = formattedVector.concat(Array(1536 - formattedVector.length).fill(0)); // 在后面添加0 + + return formattedVector; +}