mirror of
https://github.com/labring/FastGPT.git
synced 2025-08-03 05:19:51 +00:00
perf: chat framwork
This commit is contained in:
@@ -1,14 +1,18 @@
|
||||
import { Configuration, OpenAIApi } from 'openai';
|
||||
import type { NextApiRequest } from 'next';
|
||||
import jwt from 'jsonwebtoken';
|
||||
import { Chat, Model, OpenApi, User } from '../mongo';
|
||||
import type { ModelSchema } from '@/types/mongoSchema';
|
||||
import { getOpenApiKey } from './openai';
|
||||
import type { ChatItemSimpleType } from '@/types/chat';
|
||||
import mongoose from 'mongoose';
|
||||
import { defaultModel } from '@/constants/model';
|
||||
import { formatPrice } from '@/utils/user';
|
||||
import { ERROR_ENUM } from '../errorCode';
|
||||
import {
|
||||
ChatModelType,
|
||||
OpenAiChatEnum,
|
||||
embeddingModel,
|
||||
EmbeddingModelType
|
||||
} from '@/constants/model';
|
||||
|
||||
/* 校验 token */
|
||||
export const authToken = (token?: string): Promise<string> => {
|
||||
@@ -29,13 +33,63 @@ export const authToken = (token?: string): Promise<string> => {
|
||||
});
|
||||
};
|
||||
|
||||
export const getOpenAIApi = (apiKey: string) => {
|
||||
const configuration = new Configuration({
|
||||
apiKey,
|
||||
basePath: process.env.OPENAI_BASE_URL
|
||||
});
|
||||
/* 获取 api 请求的 key */
|
||||
export const getApiKey = async ({
|
||||
model,
|
||||
userId
|
||||
}: {
|
||||
model: ChatModelType | EmbeddingModelType;
|
||||
userId: string;
|
||||
}) => {
|
||||
const user = await User.findById(userId);
|
||||
if (!user) {
|
||||
return Promise.reject({
|
||||
code: 501,
|
||||
message: '找不到用户'
|
||||
});
|
||||
}
|
||||
|
||||
return new OpenAIApi(configuration);
|
||||
const keyMap = {
|
||||
[OpenAiChatEnum.GPT35]: {
|
||||
userApiKey: user.openaiKey || '',
|
||||
systemApiKey: process.env.OPENAIKEY as string
|
||||
},
|
||||
[OpenAiChatEnum.GPT4]: {
|
||||
userApiKey: user.openaiKey || '',
|
||||
systemApiKey: process.env.OPENAIKEY as string
|
||||
},
|
||||
[OpenAiChatEnum.GPT432k]: {
|
||||
userApiKey: user.openaiKey || '',
|
||||
systemApiKey: process.env.OPENAIKEY as string
|
||||
},
|
||||
[embeddingModel]: {
|
||||
userApiKey: user.openaiKey || '',
|
||||
systemApiKey: process.env.OPENAIKEY as string
|
||||
}
|
||||
};
|
||||
|
||||
// 有自己的key
|
||||
if (keyMap[model].userApiKey) {
|
||||
return {
|
||||
user,
|
||||
userApiKey: keyMap[model].userApiKey,
|
||||
systemApiKey: ''
|
||||
};
|
||||
}
|
||||
|
||||
// 平台账号余额校验
|
||||
if (formatPrice(user.balance) <= 0) {
|
||||
return Promise.reject({
|
||||
code: 501,
|
||||
message: '账号余额不足'
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
user,
|
||||
userApiKey: '',
|
||||
systemApiKey: keyMap[model].systemApiKey
|
||||
};
|
||||
};
|
||||
|
||||
// 模型使用权校验
|
||||
@@ -122,11 +176,11 @@ export const authChat = async ({
|
||||
]);
|
||||
}
|
||||
// 获取 user 的 apiKey
|
||||
const { userApiKey, systemKey } = await getOpenApiKey(userId);
|
||||
const { userApiKey, systemApiKey } = await getApiKey({ model: model.chat.chatModel, userId });
|
||||
|
||||
return {
|
||||
userApiKey,
|
||||
systemKey,
|
||||
systemApiKey,
|
||||
content,
|
||||
userId,
|
||||
model,
|
||||
|
155
src/service/utils/chat/index.ts
Normal file
155
src/service/utils/chat/index.ts
Normal file
@@ -0,0 +1,155 @@
|
||||
import { ChatItemSimpleType } from '@/types/chat';
|
||||
import { modelToolMap } from '@/utils/chat';
|
||||
import type { ChatModelType } from '@/constants/model';
|
||||
import { ChatRoleEnum, SYSTEM_PROMPT_PREFIX } from '@/constants/chat';
|
||||
import { OpenAiChatEnum } from '@/constants/model';
|
||||
import { chatResponse, openAiStreamResponse } from './openai';
|
||||
import type { NextApiResponse } from 'next';
|
||||
import type { PassThrough } from 'stream';
|
||||
|
||||
export type ChatCompletionType = {
|
||||
apiKey: string;
|
||||
temperature: number;
|
||||
messages: ChatItemSimpleType[];
|
||||
stream: boolean;
|
||||
};
|
||||
export type StreamResponseType = {
|
||||
stream: PassThrough;
|
||||
chatResponse: any;
|
||||
prompts: ChatItemSimpleType[];
|
||||
};
|
||||
|
||||
export const modelServiceToolMap = {
|
||||
[OpenAiChatEnum.GPT35]: {
|
||||
chatCompletion: (data: ChatCompletionType) =>
|
||||
chatResponse({ model: OpenAiChatEnum.GPT35, ...data }),
|
||||
streamResponse: (data: StreamResponseType) =>
|
||||
openAiStreamResponse({
|
||||
model: OpenAiChatEnum.GPT35,
|
||||
...data
|
||||
})
|
||||
},
|
||||
[OpenAiChatEnum.GPT4]: {
|
||||
chatCompletion: (data: ChatCompletionType) =>
|
||||
chatResponse({ model: OpenAiChatEnum.GPT4, ...data }),
|
||||
streamResponse: (data: StreamResponseType) =>
|
||||
openAiStreamResponse({
|
||||
model: OpenAiChatEnum.GPT4,
|
||||
...data
|
||||
})
|
||||
},
|
||||
[OpenAiChatEnum.GPT432k]: {
|
||||
chatCompletion: (data: ChatCompletionType) =>
|
||||
chatResponse({ model: OpenAiChatEnum.GPT432k, ...data }),
|
||||
streamResponse: (data: StreamResponseType) =>
|
||||
openAiStreamResponse({
|
||||
model: OpenAiChatEnum.GPT432k,
|
||||
...data
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
/* delete invalid symbol */
|
||||
const simplifyStr = (str: string) =>
|
||||
str
|
||||
.replace(/\n+/g, '\n') // 连续空行
|
||||
.replace(/[^\S\r\n]+/g, ' ') // 连续空白内容
|
||||
.trim();
|
||||
|
||||
/* 聊天上下文 tokens 截断 */
|
||||
export const ChatContextFilter = ({
|
||||
model,
|
||||
prompts,
|
||||
maxTokens
|
||||
}: {
|
||||
model: ChatModelType;
|
||||
prompts: ChatItemSimpleType[];
|
||||
maxTokens: number;
|
||||
}) => {
|
||||
let rawTextLen = 0;
|
||||
const formatPrompts = prompts.map<ChatItemSimpleType>((item) => {
|
||||
const val = simplifyStr(item.value);
|
||||
rawTextLen += val.length;
|
||||
return {
|
||||
obj: item.obj,
|
||||
value: val
|
||||
};
|
||||
});
|
||||
|
||||
// 长度太小时,不需要进行 token 截断
|
||||
if (formatPrompts.length <= 2 || rawTextLen < maxTokens * 0.5) {
|
||||
return formatPrompts;
|
||||
}
|
||||
|
||||
// 根据 tokens 截断内容
|
||||
const chats: ChatItemSimpleType[] = [];
|
||||
let systemPrompt: ChatItemSimpleType | null = null;
|
||||
|
||||
// System 词保留
|
||||
if (formatPrompts[0].obj === ChatRoleEnum.System) {
|
||||
const prompt = formatPrompts.shift();
|
||||
if (prompt) {
|
||||
systemPrompt = prompt;
|
||||
}
|
||||
}
|
||||
|
||||
let messages: ChatItemSimpleType[] = [];
|
||||
|
||||
// 从后往前截取对话内容
|
||||
for (let i = formatPrompts.length - 1; i >= 0; i--) {
|
||||
chats.unshift(formatPrompts[i]);
|
||||
|
||||
messages = systemPrompt ? [systemPrompt, ...chats] : chats;
|
||||
|
||||
const tokens = modelToolMap[model].countTokens({
|
||||
messages
|
||||
});
|
||||
|
||||
/* 整体 tokens 超出范围 */
|
||||
if (tokens >= maxTokens) {
|
||||
return systemPrompt ? [systemPrompt, ...chats.slice(1)] : chats.slice(1);
|
||||
}
|
||||
}
|
||||
|
||||
return messages;
|
||||
};
|
||||
|
||||
/* stream response */
|
||||
export const resStreamResponse = async ({
|
||||
model,
|
||||
res,
|
||||
stream,
|
||||
chatResponse,
|
||||
systemPrompt,
|
||||
prompts
|
||||
}: StreamResponseType & {
|
||||
model: ChatModelType;
|
||||
res: NextApiResponse;
|
||||
systemPrompt?: string;
|
||||
}) => {
|
||||
// 创建响应流
|
||||
res.setHeader('Content-Type', 'text/event-stream;charset-utf-8');
|
||||
res.setHeader('Access-Control-Allow-Origin', '*');
|
||||
res.setHeader('X-Accel-Buffering', 'no');
|
||||
res.setHeader('Cache-Control', 'no-cache, no-transform');
|
||||
stream.pipe(res);
|
||||
|
||||
const { responseContent, totalTokens, finishMessages } = await modelServiceToolMap[
|
||||
model
|
||||
].streamResponse({
|
||||
chatResponse,
|
||||
stream,
|
||||
prompts
|
||||
});
|
||||
|
||||
// push system prompt
|
||||
!stream.destroyed &&
|
||||
systemPrompt &&
|
||||
stream.push(`${SYSTEM_PROMPT_PREFIX}${systemPrompt.replace(/\n/g, '<br/>')}`);
|
||||
|
||||
// close stream
|
||||
!stream.destroyed && stream.push(null);
|
||||
stream.destroy();
|
||||
|
||||
return { responseContent, totalTokens, finishMessages };
|
||||
};
|
174
src/service/utils/chat/openai.ts
Normal file
174
src/service/utils/chat/openai.ts
Normal file
@@ -0,0 +1,174 @@
|
||||
import { Configuration, OpenAIApi } from 'openai';
|
||||
import { createParser, ParsedEvent, ReconnectInterval } from 'eventsource-parser';
|
||||
import { axiosConfig } from '../tools';
|
||||
import { ChatModelMap, embeddingModel, OpenAiChatEnum } from '@/constants/model';
|
||||
import { pushGenerateVectorBill } from '../../events/pushBill';
|
||||
import { adaptChatItem_openAI } from '@/utils/chat/openai';
|
||||
import { modelToolMap } from '@/utils/chat';
|
||||
import { ChatCompletionType, ChatContextFilter, StreamResponseType } from './index';
|
||||
import { ChatRoleEnum } from '@/constants/chat';
|
||||
|
||||
export const getOpenAIApi = (apiKey: string) => {
|
||||
const configuration = new Configuration({
|
||||
apiKey,
|
||||
basePath: process.env.OPENAI_BASE_URL
|
||||
});
|
||||
|
||||
return new OpenAIApi(configuration);
|
||||
};
|
||||
|
||||
/* 获取向量 */
|
||||
export const openaiCreateEmbedding = async ({
|
||||
userApiKey,
|
||||
systemApiKey,
|
||||
userId,
|
||||
text
|
||||
}: {
|
||||
userApiKey?: string;
|
||||
systemApiKey: string;
|
||||
userId: string;
|
||||
text: string;
|
||||
}) => {
|
||||
// 获取 chatAPI
|
||||
const chatAPI = getOpenAIApi(userApiKey || systemApiKey);
|
||||
|
||||
// 把输入的内容转成向量
|
||||
const res = await chatAPI
|
||||
.createEmbedding(
|
||||
{
|
||||
model: embeddingModel,
|
||||
input: text
|
||||
},
|
||||
{
|
||||
timeout: 60000,
|
||||
...axiosConfig()
|
||||
}
|
||||
)
|
||||
.then((res) => ({
|
||||
tokenLen: res.data.usage.total_tokens || 0,
|
||||
vector: res.data.data?.[0]?.embedding || []
|
||||
}));
|
||||
|
||||
pushGenerateVectorBill({
|
||||
isPay: !userApiKey,
|
||||
userId,
|
||||
text,
|
||||
tokenLen: res.tokenLen
|
||||
});
|
||||
|
||||
return {
|
||||
vector: res.vector,
|
||||
chatAPI
|
||||
};
|
||||
};
|
||||
|
||||
/* 模型对话 */
|
||||
export const chatResponse = async ({
|
||||
model,
|
||||
apiKey,
|
||||
temperature,
|
||||
messages,
|
||||
stream
|
||||
}: ChatCompletionType & { model: `${OpenAiChatEnum}` }) => {
|
||||
const filterMessages = ChatContextFilter({
|
||||
model,
|
||||
prompts: messages,
|
||||
maxTokens: Math.ceil(ChatModelMap[model].contextMaxToken * 0.9)
|
||||
});
|
||||
|
||||
const adaptMessages = adaptChatItem_openAI({ messages: filterMessages });
|
||||
const chatAPI = getOpenAIApi(apiKey);
|
||||
|
||||
const response = await chatAPI.createChatCompletion(
|
||||
{
|
||||
model,
|
||||
temperature: Number(temperature) || 0,
|
||||
messages: adaptMessages,
|
||||
frequency_penalty: 0.5, // 越大,重复内容越少
|
||||
presence_penalty: -0.5, // 越大,越容易出现新内容
|
||||
stream,
|
||||
stop: ['.!?。']
|
||||
},
|
||||
{
|
||||
timeout: stream ? 40000 : 240000,
|
||||
responseType: stream ? 'stream' : 'json',
|
||||
...axiosConfig()
|
||||
}
|
||||
);
|
||||
|
||||
let responseText = '';
|
||||
let totalTokens = 0;
|
||||
|
||||
// adapt data
|
||||
if (!stream) {
|
||||
responseText = response.data.choices[0].message?.content || '';
|
||||
totalTokens = response.data.usage?.total_tokens || 0;
|
||||
}
|
||||
|
||||
return {
|
||||
streamResponse: response,
|
||||
responseMessages: filterMessages.concat({ obj: 'AI', value: responseText }),
|
||||
responseText,
|
||||
totalTokens
|
||||
};
|
||||
};
|
||||
|
||||
/* openai stream response */
|
||||
export const openAiStreamResponse = async ({
|
||||
model,
|
||||
stream,
|
||||
chatResponse,
|
||||
prompts
|
||||
}: StreamResponseType & {
|
||||
model: `${OpenAiChatEnum}`;
|
||||
}) => {
|
||||
try {
|
||||
let responseContent = '';
|
||||
|
||||
const onParse = async (event: ParsedEvent | ReconnectInterval) => {
|
||||
if (event.type !== 'event') return;
|
||||
const data = event.data;
|
||||
if (data === '[DONE]') return;
|
||||
try {
|
||||
const json = JSON.parse(data);
|
||||
const content: string = json?.choices?.[0].delta.content || '';
|
||||
responseContent += content;
|
||||
|
||||
!stream.destroyed && content && stream.push(content.replace(/\n/g, '<br/>'));
|
||||
} catch (error) {
|
||||
error;
|
||||
}
|
||||
};
|
||||
|
||||
try {
|
||||
const decoder = new TextDecoder();
|
||||
const parser = createParser(onParse);
|
||||
for await (const chunk of chatResponse.data as any) {
|
||||
if (stream.destroyed) {
|
||||
// 流被中断了,直接忽略后面的内容
|
||||
break;
|
||||
}
|
||||
parser.feed(decoder.decode(chunk, { stream: true }));
|
||||
}
|
||||
} catch (error) {
|
||||
console.log('pipe error', error);
|
||||
}
|
||||
|
||||
// count tokens
|
||||
const finishMessages = prompts.concat({
|
||||
obj: ChatRoleEnum.AI,
|
||||
value: responseContent
|
||||
});
|
||||
const totalTokens = modelToolMap[model].countTokens({
|
||||
messages: finishMessages
|
||||
});
|
||||
|
||||
return {
|
||||
responseContent,
|
||||
totalTokens,
|
||||
finishMessages
|
||||
};
|
||||
} catch (error) {
|
||||
return Promise.reject(error);
|
||||
}
|
||||
};
|
@@ -1,179 +0,0 @@
|
||||
import type { NextApiResponse } from 'next';
|
||||
import type { PassThrough } from 'stream';
|
||||
import { createParser, ParsedEvent, ReconnectInterval } from 'eventsource-parser';
|
||||
import { getOpenAIApi } from '@/service/utils/auth';
|
||||
import { axiosConfig } from './tools';
|
||||
import { User } from '../models/user';
|
||||
import { formatPrice } from '@/utils/user';
|
||||
import { embeddingModel } from '@/constants/model';
|
||||
import { pushGenerateVectorBill } from '../events/pushBill';
|
||||
import { SYSTEM_PROMPT_PREFIX } from '@/constants/chat';
|
||||
|
||||
/* 获取用户 api 的 openai 信息 */
|
||||
export const getUserApiOpenai = async (userId: string) => {
|
||||
const user = await User.findById(userId);
|
||||
|
||||
const userApiKey = user?.openaiKey;
|
||||
|
||||
if (!userApiKey) {
|
||||
return Promise.reject('缺少ApiKey, 无法请求');
|
||||
}
|
||||
|
||||
return {
|
||||
user,
|
||||
openai: getOpenAIApi(userApiKey),
|
||||
apiKey: userApiKey
|
||||
};
|
||||
};
|
||||
|
||||
/* 获取 open api key,如果用户没有自己的key,就用平台的,用平台记得加账单 */
|
||||
export const getOpenApiKey = async (userId: string) => {
|
||||
const user = await User.findById(userId);
|
||||
if (!user) {
|
||||
return Promise.reject({
|
||||
code: 501,
|
||||
message: '找不到用户'
|
||||
});
|
||||
}
|
||||
|
||||
const userApiKey = user?.openaiKey;
|
||||
|
||||
// 有自己的key
|
||||
if (userApiKey) {
|
||||
return {
|
||||
user,
|
||||
userApiKey,
|
||||
systemKey: ''
|
||||
};
|
||||
}
|
||||
|
||||
// 平台账号余额校验
|
||||
if (formatPrice(user.balance) <= 0) {
|
||||
return Promise.reject({
|
||||
code: 501,
|
||||
message: '账号余额不足'
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
user,
|
||||
userApiKey: '',
|
||||
systemKey: process.env.OPENAIKEY as string
|
||||
};
|
||||
};
|
||||
|
||||
/* 获取向量 */
|
||||
export const openaiCreateEmbedding = async ({
|
||||
isPay,
|
||||
userId,
|
||||
apiKey,
|
||||
text
|
||||
}: {
|
||||
isPay: boolean;
|
||||
userId: string;
|
||||
apiKey: string;
|
||||
text: string;
|
||||
}) => {
|
||||
// 获取 chatAPI
|
||||
const chatAPI = getOpenAIApi(apiKey);
|
||||
|
||||
// 把输入的内容转成向量
|
||||
const res = await chatAPI
|
||||
.createEmbedding(
|
||||
{
|
||||
model: embeddingModel,
|
||||
input: text
|
||||
},
|
||||
{
|
||||
timeout: 60000,
|
||||
...axiosConfig()
|
||||
}
|
||||
)
|
||||
.then((res) => ({
|
||||
tokenLen: res.data.usage.total_tokens || 0,
|
||||
vector: res.data.data?.[0]?.embedding || []
|
||||
}));
|
||||
|
||||
pushGenerateVectorBill({
|
||||
isPay,
|
||||
userId,
|
||||
text,
|
||||
tokenLen: res.tokenLen
|
||||
});
|
||||
|
||||
return {
|
||||
vector: res.vector,
|
||||
chatAPI
|
||||
};
|
||||
};
|
||||
|
||||
/* gpt35 响应 */
|
||||
export const gpt35StreamResponse = ({
|
||||
res,
|
||||
stream,
|
||||
chatResponse,
|
||||
systemPrompt = ''
|
||||
}: {
|
||||
res: NextApiResponse;
|
||||
stream: PassThrough;
|
||||
chatResponse: any;
|
||||
systemPrompt?: string;
|
||||
}) =>
|
||||
new Promise<{ responseContent: string }>(async (resolve, reject) => {
|
||||
try {
|
||||
// 创建响应流
|
||||
res.setHeader('Content-Type', 'text/event-stream;charset-utf-8');
|
||||
res.setHeader('Access-Control-Allow-Origin', '*');
|
||||
res.setHeader('X-Accel-Buffering', 'no');
|
||||
res.setHeader('Cache-Control', 'no-cache, no-transform');
|
||||
stream.pipe(res);
|
||||
|
||||
let responseContent = '';
|
||||
|
||||
const onParse = async (event: ParsedEvent | ReconnectInterval) => {
|
||||
if (event.type !== 'event') return;
|
||||
const data = event.data;
|
||||
if (data === '[DONE]') return;
|
||||
try {
|
||||
const json = JSON.parse(data);
|
||||
const content: string = json?.choices?.[0].delta.content || '';
|
||||
responseContent += content;
|
||||
|
||||
if (!stream.destroyed && content) {
|
||||
stream.push(content.replace(/\n/g, '<br/>'));
|
||||
}
|
||||
} catch (error) {
|
||||
error;
|
||||
}
|
||||
};
|
||||
|
||||
try {
|
||||
const decoder = new TextDecoder();
|
||||
const parser = createParser(onParse);
|
||||
for await (const chunk of chatResponse.data as any) {
|
||||
if (stream.destroyed) {
|
||||
// 流被中断了,直接忽略后面的内容
|
||||
break;
|
||||
}
|
||||
parser.feed(decoder.decode(chunk, { stream: true }));
|
||||
}
|
||||
} catch (error) {
|
||||
console.log('pipe error', error);
|
||||
}
|
||||
|
||||
// push system prompt
|
||||
!stream.destroyed &&
|
||||
systemPrompt &&
|
||||
stream.push(`${SYSTEM_PROMPT_PREFIX}${systemPrompt.replace(/\n/g, '<br/>')}`);
|
||||
|
||||
// close stream
|
||||
!stream.destroyed && stream.push(null);
|
||||
stream.destroy();
|
||||
|
||||
resolve({
|
||||
responseContent
|
||||
});
|
||||
} catch (error) {
|
||||
reject(error);
|
||||
}
|
||||
});
|
@@ -1,9 +1,5 @@
|
||||
import crypto from 'crypto';
|
||||
import jwt from 'jsonwebtoken';
|
||||
import { ChatItemSimpleType } from '@/types/chat';
|
||||
import { countChatTokens, sliceTextByToken } from '@/utils/tools';
|
||||
import { ChatCompletionRequestMessageRoleEnum, ChatCompletionRequestMessage } from 'openai';
|
||||
import type { ChatModelType } from '@/constants/model';
|
||||
|
||||
/* 密码加密 */
|
||||
export const hashPassword = (psw: string) => {
|
||||
@@ -30,92 +26,3 @@ export const axiosConfig = () => ({
|
||||
auth: process.env.OPENAI_BASE_URL_AUTH || ''
|
||||
}
|
||||
});
|
||||
|
||||
/* delete invalid symbol */
|
||||
const simplifyStr = (str: string) =>
|
||||
str
|
||||
.replace(/\n+/g, '\n') // 连续空行
|
||||
.replace(/[^\S\r\n]+/g, ' ') // 连续空白内容
|
||||
.trim();
|
||||
|
||||
/* 聊天内容 tokens 截断 */
|
||||
export const openaiChatFilter = ({
|
||||
model,
|
||||
prompts,
|
||||
maxTokens
|
||||
}: {
|
||||
model: ChatModelType;
|
||||
prompts: ChatItemSimpleType[];
|
||||
maxTokens: number;
|
||||
}) => {
|
||||
// role map
|
||||
const map = {
|
||||
Human: ChatCompletionRequestMessageRoleEnum.User,
|
||||
AI: ChatCompletionRequestMessageRoleEnum.Assistant,
|
||||
SYSTEM: ChatCompletionRequestMessageRoleEnum.System
|
||||
};
|
||||
|
||||
let rawTextLen = 0;
|
||||
const formatPrompts = prompts.map((item) => {
|
||||
const val = simplifyStr(item.value);
|
||||
rawTextLen += val.length;
|
||||
return {
|
||||
role: map[item.obj],
|
||||
content: val
|
||||
};
|
||||
});
|
||||
|
||||
// 长度太小时,不需要进行 token 截断
|
||||
if (rawTextLen < maxTokens * 0.5) {
|
||||
return formatPrompts;
|
||||
}
|
||||
|
||||
// 根据 tokens 截断内容
|
||||
const chats: ChatCompletionRequestMessage[] = [];
|
||||
let systemPrompt: ChatCompletionRequestMessage | null = null;
|
||||
|
||||
// System 词保留
|
||||
if (formatPrompts[0]?.role === 'system') {
|
||||
systemPrompt = formatPrompts.shift() as ChatCompletionRequestMessage;
|
||||
}
|
||||
|
||||
let messages: { role: ChatCompletionRequestMessageRoleEnum; content: string }[] = [];
|
||||
|
||||
// 从后往前截取对话内容
|
||||
for (let i = formatPrompts.length - 1; i >= 0; i--) {
|
||||
chats.unshift(formatPrompts[i]);
|
||||
|
||||
messages = systemPrompt ? [systemPrompt, ...chats] : chats;
|
||||
|
||||
const tokens = countChatTokens({
|
||||
model,
|
||||
messages
|
||||
});
|
||||
|
||||
/* 整体 tokens 超出范围 */
|
||||
if (tokens >= maxTokens) {
|
||||
return systemPrompt ? [systemPrompt, ...chats.slice(1)] : chats.slice(1);
|
||||
}
|
||||
}
|
||||
|
||||
return messages;
|
||||
};
|
||||
|
||||
/* system 内容截断. 相似度从高到低 */
|
||||
export const systemPromptFilter = ({
|
||||
model,
|
||||
prompts,
|
||||
maxTokens
|
||||
}: {
|
||||
model: 'gpt-4' | 'gpt-4-32k' | 'gpt-3.5-turbo';
|
||||
prompts: string[];
|
||||
maxTokens: number;
|
||||
}) => {
|
||||
const systemPrompt = prompts.join('\n');
|
||||
|
||||
return sliceTextByToken({
|
||||
model,
|
||||
text: systemPrompt,
|
||||
length: maxTokens
|
||||
});
|
||||
};
|
||||
|
Reference in New Issue
Block a user