mirror of
https://github.com/labring/FastGPT.git
synced 2026-05-10 01:08:08 +08:00
rerank config (#6892)
This commit is contained in:
@@ -110,7 +110,8 @@ export type EmbeddingModelItemType = z.infer<typeof EmbeddingModelItemSchema>;
|
||||
|
||||
export const RerankModelItemSchema = PriceTypeSchema.extend(BaseModelItemSchema.shape).extend({
|
||||
type: z.literal(ModelTypeEnum.rerank),
|
||||
maxToken: z.number().optional() // max input token for rerank query + one document
|
||||
maxToken: z.number().optional(), // max input token for rerank query + one document
|
||||
defaultConfig: z.record(z.string(), z.any()).optional() // post request config
|
||||
});
|
||||
export type RerankModelItemType = z.infer<typeof RerankModelItemSchema>;
|
||||
|
||||
|
||||
@@ -95,22 +95,21 @@ export async function reRankRecall({
|
||||
|
||||
// 模型的请求 url,允许是内网
|
||||
const requestUrl = model.requestUrl ? model.requestUrl : `${baseUrl}/rerank`;
|
||||
const requestBody = {
|
||||
model: model.model,
|
||||
query,
|
||||
documents: documentsTextArray,
|
||||
...model.defaultConfig
|
||||
};
|
||||
|
||||
const apiResult = await axiosWithoutSSRF
|
||||
.post<PostReRankResponse>(
|
||||
requestUrl,
|
||||
{
|
||||
model: model.model,
|
||||
query,
|
||||
documents: documentsTextArray
|
||||
.post<PostReRankResponse>(requestUrl, requestBody, {
|
||||
headers: {
|
||||
Authorization: model.requestAuth ? `Bearer ${model.requestAuth}` : authorization,
|
||||
...headers
|
||||
},
|
||||
{
|
||||
headers: {
|
||||
Authorization: model.requestAuth ? `Bearer ${model.requestAuth}` : authorization,
|
||||
...headers
|
||||
},
|
||||
timeout: 30000
|
||||
}
|
||||
)
|
||||
timeout: 30000
|
||||
})
|
||||
.then((res) => res.data)
|
||||
.then(async (data) => {
|
||||
if (!data?.results || data?.results?.length === 0) {
|
||||
@@ -126,13 +125,15 @@ export async function reRankRecall({
|
||||
logger.info('Rerank completed', { durationMs: time });
|
||||
}
|
||||
|
||||
const providerResults = data.results ?? [];
|
||||
|
||||
const existsId = new Set<string>();
|
||||
const results: {
|
||||
id: string;
|
||||
score: number;
|
||||
}[] = [];
|
||||
|
||||
data.results.forEach((item) => {
|
||||
providerResults.forEach((item) => {
|
||||
const chunkId = expandedDocuments[item.index].id;
|
||||
const docId = chunkIdToDocIdMap.get(chunkId);
|
||||
// 因为 data.results 是从高到低的,如果高分的同一个docId,则低分的不用处理
|
||||
|
||||
@@ -264,6 +264,83 @@ describe('reRankRecall', () => {
|
||||
expect(url.endsWith('/rerank')).toBe(true);
|
||||
});
|
||||
|
||||
it('合并 defaultConfig 到请求体,透传 top_k 参数', async () => {
|
||||
mockAxiosPost.mockResolvedValueOnce({
|
||||
id: 'r1',
|
||||
results: [{ index: 0, relevance_score: 0.8 }],
|
||||
meta: { tokens: { input_tokens: 5, output_tokens: 0 } }
|
||||
});
|
||||
|
||||
const result = await reRankRecall({
|
||||
model: {
|
||||
...mockModel,
|
||||
defaultConfig: {
|
||||
top_k: 1
|
||||
}
|
||||
},
|
||||
query: 'q',
|
||||
documents: [
|
||||
{ id: 'doc1', text: 'hello' },
|
||||
{ id: 'doc2', text: 'world' }
|
||||
]
|
||||
});
|
||||
|
||||
expect(mockAxiosPost).toHaveBeenCalledWith(
|
||||
expect.any(String),
|
||||
expect.objectContaining({
|
||||
model: mockModel.model,
|
||||
query: 'q',
|
||||
documents: ['hello', 'world'],
|
||||
top_k: 1
|
||||
}),
|
||||
expect.any(Object)
|
||||
);
|
||||
expect(result.results).toEqual([{ id: 'doc1', score: 0.8 }]);
|
||||
});
|
||||
|
||||
it('透传 topn,不在本地做截断', async () => {
|
||||
mockAxiosPost.mockResolvedValueOnce({
|
||||
id: 'r1',
|
||||
results: [
|
||||
{ index: 2, relevance_score: 0.9 },
|
||||
{ index: 1, relevance_score: 0.8 },
|
||||
{ index: 0, relevance_score: 0.7 }
|
||||
],
|
||||
meta: { tokens: { input_tokens: 5, output_tokens: 0 } }
|
||||
});
|
||||
|
||||
const result = await reRankRecall({
|
||||
model: {
|
||||
...mockModel,
|
||||
defaultConfig: {
|
||||
topn: 1
|
||||
}
|
||||
},
|
||||
query: 'q',
|
||||
documents: [
|
||||
{ id: 'doc1', text: 'hello' },
|
||||
{ id: 'doc2', text: 'world' },
|
||||
{ id: 'doc3', text: 'fastgpt' }
|
||||
]
|
||||
});
|
||||
|
||||
expect(mockAxiosPost).toHaveBeenCalledWith(
|
||||
expect.any(String),
|
||||
expect.objectContaining({
|
||||
model: mockModel.model,
|
||||
query: 'q',
|
||||
documents: ['hello', 'world', 'fastgpt'],
|
||||
topn: 1
|
||||
}),
|
||||
expect.any(Object)
|
||||
);
|
||||
expect(result.results).toEqual([
|
||||
{ id: 'doc3', score: 0.9 },
|
||||
{ id: 'doc2', score: 0.8 },
|
||||
{ id: 'doc1', score: 0.7 }
|
||||
]);
|
||||
});
|
||||
|
||||
// ── 异常场景 ──────────────────────────────────────────────────────────────
|
||||
|
||||
it('model 为 undefined 时 reject', async () => {
|
||||
|
||||
@@ -39,6 +39,7 @@
|
||||
"model.defaultConfig_tip": "Each request will carry this additional Body parameter.",
|
||||
"model.default_config": "Body extra fields",
|
||||
"model.default_config_tip": "When initiating a conversation request, merge this configuration. \nFor example:\n\"\"\"\n{\n \"temperature\": 1,\n \"max_tokens\": null\n}\n\"\"\"",
|
||||
"model.rerank_default_config_tip": "Merge this config when sending rerank requests. \nFor example:\n\"\"\"\n{\n \"topn\": 5\n}\n\"\"\"",
|
||||
"model.default_model": "Default model",
|
||||
"model.default_system_chat_prompt": "Default prompt",
|
||||
"model.default_system_chat_prompt_tip": "When the model talks, it will carry this default prompt word.",
|
||||
|
||||
@@ -39,6 +39,7 @@
|
||||
"model.defaultConfig_tip": "每次请求时候,都会携带该额外 Body 参数",
|
||||
"model.default_config": "Body 额外字段",
|
||||
"model.default_config_tip": "发起对话请求时候,合并该配置。例如:\n\"\"\"\n{\n \"temperature\": 1,\n \"max_tokens\": null\n}\n\"\"\"",
|
||||
"model.rerank_default_config_tip": "发起重排请求时候,合并该配置。例如:\n\"\"\"\n{\n \"topn\": 5\n}\n\"\"\"",
|
||||
"model.default_model": "默认模型",
|
||||
"model.default_system_chat_prompt": "默认提示词",
|
||||
"model.default_system_chat_prompt_tip": "模型对话时,都会携带该默认提示词",
|
||||
|
||||
@@ -39,6 +39,7 @@
|
||||
"model.defaultConfig_tip": "每次請求時候,都會攜帶該額外 Body 參數",
|
||||
"model.default_config": "Body 額外欄位",
|
||||
"model.default_config_tip": "發起對話請求時候,合併該設定。例如:\n\"\"\"\n{\n \"temperature\": 1,\n \"max_tokens\": null\n}\n\"\"\"",
|
||||
"model.rerank_default_config_tip": "發起重排請求時候,合併該設定。例如:\n\"\"\"\n{\n \"topn\": 5\n}\n\"\"\"",
|
||||
"model.default_model": "預設模型",
|
||||
"model.default_system_chat_prompt": "預設提示詞",
|
||||
"model.default_system_chat_prompt_tip": "模型對話時,都會攜帶該預設提示詞",
|
||||
|
||||
+1
-1
Submodule pro updated: 1dc6ead607...ee1b1d779d
@@ -1149,17 +1149,21 @@ export const ModelEditModal = ({
|
||||
/>
|
||||
</Field>
|
||||
)}
|
||||
{(isLLMModel || isEmbeddingModel) && (
|
||||
{(isLLMModel || isEmbeddingModel || isRerankModel) && (
|
||||
<DefaultConfigField
|
||||
control={control}
|
||||
setValue={setValue}
|
||||
label={
|
||||
isLLMModel ? t('account:model.default_config') : t('account:model.defaultConfig')
|
||||
isEmbeddingModel
|
||||
? t('account:model.defaultConfig')
|
||||
: t('account:model.default_config')
|
||||
}
|
||||
tip={
|
||||
isLLMModel
|
||||
? t('account:model.default_config_tip')
|
||||
: t('account:model.defaultConfig_tip')
|
||||
isEmbeddingModel
|
||||
? t('account:model.defaultConfig_tip')
|
||||
: isRerankModel
|
||||
? t('account:model.rerank_default_config_tip')
|
||||
: t('account:model.default_config_tip')
|
||||
}
|
||||
/>
|
||||
)}
|
||||
|
||||
Reference in New Issue
Block a user