feat: 替换redis搜索

This commit is contained in:
archer
2023-04-19 12:00:28 +08:00
parent 867d69659f
commit 1e5714da1b
12 changed files with 147 additions and 228 deletions

View File

@@ -1,7 +1,7 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { connectToDatabase } from '@/service/mongo';
import { authChat } from '@/service/utils/chat';
import { httpsAgent, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools';
import { httpsAgent, systemPromptFilter } from '@/service/utils/tools';
import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai';
import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response';
@@ -9,11 +9,9 @@ import type { ModelSchema } from '@/types/mongoSchema';
import { PassThrough } from 'stream';
import { modelList, ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill';
import { connectRedis } from '@/service/redis';
import { VecModelDataPrefix } from '@/constants/redis';
import { vectorToBuffer } from '@/utils/tools';
import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai';
import dayjs from 'dayjs';
import { PgClient } from '@/service/pg';
/* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
@@ -43,7 +41,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
}
await connectToDatabase();
const redis = await connectRedis();
let startTime = Date.now();
const { chat, userApiKey, systemKey, userId } = await authChat(chatId, authorization);
@@ -65,38 +62,22 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
text: prompt.value
});
// 相似度搜素
const similarity = ModelVectorSearchModeMap[model.search.mode]?.similarity || 0.22;
// 搜索系统提示词, 按相似度从 redis 中搜出相关的 q 和 text
const redisData: any[] = await redis.sendCommand([
'FT.SEARCH',
`idx:${VecModelDataPrefix}:hash`,
`@modelId:{${String(
chat.modelId._id
)}} @vector:[VECTOR_RANGE ${similarity} $blob]=>{$YIELD_DISTANCE_AS: score}`,
'RETURN',
'1',
'text',
'SORTBY',
'score',
'PARAMS',
'2',
'blob',
vectorToBuffer(promptVector),
'LIMIT',
'0',
'30',
'DIALECT',
'2'
]);
const vectorSearch = await PgClient.select<{ id: string; q: string; a: string }>('modelData', {
fields: ['id', 'q', 'a'],
order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }],
where: [
['model_id', model._id],
'AND',
['user_id', userId],
'AND',
`vector <=> '[${promptVector}]' < ${similarity}`
],
limit: 30
});
const formatRedisPrompt: string[] = [];
// 格式化响应值,获取 qa
for (let i = 2; i < 61; i += 2) {
const text = redisData[i]?.[1];
if (text) {
formatRedisPrompt.push(text);
}
}
const formatRedisPrompt: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`);
/* 高相似度+退出,无法匹配时直接退出 */
if (
@@ -121,9 +102,9 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
prompts.unshift({
obj: 'SYSTEM',
value: `${model.systemPrompt} 知识库内容回答,知识库内容为: "当前时间:${dayjs().format(
value: `${model.systemPrompt} 知识库是最新的,下面是知识库内容:当前时间${dayjs().format(
'YYYY/MM/DD HH:mm:ss'
)} ${systemPrompt}"`
)}\n${systemPrompt}`
});
}

View File

@@ -2,8 +2,7 @@ import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import { connectRedis } from '@/service/redis';
import { VecModelDataIdx } from '@/constants/redis';
import { PgClient } from '@/service/pg';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
@@ -25,28 +24,23 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
const userId = await authToken(authorization);
await connectToDatabase();
const redis = await connectRedis();
// 从 redis 中获取数据
const searchRes = await redis.ft.search(
VecModelDataIdx,
`@modelId:{${modelId}} @userId:{${userId}}`,
{
RETURN: ['q', 'text'],
LIMIT: {
from: 0,
size: 10000
}
}
);
const data: [string, string][] = [];
searchRes.documents.forEach((item: any) => {
if (item.value.q && item.value.text) {
data.push([item.value.q.replace(/\n/g, '\\n'), item.value.text.replace(/\n/g, '\\n')]);
}
// 统计数据
const count = await PgClient.count('modelData', {
where: [['model_id', modelId], 'AND', ['user_id', userId]]
});
// 从 pg 中获取所有数据
const pgData = await PgClient.select<{ q: string; a: string }>('modelData', {
where: [['model_id', modelId], 'AND', ['user_id', userId]],
fields: ['q', 'a'],
order: [{ field: 'id', mode: 'DESC' }],
limit: count
});
const data: [string, string][] = pgData.rows.map((item) => [
item.q.replace(/\n/g, '\\n'),
item.a.replace(/\n/g, '\\n')
]);
jsonRes(res, {
data

View File

@@ -37,7 +37,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
await connectToDatabase();
const searchRes = await PgClient.select<PgModelDataItemType>('modelData', {
field: ['id', 'q', 'a', 'status'],
fields: ['id', 'q', 'a', 'status'],
where: [['user_id', userId], 'AND', ['model_id', modelId]],
order: [{ field: 'id', mode: 'DESC' }],
limit: pageSize,

View File

@@ -3,11 +3,8 @@ import { jsonRes } from '@/service/response';
import { connectToDatabase, Model } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import { generateVector } from '@/service/events/generateVector';
import { connectRedis } from '@/service/redis';
import { VecModelDataPrefix, ModelDataStatusEnum } from '@/constants/redis';
import { VecModelDataIdx } from '@/constants/redis';
import { customAlphabet } from 'nanoid';
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
import { ModelDataStatusEnum } from '@/constants/model';
import { PgClient } from '@/service/pg';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
@@ -29,7 +26,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
const userId = await authToken(authorization);
await connectToDatabase();
const redis = await connectRedis();
// 验证是否是该用户的 model
const model = await Model.findOne({
@@ -47,10 +43,18 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
try {
q = q.replace(/\\n/g, '\n');
a = a.replace(/\\n/g, '\n');
const redisSearch = await redis.ft.search(VecModelDataIdx, `@q:${q} @text:${a}`, {
RETURN: ['q', 'text']
const count = await PgClient.count('modelData', {
where: [
['user_id', userId],
'AND',
['model_id', modelId],
'AND',
['q', q],
'AND',
['a', a]
]
});
if (redisSearch.total > 0) {
if (count > 0) {
return Promise.reject('已经存在');
}
} catch (error) {
@@ -62,35 +66,26 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
});
})
);
// 过滤重复的内容
const filterData = searchRes
.filter((item) => item.status === 'fulfilled')
.map<{ q: string; a: string }>((item: any) => item.value);
// 插入 redis
const insertRedisRes = await Promise.allSettled(
filterData.map((item) => {
return redis.sendCommand([
'HMSET',
`${VecModelDataPrefix}:${nanoid()}`,
'userId',
userId,
'modelId',
String(modelId),
'q',
item.q,
'text',
item.a,
'status',
ModelDataStatusEnum.waiting
]);
})
);
// 插入 pg
const insertRes = await PgClient.insert('modelData', {
values: filterData.map((item) => [
{ key: 'user_id', value: userId },
{ key: 'model_id', value: modelId },
{ key: 'q', value: item.q },
{ key: 'a', value: item.a },
{ key: 'status', value: ModelDataStatusEnum.waiting }
])
});
generateVector();
jsonRes(res, {
data: insertRedisRes.filter((item) => item.status === 'fulfilled').length
data: insertRes.rowCount
});
} catch (err) {
jsonRes(res, {

View File

@@ -1,13 +1,13 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { authToken } from '@/service/utils/tools';
import { connectRedis } from '@/service/redis';
import { ModelDataStatusEnum } from '@/constants/redis';
import { generateVector } from '@/service/events/generateVector';
import { PgClient } from '@/service/pg';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
const { dataId, text, q } = req.body as { dataId: string; text: string; q?: string };
const { dataId, a, q } = req.body as { dataId: string; a: string; q?: string };
const { authorization } = req.headers;
if (!authorization) {
@@ -21,26 +21,21 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
// 凭证校验
const userId = await authToken(authorization);
const redis = await connectRedis();
// 更新 pg 内容
await PgClient.update('modelData', {
where: [['id', dataId], 'AND', ['user_id', userId]],
values: [
{ key: 'a', value: a },
...(q
? [
{ key: 'q', value: q },
{ key: 'status', value: ModelDataStatusEnum.waiting }
]
: [])
]
});
// 校验是否为该用户的数据
const dataItemUserId = await redis.hGet(dataId, 'userId');
if (dataItemUserId !== userId) {
throw new Error('无权操作');
}
// 更新
await redis.sendCommand([
'HMSET',
dataId,
...(q ? ['q', q, 'status', ModelDataStatusEnum.waiting] : []),
'text',
text
]);
if (q) {
generateVector();
}
q && generateVector();
jsonRes(res);
} catch (err) {

View File

@@ -6,13 +6,12 @@ import { getUserApiOpenai } from '@/service/utils/openai';
import { TrainingStatusEnum } from '@/constants/model';
import { TrainingItemType } from '@/types/training';
import { httpsAgent } from '@/service/utils/tools';
import { connectRedis } from '@/service/redis';
import { VecModelDataIdx } from '@/constants/redis';
import { PgClient } from '@/service/pg';
/* 获取我的模型 */
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
const { modelId } = req.query;
const { modelId } = req.query as { modelId: string };
const { authorization } = req.headers;
if (!authorization) {
@@ -37,21 +36,11 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
}
await connectToDatabase();
const redis = await connectRedis();
// 获取 redis 中模型关联的所有数据
const searchRes = await redis.ft.search(
VecModelDataIdx,
`@modelId:{${modelId}} @userId:{${userId}}`,
{
LIMIT: {
from: 0,
size: 10000
}
}
);
// 删除 redis 内容
await Promise.all(searchRes.documents.map((item) => redis.del(item.id)));
// 删除 pg 中所有该模型的数据
await PgClient.delete('modelData', {
where: [['user_id', userId], 'AND', ['model_id', modelId]]
});
// 删除对应的聊天
await Chat.deleteMany({

View File

@@ -7,12 +7,15 @@ import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } fr
import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response';
import { PassThrough } from 'stream';
import { ChatModelNameEnum, modelList, ChatModelNameMap } from '@/constants/model';
import {
ChatModelNameEnum,
modelList,
ChatModelNameMap,
ModelVectorSearchModeMap
} from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill';
import { connectRedis } from '@/service/redis';
import { VecModelDataPrefix } from '@/constants/redis';
import { vectorToBuffer } from '@/utils/tools';
import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai';
import { PgClient } from '@/service/pg';
/* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
@@ -46,7 +49,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
}
await connectToDatabase();
const redis = await connectRedis();
let startTime = Date.now();
/* 凭证校验 */
@@ -144,39 +146,29 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
// 读取对话内容
const prompts = [prompt];
// 搜索系统提示词, 按相似度从 redis 中搜出相关的 q 和 text
const redisData: any[] = await redis.sendCommand([
'FT.SEARCH',
`idx:${VecModelDataPrefix}:hash`,
`@modelId:{${String(model._id)}}=>[KNN 20 @vector $blob AS score]`,
'RETURN',
'1',
'text',
'SORTBY',
'score',
'PARAMS',
'2',
'blob',
vectorToBuffer(promptVector),
'DIALECT',
'2'
]);
// 相似度搜索
const similarity = ModelVectorSearchModeMap[model.search.mode]?.similarity || 0.22;
const vectorSearch = await PgClient.select<{ id: string; q: string; a: string }>('modelData', {
fields: ['id', 'q', 'a'],
order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }],
where: [
['model_id', model._id],
'AND',
['user_id', userId],
'AND',
`vector <=> '[${promptVector}]' < ${similarity}`
],
limit: 30
});
// 格式化响应值,获取 qa
const formatRedisPrompt: string[] = [];
for (let i = 2; i < 42; i += 2) {
const text = redisData[i]?.[1];
if (text) {
formatRedisPrompt.push(text);
}
}
const formatRedisPrompt: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`);
// textArr 筛选,最多 3000 tokens
const systemPrompt = systemPromptFilter(formatRedisPrompt, 3000);
prompts.unshift({
obj: 'SYSTEM',
value: `${model.systemPrompt} 知识库内容是最新的知识库内容为: "${systemPrompt}"`
value: `${model.systemPrompt} 知识库是最新的,下面是知识库内容:${systemPrompt}`
});
// 控制在 tokens 数量,防止超出

View File

@@ -1,22 +1,15 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { connectToDatabase, Model } from '@/service/mongo';
import {
httpsAgent,
openaiChatFilter,
systemPromptFilter,
authOpenApiKey
} from '@/service/utils/tools';
import { httpsAgent, systemPromptFilter, authOpenApiKey } from '@/service/utils/tools';
import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai';
import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response';
import { PassThrough } from 'stream';
import { modelList, ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill';
import { connectRedis } from '@/service/redis';
import { VecModelDataPrefix } from '@/constants/redis';
import { vectorToBuffer } from '@/utils/tools';
import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai';
import dayjs from 'dayjs';
import { PgClient } from '@/service/pg';
/* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
@@ -56,7 +49,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
}
await connectToDatabase();
const redis = await connectRedis();
let startTime = Date.now();
/* 凭证校验 */
@@ -84,38 +76,22 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
text: prompts[prompts.length - 1].value // 取最后一个
});
// 搜索系统提示词, 按相似度从 redis 中搜出相关的 q 和 text
// 相似度搜素
const similarity = ModelVectorSearchModeMap[model.search.mode]?.similarity || 0.22;
// 搜索系统提示词, 按相似度从 redis 中搜出相关的 q 和 text
const redisData: any[] = await redis.sendCommand([
'FT.SEARCH',
`idx:${VecModelDataPrefix}:hash`,
`@modelId:{${modelId}} @vector:[VECTOR_RANGE ${similarity} $blob]=>{$YIELD_DISTANCE_AS: score}`,
'RETURN',
'1',
'text',
'SORTBY',
'score',
'PARAMS',
'2',
'blob',
vectorToBuffer(promptVector),
'LIMIT',
'0',
'30',
'DIALECT',
'2'
]);
const vectorSearch = await PgClient.select<{ id: string; q: string; a: string }>('modelData', {
fields: ['id', 'q', 'a'],
order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }],
where: [
['model_id', model._id],
'AND',
['user_id', userId],
'AND',
`vector <=> '[${promptVector}]' < ${similarity}`
],
limit: 30
});
const formatRedisPrompt: string[] = [];
// 格式化响应值,获取 qa
for (let i = 2; i < 61; i += 2) {
const text = redisData[i]?.[1];
if (text) {
formatRedisPrompt.push(text);
}
}
const formatRedisPrompt: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`);
// system 合并
if (prompts[0].obj === 'SYSTEM') {
@@ -145,9 +121,9 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
prompts.unshift({
obj: 'SYSTEM',
value: `${model.systemPrompt} 知识库内容回答,知识库内容为: "当前时间:${dayjs().format(
value: `${model.systemPrompt} 知识库是最新的,下面是知识库内容:当前时间${dayjs().format(
'YYYY/MM/DD HH:mm:ss'
)} ${systemPrompt}"`
)}\n${systemPrompt}`
});
}

View File

@@ -16,7 +16,7 @@ export async function generateVector(next = false): Promise<any> {
try {
// 从找出一个 status = waiting 的数据
const searchRes = await PgClient.select('modelData', {
field: ['id', 'q', 'user_id'],
fields: ['id', 'q', 'user_id'],
where: [['status', 'waiting']],
limit: 1
});

View File

@@ -34,9 +34,9 @@ export const connectPg = async () => {
type WhereProps = (string | [string, string | number])[];
type GetProps = {
field?: string[];
fields?: string[];
where?: WhereProps;
order?: { field: string; mode: 'DESC' | 'ASC' }[];
order?: { field: string; mode: 'DESC' | 'ASC' | string }[];
limit?: number;
offset?: number;
};
@@ -62,7 +62,7 @@ class Pg {
if (typeof item === 'string') {
return item;
}
const val = typeof item[1] === 'string' ? `'${item[1]}'` : item[1];
const val = typeof item[1] === 'number' ? item[1] : `'${String(item[1])}'`;
return `${item[0]}=${val}`;
})
.join(' ')}`
@@ -95,7 +95,9 @@ class Pg {
.join(',');
}
async select<T extends QueryResultRow = any>(table: string, props: GetProps) {
const sql = `SELECT ${!props.field || props.field?.length === 0 ? '*' : props.field?.join(',')}
const sql = `SELECT ${
!props.fields || props.fields?.length === 0 ? '*' : props.fields?.join(',')
}
FROM ${table}
${this.getWhereStr(props.where)}
${
@@ -123,19 +125,34 @@ class Pg {
return pg.query(sql);
}
async update(table: string, props: UpdateProps) {
if (props.values.length === 0) {
return {
rowCount: 0
};
}
const sql = `UPDATE ${table} SET ${this.getUpdateValStr(props.values)} ${this.getWhereStr(
props.where
)}`;
const pg = await connectPg();
return pg.query(sql);
}
async insert(table: string, props: InsertProps) {
if (props.values.length === 0) {
return {
rowCount: 0
};
}
const fields = props.values[0].map((item) => item.key).join(',');
const sql = `INSERT INTO ${table} (${fields}) VALUES ${this.getInsertValStr(props.values)} `;
const pg = await connectPg();
return pg.query(sql);
}
async query<T extends QueryResultRow = any>(sql: string) {
const pg = await connectPg();
return pg.query<T>(sql);
}
}
export const PgClient = new Pg();

View File

@@ -137,5 +137,5 @@ export const systemPromptFilter = (prompts: string[], maxTokens: number) => {
}
}
return splitText.slice(0, splitText.length - 1);
return splitText.slice(0, splitText.length - 1).replace(/\n+/g, '\n');
};

View File

@@ -51,23 +51,3 @@ export const Obj2Query = (obj: Record<string, string | number>) => {
}
return queryParams.toString();
};
/**
* 向量转成 float32 buffer 格式
*/
export const vectorToBuffer = (vector: number[]) => {
const npVector = new Float32Array(vector);
const buffer = Buffer.from(npVector.buffer);
return buffer;
};
export const formatVector = (vector: number[]) => {
let formattedVector = vector.slice(0, 1536); // 截取前1536个元素
if (vector.length > 1536) {
formattedVector = formattedVector.concat(Array(1536 - formattedVector.length).fill(0)); // 在后面添加0
}
return formattedVector;
};