Skip to content

Commit dc66bec

Browse files
OlegDokukarobertroeser
authored andcommitted
fixes uncontrolled data sending in case of direct propagation of request from requester (#595) (#596)
* fixes uncontrolled data sending in case of direct propagation of request from requester * fixes timeout typo * replaces forEach with explicit loop * optimize access to limitableRequestPublisher Signed-off-by: Oleh Dokuka <[email protected]>
1 parent 87223c6 commit dc66bec

File tree

7 files changed

+212
-26
lines changed

7 files changed

+212
-26
lines changed

rsocket-core/src/main/java/io/rsocket/RSocketClient.java

+9-2
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,13 @@ class RSocketClient implements RSocket {
8282
connection.onClose().doFinally(signalType -> terminate()).subscribe(null, errorConsumer);
8383

8484
connection
85-
.send(sendProcessor)
85+
.send(
86+
sendProcessor.doOnRequest(
87+
r -> {
88+
for (LimitableRequestPublisher lrp : senders.values()) {
89+
lrp.increaseInternalLimit(r);
90+
}
91+
}))
8692
.doFinally(this::handleSendProcessorCancel)
8793
.subscribe(null, this::handleSendProcessorError);
8894

@@ -294,7 +300,8 @@ private Flux<Payload> handleChannel(Flux<Payload> request) {
294300
.transform(
295301
f -> {
296302
LimitableRequestPublisher<Payload> wrapped =
297-
LimitableRequestPublisher.wrap(f);
303+
LimitableRequestPublisher.wrap(
304+
f, sendProcessor.available());
298305
// Need to set this to one for first the frame
299306
wrapped.increaseRequestLimit(1);
300307
senders.put(streamId, wrapped);

rsocket-core/src/main/java/io/rsocket/RSocketServer.java

+47-5
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class RSocketServer implements ResponderRSocket {
5050
private final Function<Frame, ? extends Payload> frameDecoder;
5151
private final Consumer<Throwable> errorConsumer;
5252

53+
private final Map<Integer, LimitableRequestPublisher> sendingLimitableSubscriptions;
5354
private final Map<Integer, Subscription> sendingSubscriptions;
5455
private final Map<Integer, Processor<Payload, Payload>> channelProcessors;
5556

@@ -81,6 +82,7 @@ class RSocketServer implements ResponderRSocket {
8182
this.connection = connection;
8283
this.frameDecoder = frameDecoder;
8384
this.errorConsumer = errorConsumer;
85+
this.sendingLimitableSubscriptions = Collections.synchronizedMap(new IntObjectHashMap<>());
8486
this.sendingSubscriptions = Collections.synchronizedMap(new IntObjectHashMap<>());
8587
this.channelProcessors = Collections.synchronizedMap(new IntObjectHashMap<>());
8688

@@ -89,7 +91,13 @@ class RSocketServer implements ResponderRSocket {
8991
this.sendProcessor = new UnboundedProcessor<>();
9092

9193
connection
92-
.send(sendProcessor)
94+
.send(
95+
sendProcessor.doOnRequest(
96+
r -> {
97+
for (LimitableRequestPublisher lrp : sendingLimitableSubscriptions.values()) {
98+
lrp.increaseInternalLimit(r);
99+
}
100+
}))
93101
.doFinally(this::handleSendProcessorCancel)
94102
.subscribe(null, this::handleSendProcessorError);
95103

@@ -135,6 +143,17 @@ private void handleSendProcessorError(Throwable t) {
135143
}
136144
});
137145

146+
sendingLimitableSubscriptions
147+
.values()
148+
.forEach(
149+
subscription -> {
150+
try {
151+
subscription.cancel();
152+
} catch (Throwable e) {
153+
errorConsumer.accept(e);
154+
}
155+
});
156+
138157
channelProcessors
139158
.values()
140159
.forEach(
@@ -163,6 +182,17 @@ private void handleSendProcessorCancel(SignalType t) {
163182
}
164183
});
165184

185+
sendingLimitableSubscriptions
186+
.values()
187+
.forEach(
188+
subscription -> {
189+
try {
190+
subscription.cancel();
191+
} catch (Throwable e) {
192+
errorConsumer.accept(e);
193+
}
194+
});
195+
166196
channelProcessors
167197
.values()
168198
.forEach(
@@ -258,6 +288,9 @@ private void cleanup() {
258288
private synchronized void cleanUpSendingSubscriptions() {
259289
sendingSubscriptions.values().forEach(Subscription::cancel);
260290
sendingSubscriptions.clear();
291+
292+
sendingLimitableSubscriptions.values().forEach(Subscription::cancel);
293+
sendingLimitableSubscriptions.clear();
261294
}
262295

263296
private synchronized void cleanUpChannelProcessors() {
@@ -373,12 +406,12 @@ private void handleStream(int streamId, Flux<Payload> response, int initialReque
373406
.transform(
374407
frameFlux -> {
375408
LimitableRequestPublisher<Payload> payloads =
376-
LimitableRequestPublisher.wrap(frameFlux);
377-
sendingSubscriptions.put(streamId, payloads);
409+
LimitableRequestPublisher.wrap(frameFlux, sendProcessor.available());
410+
sendingLimitableSubscriptions.put(streamId, payloads);
378411
payloads.increaseRequestLimit(initialRequestN);
379412
return payloads;
380413
})
381-
.doFinally(signalType -> sendingSubscriptions.remove(streamId))
414+
.doFinally(signalType -> sendingLimitableSubscriptions.remove(streamId))
382415
.subscribe(
383416
payload -> {
384417
final Frame frame = Frame.PayloadFrame.from(streamId, FrameType.NEXT, payload);
@@ -423,6 +456,11 @@ private void handleKeepAliveFrame(Frame frame) {
423456

424457
private void handleCancelFrame(int streamId) {
425458
Subscription subscription = sendingSubscriptions.remove(streamId);
459+
460+
if (subscription == null) {
461+
subscription = sendingLimitableSubscriptions.get(streamId);
462+
}
463+
426464
if (subscription != null) {
427465
subscription.cancel();
428466
}
@@ -434,7 +472,11 @@ private void handleError(int streamId, Throwable t) {
434472
}
435473

436474
private void handleRequestN(int streamId, Frame frame) {
437-
final Subscription subscription = sendingSubscriptions.get(streamId);
475+
Subscription subscription = sendingSubscriptions.get(streamId);
476+
477+
if (subscription == null) {
478+
subscription = sendingLimitableSubscriptions.get(streamId);
479+
}
438480
if (subscription != null) {
439481
int n = Frame.RequestN.requestN(frame);
440482
subscription.request(n >= Integer.MAX_VALUE ? Long.MAX_VALUE : n);

rsocket-core/src/main/java/io/rsocket/internal/LimitableRequestPublisher.java

+27-13
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ public class LimitableRequestPublisher<T> extends Flux<T> implements Subscriptio
3131

3232
private final AtomicBoolean canceled;
3333

34+
private final long prefetch;
35+
3436
private long internalRequested;
3537

3638
private long externalRequested;
@@ -39,13 +41,14 @@ public class LimitableRequestPublisher<T> extends Flux<T> implements Subscriptio
3941

4042
private volatile @Nullable Subscription internalSubscription;
4143

42-
private LimitableRequestPublisher(Publisher<T> source) {
44+
private LimitableRequestPublisher(Publisher<T> source, long prefetch) {
4345
this.source = source;
46+
this.prefetch = prefetch;
4447
this.canceled = new AtomicBoolean();
4548
}
4649

47-
public static <T> LimitableRequestPublisher<T> wrap(Publisher<T> source) {
48-
return new LimitableRequestPublisher<>(source);
50+
public static <T> LimitableRequestPublisher<T> wrap(Publisher<T> source, long prefetch) {
51+
return new LimitableRequestPublisher<>(source, prefetch);
4952
}
5053

5154
@Override
@@ -60,6 +63,7 @@ public void subscribe(CoreSubscriber<? super T> destination) {
6063

6164
destination.onSubscribe(new InnerSubscription());
6265
source.subscribe(new InnerSubscriber(destination));
66+
increaseInternalLimit(prefetch);
6367
}
6468

6569
public void increaseRequestLimit(long n) {
@@ -70,6 +74,14 @@ public void increaseRequestLimit(long n) {
7074
requestN();
7175
}
7276

77+
public void increaseInternalLimit(long n) {
78+
synchronized (this) {
79+
internalRequested = Operators.addCap(n, internalRequested);
80+
}
81+
82+
requestN();
83+
}
84+
7385
@Override
7486
public void request(long n) {
7587
increaseRequestLimit(n);
@@ -82,9 +94,17 @@ private void requestN() {
8294
return;
8395
}
8496

85-
r = Math.min(internalRequested, externalRequested);
86-
externalRequested -= r;
87-
internalRequested -= r;
97+
if (externalRequested != Long.MAX_VALUE || internalRequested != Long.MAX_VALUE) {
98+
r = Math.min(internalRequested, externalRequested);
99+
if (externalRequested != Long.MAX_VALUE) {
100+
externalRequested -= r;
101+
}
102+
if (internalRequested != Long.MAX_VALUE) {
103+
internalRequested -= r;
104+
}
105+
} else {
106+
r = Long.MAX_VALUE;
107+
}
88108
}
89109

90110
if (r > 0) {
@@ -144,13 +164,7 @@ public void onComplete() {
144164

145165
private class InnerSubscription implements Subscription {
146166
@Override
147-
public void request(long n) {
148-
synchronized (LimitableRequestPublisher.this) {
149-
internalRequested = Operators.addCap(n, internalRequested);
150-
}
151-
152-
requestN();
153-
}
167+
public void request(long n) {}
154168

155169
@Override
156170
public void cancel() {

rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
package io.rsocket.internal;
1818

1919
import io.netty.util.ReferenceCountUtil;
20-
import io.netty.util.internal.shaded.org.jctools.queues.MpscUnboundedArrayQueue;
2120
import org.reactivestreams.Subscriber;
2221
import org.reactivestreams.Subscription;
2322
import reactor.core.CoreSubscriber;
@@ -221,6 +220,10 @@ public void onSubscribe(Subscription s) {
221220
}
222221
}
223222

223+
public long available() {
224+
return requested;
225+
}
226+
224227
@Override
225228
public int getPrefetch() {
226229
return Integer.MAX_VALUE;

rsocket-core/src/test/java/io/rsocket/RSocketClientTest.java

+25
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,15 @@
3333
import io.rsocket.exceptions.RejectedSetupException;
3434
import io.rsocket.frame.RequestFrameFlyweight;
3535
import io.rsocket.framing.FrameType;
36+
import io.rsocket.test.util.TestDuplexConnection;
3637
import io.rsocket.test.util.TestSubscriber;
3738
import io.rsocket.util.DefaultPayload;
3839
import io.rsocket.util.EmptyPayload;
3940
import java.time.Duration;
4041
import java.util.ArrayList;
4142
import java.util.List;
43+
import java.util.Queue;
44+
import java.util.concurrent.ConcurrentLinkedQueue;
4245
import java.util.stream.Collectors;
4346
import org.assertj.core.api.Assertions;
4447
import org.junit.Rule;
@@ -215,6 +218,28 @@ public void testChannelRequestServerSideCancellation() {
215218
Assertions.assertThat(request.isDisposed()).isTrue();
216219
}
217220

221+
@Test(timeout = 2_000)
222+
@SuppressWarnings("unchecked")
223+
public void
224+
testClientSideRequestChannelShouldNotHangInfinitelySendingElementsAndShouldProduceDataValuingConnectionBackpressure() {
225+
final Queue<Long> requests = new ConcurrentLinkedQueue<>();
226+
rule.connection.dispose();
227+
rule.connection = new TestDuplexConnection();
228+
rule.connection.setInitialSendRequestN(256);
229+
rule.init();
230+
231+
rule.socket
232+
.requestChannel(
233+
Flux.<Payload>generate(s -> s.next(EmptyPayload.INSTANCE)).doOnRequest(requests::add))
234+
.subscribe();
235+
236+
int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL);
237+
238+
rule.connection.addToReceivedBuffer(Frame.RequestN.from(streamId, 2));
239+
rule.connection.addToReceivedBuffer(Frame.RequestN.from(streamId, Integer.MAX_VALUE));
240+
Assertions.assertThat(requests).containsOnly(1L, 2L, 253L);
241+
}
242+
218243
public int sendRequestResponse(Publisher<Payload> response) {
219244
Subscriber<Payload> sub = TestSubscriber.create();
220245
response.subscribe(sub);

0 commit comments

Comments
 (0)