diff --git a/src/api/model.ts b/src/api/model.ts index ad5ee7a4b..1c17d0c2f 100644 --- a/src/api/model.ts +++ b/src/api/model.ts @@ -49,8 +49,8 @@ export const postModelDataInput = (data: { data: { text: ModelDataSchema['text']; q: ModelDataSchema['q'] }[]; }) => POST(`/model/data/pushModelDataInput`, data); -export const postModelDataFileText = (modelId: string, text: string) => - POST(`/model/data/splitData`, { modelId, text }); +export const postModelDataFileText = (data: { modelId: string; text: string; prompt: string }) => + POST(`/model/data/splitData`, data); export const postModelDataJsonData = ( modelId: string, diff --git a/src/pages/api/chat/vectorGpt.ts b/src/pages/api/chat/vectorGpt.ts index 0fe940f7e..387baeee0 100644 --- a/src/pages/api/chat/vectorGpt.ts +++ b/src/pages/api/chat/vectorGpt.ts @@ -118,7 +118,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) prompts.unshift({ obj: 'SYSTEM', - value: `${model.systemPrompt} 知识库内容: "${systemPrompt}"` + value: `${model.systemPrompt} 知识库内容是最新的,知识库内容为: "${systemPrompt}"` }); // 控制在 tokens 数量,防止超出 diff --git a/src/pages/api/model/data/splitData.ts b/src/pages/api/model/data/splitData.ts index 13407e20a..4bc834c34 100644 --- a/src/pages/api/model/data/splitData.ts +++ b/src/pages/api/model/data/splitData.ts @@ -8,8 +8,8 @@ import { encode } from 'gpt-token-utils'; /* 拆分数据成QA */ export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { - const { text, modelId } = req.body as { text: string; modelId: string }; - if (!text || !modelId) { + const { text, modelId, prompt } = req.body as { text: string; modelId: string; prompt: string }; + if (!text || !modelId || !prompt) { throw new Error('参数错误'); } await connectToDatabase(); @@ -62,7 +62,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) userId, modelId, rawText: text, - textList + textList, + prompt }); generateQA(); diff --git a/src/pages/model/detail/components/SelectFileModal.tsx b/src/pages/model/detail/components/SelectFileModal.tsx index 823b49bd3..1166d4769 100644 --- a/src/pages/model/detail/components/SelectFileModal.tsx +++ b/src/pages/model/detail/components/SelectFileModal.tsx @@ -8,7 +8,8 @@ import { ModalContent, ModalHeader, ModalCloseButton, - ModalBody + ModalBody, + Input } from '@chakra-ui/react'; import { useToast } from '@/hooks/useToast'; import { useSelectFile } from '@/hooks/useSelectFile'; @@ -34,6 +35,7 @@ const SelectFileModal = ({ }) => { const [selecting, setSelecting] = useState(false); const { toast } = useToast(); + const [prompt, setPrompt] = useState(''); const { File, onOpen } = useSelectFile({ fileType: fileExtension, multiple: true }); const [fileText, setFileText] = useState(''); const { openConfirm, ConfirmChild } = useConfirm({ @@ -83,7 +85,11 @@ const SelectFileModal = ({ const { mutate, isLoading } = useMutation({ mutationFn: async () => { if (!fileText) return; - await postModelDataFileText(modelId, fileText); + await postModelDataFileText({ + modelId, + text: fileText, + prompt: `下面是${prompt || '一段长文本'}` + }); toast({ title: '导入数据成功,需要一段拆解和训练', status: 'success' @@ -102,7 +108,7 @@ const SelectFileModal = ({ return ( - + 文件导入 @@ -125,6 +131,17 @@ const SelectFileModal = ({ 一共 {fileText.length} 个字,{encode(fileText).length} 个tokens + + + 下面是 + + setPrompt(e.target.value)} + size={'sm'} + /> + { if (global.generatingQA && !next) return; global.generatingQA = true; - const systemPrompt: ChatCompletionRequestMessage = { - role: 'system', - content: `总结助手。我会向你发送一段长文本,请从中总结出5至30个问题和答案,答案请尽量详细,并按以下格式返回: Q1:\nA1:\nQ2:\nA2:\n` - }; - try { const redis = await connectRedis(); // 找出一个需要生成的 dataItem @@ -63,6 +58,13 @@ export async function generateQA(next = false): Promise { // 获取 openai 请求实例 const chatAPI = getOpenAIApi(userApiKey || systemKey); + const systemPrompt: ChatCompletionRequestMessage = { + role: 'system', + content: `${ + dataItem.prompt || '下面是一段长文本' + },请从中总结出5至30个问题和答案,答案尽量详细,并按以下格式返回: Q1:\nA1:\nQ2:\nA2:\n` + }; + // 请求 chatgpt 获取回答 const response = await chatAPI .createChatCompletion( diff --git a/src/service/models/splitData.ts b/src/service/models/splitData.ts index 1e5a7fcf8..78999ff69 100644 --- a/src/service/models/splitData.ts +++ b/src/service/models/splitData.ts @@ -8,6 +8,11 @@ const SplitDataSchema = new Schema({ ref: 'user', required: true }, + prompt: { + // 拆分时的提示词 + type: String, + required: true + }, modelId: { type: Schema.Types.ObjectId, ref: 'model', diff --git a/src/types/mongoSchema.d.ts b/src/types/mongoSchema.d.ts index 10c6a0971..a0f30578f 100644 --- a/src/types/mongoSchema.d.ts +++ b/src/types/mongoSchema.d.ts @@ -69,6 +69,7 @@ export interface ModelSplitDataSchema { userId: string; modelId: string; rawText: string; + prompt: string; errorText: string; textList: string[]; }