Skip to content

Commit

Permalink
Do not shutdown output channel until client receives the full response (
Browse files Browse the repository at this point in the history
#1090)

Motivation:

When client observes "Connection: close" header it shutdowns the output
channel as soon as request is written. Some servers interpret FIN from
the client as an indicator that it lost interest in their data and
therefore, server just closes the second half of the connection asap.
As the result, connection may be closed before client receives the full
response from the server.

Modifications:

- Reproduce this scenario using a simple proxy tunnel that is not aware
of HTTP protocol semantics;
- Defer connection closure on the client-side until it completes the
full request-response iteration;
- Fail all subsequent or pipelined requests on the connection that moves
to the "closing" state;
- Adjust `RequestResponseCloseHandlerTest` to not expecting outbound
half-closure on the client side;
- Add more tests to verify that client handles "Connection: close" header
correctly;

Result:

Response is not aborted when client observes "Connection: close" header.
  • Loading branch information
idelpivnitskiy authored Jul 1, 2020
1 parent ef94190 commit 39fab80
Show file tree
Hide file tree
Showing 8 changed files with 489 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import static io.servicetalk.concurrent.api.Single.collectUnordered;
import static io.servicetalk.http.api.HttpSerializationProviders.textSerializer;
import static io.servicetalk.http.netty.HttpProtocolConfigs.h1;
import static java.net.InetAddress.getLoopbackAddress;
import static java.nio.charset.StandardCharsets.US_ASCII;
import static java.util.concurrent.TimeUnit.SECONDS;

Expand Down Expand Up @@ -76,7 +77,7 @@ public static void afterClass() {
@Before
public void startServer() throws Exception {
assert executor != null;
serverSocket = new ServerSocket(0);
serverSocket = new ServerSocket(0, 50, getLoopbackAddress());

executor.submit(() -> {
while (!executor.isShutdown()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
/*
* Copyright © 2020 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.servicetalk.http.netty;

import io.servicetalk.buffer.api.Buffer;
import io.servicetalk.concurrent.internal.ServiceTalkTestTimeout;
import io.servicetalk.http.api.HttpPayloadWriter;
import io.servicetalk.http.api.ReservedStreamingHttpConnection;
import io.servicetalk.http.api.StreamingHttpClient;
import io.servicetalk.http.api.StreamingHttpRequest;
import io.servicetalk.http.api.StreamingHttpResponse;
import io.servicetalk.transport.api.IoExecutor;
import io.servicetalk.transport.api.ServerContext;
import io.servicetalk.transport.netty.internal.IoThreadFactory;

import org.junit.After;
import org.junit.Rule;
import org.junit.Test;

import java.nio.channels.ClosedChannelException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

import static io.servicetalk.concurrent.api.AsyncCloseables.newCompositeCloseable;
import static io.servicetalk.concurrent.api.Completable.never;
import static io.servicetalk.concurrent.api.Publisher.from;
import static io.servicetalk.http.api.HttpHeaderNames.CONNECTION;
import static io.servicetalk.http.api.HttpHeaderNames.CONTENT_LENGTH;
import static io.servicetalk.http.api.HttpHeaderValues.CLOSE;
import static io.servicetalk.http.api.HttpHeaderValues.ZERO;
import static io.servicetalk.http.api.HttpResponseStatus.OK;
import static io.servicetalk.http.api.HttpSerializationProviders.textSerializer;
import static io.servicetalk.http.api.Matchers.contentEqualTo;
import static io.servicetalk.transport.netty.NettyIoExecutors.createIoExecutor;
import static io.servicetalk.transport.netty.internal.AddressUtils.localAddress;
import static io.servicetalk.transport.netty.internal.AddressUtils.serverHostAndPort;
import static io.servicetalk.utils.internal.PlatformDependent.throwException;
import static java.lang.String.valueOf;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertThrows;

public class ConnectionCloseHeaderHandlingTest {

@Rule
public final ServiceTalkTestTimeout timeout = new ServiceTalkTestTimeout();

private final IoExecutor serverIoExecutor;
private final ServerContext serverContext;
private final StreamingHttpClient client;
private final ReservedStreamingHttpConnection connection;

private final CountDownLatch sendResponse = new CountDownLatch(1);
private final CountDownLatch responseReceived = new CountDownLatch(1);
private final CountDownLatch requestReceived = new CountDownLatch(1);
private final CountDownLatch connectionClosed = new CountDownLatch(1);
private final AtomicInteger requestPayloadSize = new AtomicInteger();

public ConnectionCloseHeaderHandlingTest() throws Exception {
serverIoExecutor = createIoExecutor(new IoThreadFactory("server-io-executor"));
serverContext = HttpServers.forAddress(localAddress(0))
.ioExecutor(serverIoExecutor)
.listenBlockingStreamingAndAwait((ctx, request, response) -> {
requestReceived.countDown();
String content = "server_content";
response.addHeader(CONTENT_LENGTH, valueOf(content.length()))
.addHeader(CONNECTION, CLOSE);

sendResponse.await();
try (HttpPayloadWriter<String> writer = response.sendMetaData(textSerializer())) {
// Defer payload body to see how client processes "Connection: close" header
request.payloadBody().forEach(chunk -> requestPayloadSize.addAndGet(chunk.readableBytes()));
responseReceived.await();
writer.write(content);
}
});

client = HttpClients.forSingleAddress(serverHostAndPort(serverContext))
.buildStreaming();
connection = client.reserveConnection(client.get("/")).toFuture().get();
connection.onClose().whenFinally(connectionClosed::countDown).subscribe();
}

@After
public void tearDown() throws Exception {
newCompositeCloseable().appendAll(client, serverContext, serverIoExecutor).close();
}

@Test
public void serverCloseNoRequestPayloadBody() throws Exception {
sendRequestAndAssertResponse(connection.get("/first")
.addHeader(CONTENT_LENGTH, ZERO));
}

@Test
public void serverCloseRequestWithPayloadBody() throws Exception {
String content = "request_content";
sendRequestAndAssertResponse(connection.post("/first")
.addHeader(CONTENT_LENGTH, valueOf(content.length()))
.payloadBody(client.executionContext().executor().submit(() -> {
try {
responseReceived.await();
} catch (InterruptedException e) {
throwException(e);
}
}).concat(from(content)), textSerializer()));
}

private void sendRequestAndAssertResponse(StreamingHttpRequest request) throws Exception {
sendResponse.countDown();
StreamingHttpResponse response = connection.request(request).toFuture().get();
assertResponse(response);
responseReceived.countDown();

assertResponsePayloadBody(response);
assertThat(request.headers().get(CONTENT_LENGTH), contentEqualTo(valueOf(requestPayloadSize.get())));

connectionClosed.await();
assertClosedChannelException("/second");
}

@Test
public void serverCloseTwoPipelinedRequestsSentBeforeFirstResponse() throws Exception {
AtomicReference<StreamingHttpResponse> firstResponse = new AtomicReference<>();
AtomicReference<Throwable> secondRequestError = new AtomicReference<>();
CountDownLatch secondResponseReceived = new CountDownLatch(1);

connection.request(connection.get("/first")
.addHeader(CONTENT_LENGTH, ZERO)).subscribe(first -> {
firstResponse.set(first);
responseReceived.countDown();
});
connection.request(connection.get("/second")
.addHeader(CONTENT_LENGTH, ZERO))
.whenOnError(secondRequestError::set)
.whenFinally(secondResponseReceived::countDown)
.subscribe(second -> { });
requestReceived.await();
sendResponse.countDown();
responseReceived.await();

StreamingHttpResponse response = firstResponse.get();
assertResponse(response);
assertResponsePayloadBody(response);

connectionClosed.await();
secondResponseReceived.await();
assertThat(secondRequestError.get(), instanceOf(ClosedChannelException.class));
assertClosedChannelException("/third");
}

@Test
public void serverCloseSecondPipelinedRequestWriteAborted() throws Exception {
AtomicReference<StreamingHttpResponse> firstResponse = new AtomicReference<>();
AtomicReference<Throwable> secondRequestError = new AtomicReference<>();
CountDownLatch secondResponseReceived = new CountDownLatch(1);

connection.request(connection.get("/first")
.addHeader(CONTENT_LENGTH, ZERO)).subscribe(first -> {
firstResponse.set(first);
responseReceived.countDown();
});
String content = "request_content";
connection.request(connection.get("/second")
.addHeader(CONTENT_LENGTH, valueOf(content.length()))
.payloadBody(from(content).concat(never()), textSerializer()))
.whenOnError(secondRequestError::set)
.whenFinally(secondResponseReceived::countDown)
.subscribe(second -> { });
requestReceived.await();
sendResponse.countDown();
responseReceived.await();

StreamingHttpResponse response = firstResponse.get();
assertResponse(response);
assertResponsePayloadBody(response);

connectionClosed.await();
secondResponseReceived.await();
assertThat(secondRequestError.get(), instanceOf(ClosedChannelException.class));
assertClosedChannelException("/third");
}

@Test
public void serverCloseTwoPipelinedRequestsInSequence() throws Exception {
sendResponse.countDown();
StreamingHttpResponse response = connection.request(connection.get("/first")
.addHeader(CONTENT_LENGTH, ZERO)).toFuture().get();
assertResponse(response);

// Send another request before client reads payload body of the first request:
assertClosedChannelException("/second");

responseReceived.countDown();
assertResponsePayloadBody(response);
connectionClosed.await();
}

@Test
public void clientCloseTwoPipelinedRequestsSentBeforeFirstResponse() throws Exception {
AtomicReference<StreamingHttpResponse> firstResponse = new AtomicReference<>();

connection.request(connection.get("/first")
.addHeader(CONTENT_LENGTH, ZERO)
// Request connection closure:
.addHeader(CONNECTION, CLOSE)).subscribe(first -> {
firstResponse.set(first);
responseReceived.countDown();
});
// Send another request before client receives a response for the first request:
assertClosedChannelException("/second");
sendResponse.countDown();
responseReceived.await();

StreamingHttpResponse response = firstResponse.get();
assertResponse(response);
assertResponsePayloadBody(response);
connectionClosed.await();
}

private static void assertResponse(StreamingHttpResponse response) {
assertThat(response.status(), is(OK));
assertThat(response.headers().get(CONNECTION), contentEqualTo(CLOSE));
}

private static void assertResponsePayloadBody(StreamingHttpResponse response) throws Exception {
int actualContentLength = response.payloadBody().map(Buffer::readableBytes)
.collect(AtomicInteger::new, (total, current) -> {
total.addAndGet(current);
return total;
}).toFuture().get().get();
assertThat(response.headers().get(CONTENT_LENGTH), contentEqualTo(valueOf(actualContentLength)));
}

private void assertClosedChannelException(String path) {
Exception e = assertThrows(ExecutionException.class,
() -> connection.request(connection.get(path).addHeader(CONTENT_LENGTH, ZERO)).toFuture().get());
assertThat(e.getCause(), instanceOf(ClosedChannelException.class));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,48 @@
package io.servicetalk.http.netty;

import io.servicetalk.concurrent.internal.ServiceTalkTestTimeout;
import io.servicetalk.http.api.BlockingHttpClient;
import io.servicetalk.http.api.HttpClient;
import io.servicetalk.http.api.HttpResponse;
import io.servicetalk.transport.api.HostAndPort;
import io.servicetalk.transport.api.ServerContext;

import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;

import java.util.concurrent.atomic.AtomicInteger;
import javax.annotation.Nullable;

import static io.servicetalk.concurrent.api.Single.succeeded;
import static io.servicetalk.http.api.HttpHeaderNames.HOST;
import static io.servicetalk.http.api.HttpResponseStatus.OK;
import static io.servicetalk.http.api.HttpSerializationProviders.textSerializer;
import static io.servicetalk.http.netty.HttpsProxyTest.safeClose;
import static io.servicetalk.transport.netty.internal.AddressUtils.localAddress;
import static io.servicetalk.transport.netty.internal.AddressUtils.serverHostAndPort;
import static java.nio.charset.StandardCharsets.US_ASCII;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertThat;

public class HttpProxyTest {

@Rule
public final ServiceTalkTestTimeout timeout = new ServiceTalkTestTimeout();

private int proxyPort;
private int serverPort;
private HttpClient client;
@Nullable
private HttpClient proxyClient;
@Nullable
private ServerContext proxyContext;
@Nullable
private HostAndPort proxyAddress;
@Nullable
private ServerContext serverContext;
@Nullable
private HostAndPort serverAddress;
@Nullable
private BlockingHttpClient client;
private final AtomicInteger proxyRequestCount = new AtomicInteger();

@Before
Expand All @@ -53,33 +67,43 @@ public void setup() throws Exception {
createClient();
}

@After
public void tearDown() throws Exception {
safeClose(client);
safeClose(proxyClient);
safeClose(proxyContext);
safeClose(serverContext);
}

public void startProxy() throws Exception {
final HttpClient proxyClient = HttpClients.forMultiAddressUrl().build();
final ServerContext serverContext = HttpServers.forAddress(localAddress(0))
proxyClient = HttpClients.forMultiAddressUrl().build();
proxyContext = HttpServers.forAddress(localAddress(0))
.listenAndAwait((ctx, request, responseFactory) -> {
proxyRequestCount.incrementAndGet();
return proxyClient.request(request);
});
proxyPort = serverHostAndPort(serverContext).port();
proxyAddress = serverHostAndPort(proxyContext);
}

public void startServer() throws Exception {
final ServerContext serverContext = HttpServers.forAddress(localAddress(0))
serverContext = HttpServers.forAddress(localAddress(0))
.listenAndAwait((ctx, request, responseFactory) -> succeeded(responseFactory.ok()
.payloadBody("host: " + request.headers().get(HOST), textSerializer())));
serverPort = serverHostAndPort(serverContext).port();
serverAddress = serverHostAndPort(serverContext);
}

public void createClient() {
client = HttpClients.forSingleAddressViaProxy("localhost", serverPort, "localhost", proxyPort)
.build();
assert serverAddress != null && proxyAddress != null;
client = HttpClients.forSingleAddressViaProxy(serverAddress, proxyAddress)
.buildBlocking();
}

@Test
public void testRequest() throws Exception {
final HttpResponse httpResponse = client.request(client.get("/path")).toFuture().get();
assert client != null;
final HttpResponse httpResponse = client.request(client.get("/path"));
assertThat(httpResponse.status(), is(OK));
assertThat(proxyRequestCount.get(), is(1));
assertThat(httpResponse.payloadBody().toString(US_ASCII), is("host: localhost:" + serverPort));
assertThat(httpResponse.payloadBody().toString(US_ASCII), is("host: " + serverAddress));
}
}
Loading

0 comments on commit 39fab80

Please sign in to comment.