feat: config vector model and qa model

This commit is contained in:
archer
2023-08-25 15:00:51 +08:00
parent a9970dd694
commit 6d93059e25
35 changed files with 337 additions and 196 deletions

View File

@@ -1,5 +1,5 @@
import { TrainingData } from '@/service/mongo';
import { pushSplitDataBill } from '@/service/events/pushBill';
import { pushQABill } from '@/service/events/pushBill';
import { pushDataToKb } from '@/pages/api/openapi/kb/pushData';
import { TrainingModeEnum } from '@/constants/plugin';
import { ERROR_ENUM } from '../errorCode';
@@ -60,14 +60,13 @@ export async function generateQA(): Promise<any> {
// 请求 chatgpt 获取回答
const response = await Promise.all(
[data.q].map((text) => {
const modelTokenLimit =
chatModels.find((item) => item.model === data.model)?.contextMaxToken || 16000;
const modelTokenLimit = global.qaModel.maxToken || 16000;
const messages: ChatCompletionRequestMessage[] = [
{
role: 'system',
content: `你是出题人.
${data.prompt || '我会发送一段长文本'}.
从中提取出 25 个问题和答案. 答案详细完整. 按下面格式返回:
content: `你是出题人${
data.prompt || '我会发送一段长文本'
},请从中提取出 25 个问题和答案. 答案详细完整,并按下面格式返回:
Q1:
A1:
Q2:
@@ -88,7 +87,7 @@ A2:
return chatAPI
.createChatCompletion(
{
model: data.model,
model: global.qaModel.model,
temperature: 0.8,
messages,
stream: false,
@@ -106,10 +105,9 @@ A2:
const result = formatSplitText(answer || ''); // 格式化后的QA对
console.log(`split result length: `, result.length);
// 计费
pushSplitDataBill({
pushQABill({
userId: data.userId,
totalTokens,
model: data.model,
appName: 'QA 拆分'
});
return {
@@ -135,7 +133,6 @@ A2:
source: data.source
})),
userId,
model: global.vectorModels[0].model,
mode: TrainingModeEnum.index
});

View File

@@ -38,7 +38,7 @@ export async function generateVector(): Promise<any> {
q: 1,
a: 1,
source: 1,
model: 1
vectorModel: 1
});
// task preemption
@@ -61,7 +61,7 @@ export async function generateVector(): Promise<any> {
// 生成词向量
const { vectors } = await getVector({
model: data.model,
model: data.vectorModel,
input: dataItems.map((item) => item.q),
userId
});

View File

@@ -76,13 +76,11 @@ export const updateShareChatBill = async ({
}
};
export const pushSplitDataBill = async ({
export const pushQABill = async ({
userId,
totalTokens,
model,
appName
}: {
model: string;
userId: string;
totalTokens: number;
appName: string;
@@ -95,7 +93,7 @@ export const pushSplitDataBill = async ({
await connectToDatabase();
// 获取模型单价格, 都是用 gpt35 拆分
const unitPrice = global.chatModels.find((item) => item.model === model)?.price || 3;
const unitPrice = global.qaModel.price || 3;
// 计算价格
const total = unitPrice * totalTokens;

View File

@@ -19,7 +19,7 @@ const kbSchema = new Schema({
type: String,
required: true
},
model: {
vectorModel: {
type: String,
required: true,
default: 'text-embedding-ada-002'

View File

@@ -28,9 +28,10 @@ const TrainingDataSchema = new Schema({
enum: Object.keys(TrainingTypeMap),
required: true
},
model: {
vectorModel: {
type: String,
required: true
required: true,
default: 'text-embedding-ada-002'
},
prompt: {
// qa split prompt

View File

@@ -181,7 +181,7 @@ export const dispatchChatCompletion = async (props: Record<string, any>): Promis
tokens: totalTokens,
question: userChatInput,
answer: answerText,
maxToken,
maxToken: max_tokens,
quoteList: filterQuoteQA,
completeMessages
},
@@ -237,7 +237,7 @@ function getChatMessages({
}) {
const limitText = (() => {
if (limitPrompt)
return `Use the provided content delimited by triple quotes to answer questions.${limitPrompt}`;
return `Use the provided content delimited by triple quotes to answer questions. ${limitPrompt}`;
if (quotePrompt && !limitPrompt) {
return `Use the provided content delimited by triple quotes to answer questions.Your task is to answer the question using only the provided content. If the content does not contain the information needed to answer this question then simply write: "你的问题没有在知识库中体现".`;
}

View File

@@ -4,11 +4,7 @@ export const getChatModel = (model?: string) => {
export const getVectorModel = (model?: string) => {
return global.vectorModels.find((item) => item.model === model);
};
export const getQAModel = (model?: string) => {
return global.qaModels.find((item) => item.model === model);
};
export const getModel = (model?: string) => {
return [...global.chatModels, ...global.vectorModels, ...global.qaModels].find(
(item) => item.model === model
);
return [...global.chatModels, ...global.vectorModels].find((item) => item.model === model);
};