Skip to content

Commit 32de987

Browse files
committed
websocket encoder: add support for external frame masking keys
1 parent 55dc3ed commit 32de987

File tree

9 files changed

+265
-29
lines changed

9 files changed

+265
-29
lines changed

netty-websocket-http1-test/src/test/java/com/jauntsdn/netty/handler/codec/http/websocketx/WebSocketHandshakeTest.java

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,10 @@
5959
import java.util.List;
6060
import java.util.concurrent.CompletableFuture;
6161
import java.util.concurrent.CopyOnWriteArrayList;
62+
import java.util.concurrent.ThreadLocalRandom;
6263
import java.util.concurrent.TimeUnit;
6364
import java.util.function.Consumer;
65+
import java.util.function.IntSupplier;
6466
import org.assertj.core.api.Assertions;
6567
import org.junit.jupiter.api.AfterAll;
6668
import org.junit.jupiter.api.AfterEach;
@@ -388,6 +390,42 @@ void serverHandshakeEvents() throws InterruptedException {
388390
Assertions.assertThat(completeEvent.selectedSubprotocol()).isEqualTo(subprotocol);
389391
}
390392

393+
@Timeout(15)
394+
@Test
395+
void externalMaskEnabledMasking() throws Exception {
396+
WebSocketDecoderConfig decoderConfig = webSocketDecoderConfig(true, false, 125);
397+
TestWebSocketHandler serverHandler = new TestWebSocketHandler();
398+
Channel s = server = testServer("/", decoderConfig, serverHandler);
399+
400+
TestWebSocketHandler clientHandler = new TestWebSocketHandler();
401+
IntSupplier externalMask = () -> ThreadLocalRandom.current().nextInt();
402+
Channel client =
403+
testClient(s.localAddress(), "/", null, true, externalMask, false, 125, clientHandler);
404+
405+
clientHandler.onOpen.join();
406+
Assertions.assertThat(clientHandler.webSocketFrameFactory)
407+
.isNotNull()
408+
.isExactlyInstanceOf(WebSocketMaskedEncoder.ExternalMaskFrameFactory.class);
409+
}
410+
411+
@Timeout(15)
412+
@Test
413+
void externalMaskDisabledMasking() throws Exception {
414+
WebSocketDecoderConfig decoderConfig = webSocketDecoderConfig(false, false, 125);
415+
TestWebSocketHandler serverHandler = new TestWebSocketHandler();
416+
Channel s = server = testServer("/", decoderConfig, serverHandler);
417+
418+
TestWebSocketHandler clientHandler = new TestWebSocketHandler();
419+
IntSupplier externalMask = () -> ThreadLocalRandom.current().nextInt();
420+
Channel client =
421+
testClient(s.localAddress(), "/", null, false, externalMask, false, 125, clientHandler);
422+
423+
clientHandler.onOpen.join();
424+
Assertions.assertThat(clientHandler.webSocketFrameFactory)
425+
.isNotNull()
426+
.isExactlyInstanceOf(WebSocketNonMaskedEncoder.FrameFactory.class);
427+
}
428+
391429
@Timeout(15)
392430
@CsvSource(
393431
value = {"true:false", "false:false", "false:true"},
@@ -479,6 +517,38 @@ void nomaskingExtensionAccepted() throws Exception {
479517
.isTrue();
480518
}
481519

520+
@Timeout(15)
521+
@Test
522+
void nomaskingExtensionAcceptedExternalMask() throws Exception {
523+
WebSocketDecoderConfig decoderConfig = webSocketDecoderConfig(true, false, 125);
524+
TestWebSocketHandler serverHandler = new TestWebSocketHandler();
525+
ChannelState serverChannelState = new ChannelState();
526+
Channel s =
527+
server =
528+
testNomaskingExtensionServer(
529+
true, serverSslContext, decoderConfig, serverHandler, serverChannelState);
530+
531+
IntSupplier externalMask = () -> ThreadLocalRandom.current().nextInt();
532+
TestWebSocketHandler clientHandler = new TestWebSocketHandler();
533+
ChannelState clientChannelState = new ChannelState();
534+
Channel client =
535+
testNomaskingExtensionClient(
536+
s.localAddress(),
537+
true,
538+
clientSslContext,
539+
true,
540+
externalMask,
541+
false,
542+
125,
543+
clientHandler,
544+
clientChannelState);
545+
546+
clientHandler.onOpen.get();
547+
548+
Assertions.assertThat(clientHandler.webSocketFrameFactory)
549+
.isExactlyInstanceOf(WebSocketNonMaskedEncoder.FrameFactory.class);
550+
}
551+
482552
@Timeout(15)
483553
@ParameterizedTest
484554
@ValueSource(booleans = {true, false})
@@ -524,6 +594,38 @@ void nomaskingExtensionRejected(boolean expectMasked) throws Exception {
524594
.isTrue();
525595
}
526596

597+
@Timeout(15)
598+
@Test
599+
void nomaskingExtensionRejectedExternalMask() throws Exception {
600+
WebSocketDecoderConfig decoderConfig = webSocketDecoderConfig(true, false, 125);
601+
TestWebSocketHandler serverHandler = new TestWebSocketHandler();
602+
ChannelState serverChannelState = new ChannelState();
603+
Channel s =
604+
server =
605+
testNomaskingExtensionServer(
606+
false, serverSslContext, decoderConfig, serverHandler, serverChannelState);
607+
608+
IntSupplier externalMask = () -> ThreadLocalRandom.current().nextInt();
609+
TestWebSocketHandler clientHandler = new TestWebSocketHandler();
610+
ChannelState clientChannelState = new ChannelState();
611+
Channel client =
612+
testNomaskingExtensionClient(
613+
s.localAddress(),
614+
true,
615+
clientSslContext,
616+
true,
617+
externalMask,
618+
false,
619+
125,
620+
clientHandler,
621+
clientChannelState);
622+
623+
clientHandler.onOpen.get();
624+
625+
Assertions.assertThat(clientHandler.webSocketFrameFactory)
626+
.isExactlyInstanceOf(WebSocketMaskedEncoder.ExternalMaskFrameFactory.class);
627+
}
628+
527629
@Timeout(15)
528630
@ValueSource(booleans = {true, false})
529631
@ParameterizedTest
@@ -582,6 +684,27 @@ static Channel testClient(
582684
int maxFramePayloadLength,
583685
WebSocketCallbacksHandler webSocketCallbacksHandler)
584686
throws InterruptedException {
687+
return testClient(
688+
address,
689+
path,
690+
subprotocol,
691+
mask,
692+
null,
693+
allowMaskMismatch,
694+
maxFramePayloadLength,
695+
webSocketCallbacksHandler);
696+
}
697+
698+
static Channel testClient(
699+
SocketAddress address,
700+
String path,
701+
String subprotocol,
702+
boolean mask,
703+
IntSupplier externalMask,
704+
boolean allowMaskMismatch,
705+
int maxFramePayloadLength,
706+
WebSocketCallbacksHandler webSocketCallbacksHandler)
707+
throws InterruptedException {
585708
return new Bootstrap()
586709
.group(new NioEventLoopGroup(1))
587710
.channel(NioSocketChannel.class)
@@ -597,6 +720,7 @@ protected void initChannel(SocketChannel ch) {
597720
WebSocketClientProtocolHandler.create()
598721
.path(path)
599722
.mask(mask)
723+
.mask(externalMask)
600724
.allowMaskMismatch(allowMaskMismatch)
601725
.maxFramePayloadLength(maxFramePayloadLength)
602726
.webSocketHandler(webSocketCallbacksHandler)
@@ -680,6 +804,29 @@ static Channel testNomaskingExtensionClient(
680804
WebSocketCallbacksHandler webSocketCallbacksHandler,
681805
ChannelState channelState)
682806
throws InterruptedException {
807+
return testNomaskingExtensionClient(
808+
address,
809+
nomaskingExtension,
810+
sslContext,
811+
mask,
812+
null,
813+
allowMaskMismatch,
814+
maxFramePayloadLength,
815+
webSocketCallbacksHandler,
816+
channelState);
817+
}
818+
819+
static Channel testNomaskingExtensionClient(
820+
SocketAddress address,
821+
boolean nomaskingExtension,
822+
SslContext sslContext,
823+
boolean mask,
824+
IntSupplier externalMask,
825+
boolean allowMaskMismatch,
826+
int maxFramePayloadLength,
827+
WebSocketCallbacksHandler webSocketCallbacksHandler,
828+
ChannelState channelState)
829+
throws InterruptedException {
683830
return new Bootstrap()
684831
.group(new NioEventLoopGroup(1))
685832
.channel(NioSocketChannel.class)
@@ -696,6 +843,7 @@ protected void initChannel(SocketChannel ch) {
696843
.path("/")
697844
.nomaskingExtension(nomaskingExtension)
698845
.mask(mask)
846+
.mask(externalMask)
699847
.allowMaskMismatch(allowMaskMismatch)
700848
.maxFramePayloadLength(maxFramePayloadLength)
701849
.webSocketHandler(webSocketCallbacksHandler)

netty-websocket-http1/src/main/java/com/jauntsdn/netty/handler/codec/http/websocketx/WebSocketCallbacksFrameEncoder.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,20 @@
1818

1919
import io.netty.channel.ChannelHandlerContext;
2020
import io.netty.handler.codec.http.websocketx.WebSocketFrameEncoder;
21+
import java.util.function.IntSupplier;
2122

2223
interface WebSocketCallbacksFrameEncoder extends WebSocketFrameEncoder {
2324

2425
WebSocketFrameFactory frameFactory(ChannelHandlerContext ctx);
2526

26-
static WebSocketCallbacksFrameEncoder frameEncoder(boolean performMasking) {
27-
return performMasking ? WebSocketMaskedEncoder.INSTANCE : WebSocketNonMaskedEncoder.INSTANCE;
27+
static WebSocketCallbacksFrameEncoder frameEncoder(
28+
boolean performMasking, IntSupplier externalMask) {
29+
if (performMasking) {
30+
if (externalMask != null) {
31+
return new WebSocketMaskedEncoder(externalMask);
32+
}
33+
return WebSocketMaskedEncoder.INSTANCE;
34+
}
35+
return WebSocketNonMaskedEncoder.INSTANCE;
2836
}
2937
}

netty-websocket-http1/src/main/java/com/jauntsdn/netty/handler/codec/http/websocketx/WebSocketClientHandshaker.java

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,20 @@
2323
import io.netty.handler.codec.http.websocketx.WebSocketVersion;
2424
import java.net.URI;
2525
import java.util.Objects;
26+
import java.util.function.IntSupplier;
27+
import javax.annotation.Nullable;
2628

2729
public class WebSocketClientHandshaker extends WebSocketClientHandshaker13 {
2830
private final boolean expectMaskedFrames;
31+
private final IntSupplier externalMask;
2932

3033
public WebSocketClientHandshaker(
3134
URI webSocketURL,
3235
String subprotocol,
3336
HttpHeaders customHeaders,
3437
int maxFramePayloadLength,
3538
boolean performMasking,
39+
@Nullable IntSupplier externalMask,
3640
boolean expectMaskedFrames,
3741
boolean allowMaskMismatch) {
3842
super(
@@ -46,6 +50,26 @@ public WebSocketClientHandshaker(
4650
allowMaskMismatch,
4751
/*unused*/ -1);
4852
this.expectMaskedFrames = expectMaskedFrames;
53+
this.externalMask = externalMask;
54+
}
55+
56+
public WebSocketClientHandshaker(
57+
URI webSocketURL,
58+
String subprotocol,
59+
HttpHeaders customHeaders,
60+
int maxFramePayloadLength,
61+
boolean performMasking,
62+
boolean expectMaskedFrames,
63+
boolean allowMaskMismatch) {
64+
this(
65+
webSocketURL,
66+
subprotocol,
67+
customHeaders,
68+
maxFramePayloadLength,
69+
performMasking,
70+
null,
71+
expectMaskedFrames,
72+
allowMaskMismatch);
4973
}
5074

5175
@Override
@@ -56,6 +80,6 @@ protected WebSocketFrameDecoder newWebsocketDecoder() {
5680

5781
@Override
5882
protected WebSocketFrameEncoder newWebSocketEncoder() {
59-
return WebSocketCallbacksFrameEncoder.frameEncoder(isPerformMasking());
83+
return WebSocketCallbacksFrameEncoder.frameEncoder(isPerformMasking(), externalMask);
6084
}
6185
}

netty-websocket-http1/src/main/java/com/jauntsdn/netty/handler/codec/http/websocketx/WebSocketClientNomaskingHandshaker.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@
2727
import io.netty.handler.codec.http.websocketx.WebSocketFrameEncoder;
2828
import java.net.URI;
2929
import java.util.Objects;
30+
import java.util.function.IntSupplier;
3031

3132
final class WebSocketClientNomaskingHandshaker extends WebSocketClientHandshaker {
33+
private final IntSupplier externalMask;
3234
private Channel channel;
3335

3436
WebSocketClientNomaskingHandshaker(
@@ -37,7 +39,8 @@ final class WebSocketClientNomaskingHandshaker extends WebSocketClientHandshaker
3739
HttpHeaders customHeaders,
3840
int maxFramePayloadLength,
3941
boolean expectMaskedFrames,
40-
boolean allowMaskMismatch) {
42+
boolean allowMaskMismatch,
43+
IntSupplier externalMask) {
4144
super(
4245
Objects.requireNonNull(webSocketURL, "webSocketURL"),
4346
subprotocol,
@@ -46,6 +49,7 @@ final class WebSocketClientNomaskingHandshaker extends WebSocketClientHandshaker
4649
false,
4750
expectMaskedFrames,
4851
allowMaskMismatch);
52+
this.externalMask = externalMask;
4953
}
5054

5155
public static boolean supportsNoMaskingExtension(URI webSocketURL) {
@@ -59,7 +63,7 @@ protected WebSocketFrameDecoder newWebsocketDecoder() {
5963

6064
@Override
6165
protected WebSocketFrameEncoder newWebSocketEncoder() {
62-
return WebSocketCallbacksFrameEncoder.frameEncoder(false);
66+
return WebSocketCallbacksFrameEncoder.frameEncoder(false, externalMask);
6367
}
6468

6569
@Override
@@ -92,7 +96,7 @@ private void nomaskingExtensionComplete(FullHttpResponse handshakeResponse) {
9296
pipeline.replace(
9397
WebSocketFrameEncoder.class,
9498
"ws-encoder",
95-
WebSocketCallbacksFrameEncoder.frameEncoder(true));
99+
WebSocketCallbacksFrameEncoder.frameEncoder(true, externalMask));
96100
}
97101
}
98102
}

0 commit comments

Comments
 (0)