feat: function call prompt version (#331)

This commit is contained in:
Archer
2023-09-21 12:27:48 +08:00
committed by GitHub
parent 7e0deb29e0
commit e367265dbb
12 changed files with 364 additions and 120 deletions

View File

@@ -4,40 +4,69 @@ import type { ChatHistoryItemResType, ChatItemType } from '@/types/chat';
import { ChatRoleEnum, TaskResponseKeyEnum } from '@/constants/chat';
import { getAIChatApi, axiosConfig } from '@/service/lib/openai';
import type { ClassifyQuestionAgentItemType } from '@/types/app';
import { countModelPrice } from '@/service/events/pushBill';
import { getModel } from '@/service/utils/data';
import { SystemInputEnum } from '@/constants/app';
import { SpecialInputKeyEnum } from '@/constants/flow';
import { FlowModuleTypeEnum } from '@/constants/flow';
import { ModuleDispatchProps } from '@/types/core/modules';
import { replaceVariable } from '@/utils/common/tools/text';
import { Prompt_CQJson } from '@/prompts/core/agent';
export type CQProps = ModuleDispatchProps<{
type Props = ModuleDispatchProps<{
systemPrompt?: string;
history?: ChatItemType[];
[SystemInputEnum.userChatInput]: string;
[SpecialInputKeyEnum.agents]: ClassifyQuestionAgentItemType[];
}>;
export type CQResponse = {
type CQResponse = {
[TaskResponseKeyEnum.responseData]: ChatHistoryItemResType;
[key: string]: any;
};
const agentModel = 'gpt-3.5-turbo';
const agentFunName = 'agent_user_question';
const maxTokens = 3000;
/* request openai chat */
export const dispatchClassifyQuestion = async (props: Record<string, any>): Promise<CQResponse> => {
export const dispatchClassifyQuestion = async (props: Props): Promise<CQResponse> => {
const {
moduleName,
userOpenaiAccount,
inputs: { agents, systemPrompt, history = [], userChatInput }
} = props as CQProps;
inputs: { agents, userChatInput }
} = props as Props;
if (!userChatInput) {
return Promise.reject('Input is empty');
}
const cqModel = global.cqModel;
const { arg, tokens } = await (async () => {
if (cqModel.functionCall) {
return functionCall(props);
}
return completions(props);
})();
const result = agents.find((item) => item.key === arg?.type) || agents[0];
return {
[result.key]: 1,
[TaskResponseKeyEnum.responseData]: {
moduleType: FlowModuleTypeEnum.classifyQuestion,
moduleName,
price: userOpenaiAccount?.key ? 0 : cqModel.price * tokens,
model: cqModel.name || '',
tokens,
cqList: agents,
cqResult: result.value
}
};
};
async function functionCall({
userOpenaiAccount,
inputs: { agents, systemPrompt, history = [], userChatInput }
}: Props) {
const cqModel = global.cqModel;
const messages: ChatItemType[] = [
...(systemPrompt
? [
@@ -55,14 +84,14 @@ export const dispatchClassifyQuestion = async (props: Record<string, any>): Prom
];
const filterMessages = ChatContextFilter({
messages,
maxTokens
maxTokens: cqModel.maxToken
});
const adaptMessages = adaptChat2GptMessages({ messages: filterMessages, reserveId: false });
// function body
const agentFunction = {
name: agentFunName,
description: '判断用户问题的类型属于哪方面,返回对应的枚举字段',
description: '判断用户问题的类型属于哪方面,返回对应的字段',
parameters: {
type: 'object',
properties: {
@@ -79,7 +108,7 @@ export const dispatchClassifyQuestion = async (props: Record<string, any>): Prom
const response = await chatAPI.createChatCompletion(
{
model: agentModel,
model: cqModel.model,
temperature: 0,
messages: [...adaptMessages],
function_call: { name: agentFunName },
@@ -92,20 +121,51 @@ export const dispatchClassifyQuestion = async (props: Record<string, any>): Prom
const arg = JSON.parse(response.data.choices?.[0]?.message?.function_call?.arguments || '');
const tokens = response.data.usage?.total_tokens || 0;
return {
arg,
tokens: response.data.usage?.total_tokens || 0
};
}
const result = agents.find((item) => item.key === arg?.type) || agents[0];
async function completions({
userOpenaiAccount,
inputs: { agents, systemPrompt = '', history = [], userChatInput }
}: Props) {
const extractModel = global.extractModel;
const messages: ChatItemType[] = [
{
obj: ChatRoleEnum.Human,
value: replaceVariable(extractModel.prompt || Prompt_CQJson, {
systemPrompt,
typeList: agents.map((item) => `ID: "${item.key}", 问题类型:${item.value}`).join('\n'),
text: `${history.map((item) => `${item.obj}:${item.value}`).join('\n')}
Human:${userChatInput}`
})
}
];
const chatAPI = getAIChatApi(userOpenaiAccount);
const { data } = await chatAPI.createChatCompletion(
{
model: extractModel.model,
temperature: 0.01,
messages: adaptChat2GptMessages({ messages, reserveId: false }),
stream: false
},
{
timeout: 480000,
...axiosConfig()
}
);
const answer = data.choices?.[0].message?.content || '';
const totalTokens = data.usage?.total_tokens || 0;
const id = agents.find((item) => answer.includes(item.key))?.key || '';
return {
[result.key]: 1,
[TaskResponseKeyEnum.responseData]: {
moduleType: FlowModuleTypeEnum.classifyQuestion,
moduleName,
price: userOpenaiAccount?.key ? 0 : countModelPrice({ model: agentModel, tokens }),
model: getModel(agentModel)?.name || agentModel,
tokens,
cqList: agents,
cqResult: result.value
}
tokens: totalTokens,
arg: { type: id }
};
};
}