基于Redis的分布式服务限流方案

由于API接口无法控制调用方的行为,因此当遇到瞬时请求量激增时,会导致接口占用过多服务器资源,使得其他请求响应速度降低或是超时,更有甚者可能导致服务器宕机。 

限流指对应用服务接口的请求调用次数进行限制,对超过限制次数的请求则进行快速失败或丢弃。

限流可以应对:

1、热点业务带来的高并发请求;

2、客户端异常重试导致的并发请求;

3、恶意攻击请求;

限流算法多种多样,比如常见的:固定窗口计数器、滑动窗口计数器、漏桶、令牌桶等。本章通过Redis 的Lua来实现滑动窗口的计数器算法。

1、Redis lua脚本如下:

local ratelimit_info = redis.pcall('HMGET',KEYS[1],'last_time','current_token') 
local last_time = ratelimit_info[1] 
local current_token = tonumber(ratelimit_info[2]) 
local max_token = tonumber(ARGV[1]) 
local token_rate = tonumber(ARGV[2]) 
local current_time = tonumber(ARGV[3]) 
local reverse_time = token_rate*1000/max_token 
if current_token == nil then 
  current_token = max_token 
  last_time = current_time 
else 
  local past_time = current_time-last_time 
  local reverse_token = math.floor(past_time/reverse_time)
  current_token = current_token+reverse_token 
  last_time = reverse_time*reverse_token+last_time 
  if current_token>max_token then 
    current_token = max_token 
  end 
end 

local result = '0' 
if(current_token>0) then 
  result = '1' 
  current_token = current_token-1 
end 

redis.call('HMSET',KEYS[1],'last_time',last_time,'current_token',current_token) 
redis.call('pexpire',KEYS[1],math.ceil(reverse_time*(max_token-current_token)+(current_time-last_time))) 

return result

2、项目中引入spring-data-redis和commons-codec,相关配置请自行google。

3、RedisRateLimitScript类

package com.huatech.support.limit;

import org.apache.commons.codec.digest.DigestUtils;
import org.springframework.data.redis.core.script.RedisScript;

public class RedisRateLimitScript implements RedisScript<String> {

   private static final String SCRIPT = 
      "local ratelimit_info = redis.pcall('HMGET',KEYS[1],'last_time','current_token') local last_time = ratelimit_info[1] local current_token = tonumber(ratelimit_info[2]) local max_token = tonumber(ARGV[1]) local token_rate = tonumber(ARGV[2]) local current_time = tonumber(ARGV[3]) local reverse_time = token_rate*1000/max_token if current_token == nil then current_token = max_token last_time = current_time else local past_time = current_time-last_time local reverse_token = math.floor(past_time/reverse_time) current_token = current_token+reverse_token last_time = reverse_time*reverse_token+last_time if current_token>max_token then current_token = max_token end end local result = '0' if(current_token>0) then result = '1' current_token = current_token-1 end redis.call('HMSET',KEYS[1],'last_time',last_time,'current_token',current_token) redis.call('pexpire',KEYS[1],math.ceil(reverse_time*(max_token-current_token)+(current_time-last_time))) return result"; 

  @Override   
  public String getSha1() { 
    return DigestUtils.sha1Hex(SCRIPT); 
  } 

  @Override   
  public Class<String> getResultType() {     
	  return String.class; 
  } 

  @Override   
  public String getScriptAsString() {     
	  return SCRIPT; 
  } 
}

4、添加RateLimit注解

package com.huatech.support.limit;

import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target({ ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimit {
	
	/**
	 * 接口标识
	 * @return
	 */
	String value() default "";
	
	/**
	 * 周期:多久为一个周期,单位s
	 * @return
	 */
	int period() default 1;
	
	/**
	 * 周期速率
	 * @return
	 */
	int rate() default 100;
	
	/**
	 * 限制类型,默认按接口限制
	 * @return
	 */
	LimitType limitType() default LimitType.GLOBAL;
	
	/**
	 * 超限后处理方式,默认拒绝访问
	 * @return
	 */
	LimitedMethod method() default LimitedMethod.ACCESS_DENIED;

}

基于Redis的分布式服务限流有两种落地方案:

一种是基于aop的切面实现,另一种是基于interceptor的拦截器实现,下面分别做介绍。

方案一:基于aspject的aop实现方案

1、添加LimitAspect类

package com.huatech.common.aop;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.apache.commons.lang3.StringUtils;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.util.WebUtils;

import com.alibaba.fastjson.JSONObject;
import com.huatech.common.constant.Constants;
import com.huatech.common.util.IpUtil;
import com.huatech.support.limit.RateLimit;
import com.huatech.support.limit.RedisRateLimitScript;


@Aspect
@Component
public class LimitAspect {
	
	private static final Logger LOGGER = LoggerFactory.getLogger(LimitAspect.class);
	@Autowired
	private StringRedisTemplate redisTemplate;
	
	@Around("execution(* com.huatech.core.controller..*(..) ) && @annotation(com.huatech.support.limit.RateLimit)")
	public Object interceptor(ProceedingJoinPoint joinPoint) throws Throwable{
		
		MethodSignature signature = (MethodSignature) joinPoint.getSignature();
		Method method = signature.getMethod();
		RateLimit rateLimit = method.getAnnotation(RateLimit.class);
		if(rateLimit !=	null) {
			ServletRequestAttributes requestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
			HttpServletRequest request = requestAttributes.getRequest();
			HttpServletResponse response = requestAttributes.getResponse();
			
			Class<?> targetClass = method.getDeclaringClass();
			List<String> keyList = new ArrayList<>(1);
		    String key = rateLimit.value();
		    if(StringUtils.isBlank(key)){
		    	key = targetClass.getName() + "-" + method.getName();
		    }
		    switch (rateLimit.limitType()) {
			case IP:
				String ip = IpUtil.getRemortIP(request);
				key = ip + "-" + key;
				break;
			case USER:
				String userId = WebUtils.getSessionAttribute(request, Constants.SESSION_USER_ID).toString();
				key = userId + "-" + key;
			default:
				break;
			}
		    keyList.add(key);
		    
		    long timer = System.currentTimeMillis();
		    boolean pass = "1".equals(redisTemplate.execute(new RedisRateLimitScript(), keyList, 
		    		Integer.toString(rateLimit.rate()), Integer.toString(rateLimit.period()), 
		    		Long.toString(timer)));
		    if(pass){
		    	return joinPoint.proceed();
		    }else{				
		    	LOGGER.warn("接口key:{}, 周期:{}, 频率:{}", key, rateLimit.period(), rateLimit.rate());
		    	Map<String, Object> result = new HashMap<>();
				result.put("code", "400");
				result.put("msg", "访问超过次数限制!");
				response.setContentType("application/json");
				response.setCharacterEncoding("utf-8");
				response.getWriter().print(JSONObject.toJSON(result));
		    	return null;
		    }
		    
		}else{
			return joinPoint.proceed();
		}
	}
	
}

2、在spring-mvc配置文件中开启自定义注解

<aop:aspectj-autoproxy/>

3、开启LimitAspect类的自动扫描操作,或者在spring配置文件中配置bean

<context:component-scan base-package="com.huatech.common.aop,com.huatech.core.controller"/>

方式二:基于interceptor的拦截器实现方案

1、添加RateLimitInterceptor类

public class RateLimitInterceptor extends HandlerInterceptorAdapter {
	
	private static final Logger LOGGER = LoggerFactory.getLogger(RateLimitInterceptor.class);
	@Autowired StringRedisTemplate redisTemplate;

	@Override
	public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {

		if (handler instanceof HandlerMethod) {
			HandlerMethod method = (HandlerMethod) handler;
			final RateLimit rateLimit = method.getMethodAnnotation(RateLimit.class);
			if (rateLimit != null) {
				// 令牌名称
				List<String> keyList = new ArrayList<>(1);
			    String key = rateLimit.value();
			    if(StringUtils.isBlank(key)){
			    	key = method.getClass().getName() + "-" + method.getMethod().getName();
			    }
			    switch (rateLimit.limitType()) {
					case IP:
						String ip = IpUtil.getRemortIP(request);
						key = ip + "-" + key;
						break;
					case USER:
						String userId = WebUtils.getSessionAttribute(request, Constants.SESSION_USER_ID).toString();
						key = "uid:" + userId + "-" + key;
					default:
						break;
				}
			    keyList.add(key);
			    
			    long timer = System.currentTimeMillis();
			    boolean pass = "1".equals(redisTemplate.execute(new RedisRateLimitScript(), keyList, 
			    		Integer.toString(rateLimit.rate()), Integer.toString(rateLimit.period()), 
			    		Long.toString(timer)));
			    if(pass){
			    	return true;
			    }else{				
			    	LOGGER.warn("接口key:{}, 周期:{}, 频率:{}", key, rateLimit.period(), rateLimit.rate());
			    	Map<String, Object> result = new HashMap<>();
					result.put("code", "400");
					result.put("msg", "访问超过次数限制!");
					response.setContentType("application/json");
					response.setCharacterEncoding("utf-8");
					response.getWriter().print(JSONObject.toJSON(result));
			    	return false;
			    }
				
			}
		}

		return true;
	}
}

2、在spring-mvc配置文件中配置拦截器

<!-- 拦截器配置 -->
 	<mvc:interceptors>
 		<!-- 其他拦截器配置 -->
		****
		<!-- 限速拦截器配置 -->
		<mvc:interceptor>
			<mvc:mapping path="/**"/>
			<bean class="com.huatech.common.interceptor.RateLimitInterceptor"/>
		</mvc:interceptor>
	</mvc:interceptors>

使用@RateLimit

  在controller类的方法头上添加RateLimit注解

/**
     * 服务端ping地址
     * @param request
     * @param response
     * @throws Exception
     */
    @RequestMapping(value = "/api/app/open/ping.htm")
    @RateLimit(value="ping", period=5, rate=5)
    public void ping(HttpServletRequest request, HttpServletResponse response) throws Exception {
    	Map<String, Object> data = new HashMap<String, Object>();
    	data.put("time", System.currentTimeMillis());
    	ServletUtils.successData(response,data);
    }

 另外两个枚举类

package com.huatech.support.limit;
/**
 * 超限处理方式
 * @author [email protected]
 * @since 2019年8月28日
 * @version 1.0
 *
 */
public enum LimitedMethod {
	
	/**
	 * 拒绝访问(直接拒绝访问,不预警)
	 */
	ACCESS_DENIED,
	/**
	 * 预警短信(发送预警短信,但不拒绝访问)
	 */
	WARN_SMS,
	/**
	 * 拒绝访问并预警
	 */
	DENIED_AND_SMS
	;

}
package com.huatech.support.limit;
/**
 * 接口限制类型
 * @author [email protected]
 * @since 2019年8月29日
 * @version 1.0
 *
 */
public enum LimitType {
	
	/**
	 * 整个接口限制
	 */
	GLOBAL("接口"), 
	/**
	 * ip层面限制
	 */
	IP("ip"), 
	/**
	 * 用户层面限制
	 */
	USER("用户");
	
	public String value;
	private LimitType(String value) {
		this.value = value;
	}
	
	
}

 IpUtil工具类

package com.huatech.common.util;

import java.net.InetAddress;
import java.net.UnknownHostException;

import javax.servlet.http.HttpServletRequest;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * 
 * @author [email protected]
 * @since 2019年8月29日
 * @version 1.0
 *
 */
public class IpUtil {
	
	public static final Logger logger = LoggerFactory.getLogger(IpUtil.class);
    
	/**
	 * 获取请求IP
	 * @param request
	 * @return
	 */
	public static String getRemortIP(HttpServletRequest request) {
		String ip = request.getHeader("x-forwarded-for");
		if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
			ip = request.getHeader("X-Real-IP");
		}
		if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
			ip = request.getHeader("WL-Proxy-Client-IP");
		}
		if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
			ip = request.getRemoteAddr();
		}
		
		 //这里主要是获取本机的ip,可有可无  
	    if ("127.0.0.1".equals(ip) || ip.endsWith("0:0:0:0:0:0:1")) {  
	        // 根据网卡取本机配置的IP  
	        InetAddress inet = null;
	        try {  
	            inet = InetAddress.getLocalHost();  
	        } catch (UnknownHostException e) {  
	            logger.error(e.getMessage(), e);
	        }
	        if(inet != null){
	        	ip = inet.getHostAddress();
	        }
	        return ip;
	    } 
		if(ip.length() > 0){
			String[] ipArray = ip.split(",");
			if (ipArray != null && ipArray.length > 1) {
				return ipArray[0];
			}
			return ip;
		}
		
		return "";
	}
}

相关推荐