perf: model framwork

This commit is contained in:
archer
2023-04-29 15:55:47 +08:00
parent cd9acab938
commit 78762498eb
30 changed files with 649 additions and 757 deletions

View File

@@ -12,8 +12,7 @@ export const getMyModels = () => GET<ModelSchema[]>('/model/list');
/**
* 创建一个模型
*/
export const postCreateModel = (data: { name: string; serviceModelName: string }) =>
POST<ModelSchema>('/model/create', data);
export const postCreateModel = (data: { name: string }) => POST<string>('/model/create', data);
/**
* 根据 ID 删除模型

View File

@@ -7,7 +7,6 @@ export type InitChatResponse = {
name: string;
avatar: string;
intro: string;
chatModel: ModelSchema.service.chatModel; // 对话模型名
modelName: ModelSchema.service.modelName; // 底层模型
chatModel: ModelSchema['chat']['chatModel']; // 对话模型名
history: ChatItemType[];
};

View File

@@ -1,50 +1,32 @@
import type { ModelSchema } from '@/types/mongoSchema';
export const embeddingModel = 'text-embedding-ada-002';
export enum ChatModelEnum {
'GPT35' = 'gpt-3.5-turbo',
'GPT4' = 'gpt-4',
'GPT432k' = 'gpt-4-32k'
}
export enum ModelNameEnum {
GPT35 = 'gpt-3.5-turbo',
VECTOR_GPT = 'VECTOR_GPT'
}
export const Model2ChatModelMap: Record<`${ModelNameEnum}`, `${ChatModelEnum}`> = {
[ModelNameEnum.GPT35]: 'gpt-3.5-turbo',
[ModelNameEnum.VECTOR_GPT]: 'gpt-3.5-turbo'
export const ChatModelMap = {
// ui name
[ChatModelEnum.GPT35]: 'ChatGpt',
[ChatModelEnum.GPT4]: 'Gpt4',
[ChatModelEnum.GPT432k]: 'Gpt4-32k'
};
export type ModelConstantsData = {
icon: 'model' | 'dbModel';
name: string;
model: `${ModelNameEnum}`;
trainName: string; // 空字符串代表不能训练
export type ChatModelConstantType = {
chatModel: `${ChatModelEnum}`;
contextMaxToken: number;
maxTemperature: number;
price: number; // 多少钱 / 1token单位: 0.00001元
};
export const modelList: ModelConstantsData[] = [
export const modelList: ChatModelConstantType[] = [
{
icon: 'model',
name: 'chatGPT',
model: ModelNameEnum.GPT35,
trainName: '',
chatModel: ChatModelEnum.GPT35,
contextMaxToken: 4096,
maxTemperature: 1.5,
price: 3
},
{
icon: 'dbModel',
name: '知识库',
model: ModelNameEnum.VECTOR_GPT,
trainName: 'vector',
contextMaxToken: 4096,
maxTemperature: 1,
price: 3
}
];
@@ -115,14 +97,16 @@ export const ModelVectorSearchModeMap: Record<
export const defaultModel: ModelSchema = {
_id: 'modelId',
userId: 'userId',
name: 'modelName',
name: '模型名称',
avatar: '/icon/logo.png',
status: ModelStatusEnum.pending,
updateTime: Date.now(),
systemPrompt: '',
temperature: 5,
search: {
mode: ModelVectorSearchModeEnum.hightSimilarity
chat: {
useKb: false,
searchMode: ModelVectorSearchModeEnum.hightSimilarity,
systemPrompt: '',
temperature: 0,
chatModel: ChatModelEnum.GPT35
},
share: {
isShare: false,
@@ -130,10 +114,6 @@ export const defaultModel: ModelSchema = {
intro: '',
collection: 0
},
service: {
chatModel: ModelNameEnum.GPT35,
modelName: ModelNameEnum.GPT35
},
security: {
domain: ['*'],
contextMaxLen: 1,

View File

@@ -1,13 +1,14 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { connectToDatabase } from '@/service/mongo';
import { getOpenAIApi, authChat } from '@/service/utils/auth';
import { axiosConfig, openaiChatFilter } from '@/service/utils/tools';
import { axiosConfig, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools';
import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response';
import { PassThrough } from 'stream';
import { modelList } from '@/constants/model';
import { modelList, ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill';
import { gpt35StreamResponse } from '@/service/utils/openai';
import { searchKb_openai } from '@/service/tools/searchKb';
/* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
@@ -46,7 +47,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
authorization
});
const modelConstantsData = modelList.find((item) => item.model === model.service.modelName);
const modelConstantsData = modelList.find((item) => item.chatModel === model.chat.chatModel);
if (!modelConstantsData) {
throw new Error('模型加载异常');
}
@@ -54,31 +55,84 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
// 读取对话内容
const prompts = [...content, prompt];
// 如果有系统提示词,自动插入
if (model.systemPrompt) {
prompts.unshift({
obj: 'SYSTEM',
value: model.systemPrompt
// 使用了知识库搜索
if (model.chat.useKb) {
const { systemPrompts } = await searchKb_openai({
apiKey: userApiKey || systemKey,
isPay: !userApiKey,
text: prompt.value,
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22,
modelId,
userId
});
// filter system prompt
if (
systemPrompts.length === 0 &&
model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity
) {
return res.send('对不起,你的问题不在知识库中。');
}
/* 高相似度+无上下文,不添加额外知识,仅用系统提示词 */
if (
systemPrompts.length === 0 &&
model.chat.searchMode === ModelVectorSearchModeEnum.noContext
) {
prompts.unshift({
obj: 'SYSTEM',
value: model.chat.systemPrompt
});
} else {
// 有匹配情况下system 添加知识库内容。
// 系统提示词过滤,最多 2500 tokens
const filterSystemPrompt = systemPromptFilter({
model: model.chat.chatModel,
prompts: systemPrompts,
maxTokens: 2500
});
prompts.unshift({
obj: 'SYSTEM',
value: `
${model.chat.systemPrompt}
${
model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity
? `不回答知识库外的内容.`
: ''
}
知识库内容为: ${filterSystemPrompt}'
`
});
}
} else {
// 没有用知识库搜索,仅用系统提示词
if (model.chat.systemPrompt) {
prompts.unshift({
obj: 'SYSTEM',
value: model.chat.systemPrompt
});
}
}
// 控制 tokens 数量,防止超出
// 控制 tokens 数量,防止超出
const filterPrompts = openaiChatFilter({
model: model.service.chatModel,
model: model.chat.chatModel,
prompts,
maxTokens: modelConstantsData.contextMaxToken - 500
});
// 计算温度
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10);
const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
2
);
// console.log(filterPrompts);
// 获取 chatAPI
const chatAPI = getOpenAIApi(userApiKey || systemKey);
// 发出请求
const chatResponse = await chatAPI.createChatCompletion(
{
model: model.service.chatModel,
temperature,
model: model.chat.chatModel,
temperature: Number(temperature) || 0,
messages: filterPrompts,
frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容
@@ -105,7 +159,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
// 只有使用平台的 key 才计费
pushChatBill({
isPay: !userApiKey,
modelName: model.service.modelName,
chatModel: model.chat.chatModel,
userId,
chatId,
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })

View File

@@ -59,8 +59,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
name: model.name,
avatar: model.avatar,
intro: model.share.intro,
modelName: model.service.modelName,
chatModel: model.service.chatModel,
chatModel: model.chat.chatModel,
history
}
});

View File

@@ -1,189 +0,0 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { connectToDatabase } from '@/service/mongo';
import { authChat } from '@/service/utils/auth';
import { axiosConfig, systemPromptFilter, openaiChatFilter } from '@/service/utils/tools';
import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response';
import { PassThrough } from 'stream';
import {
modelList,
ModelVectorSearchModeMap,
ModelVectorSearchModeEnum,
ModelDataStatusEnum
} from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill';
import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai';
import dayjs from 'dayjs';
import { PgClient } from '@/service/pg';
/* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
let step = 0; // step=1时表示开始了流响应
const stream = new PassThrough();
stream.on('error', () => {
console.log('error: ', 'stream error');
stream.destroy();
});
res.on('close', () => {
stream.destroy();
});
res.on('error', () => {
console.log('error: ', 'request error');
stream.destroy();
});
try {
const { modelId, chatId, prompt } = req.body as {
modelId: string;
chatId: '' | string;
prompt: ChatItemType;
};
const { authorization } = req.headers;
if (!modelId || !prompt) {
throw new Error('缺少参数');
}
await connectToDatabase();
let startTime = Date.now();
const { model, content, userApiKey, systemKey, userId } = await authChat({
modelId,
chatId,
authorization
});
const modelConstantsData = modelList.find((item) => item.model === model.service.modelName);
if (!modelConstantsData) {
throw new Error('模型加载异常');
}
// 读取对话内容
const prompts = [...content, prompt];
// 获取提示词的向量
const { vector: promptVector, chatAPI } = await openaiCreateEmbedding({
isPay: !userApiKey,
apiKey: userApiKey || systemKey,
userId,
text: prompt.value
});
// 相似度搜素
const similarity = ModelVectorSearchModeMap[model.search.mode]?.similarity || 0.22;
const vectorSearch = await PgClient.select<{ id: string; q: string; a: string }>('modelData', {
fields: ['id', 'q', 'a'],
where: [
['status', ModelDataStatusEnum.ready],
'AND',
['model_id', model._id],
'AND',
`vector <=> '[${promptVector}]' < ${similarity}`
],
order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }],
limit: 20
});
const formatRedisPrompt: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`);
/* 高相似度+退出,无法匹配时直接退出 */
if (
formatRedisPrompt.length === 0 &&
model.search.mode === ModelVectorSearchModeEnum.hightSimilarity
) {
return res.send('对不起,你的问题不在知识库中。');
}
/* 高相似度+无上下文,不添加额外知识 */
if (
formatRedisPrompt.length === 0 &&
model.search.mode === ModelVectorSearchModeEnum.noContext
) {
prompts.unshift({
obj: 'SYSTEM',
value: model.systemPrompt
});
} else {
// 有匹配情况下system 添加知识库内容。
// 系统提示词过滤,最多 2500 tokens
const systemPrompt = systemPromptFilter({
model: model.service.chatModel,
prompts: formatRedisPrompt,
maxTokens: 2500
});
prompts.unshift({
obj: 'SYSTEM',
value: `
${model.systemPrompt}
${
model.search.mode === ModelVectorSearchModeEnum.hightSimilarity
? `你只能从知识库选择内容回答.不在知识库内容拒绝回复`
: ''
}
知识库内容为: 当前时间为${dayjs().format('YYYY/MM/DD HH:mm:ss')}\n${systemPrompt}'
`
});
}
// 控制在 tokens 数量,防止超出
const filterPrompts = openaiChatFilter({
model: model.service.chatModel,
prompts,
maxTokens: modelConstantsData.contextMaxToken - 500
});
// console.log(filterPrompts);
// 计算温度
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10);
// 发出请求
const chatResponse = await chatAPI.createChatCompletion(
{
model: model.service.chatModel,
temperature,
messages: filterPrompts,
frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容
stream: true,
stop: ['.!?。']
},
{
timeout: 40000,
responseType: 'stream',
...axiosConfig()
}
);
console.log('api response time:', `${(Date.now() - startTime) / 1000}s`);
step = 1;
const { responseContent } = await gpt35StreamResponse({
res,
stream,
chatResponse
});
// 只有使用平台的 key 才计费
pushChatBill({
isPay: !userApiKey,
modelName: model.service.modelName,
userId,
chatId,
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
});
// jsonRes(res);
} catch (err: any) {
if (step === 1) {
// 直接结束流
console.log('error结束');
stream.destroy();
} else {
res.status(500);
jsonRes(res, {
code: 500,
error: err
});
}
}
}

View File

@@ -3,14 +3,13 @@ import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import { ModelStatusEnum, modelList, ModelNameEnum, Model2ChatModelMap } from '@/constants/model';
import { ModelStatusEnum } from '@/constants/model';
import { Model } from '@/service/models/model';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
const { name, serviceModelName } = req.body as {
const { name } = req.body as {
name: string;
serviceModelName: `${ModelNameEnum}`;
};
const { authorization } = req.headers;
@@ -18,45 +17,32 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
throw new Error('无权操作');
}
if (!name || !serviceModelName) {
if (!name) {
throw new Error('缺少参数');
}
// 凭证校验
const userId = await authToken(authorization);
const modelItem = modelList.find((item) => item.model === serviceModelName);
if (!modelItem) {
throw new Error('模型不存在');
}
await connectToDatabase();
// 上限校验
const authCount = await Model.countDocuments({
userId
});
if (authCount >= 20) {
throw new Error('上限 20 个模型');
if (authCount >= 30) {
throw new Error('上限 30 个模型');
}
// 创建模型
const response = await Model.create({
name,
userId,
status: ModelStatusEnum.running,
service: {
chatModel: Model2ChatModelMap[modelItem.model], // 聊天时用的模型
modelName: modelItem.model // 最底层的模型,不会变,用于计费等核心操作
}
status: ModelStatusEnum.running
});
// 根据 id 获取模型信息
const model = await Model.findById(response._id);
jsonRes(res, {
data: model
data: response._id
});
} catch (err) {
jsonRes(res, {

View File

@@ -9,8 +9,7 @@ import { authModel } from '@/service/utils/auth';
/* 获取我的模型 */
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
const { name, avatar, search, share, service, security, systemPrompt, temperature } =
req.body as ModelUpdateParams;
const { name, avatar, chat, share, security } = req.body as ModelUpdateParams;
const { modelId } = req.query as { modelId: string };
const { authorization } = req.headers;
@@ -18,7 +17,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
throw new Error('无权操作');
}
if (!name || !service || !security || !modelId) {
if (!name || !chat || !security || !modelId) {
throw new Error('参数错误');
}
@@ -41,12 +40,10 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
{
name,
avatar,
systemPrompt,
temperature,
chat,
'share.isShare': share.isShare,
'share.isShareDetail': share.isShareDetail,
'share.intro': share.intro,
search,
security
}
);

View File

@@ -0,0 +1,202 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { connectToDatabase } from '@/service/mongo';
import { getOpenAIApi, authOpenApiKey, authModel } from '@/service/utils/auth';
import { axiosConfig, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools';
import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response';
import { PassThrough } from 'stream';
import { modelList, ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill';
import { gpt35StreamResponse } from '@/service/utils/openai';
import { searchKb_openai } from '@/service/tools/searchKb';
/* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
let step = 0; // step=1时表示开始了流响应
const stream = new PassThrough();
stream.on('error', () => {
console.log('error: ', 'stream error');
stream.destroy();
});
res.on('close', () => {
stream.destroy();
});
res.on('error', () => {
console.log('error: ', 'request error');
stream.destroy();
});
try {
const {
prompts,
modelId,
isStream = true
} = req.body as {
prompts: ChatItemType[];
modelId: string;
isStream: boolean;
};
if (!prompts || !modelId) {
throw new Error('缺少参数');
}
if (!Array.isArray(prompts)) {
throw new Error('prompts is not array');
}
if (prompts.length > 30 || prompts.length === 0) {
throw new Error('prompts length range 1-30');
}
await connectToDatabase();
let startTime = Date.now();
/* 凭证校验 */
const { apiKey, userId } = await authOpenApiKey(req);
const { model } = await authModel({
userId,
modelId
});
const modelConstantsData = modelList.find((item) => item.chatModel === model.chat.chatModel);
if (!modelConstantsData) {
throw new Error('模型加载异常');
}
// 使用了知识库搜索
if (model.chat.useKb) {
const similarity = ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22;
const { systemPrompts } = await searchKb_openai({
apiKey,
isPay: true,
text: prompts[prompts.length - 1].value,
similarity,
modelId,
userId
});
// filter system prompt
if (
systemPrompts.length === 0 &&
model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity
) {
return jsonRes(res, {
code: 500,
message: '对不起,你的问题不在知识库中。',
data: '对不起,你的问题不在知识库中。'
});
}
/* 高相似度+无上下文,不添加额外知识,仅用系统提示词 */
if (
systemPrompts.length === 0 &&
model.chat.searchMode === ModelVectorSearchModeEnum.noContext
) {
prompts.unshift({
obj: 'SYSTEM',
value: model.chat.systemPrompt
});
} else {
// 有匹配情况下system 添加知识库内容。
// 系统提示词过滤,最多 2500 tokens
const filterSystemPrompt = systemPromptFilter({
model: model.chat.chatModel,
prompts: systemPrompts,
maxTokens: 2500
});
prompts.unshift({
obj: 'SYSTEM',
value: `
${model.chat.systemPrompt}
${
model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity
? `不回答知识库外的内容.`
: ''
}
知识库内容为: ${filterSystemPrompt}'
`
});
}
} else {
// 没有用知识库搜索,仅用系统提示词
if (model.chat.systemPrompt) {
prompts.unshift({
obj: 'SYSTEM',
value: model.chat.systemPrompt
});
}
}
// 控制总 tokens 数量,防止超出
const filterPrompts = openaiChatFilter({
model: model.chat.chatModel,
prompts,
maxTokens: modelConstantsData.contextMaxToken - 500
});
// 计算温度
const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
2
);
// console.log(filterPrompts);
// 获取 chatAPI
const chatAPI = getOpenAIApi(apiKey);
// 发出请求
const chatResponse = await chatAPI.createChatCompletion(
{
model: model.chat.chatModel,
temperature: Number(temperature) || 0,
messages: filterPrompts,
frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容
stream: isStream,
stop: ['.!?。']
},
{
timeout: 180000,
responseType: isStream ? 'stream' : 'json',
...axiosConfig()
}
);
console.log('api response time:', `${(Date.now() - startTime) / 1000}s`);
let responseContent = '';
if (isStream) {
step = 1;
const streamResponse = await gpt35StreamResponse({
res,
stream,
chatResponse
});
responseContent = streamResponse.responseContent;
} else {
responseContent = chatResponse.data.choices?.[0]?.message?.content || '';
jsonRes(res, {
data: responseContent
});
}
// 只有使用平台的 key 才计费
pushChatBill({
isPay: true,
chatModel: model.chat.chatModel,
userId,
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
});
} catch (err: any) {
if (step === 1) {
// 直接结束流
console.log('error结束');
stream.destroy();
} else {
res.status(500);
jsonRes(res, {
code: 500,
error: err
});
}
}
}

View File

@@ -1,7 +1,7 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { connectToDatabase, Model } from '@/service/mongo';
import { getOpenAIApi } from '@/service/utils/auth';
import { axiosConfig, openaiChatFilter, authOpenApiKey } from '@/service/utils/tools';
import { getOpenAIApi, authOpenApiKey } from '@/service/utils/auth';
import { axiosConfig, openaiChatFilter } from '@/service/utils/tools';
import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response';
import { PassThrough } from 'stream';
@@ -60,37 +60,38 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
throw new Error('无权使用该模型');
}
const modelConstantsData = modelList.find((item) => item.model === model.service.modelName);
const modelConstantsData = modelList.find((item) => item.chatModel === model.chat.chatModel);
if (!modelConstantsData) {
throw new Error('模型加载异常');
}
// 如果有系统提示词,自动插入
if (model.systemPrompt) {
if (model.chat.systemPrompt) {
prompts.unshift({
obj: 'SYSTEM',
value: model.systemPrompt
value: model.chat.systemPrompt
});
}
// 控制在 tokens 数量,防止超出
const filterPrompts = openaiChatFilter({
model: model.service.chatModel,
model: model.chat.chatModel,
prompts,
maxTokens: modelConstantsData.contextMaxToken - 500
});
// console.log(filterPrompts);
// 计算温度
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10);
const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
2
);
// 获取 chatAPI
const chatAPI = getOpenAIApi(apiKey);
// 发出请求
const chatResponse = await chatAPI.createChatCompletion(
{
model: model.service.chatModel,
temperature,
model: model.chat.chatModel,
temperature: Number(temperature) || 0,
messages: filterPrompts,
frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容
@@ -126,7 +127,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
// 只有使用平台的 key 才计费
pushChatBill({
isPay: true,
modelName: model.service.modelName,
chatModel: model.chat.chatModel,
userId,
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
});

View File

@@ -1,20 +1,14 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { connectToDatabase, Model } from '@/service/mongo';
import { getOpenAIApi } from '@/service/utils/auth';
import { authOpenApiKey } from '@/service/utils/tools';
import { getOpenAIApi, authOpenApiKey } from '@/service/utils/auth';
import { axiosConfig, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools';
import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response';
import { PassThrough } from 'stream';
import {
ModelNameEnum,
modelList,
ModelVectorSearchModeMap,
ChatModelEnum
} from '@/constants/model';
import { modelList, ModelVectorSearchModeMap, ChatModelEnum } from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill';
import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai';
import { PgClient } from '@/service/pg';
import { gpt35StreamResponse } from '@/service/utils/openai';
import { searchKb_openai } from '@/service/tools/searchKb';
/* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
@@ -59,10 +53,11 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
throw new Error('找不到模型');
}
const modelConstantsData = modelList.find((item) => item.model === ModelNameEnum.VECTOR_GPT);
const modelConstantsData = modelList.find((item) => item.chatModel === model.chat.chatModel);
if (!modelConstantsData) {
throw new Error('模型已下架');
throw new Error('model is undefined');
}
console.log('laf gpt start');
// 获取 chatAPI
@@ -132,62 +127,48 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
prompt.value += ` ${promptResolve}`;
console.log('prompt resolve success, time:', `${(Date.now() - startTime) / 1000}s`);
// 获取提示词的向量
const { vector: promptVector } = await openaiCreateEmbedding({
isPay: true,
apiKey,
userId,
text: prompt.value
});
// 读取对话内容
const prompts = [prompt];
// 相似度搜索
const similarity = ModelVectorSearchModeMap[model.search.mode]?.similarity || 0.22;
const vectorSearch = await PgClient.select<{ id: string; q: string; a: string }>('modelData', {
fields: ['id', 'q', 'a'],
order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }],
where: [
['model_id', model._id],
'AND',
['user_id', userId],
'AND',
`vector <=> '[${promptVector}]' < ${similarity}`
],
limit: 30
// 获取向量匹配到的提示词
const { systemPrompts } = await searchKb_openai({
isPay: true,
apiKey,
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22,
text: prompt.value,
modelId,
userId
});
const formatRedisPrompt: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`);
// system 筛选,最多 2500 tokens
const systemPrompt = systemPromptFilter({
model: model.service.chatModel,
prompts: formatRedisPrompt,
const filterSystemPrompt = systemPromptFilter({
model: model.chat.chatModel,
prompts: systemPrompts,
maxTokens: 2500
});
prompts.unshift({
obj: 'SYSTEM',
value: `${model.systemPrompt} 知识库是最新的,下面是知识库内容:${systemPrompt}`
value: `${model.chat.systemPrompt} 知识库是最新的,下面是知识库内容:${filterSystemPrompt}`
});
// 控制上下文 tokens 数量,防止超出
const filterPrompts = openaiChatFilter({
model: model.service.chatModel,
model: model.chat.chatModel,
prompts,
maxTokens: modelConstantsData.contextMaxToken - 500
});
// console.log(filterPrompts);
// 计算温度
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10);
const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
2
);
// 发出请求
const chatResponse = await chatAPI.createChatCompletion(
{
model: model.service.chatModel,
temperature,
model: model.chat.chatModel,
temperature: Number(temperature) || 0,
messages: filterPrompts,
frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容
@@ -223,7 +204,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
pushChatBill({
isPay: true,
modelName: model.service.modelName,
chatModel: model.chat.chatModel,
userId,
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
});

View File

@@ -1,24 +1,14 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { connectToDatabase, Model } from '@/service/mongo';
import {
axiosConfig,
systemPromptFilter,
authOpenApiKey,
openaiChatFilter
} from '@/service/utils/tools';
import { axiosConfig, systemPromptFilter, openaiChatFilter } from '@/service/utils/tools';
import { getOpenAIApi, authOpenApiKey } from '@/service/utils/auth';
import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response';
import { PassThrough } from 'stream';
import {
modelList,
ModelVectorSearchModeMap,
ModelVectorSearchModeEnum,
ModelDataStatusEnum
} from '@/constants/model';
import { modelList, ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill';
import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai';
import dayjs from 'dayjs';
import { PgClient } from '@/service/pg';
import { gpt35StreamResponse } from '@/service/utils/openai';
import { searchKb_openai } from '@/service/tools/searchKb';
/* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
@@ -72,96 +62,86 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
throw new Error('无权使用该模型');
}
const modelConstantsData = modelList.find((item) => item.model === model?.service?.modelName);
const modelConstantsData = modelList.find((item) => item.chatModel === model.chat.chatModel);
if (!modelConstantsData) {
throw new Error('模型初始化异常');
}
// 获取提示词的向量
const { vector: promptVector, chatAPI } = await openaiCreateEmbedding({
// 获取向量匹配到的提示词
const { systemPrompts } = await searchKb_openai({
isPay: true,
apiKey,
userId,
text: prompts[prompts.length - 1].value // 取最后一个
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22,
text: prompts[prompts.length - 1].value,
modelId,
userId
});
// 相似度搜素
const similarity = ModelVectorSearchModeMap[model.search.mode]?.similarity || 0.22;
const vectorSearch = await PgClient.select<{ id: string; q: string; a: string }>('modelData', {
fields: ['id', 'q', 'a'],
where: [
['status', ModelDataStatusEnum.ready],
'AND',
['model_id', model._id],
'AND',
`vector <=> '[${promptVector}]' < ${similarity}`
],
order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }],
limit: 20
});
const formatRedisPrompt: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`);
// system 合并
if (prompts[0].obj === 'SYSTEM') {
formatRedisPrompt.unshift(prompts.shift()?.value || '');
systemPrompts.unshift(prompts.shift()?.value || '');
}
/* 高相似度+退出,无法匹配时直接退出 */
if (
formatRedisPrompt.length === 0 &&
model.search.mode === ModelVectorSearchModeEnum.hightSimilarity
systemPrompts.length === 0 &&
model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity
) {
return res.send('对不起,你的问题不在知识库中。');
return jsonRes(res, {
code: 500,
message: '对不起,你的问题不在知识库中。',
data: '对不起,你的问题不在知识库中。'
});
}
/* 高相似度+无上下文,不添加额外知识 */
if (
formatRedisPrompt.length === 0 &&
model.search.mode === ModelVectorSearchModeEnum.noContext
systemPrompts.length === 0 &&
model.chat.searchMode === ModelVectorSearchModeEnum.noContext
) {
prompts.unshift({
obj: 'SYSTEM',
value: model.systemPrompt
value: model.chat.systemPrompt
});
} else {
// 有匹配或者低匹配度模式情况下,添加知识库内容。
// 系统提示词过滤,最多 2500 tokens
const systemPrompt = systemPromptFilter({
model: model.service.chatModel,
prompts: formatRedisPrompt,
model: model.chat.chatModel,
prompts: systemPrompts,
maxTokens: 2500
});
prompts.unshift({
obj: 'SYSTEM',
value: `
${model.systemPrompt}
${model.chat.systemPrompt}
${
model.search.mode === ModelVectorSearchModeEnum.hightSimilarity
? `你只能从知识库选择内容回答.不在知识库内容拒绝回复`
: ''
model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity ? `不回答知识库外的内容.` : ''
}
知识库内容为: 当前时间为${dayjs().format('YYYY/MM/DD HH:mm:ss')}\n${systemPrompt}'
知识库内容为: ${systemPrompt}'
`
});
}
// 控制在 tokens 数量,防止超出
const filterPrompts = openaiChatFilter({
model: model.service.chatModel,
model: model.chat.chatModel,
prompts,
maxTokens: modelConstantsData.contextMaxToken - 500
});
// console.log(filterPrompts);
// 计算温度
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10);
const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
2
);
const chatAPI = getOpenAIApi(apiKey);
// 发出请求
const chatResponse = await chatAPI.createChatCompletion(
{
model: model.service.chatModel,
temperature,
model: model.chat.chatModel,
temperature: Number(temperature) || 0,
messages: filterPrompts,
frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容
@@ -196,7 +176,7 @@ ${
pushChatBill({
isPay: true,
modelName: model.service.modelName,
chatModel: model.chat.chatModel,
userId,
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
});

View File

@@ -52,7 +52,7 @@ const SlideBar = ({
const myModelList = myModels.map((item) => ({
id: item._id,
name: item.name,
icon: modelList.find((model) => model.model === item?.service?.modelName)?.icon || 'model'
icon: 'model' as any
}));
const collectionList = collectionModels
.map((item) => ({

View File

@@ -1,6 +1,5 @@
import React, { useCallback, useState, useRef, useMemo, useEffect } from 'react';
import { useRouter } from 'next/router';
import Image from 'next/image';
import { getInitChatSiteInfo, delChatRecordByIndex, postSaveChat } from '@/api/chat';
import type { InitChatResponse } from '@/api/response/chat';
import type { ChatItemType } from '@/types/chat';
@@ -16,12 +15,13 @@ import {
Menu,
MenuButton,
MenuList,
MenuItem
MenuItem,
Image
} from '@chakra-ui/react';
import { useToast } from '@/hooks/useToast';
import { useScreen } from '@/hooks/useScreen';
import { useQuery } from '@tanstack/react-query';
import { ModelNameEnum } from '@/constants/model';
import { ChatModelEnum } from '@/constants/model';
import dynamic from 'next/dynamic';
import { useGlobalStore } from '@/store/global';
import { useCopyData } from '@/utils/tools';
@@ -65,8 +65,7 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
name: '',
avatar: '/icon/logo.png',
intro: '',
chatModel: '',
modelName: '',
chatModel: ChatModelEnum.GPT35,
history: []
}); // 聊天框整体数据
@@ -193,13 +192,6 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
// gpt 对话
const gptChatPrompt = useCallback(
async (prompts: ChatSiteItemType) => {
const urlMap: Record<string, string> = {
[ModelNameEnum.GPT35]: '/api/chat/chatGpt',
[ModelNameEnum.VECTOR_GPT]: '/api/chat/vectorGpt'
};
if (!urlMap[chatData.modelName]) return Promise.reject('找不到模型');
// create abort obj
const abortSignal = new AbortController();
controller.current = abortSignal;
@@ -212,7 +204,7 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
// 流请求,获取数据
const responseText = await streamFetch({
url: urlMap[chatData.modelName],
url: '/api/chat/chat',
data: {
prompt,
chatId,
@@ -278,7 +270,7 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
})
}));
},
[chatData.modelName, chatId, generatingMessage, modelId, router, toast]
[chatId, generatingMessage, modelId, router, toast]
);
/**
@@ -393,7 +385,7 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
// 更新流中断对象
useEffect(() => {
return () => {
// eslint-disable-next-line react-hooks/exhaustive-deps
isResetPage.current = true;
controller.current?.abort();
};
}, []);
@@ -476,8 +468,9 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
: chatData.avatar || '/icon/logo.png'
}
alt="avatar"
width={media(30, 20)}
height={media(30, 20)}
w={['20px', '30px']}
maxH={'50px'}
objectFit={'contain'}
/>
</MenuButton>
<MenuList fontSize={'sm'}>

View File

@@ -45,9 +45,10 @@ const ModelDataCard = ({ modelId, isOwner }: { modelId: string; isOwner: boolean
const [searchText, setSearchText] = useState('');
const tdStyles = useRef<BoxProps>({
fontSize: 'xs',
minW: '150px',
maxW: '500px',
whiteSpace: 'pre-wrap',
maxH: '250px',
whiteSpace: 'pre-wrap',
overflowY: 'auto'
});
const {
@@ -132,7 +133,7 @@ const ModelDataCard = ({ modelId, isOwner }: { modelId: string; isOwner: boolean
<>
<Flex>
<Box fontWeight={'bold'} fontSize={'lg'} flex={1} mr={2}>
: {total}
: {total}
</Box>
{isOwner && (
<>

View File

@@ -21,7 +21,7 @@ import {
import { QuestionOutlineIcon } from '@chakra-ui/icons';
import type { ModelSchema } from '@/types/mongoSchema';
import { UseFormReturn } from 'react-hook-form';
import { modelList, ModelVectorSearchModeMap } from '@/constants/model';
import { ChatModelMap, modelList, ModelVectorSearchModeMap } from '@/constants/model';
import { formatPrice } from '@/utils/user';
import { useConfirm } from '@/hooks/useConfirm';
import { useSelectFile } from '@/hooks/useSelectFile';
@@ -30,12 +30,10 @@ import { fileToBase64 } from '@/utils/file';
const ModelEditForm = ({
formHooks,
canTrain,
isOwner,
handleDelModel
}: {
formHooks: UseFormReturn<ModelSchema>;
canTrain: boolean;
isOwner: boolean;
handleDelModel: () => void;
}) => {
@@ -73,6 +71,12 @@ const ModelEditForm = ({
<>
<Card p={4}>
<Box fontWeight={'bold'}></Box>
<Flex alignItems={'center'} mt={4}>
<Box flex={'0 0 80px'} w={0}>
modelId:
</Box>
<Box>{getValues('_id')}</Box>
</Flex>
<Flex mt={4} alignItems={'center'}>
<Box flex={'0 0 80px'} w={0}>
:
@@ -101,17 +105,12 @@ const ModelEditForm = ({
></Input>
</Flex>
</FormControl>
<Flex alignItems={'center'} mt={5}>
<Box flex={'0 0 80px'} w={0}>
modelId:
:
</Box>
<Box>{getValues('_id')}</Box>
</Flex>
<Flex alignItems={'center'} mt={5}>
<Box flex={'0 0 80px'} w={0}>
:
</Box>
<Box>{modelList.find((item) => item.model === getValues('service.modelName'))?.name}</Box>
<Box>{ChatModelMap[getValues('chat.chatModel')]}</Box>
</Flex>
<Flex alignItems={'center'} mt={5}>
<Box flex={'0 0 80px'} w={0}>
@@ -119,7 +118,7 @@ const ModelEditForm = ({
</Box>
<Box>
{formatPrice(
modelList.find((item) => item.model === getValues('service.modelName'))?.price || 0,
modelList.find((item) => item.chatModel === getValues('chat.chatModel'))?.price || 0,
1000
)}
/1K tokens()
@@ -163,15 +162,15 @@ const ModelEditForm = ({
min={0}
max={10}
step={1}
value={getValues('temperature')}
value={getValues('chat.temperature')}
isDisabled={!isOwner}
onChange={(e) => {
setValue('temperature', e);
setValue('chat.temperature', e);
setRefresh(!refresh);
}}
>
<SliderMark
value={getValues('temperature')}
value={getValues('chat.temperature')}
textAlign="center"
bg="blue.500"
color="white"
@@ -181,7 +180,7 @@ const ModelEditForm = ({
fontSize={'xs'}
transform={'translate(-50%, -200%)'}
>
{getValues('temperature')}
{getValues('chat.temperature')}
</SliderMark>
<SliderTrack>
<SliderFilledTrack />
@@ -190,35 +189,42 @@ const ModelEditForm = ({
</Slider>
</Flex>
</FormControl>
{canTrain && (
<FormControl mt={4}>
<Flex alignItems={'center'}>
<Box flex={'0 0 70px'}></Box>
<Select
isDisabled={!isOwner}
{...register('search.mode', { required: '搜索模式不能为空' })}
>
{Object.entries(ModelVectorSearchModeMap).map(([key, { text }]) => (
<option key={key} value={key}>
{text}
</option>
))}
</Select>
</Flex>
</FormControl>
<Flex mt={4} alignItems={'center'}>
<Box mr={4}></Box>
<Switch
isChecked={getValues('chat.useKb')}
onChange={() => {
setValue('chat.useKb', !getValues('chat.useKb'));
setRefresh(!refresh);
}}
/>
</Flex>
{getValues('chat.useKb') && (
<Flex mt={4} alignItems={'center'}>
<Box mr={4} whiteSpace={'nowrap'}>
&emsp;
</Box>
<Select
isDisabled={!isOwner}
{...register('chat.searchMode', { required: '搜索模式不能为空' })}
>
{Object.entries(ModelVectorSearchModeMap).map(([key, { text }]) => (
<option key={key} value={key}>
{text}
</option>
))}
</Select>
</Flex>
)}
<Box mt={4}>
<Box mb={1}></Box>
<Textarea
rows={8}
maxLength={-1}
isDisabled={!isOwner}
placeholder={
canTrain
? '训练的模型会根据知识库内容,生成一部分系统提示词,因此在对话时需要消耗更多的 tokens。你可以增加提示词让效果更符合预期。例如: \n1. 请根据知识库内容回答用户问题。\n2. 知识库是电影《铃芽之旅》的内容,根据知识库内容回答。无关问题,拒绝回复!'
: '模型默认的 prompt 词,通过调整该内容,可以生成一个限定范围的模型。\n注意改功能会影响对话的整体朝向'
}
{...register('systemPrompt')}
placeholder={'模型默认的 prompt 词,通过调整该内容,可以引导模型聊天方向。'}
{...register('chat.systemPrompt')}
/>
</Box>
</Card>

View File

@@ -27,11 +27,6 @@ const ModelDetail = ({ modelId }: { modelId: string }) => {
defaultValues: model
});
const canTrain = useMemo(() => {
const openai = modelList.find((item) => item.model === model?.service.modelName);
return !!(openai && openai.trainName);
}, [model]);
const isOwner = useMemo(() => model.userId === userInfo?._id, [model.userId, userInfo?._id]);
/* 加载模型数据 */
@@ -86,11 +81,8 @@ const ModelDetail = ({ modelId }: { modelId: string }) => {
await putModelById(data._id, {
name: data.name,
avatar: data.avatar || '/icon/logo.png',
systemPrompt: data.systemPrompt,
temperature: data.temperature,
search: data.search,
chat: data.chat,
share: data.share,
service: data.service,
security: data.security
});
toast({
@@ -171,11 +163,15 @@ const ModelDetail = ({ modelId }: { modelId: string }) => {
</Tag>
</Flex>
<Box mt={4} textAlign={'right'}>
<Button variant={'outline'} onClick={handlePreviewChat}>
<Button variant={'outline'} size={'sm'} onClick={handlePreviewChat}>
</Button>
{isOwner && (
<Button ml={4} onClick={formHooks.handleSubmit(saveSubmitSuccess, saveSubmitError)}>
<Button
ml={4}
size={'sm'}
onClick={formHooks.handleSubmit(saveSubmitSuccess, saveSubmitError)}
>
</Button>
)}
@@ -184,16 +180,11 @@ const ModelDetail = ({ modelId }: { modelId: string }) => {
)}
</Card>
<Grid mt={5} gridTemplateColumns={['1fr', '1fr 1fr']} gridGap={5}>
<ModelEditForm
formHooks={formHooks}
handleDelModel={handleDelModel}
canTrain={canTrain}
isOwner={isOwner}
/>
<ModelEditForm formHooks={formHooks} handleDelModel={handleDelModel} isOwner={isOwner} />
{canTrain && !!model._id && (
{modelId && (
<Card p={4} gridColumnStart={[1, 1]} gridColumnEnd={[2, 3]}>
<ModelDataCard modelId={model._id} isOwner={isOwner} />
<ModelDataCard modelId={modelId} isOwner={isOwner} />
</Card>
)}
</Grid>

View File

@@ -1,138 +0,0 @@
import React, { Dispatch, useState, useCallback, useMemo } from 'react';
import {
Modal,
ModalOverlay,
ModalContent,
ModalHeader,
ModalFooter,
ModalBody,
ModalCloseButton,
FormControl,
FormErrorMessage,
Button,
useToast,
Input,
Select,
Box
} from '@chakra-ui/react';
import { useForm } from 'react-hook-form';
import { postCreateModel } from '@/api/model';
import type { ModelSchema } from '@/types/mongoSchema';
import { modelList } from '@/constants/model';
import { formatPrice } from '@/utils/user';
interface CreateFormType {
name: string;
serviceModelName: string;
}
const CreateModel = ({
setCreateModelOpen,
onSuccess
}: {
setCreateModelOpen: Dispatch<boolean>;
onSuccess: Dispatch<ModelSchema>;
}) => {
const [requesting, setRequesting] = useState(false);
const [refresh, setRefresh] = useState(false);
const toast = useToast({
duration: 2000,
position: 'top'
});
const {
getValues,
register,
handleSubmit,
formState: { errors }
} = useForm<CreateFormType>({
defaultValues: {
serviceModelName: modelList[0].model
}
});
const handleCreateModel = useCallback(
async (data: CreateFormType) => {
setRequesting(true);
try {
const res = await postCreateModel(data);
toast({
title: '创建成功',
status: 'success'
});
onSuccess(res);
setCreateModelOpen(false);
} catch (err: any) {
toast({
title: typeof err === 'string' ? err : err.message || '出现了意外',
status: 'error'
});
}
setRequesting(false);
},
[onSuccess, setCreateModelOpen, toast]
);
return (
<>
<Modal isOpen={true} onClose={() => setCreateModelOpen(false)}>
<ModalOverlay />
<ModalContent>
<ModalHeader></ModalHeader>
<ModalCloseButton />
<ModalBody>
<FormControl mb={8} isInvalid={!!errors.name}>
<Input
placeholder="模型名称"
{...register('name', {
required: '模型名不能为空'
})}
/>
<FormErrorMessage position={'absolute'} fontSize="xs">
{!!errors.name && errors.name.message}
</FormErrorMessage>
</FormControl>
<FormControl isInvalid={!!errors.serviceModelName}>
<Select
placeholder="选择基础模型类型"
{...register('serviceModelName', {
required: '底层模型不能为空',
onChange() {
setRefresh(!refresh);
}
})}
>
{modelList.map((item) => (
<option key={item.model} value={item.model}>
{item.name}
</option>
))}
</Select>
<FormErrorMessage position={'absolute'} fontSize="xs">
{!!errors.serviceModelName && errors.serviceModelName.message}
</FormErrorMessage>
</FormControl>
<Box mt={3} textAlign={'center'} fontSize={'sm'} color={'blackAlpha.600'}>
{formatPrice(
modelList.find((item) => item.model === getValues('serviceModelName'))?.price || 0,
1000
)}
/1K tokens()
</Box>
</ModalBody>
<ModalFooter>
<Button mr={3} colorScheme={'gray'} onClick={() => setCreateModelOpen(false)}>
</Button>
<Button isLoading={requesting} onClick={handleSubmit(handleCreateModel)}>
</Button>
</ModalFooter>
</ModalContent>
</Modal>
</>
);
};
export default CreateModel;

View File

@@ -2,8 +2,8 @@ import React, { useEffect } from 'react';
import { Box, Button, Flex, Tag } from '@chakra-ui/react';
import type { ModelSchema } from '@/types/mongoSchema';
import { formatModelStatus } from '@/constants/model';
import dayjs from 'dayjs';
import { useRouter } from 'next/router';
import { ChatModelMap } from '@/constants/model';
const ModelPhoneList = ({
models,
@@ -42,12 +42,12 @@ const ModelPhoneList = ({
</Tag>
</Flex>
<Flex mt={5}>
<Box flex={'0 0 100px'}>: </Box>
<Box color={'blackAlpha.500'}>{dayjs(model.updateTime).format('YYYY-MM-DD HH:mm')}</Box>
<Box flex={'0 0 100px'}>: </Box>
<Box color={'blackAlpha.500'}>{ChatModelMap[model.chat.chatModel]}</Box>
</Flex>
<Flex mt={5}>
<Box flex={'0 0 100px'}>AI模型: </Box>
<Box color={'blackAlpha.500'}>{model.service.modelName}</Box>
<Box flex={'0 0 100px'}>: </Box>
<Box color={'blackAlpha.500'}>{model.chat.temperature}</Box>
</Flex>
<Flex mt={5} justifyContent={'flex-end'}>
<Button

View File

@@ -13,10 +13,9 @@ import {
Box
} from '@chakra-ui/react';
import { formatModelStatus } from '@/constants/model';
import dayjs from 'dayjs';
import type { ModelSchema } from '@/types/mongoSchema';
import { useRouter } from 'next/router';
import { modelList } from '@/constants/model';
import { ChatModelMap } from '@/constants/model';
const ModelTable = ({
models = [],
@@ -33,18 +32,18 @@ const ModelTable = ({
dataIndex: 'name'
},
{
title: '模型类型',
title: '对话模型',
key: 'service',
render: (model: ModelSchema) => (
<Box fontWeight={'bold'} whiteSpace={'pre-wrap'} maxW={'200px'}>
{modelList.find((item) => item.model === model.service.modelName)?.name}
{ChatModelMap[model.chat.chatModel]}
</Box>
)
},
{
title: '最后更新时间',
key: 'updateTime',
render: (item: ModelSchema) => dayjs(item.updateTime).format('YYYY-MM-DD HH:mm')
title: '温度',
key: 'temperature',
render: (model: ModelSchema) => <>{model.chat.temperature}</>
},
{
title: '状态',

View File

@@ -1,4 +1,4 @@
import React, { useState, useCallback } from 'react';
import React, { useCallback } from 'react';
import { Box, Button, Flex, Card } from '@chakra-ui/react';
import type { ModelSchema } from '@/types/mongoSchema';
import { useRouter } from 'next/router';
@@ -7,30 +7,37 @@ import ModelPhoneList from './components/ModelPhoneList';
import { useScreen } from '@/hooks/useScreen';
import { useQuery } from '@tanstack/react-query';
import { useLoading } from '@/hooks/useLoading';
import dynamic from 'next/dynamic';
import { useToast } from '@/hooks/useToast';
import { useUserStore } from '@/store/user';
const CreateModel = dynamic(() => import('./components/CreateModel'));
import { postCreateModel } from '@/api/model';
const modelList = () => {
const { toast } = useToast();
const { isPc } = useScreen();
const router = useRouter();
const { myModels, setMyModels, getMyModels } = useUserStore();
const [openCreateModel, setOpenCreateModel] = useState(false);
const { myModels, getMyModels } = useUserStore();
const { Loading, setIsLoading } = useLoading();
/* 加载模型 */
const { isLoading } = useQuery(['loadModels'], getMyModels);
/* 创建成功回调 */
const createModelSuccess = useCallback(
(data: ModelSchema) => {
setMyModels([data, ...myModels]);
},
[myModels, setMyModels]
);
const handleCreateModel = useCallback(async () => {
setIsLoading(true);
try {
const id = await postCreateModel({ name: `模型${myModels.length}` });
toast({
title: '创建成功',
status: 'success'
});
router.push(`/model/detail?modelId=${id}`);
} catch (err: any) {
toast({
title: typeof err === 'string' ? err : err.message || '出现了意外',
status: 'error'
});
}
setIsLoading(false);
}, [myModels.length, router, setIsLoading, toast]);
/* 点前往聊天预览页 */
const handlePreviewChat = useCallback(
@@ -61,7 +68,7 @@ const modelList = () => {
</Box>
<Button flex={'0 0 145px'} variant={'outline'} onClick={() => setOpenCreateModel(true)}>
<Button flex={'0 0 145px'} variant={'outline'} onClick={handleCreateModel}>
</Button>
</Flex>
@@ -74,10 +81,6 @@ const modelList = () => {
<ModelPhoneList models={myModels} handlePreviewChat={handlePreviewChat} />
)}
</Box>
{/* 创建弹窗 */}
{openCreateModel && (
<CreateModel setCreateModelOpen={setOpenCreateModel} onSuccess={createModelSuccess} />
)}
<Loading loading={isLoading} />
</Box>

View File

@@ -1,23 +1,17 @@
import { connectToDatabase, Bill, User } from '../mongo';
import {
modelList,
ChatModelEnum,
ModelNameEnum,
Model2ChatModelMap,
embeddingModel
} from '@/constants/model';
import { modelList, ChatModelEnum, embeddingModel } from '@/constants/model';
import { BillTypeEnum } from '@/constants/user';
import { countChatTokens } from '@/utils/tools';
export const pushChatBill = async ({
isPay,
modelName,
chatModel,
userId,
chatId,
messages
}: {
isPay: boolean;
modelName: `${ModelNameEnum}`;
chatModel: `${ChatModelEnum}`;
userId: string;
chatId?: '' | string;
messages: { role: 'system' | 'user' | 'assistant'; content: string }[];
@@ -26,7 +20,7 @@ export const pushChatBill = async ({
try {
// 计算 token 数量
const tokens = countChatTokens({ model: Model2ChatModelMap[modelName] as any, messages });
const tokens = countChatTokens({ model: chatModel, messages });
const text = messages.map((item) => item.content).join('');
console.log(
@@ -37,7 +31,7 @@ export const pushChatBill = async ({
await connectToDatabase();
// 获取模型单价格
const modelItem = modelList.find((item) => item.model === modelName);
const modelItem = modelList.find((item) => item.chatModel === chatModel);
// 计算价格
const unitPrice = modelItem?.price || 5;
const price = unitPrice * tokens;
@@ -47,7 +41,7 @@ export const pushChatBill = async ({
const res = await Bill.create({
userId,
type: 'chat',
modelName,
modelName: chatModel,
chatId: chatId ? chatId : undefined,
textLen: text.length,
tokenLen: tokens,
@@ -94,7 +88,7 @@ export const pushSplitDataBill = async ({
if (isPay) {
try {
// 获取模型单价格, 都是用 gpt35 拆分
const modelItem = modelList.find((item) => item.model === ChatModelEnum.GPT35);
const modelItem = modelList.find((item) => item.chatModel === ChatModelEnum.GPT35);
const unitPrice = modelItem?.price || 3;
// 计算价格
const price = unitPrice * tokenLen;

View File

@@ -1,5 +1,5 @@
import { Schema, model, models, Model } from 'mongoose';
import { modelList } from '@/constants/model';
import { ChatModelMap } from '@/constants/model';
import { BillSchema as BillType } from '@/types/mongoSchema';
import { BillTypeMap } from '@/constants/user';
@@ -16,7 +16,7 @@ const BillSchema = new Schema({
},
modelName: {
type: String,
enum: [...modelList.map((item) => item.model), 'text-embedding-ada-002'],
enum: [...Object.keys(ChatModelMap), 'text-embedding-ada-002'],
required: true
},
chatId: {

View File

@@ -1,6 +1,11 @@
import { Schema, model, models, Model as MongoModel } from 'mongoose';
import { ModelSchema as ModelType } from '@/types/mongoSchema';
import { ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model';
import {
ModelVectorSearchModeMap,
ModelVectorSearchModeEnum,
ChatModelMap,
ChatModelEnum
} from '@/constants/model';
const ModelSchema = new Schema({
userId: {
@@ -16,11 +21,6 @@ const ModelSchema = new Schema({
type: String,
default: '/icon/logo.png'
},
systemPrompt: {
// 系统提示词
type: String,
default: ''
},
status: {
type: String,
required: true,
@@ -30,17 +30,34 @@ const ModelSchema = new Schema({
type: Date,
default: () => new Date()
},
temperature: {
type: Number,
min: 0,
max: 10,
default: 4
},
search: {
mode: {
chat: {
useKb: {
// use knowledge base to search
type: Boolean,
default: false
},
searchMode: {
// knowledge base search mode
type: String,
enum: Object.keys(ModelVectorSearchModeMap),
default: ModelVectorSearchModeEnum.hightSimilarity
},
systemPrompt: {
// 系统提示词
type: String,
default: ''
},
temperature: {
type: Number,
min: 0,
max: 10,
default: 0
},
chatModel: {
// 聊天时使用的模型
type: String,
enum: Object.keys(ChatModelMap),
default: ChatModelEnum.GPT35
}
},
share: {
@@ -63,18 +80,6 @@ const ModelSchema = new Schema({
default: 0
}
},
service: {
chatModel: {
// 聊天时使用的模型
type: String,
required: true
},
modelName: {
// 底层模型的名称
type: String,
required: true
}
},
security: {
type: {
domain: {
@@ -100,8 +105,7 @@ const ModelSchema = new Schema({
default: -1
}
},
default: {},
required: true
default: {}
}
});

View File

@@ -0,0 +1,47 @@
import { openaiCreateEmbedding } from '../utils/openai';
import { PgClient } from '@/service/pg';
import { ModelDataStatusEnum } from '@/constants/model';
/**
* use openai embedding search kb
*/
export const searchKb_openai = async ({
apiKey,
isPay,
text,
similarity,
modelId,
userId
}: {
apiKey: string;
isPay: boolean;
text: string;
modelId: string;
userId: string;
similarity: number;
}) => {
// 获取提示词的向量
const { vector: promptVector } = await openaiCreateEmbedding({
isPay,
apiKey,
userId,
text
});
const vectorSearch = await PgClient.select<{ id: string; q: string; a: string }>('modelData', {
fields: ['id', 'q', 'a'],
where: [
['status', ModelDataStatusEnum.ready],
'AND',
['model_id', modelId],
'AND',
`vector <=> '[${promptVector}]' < ${similarity}`
],
order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }],
limit: 20
});
const systemPrompts: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`);
return { systemPrompts };
};

View File

@@ -1,10 +1,33 @@
import { Configuration, OpenAIApi } from 'openai';
import { Chat, Model } from '../mongo';
import type { NextApiRequest } from 'next';
import jwt from 'jsonwebtoken';
import { Chat, Model, OpenApi, User } from '../mongo';
import type { ModelSchema } from '@/types/mongoSchema';
import { authToken } from './tools';
import { getOpenApiKey } from './openai';
import type { ChatItemType } from '@/types/chat';
import mongoose from 'mongoose';
import { defaultModel } from '@/constants/model';
import { formatPrice } from '@/utils/user';
import { ERROR_ENUM } from '../errorCode';
/* 校验 token */
export const authToken = (token?: string): Promise<string> => {
return new Promise((resolve, reject) => {
if (!token) {
reject('缺少登录凭证');
return;
}
const key = process.env.TOKEN_KEY as string;
jwt.verify(token, key, function (err, decoded: any) {
if (err || !decoded?.userId) {
reject('凭证无效');
return;
}
resolve(decoded.userId);
});
});
};
export const getOpenAIApi = (apiKey: string) => {
const configuration = new Configuration({
@@ -20,12 +43,14 @@ export const authModel = async ({
modelId,
userId,
authUser = true,
authOwner = true
authOwner = true,
reserveDetail = false
}: {
modelId: string;
userId: string;
authUser?: boolean;
authOwner?: boolean;
reserveDetail?: boolean; // focus reserve detail
}) => {
// 获取 model 数据
const model = await Model.findById<ModelSchema>(modelId);
@@ -33,15 +58,21 @@ export const authModel = async ({
return Promise.reject('模型不存在');
}
// 使用权限校验
/*
Access verification
1. authOwner=true or authUser = true , just owner can use
2. authUser = false and share, anyone can use
*/
if ((authOwner || (authUser && !model.share.isShare)) && userId !== String(model.userId)) {
return Promise.reject('无权操作该模型');
}
// detail 内容去除
if (!model.share.isShareDetail && userId !== String(model.userId)) {
model.systemPrompt = '';
model.temperature = 0;
// do not share detail info
if (!reserveDetail && !model.share.isShareDetail && userId !== String(model.userId)) {
model.chat = {
...defaultModel.chat,
chatModel: model.chat.chatModel
};
}
return { model };
@@ -60,7 +91,7 @@ export const authChat = async ({
const userId = await authToken(authorization);
// 获取 model 数据
const { model } = await authModel({ modelId, userId, authOwner: false });
const { model } = await authModel({ modelId, userId, authOwner: false, reserveDetail: true });
// 聊天内容
let content: ChatItemType[] = [];
@@ -91,3 +122,41 @@ export const authChat = async ({
model
};
};
/* 校验 open api key */
export const authOpenApiKey = async (req: NextApiRequest) => {
const { apikey: apiKey } = req.headers;
if (!apiKey) {
return Promise.reject(ERROR_ENUM.unAuthorization);
}
try {
const openApi = await OpenApi.findOne({ apiKey });
if (!openApi) {
return Promise.reject(ERROR_ENUM.unAuthorization);
}
const userId = String(openApi.userId);
// 余额校验
const user = await User.findById(userId);
if (!user) {
return Promise.reject(ERROR_ENUM.unAuthorization);
}
if (formatPrice(user.balance) <= 0) {
return Promise.reject(ERROR_ENUM.insufficientQuota);
}
// 更新使用的时间
await OpenApi.findByIdAndUpdate(openApi._id, {
lastUsedTime: new Date()
});
return {
apiKey: process.env.OPENAIKEY as string,
userId
};
} catch (error) {
return Promise.reject(error);
}
};

View File

@@ -1,6 +1,5 @@
import * as nodemailer from 'nodemailer';
import { UserAuthTypeEnum } from '@/constants/common';
import dayjs from 'dayjs';
import Dysmsapi, * as dysmsapi from '@alicloud/dysmsapi20170525';
// @ts-ignore
import * as OpenApi from '@alicloud/openapi-client';
@@ -48,25 +47,6 @@ export const sendEmailCode = (email: string, code: string, type: `${UserAuthType
});
};
export const sendTrainSucceed = (email: string, modelName: string) => {
return new Promise((resolve, reject) => {
const options = {
from: `"FastGPT" ${myEmail}`,
to: email,
subject: '模型训练完成通知',
html: `你的模型 ${modelName} 已于 ${dayjs().format('YYYY-MM-DD HH:mm')} 训练完成!`
};
mailTransport.sendMail(options, function (err, msg) {
if (err) {
console.log('send email error->', err);
reject('邮箱异常');
} else {
resolve('');
}
});
});
};
export const sendPhoneCode = async (phone: string, code: string) => {
const accessKeyId = process.env.aliAccessKeyId;
const accessKeySecret = process.env.aliAccessKeySecret;

View File

@@ -1,10 +1,6 @@
import type { NextApiRequest } from 'next';
import crypto from 'crypto';
import jwt from 'jsonwebtoken';
import { ChatItemType } from '@/types/chat';
import { OpenApi, User } from '../mongo';
import { formatPrice } from '@/utils/user';
import { ERROR_ENUM } from '../errorCode';
import { countChatTokens } from '@/utils/tools';
import { ChatCompletionRequestMessageRoleEnum, ChatCompletionRequestMessage } from 'openai';
import { ChatModelEnum } from '@/constants/model';
@@ -46,44 +42,6 @@ export const authToken = (token?: string): Promise<string> => {
});
};
/* 校验 open api key */
export const authOpenApiKey = async (req: NextApiRequest) => {
const { apikey: apiKey } = req.headers;
if (!apiKey) {
return Promise.reject(ERROR_ENUM.unAuthorization);
}
try {
const openApi = await OpenApi.findOne({ apiKey });
if (!openApi) {
return Promise.reject(ERROR_ENUM.unAuthorization);
}
const userId = String(openApi.userId);
// 余额校验
const user = await User.findById(userId);
if (!user) {
return Promise.reject(ERROR_ENUM.unAuthorization);
}
if (formatPrice(user.balance) <= 0) {
return Promise.reject('Insufficient account balance');
}
// 更新使用的时间
await OpenApi.findByIdAndUpdate(openApi._id, {
lastUsedTime: new Date()
});
return {
apiKey: process.env.OPENAIKEY as string,
userId
};
} catch (error) {
return Promise.reject(error);
}
};
/* openai axios config */
export const axiosConfig = () => ({
httpsAgent: global.httpsAgent,

View File

@@ -1,13 +1,11 @@
import { ModelStatusEnum } from '@/constants/model';
import type { ModelSchema } from './mongoSchema';
export interface ModelUpdateParams {
name: string;
avatar: string;
systemPrompt: string;
temperature: number;
search: ModelSchema['search'];
chat: ModelSchema['chat'];
share: ModelSchema['share'];
service: ModelSchema['service'];
security: ModelSchema['security'];
}

View File

@@ -31,15 +31,17 @@ export interface AuthCodeSchema {
export interface ModelSchema {
_id: string;
userId: string;
name: string;
avatar: string;
systemPrompt: string;
userId: string;
status: `${ModelStatusEnum}`;
updateTime: number;
temperature: number;
search: {
mode: `${ModelVectorSearchModeEnum}`;
chat: {
useKb: boolean;
searchMode: `${ModelVectorSearchModeEnum}`;
systemPrompt: string;
temperature: number;
chatModel: `${ChatModelEnum}`; // 聊天时用的模型,训练后就是训练的模型
};
share: {
isShare: boolean;
@@ -47,10 +49,6 @@ export interface ModelSchema {
intro: string;
collection: number;
};
service: {
chatModel: `${ChatModelEnum}`; // 聊天时用的模型,训练后就是训练的模型
modelName: `${ModelNameEnum}`; // 底层模型名称,不会变
};
security: {
domain: string[];
contextMaxLen: number;