自定义注解+Redis实现简单IP限流

自定义注解+Redis实现简单IP限流

Scroll Down

前言

标题其实也是我被问到的一道面试题,只不过整段原话是这样子的(并不是问IP限流):

  • Q:假如你的系统QPS只有300,而未来你的应用有比较火爆的活动,可能会达到1W请求/s,那么你会怎么处理?
  • A:对应用进行扩容,负载均衡balabala...
  • Q:如果不扩容呢?
  • A:(空气凝结了一阵子,然后面试官给了些提示,我总算get到题意了)
    ...
  • Q:你如何将限流算法整合到项目中,从而实现可以配置指定接口限流?
  • A:可以自定义注解,结合Spring的AOP,对注解进行前置拦截,拦截代码的处理逻辑就是令牌桶算法(这里随便举了一个例子),如果能申请到令牌就进入真正的业务逻辑,如果不能就失败/降级返回。

关于这个问题,其实你还可以有更高级的回答,比如阿里有个开源的微服务治理组件:Sentinel。它就支持对接口进行限流(并且支持在线配置的),而且限流策略还挺多,我虽然有使用经验,但对其原理还是不太熟悉,就不敢贸然回答这个了。

其实一开始面试官就是想问限流,只是我把问题理解成如何提升系统QPS了,后来才get到面试官想问的是限流55555555。

下面废话不多说,直接上代码,代码逻辑不过多解释,相信不难看懂,如果有不懂的可以在评论区留言。

实验代码

注解代码

import run.ut.app.model.enums.RateLimitEnum;

import java.lang.annotation.*;
import java.util.concurrent.TimeUnit;

/**
 * Annotation for current limiting
 *
 * @author wenjie
 */

@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RequestRateLimit {

    RateLimitEnum limit();

    TimeUnit timeUnit() default TimeUnit.SECONDS;
}

枚举代码

/**
 * @author wenjie
 */
public enum RateLimitEnum {

    /**
     * M/N means that only M times can be requested in N time units
     */
    RRLimit_1_5("1/5"),
    RRLimit_1_10("1/10"),
    RRLimit_1_60("1/60"),;

    private String limit;

    RateLimitEnum(final String limit) {
        this.limit = limit;
    }

    public String limit() {
        return this.limit;
    }
}

切面代码

import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import run.ut.app.exception.FrequentAccessException;
import run.ut.app.model.enums.RateLimitEnum;
import run.ut.app.service.RedisService;
import run.ut.app.utils.ServletUtils;

import java.lang.reflect.Method;
import java.util.concurrent.TimeUnit;

/**
 * RequestRateLimit's aspect (only IP current limiting is implemented)
 * @author wenjie
 */

@Component
@Aspect
@Slf4j
@RequiredArgsConstructor(onConstructor = @__(@Autowired))
public class RequestRateLimitAspect {

    private final RedisService redisService;

    @Around("@annotation(run.ut.app.cache.lock.RequestRateLimit)")
    public Object requestRateLimit(ProceedingJoinPoint point) throws Throwable {
        // Gets request URI
        String requestURI = ServletUtils.getRequestURI();
        // Gets user-agent from header
        String userAgent = ServletUtils.getHeaderIgnoreCase("user-agent");
        // Gets IP
        String requestIp = ServletUtils.getRequestIp();

        // Gets method
        final Method method = ((MethodSignature) point.getSignature()).getMethod();
        // Gets annotation
        RequestRateLimit requestRateLimit = method.getAnnotation(RequestRateLimit.class);
        // Gets annotation params
        RateLimitEnum limitEnum = requestRateLimit.limit();
        TimeUnit timeUnit = requestRateLimit.timeUnit();
        int[] limitParams = getLimitParams(limitEnum);

        // Generates key
        String key = String.format("ut_api_request_limit_rate_%s_%s_%s", requestIp, method.getName(), requestURI);
        // Checks
        boolean over = redisService.overRequestRateLimit(key, limitParams[0], limitParams[1], timeUnit, userAgent);
        if (over) {
            throw new FrequentAccessException("请求过于频繁,请稍后重试。");
        }

        return point.proceed();
    }

    /***
     * In the returned array,
     *
     * @return ↓
     *        elements[0] is the time limit,
     *        elements[1] is the number of times that can be requested within the time limit.
     */
    private static int[] getLimitParams(RateLimitEnum rateLimitEnum) {
        String limit = rateLimitEnum.limit();
        int[] result = new int[2];
        String[] limits = limit.split("/");
        result[0] = Integer.parseInt(limits[0]);
        result[1] = Integer.parseInt(limits[1]);
        return result;
    }

}
  • 不同的限流策略,其实就只需要更改上面的redisService.overRequestRateLimit的逻辑就可以了,比如你用令牌桶算法,那就可以把这段逻辑改成尝试获取一个令牌。

redis代码

service接口层

    boolean overRequestRateLimit(@NonNull String key, final int expireTime, final int max,
                                 @NonNull TimeUnit timeUnit, String userAgent);

service实现层

    @Override
    public boolean overRequestRateLimit(String key, int max, int expireTime, TimeUnit timeUnit, String userAgent) {
        Assert.hasText(key, "redis key must not be blank");

        long count = increment(key, 1);
        long time = stringRedisTemplate.getExpire(key);

        /*
         * count == 1 means that redis key is set for the first time
         */
        if (count == 1 || time == -1) {
            expire(key, expireTime, timeUnit);
        }

        log.debug("UT api request limit rate:too many requests: key={}, redis count={}, max count={}, " +
            "expire time= {} s, user-agent={} ", key, count, max, expireTime, userAgent);

        return count > max;
    }
  • 如果你担心操作原子性的问题,这段逻辑也可以换成LUA脚本。

Serverlet工具类

用到了hutool,如果使用这个代码,注意引入相关依赖。

下面代码是基于halo的源码扩展的。

import cn.hutool.extra.servlet.ServletUtil;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.HttpServletRequest;
import java.util.Optional;

/**
 * Servlet utilities.
 *
 * @author johnniang
 * @author wenjie
 * @date 20-4-28
 */
public class ServletUtils {

    private ServletUtils() {
    }

    /**
     * Gets current http servlet request.
     *
     * @return an optional http servlet request
     */
    @NonNull
    public static Optional<HttpServletRequest> getCurrentRequest() {
        return Optional.ofNullable(RequestContextHolder.getRequestAttributes())
                .filter(requestAttributes -> requestAttributes instanceof ServletRequestAttributes)
                .map(requestAttributes -> ((ServletRequestAttributes) requestAttributes))
                .map(ServletRequestAttributes::getRequest);
    }

    /**
     * Gets request ip.
     *
     * @return ip address or null
     */
    @Nullable
    public static String getRequestIp() {
        return getCurrentRequest().map(ServletUtil::getClientIP).orElse(null);
    }

    /**
     * Gets request header.
     *
     * @param header http header name
     * @return http header of null
     */
    @Nullable
    public static String getHeaderIgnoreCase(String header) {
        return getCurrentRequest().map(request -> ServletUtil.getHeaderIgnoreCase(request, header)).orElse(null);
    }

    /**
     * Gets request URI
     */
    @Nullable
    public static String getRequestURI() {
        return getCurrentRequest().map(HttpServletRequest::getRequestURI).orElse(null);
    }

}

注解效果测试

比如一个发送邮箱验证码接口,现在要求你根据请求的IP限速,限制一个IP在60秒内只能请求一次接口,那么你就可以像下面这个样子使用注解(其它代码略):

image.png

第一次发送,返回成功了:

image.png

来看看redis的key:

ut_api_request_limit_rate_0:0:0:0:0:0:0:1_sendEmailCode_/ut/admin/sendEmailCode

来看看控制台debug信息:

Express api request limit rate:too many requests: key=ut_api_request_limit_rate_0:0:0:0:0:0:0:1_sendEmailCode_/ut/admin/sendEmailCode, redis count=1, max count=1, expire time= 60 s, user-agent=PostmanRuntime/7.24.1 

之后如果在60秒内再请求发送,就会返回类似请求过于频繁了的信息:
image.png

好了,本文暂时就讲这么多,最后再给读者扩展一些方案,比如缓存除了redis外,你还可以考虑使用Guava的本地缓存(同样可以设置有效期),又或者你可以学学halo的做法,自己用Java实现缓存。

如果你觉得本博客的实现仍有可以改进的地方,或者错误的地方,希望能在评论区指出。