42
42
import io .netty .handler .codec .http .HttpVersion ;
43
43
import io .netty .handler .codec .http .websocketx .WebSocketClientHandshakeException ;
44
44
import io .netty .handler .codec .http .websocketx .WebSocketDecoderConfig ;
45
+ import io .netty .handler .codec .http .websocketx .WebSocketServerProtocolHandler .HandshakeComplete ;
46
+ import io .netty .handler .codec .http .websocketx .WebSocketServerProtocolHandler .ServerHandshakeStateEvent ;
45
47
import io .netty .util .AttributeKey ;
46
48
import io .netty .util .ReferenceCountUtil ;
47
49
import io .netty .util .concurrent .DefaultPromise ;
48
50
import io .netty .util .concurrent .Future ;
49
51
import io .netty .util .concurrent .Promise ;
50
52
import java .net .SocketAddress ;
51
53
import java .nio .channels .ClosedChannelException ;
54
+ import java .util .List ;
52
55
import java .util .concurrent .CompletableFuture ;
56
+ import java .util .concurrent .CopyOnWriteArrayList ;
53
57
import java .util .function .Consumer ;
54
58
import org .assertj .core .api .Assertions ;
55
59
import org .junit .jupiter .api .AfterEach ;
@@ -305,9 +309,36 @@ protected void initChannel(SocketChannel ch) {
305
309
Assertions .assertThat (client .isOpen ()).isFalse ();
306
310
}
307
311
312
+ @ SuppressWarnings ("deprecation" )
313
+ @ Timeout (15 )
314
+ @ Test
315
+ void serverHandshakeEvents () throws InterruptedException {
316
+ WebSocketDecoderConfig decoderConfig = webSocketDecoderConfig (false , true , 125 );
317
+ TestWebSocketHandler serverHandler = new TestWebSocketHandler ();
318
+ TestWebSocketHandler clientHandler = new TestWebSocketHandler ();
319
+ String subprotocol = "subprotocol" ;
320
+ String path = "/" ;
321
+ Channel s = server = testServer (path , subprotocol , decoderConfig , serverHandler , null );
322
+ Channel client =
323
+ testClient (s .localAddress (), path , subprotocol , true , true , 65_535 , clientHandler );
324
+ serverHandler .onOpen .join ();
325
+ client .close ();
326
+ serverHandler .onClose .join ();
327
+ List <Object > events = serverHandler .events ;
328
+ Assertions .assertThat (events ).hasSize (2 );
329
+ Assertions .assertThat (events .get (0 )).isEqualTo (ServerHandshakeStateEvent .HANDSHAKE_COMPLETE );
330
+ Object event = serverHandler .events .get (1 );
331
+ Assertions .assertThat (event ).isExactlyInstanceOf (HandshakeComplete .class );
332
+ HandshakeComplete completeEvent = (HandshakeComplete ) event ;
333
+ Assertions .assertThat (completeEvent .requestUri ()).isEqualTo (path );
334
+ Assertions .assertThat (completeEvent .requestHeaders ()).isNotNull ().isNotEmpty ();
335
+ Assertions .assertThat (completeEvent .selectedSubprotocol ()).isEqualTo (subprotocol );
336
+ }
337
+
308
338
static Channel testClient (
309
339
SocketAddress address ,
310
340
String path ,
341
+ String subprotocol ,
311
342
boolean mask ,
312
343
boolean allowMaskMismatch ,
313
344
int maxFramePayloadLength ,
@@ -331,6 +362,7 @@ protected void initChannel(SocketChannel ch) {
331
362
.allowMaskMismatch (allowMaskMismatch )
332
363
.maxFramePayloadLength (maxFramePayloadLength )
333
364
.webSocketHandler (webSocketCallbacksHandler )
365
+ .subprotocol (subprotocol )
334
366
.build ();
335
367
336
368
ChannelPipeline pipeline = ch .pipeline ();
@@ -342,6 +374,24 @@ protected void initChannel(SocketChannel ch) {
342
374
.channel ();
343
375
}
344
376
377
+ static Channel testClient (
378
+ SocketAddress address ,
379
+ String path ,
380
+ boolean mask ,
381
+ boolean allowMaskMismatch ,
382
+ int maxFramePayloadLength ,
383
+ WebSocketCallbacksHandler webSocketCallbacksHandler )
384
+ throws InterruptedException {
385
+ return testClient (
386
+ address ,
387
+ path ,
388
+ null ,
389
+ mask ,
390
+ allowMaskMismatch ,
391
+ maxFramePayloadLength ,
392
+ webSocketCallbacksHandler );
393
+ }
394
+
345
395
static Channel testServer (
346
396
String path ,
347
397
WebSocketDecoderConfig decoderConfig ,
@@ -356,12 +406,27 @@ static Channel testServer(
356
406
WebSocketCallbacksHandler webSocketCallbacksHandler ,
357
407
Consumer <Object > nonHandledMessageConsumer )
358
408
throws InterruptedException {
409
+ return testServer (
410
+ path , null , decoderConfig , webSocketCallbacksHandler , nonHandledMessageConsumer );
411
+ }
412
+
413
+ static Channel testServer (
414
+ String path ,
415
+ String subprotocol ,
416
+ WebSocketDecoderConfig decoderConfig ,
417
+ WebSocketCallbacksHandler webSocketCallbacksHandler ,
418
+ Consumer <Object > nonHandledMessageConsumer )
419
+ throws InterruptedException {
359
420
return new ServerBootstrap ()
360
421
.group (new NioEventLoopGroup (1 ))
361
422
.channel (NioServerSocketChannel .class )
362
423
.childHandler (
363
424
new TestAcceptor (
364
- path , decoderConfig , webSocketCallbacksHandler , nonHandledMessageConsumer ))
425
+ path ,
426
+ subprotocol ,
427
+ decoderConfig ,
428
+ webSocketCallbacksHandler ,
429
+ nonHandledMessageConsumer ))
365
430
.bind ("localhost" , 0 )
366
431
.sync ()
367
432
.channel ();
@@ -414,16 +479,19 @@ Future<FullHttpResponse> response() {
414
479
415
480
static class TestAcceptor extends ChannelInitializer <SocketChannel > {
416
481
private final String path ;
482
+ private final String subprotocol ;
417
483
private final WebSocketDecoderConfig webSocketDecoderConfig ;
418
484
private final WebSocketCallbacksHandler webSocketCallbacksHandler ;
419
485
private final Consumer <Object > nonHandledMessageConsumer ;
420
486
421
487
TestAcceptor (
422
488
String path ,
489
+ String subprotocol ,
423
490
WebSocketDecoderConfig decoderConfig ,
424
491
WebSocketCallbacksHandler webSocketCallbacksHandler ,
425
492
Consumer <Object > nonHandledMessageConsumer ) {
426
493
this .path = path ;
494
+ this .subprotocol = subprotocol ;
427
495
this .webSocketDecoderConfig = decoderConfig ;
428
496
this .webSocketCallbacksHandler = webSocketCallbacksHandler ;
429
497
this .nonHandledMessageConsumer = nonHandledMessageConsumer ;
@@ -436,6 +504,7 @@ protected void initChannel(SocketChannel ch) {
436
504
WebSocketServerProtocolHandler webSocketProtocolHandler =
437
505
WebSocketServerProtocolHandler .create ()
438
506
.path (path )
507
+ .subprotocols (subprotocol )
439
508
.decoderConfig (webSocketDecoderConfig )
440
509
.webSocketCallbacksHandler (webSocketCallbacksHandler )
441
510
.build ();
@@ -458,6 +527,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) {
458
527
static class TestWebSocketHandler implements WebSocketCallbacksHandler {
459
528
final CompletableFuture <Void > onOpen = new CompletableFuture <>();
460
529
final CompletableFuture <Void > onClose = new CompletableFuture <>();
530
+ final List <Object > events = new CopyOnWriteArrayList <>();
461
531
462
532
volatile WebSocketFrameFactory webSocketFrameFactory ;
463
533
volatile Channel channel ;
@@ -477,6 +547,11 @@ public void onChannelRead(
477
547
int opcode ,
478
548
ByteBuf payload ) {}
479
549
550
+ @ Override
551
+ public void onUserEventTriggered (ChannelHandlerContext ctx , Object evt ) {
552
+ events .add (evt );
553
+ }
554
+
480
555
@ Override
481
556
public void onOpen (ChannelHandlerContext ctx ) {
482
557
onOpen .complete (null );
0 commit comments