Skip to content

Commit b9f5d21

Browse files
committed
core: DelayedStream cancels provided stream if not using it.
Resolves grpc#1537 Also disallow cancel() before start(). DelayedClientTransport.shutdownNow() races with stream start(), thus it shouldn't call cancel() directly. It would delay the cancellation until the stream is started.
1 parent c4642f8 commit b9f5d21

File tree

6 files changed

+221
-51
lines changed

6 files changed

+221
-51
lines changed

core/src/main/java/io/grpc/internal/DelayedClientTransport.java

+32-2
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ public final void shutdownNow(Status status) {
218218
}
219219
if (savedPendingStreams != null) {
220220
for (PendingStream stream : savedPendingStreams) {
221-
stream.cancel(status);
221+
stream.cancelInternal(status);
222222
}
223223
listener.transportTerminated();
224224
}
@@ -477,6 +477,12 @@ private class PendingStream extends DelayedStream {
477477
private final Context context;
478478
private final StatsTraceContext statsTraceCtx;
479479

480+
private final Object pendingStreamLock = new Object();
481+
@GuardedBy("pendingStreamLock")
482+
private boolean started;
483+
@GuardedBy("pendingStreamLock")
484+
private Status pendingCancelReason;
485+
480486
private PendingStream(MethodDescriptor<?, ?> method, Metadata headers,
481487
CallOptions callOptions, StatsTraceContext statsTraceCtx) {
482488
this.method = method;
@@ -497,8 +503,32 @@ private void createRealStream(ClientTransport transport) {
497503
setStream(realStream);
498504
}
499505

506+
// This may be called concurrently with other methods on the stream
507+
private void cancelInternal(Status reason) {
508+
synchronized (pendingStreamLock) {
509+
if (!started) {
510+
pendingCancelReason = reason;
511+
return;
512+
}
513+
}
514+
cancel(reason);
515+
}
516+
517+
@Override
518+
public final void start(ClientStreamListener listener) {
519+
Status savedPendingCancelReason;
520+
synchronized (pendingStreamLock) {
521+
started = true;
522+
savedPendingCancelReason = pendingCancelReason;
523+
}
524+
super.start(listener);
525+
if (savedPendingCancelReason != null) {
526+
cancel(savedPendingCancelReason);
527+
}
528+
}
529+
500530
@Override
501-
public void cancel(Status reason) {
531+
public final void cancel(Status reason) {
502532
super.cancel(reason);
503533
synchronized (lock) {
504534
if (pendingStreams != null) {

core/src/main/java/io/grpc/internal/DelayedClientTransport2.java

+32-2
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ public final void shutdownNow(Status status) {
247247
}
248248
if (savedPendingStreams != null) {
249249
for (PendingStream stream : savedPendingStreams) {
250-
stream.cancel(status);
250+
stream.cancelInternal(status);
251251
}
252252
channelExecutor.executeLater(reportTransportTerminated).drain();
253253
}
@@ -353,6 +353,12 @@ private class PendingStream extends DelayedStream {
353353
private final Context context;
354354
private final StatsTraceContext statsTraceCtx;
355355

356+
private final Object pendingStreamLock = new Object();
357+
@GuardedBy("pendingStreamLock")
358+
private boolean started;
359+
@GuardedBy("pendingStreamLock")
360+
private Status pendingCancelReason;
361+
356362
private PendingStream(MethodDescriptor<?, ?> method, Metadata headers,
357363
CallOptions callOptions, StatsTraceContext statsTraceCtx) {
358364
this.method = method;
@@ -373,8 +379,32 @@ private void createRealStream(ClientTransport transport) {
373379
setStream(realStream);
374380
}
375381

382+
// This may be called concurrently with other methods on the stream
383+
private void cancelInternal(Status reason) {
384+
synchronized (pendingStreamLock) {
385+
if (!started) {
386+
pendingCancelReason = reason;
387+
return;
388+
}
389+
}
390+
cancel(reason);
391+
}
392+
393+
@Override
394+
public final void start(ClientStreamListener listener) {
395+
Status savedPendingCancelReason;
396+
synchronized (pendingStreamLock) {
397+
started = true;
398+
savedPendingCancelReason = pendingCancelReason;
399+
}
400+
super.start(listener);
401+
if (savedPendingCancelReason != null) {
402+
cancel(savedPendingCancelReason);
403+
}
404+
}
405+
376406
@Override
377-
public void cancel(Status reason) {
407+
public final void cancel(Status reason) {
378408
super.cancel(reason);
379409
synchronized (lock) {
380410
if (pendingStreams != null) {

core/src/main/java/io/grpc/internal/DelayedStream.java

+59-24
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,21 @@
5656
* necessary.
5757
*/
5858
class DelayedStream implements ClientStream {
59+
@VisibleForTesting
60+
static final ClientStreamListener NOOP_STREAM_LISTENER = new ClientStreamListener() {
61+
@Override
62+
public void messageRead(InputStream message) {}
63+
64+
@Override
65+
public void onReady() {}
66+
67+
@Override
68+
public void headersRead(Metadata headers) {}
69+
70+
@Override
71+
public void closed(Status status, Metadata trailers) {}
72+
};
73+
5974
/** {@code true} once realStream is valid and all pending calls have been drained. */
6075
private volatile boolean passThrough;
6176
/**
@@ -73,7 +88,7 @@ class DelayedStream implements ClientStream {
7388
private DelayedStreamListener delayedListener;
7489

7590
@Override
76-
public void setMaxInboundMessageSize(final int maxSize) {
91+
public final void setMaxInboundMessageSize(final int maxSize) {
7792
if (passThrough) {
7893
realStream.setMaxInboundMessageSize(maxSize);
7994
} else {
@@ -87,7 +102,7 @@ public void run() {
87102
}
88103

89104
@Override
90-
public void setMaxOutboundMessageSize(final int maxSize) {
105+
public final void setMaxOutboundMessageSize(final int maxSize) {
91106
if (passThrough) {
92107
realStream.setMaxOutboundMessageSize(maxSize);
93108
} else {
@@ -103,19 +118,40 @@ public void run() {
103118
/**
104119
* Transfers all pending and future requests and mutations to the given stream.
105120
*
106-
* <p>No-op if either this method or {@link #cancel} have already been called.
121+
* <p>This method must be called at most once. Extraneous calls will throw and end up cancelling
122+
* the given streams.
123+
*
124+
* <p>If {@link #cancelInternal} has been called, this method will cancel the given stream.
107125
*/
108126
// When this method returns, passThrough is guaranteed to be true
109127
final void setStream(ClientStream stream) {
128+
ClientStream savedRealStream;
129+
Status savedError;
110130
synchronized (this) {
111-
// If realStream != null, then either setStream() or cancel() has been called.
112-
if (realStream != null) {
113-
return;
131+
savedRealStream = realStream;
132+
savedError = error;
133+
if (savedRealStream == null) {
134+
realStream = checkNotNull(stream, "stream");
114135
}
115-
realStream = checkNotNull(stream, "stream");
116136
}
117137

118-
drainPendingCalls();
138+
if (savedRealStream == null) {
139+
drainPendingCalls();
140+
} else {
141+
// If realStream was not null, then either setStream() or cancel() must had been called,
142+
// we will cancel and discard the given stream.
143+
// ClientStream.cancel() must be called after start()
144+
stream.start(NOOP_STREAM_LISTENER);
145+
if (savedError != null) {
146+
stream.cancel(savedError);
147+
} else {
148+
// If cancel() were called, error must have been non-null.
149+
IllegalStateException exception = new IllegalStateException(
150+
"DelayedStream.setStream() is called more than once");
151+
stream.cancel(Status.CANCELLED.withCause(exception));
152+
throw exception;
153+
}
154+
}
119155
}
120156

121157
/**
@@ -173,7 +209,7 @@ private void delayOrExecute(Runnable runnable) {
173209
}
174210

175211
@Override
176-
public void setAuthority(final String authority) {
212+
public final void setAuthority(final String authority) {
177213
checkState(listener == null, "May only be called before start");
178214
checkNotNull(authority, "authority");
179215
delayOrExecute(new Runnable() {
@@ -192,7 +228,8 @@ public void start(ClientStreamListener listener) {
192228
boolean savedPassThrough;
193229
synchronized (this) {
194230
this.listener = checkNotNull(listener, "listener");
195-
// If error != null, then cancel() has been called and was unable to close the listener
231+
// If error != null, then cancelInternal() has been called and was unable to close the
232+
// listener
196233
savedError = error;
197234
savedPassThrough = passThrough;
198235
if (!savedPassThrough) {
@@ -218,7 +255,7 @@ public void run() {
218255
}
219256

220257
@Override
221-
public void writeMessage(final InputStream message) {
258+
public final void writeMessage(final InputStream message) {
222259
checkNotNull(message, "message");
223260
if (passThrough) {
224261
realStream.writeMessage(message);
@@ -233,7 +270,7 @@ public void run() {
233270
}
234271

235272
@Override
236-
public void flush() {
273+
public final void flush() {
237274
if (passThrough) {
238275
realStream.flush();
239276
} else {
@@ -253,12 +290,12 @@ public void cancel(final Status reason) {
253290
boolean delegateToRealStream = true;
254291
ClientStreamListener listenerToClose = null;
255292
synchronized (this) {
256-
// If realStream != null, then either setStream() or cancel() has been called
293+
if (listener == null) {
294+
throw new IllegalStateException("cancel() must be called after start()");
295+
}
257296
if (realStream == null) {
258297
realStream = NoopClientStream.INSTANCE;
259298
delegateToRealStream = false;
260-
261-
// If listener == null, then start() will later call listener with 'error'
262299
listenerToClose = listener;
263300
error = reason;
264301
}
@@ -271,15 +308,13 @@ public void run() {
271308
}
272309
});
273310
} else {
274-
if (listenerToClose != null) {
275-
listenerToClose.closed(reason, new Metadata());
276-
}
311+
listenerToClose.closed(reason, new Metadata());
277312
drainPendingCalls();
278313
}
279314
}
280315

281316
@Override
282-
public void halfClose() {
317+
public final void halfClose() {
283318
delayOrExecute(new Runnable() {
284319
@Override
285320
public void run() {
@@ -289,7 +324,7 @@ public void run() {
289324
}
290325

291326
@Override
292-
public void request(final int numMessages) {
327+
public final void request(final int numMessages) {
293328
if (passThrough) {
294329
realStream.request(numMessages);
295330
} else {
@@ -303,7 +338,7 @@ public void run() {
303338
}
304339

305340
@Override
306-
public void setCompressor(final Compressor compressor) {
341+
public final void setCompressor(final Compressor compressor) {
307342
checkNotNull(compressor, "compressor");
308343
delayOrExecute(new Runnable() {
309344
@Override
@@ -314,7 +349,7 @@ public void run() {
314349
}
315350

316351
@Override
317-
public void setDecompressor(Decompressor decompressor) {
352+
public final void setDecompressor(Decompressor decompressor) {
318353
checkNotNull(decompressor, "decompressor");
319354
// This method being called only makes sense after setStream() has been called (but not
320355
// necessarily returned), but there is not necessarily a happens-before relationship. This
@@ -327,7 +362,7 @@ public void setDecompressor(Decompressor decompressor) {
327362
}
328363

329364
@Override
330-
public boolean isReady() {
365+
public final boolean isReady() {
331366
if (passThrough) {
332367
return realStream.isReady();
333368
} else {
@@ -336,7 +371,7 @@ public boolean isReady() {
336371
}
337372

338373
@Override
339-
public void setMessageCompression(final boolean enable) {
374+
public final void setMessageCompression(final boolean enable) {
340375
if (passThrough) {
341376
realStream.setMessageCompression(enable);
342377
} else {

core/src/test/java/io/grpc/internal/DelayedClientTransport2Test.java

+14
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,9 @@ public class DelayedClientTransport2Test {
210210
@Test public void cancelStreamWithoutSetTransport() {
211211
ClientStream stream = delayedTransport.newStream(method, new Metadata());
212212
assertEquals(1, delayedTransport.getPendingStreamsCount());
213+
stream.start(streamListener);
213214
stream.cancel(Status.CANCELLED);
215+
verify(streamListener).closed(any(Status.class), any(Metadata.class));
214216
assertEquals(0, delayedTransport.getPendingStreamsCount());
215217
verifyNoMoreInteractions(mockRealTransport);
216218
verifyNoMoreInteractions(mockRealStream);
@@ -265,7 +267,9 @@ public class DelayedClientTransport2Test {
265267
verify(transportListener).transportShutdown(any(Status.class));
266268
verify(transportListener, times(0)).transportTerminated();
267269
assertEquals(1, delayedTransport.getPendingStreamsCount());
270+
stream.start(streamListener);
268271
stream.cancel(Status.CANCELLED);
272+
verify(streamListener).closed(any(Status.class), any(Metadata.class));
269273
verify(transportListener).transportTerminated();
270274
assertEquals(0, delayedTransport.getPendingStreamsCount());
271275
verifyNoMoreInteractions(mockRealTransport);
@@ -282,6 +286,16 @@ public class DelayedClientTransport2Test {
282286
assertEquals(Status.Code.UNAVAILABLE, statusCaptor.getValue().getCode());
283287
}
284288

289+
@Test public void newStreamThenShutdownNow() {
290+
ClientStream stream = delayedTransport.newStream(method, new Metadata());
291+
delayedTransport.shutdownNow(Status.UNAVAILABLE);
292+
verify(transportListener).transportShutdown(any(Status.class));
293+
verify(transportListener).transportTerminated();
294+
stream.start(streamListener);
295+
verify(streamListener).closed(statusCaptor.capture(), any(Metadata.class));
296+
assertEquals(Status.Code.UNAVAILABLE, statusCaptor.getValue().getCode());
297+
}
298+
285299
@Test public void startStreamThenShutdownNow() {
286300
ClientStream stream = delayedTransport.newStream(method, new Metadata());
287301
stream.start(streamListener);

core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java

+12
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ public class DelayedClientTransportTest {
226226
@Test public void cancelStreamWithoutSetTransport() {
227227
ClientStream stream = delayedTransport.newStream(method, new Metadata());
228228
assertEquals(1, delayedTransport.getPendingStreamsCount());
229+
stream.start(streamListener);
229230
stream.cancel(Status.CANCELLED);
230231
assertEquals(0, delayedTransport.getPendingStreamsCount());
231232
verifyNoMoreInteractions(mockRealTransport);
@@ -249,6 +250,7 @@ public class DelayedClientTransportTest {
249250
verify(transportListener).transportShutdown(any(Status.class));
250251
verify(transportListener, times(0)).transportTerminated();
251252
assertEquals(1, delayedTransport.getPendingStreamsCount());
253+
stream.start(streamListener);
252254
stream.cancel(Status.CANCELLED);
253255
verify(transportListener).transportTerminated();
254256
assertEquals(0, delayedTransport.getPendingStreamsCount());
@@ -291,6 +293,16 @@ public class DelayedClientTransportTest {
291293
assertEquals(Status.Code.UNAVAILABLE, statusCaptor.getValue().getCode());
292294
}
293295

296+
@Test public void newStreamThenShutdownNow() {
297+
ClientStream stream = delayedTransport.newStream(method, new Metadata());
298+
delayedTransport.shutdownNow(Status.UNAVAILABLE);
299+
verify(transportListener).transportShutdown(any(Status.class));
300+
verify(transportListener).transportTerminated();
301+
stream.start(streamListener);
302+
verify(streamListener).closed(statusCaptor.capture(), any(Metadata.class));
303+
assertEquals(Status.Code.UNAVAILABLE, statusCaptor.getValue().getCode());
304+
}
305+
294306
@Test public void startStreamThenShutdownNow() {
295307
ClientStream stream = delayedTransport.newStream(method, new Metadata());
296308
stream.start(streamListener);

0 commit comments

Comments
 (0)