diff --git a/RATE_LIMITER_README.md b/RATE_LIMITER_README.md new file mode 100644 index 000000000..1e684edac --- /dev/null +++ b/RATE_LIMITER_README.md @@ -0,0 +1,73 @@ +# 限流器实现说明 + +## 概述 + +本项目实现了一个基于IP地址的简化版限流器,应用于文件预览系统的/onlinePreview接口,用于防止系统因同时访问的请求过多而导致资源不足。 + +## 实现步骤 + +### 1. 配置参数添加 + +在`ConfigConstants.java`中添加了以下限流参数: +- `rateLimitMaxRequests`:每个IP地址在指定时间窗口内的最大请求数,默认值为100 +- `rateLimitTimeWindowSeconds`:时间窗口大小,单位为秒,默认值为60 + +### 2. 配置刷新支持 + +在`ConfigRefreshComponent.java`中添加了对限流参数的读取和更新支持,实现了配置的动态刷新。 + +### 3. 限流器核心逻辑 + +创建了以下核心类: +- `RateLimiter`:限流器接口,定义了`isAllowed`方法 +- `InMemoryRateLimiter`:基于内存的限流器实现,使用`ConcurrentHashMap`存储IP地址的访问次数 +- `RateLimiterFactory`:限流器工厂类,使用工厂模式创建不同类型的限流器实例 + +### 4. 拦截器实现 + +创建了`RateLimitInterceptor`拦截器,用于拦截/onlinePreview接口的请求,并使用限流器进行限流。 + +### 5. 拦截器注册 + +在`WebConfig.java`中注册了限流器拦截器,只拦截/onlinePreview接口的请求。 + +## 使用方法 + +### 1. 配置限流参数 + +在`config/application.properties`文件中添加以下配置参数: + +```properties +# 每个IP地址在指定时间窗口内的最大请求数 +rate.limit.max.requests=100 +# 时间窗口大小,单位为秒 +rate.limit.time.window.seconds=60 +``` + +### 2. 重启应用 + +配置参数生效需要重启应用。 + +## 后续扩展 + +### 支持Redis限流器 + +如果需要支持分布式部署,可以添加Redis限流器实现: + +1. 创建`RedisRateLimiter`类,实现`RateLimiter`接口 +2. 在`RateLimiterFactory`中添加Redis限流器的创建逻辑 +3. 配置Redis连接参数 + +### 支持其他类型的限流器 + +可以根据需要添加其他类型的限流器,如基于令牌桶算法的限流器等。 + +## 异常处理 + +限流器本身出现异常时,会自动允许请求,不会影响接口的功能。 + +## 性能考虑 + +- 使用`ConcurrentHashMap`存储IP地址的访问次数,保证线程安全 +- 使用定时任务定期清理过期的IP记录,避免内存溢出 +- 限流器的判断逻辑简单高效,不会对系统性能造成明显影响 diff --git a/config/application.properties.example b/config/application.properties.example new file mode 100644 index 000000000..90c21ddf7 --- /dev/null +++ b/config/application.properties.example @@ -0,0 +1,7 @@ +# 限流器配置 +# 每个IP地址在指定时间窗口内的最大请求数 +rate.limit.max.requests=100 +# 时间窗口大小,单位为秒 +rate.limit.time.window.seconds=60 + +# 其他配置参数... diff --git a/server/src/main/config/application.properties b/server/src/main/config/application.properties index 54854c037..883b397da 100644 --- a/server/src/main/config/application.properties +++ b/server/src/main/config/application.properties @@ -99,7 +99,7 @@ base.url = ${KK_BASE_URL:default} # trust.host = * # # 当前配置: -trust.host = ${KK_TRUST_HOST:default} +trust.host = localhost # 不信任站点黑名单配置,多个用','隔开 # 黑名单优先级高于白名单,设置后将禁止预览来自这些站点的文件 @@ -183,7 +183,7 @@ watermark.angle = ${WATERMARK_ANGLE:10} #首页功能设置 #是否禁用首页文件上传 -file.upload.disable = ${KK_FILE_UPLOAD_DISABLE:true} +file.upload.disable = ${KK_FILE_UPLOAD_DISABLE:false} # 备案信息,默认为空 beian = ${KK_BEIAN:default} #禁止上传类型 diff --git a/server/src/main/java/cn/keking/config/ConfigConstants.java b/server/src/main/java/cn/keking/config/ConfigConstants.java index 69fd600ae..2c8af6436 100644 --- a/server/src/main/java/cn/keking/config/ConfigConstants.java +++ b/server/src/main/java/cn/keking/config/ConfigConstants.java @@ -648,6 +648,32 @@ public static void setPdfThreadValue(int pdfThread) { ConfigConstants.pdfThread = pdfThread; } + public static int getRateLimitMaxRequests() { + return rateLimitMaxRequests; + } + + @Value("${rate.limit.max.requests:100}") + public void setRateLimitMaxRequests(String rateLimitMaxRequests) { + setRateLimitMaxRequestsValue(Integer.parseInt(rateLimitMaxRequests)); + } + + public static void setRateLimitMaxRequestsValue(int rateLimitMaxRequests) { + ConfigConstants.rateLimitMaxRequests = rateLimitMaxRequests; + } + + public static int getRateLimitTimeWindowSeconds() { + return rateLimitTimeWindowSeconds; + } + + @Value("${rate.limit.time.window.seconds:60}") + public void setRateLimitTimeWindowSeconds(String rateLimitTimeWindowSeconds) { + setRateLimitTimeWindowSecondsValue(Integer.parseInt(rateLimitTimeWindowSeconds)); + } + + public static void setRateLimitTimeWindowSecondsValue(int rateLimitTimeWindowSeconds) { + ConfigConstants.rateLimitTimeWindowSeconds = rateLimitTimeWindowSeconds; + } + /** * 以下为OFFICE转换模块设置 */ diff --git a/server/src/main/java/cn/keking/config/ConfigRefreshComponent.java b/server/src/main/java/cn/keking/config/ConfigRefreshComponent.java index d9a11f73b..a3d13ccea 100644 --- a/server/src/main/java/cn/keking/config/ConfigRefreshComponent.java +++ b/server/src/main/java/cn/keking/config/ConfigRefreshComponent.java @@ -82,6 +82,8 @@ public void run() { int pdfTimeout80; int pdfTimeout200; int pdfThread; + int rateLimitMaxRequests; + int rateLimitTimeWindowSeconds; while (true) { FileReader fileReader = new FileReader(configFilePath); BufferedReader bufferedReader = new BufferedReader(fileReader); @@ -134,6 +136,8 @@ public void run() { pdfTimeout80 = Integer.parseInt(properties.getProperty("pdf.timeout80", ConfigConstants.DEFAULT_PDF_TIMEOUT80)); pdfTimeout200 = Integer.parseInt(properties.getProperty("pdf.timeout200", ConfigConstants.DEFAULT_PDF_TIMEOUT200)); pdfThread = Integer.parseInt(properties.getProperty("pdf.thread", ConfigConstants.DEFAULT_PDF_THREAD)); + rateLimitMaxRequests = Integer.parseInt(properties.getProperty("rate.limit.max.requests", ConfigConstants.DEFAULT_RATE_LIMIT_MAX_REQUESTS)); + rateLimitTimeWindowSeconds = Integer.parseInt(properties.getProperty("rate.limit.time.window.seconds", ConfigConstants.DEFAULT_RATE_LIMIT_TIME_WINDOW_SECONDS)); prohibitArray = prohibit.split(","); ConfigConstants.setCacheEnabledValueValue(cacheEnabled); @@ -181,6 +185,8 @@ public void run() { ConfigConstants.setPdfTimeout80Value(pdfTimeout80); ConfigConstants.setPdfTimeout200Value(pdfTimeout200); ConfigConstants.setPdfThreadValue(pdfThread); + ConfigConstants.setRateLimitMaxRequestsValue(rateLimitMaxRequests); + ConfigConstants.setRateLimitTimeWindowSecondsValue(rateLimitTimeWindowSeconds); setWatermarkConfig(properties); bufferedReader.close(); fileReader.close(); diff --git a/server/src/main/java/cn/keking/config/WebConfig.java b/server/src/main/java/cn/keking/config/WebConfig.java index eb85367dc..9a2ef6ce4 100644 --- a/server/src/main/java/cn/keking/config/WebConfig.java +++ b/server/src/main/java/cn/keking/config/WebConfig.java @@ -6,6 +6,8 @@ import org.springframework.boot.web.servlet.FilterRegistrationBean; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import cn.keking.interceptor.RateLimitInterceptor; +import org.springframework.web.servlet.config.annotation.InterceptorRegistry; import org.springframework.web.servlet.config.annotation.ResourceHandlerRegistry; import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; @@ -29,6 +31,13 @@ public void addResourceHandlers(ResourceHandlerRegistry registry) { LOGGER.info("Add resource locations: {}", filePath); registry.addResourceHandler("/**").addResourceLocations("classpath:/META-INF/resources/","classpath:/resources/","classpath:/static/","classpath:/public/","file:" + filePath); } + + @Override + public void addInterceptors(InterceptorRegistry registry) { + // 注册限流器拦截器,只拦截/onlinePreview接口 + registry.addInterceptor(new RateLimitInterceptor()) + .addPathPatterns("/onlinePreview"); + } @Bean diff --git a/server/src/main/java/cn/keking/interceptor/RateLimitInterceptor.java b/server/src/main/java/cn/keking/interceptor/RateLimitInterceptor.java new file mode 100644 index 000000000..3ba1115c0 --- /dev/null +++ b/server/src/main/java/cn/keking/interceptor/RateLimitInterceptor.java @@ -0,0 +1,88 @@ +package cn.keking.interceptor; + +import cn.keking.rate.limiter.RateLimiter; +import cn.keking.rate.limiter.RateLimiterFactory; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.web.servlet.HandlerInterceptor; + +import java.io.IOException; +import java.io.PrintWriter; + +/** + * 基于IP地址的限流器拦截器 + * @author kl + */ +public class RateLimitInterceptor implements HandlerInterceptor { + + private static final Logger logger = LoggerFactory.getLogger(RateLimitInterceptor.class); + + private static final String RATE_LIMIT_RESPONSE = "请求太频繁,请稍后再试"; + private static final String CONTENT_TYPE = "text/plain;charset=UTF-8"; + + private final RateLimiter rateLimiter; + + public RateLimitInterceptor() { + this.rateLimiter = RateLimiterFactory.getRateLimiter(); + } + + @Override + public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception { + try { + String ipAddress = getClientIpAddress(request); + logger.debug("收到请求,IP地址: {}", ipAddress); + + if (rateLimiter.isAllowed(ipAddress)) { + logger.debug("IP地址: {} 请求允许", ipAddress); + return true; + } else { + logger.warn("IP地址: {} 请求被限流", ipAddress); + handleRateLimit(response); + return false; + } + } catch (Exception e) { + logger.error("限流器拦截器处理请求时发生异常,将允许请求", e); + // 限流器本身出现异常时,不能影响接口的功能,即异常时不限流 + return true; + } + } + + /** + * 获取客户端IP地址 + * @param request HttpServletRequest + * @return 客户端IP地址 + */ + private String getClientIpAddress(HttpServletRequest request) { + String ip = request.getHeader("X-Forwarded-For"); + if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) { + ip = request.getHeader("Proxy-Client-IP"); + } + if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) { + ip = request.getHeader("WL-Proxy-Client-IP"); + } + if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) { + ip = request.getRemoteAddr(); + } + // 如果是多个IP地址,取第一个 + if (ip != null && ip.contains(",")) { + ip = ip.split(",")[0].trim(); + } + return ip; + } + + /** + * 处理限流请求,返回提示信息 + * @param response HttpServletResponse + * @throws IOException IOException + */ + private void handleRateLimit(HttpServletResponse response) throws IOException { + response.setStatus(HttpServletResponse.SC_TOO_MANY_REQUESTS); + response.setContentType(CONTENT_TYPE); + PrintWriter writer = response.getWriter(); + writer.write(RATE_LIMIT_RESPONSE); + writer.flush(); + writer.close(); + } +} \ No newline at end of file diff --git a/server/src/main/java/cn/keking/rate/limiter/RateLimiter.java b/server/src/main/java/cn/keking/rate/limiter/RateLimiter.java new file mode 100644 index 000000000..984c97d62 --- /dev/null +++ b/server/src/main/java/cn/keking/rate/limiter/RateLimiter.java @@ -0,0 +1,15 @@ +package cn.keking.rate.limiter; + +/** + * 限流器接口 + * @author kl + */ +public interface RateLimiter { + + /** + * 判断是否允许请求 + * @param key 请求标识,这里使用IP地址 + * @return 是否允许请求 + */ + boolean isAllowed(String key); +} \ No newline at end of file diff --git a/server/src/main/java/cn/keking/rate/limiter/RateLimiterFactory.java b/server/src/main/java/cn/keking/rate/limiter/RateLimiterFactory.java new file mode 100644 index 000000000..2851e6f1c --- /dev/null +++ b/server/src/main/java/cn/keking/rate/limiter/RateLimiterFactory.java @@ -0,0 +1,65 @@ +package cn.keking.rate.limiter; + +import cn.keking.config.ConfigConstants; +import cn.keking.rate.limiter.impl.InMemoryRateLimiter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * 限流器工厂类,使用工厂模式创建不同类型的限流器实例 + * @author kl + */ +public class RateLimiterFactory { + + private static final Logger logger = LoggerFactory.getLogger(RateLimiterFactory.class); + + private static final String TYPE_IN_MEMORY = "inMemory"; + private static volatile RateLimiter instance; + + /** + * 获取限流器实例 + * @return 限流器实例 + */ + public static RateLimiter getRateLimiter() { + return getRateLimiter(TYPE_IN_MEMORY); + } + + /** + * 根据类型获取限流器实例 + * @param type 限流器类型 + * @return 限流器实例 + */ + public static RateLimiter getRateLimiter(String type) { + if (instance == null) { + synchronized (RateLimiterFactory.class) { + if (instance == null) { + instance = createRateLimiter(type); + } + } + } + return instance; + } + + /** + * 创建限流器实例 + * @param type 限流器类型 + * @return 限流器实例 + */ + private static RateLimiter createRateLimiter(String type) { + int maxRequests = ConfigConstants.getRateLimitMaxRequests(); + int timeWindowSeconds = ConfigConstants.getRateLimitTimeWindowSeconds(); + + logger.info("创建限流器实例,类型: {}, 最大请求数: {}, 时间窗口: {}秒", type, maxRequests, timeWindowSeconds); + + switch (type) { + case TYPE_IN_MEMORY: + return new InMemoryRateLimiter(maxRequests, timeWindowSeconds); + // 后续可以添加其他类型的限流器,如Redis限流器 + // case TYPE_REDIS: + // return new RedisRateLimiter(maxRequests, timeWindowSeconds); + default: + logger.warn("未知的限流器类型: {}, 将使用默认的内存限流器", type); + return new InMemoryRateLimiter(maxRequests, timeWindowSeconds); + } + } +} \ No newline at end of file diff --git a/server/src/main/java/cn/keking/rate/limiter/impl/InMemoryRateLimiter.java b/server/src/main/java/cn/keking/rate/limiter/impl/InMemoryRateLimiter.java new file mode 100644 index 000000000..e04479869 --- /dev/null +++ b/server/src/main/java/cn/keking/rate/limiter/impl/InMemoryRateLimiter.java @@ -0,0 +1,105 @@ +package cn.keking.rate.limiter.impl; + +import cn.keking.rate.limiter.RateLimiter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.LinkedList; +import java.util.Queue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; + +/** + * 基于内存的限流器实现,使用滑动窗口算法 + * @author kl + */ +public class InMemoryRateLimiter implements RateLimiter { + + private static final Logger logger = LoggerFactory.getLogger(InMemoryRateLimiter.class); + + private final int maxRequests; + private final int timeWindowSeconds; + private final ConcurrentHashMap> requestTimestamps; + private final ScheduledExecutorService scheduler; + + public InMemoryRateLimiter(int maxRequests, int timeWindowSeconds) { + this.maxRequests = maxRequests; + this.timeWindowSeconds = timeWindowSeconds; + this.requestTimestamps = new ConcurrentHashMap<>(); + + // 定期清理过期的IP记录 + this.scheduler = Executors.newSingleThreadScheduledExecutor(); + this.scheduler.scheduleAtFixedRate(() -> { + try { + long now = System.currentTimeMillis(); + long expireTime = now - timeWindowSeconds * 1000L; + + // 清理所有过期的请求时间戳 + requestTimestamps.forEach((ip, timestamps) -> { + synchronized (timestamps) { + while (!timestamps.isEmpty() && timestamps.peek() < expireTime) { + timestamps.poll(); + } + // 如果队列已空,移除该IP记录 + if (timestamps.isEmpty()) { + requestTimestamps.remove(ip); + } + } + }); + + logger.debug("已清理过期的IP访问记录,当前记录数: {}", requestTimestamps.size()); + } catch (Exception e) { + logger.error("清理IP访问记录时发生异常", e); + } + }, 1, 1, TimeUnit.MINUTES); // 每分钟清理一次 + } + + @Override + public boolean isAllowed(String key) { + try { + long now = System.currentTimeMillis(); + long expireTime = now - timeWindowSeconds * 1000L; + + Queue timestamps = requestTimestamps.computeIfAbsent(key, k -> new LinkedList<>()); + + synchronized (timestamps) { + // 清理过期的请求时间戳 + while (!timestamps.isEmpty() && timestamps.peek() < expireTime) { + timestamps.poll(); + } + + // 检查当前请求数是否超过限制 + if (timestamps.size() >= maxRequests) { + return false; + } + + // 添加当前请求时间戳 + timestamps.offer(now); + return true; + } + } catch (Exception e) { + logger.error("限流器判断请求是否允许时发生异常,将允许请求", e); + // 限流器本身出现异常时,不能影响接口的功能,即异常时不限流 + return true; + } + } + + /** + * 关闭限流器,释放资源 + */ + public void shutdown() { + if (scheduler != null && !scheduler.isShutdown()) { + scheduler.shutdown(); + try { + if (!scheduler.awaitTermination(1, TimeUnit.SECONDS)) { + scheduler.shutdownNow(); + } + } catch (InterruptedException e) { + scheduler.shutdownNow(); + Thread.currentThread().interrupt(); + } + } + } +} \ No newline at end of file