From 73b1291851541cd57330c3605b09c598ebf60287 Mon Sep 17 00:00:00 2001 From: mks155 Date: Thu, 22 Jan 2026 22:20:48 +0800 Subject: [PATCH 1/3] =?UTF-8?q?feat:=20=E4=B8=BB=E6=9C=BA=E9=BB=91?= =?UTF-8?q?=E7=99=BD=E5=90=8D=E5=8D=95=E6=94=AF=E6=8C=81=E9=80=9A=E9=85=8D?= =?UTF-8?q?=E7=AC=A6*=E5=8C=B9=E9=85=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../cn/keking/utils/DomainIpMatcherUtil.java | 265 ++++++++++++++++++ .../cn/keking/web/filter/TrustHostFilter.java | 6 +- .../utils/DomainIpMatcherUtilTests.java | 206 ++++++++++++++ 3 files changed, 475 insertions(+), 2 deletions(-) create mode 100644 server/src/main/java/cn/keking/utils/DomainIpMatcherUtil.java create mode 100644 server/src/test/java/cn/keking/utils/DomainIpMatcherUtilTests.java diff --git a/server/src/main/java/cn/keking/utils/DomainIpMatcherUtil.java b/server/src/main/java/cn/keking/utils/DomainIpMatcherUtil.java new file mode 100644 index 000000000..7c2caa4fb --- /dev/null +++ b/server/src/main/java/cn/keking/utils/DomainIpMatcherUtil.java @@ -0,0 +1,265 @@ +package cn.keking.utils; + +import java.util.Set; +import java.util.regex.Pattern; + +/** + * @author mks155 + * @date 2026/1/22 + * @description 域名/IP匹配工具 + * 支持:*.example.com, example.com, localhost, 127.0.0.1, 192.168.*, 172.16.*, 10.* + */ +public final class DomainIpMatcherUtil { + + // IPv4地址正则 + private static final Pattern IPV4_PATTERN = + Pattern.compile("^(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)$"); + + // 域名正则 + private static final Pattern DOMAIN_PATTERN = + Pattern.compile("^[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(\\.[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$"); + + private static final String WILDCARD_PREFIX = "*."; + + /** + * 检查域名/IP是否在允许的配置列表中 + */ + public static boolean isAllowed(Set allowedPatterns, String host) { + if (allowedPatterns == null || allowedPatterns.isEmpty() || host == null) { + return false; + } + + String trimmedHost = host.trim(); + if (trimmedHost.isEmpty()) { + return false; + } + + for (String pattern : allowedPatterns) { + if (pattern == null) continue; + + String trimmedPattern = pattern.trim(); + if (trimmedPattern.isEmpty()) continue; + + if (matchPattern(trimmedPattern, trimmedHost)) { + return true; + } + } + + return false; + } + + /** + * 匹配单个模式 + */ + private static boolean matchPattern(String pattern, String host) { + // 1. 精确匹配(需要验证格式,防止无效域名/IP被精确匹配) + if (pattern.equals(host)) { + // 只有格式合法才允许精确匹配 + return isValidDomain(pattern) || isIpFormat(pattern) || isSpecialDomain(pattern); + } + + // 2. IP段匹配(10.*, 192.168.*, 192.168.1.*) + if (isIpSegmentPattern(pattern)) { + return matchIpSegment(pattern, host); + } + + // 3. 通配符域名(*.example.com) + if (pattern.startsWith(WILDCARD_PREFIX)) { + return matchWildcardDomain(pattern, host); + } + + // 4. 精确IP匹配(已在第一步处理) + if (isIpFormat(pattern) && isIpFormat(host)) { + return false; + } + + // 5. 特殊域名(localhost, 127.0.0.1) + if (isSpecialDomain(pattern)) { + return isSpecialDomain(host) && pattern.equalsIgnoreCase(host); + } + + return false; + } + + /** + * 匹配通配符域名:*.example.com + * 规则: + * - 必须以 .suffix 结尾(不区分大小写) + * - 不能等于suffix + * - 子域名部分必须有效(防止evil.com.example.com) + * - 只能匹配一级子域名(a.example.com 可以,a.b.example.com 不行) + * - 域名不区分大小写 + */ + private static boolean matchWildcardDomain(String pattern, String domain) { + // 去掉 "*." 并转换为小写 + String suffix = pattern.substring(2).toLowerCase(); + String domainLower = domain.toLowerCase(); + + // 验证后缀格式 + if (!isValidDomain(suffix)) { + return false; + } + + // 必须以 .suffix 结尾 + if (!domainLower.endsWith("." + suffix)) { + return false; + } + + // 不能等于suffix + if (domainLower.equals(suffix)) { + return false; + } + + // 提取子域名部分 + String subdomain = domainLower.substring(0, domainLower.length() - suffix.length() - 1); + + // 防止多级子域名:a.b.example.com 不应匹配 *.example.com + if (subdomain.contains(".")) { + return false; + } + + // 验证子域名(防止evil.com.example.com) + return isValidSubdomain(subdomain); + } + + /** + * 匹配IP段:10.*, 192.168.*, 192.168.1.* + * 支持2、3、4段格式 + */ + private static boolean matchIpSegment(String pattern, String host) { + if (!isIpFormat(host)) { + return false; + } + + String[] patternParts = pattern.split("\\."); + String[] hostParts = host.split("\\."); + + // 支持2、3、4段pattern匹配4段host + if (hostParts.length != 4) { + return false; + } + + // 只比较pattern的段数,pattern的每一段对应host的对应段 + for (int i = 0; i < patternParts.length; i++) { + if ("*".equals(patternParts[i])) { + continue; + } + if (!patternParts[i].equals(hostParts[i])) { + return false; + } + } + return true; + } + + /** + * 判断是否是IP段模式(10.*, 192.168.*, 192.168.1.*) + * 支持2、3、4段格式 + */ + private static boolean isIpSegmentPattern(String str) { + if (str == null || !str.contains("*")) { + return false; + } + + String[] parts = str.split("\\."); + // 支持2、3、4段:10.*, 192.168.*, 192.168.1.* + if (parts.length < 2 || parts.length > 4) { + return false; + } + + for (String part : parts) { + if ("*".equals(part)) { + continue; + } + if (!isNumeric(part)) { + return false; + } + int num = Integer.parseInt(part); + if (num < 0 || num > 255) { + return false; + } + } + return true; + } + + /** + * 判断是否是标准IPv4格式 + */ + private static boolean isIpFormat(String str) { + return str != null && IPV4_PATTERN.matcher(str).matches(); + } + + /** + * 判断是否是特殊域名(localhost, 127.0.0.1) + */ + private static boolean isSpecialDomain(String str) { + if (str == null) { + return false; + } + return "localhost".equalsIgnoreCase(str) || "127.0.0.1".equals(str); + } + + /** + * 验证域名格式 + */ + private static boolean isValidDomain(String domain) { + if (domain == null || domain.isEmpty() || domain.length() > 253) { + return false; + } + + // 特殊域名直接通过 + if (isSpecialDomain(domain)) { + return true; + } + + // 格式检查 + if (domain.startsWith(".") || domain.endsWith(".") || domain.contains("..")) { + return false; + } + + String[] parts = domain.split("\\."); + for (String part : parts) { + if (part.isEmpty() || part.length() > 63) { + return false; + } + if (part.startsWith("-") || part.endsWith("-")) { + return false; + } + if (!DOMAIN_PATTERN.matcher(part).matches()) { + return false; + } + } + return true; + } + + /** + * 验证子域名(防止绕过) + */ + private static boolean isValidSubdomain(String subdomain) { + if (subdomain == null || subdomain.isEmpty()) { + return false; + } + if (subdomain.contains("*") || subdomain.contains(" ") || subdomain.contains("@") || subdomain.contains(";")) { + return false; + } + if (subdomain.replace(".", "").isEmpty()) { + return false; + } + return isValidDomain(subdomain); + } + + /** + * 判断是否是数字 + */ + private static boolean isNumeric(String str) { + if (str == null || str.isEmpty()) { + return false; + } + for (char c : str.toCharArray()) { + if (!Character.isDigit(c)) { + return false; + } + } + return true; + } + +} \ No newline at end of file diff --git a/server/src/main/java/cn/keking/web/filter/TrustHostFilter.java b/server/src/main/java/cn/keking/web/filter/TrustHostFilter.java index e661844f4..2a418e242 100644 --- a/server/src/main/java/cn/keking/web/filter/TrustHostFilter.java +++ b/server/src/main/java/cn/keking/web/filter/TrustHostFilter.java @@ -1,10 +1,12 @@ package cn.keking.web.filter; import cn.keking.config.ConfigConstants; +import cn.keking.utils.DomainIpMatcherUtil; import cn.keking.utils.WebUtils; import java.io.IOException; import java.nio.charset.StandardCharsets; + import jakarta.servlet.Filter; import jakarta.servlet.FilterChain; import jakarta.servlet.FilterConfig; @@ -56,7 +58,7 @@ public void doFilter(ServletRequest request, ServletResponse response, FilterCha public boolean isNotTrustHost(String host) { // 如果配置了黑名单,优先检查黑名单 if (CollectionUtils.isNotEmpty(ConfigConstants.getNotTrustHostSet())) { - return ConfigConstants.getNotTrustHostSet().contains(host); + return DomainIpMatcherUtil.isAllowed(ConfigConstants.getNotTrustHostSet(), host); } // 如果配置了白名单,检查是否在白名单中 @@ -66,7 +68,7 @@ public boolean isNotTrustHost(String host) { logger.debug("允许所有主机访问(通配符模式): {}", host); return false; } - return !ConfigConstants.getTrustHostSet().contains(host); + return DomainIpMatcherUtil.isAllowed(ConfigConstants.getTrustHostSet(), host); } // 安全加固:默认拒绝所有未配置的主机(防止SSRF攻击) diff --git a/server/src/test/java/cn/keking/utils/DomainIpMatcherUtilTests.java b/server/src/test/java/cn/keking/utils/DomainIpMatcherUtilTests.java new file mode 100644 index 000000000..67dd99e17 --- /dev/null +++ b/server/src/test/java/cn/keking/utils/DomainIpMatcherUtilTests.java @@ -0,0 +1,206 @@ +package cn.keking.utils; + +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * @author mks155 + * @date 2026/1/22 + * @description 域名/IP匹配工具 + * 支持:*.example.com, example.com, localhost, 127.0.0.1, 192.168.*, 172.16.*, 10.* + */ +public final class DomainIpMatcherUtilTests { + private static final Logger log = LoggerFactory.getLogger(DomainIpMatcherUtilTests.class); + + @Test + public void runAllTests() { + log.info("=== DomainIpMatcher增强测试开始 ==="); + + List failures = new ArrayList<>(); + int testCount = 0; + + // ========== 1. 域名通配符测试 ========== + Set patterns1 = Set.of("*.example.com", "*.test.com", "example.com", "www.example.com"); + test("api.example.com", true, patterns1, ++testCount, failures); + test("localhost.example.com", true, patterns1, ++testCount, failures); + test("a.b.example.com", false, patterns1, ++testCount, failures); + test("evil.com", false, patterns1, ++testCount, failures); + test("example.com.evil.com", false, patterns1, ++testCount, failures); + test("example.test.com", true, patterns1, ++testCount, failures); + test("example.com.test.com", false, patterns1, ++testCount, failures); + test("example.com", true, patterns1, ++testCount, failures); + test("www.example.com", true, patterns1, ++testCount, failures); + + // ========== 2. IP段匹配测试 ========== + Set patterns2 = Set.of("192.168.*", "10.*", "172.16.*", "192.168.0.1", "172.17.*"); + test("192.168.1.1", true, patterns2, ++testCount, failures); + test("192.168.0.100", true, patterns2, ++testCount, failures); + test("192.168.255.255", true, patterns2, ++testCount, failures); + test("10.0.0.1", true, patterns2, ++testCount, failures); + test("172.16.0.1", true, patterns2, ++testCount, failures); + test("192.169.1.1", false, patterns2, ++testCount, failures); + test("11.0.0.1", false, patterns2, ++testCount, failures); + test("192.168.0.1", true, patterns2, ++testCount, failures); + + // ========== 3. 精确IP和特殊域名 ========== + Set patterns3 = Set.of("127.0.0.1", "localhost"); + test("127.0.0.1", true, patterns3, ++testCount, failures); + test("127.0.0.2", false, patterns3, ++testCount, failures); + test("128.0.0.1", false, patterns3, ++testCount, failures); + test("localhost", true, patterns3, ++testCount, failures); + test("LOCALHOST", true, patterns3, ++testCount, failures); + test("local", false, patterns3, ++testCount, failures); + + // ========== 4. 边界和空值测试 ========== + Set patterns4 = Set.of("*.example.com"); + test("", false, patterns4, ++testCount, failures); + test(" ", false, patterns4, ++testCount, failures); + test(null, false, patterns4, ++testCount, failures); + test("example.com", false, Set.of(), ++testCount, failures); + test("example.com", false, null, ++testCount, failures); + + // ========== 5. 恶意输入测试 ========== + Set patterns5 = Set.of("*.example.com", "192.168.*"); + test("example.com; DROP TABLE", false, patterns5, ++testCount, failures); + test("example.com/../evil.com", false, patterns5, ++testCount, failures); + test("example.com