diff --git a/server/src/main/config/application.properties b/server/src/main/config/application.properties index 54854c037..883b397da 100644 --- a/server/src/main/config/application.properties +++ b/server/src/main/config/application.properties @@ -99,7 +99,7 @@ base.url = ${KK_BASE_URL:default} # trust.host = * # # 当前配置: -trust.host = ${KK_TRUST_HOST:default} +trust.host = localhost # 不信任站点黑名单配置,多个用','隔开 # 黑名单优先级高于白名单,设置后将禁止预览来自这些站点的文件 @@ -183,7 +183,7 @@ watermark.angle = ${WATERMARK_ANGLE:10} #首页功能设置 #是否禁用首页文件上传 -file.upload.disable = ${KK_FILE_UPLOAD_DISABLE:true} +file.upload.disable = ${KK_FILE_UPLOAD_DISABLE:false} # 备案信息,默认为空 beian = ${KK_BEIAN:default} #禁止上传类型 diff --git a/server/src/main/java/cn/keking/config/WebConfig.java b/server/src/main/java/cn/keking/config/WebConfig.java index eb85367dc..81fce4aaa 100644 --- a/server/src/main/java/cn/keking/config/WebConfig.java +++ b/server/src/main/java/cn/keking/config/WebConfig.java @@ -1,6 +1,7 @@ package cn.keking.config; import cn.keking.web.filter.*; +import cn.keking.web.filter.RateLimitFilter; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.boot.web.servlet.FilterRegistrationBean; @@ -86,6 +87,19 @@ public FilterRegistrationBean getUrlCheckFilter() { return registrationBean; } + @Bean + public FilterRegistrationBean getRateLimitFilter() { + Set filterUri = new HashSet<>(); + filterUri.add("/onlinePreview"); + RateLimitFilter filter = new RateLimitFilter(); + FilterRegistrationBean registrationBean = new FilterRegistrationBean<>(); + registrationBean.setFilter(filter); + registrationBean.setUrlPatterns(filterUri); + // 设置限流过滤器的顺序,确保它在其他过滤器之前执行 + registrationBean.setOrder(5); + return registrationBean; + } + @Bean public FilterRegistrationBean getWatermarkConfigFilter() { Set filterUri = new HashSet<>(); diff --git a/server/src/main/java/cn/keking/service/cache/MemoryRateLimiter.java b/server/src/main/java/cn/keking/service/cache/MemoryRateLimiter.java new file mode 100644 index 000000000..5237d1638 --- /dev/null +++ b/server/src/main/java/cn/keking/service/cache/MemoryRateLimiter.java @@ -0,0 +1,97 @@ +package cn.keking.service.cache; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +/** + * @author: keking + * @since: 2023/10/10 10:10 + * @description: 基于内存的限流器实现 + */ +public class MemoryRateLimiter implements RateLimiter { + + private static final Logger LOGGER = LoggerFactory.getLogger(MemoryRateLimiter.class); + + // 存储IP地址的访问次数和时间戳 + private final ConcurrentMap ipAccessMap = new ConcurrentHashMap<>(); + + // 限流周期(毫秒) + private final long period; + + // 每个周期内的最大访问次数 + private final int maxCount; + + public MemoryRateLimiter(long period, int maxCount) { + this.period = period; + this.maxCount = maxCount; + } + + @Override + public boolean allowAccess(String key) { + try { + long currentTime = System.currentTimeMillis(); + + // 使用compute方法确保线程安全 + RateLimitInfo rateLimitInfo = ipAccessMap.compute(key, (k, v) -> { + if (v == null) { + // 首次访问,创建记录 + return new RateLimitInfo(currentTime, 1); + } else { + if (currentTime - v.getLastAccessTime() > period) { + // 超过周期,重置访问次数,直接允许访问 + return new RateLimitInfo(currentTime, 1); + } else if (v.getAccessCount() < maxCount) { + // 未超过最大访问次数,增加访问次数 + return new RateLimitInfo(v.getLastAccessTime(), v.getAccessCount() + 1); + } else { + // 超过最大访问次数,保持原有记录不变 + return v; + } + } + }); + + // 检查是否允许访问 + if (currentTime - rateLimitInfo.getLastAccessTime() > period) { + // 超过周期,允许访问 + return true; + } else { + // 未超过周期,检查访问次数 + return rateLimitInfo.getAccessCount() <= maxCount; + } + } catch (Exception e) { + LOGGER.error("限流检查发生异常: {}, 不限流直接允许访问", e.getMessage(), e); + // 异常时不限流 + return true; + } + } + + // 限流信息类,使用volatile关键字确保线程可见性 + private static class RateLimitInfo { + private volatile long lastAccessTime; + private volatile int accessCount; + + public RateLimitInfo(long lastAccessTime, int accessCount) { + this.lastAccessTime = lastAccessTime; + this.accessCount = accessCount; + } + + public long getLastAccessTime() { + return lastAccessTime; + } + + public void setLastAccessTime(long lastAccessTime) { + this.lastAccessTime = lastAccessTime; + } + + public int getAccessCount() { + return accessCount; + } + + public void setAccessCount(int accessCount) { + this.accessCount = accessCount; + } + } +} diff --git a/server/src/main/java/cn/keking/service/cache/RateLimiter.java b/server/src/main/java/cn/keking/service/cache/RateLimiter.java new file mode 100644 index 000000000..8e92a1564 --- /dev/null +++ b/server/src/main/java/cn/keking/service/cache/RateLimiter.java @@ -0,0 +1,16 @@ +package cn.keking.service.cache; + +/** + * @author: keking + * @since: 2023/10/10 10:00 + * @description: 限流器接口 + */ +public interface RateLimiter { + + /** + * 检查是否允许访问 + * @param key 限流键,通常是IP地址 + * @return true 允许访问,false 不允许访问 + */ + boolean allowAccess(String key); +} diff --git a/server/src/main/java/cn/keking/service/cache/RateLimiterFactory.java b/server/src/main/java/cn/keking/service/cache/RateLimiterFactory.java new file mode 100644 index 000000000..b85cb9d9a --- /dev/null +++ b/server/src/main/java/cn/keking/service/cache/RateLimiterFactory.java @@ -0,0 +1,31 @@ +package cn.keking.service.cache; + +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Component; + +/** + * @author: keking + * @since: 2023/10/10 10:20 + * @description: 限流器工厂类,使用工厂模式创建限流器实例 + */ +@Component +public class RateLimiterFactory { + + // 限流周期(毫秒) + @Value("${rate.limit.period:60000}") + private long period; + + // 每个周期内的最大访问次数 + @Value("${rate.limit.max.count:10}") + private int maxCount; + + /** + * 创建限流器实例 + * @return 限流器实例 + */ + public RateLimiter createRateLimiter() { + // 目前只支持基于内存的限流器 + // 后续支持Redis等第三方缓存时,只需要修改这里的实现 + return new MemoryRateLimiter(period, maxCount); + } +} diff --git a/server/src/main/java/cn/keking/utils/WebUtils.java b/server/src/main/java/cn/keking/utils/WebUtils.java index f62862cfb..5a6eb199e 100644 --- a/server/src/main/java/cn/keking/utils/WebUtils.java +++ b/server/src/main/java/cn/keking/utils/WebUtils.java @@ -359,6 +359,29 @@ public static void setSessionAttr(HttpServletRequest request, String key, Object session.setAttribute(key, value); } + /** + * 获取用户的IP地址 + * @param request 请求 + * @return 用户的IP地址 + */ + public static String getIpAddress(HttpServletRequest request) { + String ipAddress = request.getHeader("x-forwarded-for"); + if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) { + ipAddress = request.getHeader("Proxy-Client-IP"); + } + if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) { + ipAddress = request.getHeader("WL-Proxy-Client-IP"); + } + if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) { + ipAddress = request.getRemoteAddr(); + } + // 如果通过多个代理,第一个IP为客户端真实IP,多个IP以逗号分隔 + if (ipAddress != null && ipAddress.contains(",")) { + ipAddress = ipAddress.split(",")[0].trim(); + } + return ipAddress; + } + /** * 移除 session 中的属性 * @param request 请求 diff --git a/server/src/main/java/cn/keking/web/filter/RateLimitFilter.java b/server/src/main/java/cn/keking/web/filter/RateLimitFilter.java new file mode 100644 index 000000000..0d3aeaa09 --- /dev/null +++ b/server/src/main/java/cn/keking/web/filter/RateLimitFilter.java @@ -0,0 +1,68 @@ +package cn.keking.web.filter; + +import cn.keking.service.cache.RateLimiter; +import cn.keking.service.cache.RateLimiterFactory; +import cn.keking.utils.WebUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Component; + +import jakarta.servlet.*; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.io.PrintWriter; + +/** + * @author: keking + * @since: 2023/10/10 10:30 + * @description: 基于IP地址的限流过滤器 + */ +@Component +public class RateLimitFilter implements Filter { + + private static final Logger LOGGER = LoggerFactory.getLogger(RateLimitFilter.class); + + private RateLimiter rateLimiter; + + @Autowired + private RateLimiterFactory rateLimiterFactory; + + @Override + public void init(FilterConfig filterConfig) throws ServletException { + // 初始化限流器 + rateLimiter = rateLimiterFactory.createRateLimiter(); + } + + @Override + public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException { + HttpServletRequest request = (HttpServletRequest) servletRequest; + HttpServletResponse response = (HttpServletResponse) servletResponse; + + // 获取用户的IP地址 + String ipAddress = WebUtils.getIpAddress(request); + LOGGER.debug("用户IP地址: {}", ipAddress); + + // 检查是否允许访问 + boolean allowAccess = rateLimiter.allowAccess(ipAddress); + if (allowAccess) { + // 允许访问,继续执行后续过滤器 + filterChain.doFilter(request, response); + } else { + // 拒绝访问,返回提示信息 + LOGGER.warn("用户IP地址: {} 请求太频繁,已拒绝访问", ipAddress); + response.setContentType("text/plain;charset=UTF-8"); + response.setStatus(HttpServletResponse.SC_TOO_MANY_REQUESTS); + PrintWriter writer = response.getWriter(); + writer.write("请求太频繁,请稍后再试"); + writer.flush(); + writer.close(); + } + } + + @Override + public void destroy() { + // 销毁方法,目前不需要实现 + } +}