perf: ssrf check (#6852)

This commit is contained in:
Archer
2026-04-29 13:49:52 +08:00
committed by GitHub
parent 225cb7e62b
commit 1fd5eed88a
19 changed files with 387 additions and 280 deletions
+31
View File
@@ -2,6 +2,8 @@ import _, { type AxiosInstance, type AxiosRequestConfig } from 'axios';
import { ProxyAgent } from 'proxy-agent';
import { isDevEnv } from '@fastgpt/global/common/system/constants';
import { isInternalAddress, PRIVATE_URL_TEXT } from '../system/utils';
import { isAbsoluteUrl } from '../security/network';
import { SERVICE_LOCAL_HOST } from '../system/tools';
const addSSRFInterceptor = (instance: AxiosInstance) => {
instance.interceptors.request.use(async (config) => {
@@ -41,3 +43,32 @@ export function createProxyAxios(config?: AxiosRequestConfig, ssrfCheck = true)
/** @see https://github.com/axios/axios/issues/4531 */
export const axios = createProxyAxios();
/**
* 内部相对路径请求专用的 axios 实例:
* - baseURL 固定为本机 NextJS API
* - 不带 SSRF 拦截器(本机调用必然解析到 localhost,装拦截会把所有合法请求拦死)
* - 不复用 safe axios 的 ProxyAgent,保证内部回环不会被外部代理转走
*
* 仅在 url 是相对路径时使用;绝对 URL 必须走 safe `axios`。
*/
const internalAxios: AxiosInstance = _.create({
baseURL: `http://${SERVICE_LOCAL_HOST}`
});
/**
* 根据 URL 类型自动选择合适的 axios 实例,避免每个调用点重复
* `isAbsoluteUrl ? safe : raw` 三元。
*
* - 绝对 URL(`http(s)://...` 或 `//...`)→ safe `axios`(SSRF 拦截,拒绝内网/metadata)
* - 相对路径(`/api/...` 等)→ `internalAxios`(本机 baseURL,可信内部 API)
*
* 用法:
* ```ts
* const client = pickOutboundAxios(url);
* const res = await client.get(url, { responseType: 'arraybuffer' });
* ```
*/
export const pickOutboundAxios = (url: string): AxiosInstance => {
return isAbsoluteUrl(url) ? axios : internalAxios;
};
+17 -2
View File
@@ -8,6 +8,7 @@ import { FastGPTProUrl } from '../system/constants';
import { UserError } from '@fastgpt/global/common/error/utils';
import { createProxyAxios } from './axios';
import { getLogger, LogCategories } from '../logger';
import { assertRelativePath } from '../security/network';
const logger = getLogger(LogCategories.HTTP.ERROR);
@@ -92,6 +93,14 @@ export function request(url: string, data: any, config: ConfigType, method: Meth
return Promise.reject(new UserError('The request was denied...'));
}
// plusRequest 仅用于访问商业版 Pro 服务,会自动携带 rootkey,SSRF 拦截已被显式关闭。
// 强制要求相对路径,防止调用方传入绝对 URL 覆盖 baseURL 形成带高权限头的 SSRF。
try {
assertRelativePath(url, 'plusRequest');
} catch (err) {
return Promise.reject(err);
}
/* 去空 */
for (const key in data) {
if (data[key] === null || data[key] === undefined) {
@@ -135,8 +144,14 @@ export function DELETE<T = undefined>(url: string, data = {}, config: ConfigType
return request(url, data, config, 'DELETE');
}
export const plusRequest = (config: AxiosRequestConfig) =>
instance.request({
export const plusRequest = (config: AxiosRequestConfig) => {
try {
assertRelativePath(config.url, 'plusRequest');
} catch (err) {
return Promise.reject(err);
}
return instance.request({
...config,
baseURL: FastGPTProUrl
});
};
@@ -2,6 +2,7 @@ import { SERVICE_LOCAL_HOST } from '../system/tools';
import { type Method, type InternalAxiosRequestConfig, type AxiosResponse } from 'axios';
import { createProxyAxios } from './axios';
import { getLogger, LogCategories } from '../logger';
import { assertRelativePath } from '../security/network';
const logger = getLogger(LogCategories.HTTP.ERROR);
@@ -78,6 +79,14 @@ instance.interceptors.request.use(requestStart, (err) => Promise.reject(err));
instance.interceptors.response.use(responseSuccess, (err) => Promise.reject(err));
export function request(url: string, data: any, config: ConfigType, method: Method): any {
// serverRequest 仅用于访问本机 NextJS API,SSRF 拦截已被显式关闭。
// 强制要求相对路径,防止调用方传入绝对 URL 覆盖 baseURL 形成 SSRF。
try {
assertRelativePath(url, 'serverRequest');
} catch (err) {
return Promise.reject(err);
}
/* 去空 */
for (const key in data) {
if (data[key] === null || data[key] === undefined) {
@@ -0,0 +1,55 @@
/**
* 网络出站安全校验工具集。
*
* 这里集中导出 URL 字符串层面的轻量校验,供 serverRequest/plusRequest 等
* 内部 helper 在调用前快速短路。更复杂的 SSRF 校验(协议白名单 + DNS +
* 内网/metadata 拦截)请使用 `common/system/utils.ts` 中的 `checkUrlSafety`。
*/
/**
* 判断给定字符串是否是"绝对 URL"。
* 命中条件:
* - 以 `scheme://` 形式开头(http://、https://、ws://、ftp:// ...)
* - 以 `//` 开头(protocol-relative,会被 axios/new URL 当成绝对 URL 处理)
*
* 校验严格的目的是阻止 helper 调用方意外把绝对 URL 传进来覆盖 baseURL,
* 即使该 helper 已经显式关闭 SSRF 拦截器也不会形成 SSRF。
*/
export const isAbsoluteUrl = (url: unknown): boolean => {
if (typeof url !== 'string') return false;
return /^[a-z][a-z0-9+.-]*:\/\//i.test(url) || url.startsWith('//');
};
/**
* 强制要求传入的 URL 是相对路径,否则 reject。
* 适用于"按设计只访问内部固定 baseURL"的内部 helper。
*
* 例: serverRequest(本机 NextJS API)、plusRequest(商业版 Pro 服务)。
*/
export const assertRelativePath = (url: unknown, helperName = 'request'): void => {
if (typeof url !== 'string' || isAbsoluteUrl(url)) {
throw new Error(`${helperName} only accepts relative paths, absolute URLs are not allowed`);
}
};
/**
* 在用 `new URL(path, base)` 构造目标 URL 后,强制校验最终 origin 与 base 一致。
*
* 防御 "protocol-relative URL" 主机覆盖:
* - `new URL('//169.254.169.254/foo', 'http://internal:3000')` → host 被替换
* - NextJS catch-all `[...path]` 中,`/api//evil/...` 会被拆成 `['', 'evil', ...]`,
* join 回去就构造出 `//evil/...` 这种 protocol-relative path
*
* 用法:
* const target = buildSameOriginUrl(requestPath, baseUrl); // 抛错 = 攻击
*/
export const buildSameOriginUrl = (path: string, base: string): URL => {
const baseUrl = new URL(base);
const target = new URL(path, baseUrl);
if (target.origin !== baseUrl.origin) {
throw new Error(
`Refused: target URL origin (${target.origin}) does not match base (${baseUrl.origin})`
);
}
return target;
};
+27
View File
@@ -178,3 +178,30 @@ export const isInternalAddress = async (url: string): Promise<boolean> => {
};
export const PRIVATE_URL_TEXT = 'Request to private network not allowed';
/**
* 用于"保存配置 URL"或"调用前校验"的统一安全检查:
* - 必须是合法 URL
* - 协议必须是 http/https
* - 不能指向内部地址(loopback/metadata,以及在 CHECK_INTERNAL_IP=true 时的私网)
*
* 注意:`isInternalAddress` 在 dev 环境直接放行;为了让保存入口
* 在 dev 也能拒绝明显错误的 URL(localhost / metadata),
* 这里**不依赖 isDevEnv**,而是用同一套规则做轻量校验。
*/
export const checkUrlSafety = async (url: string, fieldName = 'URL'): Promise<void> => {
let parsed: URL;
try {
parsed = new URL(url);
} catch {
return Promise.reject(new Error(`${fieldName} must be a valid URL`));
}
if (parsed.protocol !== 'http:' && parsed.protocol !== 'https:') {
return Promise.reject(new Error(`${fieldName} must use http or https protocol`));
}
if (await isInternalAddress(url)) {
return Promise.reject(new Error(`${fieldName}: ${PRIVATE_URL_TEXT}`));
}
};
+18 -15
View File
@@ -1,4 +1,4 @@
import { POST } from '../../../common/api/serverRequest';
import { axios } from '../../../common/api/axios';
import { getDefaultRerankModel } from '../model';
import { getAxiosConfig } from '../config';
import { type RerankModelItemType } from '@fastgpt/global/core/ai/model.schema';
@@ -93,21 +93,24 @@ export async function reRankRecall({
const { baseUrl, authorization } = getAxiosConfig();
const start = Date.now();
const apiResult = await POST<PostReRankResponse>(
model.requestUrl ? model.requestUrl : `${baseUrl}/rerank`,
{
model: model.model,
query,
documents: documentsTextArray
},
{
headers: {
Authorization: model.requestAuth ? `Bearer ${model.requestAuth}` : authorization,
...headers
const requestUrl = model.requestUrl ? model.requestUrl : `${baseUrl}/rerank`;
const apiResult = await axios
.post<PostReRankResponse>(
requestUrl,
{
model: model.model,
query,
documents: documentsTextArray
},
timeout: 30000
}
)
{
headers: {
Authorization: model.requestAuth ? `Bearer ${model.requestAuth}` : authorization,
...headers
},
timeout: 30000
}
)
.then((res) => res.data)
.then(async (data) => {
if (!data?.results || data?.results?.length === 0) {
logger.error('Rerank returned empty results', { data });
@@ -6,8 +6,7 @@ import { toolMap as getFileUrlToolMap } from './getFileUrl.tool';
import { toolMap as shellToolMap } from './shell.tool';
import { getSandboxClient } from '../controller';
import { parseJsonArgs } from '../../utils';
import { axios } from '../../../../common/api/axios';
import { serverRequestBaseUrl } from '../../../../common/api/serverRequest';
import { pickOutboundAxios } from '../../../../common/api/axios';
import type { FileWriteEntry } from '@fastgpt-sdk/sandbox-adapter';
const ToolMap = {
@@ -95,8 +94,7 @@ export const injectSandboxFiles = async ({
files
.filter((file) => file.path)
.map(async ({ path, url }): Promise<FileWriteEntry> => {
const response = await axios.get<ArrayBuffer>(url, {
baseURL: serverRequestBaseUrl,
const response = await pickOutboundAxios(url).get<ArrayBuffer>(url, {
responseType: 'arraybuffer'
});
@@ -1,6 +1,5 @@
import { isInternalAddress, PRIVATE_URL_TEXT } from '../../../../../../../common/system/utils';
import axios from 'axios';
import { serverRequestBaseUrl } from '../../../../../../../common/api/serverRequest';
import { pickOutboundAxios } from '../../../../../../../common/api/axios';
import { parseFileExtensionFromUrl } from '@fastgpt/global/common/string/tools';
import {
detectFileEncoding,
@@ -62,9 +61,7 @@ export const dispatchFileRead = async ({
content: Promise.reject(PRIVATE_URL_TEXT)
};
}
// Get file buffer data
const response = await axios.get(url, {
baseURL: serverRequestBaseUrl,
const response = await pickOutboundAxios(url).get(url, {
responseType: 'arraybuffer'
});
@@ -15,8 +15,7 @@ import type {
SandboxSearchSchema,
SandboxFetchUserFileSchema
} from '@fastgpt/global/core/workflow/node/agent/skillTools';
import axios from 'axios';
import { serverRequestBaseUrl } from '../../../../../../../common/api/serverRequest';
import { pickOutboundAxios } from '../../../../../../../common/api/axios';
import path from 'path';
type DispatchResult = {
@@ -200,9 +199,16 @@ export async function dispatchSandboxFetchUserFile(
};
}
// 拒绝 ws/wss 协议进入文件下载链路
if (/^wss?:/i.test(fileEntry.url)) {
return {
response: `Failed: ws/wss protocol is not allowed for file URL`,
usages: []
};
}
try {
const response = await axios.get(fileEntry.url, {
baseURL: serverRequestBaseUrl,
const response = await pickOutboundAxios(fileEntry.url).get(fileEntry.url, {
responseType: 'arraybuffer'
});
const buffer: ArrayBuffer = response.data;
+2 -5
View File
@@ -3,8 +3,7 @@ import type { UserChatItemValueItemType } from '@fastgpt/global/core/chat/type';
import { parseUrlToFileType } from './context';
import { getS3RawTextSource } from '../../../common/s3/sources/rawText';
import { isInternalAddress, PRIVATE_URL_TEXT } from '../../../common/system/utils';
import { axios } from '../../../common/api/axios';
import { serverRequestBaseUrl } from '../../../common/api/serverRequest';
import { pickOutboundAxios } from '../../../common/api/axios';
import { S3Buckets } from '../../../common/s3/config/constants';
import { S3Sources } from '../../../common/s3/contracts/type';
import {
@@ -112,9 +111,7 @@ export const normalizeReadableFileUrl = ({
};
export const getFileInfoFromUrl = async ({ teamId, url }: { teamId: string; url: string }) => {
// Get file buffer data
const response = await axios.get(url, {
baseURL: serverRequestBaseUrl,
const response = await pickOutboundAxios(url).get(url, {
responseType: 'arraybuffer'
});
@@ -163,4 +163,33 @@ describe('axios.ts', () => {
expect(axios.defaults.httpsAgent).toBeDefined();
});
});
describe('pickOutboundAxios', () => {
it.each([
'http://example.com',
'https://example.com/path',
'http://169.254.169.254/latest/meta-data/',
'//attacker.example/probe' // protocol-relative 也按绝对处理
])('绝对 URL %j 返回 safe axios 实例', async (url) => {
const { axios, pickOutboundAxios } = await import('@fastgpt/service/common/api/axios');
expect(pickOutboundAxios(url)).toBe(axios);
});
it.each(['/api/foo', 'api/foo', '/support/outLink/feishu/abc'])(
'相对路径 %j 返回内部 axios(baseURL 固定到本机)',
async (url) => {
const { axios, pickOutboundAxios } = await import('@fastgpt/service/common/api/axios');
const client = pickOutboundAxios(url);
expect(client).not.toBe(axios);
expect(client.defaults.baseURL).toMatch(/^http:\/\//);
}
);
it('多次调用同一类型的 URL,内部 client 应被复用(避免每次新建实例)', async () => {
const { pickOutboundAxios } = await import('@fastgpt/service/common/api/axios');
const a = pickOutboundAxios('/api/a');
const b = pickOutboundAxios('/api/b');
expect(a).toBe(b);
});
});
});
@@ -0,0 +1,121 @@
import { describe, it, expect } from 'vitest';
import {
isAbsoluteUrl,
assertRelativePath,
buildSameOriginUrl
} from '@fastgpt/service/common/security/network';
describe('common/security/network', () => {
describe('isAbsoluteUrl', () => {
it.each([
['http://example.com', true],
['https://example.com/path', true],
['HTTP://EXAMPLE.COM', true], // 协议大小写不敏感
['ws://example.com', true],
['wss://example.com', true],
['ftp://example.com', true],
['file:///etc/passwd', true],
['javascript:alert(1)', false], // 没有 :// 不算
['//example.com/path', true], // protocol-relative
['//169.254.169.254/latest/meta-data/', true],
['/api/foo', false],
['api/foo', false],
['', false],
['?query=1', false],
['#hash', false]
])('isAbsoluteUrl(%j) === %s', (input, expected) => {
expect(isAbsoluteUrl(input)).toBe(expected);
});
it('non-string 输入一律返回 false', () => {
expect(isAbsoluteUrl(undefined)).toBe(false);
expect(isAbsoluteUrl(null)).toBe(false);
expect(isAbsoluteUrl(123)).toBe(false);
expect(isAbsoluteUrl({})).toBe(false);
});
});
describe('assertRelativePath', () => {
it('相对路径不抛错', () => {
expect(() => assertRelativePath('/api/foo')).not.toThrow();
expect(() => assertRelativePath('api/foo')).not.toThrow();
expect(() => assertRelativePath('support/outLink/wecom/abc')).not.toThrow();
});
it.each([
'http://example.com',
'https://169.254.169.254/latest/meta-data/',
'//attacker.example/probe',
'ws://internal/socket'
])('绝对 URL 抛错: %j', (url) => {
expect(() => assertRelativePath(url)).toThrow(/only accepts relative paths/i);
});
it('non-string 抛错', () => {
expect(() => assertRelativePath(undefined)).toThrow(/only accepts relative paths/i);
expect(() => assertRelativePath(null)).toThrow(/only accepts relative paths/i);
});
it('错误信息包含调用者名称,便于定位', () => {
expect(() => assertRelativePath('http://x', 'plusRequest')).toThrow(/plusRequest/);
expect(() => assertRelativePath('http://x', 'serverRequest')).toThrow(/serverRequest/);
});
});
describe('buildSameOriginUrl', () => {
const base = 'http://internal-service:3000';
it('普通相对路径正常拼接', () => {
const u = buildSameOriginUrl('/api/foo', base);
expect(u.href).toBe('http://internal-service:3000/api/foo');
});
it('保留 query 与 hash', () => {
const u = buildSameOriginUrl('/api/foo?x=1#bar', base);
expect(u.href).toBe('http://internal-service:3000/api/foo?x=1#bar');
});
it('保留 base 自带 path 的相对解析行为', () => {
const u = buildSameOriginUrl('foo', 'http://h:3000/api/');
expect(u.href).toBe('http://h:3000/api/foo');
});
it.each([
// protocol-relative URL 直接覆盖主机
'//169.254.169.254/latest/meta-data/',
'//attacker.example/probe',
// NextJS catch-all 拼接产物: requestPath = `/${['', 'evil', 'x'].join('/')}` = `//evil/x`
'//evil.example/path',
// 绝对 URL 也会替换主机
'http://attacker.example/x',
'https://169.254.169.254/',
// 协议 + 主机 + 不同端口
'http://internal-service:9999/'
])('protocol-relative / 绝对 URL 改写主机时抛错: %j', (path) => {
expect(() => buildSameOriginUrl(path, base)).toThrow(/does not match base/i);
});
it('host 相同但端口不同也算不同 origin', () => {
expect(() => buildSameOriginUrl('//internal-service:9999/x', base)).toThrow(
/does not match base/i
);
});
it('host 相同但协议不同也算不同 origin', () => {
expect(() => buildSameOriginUrl('https://internal-service:3000/', base)).toThrow(
/does not match base/i
);
});
it('base 非法 URL 时抛错', () => {
expect(() => buildSameOriginUrl('/api/foo', 'not a url')).toThrow();
});
it('NextJS catch-all 真实场景: path 含空段产生 protocol-relative', () => {
// 模拟 `req.query.path = ['', '169.254.169.254', 'latest']` (来源: /aiproxy//169.254.169.254/latest)
const requestPath = `/${['', '169.254.169.254', 'latest'].join('/')}`;
expect(requestPath).toBe('//169.254.169.254/latest');
expect(() => buildSameOriginUrl(requestPath, base)).toThrow(/does not match base/i);
});
});
});
@@ -3,17 +3,21 @@ import { ModelTypeEnum } from '@fastgpt/global/core/ai/constants';
import type { RerankModelItemType } from '@fastgpt/global/core/ai/model.schema';
// hoisted:让 mock 实例可在 beforeEach 中重设
const { mockCountPromptTokens, mockPOST } = vi.hoisted(() => ({
const { mockCountPromptTokens, mockAxiosPost } = vi.hoisted(() => ({
mockCountPromptTokens: vi.fn(),
mockPOST: vi.fn()
// mockAxiosPost 接收原始 payload(即 axios response 的 .data),包装成 { data }
mockAxiosPost: vi.fn()
}));
vi.mock('@fastgpt/service/common/string/tiktoken', () => ({
countPromptTokens: mockCountPromptTokens
}));
vi.mock('@fastgpt/service/common/api/serverRequest', () => ({
POST: (...args: any[]) => mockPOST(...args)
// rerank 现在改用统一 axios(带 SSRF 拦截),mock axios.post 返回 axios 风格的 { data, ... }
vi.mock('@fastgpt/service/common/api/axios', () => ({
axios: {
post: (...args: any[]) => Promise.resolve(mockAxiosPost(...args)).then((data) => ({ data }))
}
}));
// Mock text2Chunks:按 chunkSize 字符切分,保证测试确定性
@@ -40,7 +44,7 @@ const mockModel: RerankModelItemType = {
describe('reRankRecall', () => {
beforeEach(() => {
mockPOST.mockReset();
mockAxiosPost.mockReset();
mockCountPromptTokens.mockReset();
mockCountPromptTokens.mockImplementation(async (text: string) => text.length);
});
@@ -48,7 +52,7 @@ describe('reRankRecall', () => {
// ── 基础场景 ──────────────────────────────────────────────────────────────
it('正常场景:多文档返回正确 id 和 score', async () => {
mockPOST.mockResolvedValueOnce({
mockAxiosPost.mockResolvedValueOnce({
id: 'r1',
results: [
{ index: 1, relevance_score: 0.9 },
@@ -73,7 +77,7 @@ describe('reRankRecall', () => {
});
it('单文档正常召回', async () => {
mockPOST.mockResolvedValueOnce({
mockAxiosPost.mockResolvedValueOnce({
id: 'r1',
results: [{ index: 0, relevance_score: 0.75 }],
meta: { tokens: { input_tokens: 10, output_tokens: 0 } }
@@ -99,7 +103,7 @@ describe('reRankRecall', () => {
});
expect(result).toEqual({ results: [], inputTokens: 0 });
expect(mockPOST).not.toHaveBeenCalled();
expect(mockAxiosPost).not.toHaveBeenCalled();
});
it('所有文档 text 为空或空白时,返回空结果,不发请求', async () => {
@@ -113,7 +117,7 @@ describe('reRankRecall', () => {
});
expect(result).toEqual({ results: [], inputTokens: 0 });
expect(mockPOST).not.toHaveBeenCalled();
expect(mockAxiosPost).not.toHaveBeenCalled();
});
// ── 复杂场景:文档切分 ────────────────────────────────────────────────────
@@ -125,7 +129,7 @@ describe('reRankRecall', () => {
// doc2 'short' length=5 <= 599 → 不切分 (index 3)
const longText = 'a'.repeat(1100);
mockPOST.mockResolvedValueOnce({
mockAxiosPost.mockResolvedValueOnce({
id: 'r1',
// API 按 score 降序返回
results: [
@@ -158,7 +162,7 @@ describe('reRankRecall', () => {
// maxToken=600, query='q'(1), docBudget=599, chunkSize=539
const longText = 'b'.repeat(1100);
mockPOST.mockResolvedValueOnce({
mockAxiosPost.mockResolvedValueOnce({
id: 'r1',
results: [
{ index: 1, relevance_score: 0.95 }, // chunk_1 最高
@@ -181,7 +185,7 @@ describe('reRankRecall', () => {
// ── inputTokens 计算 ──────────────────────────────────────────────────────
it('API 未返回 meta tokens 时,通过 countPromptTokens 估算', async () => {
mockPOST.mockResolvedValueOnce({
mockAxiosPost.mockResolvedValueOnce({
id: 'r1',
results: [{ index: 0, relevance_score: 0.5 }]
// 无 meta
@@ -198,7 +202,7 @@ describe('reRankRecall', () => {
});
it('API 返回 meta tokens 时直接使用', async () => {
mockPOST.mockResolvedValueOnce({
mockAxiosPost.mockResolvedValueOnce({
id: 'r1',
results: [{ index: 0, relevance_score: 0.5 }],
meta: { tokens: { input_tokens: 42, output_tokens: 0 } }
@@ -216,7 +220,7 @@ describe('reRankRecall', () => {
// ── requestUrl / requestAuth ──────────────────────────────────────────────
it('有 requestUrl 和 requestAuth 时,使用自定义地址和认证头', async () => {
mockPOST.mockResolvedValueOnce({
mockAxiosPost.mockResolvedValueOnce({
id: 'r1',
results: [{ index: 0, relevance_score: 0.5 }],
meta: { tokens: { input_tokens: 5, output_tokens: 0 } }
@@ -232,7 +236,7 @@ describe('reRankRecall', () => {
documents: [{ id: 'doc1', text: 'hello' }]
});
expect(mockPOST).toHaveBeenCalledWith(
expect(mockAxiosPost).toHaveBeenCalledWith(
'https://custom.rerank.io/rerank',
expect.any(Object),
expect.objectContaining({
@@ -244,7 +248,7 @@ describe('reRankRecall', () => {
});
it('未设置 requestUrl 时,使用 baseUrl/rerank', async () => {
mockPOST.mockResolvedValueOnce({
mockAxiosPost.mockResolvedValueOnce({
id: 'r1',
results: [{ index: 0, relevance_score: 0.5 }],
meta: { tokens: { input_tokens: 5, output_tokens: 0 } }
@@ -256,7 +260,7 @@ describe('reRankRecall', () => {
documents: [{ id: 'doc1', text: 'hello' }]
});
const url: string = mockPOST.mock.calls[0][0];
const url: string = mockAxiosPost.mock.calls[0][0];
expect(url.endsWith('/rerank')).toBe(true);
});
@@ -297,7 +301,7 @@ describe('reRankRecall', () => {
it('docBudget === 501 时不因 query 过长 reject', async () => {
// maxToken=502, query='q'(length=1) → docBudget = 502-1 = 501 > 500 → 正常发请求
mockPOST.mockResolvedValueOnce({
mockAxiosPost.mockResolvedValueOnce({
id: 'r1',
results: [{ index: 0, relevance_score: 0.5 }],
meta: { tokens: { input_tokens: 5, output_tokens: 0 } }
@@ -310,11 +314,11 @@ describe('reRankRecall', () => {
});
expect(result.results).toHaveLength(1);
expect(mockPOST).toHaveBeenCalledOnce();
expect(mockAxiosPost).toHaveBeenCalledOnce();
});
it('API 请求失败时,reject 并传递原始错误', async () => {
mockPOST.mockRejectedValueOnce(new Error('Network error'));
mockAxiosPost.mockRejectedValueOnce(new Error('Network error'));
await expect(
reRankRecall({
@@ -326,7 +330,7 @@ describe('reRankRecall', () => {
});
it('API 返回空 results 时,返回空 results', async () => {
mockPOST.mockResolvedValueOnce({
mockAxiosPost.mockResolvedValueOnce({
id: 'r1',
results: []
});
@@ -338,7 +342,7 @@ describe('reRankRecall', () => {
});
expect(result.results).toHaveLength(0);
expect(mockPOST).toHaveBeenCalledOnce();
expect(mockAxiosPost).toHaveBeenCalledOnce();
// 空 results 时提前返回,inputTokens 固定为 0
expect(result.inputTokens).toBe(0);
});
@@ -26,10 +26,32 @@ vi.mock('@fastgpt/service/common/system/utils', async (importOriginal) => {
vi.mock('@fastgpt/service/common/api/axios', async (importOriginal) => {
const mod = await importOriginal<typeof import('@fastgpt/service/common/api/axios')>();
// 把 axios 和 pickOutboundAxios 一起 mock:
// - axios: 直接换成 mock(供绝对 URL 路径)
// - pickOutboundAxios: 不论 URL 类型都返回同一个 mock client(供测试统一断言 .get 调用)
const mockClient = {
get: mockAxiosGet,
defaults: { baseURL: 'http://localhost:3000' }
};
return {
...mod,
axios: {
get: mockAxiosGet
axios: mockClient,
pickOutboundAxios: () => mockClient
};
});
// 文件下载链路对相对路径走 raw axios + serverRequestBaseUrl,这里也 mock 住,
// 保证测试不会真发网络请求,且保留对 mockAxiosGet 调用次数的断言能力。
// outbound.ts 用 axios.create() 创建内部 client,所以 mock 必须提供 create 方法。
vi.mock('axios', () => {
const internalClient = {
get: mockAxiosGet,
defaults: { baseURL: 'http://localhost:3000' }
};
return {
default: {
get: mockAxiosGet,
create: vi.fn(() => internalClient)
}
};
});
@@ -409,8 +431,9 @@ describe('parseFileInfoFromUrls', () => {
});
expect(mockAxiosGet).toHaveBeenCalledTimes(1);
// 注:相对路径走 axios.create({ baseURL }) 创建的内部 client,
// baseURL 在 client 上而不在 .get() 调用参数里。
expect(mockAxiosGet).toHaveBeenCalledWith('/report.pdf', {
baseURL: expect.any(String),
responseType: 'arraybuffer'
});
expect(mockReadFileContentByBuffer).not.toHaveBeenCalled();