59
59
import java .util .List ;
60
60
import java .util .concurrent .CompletableFuture ;
61
61
import java .util .concurrent .CopyOnWriteArrayList ;
62
+ import java .util .concurrent .ThreadLocalRandom ;
62
63
import java .util .concurrent .TimeUnit ;
63
64
import java .util .function .Consumer ;
65
+ import java .util .function .IntSupplier ;
64
66
import org .assertj .core .api .Assertions ;
65
67
import org .junit .jupiter .api .AfterAll ;
66
68
import org .junit .jupiter .api .AfterEach ;
@@ -388,6 +390,42 @@ void serverHandshakeEvents() throws InterruptedException {
388
390
Assertions .assertThat (completeEvent .selectedSubprotocol ()).isEqualTo (subprotocol );
389
391
}
390
392
393
+ @ Timeout (15 )
394
+ @ Test
395
+ void externalMaskEnabledMasking () throws Exception {
396
+ WebSocketDecoderConfig decoderConfig = webSocketDecoderConfig (true , false , 125 );
397
+ TestWebSocketHandler serverHandler = new TestWebSocketHandler ();
398
+ Channel s = server = testServer ("/" , decoderConfig , serverHandler );
399
+
400
+ TestWebSocketHandler clientHandler = new TestWebSocketHandler ();
401
+ IntSupplier externalMask = () -> ThreadLocalRandom .current ().nextInt ();
402
+ Channel client =
403
+ testClient (s .localAddress (), "/" , null , true , externalMask , false , 125 , clientHandler );
404
+
405
+ clientHandler .onOpen .join ();
406
+ Assertions .assertThat (clientHandler .webSocketFrameFactory )
407
+ .isNotNull ()
408
+ .isExactlyInstanceOf (WebSocketMaskedEncoder .ExternalMaskFrameFactory .class );
409
+ }
410
+
411
+ @ Timeout (15 )
412
+ @ Test
413
+ void externalMaskDisabledMasking () throws Exception {
414
+ WebSocketDecoderConfig decoderConfig = webSocketDecoderConfig (false , false , 125 );
415
+ TestWebSocketHandler serverHandler = new TestWebSocketHandler ();
416
+ Channel s = server = testServer ("/" , decoderConfig , serverHandler );
417
+
418
+ TestWebSocketHandler clientHandler = new TestWebSocketHandler ();
419
+ IntSupplier externalMask = () -> ThreadLocalRandom .current ().nextInt ();
420
+ Channel client =
421
+ testClient (s .localAddress (), "/" , null , false , externalMask , false , 125 , clientHandler );
422
+
423
+ clientHandler .onOpen .join ();
424
+ Assertions .assertThat (clientHandler .webSocketFrameFactory )
425
+ .isNotNull ()
426
+ .isExactlyInstanceOf (WebSocketNonMaskedEncoder .FrameFactory .class );
427
+ }
428
+
391
429
@ Timeout (15 )
392
430
@ CsvSource (
393
431
value = {"true:false" , "false:false" , "false:true" },
@@ -479,6 +517,38 @@ void nomaskingExtensionAccepted() throws Exception {
479
517
.isTrue ();
480
518
}
481
519
520
+ @ Timeout (15 )
521
+ @ Test
522
+ void nomaskingExtensionAcceptedExternalMask () throws Exception {
523
+ WebSocketDecoderConfig decoderConfig = webSocketDecoderConfig (true , false , 125 );
524
+ TestWebSocketHandler serverHandler = new TestWebSocketHandler ();
525
+ ChannelState serverChannelState = new ChannelState ();
526
+ Channel s =
527
+ server =
528
+ testNomaskingExtensionServer (
529
+ true , serverSslContext , decoderConfig , serverHandler , serverChannelState );
530
+
531
+ IntSupplier externalMask = () -> ThreadLocalRandom .current ().nextInt ();
532
+ TestWebSocketHandler clientHandler = new TestWebSocketHandler ();
533
+ ChannelState clientChannelState = new ChannelState ();
534
+ Channel client =
535
+ testNomaskingExtensionClient (
536
+ s .localAddress (),
537
+ true ,
538
+ clientSslContext ,
539
+ true ,
540
+ externalMask ,
541
+ false ,
542
+ 125 ,
543
+ clientHandler ,
544
+ clientChannelState );
545
+
546
+ clientHandler .onOpen .get ();
547
+
548
+ Assertions .assertThat (clientHandler .webSocketFrameFactory )
549
+ .isExactlyInstanceOf (WebSocketNonMaskedEncoder .FrameFactory .class );
550
+ }
551
+
482
552
@ Timeout (15 )
483
553
@ ParameterizedTest
484
554
@ ValueSource (booleans = {true , false })
@@ -524,6 +594,38 @@ void nomaskingExtensionRejected(boolean expectMasked) throws Exception {
524
594
.isTrue ();
525
595
}
526
596
597
+ @ Timeout (15 )
598
+ @ Test
599
+ void nomaskingExtensionRejectedExternalMask () throws Exception {
600
+ WebSocketDecoderConfig decoderConfig = webSocketDecoderConfig (true , false , 125 );
601
+ TestWebSocketHandler serverHandler = new TestWebSocketHandler ();
602
+ ChannelState serverChannelState = new ChannelState ();
603
+ Channel s =
604
+ server =
605
+ testNomaskingExtensionServer (
606
+ false , serverSslContext , decoderConfig , serverHandler , serverChannelState );
607
+
608
+ IntSupplier externalMask = () -> ThreadLocalRandom .current ().nextInt ();
609
+ TestWebSocketHandler clientHandler = new TestWebSocketHandler ();
610
+ ChannelState clientChannelState = new ChannelState ();
611
+ Channel client =
612
+ testNomaskingExtensionClient (
613
+ s .localAddress (),
614
+ true ,
615
+ clientSslContext ,
616
+ true ,
617
+ externalMask ,
618
+ false ,
619
+ 125 ,
620
+ clientHandler ,
621
+ clientChannelState );
622
+
623
+ clientHandler .onOpen .get ();
624
+
625
+ Assertions .assertThat (clientHandler .webSocketFrameFactory )
626
+ .isExactlyInstanceOf (WebSocketMaskedEncoder .ExternalMaskFrameFactory .class );
627
+ }
628
+
527
629
@ Timeout (15 )
528
630
@ ValueSource (booleans = {true , false })
529
631
@ ParameterizedTest
@@ -582,6 +684,27 @@ static Channel testClient(
582
684
int maxFramePayloadLength ,
583
685
WebSocketCallbacksHandler webSocketCallbacksHandler )
584
686
throws InterruptedException {
687
+ return testClient (
688
+ address ,
689
+ path ,
690
+ subprotocol ,
691
+ mask ,
692
+ null ,
693
+ allowMaskMismatch ,
694
+ maxFramePayloadLength ,
695
+ webSocketCallbacksHandler );
696
+ }
697
+
698
+ static Channel testClient (
699
+ SocketAddress address ,
700
+ String path ,
701
+ String subprotocol ,
702
+ boolean mask ,
703
+ IntSupplier externalMask ,
704
+ boolean allowMaskMismatch ,
705
+ int maxFramePayloadLength ,
706
+ WebSocketCallbacksHandler webSocketCallbacksHandler )
707
+ throws InterruptedException {
585
708
return new Bootstrap ()
586
709
.group (new NioEventLoopGroup (1 ))
587
710
.channel (NioSocketChannel .class )
@@ -597,6 +720,7 @@ protected void initChannel(SocketChannel ch) {
597
720
WebSocketClientProtocolHandler .create ()
598
721
.path (path )
599
722
.mask (mask )
723
+ .mask (externalMask )
600
724
.allowMaskMismatch (allowMaskMismatch )
601
725
.maxFramePayloadLength (maxFramePayloadLength )
602
726
.webSocketHandler (webSocketCallbacksHandler )
@@ -680,6 +804,29 @@ static Channel testNomaskingExtensionClient(
680
804
WebSocketCallbacksHandler webSocketCallbacksHandler ,
681
805
ChannelState channelState )
682
806
throws InterruptedException {
807
+ return testNomaskingExtensionClient (
808
+ address ,
809
+ nomaskingExtension ,
810
+ sslContext ,
811
+ mask ,
812
+ null ,
813
+ allowMaskMismatch ,
814
+ maxFramePayloadLength ,
815
+ webSocketCallbacksHandler ,
816
+ channelState );
817
+ }
818
+
819
+ static Channel testNomaskingExtensionClient (
820
+ SocketAddress address ,
821
+ boolean nomaskingExtension ,
822
+ SslContext sslContext ,
823
+ boolean mask ,
824
+ IntSupplier externalMask ,
825
+ boolean allowMaskMismatch ,
826
+ int maxFramePayloadLength ,
827
+ WebSocketCallbacksHandler webSocketCallbacksHandler ,
828
+ ChannelState channelState )
829
+ throws InterruptedException {
683
830
return new Bootstrap ()
684
831
.group (new NioEventLoopGroup (1 ))
685
832
.channel (NioSocketChannel .class )
@@ -696,6 +843,7 @@ protected void initChannel(SocketChannel ch) {
696
843
.path ("/" )
697
844
.nomaskingExtension (nomaskingExtension )
698
845
.mask (mask )
846
+ .mask (externalMask )
699
847
.allowMaskMismatch (allowMaskMismatch )
700
848
.maxFramePayloadLength (maxFramePayloadLength )
701
849
.webSocketHandler (webSocketCallbacksHandler )
0 commit comments