feat: 模型数据管理

feat: 模型数据导入

feat: redis 向量入库

feat: 向量索引

feat: 文件导入模型

perf: 交互

perf: prompt
This commit is contained in:
archer
2023-03-29 00:22:48 +08:00
parent 713332522f
commit 2099a87908
45 changed files with 1522 additions and 284 deletions

View File

@@ -3,4 +3,6 @@ AXIOS_PROXY_PORT=33210
MONGODB_URI=
MY_MAIL=
MAILE_CODE=
TOKEN_KEY=
TOKEN_KEY=
OPENAIKEY=
REDIS_URL=

View File

@@ -41,6 +41,7 @@
"react-hook-form": "^7.43.1",
"react-markdown": "^8.0.5",
"react-syntax-highlighter": "^15.5.0",
"redis": "^4.6.5",
"rehype-katex": "^6.0.2",
"remark-gfm": "^3.0.1",
"remark-math": "^5.1.1",

95
pnpm-lock.yaml generated
View File

@@ -47,6 +47,7 @@ specifiers:
react-hook-form: ^7.43.1
react-markdown: ^8.0.5
react-syntax-highlighter: ^15.5.0
redis: ^4.6.5
rehype-katex: ^6.0.2
remark-gfm: ^3.0.1
remark-math: ^5.1.1
@@ -87,6 +88,7 @@ dependencies:
react-hook-form: registry.npmmirror.com/react-hook-form/7.43.1_react@18.2.0
react-markdown: registry.npmmirror.com/react-markdown/8.0.5_pmekkgnqduwlme35zpnqhenc34
react-syntax-highlighter: registry.npmmirror.com/react-syntax-highlighter/15.5.0_react@18.2.0
redis: registry.npmmirror.com/redis/4.6.5
rehype-katex: registry.npmmirror.com/rehype-katex/6.0.2
remark-gfm: registry.npmmirror.com/remark-gfm/3.0.1
remark-math: registry.npmmirror.com/remark-math/5.1.1
@@ -4504,6 +4506,72 @@ packages:
version: 2.11.6
dev: false
registry.npmmirror.com/@redis/bloom/1.2.0_@redis+client@1.5.6:
resolution: {integrity: sha512-HG2DFjYKbpNmVXsa0keLHp/3leGJz1mjh09f2RLGGLQZzSHpkmZWuwJbAvo3QcRY8p80m5+ZdXZdYOSBLlp7Cg==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/@redis/bloom/-/bloom-1.2.0.tgz}
id: registry.npmmirror.com/@redis/bloom/1.2.0
name: '@redis/bloom'
version: 1.2.0
peerDependencies:
'@redis/client': ^1.0.0
dependencies:
'@redis/client': registry.npmmirror.com/@redis/client/1.5.6
dev: false
registry.npmmirror.com/@redis/client/1.5.6:
resolution: {integrity: sha512-dFD1S6je+A47Lj22jN/upVU2fj4huR7S9APd7/ziUXsIXDL+11GPYti4Suv5y8FuXaN+0ZG4JF+y1houEJ7ToA==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/@redis/client/-/client-1.5.6.tgz}
name: '@redis/client'
version: 1.5.6
engines: {node: '>=14'}
dependencies:
cluster-key-slot: registry.npmmirror.com/cluster-key-slot/1.1.2
generic-pool: registry.npmmirror.com/generic-pool/3.9.0
yallist: registry.npmmirror.com/yallist/4.0.0
dev: false
registry.npmmirror.com/@redis/graph/1.1.0_@redis+client@1.5.6:
resolution: {integrity: sha512-16yZWngxyXPd+MJxeSr0dqh2AIOi8j9yXKcKCwVaKDbH3HTuETpDVPcLujhFYVPtYrngSco31BUcSa9TH31Gqg==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/@redis/graph/-/graph-1.1.0.tgz}
id: registry.npmmirror.com/@redis/graph/1.1.0
name: '@redis/graph'
version: 1.1.0
peerDependencies:
'@redis/client': ^1.0.0
dependencies:
'@redis/client': registry.npmmirror.com/@redis/client/1.5.6
dev: false
registry.npmmirror.com/@redis/json/1.0.4_@redis+client@1.5.6:
resolution: {integrity: sha512-LUZE2Gdrhg0Rx7AN+cZkb1e6HjoSKaeeW8rYnt89Tly13GBI5eP4CwDVr+MY8BAYfCg4/N15OUrtLoona9uSgw==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/@redis/json/-/json-1.0.4.tgz}
id: registry.npmmirror.com/@redis/json/1.0.4
name: '@redis/json'
version: 1.0.4
peerDependencies:
'@redis/client': ^1.0.0
dependencies:
'@redis/client': registry.npmmirror.com/@redis/client/1.5.6
dev: false
registry.npmmirror.com/@redis/search/1.1.2_@redis+client@1.5.6:
resolution: {integrity: sha512-/cMfstG/fOh/SsE+4/BQGeuH/JJloeWuH+qJzM8dbxuWvdWibWAOAHHCZTMPhV3xIlH4/cUEIA8OV5QnYpaVoA==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/@redis/search/-/search-1.1.2.tgz}
id: registry.npmmirror.com/@redis/search/1.1.2
name: '@redis/search'
version: 1.1.2
peerDependencies:
'@redis/client': ^1.0.0
dependencies:
'@redis/client': registry.npmmirror.com/@redis/client/1.5.6
dev: false
registry.npmmirror.com/@redis/time-series/1.0.4_@redis+client@1.5.6:
resolution: {integrity: sha512-ThUIgo2U/g7cCuZavucQTQzA9g9JbDDY2f64u3AbAoz/8vE2lt2U37LamDUVChhaDA3IRT9R6VvJwqnUfTJzng==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/@redis/time-series/-/time-series-1.0.4.tgz}
id: registry.npmmirror.com/@redis/time-series/1.0.4
name: '@redis/time-series'
version: 1.0.4
peerDependencies:
'@redis/client': ^1.0.0
dependencies:
'@redis/client': registry.npmmirror.com/@redis/client/1.5.6
dev: false
registry.npmmirror.com/@rushstack/eslint-patch/1.2.0:
resolution: {integrity: sha512-sXo/qW2/pAcmT43VoRKOJbDOfV3cYpq3szSVfIThQXNt+E4DfKj361vaAt3c88U5tPUxzEswam7GW48PJqtKAg==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/@rushstack/eslint-patch/-/eslint-patch-1.2.0.tgz}
name: '@rushstack/eslint-patch'
@@ -5562,6 +5630,13 @@ packages:
version: 0.0.1
dev: false
registry.npmmirror.com/cluster-key-slot/1.1.2:
resolution: {integrity: sha512-RMr0FhtfXemyinomL4hrWcYJxmX6deFdCxpJzhDttxgO1+bcCnkk+9drydLVDmAMG7NE6aN/fl4F7ucU/90gAA==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/cluster-key-slot/-/cluster-key-slot-1.1.2.tgz}
name: cluster-key-slot
version: 1.1.2
engines: {node: '>=0.10.0'}
dev: false
registry.npmmirror.com/color-convert/1.9.3:
resolution: {integrity: sha512-QfAUtd+vFdAtFQcC8CCyYt1fYWxSqAiK2cSD6zDB8N3cpsEBAvRxp9zOGg6G/SHHJYAT88/az/IuDGALsNVbGg==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/color-convert/-/color-convert-1.9.3.tgz}
name: color-convert
@@ -6799,6 +6874,13 @@ packages:
version: 1.2.3
dev: true
registry.npmmirror.com/generic-pool/3.9.0:
resolution: {integrity: sha512-hymDOu5B53XvN4QT9dBmZxPX4CWhBPPLguTZ9MMFeFa/Kg0xWVfylOVNlJji/E7yTZWFd/q9GO5TxDLq156D7g==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/generic-pool/-/generic-pool-3.9.0.tgz}
name: generic-pool
version: 3.9.0
engines: {node: '>= 4'}
dev: false
registry.npmmirror.com/gensync/1.0.0-beta.2:
resolution: {integrity: sha512-3hN7NaskYvMDLQY55gnW3NQ+mesEAepTqlg+VEbj7zzqEMBVNhzcGYYeqFo/TlYz6eQiFcp1HcsCZO+nGgS8zg==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/gensync/-/gensync-1.0.0-beta.2.tgz}
name: gensync
@@ -9367,6 +9449,19 @@ packages:
picomatch: registry.npmmirror.com/picomatch/2.3.1
dev: false
registry.npmmirror.com/redis/4.6.5:
resolution: {integrity: sha512-O0OWA36gDQbswOdUuAhRL6mTZpHFN525HlgZgDaVNgCJIAZR3ya06NTESb0R+TUZ+BFaDpz6NnnVvoMx9meUFg==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/redis/-/redis-4.6.5.tgz}
name: redis
version: 4.6.5
dependencies:
'@redis/bloom': registry.npmmirror.com/@redis/bloom/1.2.0_@redis+client@1.5.6
'@redis/client': registry.npmmirror.com/@redis/client/1.5.6
'@redis/graph': registry.npmmirror.com/@redis/graph/1.1.0_@redis+client@1.5.6
'@redis/json': registry.npmmirror.com/@redis/json/1.0.4_@redis+client@1.5.6
'@redis/search': registry.npmmirror.com/@redis/search/1.1.2_@redis+client@1.5.6
'@redis/time-series': registry.npmmirror.com/@redis/time-series/1.0.4_@redis+client@1.5.6
dev: false
registry.npmmirror.com/refractor/3.6.0:
resolution: {integrity: sha512-MY9W41IOWxxk31o+YvFCNyNzdkc9M20NoZK5vq6jkv4I/uh2zkWcfudj0Q1fovjUQJrNewS9NMzeTtqPf+n5EA==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/refractor/-/refractor-3.6.0.tgz}
name: refractor

View File

@@ -1,7 +1,10 @@
import { GET, POST, DELETE, PUT } from './request';
import type { ModelSchema } from '@/types/mongoSchema';
import type { ModelSchema, ModelDataSchema } from '@/types/mongoSchema';
import { ModelUpdateParams } from '@/types/model';
import { TrainingItemType } from '../types/training';
import { PagingData } from '@/types';
import { RequestPaging } from '../types/index';
import { Obj2Query } from '@/utils/tools';
export const getMyModels = () => GET<ModelSchema[]>('/model/list');
@@ -16,13 +19,35 @@ export const putModelById = (id: string, data: ModelUpdateParams) =>
PUT(`/model/update?modelId=${id}`, data);
export const postTrainModel = (id: string, form: FormData) =>
POST(`/model/train?modelId=${id}`, form, {
POST(`/model/train/train?modelId=${id}`, form, {
headers: {
'content-type': 'multipart/form-data'
}
});
export const putModelTrainingStatus = (id: string) => PUT(`/model/putTrainStatus?modelId=${id}`);
export const putModelTrainingStatus = (id: string) =>
PUT(`/model/train/putTrainStatus?modelId=${id}`);
export const getModelTrainings = (id: string) =>
GET<TrainingItemType[]>(`/model/getTrainings?modelId=${id}`);
GET<TrainingItemType[]>(`/model/train/getTrainings?modelId=${id}`);
/* 模型 data */
type GetModelDataListProps = RequestPaging & {
modelId: string;
};
export const getModelDataList = (props: GetModelDataListProps) =>
GET(`/model/data/getModelData?${Obj2Query(props)}`);
export const postModelDataInput = (data: {
modelId: string;
data: { text: ModelDataSchema['text']; q: ModelDataSchema['q'] }[];
}) => POST(`/model/data/pushModelDataInput`, data);
export const postModelDataFileText = (modelId: string, text: string) =>
POST(`/model/data/splitData`, { modelId, text });
export const putModelDataById = (data: { dataId: string; text: string }) =>
PUT('/model/data/putModelData', data);
export const delOneModelData = (dataId: string) =>
DELETE(`/model/data/delModelDataById?dataId=${dataId}`);

View File

@@ -26,12 +26,12 @@ const navbarList = [
link: '/model/list',
activeLink: ['/model/list', '/model/detail']
},
{
label: '数据',
icon: 'icon-datafull',
link: '/data/list',
activeLink: ['/data/list', '/data/detail']
},
// {
// label: '数据',
// icon: 'icon-datafull',
// link: '/data/list',
// activeLink: ['/data/list', '/data/detail']
// },
{
label: '账号',
icon: 'icon-yonghu-yuan',

View File

@@ -1,11 +1,17 @@
import type { ServiceName } from '@/types/mongoSchema';
import { ModelSchema } from '../types/mongoSchema';
import type { ServiceName, ModelDataType, ModelSchema } from '@/types/mongoSchema';
export enum ChatModelNameEnum {
GPT35 = 'gpt-3.5-turbo',
VECTOR_GPT = 'VECTOR_GPT',
GPT3 = 'text-davinci-003'
}
export const ChatModelNameMap = {
[ChatModelNameEnum.GPT35]: 'gpt-3.5-turbo',
[ChatModelNameEnum.VECTOR_GPT]: 'gpt-3.5-turbo',
[ChatModelNameEnum.GPT3]: 'text-davinci-003'
};
export type ModelConstantsData = {
serviceCompany: `${ServiceName}`;
name: string;
@@ -29,6 +35,17 @@ export const modelList: ModelConstantsData[] = [
trainedMaxToken: 2000,
maxTemperature: 2,
price: 3
},
{
serviceCompany: 'openai',
name: '知识库',
model: ChatModelNameEnum.VECTOR_GPT,
trainName: 'vector',
maxToken: 4000,
contextMaxToken: 7500,
trainedMaxToken: 2000,
maxTemperature: 1,
price: 3
}
// {
// serviceCompany: 'openai',
@@ -76,6 +93,11 @@ export const formatModelStatus = {
}
};
export const ModelDataStatusMap = {
0: '训练完成',
1: '训练中'
};
export const defaultModel: ModelSchema = {
_id: '',
userId: '',

1
src/constants/redis.ts Normal file
View File

@@ -0,0 +1 @@
export const VecModelDataIndex = 'model:data';

View File

@@ -8,7 +8,7 @@ export const usePaging = <T = any>({
pageSize = 10,
params = {}
}: {
api: (data: any) => Promise<PagingData<T>>;
api: (data: any) => any;
pageSize?: number;
params?: Record<string, any>;
}) => {
@@ -30,7 +30,7 @@ export const usePaging = <T = any>({
setRequesting(true);
try {
const res = await api({
const res: PagingData<T> = await api({
pageNum: num,
pageSize,
...params
@@ -75,6 +75,7 @@ export const usePaging = <T = any>({
requesting,
isLoadAll,
nextPage,
initRequesting
initRequesting,
setData
};
};

View File

@@ -46,7 +46,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
const model: ModelSchema = chat.modelId;
const modelConstantsData = modelList.find((item) => item.model === model.service.modelName);
if (!modelConstantsData) {
throw new Error('模型异常,请用 chatgpt 模型');
throw new Error('模型加载异常');
}
// 读取对话内容

View File

@@ -0,0 +1,241 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { createParser, ParsedEvent, ReconnectInterval } from 'eventsource-parser';
import { connectToDatabase, ModelData } from '@/service/mongo';
import { getOpenAIApi, authChat } from '@/service/utils/chat';
import { httpsAgent, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools';
import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai';
import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response';
import type { ModelSchema } from '@/types/mongoSchema';
import { PassThrough } from 'stream';
import { modelList } from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill';
import { connectRedis } from '@/service/redis';
import { VecModelDataIndex } from '@/constants/redis';
import { vectorToBuffer } from '@/utils/tools';
let vectorData = [
-0.025028639, -0.010407282, 0.026523087, -0.0107438695, -0.006967359, 0.010043768, -0.012043097,
0.008724345, -0.028919589, -0.0117738275, 0.0050690062, 0.02961969
].concat(new Array(1524).fill(0));
/* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
let step = 0; // step=1时表示开始了流响应
const stream = new PassThrough();
stream.on('error', () => {
console.log('error: ', 'stream error');
stream.destroy();
});
res.on('close', () => {
stream.destroy();
});
res.on('error', () => {
console.log('error: ', 'request error');
stream.destroy();
});
try {
const { chatId, prompt } = req.body as {
prompt: ChatItemType;
chatId: string;
};
const { authorization } = req.headers;
if (!chatId || !prompt) {
throw new Error('缺少参数');
}
await connectToDatabase();
const redis = await connectRedis();
const { chat, userApiKey, systemKey, userId } = await authChat(chatId, authorization);
const model: ModelSchema = chat.modelId;
const modelConstantsData = modelList.find((item) => item.model === model.service.modelName);
if (!modelConstantsData) {
throw new Error('模型加载异常');
}
// 读取对话内容
const prompts = [...chat.content, prompt];
// 获取 chatAPI
const chatAPI = getOpenAIApi(userApiKey || systemKey);
// 把输入的内容转成向量
const promptVector = await chatAPI
.createEmbedding(
{
model: 'text-embedding-ada-002',
input: prompt.value
},
{
timeout: 120000,
httpsAgent
}
)
.then((res) => res?.data?.data?.[0]?.embedding || []);
const binary = vectorToBuffer(promptVector);
// 搜索系统提示词, 按相似度从 redis 中搜出前3条不同 dataId 的数据
const redisData: any[] = await redis.sendCommand([
'FT.SEARCH',
`idx:${VecModelDataIndex}`,
`@modelId:{${String(chat.modelId._id)}} @vector:[VECTOR_RANGE 0.2 $blob]`,
// `@modelId:{${String(chat.modelId._id)}}=>[KNN 10 @vector $blob AS score]`,
'RETURN',
'1',
'dataId',
// 'SORTBY',
// 'score',
'PARAMS',
'2',
'blob',
binary,
'DIALECT',
'2'
]);
// 格式化响应值获取去重后的id
let formatIds = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
.map((i) => {
if (!redisData[i] || !redisData[i][1]) return '';
return redisData[i][1];
})
.filter((item) => item);
formatIds = Array.from(new Set(formatIds));
if (formatIds.length === 0) {
throw new Error('对不起,我没有找到你的问题');
}
// 从 mongo 中取出原文作为提示词
const textArr = (
await Promise.all(
[2, 4, 6, 8, 10, 12, 14, 16, 18, 20].map((i) => {
if (!redisData[i] || !redisData[i][1]) return '';
return ModelData.findById(redisData[i][1])
.select('text')
.then((res) => res?.text || '');
})
)
).filter((item) => item);
// textArr 筛选,最多 3000 tokens
const systemPrompt = systemPromptFilter(textArr, 2800);
prompts.unshift({
obj: 'SYSTEM',
value: `请根据下面的知识回答问题: ${systemPrompt}`
});
// 控制在 tokens 数量,防止超出
const filterPrompts = openaiChatFilter(prompts, modelConstantsData.contextMaxToken);
// 格式化文本内容成 chatgpt 格式
const map = {
Human: ChatCompletionRequestMessageRoleEnum.User,
AI: ChatCompletionRequestMessageRoleEnum.Assistant,
SYSTEM: ChatCompletionRequestMessageRoleEnum.System
};
const formatPrompts: ChatCompletionRequestMessage[] = filterPrompts.map(
(item: ChatItemType) => ({
role: map[item.obj],
content: item.value
})
);
// console.log(formatPrompts);
// 计算温度
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10);
let startTime = Date.now();
// 发出请求
const chatResponse = await chatAPI.createChatCompletion(
{
model: model.service.chatModel,
temperature: temperature,
// max_tokens: modelConstantsData.maxToken,
messages: formatPrompts,
frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容
stream: true
},
{
timeout: 40000,
responseType: 'stream',
httpsAgent
}
);
console.log('api response time:', `${(Date.now() - startTime) / 1000}s`);
// 创建响应流
res.setHeader('Content-Type', 'text/event-stream;charset-utf-8');
res.setHeader('Access-Control-Allow-Origin', '*');
res.setHeader('X-Accel-Buffering', 'no');
res.setHeader('Cache-Control', 'no-cache, no-transform');
step = 1;
let responseContent = '';
stream.pipe(res);
const onParse = async (event: ParsedEvent | ReconnectInterval) => {
if (event.type !== 'event') return;
const data = event.data;
if (data === '[DONE]') return;
try {
const json = JSON.parse(data);
const content: string = json?.choices?.[0].delta.content || '';
if (!content || (responseContent === '' && content === '\n')) return;
responseContent += content;
// console.log('content:', content)
!stream.destroyed && stream.push(content.replace(/\n/g, '<br/>'));
} catch (error) {
error;
}
};
const decoder = new TextDecoder();
try {
for await (const chunk of chatResponse.data as any) {
if (stream.destroyed) {
// 流被中断了,直接忽略后面的内容
break;
}
const parser = createParser(onParse);
parser.feed(decoder.decode(chunk));
}
} catch (error) {
console.log('pipe error', error);
}
// close stream
!stream.destroyed && stream.push(null);
stream.destroy();
const promptsContent = formatPrompts.map((item) => item.content).join('');
// 只有使用平台的 key 才计费
pushChatBill({
isPay: !userApiKey,
modelName: model.service.modelName,
userId,
chatId,
text: promptsContent + responseContent
});
// jsonRes(res);
} catch (err: any) {
if (step === 1) {
// 直接结束流
console.log('error结束');
stream.destroy();
} else {
res.status(500);
jsonRes(res, {
code: 500,
error: err
});
}
}
}

View File

@@ -24,7 +24,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
if (!DataRecord) {
throw new Error('找不到数据集');
}
const replaceText = text.replace(/[\r\n\\n]+/g, ' ');
const replaceText = text.replace(/[\\n]+/g, ' ');
// 文本拆分成 chunk
let chunks = replaceText.match(/[^!?.。]+[!?.。]/g) || [];
@@ -35,7 +35,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
chunks.forEach((chunk) => {
splitText += chunk;
const tokens = encode(splitText).length;
if (tokens >= 980) {
if (tokens >= 780) {
dataItems.push({
userId,
dataId,

View File

@@ -3,7 +3,7 @@ import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import { ModelStatusEnum, modelList, ChatModelNameEnum } from '@/constants/model';
import { ModelStatusEnum, modelList, ChatModelNameEnum, ChatModelNameMap } from '@/constants/model';
import { Model } from '@/service/models/model';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
@@ -33,15 +33,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
await connectToDatabase();
// 重名校验
const authRepeatName = await Model.findOne({
name,
userId
});
if (authRepeatName) {
throw new Error('模型名重复');
}
// 上限校验
const authCount = await Model.countDocuments({
userId
@@ -57,9 +48,9 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
status: ModelStatusEnum.running,
service: {
company: modelItem.serviceCompany,
trainId: modelItem.trainName,
chatModel: modelItem.model,
modelName: modelItem.model
trainId: '',
chatModel: ChatModelNameMap[modelItem.model], // 聊天时用的模型
modelName: modelItem.model // 最底层的模型,不会变,用于计费等核心操作
}
});

View File

@@ -5,8 +5,8 @@ import { authToken } from '@/service/utils/tools';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
let { modelId } = req.query as {
modelId: string;
let { dataId } = req.query as {
dataId: string;
};
const { authorization } = req.headers;
@@ -14,7 +14,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
throw new Error('无权操作');
}
if (!modelId) {
if (!dataId) {
throw new Error('缺少参数');
}
@@ -24,7 +24,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
await connectToDatabase();
await ModelData.deleteOne({
modelId,
_id: dataId,
userId
});

View File

@@ -14,6 +14,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
pageNum: string;
pageSize: string;
};
const { authorization } = req.headers;
pageNum = +pageNum;
@@ -41,7 +42,15 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
.limit(pageSize);
jsonRes(res, {
data
data: {
pageNum,
pageSize,
data,
total: await ModelData.countDocuments({
modelId,
userId
})
}
});
} catch (err) {
jsonRes(res, {

View File

@@ -2,12 +2,14 @@ import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, ModelData, Model } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import { ModelDataSchema } from '@/types/mongoSchema';
import { generateVector } from '@/service/events/generateVector';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
const { modelId, data } = req.body as {
modelId: string;
data: { q: string; a: string }[];
data: { text: ModelDataSchema['text']; q: ModelDataSchema['q'] }[];
};
const { authorization } = req.headers;
@@ -43,6 +45,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
}))
);
generateVector(true);
jsonRes(res, {
data: model
});

View File

@@ -0,0 +1,57 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, DataItem, ModelData } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import { customAlphabet } from 'nanoid';
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
try {
let { dataIds, modelId } = req.body as { dataIds: string[]; modelId: string };
if (!dataIds) {
throw new Error('参数错误');
}
await connectToDatabase();
const { authorization } = req.headers;
const userId = await authToken(authorization);
const dataItems = (
await Promise.all(
dataIds.map((dataId) =>
DataItem.find<{ _id: string; result: { q: string }[]; text: string }>(
{
userId,
dataId
},
'result text'
)
)
)
).flat();
// push data
await ModelData.insertMany(
dataItems.map((item) => ({
modelId: modelId,
userId,
text: item.text,
q: item.result.map((item) => ({
id: nanoid(),
text: item.q
}))
}))
);
jsonRes(res, {
data: dataItems
});
} catch (err) {
jsonRes(res, {
code: 500,
error: err
});
}
}

View File

@@ -5,9 +5,9 @@ import { authToken } from '@/service/utils/tools';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
let { modelId, answer } = req.body as {
modelId: string;
answer: string;
let { dataId, text } = req.body as {
dataId: string;
text: string;
};
const { authorization } = req.headers;
@@ -15,7 +15,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
throw new Error('无权操作');
}
if (!modelId) {
if (!dataId) {
throw new Error('缺少参数');
}
@@ -26,11 +26,11 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
await ModelData.updateOne(
{
modelId,
_id: dataId,
userId
},
{
a: answer
text
}
);

View File

@@ -0,0 +1,67 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, SplitData, Model } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import { generateQA } from '@/service/events/generateQA';
import { encode } from 'gpt-token-utils';
/* 拆分数据成QA */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
try {
const { text, modelId } = req.body as { text: string; modelId: string };
if (!text || !modelId) {
throw new Error('参数错误');
}
await connectToDatabase();
const { authorization } = req.headers;
const userId = await authToken(authorization);
// 验证是否是该用户的 model
const model = await Model.findOne({
_id: modelId,
userId
});
if (!model) {
throw new Error('无权操作该模型');
}
const replaceText = text.replace(/(\\n|\n)+/g, ' ');
// 文本拆分成 chunk
let chunks = replaceText.match(/[^!?.。]+[!?.。]/g) || [];
const textList: string[] = [];
let splitText = '';
chunks.forEach((chunk) => {
splitText += chunk;
const tokens = encode(splitText).length;
if (tokens >= 980) {
textList.push(splitText);
splitText = '';
}
});
// 批量插入数据
await SplitData.create({
userId,
modelId,
rawText: text,
textList
});
// generateQA();
jsonRes(res, {
data: { chunks, replaceText }
});
} catch (err) {
jsonRes(res, {
code: 500,
error: err
});
}
}

View File

@@ -1,6 +1,6 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { Chat, Model, Training, connectToDatabase } from '@/service/mongo';
import { Chat, Model, Training, connectToDatabase, ModelData } from '@/service/mongo';
import { authToken, getUserOpenaiKey } from '@/service/utils/tools';
import { TrainingStatusEnum } from '@/constants/model';
import { getOpenAIApi } from '@/service/utils/chat';
@@ -26,16 +26,20 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
await connectToDatabase();
// 删除模型
await Model.deleteOne({
_id: modelId,
userId
});
let requestQueue: any[] = [];
// 删除对应的聊天
await Chat.deleteMany({
modelId
});
requestQueue.push(
Chat.deleteMany({
modelId
})
);
// 删除数据集
requestQueue.push(
ModelData.deleteMany({
modelId
})
);
// 查看是否正在训练
const training: TrainingItemType | null = await Training.findOne({
@@ -56,9 +60,20 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
}
// 删除对应训练记录
await Training.deleteMany({
modelId
});
requestQueue.push(
Training.deleteMany({
modelId
})
);
// 删除模型
requestQueue.push(
Model.deleteOne({
_id: modelId,
userId
})
);
await requestQueue;
jsonRes(res);
} catch (err) {

View File

@@ -37,7 +37,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
systemPrompt,
intro,
temperature,
service,
// service,
security
}
);

View File

@@ -119,6 +119,7 @@ const Chat = ({ chatId }: { chatId: string }) => {
async (prompts: ChatSiteItemType) => {
const urlMap: Record<string, string> = {
[ChatModelNameEnum.GPT35]: '/api/chat/chatGpt',
[ChatModelNameEnum.VECTOR_GPT]: '/api/chat/vectorGpt',
[ChatModelNameEnum.GPT3]: '/api/chat/gpt3'
};

View File

@@ -184,7 +184,7 @@ const DataList = () => {
>
</Button>
<Menu>
{/* <Menu>
<MenuButton as={Button} mr={2} size={'sm'} isLoading={isExporting}>
导出
</MenuButton>
@@ -200,7 +200,7 @@ const DataList = () => {
</MenuItem>
)}
</MenuList>
</Menu>
</Menu> */}
<Button
size={'sm'}

View File

@@ -0,0 +1,141 @@
import React, { useState, useCallback } from 'react';
import {
Box,
IconButton,
Flex,
Button,
Modal,
ModalOverlay,
ModalContent,
ModalHeader,
ModalCloseButton,
Input,
Textarea
} from '@chakra-ui/react';
import { useForm, useFieldArray } from 'react-hook-form';
import { postModelDataInput } from '@/api/model';
import { useToast } from '@/hooks/useToast';
import { DeleteIcon } from '@chakra-ui/icons';
import { customAlphabet } from 'nanoid';
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
type FormData = { text: string; q: { val: string }[] };
const InputDataModal = ({
onClose,
onSuccess,
modelId
}: {
onClose: () => void;
onSuccess: () => void;
modelId: string;
}) => {
const [importing, setImporting] = useState(false);
const { toast } = useToast();
const { register, handleSubmit, control } = useForm<FormData>({
defaultValues: {
text: '',
q: [{ val: '' }]
}
});
const {
fields: inputQ,
append: appendQ,
remove: removeQ
} = useFieldArray({
control,
name: 'q'
});
const sureImportData = useCallback(
async (e: FormData) => {
setImporting(true);
try {
await postModelDataInput({
modelId: modelId,
data: [
{
text: e.text,
q: e.q.map((item) => ({
id: nanoid(),
text: item.val
}))
}
]
});
toast({
title: '导入数据成功,需要一段时间训练',
status: 'success'
});
onClose();
onSuccess();
} catch (err) {
console.log(err);
}
setImporting(false);
},
[modelId, onClose, onSuccess, toast]
);
return (
<Modal isOpen={true} onClose={onClose}>
<ModalOverlay />
<ModalContent maxW={'min(900px, 90vw)'} maxH={'80vh'} position={'relative'}>
<ModalHeader></ModalHeader>
<ModalCloseButton />
<Box px={6} pb={2} overflowY={'auto'}>
<Box mb={2}>:</Box>
<Textarea
mb={4}
placeholder="知识点"
rows={3}
maxH={'200px'}
{...register(`text`, {
required: '知识点'
})}
/>
{inputQ.map((item, index) => (
<Box key={item.id} mb={5}>
<Box mb={2}>{index + 1}:</Box>
<Flex>
<Input
placeholder="问法"
{...register(`q.${index}.val`, {
required: '问法不能为空'
})}
></Input>
{inputQ.length > 1 && (
<IconButton
icon={<DeleteIcon />}
aria-label={'delete'}
colorScheme={'gray'}
variant={'unstyled'}
onClick={() => removeQ(index)}
/>
)}
</Flex>
</Box>
))}
</Box>
<Flex px={6} pt={2} pb={4}>
<Button alignSelf={'flex-start'} variant={'outline'} onClick={() => appendQ({ val: '' })}>
</Button>
<Box flex={1}></Box>
<Button variant={'outline'} mr={3} onClick={onClose}>
</Button>
<Button isLoading={importing} onClick={handleSubmit(sureImportData)}>
</Button>
</Flex>
</ModalContent>
</Modal>
);
};
export default InputDataModal;

View File

@@ -0,0 +1,202 @@
import React, { useCallback } from 'react';
import {
Box,
TableContainer,
Table,
Thead,
Tbody,
Tr,
Th,
Td,
IconButton,
Flex,
Button,
useDisclosure,
Textarea,
Menu,
MenuButton,
MenuList,
MenuItem
} from '@chakra-ui/react';
import type { ModelSchema } from '@/types/mongoSchema';
import { ModelDataSchema } from '@/types/mongoSchema';
import { ModelDataStatusMap } from '@/constants/model';
import { usePaging } from '@/hooks/usePaging';
import ScrollData from '@/components/ScrollData';
import { getModelDataList, delOneModelData, putModelDataById } from '@/api/model';
import { DeleteIcon, RepeatIcon } from '@chakra-ui/icons';
import { useToast } from '@/hooks/useToast';
import { useLoading } from '@/hooks/useLoading';
import dynamic from 'next/dynamic';
const InputModel = dynamic(() => import('./InputDataModal'));
const SelectModel = dynamic(() => import('./SelectFileModal'));
const ModelDataCard = ({ model }: { model: ModelSchema }) => {
const { toast } = useToast();
const { Loading } = useLoading();
const {
nextPage,
isLoadAll,
requesting,
data: modelDataList,
total,
setData,
getData
} = usePaging<ModelDataSchema>({
api: getModelDataList,
pageSize: 20,
params: {
modelId: model._id
}
});
const updateAnswer = useCallback(
async (dataId: string, text: string) => {
await putModelDataById({
dataId,
text
});
toast({
title: '修改回答成功',
status: 'success'
});
},
[toast]
);
const {
isOpen: isOpenInputModal,
onOpen: onOpenInputModal,
onClose: onCloseInputModal
} = useDisclosure();
const {
isOpen: isOpenSelectModal,
onOpen: onOpenSelectModal,
onClose: onCloseSelectModal
} = useDisclosure();
return (
<>
<Flex>
<Box fontWeight={'bold'} fontSize={'lg'} flex={1}>
: {total}{' '}
<Box as={'span'} fontSize={'sm'}>
</Box>
</Box>
<IconButton
icon={<RepeatIcon />}
aria-label={'refresh'}
variant={'outline'}
mr={4}
onClick={() => getData(1, true)}
/>
<Menu>
<MenuButton as={Button}></MenuButton>
<MenuList>
<MenuItem onClick={onOpenInputModal}></MenuItem>
<MenuItem onClick={onOpenSelectModal}></MenuItem>
</MenuList>
</Menu>
</Flex>
<ScrollData
h={'100%'}
px={6}
mt={3}
isLoadAll={isLoadAll}
requesting={requesting}
nextPage={nextPage}
position={'relative'}
>
<TableContainer mt={4}>
<Table variant={'simple'}>
<Thead>
<Tr>
<Th>Question</Th>
<Th>Text</Th>
<Th>Status</Th>
<Th></Th>
</Tr>
</Thead>
<Tbody>
{modelDataList.map((item) => (
<Tr key={item._id}>
<Td w={'350px'}>
{item.q.map((item, i) => (
<Box
key={item.id}
fontSize={'xs'}
w={'100%'}
whiteSpace={'pre-wrap'}
_notLast={{ mb: 1 }}
>
Q{i + 1}:{' '}
<Box as={'span'} userSelect={'all'}>
{item.text}
</Box>
</Box>
))}
</Td>
<Td minW={'200px'}>
<Textarea
w={'100%'}
h={'100%'}
defaultValue={item.text}
fontSize={'xs'}
resize={'both'}
onBlur={(e) => {
const oldVal = modelDataList.find((data) => item._id === data._id)?.text;
if (oldVal !== e.target.value) {
updateAnswer(item._id, e.target.value);
setData((state) =>
state.map((data) => ({
...data,
text: data._id === item._id ? e.target.value : data.text
}))
);
}
}}
></Textarea>
</Td>
<Td w={'100px'}>{ModelDataStatusMap[item.status]}</Td>
<Td>
<IconButton
icon={<DeleteIcon />}
variant={'outline'}
colorScheme={'gray'}
aria-label={'delete'}
size={'sm'}
onClick={async () => {
delOneModelData(item._id);
setData((state) => state.filter((data) => data._id !== item._id));
}}
/>
</Td>
</Tr>
))}
</Tbody>
</Table>
</TableContainer>
<Loading loading={requesting} fixed={false} />
</ScrollData>
{isOpenInputModal && (
<InputModel
modelId={model._id}
onClose={onCloseInputModal}
onSuccess={() => getData(1, true)}
/>
)}
{isOpenSelectModal && (
<SelectModel
modelId={model._id}
onClose={onCloseSelectModal}
onSuccess={() => getData(1, true)}
/>
)}
</>
);
};
export default ModelDataCard;

View File

@@ -11,13 +11,28 @@ import {
SliderFilledTrack,
SliderThumb,
SliderMark,
Tooltip
Tooltip,
Button
} from '@chakra-ui/react';
import { QuestionOutlineIcon } from '@chakra-ui/icons';
import type { ModelSchema } from '@/types/mongoSchema';
import { UseFormReturn } from 'react-hook-form';
import { modelList } from '@/constants/model';
import { formatPrice } from '@/utils/user';
import { useConfirm } from '@/hooks/useConfirm';
const ModelEditForm = ({ formHooks }: { formHooks: UseFormReturn<ModelSchema> }) => {
const ModelEditForm = ({
formHooks,
canTrain,
handleDelModel
}: {
formHooks: UseFormReturn<ModelSchema>;
canTrain: boolean;
handleDelModel: () => void;
}) => {
const { openConfirm, ConfirmChild } = useConfirm({
content: '确认删除该模型?'
});
const { register, setValue, getValues } = formHooks;
const [refresh, setRefresh] = useState(false);
@@ -29,7 +44,7 @@ const ModelEditForm = ({ formHooks }: { formHooks: UseFormReturn<ModelSchema> })
</Flex>
<FormControl mt={4}>
<Flex alignItems={'center'}>
<Box flex={'0 0 50px'} w={0}>
<Box flex={'0 0 80px'} w={0}>
:
</Box>
<Input
@@ -39,7 +54,36 @@ const ModelEditForm = ({ formHooks }: { formHooks: UseFormReturn<ModelSchema> })
></Input>
</Flex>
</FormControl>
<FormControl mt={4}>
<Flex alignItems={'center'} mt={4}>
<Box flex={'0 0 80px'} w={0}>
:
</Box>
<Box>{getValues('service.modelName')}</Box>
</Flex>
<Flex alignItems={'center'} mt={4}>
<Box flex={'0 0 80px'} w={0}>
:
</Box>
<Box>
{formatPrice(
modelList.find((item) => item.model === getValues('service.modelName'))?.price || 0,
1000
)}
/1K tokens()
</Box>
</Flex>
<Flex mt={5} alignItems={'center'}>
<Box flex={'0 0 80px'}>:</Box>
<Button
colorScheme={'gray'}
variant={'outline'}
size={'sm'}
onClick={openConfirm(handleDelModel)}
>
</Button>
</Flex>
{/* <FormControl mt={4}>
<Box mb={1}>:</Box>
<Textarea
rows={5}
@@ -47,7 +91,7 @@ const ModelEditForm = ({ formHooks }: { formHooks: UseFormReturn<ModelSchema> })
{...register('intro')}
placeholder={'模型的介绍,仅做展示,不影响模型的效果'}
/>
</FormControl>
</FormControl> */}
</Card>
<Card p={4}>
<Box fontWeight={'bold'}></Box>
@@ -94,15 +138,24 @@ const ModelEditForm = ({ formHooks }: { formHooks: UseFormReturn<ModelSchema> })
</Flex>
</FormControl>
<Box mt={4}>
<Box mb={1}></Box>
<Textarea
rows={6}
maxLength={-1}
{...register('systemPrompt')}
placeholder={
'模型默认的 prompt 词,通过调整该内容,可以生成一个限定范围的模型。\n\n注意改功能会影响对话的整体朝向'
}
/>
{canTrain ? (
<Box fontWeight={'bold'}>
prompt
使 tokens
</Box>
) : (
<>
<Box mb={1}></Box>
<Textarea
rows={6}
maxLength={-1}
{...register('systemPrompt')}
placeholder={
'模型默认的 prompt 词,通过调整该内容,可以生成一个限定范围的模型。\n\n注意改功能会影响对话的整体朝向'
}
/>
</>
)}
</Box>
</Card>
{/* <Card p={4}>
@@ -202,6 +255,7 @@ const ModelEditForm = ({ formHooks }: { formHooks: UseFormReturn<ModelSchema> })
</Flex>
</FormControl>
</Card> */}
<ConfirmChild />
</>
);
};

View File

@@ -0,0 +1,155 @@
import React, { useState, useCallback } from 'react';
import {
Box,
Flex,
Button,
Modal,
ModalOverlay,
ModalContent,
ModalHeader,
ModalCloseButton,
ModalBody
} from '@chakra-ui/react';
import { useToast } from '@/hooks/useToast';
import { useSelectFile } from '@/hooks/useSelectFile';
import { customAlphabet } from 'nanoid';
import { encode } from 'gpt-token-utils';
import { useConfirm } from '@/hooks/useConfirm';
import { readTxtContent, readPdfContent, readDocContent } from '@/utils/tools';
import { useMutation } from '@tanstack/react-query';
import { postModelDataFileText } from '@/api/model';
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
const fileExtension = '.txt,.doc,.docx,.pdf,.md';
const SelectFileModal = ({
onClose,
onSuccess,
modelId
}: {
onClose: () => void;
onSuccess: () => void;
modelId: string;
}) => {
const [selecting, setSelecting] = useState(false);
const { toast } = useToast();
const { File, onOpen } = useSelectFile({ fileType: fileExtension, multiple: true });
const [fileText, setFileText] = useState('');
const { openConfirm, ConfirmChild } = useConfirm({
content: '确认导入该文件,需要一定时间进行拆解,该任务无法终止!'
});
const onSelectFile = useCallback(
async (e: File[]) => {
setSelecting(true);
try {
const fileTexts = (
await Promise.all(
e.map((file) => {
// @ts-ignore
const extension = file?.name?.split('.').pop().toLowerCase();
switch (extension) {
case 'txt':
case 'md':
return readTxtContent(file);
case 'pdf':
return readPdfContent(file);
case 'doc':
case 'docx':
return readDocContent(file);
default:
return '';
}
})
)
)
.join('\n')
.replace(/\n+/g, '\n');
setFileText(fileTexts);
console.log(encode(fileTexts));
} catch (error: any) {
console.log(error);
toast({
title: typeof error === 'string' ? error : '解析文件失败',
status: 'error'
});
}
setSelecting(false);
},
[setSelecting, toast]
);
const { mutate, isLoading } = useMutation({
mutationFn: async () => {
if (!fileText) return;
await postModelDataFileText(modelId, fileText);
toast({
title: '导入数据成功,需要一段拆解和训练',
status: 'success'
});
onClose();
onSuccess();
},
onError() {
toast({
title: '导入文件失败',
status: 'error'
});
}
});
return (
<Modal isOpen={true} onClose={onClose}>
<ModalOverlay />
<ModalContent maxW={'min(900px, 90vw)'} position={'relative'}>
<ModalHeader></ModalHeader>
<ModalCloseButton />
<ModalBody>
<Flex
flexDirection={'column'}
p={2}
h={'100%'}
alignItems={'center'}
justifyContent={'center'}
fontSize={'sm'}
>
<Button isLoading={selecting} onClick={onOpen}>
</Button>
<Box mt={2}> {fileExtension} . </Box>
<Box mt={2}>
{fileText.length} {encode(fileText).length} tokens
</Box>
<Box
h={'300px'}
w={'100%'}
overflow={'auto'}
p={2}
backgroundColor={'blackAlpha.50'}
whiteSpace={'pre'}
fontSize={'xs'}
>
{fileText}
</Box>
</Flex>
</ModalBody>
<Flex px={6} pt={2} pb={4}>
<Box flex={1}></Box>
<Button variant={'outline'} mr={3} onClick={onClose}>
</Button>
<Button isLoading={isLoading} isDisabled={fileText === ''} onClick={openConfirm(mutate)}>
</Button>
</Flex>
</ModalContent>
<ConfirmChild />
<File onSelect={onSelectFile} />
</Modal>
);
};
export default SelectFileModal;

View File

@@ -1,37 +1,27 @@
import React, { useCallback, useState, useRef, useMemo, useEffect } from 'react';
import { useRouter } from 'next/router';
import {
getModelById,
delModelById,
postTrainModel,
putModelTrainingStatus,
putModelById
} from '@/api/model';
import { getModelById, delModelById, putModelTrainingStatus, putModelById } from '@/api/model';
import { getChatSiteId } from '@/api/chat';
import type { ModelSchema } from '@/types/mongoSchema';
import { Card, Box, Flex, Button, Tag, Grid } from '@chakra-ui/react';
import { useToast } from '@/hooks/useToast';
import { useConfirm } from '@/hooks/useConfirm';
import { useForm } from 'react-hook-form';
import { formatModelStatus, ModelStatusEnum, modelList, defaultModel } from '@/constants/model';
import { useGlobalStore } from '@/store/global';
import { useScreen } from '@/hooks/useScreen';
import ModelEditForm from './components/ModelEditForm';
import Icon from '@/components/Iconfont';
import { useQuery } from '@tanstack/react-query';
import dynamic from 'next/dynamic';
const Training = dynamic(() => import('./components/Training'));
const ModelDataCard = dynamic(() => import('./components/ModelDataCard'));
const ModelDetail = ({ modelId }: { modelId: string }) => {
const { toast } = useToast();
const router = useRouter();
const { isPc, media } = useScreen();
const { setLoading } = useGlobalStore();
const { openConfirm, ConfirmChild } = useConfirm({
content: '确认删除该模型?'
});
const SelectFileDom = useRef<HTMLInputElement>(null);
// const SelectFileDom = useRef<HTMLInputElement>(null);
const [model, setModel] = useState<ModelSchema>(defaultModel);
const formHooks = useForm<ModelSchema>({
defaultValues: model
@@ -39,7 +29,7 @@ const ModelDetail = ({ modelId }: { modelId: string }) => {
const canTrain = useMemo(() => {
const openai = modelList.find((item) => item.model === model?.service.modelName);
return openai && openai.trainName;
return !!(openai && openai.trainName);
}, [model]);
/* 加载模型数据 */
@@ -91,34 +81,34 @@ const ModelDetail = ({ modelId }: { modelId: string }) => {
}, [setLoading, model, router]);
/* 上传数据集,触发微调 */
const startTraining = useCallback(
async (e: React.ChangeEvent<HTMLInputElement>) => {
if (!modelId || !e.target.files || e.target.files?.length === 0) return;
setLoading(true);
try {
const file = e.target.files[0];
const formData = new FormData();
formData.append('file', file);
await postTrainModel(modelId, formData);
// const startTraining = useCallback(
// async (e: React.ChangeEvent<HTMLInputElement>) => {
// if (!modelId || !e.target.files || e.target.files?.length === 0) return;
// setLoading(true);
// try {
// const file = e.target.files[0];
// const formData = new FormData();
// formData.append('file', file);
// await postTrainModel(modelId, formData);
toast({
title: '开始训练...',
status: 'success'
});
// toast({
// title: '开始训练...',
// status: 'success'
// });
// 重新获取模型
loadModel();
} catch (err: any) {
toast({
title: err?.message || '上传文件失败',
status: 'error'
});
console.log('error->', err);
}
setLoading(false);
},
[setLoading, loadModel, modelId, toast]
);
// // 重新获取模型
// loadModel();
// } catch (err: any) {
// toast({
// title: err?.message || '上传文件失败',
// status: 'error'
// });
// console.log('error->', err);
// }
// setLoading(false);
// },
// [setLoading, loadModel, modelId, toast]
// );
/* 点击更新模型状态 */
const handleClickUpdateStatus = useCallback(async () => {
@@ -250,87 +240,34 @@ const ModelDetail = ({ modelId }: { modelId: string }) => {
)}
</Card>
<Grid mt={5} gridTemplateColumns={media('1fr 1fr', '1fr')} gridGap={5}>
<ModelEditForm formHooks={formHooks} />
<ModelEditForm formHooks={formHooks} handleDelModel={handleDelModel} canTrain={canTrain} />
{canTrain && (
{/* {canTrain && (
<Card p={4}>
<Training model={model} />
</Card>
)} */}
{canTrain && model._id && (
<Card
p={4}
height={'700px'}
{...media(
{
gridColumnStart: 1,
gridColumnEnd: 3
},
{}
)}
>
<ModelDataCard model={model} />
</Card>
)}
<Card p={4}>
<Box fontWeight={'bold'} fontSize={'lg'}>
</Box>
<Flex mt={5} alignItems={'center'}>
<Box flex={'0 0 80px'}>:</Box>
<Button
size={'sm'}
onClick={() => {
SelectFileDom.current?.click();
}}
title={!canTrain ? '模型不支持微调' : ''}
isDisabled={!canTrain}
>
</Button>
<Flex
as={'a'}
href="/TrainingTemplate.jsonl"
download
ml={5}
cursor={'pointer'}
alignItems={'center'}
color={'blue.500'}
>
<Icon name={'icon-yunxiazai'} color={'#3182ce'} />
</Flex>
</Flex>
{/* 提示 */}
<Box mt={3} py={3} color={'blackAlpha.600'}>
<Box as={'li'} lineHeight={1.9}>
使openai key
</Box>
<Box as={'li'} lineHeight={1.9}>
使
<Box
as={'span'}
fontWeight={'bold'}
textDecoration={'underline'}
color={'blackAlpha.800'}
mx={2}
cursor={'pointer'}
onClick={() => router.push('/data/list')}
>
</Box>
</Box>
<Box as={'li'} lineHeight={1.9}>
prompt completion
</Box>
<Box as={'li'} lineHeight={1.9}>
prompt {'</s>'}
</Box>
<Box as={'li'} lineHeight={1.9}>
completion {'</s>'}
</Box>
</Box>
<Flex mt={5} alignItems={'center'}>
<Box flex={'0 0 80px'}>:</Box>
<Button colorScheme={'red'} size={'sm'} onClick={openConfirm(handleDelModel)}>
</Button>
</Flex>
</Card>
</Grid>
{/* 文件选择 */}
<Box position={'absolute'} w={0} h={0} overflow={'hidden'}>
{/* <Box position={'absolute'} w={0} h={0} overflow={'hidden'}>
<input ref={SelectFileDom} type="file" accept=".jsonl" onChange={startTraining} />
</Box>
<ConfirmChild />
</Box> */}
</>
);
};

View File

@@ -1,29 +1,26 @@
import { DataItem } from '@/service/mongo';
import { SplitData, ModelData } from '@/service/mongo';
import { getOpenAIApi } from '@/service/utils/chat';
import { httpsAgent, getOpenApiKey } from '@/service/utils/tools';
import type { ChatCompletionRequestMessage } from 'openai';
import { DataItemSchema } from '@/types/mongoSchema';
import { ChatModelNameEnum } from '@/constants/model';
import { pushSplitDataBill } from '@/service/events/pushBill';
import { generateVector } from './generateVector';
import { customAlphabet } from 'nanoid';
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
export async function generateQA(next = false): Promise<any> {
if (process.env.NODE_ENV === 'development') return;
if (global.generatingQA && !next) return;
global.generatingQA = true;
const systemPrompt: ChatCompletionRequestMessage = {
role: 'system',
content: `总结助手。我会向你发送一段长文本,请从中总结出5至15个问题和答案,答案请尽量详细,按以下格式返回: Q1:\nA1:\nQ2:\nA2:\n`
content: `总结助手。我会向你发送一段长文本,请从中总结出5至15个问题和答案,答案请尽量详细,按以下格式返回: Q1:\nA1:\nQ2:\nA2:\n`
};
let dataItem: DataItemSchema | null = null;
try {
// 找出一个需要生成的 dataItem
dataItem = await DataItem.findOne({
status: { $ne: 0 },
times: { $gt: 0 },
type: 'QA'
const dataItem = await SplitData.findOne({
textList: { $exists: true, $ne: [] }
});
if (!dataItem) {
@@ -32,10 +29,13 @@ export async function generateQA(next = false): Promise<any> {
return;
}
// 更新状态为生成中
await DataItem.findByIdAndUpdate(dataItem._id, {
status: 2
});
// 弹出文本
await SplitData.findByIdAndUpdate(dataItem._id, { $pop: { textList: 1 } });
const text = dataItem.textList[dataItem.textList.length - 1];
if (!text) {
throw new Error('无文本');
}
// 获取 openapi Key
let userApiKey, systemKey;
@@ -44,10 +44,10 @@ export async function generateQA(next = false): Promise<any> {
userApiKey = key.userApiKey;
systemKey = key.systemKey;
} catch (error) {
// 余额不够了, 把用户所有记录改成闲置
await DataItem.updateMany({
userId: dataItem.userId,
status: 0
// 余额不够了, 清空该记录
await SplitData.findByIdAndUpdate(dataItem._id, {
textList: [],
errorText: '余额不足,生成数据集任务终止'
});
throw new Error('获取 openai key 失败');
}
@@ -59,84 +59,71 @@ export async function generateQA(next = false): Promise<any> {
// 获取 openai 请求实例
const chatAPI = getOpenAIApi(userApiKey || systemKey);
// 请求 chatgpt 获取回答
const response = await Promise.allSettled(
[0.2, 0.8].map(
(temperature) =>
chatAPI
.createChatCompletion(
{
model: ChatModelNameEnum.GPT35,
temperature: temperature,
n: 1,
messages: [
systemPrompt,
{
role: 'user',
content: dataItem?.text || ''
}
]
},
{
timeout: 120000,
httpsAgent
}
)
.then((res) => ({
rawContent: res?.data.choices[0].message?.content || '',
result: splitText(res?.data.choices[0].message?.content || '')
})) // 从 content 中提取 QA
)
);
// 过滤出成功的响应
const successResponse: {
rawContent: string;
result: { q: string; a: string }[];
}[] = response.filter((item) => item.status === 'fulfilled').map((item: any) => item.value);
const rawContents = successResponse.map((item) => item.rawContent);
const results = successResponse.map((item) => item.result).flat();
// 插入数据库,并修改状态
await DataItem.findByIdAndUpdate(dataItem._id, {
status: 0,
$push: {
rawResponse: {
$each: successResponse.map((item) => item.rawContent)
const response = await chatAPI
.createChatCompletion(
{
model: ChatModelNameEnum.GPT35,
temperature: 0.2,
n: 1,
messages: [
systemPrompt,
{
role: 'user',
content: text
}
]
},
result: {
$each: results
{
timeout: 120000,
httpsAgent
}
}
});
)
.then((res) => ({
rawContent: res?.data.choices[0].message?.content || '',
result: splitText(res?.data.choices[0].message?.content || '')
})); // 从 content 中提取 QA
// 插入 modelData 表,生成向量
await ModelData.insertMany(
response.result.map((item) => ({
modelId: dataItem.modelId,
userId: dataItem.userId,
text: item.a,
q: [
{
id: nanoid(),
text: item.q
}
],
status: 1
}))
);
console.log(
'生成QA成功time:',
`${(Date.now() - startTime) / 1000}s`,
'QA数量',
results.length
response.result.length
);
// 计费
pushSplitDataBill({
isPay: !userApiKey && results.length > 0,
isPay: !userApiKey && response.result.length > 0,
userId: dataItem.userId,
type: 'QA',
text: systemPrompt.content + dataItem.text + rawContents.join('')
text: systemPrompt.content + text + response.rawContent
});
} catch (error: any) {
console.log('error: 生成QA错误', dataItem?._id);
console.log('response:', error?.response);
if (dataItem?._id) {
await DataItem.findByIdAndUpdate(dataItem._id, {
status: dataItem.times > 0 ? 1 : 0, // 还有重试次数则可以继续进行
$inc: {
// 剩余尝试次数-1
times: -1
}
});
}
}
generateQA(true);
generateQA(true);
generateVector(true);
} catch (error: any) {
console.log(error);
console.log('生成QA错误:', error?.response);
setTimeout(() => {
generateQA(true);
}, 10000);
}
}
/**

View File

@@ -0,0 +1,88 @@
import { getOpenAIApi } from '@/service/utils/chat';
import { httpsAgent } from '@/service/utils/tools';
import { ModelData } from '../models/modelData';
import { connectRedis } from '../redis';
import { VecModelDataIndex } from '@/constants/redis';
export async function generateVector(next = false): Promise<any> {
if (global.generatingVector && !next) return;
global.generatingVector = true;
try {
const redis = await connectRedis();
// 找出一个需要生成的 dataItem
const dataItem = await ModelData.findOne({
status: { $ne: 0 }
});
if (!dataItem) {
console.log('没有需要生成 【向量】 的数据');
global.generatingVector = false;
return;
}
// 获取 openapi Key
const openAiKey = process.env.OPENAIKEY as string;
// 获取 openai 请求实例
const chatAPI = getOpenAIApi(openAiKey);
const dataId = String(dataItem._id);
// 生成词向量
const response = await Promise.allSettled(
dataItem.q.map((item, i) =>
chatAPI
.createEmbedding(
{
model: 'text-embedding-ada-002',
input: item.text
},
{
timeout: 120000,
httpsAgent
}
)
.then((res) => res?.data?.data?.[0]?.embedding || [])
.then((vector) =>
redis.sendCommand([
'JSON.SET',
`${VecModelDataIndex}:${dataId}:${i}`,
'$',
JSON.stringify({
dataId,
modelId: String(dataItem.modelId),
vector
})
])
)
)
);
if (response.filter((item) => item.status === 'fulfilled').length === 0) {
throw new Error(JSON.stringify(response));
}
// 修改该数据状态
await ModelData.findByIdAndUpdate(dataItem._id, {
status: 0
});
console.log(`生成向量成功: ${dataItem._id}`);
setTimeout(() => {
generateVector(true);
}, 3000);
} catch (error: any) {
console.log(error);
console.log('error: 生成向量错误', error?.response?.data);
if (error?.response?.statusText === 'Too Many Requests') {
console.log('次数限制1分钟后尝试');
// 限制次数1分钟后再试
setTimeout(() => {
generateVector(true);
}, 60000);
}
}
}

View File

@@ -34,7 +34,7 @@ export const pushChatBill = async ({
// 计算价格
const unitPrice = modelItem?.price || 5;
const price = unitPrice * tokens.length;
console.log(`chat bill, price: ${formatPrice(price)}`);
console.log(`chat bill, unit price: ${unitPrice}, price: ${formatPrice(price)}`);
try {
// 插入 Bill 记录

View File

@@ -13,22 +13,23 @@ const ModelDataSchema = new Schema({
ref: 'user',
required: true
},
q: {
text: {
type: String,
required: true
},
a: {
type: String,
default: ''
q: {
type: [
{
id: String, // 对应redis的key
text: String
}
],
default: []
},
status: {
type: Number,
enum: [0, 1, 2],
enum: [0, 1], // 1 训练ing
default: 1
},
createTime: {
type: Date,
default: () => new Date()
}
});

View File

@@ -0,0 +1,31 @@
/* 模型的知识库 */
import { Schema, model, models, Model as MongoModel } from 'mongoose';
import { ModelSplitDataSchema as SplitDataType } from '@/types/mongoSchema';
const SplitDataSchema = new Schema({
userId: {
type: Schema.Types.ObjectId,
ref: 'user',
required: true
},
modelId: {
type: Schema.Types.ObjectId,
ref: 'model',
required: true
},
rawText: {
type: String,
required: true
},
textList: {
type: [String],
default: []
},
errorText: {
type: String,
default: ''
}
});
export const SplitData: MongoModel<SplitDataType> =
models['splitData'] || model('splitData', SplitDataSchema);

View File

@@ -1,6 +1,7 @@
import mongoose from 'mongoose';
import { generateQA } from './events/generateQA';
import { generateAbstract } from './events/generateAbstract';
import { generateVector } from './events/generateVector';
/**
* 连接 MongoDB 数据库
@@ -27,7 +28,8 @@ export async function connectToDatabase(): Promise<void> {
}
generateQA();
generateAbstract();
// generateAbstract();
generateVector();
}
export * from './models/authCode';
@@ -40,3 +42,4 @@ export * from './models/bill';
export * from './models/pay';
export * from './models/data';
export * from './models/dataItem';
export * from './models/splitData';

45
src/service/redis.ts Normal file
View File

@@ -0,0 +1,45 @@
import { createClient } from 'redis';
import { customAlphabet } from 'nanoid';
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 10);
export const connectRedis = async () => {
// 断开了,重连
if (global.redisClient && !global.redisClient.isOpen) {
await global.redisClient.disconnect();
} else if (global.redisClient) {
// 没断开,不再连接
return global.redisClient;
}
try {
global.redisClient = createClient({
url: process.env.REDIS_URL
});
global.redisClient.on('error', (err) => {
console.log('Redis Client Error', err);
global.redisClient = null;
});
global.redisClient.on('end', () => {
global.redisClient = null;
});
global.redisClient.on('ready', () => {
console.log('redis connected');
});
await global.redisClient.connect();
// 0 - 测试库1 - 正式
await global.redisClient.select(0);
return global.redisClient;
} catch (error) {
console.log(error, '==');
global.redisClient = null;
return Promise.reject('redis 连接失败');
}
};
export const getKey = (prefix = '') => {
return `${prefix}:${nanoid()}`;
};

View File

@@ -119,3 +119,21 @@ export const openaiChatFilter = (prompts: ChatItemType[], maxTokens: number) =>
return systemPrompt ? [systemPrompt, ...res] : res;
};
/* system 内容截断 */
export const systemPromptFilter = (prompts: string[], maxTokens: number) => {
let splitText = '';
// 从前往前截取
for (let i = 0; i < prompts.length; i++) {
const prompt = prompts[i];
splitText += `${prompt}\n`;
const tokens = encode(splitText).length;
if (tokens >= maxTokens) {
break;
}
}
return splitText;
};

View File

@@ -1,9 +1,12 @@
import type { Mongoose } from 'mongoose';
import type { RedisClientType } from 'redis';
declare global {
var mongodb: Mongoose | string | null;
var redisClient: RedisClientType | null;
var generatingQA: boolean;
var generatingAbstract: boolean;
var generatingVector: boolean;
var QRCode: any;
interface Window {
['pdfjs-dist/build/pdf']: any;

View File

@@ -8,3 +8,12 @@ export interface ModelUpdateParams {
service: ModelSchema.service;
security: ModelSchema.security;
}
export interface ModelDataItemType {
id: string;
status: 0 | 1; // 1代表向量生成完毕
q: string; // 提问词
a: string; // 原文
modelId: string;
userId: string;
}

View File

@@ -51,12 +51,26 @@ export interface ModelPopulate extends ModelSchema {
userId: UserModelSchema;
}
export type ModelDataType = 0 | 1;
export interface ModelDataSchema {
_id: string;
q: string;
a: string;
status: 0 | 1 | 2;
createTime: Date;
modelId: string;
userId: string;
text: string;
q: {
id: string;
text: string;
}[];
status: ModelDataType;
}
export interface ModelSplitDataSchema {
_id: string;
userId: string;
modelId: string;
rawText: string;
errorText: string;
textList: string[];
}
export interface TrainingSchema {

6
src/types/redis.d.ts vendored Normal file
View File

@@ -0,0 +1,6 @@
export interface RedisModelDataItemType {
id: string;
vector: number[];
dataId: string;
modelId: string;
}

View File

@@ -124,3 +124,15 @@ export const readDocContent = (file: File) =>
reject('读取 doc 文件失败');
};
});
export const vectorToBuffer = (vector: number[]) => {
const float32Arr = new Float32Array(vector);
const myBuffer = new ArrayBuffer(float32Arr.length * Float32Array.BYTES_PER_ELEMENT);
const myView = new DataView(myBuffer);
for (let i = 0; i < float32Arr.length; i++) {
myView.setFloat32(i * Float32Array.BYTES_PER_ELEMENT, float32Arr[i], true);
}
return Buffer.from(myBuffer);
};