Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions server/src/main/config/application.properties
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ base.url = ${KK_BASE_URL:default}
# trust.host = *
#
# 当前配置:
trust.host = ${KK_TRUST_HOST:default}
trust.host = localhost

# 不信任站点黑名单配置,多个用','隔开
# 黑名单优先级高于白名单,设置后将禁止预览来自这些站点的文件
Expand Down Expand Up @@ -183,7 +183,7 @@ watermark.angle = ${WATERMARK_ANGLE:10}

#首页功能设置
#是否禁用首页文件上传
file.upload.disable = ${KK_FILE_UPLOAD_DISABLE:true}
file.upload.disable = ${KK_FILE_UPLOAD_DISABLE:false}
# 备案信息,默认为空
beian = ${KK_BEIAN:default}
#禁止上传类型
Expand Down
14 changes: 14 additions & 0 deletions server/src/main/java/cn/keking/config/WebConfig.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cn.keking.config;

import cn.keking.web.filter.*;
import cn.keking.web.filter.RateLimitFilter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.web.servlet.FilterRegistrationBean;
Expand Down Expand Up @@ -86,6 +87,19 @@ public FilterRegistrationBean<UrlCheckFilter> getUrlCheckFilter() {
return registrationBean;
}

@Bean
public FilterRegistrationBean<RateLimitFilter> getRateLimitFilter() {
Set<String> filterUri = new HashSet<>();
filterUri.add("/onlinePreview");
RateLimitFilter filter = new RateLimitFilter();
FilterRegistrationBean<RateLimitFilter> registrationBean = new FilterRegistrationBean<>();
registrationBean.setFilter(filter);
registrationBean.setUrlPatterns(filterUri);
// 设置限流过滤器的顺序,确保它在其他过滤器之前执行
registrationBean.setOrder(5);
return registrationBean;
}

@Bean
public FilterRegistrationBean<AttributeSetFilter> getWatermarkConfigFilter() {
Set<String> filterUri = new HashSet<>();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package cn.keking.service.cache;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

/**
* @author: keking
* @since: 2023/10/10 10:10
* @description: 基于内存的限流器实现
*/
public class MemoryRateLimiter implements RateLimiter {

private static final Logger LOGGER = LoggerFactory.getLogger(MemoryRateLimiter.class);

// 存储IP地址的访问次数和时间戳
private final ConcurrentMap<String, RateLimitInfo> ipAccessMap = new ConcurrentHashMap<>();

// 限流周期(毫秒)
private final long period;

// 每个周期内的最大访问次数
private final int maxCount;

public MemoryRateLimiter(long period, int maxCount) {
this.period = period;
this.maxCount = maxCount;
}

@Override
public boolean allowAccess(String key) {
try {
long currentTime = System.currentTimeMillis();

// 使用compute方法确保线程安全
RateLimitInfo rateLimitInfo = ipAccessMap.compute(key, (k, v) -> {
if (v == null) {
// 首次访问,创建记录
return new RateLimitInfo(currentTime, 1);
} else {
if (currentTime - v.getLastAccessTime() > period) {
// 超过周期,重置访问次数,直接允许访问
return new RateLimitInfo(currentTime, 1);
} else if (v.getAccessCount() < maxCount) {
// 未超过最大访问次数,增加访问次数
return new RateLimitInfo(v.getLastAccessTime(), v.getAccessCount() + 1);
} else {
// 超过最大访问次数,保持原有记录不变
return v;
}
}
});

// 检查是否允许访问
if (currentTime - rateLimitInfo.getLastAccessTime() > period) {
// 超过周期,允许访问
return true;
} else {
// 未超过周期,检查访问次数
return rateLimitInfo.getAccessCount() <= maxCount;
}
} catch (Exception e) {
LOGGER.error("限流检查发生异常: {}, 不限流直接允许访问", e.getMessage(), e);
// 异常时不限流
return true;
}
}

// 限流信息类,使用volatile关键字确保线程可见性
private static class RateLimitInfo {
private volatile long lastAccessTime;
private volatile int accessCount;

public RateLimitInfo(long lastAccessTime, int accessCount) {
this.lastAccessTime = lastAccessTime;
this.accessCount = accessCount;
}

public long getLastAccessTime() {
return lastAccessTime;
}

public void setLastAccessTime(long lastAccessTime) {
this.lastAccessTime = lastAccessTime;
}

public int getAccessCount() {
return accessCount;
}

public void setAccessCount(int accessCount) {
this.accessCount = accessCount;
}
}
}
16 changes: 16 additions & 0 deletions server/src/main/java/cn/keking/service/cache/RateLimiter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package cn.keking.service.cache;

/**
* @author: keking
* @since: 2023/10/10 10:00
* @description: 限流器接口
*/
public interface RateLimiter {

/**
* 检查是否允许访问
* @param key 限流键,通常是IP地址
* @return true 允许访问,false 不允许访问
*/
boolean allowAccess(String key);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package cn.keking.service.cache;

import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;

/**
* @author: keking
* @since: 2023/10/10 10:20
* @description: 限流器工厂类,使用工厂模式创建限流器实例
*/
@Component
public class RateLimiterFactory {

// 限流周期(毫秒)
@Value("${rate.limit.period:60000}")
private long period;

// 每个周期内的最大访问次数
@Value("${rate.limit.max.count:10}")
private int maxCount;

/**
* 创建限流器实例
* @return 限流器实例
*/
public RateLimiter createRateLimiter() {
// 目前只支持基于内存的限流器
// 后续支持Redis等第三方缓存时,只需要修改这里的实现
return new MemoryRateLimiter(period, maxCount);
}
}
23 changes: 23 additions & 0 deletions server/src/main/java/cn/keking/utils/WebUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,29 @@ public static void setSessionAttr(HttpServletRequest request, String key, Object
session.setAttribute(key, value);
}

/**
* 获取用户的IP地址
* @param request 请求
* @return 用户的IP地址
*/
public static String getIpAddress(HttpServletRequest request) {
String ipAddress = request.getHeader("x-forwarded-for");
if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
ipAddress = request.getHeader("Proxy-Client-IP");
}
if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
ipAddress = request.getHeader("WL-Proxy-Client-IP");
}
if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
ipAddress = request.getRemoteAddr();
}
// 如果通过多个代理,第一个IP为客户端真实IP,多个IP以逗号分隔
if (ipAddress != null && ipAddress.contains(",")) {
ipAddress = ipAddress.split(",")[0].trim();
}
return ipAddress;
}

/**
* 移除 session 中的属性
* @param request 请求
Expand Down
68 changes: 68 additions & 0 deletions server/src/main/java/cn/keking/web/filter/RateLimitFilter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package cn.keking.web.filter;

import cn.keking.service.cache.RateLimiter;
import cn.keking.service.cache.RateLimiterFactory;
import cn.keking.utils.WebUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import jakarta.servlet.*;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.io.PrintWriter;

/**
* @author: keking
* @since: 2023/10/10 10:30
* @description: 基于IP地址的限流过滤器
*/
@Component
public class RateLimitFilter implements Filter {

private static final Logger LOGGER = LoggerFactory.getLogger(RateLimitFilter.class);

private RateLimiter rateLimiter;

@Autowired
private RateLimiterFactory rateLimiterFactory;

@Override
public void init(FilterConfig filterConfig) throws ServletException {
// 初始化限流器
rateLimiter = rateLimiterFactory.createRateLimiter();
}

@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) servletRequest;
HttpServletResponse response = (HttpServletResponse) servletResponse;

// 获取用户的IP地址
String ipAddress = WebUtils.getIpAddress(request);
LOGGER.debug("用户IP地址: {}", ipAddress);

// 检查是否允许访问
boolean allowAccess = rateLimiter.allowAccess(ipAddress);
if (allowAccess) {
// 允许访问,继续执行后续过滤器
filterChain.doFilter(request, response);
} else {
// 拒绝访问,返回提示信息
LOGGER.warn("用户IP地址: {} 请求太频繁,已拒绝访问", ipAddress);
response.setContentType("text/plain;charset=UTF-8");
response.setStatus(HttpServletResponse.SC_TOO_MANY_REQUESTS);
PrintWriter writer = response.getWriter();
writer.write("请求太频繁,请稍后再试");
writer.flush();
writer.close();
}
}

@Override
public void destroy() {
// 销毁方法,目前不需要实现
}
}