前言

之前在写UT项目的时候其实就想写这篇博客了,只是碍于各种破事,被迫拖到现在。网上很多SpringBoot整合Netty 或者 使用Netty搭建WebSocket的博客几乎都没有提到如何对连接进行认证,还好我Netty的基础打得比较好,没花多少时间就攻破连接认证的问题。

如果你想学习使用SpringBoot整合Netty搭建WebSocket服务,同时又苦恼不知道怎么对WebSocket连接进行认证,那么这篇博客一定很适合你。

PS:本篇博客注重于搭建服务,对于理论方面的知识,仅会蜻蜓点水。

代码地址

由于时间关系,我就不单独重写一个demo了,完整的代码已经整合到项目中:https://gitee.com/wenjie2018/UT-APP

核心代码都在:run.ut.app.netty目录下,自行查看即可。

依赖版本:

        <dependency>
            <groupId>io.netty</groupId>
            <artifactId>netty-all</artifactId>
            <version>4.1.42.Final</version>
        </dependency>

核心代码

工具类

由于我整合的方式比较特殊,Netty服务启动是在refresh之后的,所以要获取容器的一些Bean,就需要到获取Bean的工具类。

SpringUtils.java


import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.ApplicationContextException;
import org.springframework.stereotype.Component;
import org.springframework.util.ObjectUtils;

/**
 * Spring utilities
 *
 * @author wenjie
 */
@Component
public class SpringUtils implements ApplicationContextAware {

    private static ApplicationContext applicationContext;

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        if (ObjectUtils.isEmpty(applicationContext)) {
            throw new ApplicationContextException("applicationContext must not be null");
        }
        SpringUtils.applicationContext = applicationContext;
    }


    /**
     * Gets ApplicationContext
     * @return ApplicationContext
     */
    public static ApplicationContext getApplicationContext() {
        return applicationContext;
    }

    /**
     * Gets bean by bean's name
     * @param name    bean's name
     * @return bean
     */
    public static Object getBean(String name) {
        return getApplicationContext().getBean(name);
    }


    /**
     * Gets bean by bean's class(java.lang.Class)
     * @param clazz    bean Class
     * @param <T>      Class type
     * @return bean
     */
    public static <T> T getBean(Class<T> clazz) {
        return getApplicationContext().getBean(clazz);
    }


    /**
     * Gets bean by bean's class and bean's name
     * @param name    bean name
     * @param clazz   bean Class
     * @param <T>     Class type
     * @return bean
     */
    public static <T> T getBean(String name, Class<T> clazz) {
        return getApplicationContext().getBean(name, clazz);
    }
}

各种Handler

连接认证Handler


import io.jsonwebtoken.Claims;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaders;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import run.ut.app.security.util.JwtOperator;

/**
 * FullHttpRequest
 */

@ChannelHandler.Sharable
@Slf4j
@Component
@RequiredArgsConstructor(onConstructor = @__(@Autowired))
public class AuthHandler extends ChannelInboundHandlerAdapter {

    private final UserChannelManager userChannelManager;
    private final JwtOperator jwtOperator;

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        if (msg instanceof FullHttpRequest) {
            FullHttpRequest request = (FullHttpRequest) msg;
            HttpHeaders headers = request.headers();
            if (headers.size() < 1) {
                ctx.channel().close();
                return;
            }
            String token = headers.get("token");
            Claims claims = jwtOperator.getClaimsFromToken(token);
            Long uid = Long.valueOf(claims.get("uid") + "");
            userChannelManager.add(uid, ctx.channel());
            log.debug("Authentication success. uid: {}", uid);
            ctx.pipeline().remove(this);
            // 对事件进行传播,知道完成WebSocket连接。
            ctx.fireChannelRead(msg);
        } else {
            ctx.channel().close();
        }
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        cause.printStackTrace();
        userChannelManager.remove(ctx.channel());
        ctx.channel().close();
    }
}

  • 业务代码不在这里补充了,有兴趣可以直接看项目完整源码。
  • 另外注意我这里的:ctx.pipeline().remove(this);如果你不仅仅需要连接认证,你还要对每次接受到的信息进行认证,那么就可以把这行代码删除。

处理WebSocket消息的Handler


import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import run.ut.app.model.enums.WebSocketMsgTypeEnum;
import run.ut.app.model.support.WebSocketMsg;
import run.ut.app.utils.JsonUtils;


/**
 * @author wenjie
 */

@ChannelHandler.Sharable
@Slf4j
@Component
@RequiredArgsConstructor(onConstructor = @__(@Autowired))
public class ClientMsgHandler extends SimpleChannelInboundHandler<TextWebSocketFrame> {

    private final UserChannelManager userChannelManager;

    @Override
    protected void channelRead0(ChannelHandlerContext ctx, TextWebSocketFrame msg)
        throws Exception {
        String json = msg.text();
        WebSocketMsg webSocketMsg = JsonUtils.jsonToObject(json, WebSocketMsg.class);
        WebSocketMsgTypeEnum type = WebSocketMsgTypeEnum.getByType(webSocketMsg.getType());
        switch (type) {
            case KEEPALIVE:
                log.debug("Get keepalive frame");
        }
    }

    @Override
    public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
        userChannelManager.remove(ctx.channel());
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        cause.printStackTrace();
        ctx.channel().close();
        userChannelManager.remove(ctx.channel());
    }
}

  • 目前只有对心跳包的处理,可根据业务不同增加更多的消息处理。

空闲事件Handler


import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

/**
 * Heart Beat Handler. If the client's heartbeat frames is not received for a long time
 *
 * @author wenjie
 */

@ChannelHandler.Sharable
@Slf4j
@Component
@RequiredArgsConstructor(onConstructor = @__(@Autowired))
public class HeartBeatHandler extends ChannelInboundHandlerAdapter {

    private final UserChannelManager userChannelManager;

    @Override
    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {

        if (evt instanceof IdleStateEvent) {
            IdleStateEvent event = (IdleStateEvent)evt;

            if (event.state() == IdleState.READER_IDLE) {
                log.debug("READER_IDLE...");
            } else if (event.state() == IdleState.WRITER_IDLE) {
                log.debug("WRITER_IDLE...");
            } else if (event.state() == IdleState.ALL_IDLE) {
                // Closes the channel which state is ALL_IDLE
                Channel channel = ctx.channel();
                // Clear cache
                userChannelManager.remove(channel);
                channel.close();
                log.debug("close channel: {}", channel.id().asLongText());
            }
        }

    }

}
  • 若超过指定时间(时间由IdleStateHandler设置,下面会说到),就断开连接。

IP限流Handler


import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import run.ut.app.service.RedisService;

import java.net.InetSocketAddress;
import java.util.concurrent.TimeUnit;

/**
 * @author wenjie
 */

@Slf4j
@ChannelHandler.Sharable
@Component
@RequiredArgsConstructor(onConstructor = @__(@Autowired))
public class WebSocketRateLimitHandler extends ChannelInboundHandlerAdapter {

    private final RedisService redisService;
    private final int EXPIRE_TIME = 5;
    private final int MAX = 10;

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {

        InetSocketAddress insocket = (InetSocketAddress) ctx.channel().remoteAddress();
        String ip = insocket.getAddress().getHostAddress();
        // Generates key
        String key = String.format("ut_wss_limit_rate_%s", ip);
        // Checks
        boolean over = redisService.overRequestRateLimit(key, MAX, EXPIRE_TIME, TimeUnit.SECONDS, "websocket");
        if (over) {
            log.debug("IP: {} 触发限流了 ",ip);
            ctx.channel().close();
            return;
        }
        ctx.fireChannelRead(msg);
    }
}

  • 限流的代码在我之前的博客有提到,这里就不重复讲了。

管理用户和Channel映射关系的容器类

由于考虑到多端在线,所以用的是ConcurrentHashMap<Long, Set<Channel>>,如果业务上只允许单设备在线,则ConcurrentHashMap<Long, Channel>即可。


import com.fasterxml.jackson.core.JsonProcessingException;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
import org.springframework.stereotype.Component;
import org.springframework.util.ObjectUtils;
import run.ut.app.model.enums.WebSocketMsgTypeEnum;
import run.ut.app.model.support.WebSocketMsg;
import run.ut.app.utils.JsonUtils;

import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

/**
 * It is used to manage the mapping between user and channel.
 *
 * @author wenjie
 */

@Slf4j
@Component
@RequiredArgsConstructor(onConstructor = @__(@Autowired))
public class UserChannelManager {

    private ConcurrentHashMap<Long, Set<Channel>> userChannelMap = new ConcurrentHashMap<>(1 << 8);

    private final Lock lock = new ReentrantLock();

    /**
     * Save the mapping of uid and channel
     *
     * @param uid        uid
     * @param channel    channel
     */
    public void add(@NonNull Long uid, @NonNull Channel channel) {
        lock.lock();
        Set<Channel> channels = userChannelMap.get(uid);
        if (ObjectUtils.isEmpty(channels) || channels.size() == 0) {
            Set<Channel> channelSet = new HashSet<>();
            channelSet.add(channel);
            userChannelMap.put(uid, channelSet);
        } else {
            channels.add(channel);
            userChannelMap.put(uid, channels);
        }
        lock.unlock();
    }

    /**
     * Remove the element by uid
     *
     * @param uid    uid
     */
    public void remove(@NonNull Long uid) {
        userChannelMap.remove(uid);
    }

    /**
     *  Remove the cache by channel
     *
     * @param channel    channel
     */
    public void remove(@NonNull Channel channel) {
        userChannelMap.entrySet().stream().filter(entry -> entry.getValue().contains(channel))
            .forEach(entry -> entry.getValue().remove(channel));
    }

    /**
     * Get channel by uid
     * @param uid     uid
     * @return channel
     */
    @Nullable
    public Set<Channel> get(@NonNull Long uid) {
        return userChannelMap.get(uid);
    }

    /**
     * Clear cache
     */
    public void clearAll() {
        userChannelMap.clear();
    }


    /**
     * Write and flush by uid
     *
     * @param uid      uid
     * @param msgObj   msg object, it will be automatically converted to json.
     * @throws JsonProcessingException If msgObj fails to convert to json.
     */
    public void writeAndFlush(@NonNull Long uid, @NonNull Object msgObj, @NonNull WebSocketMsgTypeEnum typeEnum) throws JsonProcessingException {
        Set<Channel> channelSet = userChannelMap.get(uid);
        if (ObjectUtils.isEmpty(channelSet) || channelSet.size() == 0) {
            return;
        }
        for (Channel channel : channelSet) {
            if (channel.isActive()) {
                WebSocketMsg webSocketMsg = new WebSocketMsg()
                    .setType(typeEnum.getType())
                    .setMsg(msgObj);
                String json = JsonUtils.objectToJson(webSocketMsg);
                TextWebSocketFrame textWebSocketFrame = new TextWebSocketFrame(json);
                ChannelFuture channelFuture = channel.writeAndFlush(textWebSocketFrame);
                channelFuture.addListener((ChannelFutureListener)future -> {
                    log.debug("对uid:{}, 发送websocket消息:{}", uid, json);
                });

            }
        }
    }

    /**
     * Write and flush to every user
     * @param msgObj msg object, it will be automatically converted to json.
     * @throws JsonProcessingException If msgObj fails to convert to json.
     */
    public void writeAndFlush(@NonNull Object msgObj, @NonNull WebSocketMsgTypeEnum typeEnum) throws JsonProcessingException {
        WebSocketMsg webSocketMsg = new WebSocketMsg()
            .setType(typeEnum.getType())
            .setMsg(msgObj);
        String json = JsonUtils.objectToJson(webSocketMsg);
        TextWebSocketFrame textWebSocketFrame = new TextWebSocketFrame(json);
        userChannelMap.forEach((uid, channels) -> {
            for (Channel channel : channels) {
                if (channel.isActive()) {
                    ChannelFuture channelFuture = channel.writeAndFlush(textWebSocketFrame);
                    channelFuture.addListener((ChannelFutureListener)future -> {
                        log.debug("对uid:{}, 发送websocket消息:{}", uid, json);
                    });
                }
            }
        });
    }
}

配置类

package run.ut.app.config.netty;

import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;

/**
 * @author wenjie
 */

@Component
@ConfigurationProperties(prefix = "netty.websocket")
@Data
public class WebSocketConfiguration {

    private int port;

    private String contextPath = "/ws";

    /**
     * 0 for automatic setting(The default is CPU * 2)
     */
    private int workerThreads = 0;

    /**
     * 0 for automatic setting (The default is CPU * 2)
     */
    private int bossThreads = 0;

    /**
     * Only in Linux environments can this be set to true
     * @see <a href="https://stackoverflow.com/questions/35568365/netty-epolleventloopgroup-vs-nioeventloopgroup-which-should-i-choose-on-centos">link<a/>
     */
    private boolean epoll = false;
}

配置样例如下:

##########################################################
##
## websocket 配置
##
##########################################################
netty:
  websocket:
    # 服务端口
    port: 8088
    context-path: /ut/ws
    # 0表示自动设置,自动设置的值为:CPU核心数 * 2
    boss-threads: 0
    # 0表示自动设置,自动设置的值为:CPU核心数 * 2
    worker-threads: 0
    # 只能在linux和mac系统下才能设置为true,可以获得更高的性能
    epoll: false

Channel初始化器


import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.stream.ChunkedWriteHandler;
import io.netty.handler.timeout.IdleStateHandler;
import run.ut.app.config.netty.WebSocketConfiguration;
import run.ut.app.utils.SpringUtils;

import java.util.concurrent.TimeUnit;

/**
 * Init Channel
 *
 * @author wenjie
 */
public class WSServerInitialzer extends ChannelInitializer<SocketChannel> {

    private AuthHandler authHandler;
    private ClientMsgHandler clientMsgHandler;
    private HeartBeatHandler heartBeatHandler;
    private WebSocketRateLimitHandler webSocketRateLimitHandler;

    WSServerInitialzer() {
        authHandler = SpringUtils.getBean(AuthHandler.class);
        clientMsgHandler = SpringUtils.getBean(ClientMsgHandler.class);
        heartBeatHandler = SpringUtils.getBean(HeartBeatHandler.class);
        webSocketRateLimitHandler = SpringUtils.getBean(WebSocketRateLimitHandler.class);
    }

    @Override
    protected void initChannel(SocketChannel ch) throws Exception {

        WebSocketConfiguration webSocketConfiguration = SpringUtils.getBean(WebSocketConfiguration.class);

        ChannelPipeline pipeline = ch.pipeline();

        pipeline
            .addLast(webSocketRateLimitHandler)
            .addLast(new HttpServerCodec())
            .addLast(new ChunkedWriteHandler())
            .addLast(new HttpObjectAggregator(1024 * 64))
            .addLast(authHandler)
//            .addLast(new IdleStateHandler(30, 30, 5, TimeUnit.SECONDS)) // test
            .addLast(new IdleStateHandler(10, 10, 30, TimeUnit.MINUTES))
            .addLast(heartBeatHandler)
            .addLast(new WebSocketServerProtocolHandler(webSocketConfiguration.getContextPath()))
            .addLast(clientMsgHandler);
    }

}

初始化 & 启动 工具类


import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.epoll.EpollEventLoopGroup;
import io.netty.channel.epoll.EpollServerSocketChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.util.concurrent.DefaultThreadFactory;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import run.ut.app.config.netty.WebSocketConfiguration;
import run.ut.app.exception.WebSocketException;

/**
 * WebSocket Server
 *
 * @author wenjie
 */

@Component
@Slf4j
@RequiredArgsConstructor(onConstructor = @__(@Autowired))
public class WebSocketServer {

    private final WebSocketConfiguration webSocketConfiguration;

    private ServerBootstrap server;

    public void init() {
        EventLoopGroup bossGroup, workerGroup;
        server = new ServerBootstrap();

        int bossThreads = webSocketConfiguration.getBossThreads();
        int workerThreads = webSocketConfiguration.getWorkerThreads();
        boolean epoll = webSocketConfiguration.isEpoll();

        if (epoll) {
            bossGroup = new EpollEventLoopGroup(bossThreads,
                new DefaultThreadFactory("WebSocketBossGroup", true));
            workerGroup = new EpollEventLoopGroup(workerThreads,
                new DefaultThreadFactory("WebSocketWorkerGroup", true));
            server.channel(EpollServerSocketChannel.class);
        } else {
            bossGroup = new NioEventLoopGroup(bossThreads);
            workerGroup = new NioEventLoopGroup(workerThreads);
            server.channel(NioServerSocketChannel.class);
        }

        server.group(bossGroup, workerGroup)
            .childHandler(new WSServerInitialzer())
            .childOption(ChannelOption.TCP_NODELAY, true)
            .childOption(ChannelOption.SO_KEEPALIVE, true);
    }

    public void start() throws Exception {
        log.info("WebSocketServer - Starting...");
        ChannelFuture channelFuture = server.bind(webSocketConfiguration.getPort()).sync();
        channelFuture.addListener(future -> {
            if (future.isSuccess()) {
                log.info("WebSocketServer - Start completed.");
            } else {
                throw new WebSocketException("WebSocket启动失败!");
            }
        });
    }
}

启动服务

SpringBoot启动后,WebSocket服务在监听到ContextRefreshedEvent之后(IOC已完成初始化)启动。


import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationListener;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.event.ContextRefreshedEvent;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;
import run.ut.app.netty.WebSocketServer;

/**
 * The method executed after the application context is refreshed.
 *
 * @author wenjie
 */

@Configuration
@Order(Ordered.HIGHEST_PRECEDENCE + 1)
@Slf4j
@RequiredArgsConstructor(onConstructor = @__(@Autowired))
public class ContextRefreshedEventListener implements ApplicationListener<ContextRefreshedEvent> {

    private final WebSocketServer webSocketServer;
    private final ConfigurableApplicationContext context;

    @Override
    public void onApplicationEvent(ContextRefreshedEvent event) {
        try {
            webSocketServerBoot();
        } catch (Exception e) {
            log.error("WebSocket启动异常,异常信息:{}", e.getMessage());
            e.printStackTrace();
            context.close();
            System.exit(-1);
        }
    }

    private void webSocketServerBoot() throws Exception {
        webSocketServer.init();
        webSocketServer.start();
    }
}

最后,如果你想测试、看看整合的效果,可以本地跑下我UT项目的小程序前端,WebSocket连接的代码都有了,小程序控制台会打印服务端WebSocket传的测试包。

如果对代码存在疑问,又或者发现某些不对的地方,欢迎留言。