mirror of
https://github.com/labring/FastGPT.git
synced 2026-05-16 01:09:01 +08:00
perf: ssrf check (#6852)
This commit is contained in:
@@ -2,6 +2,8 @@ import _, { type AxiosInstance, type AxiosRequestConfig } from 'axios';
|
||||
import { ProxyAgent } from 'proxy-agent';
|
||||
import { isDevEnv } from '@fastgpt/global/common/system/constants';
|
||||
import { isInternalAddress, PRIVATE_URL_TEXT } from '../system/utils';
|
||||
import { isAbsoluteUrl } from '../security/network';
|
||||
import { SERVICE_LOCAL_HOST } from '../system/tools';
|
||||
|
||||
const addSSRFInterceptor = (instance: AxiosInstance) => {
|
||||
instance.interceptors.request.use(async (config) => {
|
||||
@@ -41,3 +43,32 @@ export function createProxyAxios(config?: AxiosRequestConfig, ssrfCheck = true)
|
||||
|
||||
/** @see https://github.com/axios/axios/issues/4531 */
|
||||
export const axios = createProxyAxios();
|
||||
|
||||
/**
|
||||
* 内部相对路径请求专用的 axios 实例:
|
||||
* - baseURL 固定为本机 NextJS API
|
||||
* - 不带 SSRF 拦截器(本机调用必然解析到 localhost,装拦截会把所有合法请求拦死)
|
||||
* - 不复用 safe axios 的 ProxyAgent,保证内部回环不会被外部代理转走
|
||||
*
|
||||
* 仅在 url 是相对路径时使用;绝对 URL 必须走 safe `axios`。
|
||||
*/
|
||||
const internalAxios: AxiosInstance = _.create({
|
||||
baseURL: `http://${SERVICE_LOCAL_HOST}`
|
||||
});
|
||||
|
||||
/**
|
||||
* 根据 URL 类型自动选择合适的 axios 实例,避免每个调用点重复
|
||||
* `isAbsoluteUrl ? safe : raw` 三元。
|
||||
*
|
||||
* - 绝对 URL(`http(s)://...` 或 `//...`)→ safe `axios`(SSRF 拦截,拒绝内网/metadata)
|
||||
* - 相对路径(`/api/...` 等)→ `internalAxios`(本机 baseURL,可信内部 API)
|
||||
*
|
||||
* 用法:
|
||||
* ```ts
|
||||
* const client = pickOutboundAxios(url);
|
||||
* const res = await client.get(url, { responseType: 'arraybuffer' });
|
||||
* ```
|
||||
*/
|
||||
export const pickOutboundAxios = (url: string): AxiosInstance => {
|
||||
return isAbsoluteUrl(url) ? axios : internalAxios;
|
||||
};
|
||||
|
||||
@@ -8,6 +8,7 @@ import { FastGPTProUrl } from '../system/constants';
|
||||
import { UserError } from '@fastgpt/global/common/error/utils';
|
||||
import { createProxyAxios } from './axios';
|
||||
import { getLogger, LogCategories } from '../logger';
|
||||
import { assertRelativePath } from '../security/network';
|
||||
|
||||
const logger = getLogger(LogCategories.HTTP.ERROR);
|
||||
|
||||
@@ -92,6 +93,14 @@ export function request(url: string, data: any, config: ConfigType, method: Meth
|
||||
return Promise.reject(new UserError('The request was denied...'));
|
||||
}
|
||||
|
||||
// plusRequest 仅用于访问商业版 Pro 服务,会自动携带 rootkey,SSRF 拦截已被显式关闭。
|
||||
// 强制要求相对路径,防止调用方传入绝对 URL 覆盖 baseURL 形成带高权限头的 SSRF。
|
||||
try {
|
||||
assertRelativePath(url, 'plusRequest');
|
||||
} catch (err) {
|
||||
return Promise.reject(err);
|
||||
}
|
||||
|
||||
/* 去空 */
|
||||
for (const key in data) {
|
||||
if (data[key] === null || data[key] === undefined) {
|
||||
@@ -135,8 +144,14 @@ export function DELETE<T = undefined>(url: string, data = {}, config: ConfigType
|
||||
return request(url, data, config, 'DELETE');
|
||||
}
|
||||
|
||||
export const plusRequest = (config: AxiosRequestConfig) =>
|
||||
instance.request({
|
||||
export const plusRequest = (config: AxiosRequestConfig) => {
|
||||
try {
|
||||
assertRelativePath(config.url, 'plusRequest');
|
||||
} catch (err) {
|
||||
return Promise.reject(err);
|
||||
}
|
||||
return instance.request({
|
||||
...config,
|
||||
baseURL: FastGPTProUrl
|
||||
});
|
||||
};
|
||||
|
||||
@@ -2,6 +2,7 @@ import { SERVICE_LOCAL_HOST } from '../system/tools';
|
||||
import { type Method, type InternalAxiosRequestConfig, type AxiosResponse } from 'axios';
|
||||
import { createProxyAxios } from './axios';
|
||||
import { getLogger, LogCategories } from '../logger';
|
||||
import { assertRelativePath } from '../security/network';
|
||||
|
||||
const logger = getLogger(LogCategories.HTTP.ERROR);
|
||||
|
||||
@@ -78,6 +79,14 @@ instance.interceptors.request.use(requestStart, (err) => Promise.reject(err));
|
||||
instance.interceptors.response.use(responseSuccess, (err) => Promise.reject(err));
|
||||
|
||||
export function request(url: string, data: any, config: ConfigType, method: Method): any {
|
||||
// serverRequest 仅用于访问本机 NextJS API,SSRF 拦截已被显式关闭。
|
||||
// 强制要求相对路径,防止调用方传入绝对 URL 覆盖 baseURL 形成 SSRF。
|
||||
try {
|
||||
assertRelativePath(url, 'serverRequest');
|
||||
} catch (err) {
|
||||
return Promise.reject(err);
|
||||
}
|
||||
|
||||
/* 去空 */
|
||||
for (const key in data) {
|
||||
if (data[key] === null || data[key] === undefined) {
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
/**
|
||||
* 网络出站安全校验工具集。
|
||||
*
|
||||
* 这里集中导出 URL 字符串层面的轻量校验,供 serverRequest/plusRequest 等
|
||||
* 内部 helper 在调用前快速短路。更复杂的 SSRF 校验(协议白名单 + DNS +
|
||||
* 内网/metadata 拦截)请使用 `common/system/utils.ts` 中的 `checkUrlSafety`。
|
||||
*/
|
||||
|
||||
/**
|
||||
* 判断给定字符串是否是"绝对 URL"。
|
||||
* 命中条件:
|
||||
* - 以 `scheme://` 形式开头(http://、https://、ws://、ftp:// ...)
|
||||
* - 以 `//` 开头(protocol-relative,会被 axios/new URL 当成绝对 URL 处理)
|
||||
*
|
||||
* 校验严格的目的是阻止 helper 调用方意外把绝对 URL 传进来覆盖 baseURL,
|
||||
* 即使该 helper 已经显式关闭 SSRF 拦截器也不会形成 SSRF。
|
||||
*/
|
||||
export const isAbsoluteUrl = (url: unknown): boolean => {
|
||||
if (typeof url !== 'string') return false;
|
||||
return /^[a-z][a-z0-9+.-]*:\/\//i.test(url) || url.startsWith('//');
|
||||
};
|
||||
|
||||
/**
|
||||
* 强制要求传入的 URL 是相对路径,否则 reject。
|
||||
* 适用于"按设计只访问内部固定 baseURL"的内部 helper。
|
||||
*
|
||||
* 例: serverRequest(本机 NextJS API)、plusRequest(商业版 Pro 服务)。
|
||||
*/
|
||||
export const assertRelativePath = (url: unknown, helperName = 'request'): void => {
|
||||
if (typeof url !== 'string' || isAbsoluteUrl(url)) {
|
||||
throw new Error(`${helperName} only accepts relative paths, absolute URLs are not allowed`);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* 在用 `new URL(path, base)` 构造目标 URL 后,强制校验最终 origin 与 base 一致。
|
||||
*
|
||||
* 防御 "protocol-relative URL" 主机覆盖:
|
||||
* - `new URL('//169.254.169.254/foo', 'http://internal:3000')` → host 被替换
|
||||
* - NextJS catch-all `[...path]` 中,`/api//evil/...` 会被拆成 `['', 'evil', ...]`,
|
||||
* join 回去就构造出 `//evil/...` 这种 protocol-relative path
|
||||
*
|
||||
* 用法:
|
||||
* const target = buildSameOriginUrl(requestPath, baseUrl); // 抛错 = 攻击
|
||||
*/
|
||||
export const buildSameOriginUrl = (path: string, base: string): URL => {
|
||||
const baseUrl = new URL(base);
|
||||
const target = new URL(path, baseUrl);
|
||||
if (target.origin !== baseUrl.origin) {
|
||||
throw new Error(
|
||||
`Refused: target URL origin (${target.origin}) does not match base (${baseUrl.origin})`
|
||||
);
|
||||
}
|
||||
return target;
|
||||
};
|
||||
@@ -178,3 +178,30 @@ export const isInternalAddress = async (url: string): Promise<boolean> => {
|
||||
};
|
||||
|
||||
export const PRIVATE_URL_TEXT = 'Request to private network not allowed';
|
||||
|
||||
/**
|
||||
* 用于"保存配置 URL"或"调用前校验"的统一安全检查:
|
||||
* - 必须是合法 URL
|
||||
* - 协议必须是 http/https
|
||||
* - 不能指向内部地址(loopback/metadata,以及在 CHECK_INTERNAL_IP=true 时的私网)
|
||||
*
|
||||
* 注意:`isInternalAddress` 在 dev 环境直接放行;为了让保存入口
|
||||
* 在 dev 也能拒绝明显错误的 URL(localhost / metadata),
|
||||
* 这里**不依赖 isDevEnv**,而是用同一套规则做轻量校验。
|
||||
*/
|
||||
export const checkUrlSafety = async (url: string, fieldName = 'URL'): Promise<void> => {
|
||||
let parsed: URL;
|
||||
try {
|
||||
parsed = new URL(url);
|
||||
} catch {
|
||||
return Promise.reject(new Error(`${fieldName} must be a valid URL`));
|
||||
}
|
||||
|
||||
if (parsed.protocol !== 'http:' && parsed.protocol !== 'https:') {
|
||||
return Promise.reject(new Error(`${fieldName} must use http or https protocol`));
|
||||
}
|
||||
|
||||
if (await isInternalAddress(url)) {
|
||||
return Promise.reject(new Error(`${fieldName}: ${PRIVATE_URL_TEXT}`));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { POST } from '../../../common/api/serverRequest';
|
||||
import { axios } from '../../../common/api/axios';
|
||||
import { getDefaultRerankModel } from '../model';
|
||||
import { getAxiosConfig } from '../config';
|
||||
import { type RerankModelItemType } from '@fastgpt/global/core/ai/model.schema';
|
||||
@@ -93,21 +93,24 @@ export async function reRankRecall({
|
||||
const { baseUrl, authorization } = getAxiosConfig();
|
||||
const start = Date.now();
|
||||
|
||||
const apiResult = await POST<PostReRankResponse>(
|
||||
model.requestUrl ? model.requestUrl : `${baseUrl}/rerank`,
|
||||
{
|
||||
model: model.model,
|
||||
query,
|
||||
documents: documentsTextArray
|
||||
},
|
||||
{
|
||||
headers: {
|
||||
Authorization: model.requestAuth ? `Bearer ${model.requestAuth}` : authorization,
|
||||
...headers
|
||||
const requestUrl = model.requestUrl ? model.requestUrl : `${baseUrl}/rerank`;
|
||||
const apiResult = await axios
|
||||
.post<PostReRankResponse>(
|
||||
requestUrl,
|
||||
{
|
||||
model: model.model,
|
||||
query,
|
||||
documents: documentsTextArray
|
||||
},
|
||||
timeout: 30000
|
||||
}
|
||||
)
|
||||
{
|
||||
headers: {
|
||||
Authorization: model.requestAuth ? `Bearer ${model.requestAuth}` : authorization,
|
||||
...headers
|
||||
},
|
||||
timeout: 30000
|
||||
}
|
||||
)
|
||||
.then((res) => res.data)
|
||||
.then(async (data) => {
|
||||
if (!data?.results || data?.results?.length === 0) {
|
||||
logger.error('Rerank returned empty results', { data });
|
||||
|
||||
@@ -6,8 +6,7 @@ import { toolMap as getFileUrlToolMap } from './getFileUrl.tool';
|
||||
import { toolMap as shellToolMap } from './shell.tool';
|
||||
import { getSandboxClient } from '../controller';
|
||||
import { parseJsonArgs } from '../../utils';
|
||||
import { axios } from '../../../../common/api/axios';
|
||||
import { serverRequestBaseUrl } from '../../../../common/api/serverRequest';
|
||||
import { pickOutboundAxios } from '../../../../common/api/axios';
|
||||
import type { FileWriteEntry } from '@fastgpt-sdk/sandbox-adapter';
|
||||
|
||||
const ToolMap = {
|
||||
@@ -95,8 +94,7 @@ export const injectSandboxFiles = async ({
|
||||
files
|
||||
.filter((file) => file.path)
|
||||
.map(async ({ path, url }): Promise<FileWriteEntry> => {
|
||||
const response = await axios.get<ArrayBuffer>(url, {
|
||||
baseURL: serverRequestBaseUrl,
|
||||
const response = await pickOutboundAxios(url).get<ArrayBuffer>(url, {
|
||||
responseType: 'arraybuffer'
|
||||
});
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { isInternalAddress, PRIVATE_URL_TEXT } from '../../../../../../../common/system/utils';
|
||||
import axios from 'axios';
|
||||
import { serverRequestBaseUrl } from '../../../../../../../common/api/serverRequest';
|
||||
import { pickOutboundAxios } from '../../../../../../../common/api/axios';
|
||||
import { parseFileExtensionFromUrl } from '@fastgpt/global/common/string/tools';
|
||||
import {
|
||||
detectFileEncoding,
|
||||
@@ -62,9 +61,7 @@ export const dispatchFileRead = async ({
|
||||
content: Promise.reject(PRIVATE_URL_TEXT)
|
||||
};
|
||||
}
|
||||
// Get file buffer data
|
||||
const response = await axios.get(url, {
|
||||
baseURL: serverRequestBaseUrl,
|
||||
const response = await pickOutboundAxios(url).get(url, {
|
||||
responseType: 'arraybuffer'
|
||||
});
|
||||
|
||||
|
||||
@@ -15,8 +15,7 @@ import type {
|
||||
SandboxSearchSchema,
|
||||
SandboxFetchUserFileSchema
|
||||
} from '@fastgpt/global/core/workflow/node/agent/skillTools';
|
||||
import axios from 'axios';
|
||||
import { serverRequestBaseUrl } from '../../../../../../../common/api/serverRequest';
|
||||
import { pickOutboundAxios } from '../../../../../../../common/api/axios';
|
||||
import path from 'path';
|
||||
|
||||
type DispatchResult = {
|
||||
@@ -200,9 +199,16 @@ export async function dispatchSandboxFetchUserFile(
|
||||
};
|
||||
}
|
||||
|
||||
// 拒绝 ws/wss 协议进入文件下载链路
|
||||
if (/^wss?:/i.test(fileEntry.url)) {
|
||||
return {
|
||||
response: `Failed: ws/wss protocol is not allowed for file URL`,
|
||||
usages: []
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await axios.get(fileEntry.url, {
|
||||
baseURL: serverRequestBaseUrl,
|
||||
const response = await pickOutboundAxios(fileEntry.url).get(fileEntry.url, {
|
||||
responseType: 'arraybuffer'
|
||||
});
|
||||
const buffer: ArrayBuffer = response.data;
|
||||
|
||||
@@ -3,8 +3,7 @@ import type { UserChatItemValueItemType } from '@fastgpt/global/core/chat/type';
|
||||
import { parseUrlToFileType } from './context';
|
||||
import { getS3RawTextSource } from '../../../common/s3/sources/rawText';
|
||||
import { isInternalAddress, PRIVATE_URL_TEXT } from '../../../common/system/utils';
|
||||
import { axios } from '../../../common/api/axios';
|
||||
import { serverRequestBaseUrl } from '../../../common/api/serverRequest';
|
||||
import { pickOutboundAxios } from '../../../common/api/axios';
|
||||
import { S3Buckets } from '../../../common/s3/config/constants';
|
||||
import { S3Sources } from '../../../common/s3/contracts/type';
|
||||
import {
|
||||
@@ -112,9 +111,7 @@ export const normalizeReadableFileUrl = ({
|
||||
};
|
||||
|
||||
export const getFileInfoFromUrl = async ({ teamId, url }: { teamId: string; url: string }) => {
|
||||
// Get file buffer data
|
||||
const response = await axios.get(url, {
|
||||
baseURL: serverRequestBaseUrl,
|
||||
const response = await pickOutboundAxios(url).get(url, {
|
||||
responseType: 'arraybuffer'
|
||||
});
|
||||
|
||||
|
||||
@@ -163,4 +163,33 @@ describe('axios.ts', () => {
|
||||
expect(axios.defaults.httpsAgent).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('pickOutboundAxios', () => {
|
||||
it.each([
|
||||
'http://example.com',
|
||||
'https://example.com/path',
|
||||
'http://169.254.169.254/latest/meta-data/',
|
||||
'//attacker.example/probe' // protocol-relative 也按绝对处理
|
||||
])('绝对 URL %j 返回 safe axios 实例', async (url) => {
|
||||
const { axios, pickOutboundAxios } = await import('@fastgpt/service/common/api/axios');
|
||||
expect(pickOutboundAxios(url)).toBe(axios);
|
||||
});
|
||||
|
||||
it.each(['/api/foo', 'api/foo', '/support/outLink/feishu/abc'])(
|
||||
'相对路径 %j 返回内部 axios(baseURL 固定到本机)',
|
||||
async (url) => {
|
||||
const { axios, pickOutboundAxios } = await import('@fastgpt/service/common/api/axios');
|
||||
const client = pickOutboundAxios(url);
|
||||
expect(client).not.toBe(axios);
|
||||
expect(client.defaults.baseURL).toMatch(/^http:\/\//);
|
||||
}
|
||||
);
|
||||
|
||||
it('多次调用同一类型的 URL,内部 client 应被复用(避免每次新建实例)', async () => {
|
||||
const { pickOutboundAxios } = await import('@fastgpt/service/common/api/axios');
|
||||
const a = pickOutboundAxios('/api/a');
|
||||
const b = pickOutboundAxios('/api/b');
|
||||
expect(a).toBe(b);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -0,0 +1,121 @@
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import {
|
||||
isAbsoluteUrl,
|
||||
assertRelativePath,
|
||||
buildSameOriginUrl
|
||||
} from '@fastgpt/service/common/security/network';
|
||||
|
||||
describe('common/security/network', () => {
|
||||
describe('isAbsoluteUrl', () => {
|
||||
it.each([
|
||||
['http://example.com', true],
|
||||
['https://example.com/path', true],
|
||||
['HTTP://EXAMPLE.COM', true], // 协议大小写不敏感
|
||||
['ws://example.com', true],
|
||||
['wss://example.com', true],
|
||||
['ftp://example.com', true],
|
||||
['file:///etc/passwd', true],
|
||||
['javascript:alert(1)', false], // 没有 :// 不算
|
||||
['//example.com/path', true], // protocol-relative
|
||||
['//169.254.169.254/latest/meta-data/', true],
|
||||
['/api/foo', false],
|
||||
['api/foo', false],
|
||||
['', false],
|
||||
['?query=1', false],
|
||||
['#hash', false]
|
||||
])('isAbsoluteUrl(%j) === %s', (input, expected) => {
|
||||
expect(isAbsoluteUrl(input)).toBe(expected);
|
||||
});
|
||||
|
||||
it('non-string 输入一律返回 false', () => {
|
||||
expect(isAbsoluteUrl(undefined)).toBe(false);
|
||||
expect(isAbsoluteUrl(null)).toBe(false);
|
||||
expect(isAbsoluteUrl(123)).toBe(false);
|
||||
expect(isAbsoluteUrl({})).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('assertRelativePath', () => {
|
||||
it('相对路径不抛错', () => {
|
||||
expect(() => assertRelativePath('/api/foo')).not.toThrow();
|
||||
expect(() => assertRelativePath('api/foo')).not.toThrow();
|
||||
expect(() => assertRelativePath('support/outLink/wecom/abc')).not.toThrow();
|
||||
});
|
||||
|
||||
it.each([
|
||||
'http://example.com',
|
||||
'https://169.254.169.254/latest/meta-data/',
|
||||
'//attacker.example/probe',
|
||||
'ws://internal/socket'
|
||||
])('绝对 URL 抛错: %j', (url) => {
|
||||
expect(() => assertRelativePath(url)).toThrow(/only accepts relative paths/i);
|
||||
});
|
||||
|
||||
it('non-string 抛错', () => {
|
||||
expect(() => assertRelativePath(undefined)).toThrow(/only accepts relative paths/i);
|
||||
expect(() => assertRelativePath(null)).toThrow(/only accepts relative paths/i);
|
||||
});
|
||||
|
||||
it('错误信息包含调用者名称,便于定位', () => {
|
||||
expect(() => assertRelativePath('http://x', 'plusRequest')).toThrow(/plusRequest/);
|
||||
expect(() => assertRelativePath('http://x', 'serverRequest')).toThrow(/serverRequest/);
|
||||
});
|
||||
});
|
||||
|
||||
describe('buildSameOriginUrl', () => {
|
||||
const base = 'http://internal-service:3000';
|
||||
|
||||
it('普通相对路径正常拼接', () => {
|
||||
const u = buildSameOriginUrl('/api/foo', base);
|
||||
expect(u.href).toBe('http://internal-service:3000/api/foo');
|
||||
});
|
||||
|
||||
it('保留 query 与 hash', () => {
|
||||
const u = buildSameOriginUrl('/api/foo?x=1#bar', base);
|
||||
expect(u.href).toBe('http://internal-service:3000/api/foo?x=1#bar');
|
||||
});
|
||||
|
||||
it('保留 base 自带 path 的相对解析行为', () => {
|
||||
const u = buildSameOriginUrl('foo', 'http://h:3000/api/');
|
||||
expect(u.href).toBe('http://h:3000/api/foo');
|
||||
});
|
||||
|
||||
it.each([
|
||||
// protocol-relative URL 直接覆盖主机
|
||||
'//169.254.169.254/latest/meta-data/',
|
||||
'//attacker.example/probe',
|
||||
// NextJS catch-all 拼接产物: requestPath = `/${['', 'evil', 'x'].join('/')}` = `//evil/x`
|
||||
'//evil.example/path',
|
||||
// 绝对 URL 也会替换主机
|
||||
'http://attacker.example/x',
|
||||
'https://169.254.169.254/',
|
||||
// 协议 + 主机 + 不同端口
|
||||
'http://internal-service:9999/'
|
||||
])('protocol-relative / 绝对 URL 改写主机时抛错: %j', (path) => {
|
||||
expect(() => buildSameOriginUrl(path, base)).toThrow(/does not match base/i);
|
||||
});
|
||||
|
||||
it('host 相同但端口不同也算不同 origin', () => {
|
||||
expect(() => buildSameOriginUrl('//internal-service:9999/x', base)).toThrow(
|
||||
/does not match base/i
|
||||
);
|
||||
});
|
||||
|
||||
it('host 相同但协议不同也算不同 origin', () => {
|
||||
expect(() => buildSameOriginUrl('https://internal-service:3000/', base)).toThrow(
|
||||
/does not match base/i
|
||||
);
|
||||
});
|
||||
|
||||
it('base 非法 URL 时抛错', () => {
|
||||
expect(() => buildSameOriginUrl('/api/foo', 'not a url')).toThrow();
|
||||
});
|
||||
|
||||
it('NextJS catch-all 真实场景: path 含空段产生 protocol-relative', () => {
|
||||
// 模拟 `req.query.path = ['', '169.254.169.254', 'latest']` (来源: /aiproxy//169.254.169.254/latest)
|
||||
const requestPath = `/${['', '169.254.169.254', 'latest'].join('/')}`;
|
||||
expect(requestPath).toBe('//169.254.169.254/latest');
|
||||
expect(() => buildSameOriginUrl(requestPath, base)).toThrow(/does not match base/i);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -3,17 +3,21 @@ import { ModelTypeEnum } from '@fastgpt/global/core/ai/constants';
|
||||
import type { RerankModelItemType } from '@fastgpt/global/core/ai/model.schema';
|
||||
|
||||
// hoisted:让 mock 实例可在 beforeEach 中重设
|
||||
const { mockCountPromptTokens, mockPOST } = vi.hoisted(() => ({
|
||||
const { mockCountPromptTokens, mockAxiosPost } = vi.hoisted(() => ({
|
||||
mockCountPromptTokens: vi.fn(),
|
||||
mockPOST: vi.fn()
|
||||
// mockAxiosPost 接收原始 payload(即 axios response 的 .data),包装成 { data }
|
||||
mockAxiosPost: vi.fn()
|
||||
}));
|
||||
|
||||
vi.mock('@fastgpt/service/common/string/tiktoken', () => ({
|
||||
countPromptTokens: mockCountPromptTokens
|
||||
}));
|
||||
|
||||
vi.mock('@fastgpt/service/common/api/serverRequest', () => ({
|
||||
POST: (...args: any[]) => mockPOST(...args)
|
||||
// rerank 现在改用统一 axios(带 SSRF 拦截),mock axios.post 返回 axios 风格的 { data, ... }
|
||||
vi.mock('@fastgpt/service/common/api/axios', () => ({
|
||||
axios: {
|
||||
post: (...args: any[]) => Promise.resolve(mockAxiosPost(...args)).then((data) => ({ data }))
|
||||
}
|
||||
}));
|
||||
|
||||
// Mock text2Chunks:按 chunkSize 字符切分,保证测试确定性
|
||||
@@ -40,7 +44,7 @@ const mockModel: RerankModelItemType = {
|
||||
|
||||
describe('reRankRecall', () => {
|
||||
beforeEach(() => {
|
||||
mockPOST.mockReset();
|
||||
mockAxiosPost.mockReset();
|
||||
mockCountPromptTokens.mockReset();
|
||||
mockCountPromptTokens.mockImplementation(async (text: string) => text.length);
|
||||
});
|
||||
@@ -48,7 +52,7 @@ describe('reRankRecall', () => {
|
||||
// ── 基础场景 ──────────────────────────────────────────────────────────────
|
||||
|
||||
it('正常场景:多文档返回正确 id 和 score', async () => {
|
||||
mockPOST.mockResolvedValueOnce({
|
||||
mockAxiosPost.mockResolvedValueOnce({
|
||||
id: 'r1',
|
||||
results: [
|
||||
{ index: 1, relevance_score: 0.9 },
|
||||
@@ -73,7 +77,7 @@ describe('reRankRecall', () => {
|
||||
});
|
||||
|
||||
it('单文档正常召回', async () => {
|
||||
mockPOST.mockResolvedValueOnce({
|
||||
mockAxiosPost.mockResolvedValueOnce({
|
||||
id: 'r1',
|
||||
results: [{ index: 0, relevance_score: 0.75 }],
|
||||
meta: { tokens: { input_tokens: 10, output_tokens: 0 } }
|
||||
@@ -99,7 +103,7 @@ describe('reRankRecall', () => {
|
||||
});
|
||||
|
||||
expect(result).toEqual({ results: [], inputTokens: 0 });
|
||||
expect(mockPOST).not.toHaveBeenCalled();
|
||||
expect(mockAxiosPost).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('所有文档 text 为空或空白时,返回空结果,不发请求', async () => {
|
||||
@@ -113,7 +117,7 @@ describe('reRankRecall', () => {
|
||||
});
|
||||
|
||||
expect(result).toEqual({ results: [], inputTokens: 0 });
|
||||
expect(mockPOST).not.toHaveBeenCalled();
|
||||
expect(mockAxiosPost).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
// ── 复杂场景:文档切分 ────────────────────────────────────────────────────
|
||||
@@ -125,7 +129,7 @@ describe('reRankRecall', () => {
|
||||
// doc2 'short' length=5 <= 599 → 不切分 (index 3)
|
||||
const longText = 'a'.repeat(1100);
|
||||
|
||||
mockPOST.mockResolvedValueOnce({
|
||||
mockAxiosPost.mockResolvedValueOnce({
|
||||
id: 'r1',
|
||||
// API 按 score 降序返回
|
||||
results: [
|
||||
@@ -158,7 +162,7 @@ describe('reRankRecall', () => {
|
||||
// maxToken=600, query='q'(1), docBudget=599, chunkSize=539
|
||||
const longText = 'b'.repeat(1100);
|
||||
|
||||
mockPOST.mockResolvedValueOnce({
|
||||
mockAxiosPost.mockResolvedValueOnce({
|
||||
id: 'r1',
|
||||
results: [
|
||||
{ index: 1, relevance_score: 0.95 }, // chunk_1 最高
|
||||
@@ -181,7 +185,7 @@ describe('reRankRecall', () => {
|
||||
// ── inputTokens 计算 ──────────────────────────────────────────────────────
|
||||
|
||||
it('API 未返回 meta tokens 时,通过 countPromptTokens 估算', async () => {
|
||||
mockPOST.mockResolvedValueOnce({
|
||||
mockAxiosPost.mockResolvedValueOnce({
|
||||
id: 'r1',
|
||||
results: [{ index: 0, relevance_score: 0.5 }]
|
||||
// 无 meta
|
||||
@@ -198,7 +202,7 @@ describe('reRankRecall', () => {
|
||||
});
|
||||
|
||||
it('API 返回 meta tokens 时直接使用', async () => {
|
||||
mockPOST.mockResolvedValueOnce({
|
||||
mockAxiosPost.mockResolvedValueOnce({
|
||||
id: 'r1',
|
||||
results: [{ index: 0, relevance_score: 0.5 }],
|
||||
meta: { tokens: { input_tokens: 42, output_tokens: 0 } }
|
||||
@@ -216,7 +220,7 @@ describe('reRankRecall', () => {
|
||||
// ── requestUrl / requestAuth ──────────────────────────────────────────────
|
||||
|
||||
it('有 requestUrl 和 requestAuth 时,使用自定义地址和认证头', async () => {
|
||||
mockPOST.mockResolvedValueOnce({
|
||||
mockAxiosPost.mockResolvedValueOnce({
|
||||
id: 'r1',
|
||||
results: [{ index: 0, relevance_score: 0.5 }],
|
||||
meta: { tokens: { input_tokens: 5, output_tokens: 0 } }
|
||||
@@ -232,7 +236,7 @@ describe('reRankRecall', () => {
|
||||
documents: [{ id: 'doc1', text: 'hello' }]
|
||||
});
|
||||
|
||||
expect(mockPOST).toHaveBeenCalledWith(
|
||||
expect(mockAxiosPost).toHaveBeenCalledWith(
|
||||
'https://custom.rerank.io/rerank',
|
||||
expect.any(Object),
|
||||
expect.objectContaining({
|
||||
@@ -244,7 +248,7 @@ describe('reRankRecall', () => {
|
||||
});
|
||||
|
||||
it('未设置 requestUrl 时,使用 baseUrl/rerank', async () => {
|
||||
mockPOST.mockResolvedValueOnce({
|
||||
mockAxiosPost.mockResolvedValueOnce({
|
||||
id: 'r1',
|
||||
results: [{ index: 0, relevance_score: 0.5 }],
|
||||
meta: { tokens: { input_tokens: 5, output_tokens: 0 } }
|
||||
@@ -256,7 +260,7 @@ describe('reRankRecall', () => {
|
||||
documents: [{ id: 'doc1', text: 'hello' }]
|
||||
});
|
||||
|
||||
const url: string = mockPOST.mock.calls[0][0];
|
||||
const url: string = mockAxiosPost.mock.calls[0][0];
|
||||
expect(url.endsWith('/rerank')).toBe(true);
|
||||
});
|
||||
|
||||
@@ -297,7 +301,7 @@ describe('reRankRecall', () => {
|
||||
|
||||
it('docBudget === 501 时不因 query 过长 reject', async () => {
|
||||
// maxToken=502, query='q'(length=1) → docBudget = 502-1 = 501 > 500 → 正常发请求
|
||||
mockPOST.mockResolvedValueOnce({
|
||||
mockAxiosPost.mockResolvedValueOnce({
|
||||
id: 'r1',
|
||||
results: [{ index: 0, relevance_score: 0.5 }],
|
||||
meta: { tokens: { input_tokens: 5, output_tokens: 0 } }
|
||||
@@ -310,11 +314,11 @@ describe('reRankRecall', () => {
|
||||
});
|
||||
|
||||
expect(result.results).toHaveLength(1);
|
||||
expect(mockPOST).toHaveBeenCalledOnce();
|
||||
expect(mockAxiosPost).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it('API 请求失败时,reject 并传递原始错误', async () => {
|
||||
mockPOST.mockRejectedValueOnce(new Error('Network error'));
|
||||
mockAxiosPost.mockRejectedValueOnce(new Error('Network error'));
|
||||
|
||||
await expect(
|
||||
reRankRecall({
|
||||
@@ -326,7 +330,7 @@ describe('reRankRecall', () => {
|
||||
});
|
||||
|
||||
it('API 返回空 results 时,返回空 results', async () => {
|
||||
mockPOST.mockResolvedValueOnce({
|
||||
mockAxiosPost.mockResolvedValueOnce({
|
||||
id: 'r1',
|
||||
results: []
|
||||
});
|
||||
@@ -338,7 +342,7 @@ describe('reRankRecall', () => {
|
||||
});
|
||||
|
||||
expect(result.results).toHaveLength(0);
|
||||
expect(mockPOST).toHaveBeenCalledOnce();
|
||||
expect(mockAxiosPost).toHaveBeenCalledOnce();
|
||||
// 空 results 时提前返回,inputTokens 固定为 0
|
||||
expect(result.inputTokens).toBe(0);
|
||||
});
|
||||
|
||||
@@ -26,10 +26,32 @@ vi.mock('@fastgpt/service/common/system/utils', async (importOriginal) => {
|
||||
|
||||
vi.mock('@fastgpt/service/common/api/axios', async (importOriginal) => {
|
||||
const mod = await importOriginal<typeof import('@fastgpt/service/common/api/axios')>();
|
||||
// 把 axios 和 pickOutboundAxios 一起 mock:
|
||||
// - axios: 直接换成 mock(供绝对 URL 路径)
|
||||
// - pickOutboundAxios: 不论 URL 类型都返回同一个 mock client(供测试统一断言 .get 调用)
|
||||
const mockClient = {
|
||||
get: mockAxiosGet,
|
||||
defaults: { baseURL: 'http://localhost:3000' }
|
||||
};
|
||||
return {
|
||||
...mod,
|
||||
axios: {
|
||||
get: mockAxiosGet
|
||||
axios: mockClient,
|
||||
pickOutboundAxios: () => mockClient
|
||||
};
|
||||
});
|
||||
|
||||
// 文件下载链路对相对路径走 raw axios + serverRequestBaseUrl,这里也 mock 住,
|
||||
// 保证测试不会真发网络请求,且保留对 mockAxiosGet 调用次数的断言能力。
|
||||
// outbound.ts 用 axios.create() 创建内部 client,所以 mock 必须提供 create 方法。
|
||||
vi.mock('axios', () => {
|
||||
const internalClient = {
|
||||
get: mockAxiosGet,
|
||||
defaults: { baseURL: 'http://localhost:3000' }
|
||||
};
|
||||
return {
|
||||
default: {
|
||||
get: mockAxiosGet,
|
||||
create: vi.fn(() => internalClient)
|
||||
}
|
||||
};
|
||||
});
|
||||
@@ -409,8 +431,9 @@ describe('parseFileInfoFromUrls', () => {
|
||||
});
|
||||
|
||||
expect(mockAxiosGet).toHaveBeenCalledTimes(1);
|
||||
// 注:相对路径走 axios.create({ baseURL }) 创建的内部 client,
|
||||
// baseURL 在 client 上而不在 .get() 调用参数里。
|
||||
expect(mockAxiosGet).toHaveBeenCalledWith('/report.pdf', {
|
||||
baseURL: expect.any(String),
|
||||
responseType: 'arraybuffer'
|
||||
});
|
||||
expect(mockReadFileContentByBuffer).not.toHaveBeenCalled();
|
||||
|
||||
Reference in New Issue
Block a user