前言
之前在写UT项目的时候其实就想写这篇博客了,只是碍于各种破事,被迫拖到现在。网上很多SpringBoot整合Netty 或者 使用Netty搭建WebSocket的博客几乎都没有提到如何对连接进行认证,还好我Netty的基础打得比较好,没花多少时间就攻破连接认证的问题。
如果你想学习使用SpringBoot整合Netty搭建WebSocket服务,同时又苦恼不知道怎么对WebSocket连接进行认证,那么这篇博客一定很适合你。
PS:本篇博客注重于搭建服务,对于理论方面的知识,仅会蜻蜓点水。
代码地址
由于时间关系,我就不单独重写一个demo了,完整的代码已经整合到项目中:https://gitee.com/wenjie2018/UT-APP
核心代码都在:run.ut.app.netty
目录下,自行查看即可。
依赖版本:
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-all</artifactId>
<version>4.1.42.Final</version>
</dependency>
核心代码
工具类
由于我整合的方式比较特殊,Netty服务启动是在refresh之后的,所以要获取容器的一些Bean,就需要到获取Bean的工具类。
SpringUtils.java
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.ApplicationContextException;
import org.springframework.stereotype.Component;
import org.springframework.util.ObjectUtils;
/**
* Spring utilities
*
* @author wenjie
*/
@Component
public class SpringUtils implements ApplicationContextAware {
private static ApplicationContext applicationContext;
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
if (ObjectUtils.isEmpty(applicationContext)) {
throw new ApplicationContextException("applicationContext must not be null");
}
SpringUtils.applicationContext = applicationContext;
}
/**
* Gets ApplicationContext
* @return ApplicationContext
*/
public static ApplicationContext getApplicationContext() {
return applicationContext;
}
/**
* Gets bean by bean's name
* @param name bean's name
* @return bean
*/
public static Object getBean(String name) {
return getApplicationContext().getBean(name);
}
/**
* Gets bean by bean's class(java.lang.Class)
* @param clazz bean Class
* @param <T> Class type
* @return bean
*/
public static <T> T getBean(Class<T> clazz) {
return getApplicationContext().getBean(clazz);
}
/**
* Gets bean by bean's class and bean's name
* @param name bean name
* @param clazz bean Class
* @param <T> Class type
* @return bean
*/
public static <T> T getBean(String name, Class<T> clazz) {
return getApplicationContext().getBean(name, clazz);
}
}
各种Handler
连接认证Handler
import io.jsonwebtoken.Claims;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaders;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import run.ut.app.security.util.JwtOperator;
/**
* FullHttpRequest
*/
@ChannelHandler.Sharable
@Slf4j
@Component
@RequiredArgsConstructor(onConstructor = @__(@Autowired))
public class AuthHandler extends ChannelInboundHandlerAdapter {
private final UserChannelManager userChannelManager;
private final JwtOperator jwtOperator;
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof FullHttpRequest) {
FullHttpRequest request = (FullHttpRequest) msg;
HttpHeaders headers = request.headers();
if (headers.size() < 1) {
ctx.channel().close();
return;
}
String token = headers.get("token");
Claims claims = jwtOperator.getClaimsFromToken(token);
Long uid = Long.valueOf(claims.get("uid") + "");
userChannelManager.add(uid, ctx.channel());
log.debug("Authentication success. uid: {}", uid);
ctx.pipeline().remove(this);
// 对事件进行传播,知道完成WebSocket连接。
ctx.fireChannelRead(msg);
} else {
ctx.channel().close();
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
cause.printStackTrace();
userChannelManager.remove(ctx.channel());
ctx.channel().close();
}
}
- 业务代码不在这里补充了,有兴趣可以直接看项目完整源码。
- 另外注意我这里的:ctx.pipeline().remove(this);如果你不仅仅需要连接认证,你还要对每次接受到的信息进行认证,那么就可以把这行代码删除。
处理WebSocket消息的Handler
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import run.ut.app.model.enums.WebSocketMsgTypeEnum;
import run.ut.app.model.support.WebSocketMsg;
import run.ut.app.utils.JsonUtils;
/**
* @author wenjie
*/
@ChannelHandler.Sharable
@Slf4j
@Component
@RequiredArgsConstructor(onConstructor = @__(@Autowired))
public class ClientMsgHandler extends SimpleChannelInboundHandler<TextWebSocketFrame> {
private final UserChannelManager userChannelManager;
@Override
protected void channelRead0(ChannelHandlerContext ctx, TextWebSocketFrame msg)
throws Exception {
String json = msg.text();
WebSocketMsg webSocketMsg = JsonUtils.jsonToObject(json, WebSocketMsg.class);
WebSocketMsgTypeEnum type = WebSocketMsgTypeEnum.getByType(webSocketMsg.getType());
switch (type) {
case KEEPALIVE:
log.debug("Get keepalive frame");
}
}
@Override
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
userChannelManager.remove(ctx.channel());
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
cause.printStackTrace();
ctx.channel().close();
userChannelManager.remove(ctx.channel());
}
}
- 目前只有对心跳包的处理,可根据业务不同增加更多的消息处理。
空闲事件Handler
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
/**
* Heart Beat Handler. If the client's heartbeat frames is not received for a long time
*
* @author wenjie
*/
@ChannelHandler.Sharable
@Slf4j
@Component
@RequiredArgsConstructor(onConstructor = @__(@Autowired))
public class HeartBeatHandler extends ChannelInboundHandlerAdapter {
private final UserChannelManager userChannelManager;
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof IdleStateEvent) {
IdleStateEvent event = (IdleStateEvent)evt;
if (event.state() == IdleState.READER_IDLE) {
log.debug("READER_IDLE...");
} else if (event.state() == IdleState.WRITER_IDLE) {
log.debug("WRITER_IDLE...");
} else if (event.state() == IdleState.ALL_IDLE) {
// Closes the channel which state is ALL_IDLE
Channel channel = ctx.channel();
// Clear cache
userChannelManager.remove(channel);
channel.close();
log.debug("close channel: {}", channel.id().asLongText());
}
}
}
}
- 若超过指定时间(时间由IdleStateHandler设置,下面会说到),就断开连接。
IP限流Handler
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import run.ut.app.service.RedisService;
import java.net.InetSocketAddress;
import java.util.concurrent.TimeUnit;
/**
* @author wenjie
*/
@Slf4j
@ChannelHandler.Sharable
@Component
@RequiredArgsConstructor(onConstructor = @__(@Autowired))
public class WebSocketRateLimitHandler extends ChannelInboundHandlerAdapter {
private final RedisService redisService;
private final int EXPIRE_TIME = 5;
private final int MAX = 10;
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
InetSocketAddress insocket = (InetSocketAddress) ctx.channel().remoteAddress();
String ip = insocket.getAddress().getHostAddress();
// Generates key
String key = String.format("ut_wss_limit_rate_%s", ip);
// Checks
boolean over = redisService.overRequestRateLimit(key, MAX, EXPIRE_TIME, TimeUnit.SECONDS, "websocket");
if (over) {
log.debug("IP: {} 触发限流了 ",ip);
ctx.channel().close();
return;
}
ctx.fireChannelRead(msg);
}
}
- 限流的代码在我之前的博客有提到,这里就不重复讲了。
管理用户和Channel映射关系的容器类
由于考虑到多端在线,所以用的是ConcurrentHashMap<Long, Set<Channel>>
,如果业务上只允许单设备在线,则ConcurrentHashMap<Long, Channel>
即可。
import com.fasterxml.jackson.core.JsonProcessingException;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
import org.springframework.stereotype.Component;
import org.springframework.util.ObjectUtils;
import run.ut.app.model.enums.WebSocketMsgTypeEnum;
import run.ut.app.model.support.WebSocketMsg;
import run.ut.app.utils.JsonUtils;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
/**
* It is used to manage the mapping between user and channel.
*
* @author wenjie
*/
@Slf4j
@Component
@RequiredArgsConstructor(onConstructor = @__(@Autowired))
public class UserChannelManager {
private ConcurrentHashMap<Long, Set<Channel>> userChannelMap = new ConcurrentHashMap<>(1 << 8);
private final Lock lock = new ReentrantLock();
/**
* Save the mapping of uid and channel
*
* @param uid uid
* @param channel channel
*/
public void add(@NonNull Long uid, @NonNull Channel channel) {
lock.lock();
Set<Channel> channels = userChannelMap.get(uid);
if (ObjectUtils.isEmpty(channels) || channels.size() == 0) {
Set<Channel> channelSet = new HashSet<>();
channelSet.add(channel);
userChannelMap.put(uid, channelSet);
} else {
channels.add(channel);
userChannelMap.put(uid, channels);
}
lock.unlock();
}
/**
* Remove the element by uid
*
* @param uid uid
*/
public void remove(@NonNull Long uid) {
userChannelMap.remove(uid);
}
/**
* Remove the cache by channel
*
* @param channel channel
*/
public void remove(@NonNull Channel channel) {
userChannelMap.entrySet().stream().filter(entry -> entry.getValue().contains(channel))
.forEach(entry -> entry.getValue().remove(channel));
}
/**
* Get channel by uid
* @param uid uid
* @return channel
*/
@Nullable
public Set<Channel> get(@NonNull Long uid) {
return userChannelMap.get(uid);
}
/**
* Clear cache
*/
public void clearAll() {
userChannelMap.clear();
}
/**
* Write and flush by uid
*
* @param uid uid
* @param msgObj msg object, it will be automatically converted to json.
* @throws JsonProcessingException If msgObj fails to convert to json.
*/
public void writeAndFlush(@NonNull Long uid, @NonNull Object msgObj, @NonNull WebSocketMsgTypeEnum typeEnum) throws JsonProcessingException {
Set<Channel> channelSet = userChannelMap.get(uid);
if (ObjectUtils.isEmpty(channelSet) || channelSet.size() == 0) {
return;
}
for (Channel channel : channelSet) {
if (channel.isActive()) {
WebSocketMsg webSocketMsg = new WebSocketMsg()
.setType(typeEnum.getType())
.setMsg(msgObj);
String json = JsonUtils.objectToJson(webSocketMsg);
TextWebSocketFrame textWebSocketFrame = new TextWebSocketFrame(json);
ChannelFuture channelFuture = channel.writeAndFlush(textWebSocketFrame);
channelFuture.addListener((ChannelFutureListener)future -> {
log.debug("对uid:{}, 发送websocket消息:{}", uid, json);
});
}
}
}
/**
* Write and flush to every user
* @param msgObj msg object, it will be automatically converted to json.
* @throws JsonProcessingException If msgObj fails to convert to json.
*/
public void writeAndFlush(@NonNull Object msgObj, @NonNull WebSocketMsgTypeEnum typeEnum) throws JsonProcessingException {
WebSocketMsg webSocketMsg = new WebSocketMsg()
.setType(typeEnum.getType())
.setMsg(msgObj);
String json = JsonUtils.objectToJson(webSocketMsg);
TextWebSocketFrame textWebSocketFrame = new TextWebSocketFrame(json);
userChannelMap.forEach((uid, channels) -> {
for (Channel channel : channels) {
if (channel.isActive()) {
ChannelFuture channelFuture = channel.writeAndFlush(textWebSocketFrame);
channelFuture.addListener((ChannelFutureListener)future -> {
log.debug("对uid:{}, 发送websocket消息:{}", uid, json);
});
}
}
});
}
}
配置类
package run.ut.app.config.netty;
import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;
/**
* @author wenjie
*/
@Component
@ConfigurationProperties(prefix = "netty.websocket")
@Data
public class WebSocketConfiguration {
private int port;
private String contextPath = "/ws";
/**
* 0 for automatic setting(The default is CPU * 2)
*/
private int workerThreads = 0;
/**
* 0 for automatic setting (The default is CPU * 2)
*/
private int bossThreads = 0;
/**
* Only in Linux environments can this be set to true
* @see <a href="https://stackoverflow.com/questions/35568365/netty-epolleventloopgroup-vs-nioeventloopgroup-which-should-i-choose-on-centos">link<a/>
*/
private boolean epoll = false;
}
配置样例如下:
##########################################################
##
## websocket 配置
##
##########################################################
netty:
websocket:
# 服务端口
port: 8088
context-path: /ut/ws
# 0表示自动设置,自动设置的值为:CPU核心数 * 2
boss-threads: 0
# 0表示自动设置,自动设置的值为:CPU核心数 * 2
worker-threads: 0
# 只能在linux和mac系统下才能设置为true,可以获得更高的性能
epoll: false
Channel初始化器
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.stream.ChunkedWriteHandler;
import io.netty.handler.timeout.IdleStateHandler;
import run.ut.app.config.netty.WebSocketConfiguration;
import run.ut.app.utils.SpringUtils;
import java.util.concurrent.TimeUnit;
/**
* Init Channel
*
* @author wenjie
*/
public class WSServerInitialzer extends ChannelInitializer<SocketChannel> {
private AuthHandler authHandler;
private ClientMsgHandler clientMsgHandler;
private HeartBeatHandler heartBeatHandler;
private WebSocketRateLimitHandler webSocketRateLimitHandler;
WSServerInitialzer() {
authHandler = SpringUtils.getBean(AuthHandler.class);
clientMsgHandler = SpringUtils.getBean(ClientMsgHandler.class);
heartBeatHandler = SpringUtils.getBean(HeartBeatHandler.class);
webSocketRateLimitHandler = SpringUtils.getBean(WebSocketRateLimitHandler.class);
}
@Override
protected void initChannel(SocketChannel ch) throws Exception {
WebSocketConfiguration webSocketConfiguration = SpringUtils.getBean(WebSocketConfiguration.class);
ChannelPipeline pipeline = ch.pipeline();
pipeline
.addLast(webSocketRateLimitHandler)
.addLast(new HttpServerCodec())
.addLast(new ChunkedWriteHandler())
.addLast(new HttpObjectAggregator(1024 * 64))
.addLast(authHandler)
// .addLast(new IdleStateHandler(30, 30, 5, TimeUnit.SECONDS)) // test
.addLast(new IdleStateHandler(10, 10, 30, TimeUnit.MINUTES))
.addLast(heartBeatHandler)
.addLast(new WebSocketServerProtocolHandler(webSocketConfiguration.getContextPath()))
.addLast(clientMsgHandler);
}
}
初始化 & 启动 工具类
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.epoll.EpollEventLoopGroup;
import io.netty.channel.epoll.EpollServerSocketChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.util.concurrent.DefaultThreadFactory;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import run.ut.app.config.netty.WebSocketConfiguration;
import run.ut.app.exception.WebSocketException;
/**
* WebSocket Server
*
* @author wenjie
*/
@Component
@Slf4j
@RequiredArgsConstructor(onConstructor = @__(@Autowired))
public class WebSocketServer {
private final WebSocketConfiguration webSocketConfiguration;
private ServerBootstrap server;
public void init() {
EventLoopGroup bossGroup, workerGroup;
server = new ServerBootstrap();
int bossThreads = webSocketConfiguration.getBossThreads();
int workerThreads = webSocketConfiguration.getWorkerThreads();
boolean epoll = webSocketConfiguration.isEpoll();
if (epoll) {
bossGroup = new EpollEventLoopGroup(bossThreads,
new DefaultThreadFactory("WebSocketBossGroup", true));
workerGroup = new EpollEventLoopGroup(workerThreads,
new DefaultThreadFactory("WebSocketWorkerGroup", true));
server.channel(EpollServerSocketChannel.class);
} else {
bossGroup = new NioEventLoopGroup(bossThreads);
workerGroup = new NioEventLoopGroup(workerThreads);
server.channel(NioServerSocketChannel.class);
}
server.group(bossGroup, workerGroup)
.childHandler(new WSServerInitialzer())
.childOption(ChannelOption.TCP_NODELAY, true)
.childOption(ChannelOption.SO_KEEPALIVE, true);
}
public void start() throws Exception {
log.info("WebSocketServer - Starting...");
ChannelFuture channelFuture = server.bind(webSocketConfiguration.getPort()).sync();
channelFuture.addListener(future -> {
if (future.isSuccess()) {
log.info("WebSocketServer - Start completed.");
} else {
throw new WebSocketException("WebSocket启动失败!");
}
});
}
}
启动服务
SpringBoot启动后,WebSocket服务在监听到ContextRefreshedEvent
之后(IOC已完成初始化)启动。
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationListener;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.event.ContextRefreshedEvent;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;
import run.ut.app.netty.WebSocketServer;
/**
* The method executed after the application context is refreshed.
*
* @author wenjie
*/
@Configuration
@Order(Ordered.HIGHEST_PRECEDENCE + 1)
@Slf4j
@RequiredArgsConstructor(onConstructor = @__(@Autowired))
public class ContextRefreshedEventListener implements ApplicationListener<ContextRefreshedEvent> {
private final WebSocketServer webSocketServer;
private final ConfigurableApplicationContext context;
@Override
public void onApplicationEvent(ContextRefreshedEvent event) {
try {
webSocketServerBoot();
} catch (Exception e) {
log.error("WebSocket启动异常,异常信息:{}", e.getMessage());
e.printStackTrace();
context.close();
System.exit(-1);
}
}
private void webSocketServerBoot() throws Exception {
webSocketServer.init();
webSocketServer.start();
}
}
最后,如果你想测试、看看整合的效果,可以本地跑下我UT项目的小程序前端,WebSocket连接的代码都有了,小程序控制台会打印服务端WebSocket传的测试包。
如果对代码存在疑问,又或者发现某些不对的地方,欢迎留言。