feat: 知识库匹配模式选择

This commit is contained in:
archer
2023-04-12 00:44:01 +08:00
parent 1fe5cd751a
commit c605964fa8
11 changed files with 102 additions and 41 deletions

View File

@@ -8,3 +8,4 @@ README.md
.yalc/
yalc.lock
testApi/

View File

@@ -4,14 +4,12 @@ import type { RedisModelDataItemType } from '@/types/redis';
export enum ChatModelNameEnum {
GPT35 = 'gpt-3.5-turbo',
VECTOR_GPT = 'VECTOR_GPT',
GPT3 = 'text-davinci-003',
VECTOR = 'text-embedding-ada-002'
}
export const ChatModelNameMap = {
[ChatModelNameEnum.GPT35]: 'gpt-3.5-turbo',
[ChatModelNameEnum.VECTOR_GPT]: 'gpt-3.5-turbo',
[ChatModelNameEnum.GPT3]: 'text-davinci-003',
[ChatModelNameEnum.VECTOR]: 'text-embedding-ada-002'
};
@@ -34,7 +32,7 @@ export const modelList: ModelConstantsData[] = [
trainName: '',
maxToken: 4000,
contextMaxToken: 7500,
maxTemperature: 2,
maxTemperature: 1.5,
price: 3
},
{
@@ -47,16 +45,6 @@ export const modelList: ModelConstantsData[] = [
maxTemperature: 1,
price: 3
}
// {
// serviceCompany: 'openai',
// name: 'GPT3',
// model: ChatModelNameEnum.GPT3,
// trainName: 'davinci',
// maxToken: 4000,
// contextMaxToken: 7500,
// maxTemperature: 2,
// price: 30
// }
];
export enum TrainingStatusEnum {
@@ -97,6 +85,34 @@ export const ModelDataStatusMap: Record<RedisModelDataItemType['status'], string
waiting: '训练中'
};
/* 知识库搜索时的配置 */
// 搜索方式
export enum ModelVectorSearchModeEnum {
hightSimilarity = 'hightSimilarity', // 高相似度+禁止回复
lowSimilarity = 'lowSimilarity', // 低相似度
noContext = 'noContex' // 高相似度+无上下文回复
}
export const ModelVectorSearchModeMap: Record<
`${ModelVectorSearchModeEnum}`,
{
text: string;
similarity: number;
}
> = {
[ModelVectorSearchModeEnum.hightSimilarity]: {
text: '高相似度, 无匹配时拒绝回复',
similarity: 0.2
},
[ModelVectorSearchModeEnum.noContext]: {
text: '高相似度,无匹配时直接回复',
similarity: 0.2
},
[ModelVectorSearchModeEnum.lowSimilarity]: {
text: '低相似度匹配',
similarity: 0.8
}
};
export const defaultModel: ModelSchema = {
_id: '',
userId: '',
@@ -108,6 +124,9 @@ export const defaultModel: ModelSchema = {
systemPrompt: '',
intro: '',
temperature: 5,
search: {
mode: ModelVectorSearchModeEnum.hightSimilarity
},
service: {
company: 'openai',
trainId: '',

View File

@@ -7,7 +7,7 @@ import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response';
import type { ModelSchema } from '@/types/mongoSchema';
import { PassThrough } from 'stream';
import { modelList } from '@/constants/model';
import { modelList, ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill';
import { connectRedis } from '@/service/redis';
import { VecModelDataPrefix } from '@/constants/redis';
@@ -65,13 +65,14 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
text: prompt.value
});
const similarity = ModelVectorSearchModeMap[model.search.mode]?.similarity || 0.22;
// 搜索系统提示词, 按相似度从 redis 中搜出相关的 q 和 text
const redisData: any[] = await redis.sendCommand([
'FT.SEARCH',
`idx:${VecModelDataPrefix}:hash`,
`@modelId:{${String(
chat.modelId._id
)}} @vector:[VECTOR_RANGE 0.22 $blob]=>{$YIELD_DISTANCE_AS: score}`,
)}} @vector:[VECTOR_RANGE ${similarity} $blob]=>{$YIELD_DISTANCE_AS: score}`,
'RETURN',
'1',
'text',
@@ -97,7 +98,24 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
}
}
if (formatRedisPrompt.length > 0) {
/* 高相似度+退出,无法匹配时直接退出 */
if (
formatRedisPrompt.length === 0 &&
model.search.mode === ModelVectorSearchModeEnum.hightSimilarity
) {
return res.send('对不起,你的问题不在知识库中。');
}
/* 高相似度+无上下文,不添加额外知识 */
if (
formatRedisPrompt.length === 0 &&
model.search.mode === ModelVectorSearchModeEnum.noContext
) {
prompts.unshift({
obj: 'SYSTEM',
value: model.systemPrompt
});
} else {
// 有匹配情况下,添加知识库内容。
// 系统提示词过滤,最多 2800 tokens
const systemPrompt = systemPromptFilter(formatRedisPrompt, 2800);
@@ -107,8 +125,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
'YYYY/MM/DD HH:mm:ss'
)} ${systemPrompt}"`
});
} else {
return res.send('对不起,你的问题不在知识库中。');
}
// 控制在 tokens 数量,防止超出

View File

@@ -8,7 +8,7 @@ import type { ModelUpdateParams } from '@/types/model';
/* 获取我的模型 */
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
const { name, service, security, systemPrompt, intro, temperature } =
const { name, search, service, security, systemPrompt, intro, temperature } =
req.body as ModelUpdateParams;
const { modelId } = req.query as { modelId: string };
const { authorization } = req.headers;
@@ -37,6 +37,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
systemPrompt,
intro,
temperature,
search,
// service,
security
}

View File

@@ -83,22 +83,22 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
下面是一些例子:
实现一个手机号发生注册验证码方法.
1. 从 query 中获取 phone.
2. 校验手机号格式是否正确,不正确返回{error: "手机号格式错误"}.
2. 校验手机号格式是否正确,不正确返回错误响应,消息为:手机号格式错误.
3. 给 phone 发送一个短信验证码,验证码长度为6位字符串,内容为:你正在注册laf,验证码为:code.
4. 数据库添加数据,表为"codes",内容为 {phone, code}.
实现根据手机号注册账号,需要验证手机验证码.
1. 从 body 中获取 phone 和 code.
2. 校验手机号格式是否正确,不正确返回{error: "手机号格式错误"}.
2. 获取数据库数据,表为"codes",查找是否有符合 phone, code 等于body参数的记录,没有的话返回 {error:"验证码不正确"}.
2. 校验手机号格式是否正确,不正确返回错误响应,消息为:手机号格式错误.
2. 获取数据库数据,表为"codes",查找是否有符合 phone, code 等于body参数的记录,没有的话错误响应,消息为:验证码不正确.
4. 添加数据库数据,表为"users" ,内容为{phone, code, createTime}.
5. 删除数据库数据,删除 code 记录.
更新博客记录。传入blogId,blogText,tags,还需要记录更新的时间.
1. 从 body 中获取 blogId,blogText 和 tags.
2. 校验 blogId 是否为空,为空则返回 {error: "博客ID不能为空"}.
3. 校验 blogText 是否为空,为空则返回 {error: "博客内容不能为空"}.
4. 校验 tags 是否为数组,不是则返回 {error: "标签必须为数组"}.
2. 校验 blogId 是否为空,为空则错误响应,消息为:博客ID不能为空.
3. 校验 blogText 是否为空,为空则错误响应,消息为:博客内容不能为空.
4. 校验 tags 是否为数组,不是则错误响应,消息为:标签必须为数组.
5. 获取当前时间,记录为 updateTime.
6. 更新数据库数据,表为"blogs",更新符合 blogId 的记录的内容为{blogText, tags, updateTime}.
7. 返回结果 {message: "更新博客记录成功"}.`

View File

@@ -114,8 +114,7 @@ const Chat = ({ chatId }: { chatId: string }) => {
async (prompts: ChatSiteItemType) => {
const urlMap: Record<string, string> = {
[ChatModelNameEnum.GPT35]: '/api/chat/chatGpt',
[ChatModelNameEnum.VECTOR_GPT]: '/api/chat/vectorGpt',
[ChatModelNameEnum.GPT3]: '/api/chat/gpt3'
[ChatModelNameEnum.VECTOR_GPT]: '/api/chat/vectorGpt'
};
if (!urlMap[chatData.modelName]) return Promise.reject('找不到模型');

View File

@@ -12,12 +12,13 @@ import {
SliderThumb,
SliderMark,
Tooltip,
Button
Button,
Select
} from '@chakra-ui/react';
import { QuestionOutlineIcon } from '@chakra-ui/icons';
import type { ModelSchema } from '@/types/mongoSchema';
import { UseFormReturn } from 'react-hook-form';
import { modelList } from '@/constants/model';
import { modelList, ModelVectorSearchModeMap } from '@/constants/model';
import { formatPrice } from '@/utils/user';
import { useConfirm } from '@/hooks/useConfirm';
@@ -89,15 +90,6 @@ const ModelEditForm = ({
</Button>
</Flex>
{/* <FormControl mt={4}>
<Box mb={1}>介绍:</Box>
<Textarea
rows={5}
maxLength={500}
{...register('intro')}
placeholder={'模型的介绍,仅做展示,不影响模型的效果'}
/>
</FormControl> */}
</Card>
<Card p={4}>
<Box fontWeight={'bold'}></Box>
@@ -143,6 +135,20 @@ const ModelEditForm = ({
</Slider>
</Flex>
</FormControl>
{canTrain && (
<FormControl mt={4}>
<Flex alignItems={'center'}>
<Box flex={'0 0 70px'}></Box>
<Select {...register('search.mode', { required: '搜索模式不能为空' })}>
{Object.entries(ModelVectorSearchModeMap).map(([key, { text }]) => (
<option key={key} value={key}>
{text}
</option>
))}
</Select>
</Flex>
</FormControl>
)}
<Box mt={4}>
<Box mb={1}></Box>
<Textarea

View File

@@ -143,6 +143,7 @@ const ModelDetail = ({ modelId }: { modelId: string }) => {
systemPrompt: data.systemPrompt,
intro: data.intro,
temperature: data.temperature,
search: data.search,
service: data.service,
security: data.security
});

View File

@@ -1,5 +1,7 @@
import { Schema, model, models, Model as MongoModel } from 'mongoose';
import { ModelSchema as ModelType } from '@/types/mongoSchema';
import { ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model';
const ModelSchema = new Schema({
userId: {
type: Schema.Types.ObjectId,
@@ -43,6 +45,13 @@ const ModelSchema = new Schema({
max: 10,
default: 4
},
search: {
mode: {
type: String,
enum: Object.keys(ModelVectorSearchModeMap),
default: ModelVectorSearchModeEnum.hightSimilarity
}
},
service: {
company: {
type: String,

View File

@@ -5,8 +5,9 @@ export interface ModelUpdateParams {
systemPrompt: string;
intro: string;
temperature: number;
service: ModelSchema.service;
security: ModelSchema.security;
search: ModelSchema['search'];
service: ModelSchema['service'];
security: ModelSchema['security'];
}
export interface ModelDataItemType {

View File

@@ -1,5 +1,10 @@
import type { ChatItemType } from './chat';
import { ModelStatusEnum, TrainingStatusEnum, ChatModelNameEnum } from '@/constants/model';
import {
ModelStatusEnum,
TrainingStatusEnum,
ChatModelNameEnum,
ModelVectorSearchModeEnum
} from '@/constants/model';
import type { DataType } from './data';
export type ServiceName = 'openai';
@@ -32,6 +37,9 @@ export interface ModelSchema {
updateTime: number;
trainingTimes: number;
temperature: number;
search: {
mode: `${ModelVectorSearchModeEnum}`;
};
service: {
company: ServiceName;
trainId: string; // 训练的模型训练后就是训练的模型id