From 1561bbaeaecfd5ef790d47cb2dfcdc435ef26d66 Mon Sep 17 00:00:00 2001 From: Jesse Wilson Date: Thu, 21 Dec 2023 06:25:40 -0500 Subject: [PATCH] Convert some tests to Kotlin (#8152) * Rename .java to .kt without conversion * Convert to Kotlin --- .../kotlin/okhttp3/RecordingEventListener.kt | 2 + okhttp/src/test/java/okhttp3/CacheTest.java | 3014 --------------- okhttp/src/test/java/okhttp3/CacheTest.kt | 3432 +++++++++++++++++ .../okhttp3/CertificateChainCleanerTest.java | 293 -- .../okhttp3/CertificateChainCleanerTest.kt | 307 ++ .../java/okhttp3/CertificatePinnerTest.java | 320 -- .../java/okhttp3/CertificatePinnerTest.kt | 333 ++ .../okhttp3/ConnectionCoalescingTest.java | 556 --- .../java/okhttp3/ConnectionCoalescingTest.kt | 534 +++ .../test/java/okhttp3/ConnectionSpecTest.java | 379 -- .../test/java/okhttp3/ConnectionSpecTest.kt | 396 ++ .../src/test/java/okhttp3/DispatcherTest.java | 333 -- .../src/test/java/okhttp3/DispatcherTest.kt | 350 ++ okhttp/src/test/java/okhttp3/DuplexTest.java | 751 ---- okhttp/src/test/java/okhttp3/DuplexTest.kt | 761 ++++ .../test/java/okhttp3/EventListenerTest.java | 1657 -------- .../test/java/okhttp3/EventListenerTest.kt | 1850 +++++++++ .../test/java/okhttp3/InterceptorTest.java | 918 ----- .../src/test/java/okhttp3/InterceptorTest.kt | 854 ++++ .../test/java/okhttp3/MultipartBodyTest.java | 281 -- .../test/java/okhttp3/MultipartBodyTest.kt | 301 ++ .../okhttp3/WholeOperationTimeoutTest.java | 348 -- .../java/okhttp3/WholeOperationTimeoutTest.kt | 362 ++ .../internal/cache2/FileOperatorTest.java | 198 - .../internal/cache2/FileOperatorTest.kt | 209 + .../okhttp3/internal/cache2/RelayTest.java | 251 -- .../java/okhttp3/internal/cache2/RelayTest.kt | 233 ++ .../internal/http2/BaseTestHandler.java | 74 - .../okhttp3/internal/http2/BaseTestHandler.kt | 109 + .../okhttp3/internal/http2/HpackTest.java | 1107 ------ .../java/okhttp3/internal/http2/HpackTest.kt | 1103 ++++++ .../CertificatePinnerChainValidationTest.java | 632 --- .../CertificatePinnerChainValidationTest.kt | 653 ++++ .../okhttp3/internal/tls/ClientAuthTest.java | 374 -- .../okhttp3/internal/tls/ClientAuthTest.kt | 364 ++ .../internal/ws/RealWebSocketTest.java | 490 --- .../okhttp3/internal/ws/RealWebSocketTest.kt | 497 +++ .../internal/ws/WebSocketHttpTest.java | 1100 ------ .../okhttp3/internal/ws/WebSocketHttpTest.kt | 1143 ++++++ 39 files changed, 13793 insertions(+), 13076 deletions(-) delete mode 100644 okhttp/src/test/java/okhttp3/CacheTest.java create mode 100644 okhttp/src/test/java/okhttp3/CacheTest.kt delete mode 100644 okhttp/src/test/java/okhttp3/CertificateChainCleanerTest.java create mode 100644 okhttp/src/test/java/okhttp3/CertificateChainCleanerTest.kt delete mode 100644 okhttp/src/test/java/okhttp3/CertificatePinnerTest.java create mode 100644 okhttp/src/test/java/okhttp3/CertificatePinnerTest.kt delete mode 100644 okhttp/src/test/java/okhttp3/ConnectionCoalescingTest.java create mode 100644 okhttp/src/test/java/okhttp3/ConnectionCoalescingTest.kt delete mode 100644 okhttp/src/test/java/okhttp3/ConnectionSpecTest.java create mode 100644 okhttp/src/test/java/okhttp3/ConnectionSpecTest.kt delete mode 100644 okhttp/src/test/java/okhttp3/DispatcherTest.java create mode 100644 okhttp/src/test/java/okhttp3/DispatcherTest.kt delete mode 100644 okhttp/src/test/java/okhttp3/DuplexTest.java create mode 100644 okhttp/src/test/java/okhttp3/DuplexTest.kt delete mode 100644 okhttp/src/test/java/okhttp3/EventListenerTest.java create mode 100644 okhttp/src/test/java/okhttp3/EventListenerTest.kt delete mode 100644 okhttp/src/test/java/okhttp3/InterceptorTest.java create mode 100644 okhttp/src/test/java/okhttp3/InterceptorTest.kt delete mode 100644 okhttp/src/test/java/okhttp3/MultipartBodyTest.java create mode 100644 okhttp/src/test/java/okhttp3/MultipartBodyTest.kt delete mode 100644 okhttp/src/test/java/okhttp3/WholeOperationTimeoutTest.java create mode 100644 okhttp/src/test/java/okhttp3/WholeOperationTimeoutTest.kt delete mode 100644 okhttp/src/test/java/okhttp3/internal/cache2/FileOperatorTest.java create mode 100644 okhttp/src/test/java/okhttp3/internal/cache2/FileOperatorTest.kt delete mode 100644 okhttp/src/test/java/okhttp3/internal/cache2/RelayTest.java create mode 100644 okhttp/src/test/java/okhttp3/internal/cache2/RelayTest.kt delete mode 100644 okhttp/src/test/java/okhttp3/internal/http2/BaseTestHandler.java create mode 100644 okhttp/src/test/java/okhttp3/internal/http2/BaseTestHandler.kt delete mode 100644 okhttp/src/test/java/okhttp3/internal/http2/HpackTest.java create mode 100644 okhttp/src/test/java/okhttp3/internal/http2/HpackTest.kt delete mode 100644 okhttp/src/test/java/okhttp3/internal/tls/CertificatePinnerChainValidationTest.java create mode 100644 okhttp/src/test/java/okhttp3/internal/tls/CertificatePinnerChainValidationTest.kt delete mode 100644 okhttp/src/test/java/okhttp3/internal/tls/ClientAuthTest.java create mode 100644 okhttp/src/test/java/okhttp3/internal/tls/ClientAuthTest.kt delete mode 100644 okhttp/src/test/java/okhttp3/internal/ws/RealWebSocketTest.java create mode 100644 okhttp/src/test/java/okhttp3/internal/ws/RealWebSocketTest.kt delete mode 100644 okhttp/src/test/java/okhttp3/internal/ws/WebSocketHttpTest.java create mode 100644 okhttp/src/test/java/okhttp3/internal/ws/WebSocketHttpTest.kt diff --git a/okhttp-testing-support/src/main/kotlin/okhttp3/RecordingEventListener.kt b/okhttp-testing-support/src/main/kotlin/okhttp3/RecordingEventListener.kt index 013b4ac26168..14f858ac0a20 100644 --- a/okhttp-testing-support/src/main/kotlin/okhttp3/RecordingEventListener.kt +++ b/okhttp-testing-support/src/main/kotlin/okhttp3/RecordingEventListener.kt @@ -95,6 +95,8 @@ open class RecordingEventListener( } } + inline fun removeUpToEvent(): T = removeUpToEvent(T::class.java) + /** * Remove and return the next event from the recorded sequence. * diff --git a/okhttp/src/test/java/okhttp3/CacheTest.java b/okhttp/src/test/java/okhttp3/CacheTest.java deleted file mode 100644 index 2dceb96cf480..000000000000 --- a/okhttp/src/test/java/okhttp3/CacheTest.java +++ /dev/null @@ -1,3014 +0,0 @@ -/* - * Copyright (C) 2011 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package okhttp3; - -import java.io.IOException; -import java.net.CookieManager; -import java.net.HttpURLConnection; -import java.net.ResponseCache; -import java.security.Principal; -import java.security.cert.Certificate; -import java.text.DateFormat; -import java.text.SimpleDateFormat; -import java.util.ArrayList; -import java.util.Date; -import java.util.Iterator; -import java.util.List; -import java.util.Locale; -import java.util.NoSuchElementException; -import java.util.TimeZone; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; -import javax.net.ssl.HostnameVerifier; -import mockwebserver3.MockResponse; -import mockwebserver3.MockWebServer; -import mockwebserver3.RecordedRequest; -import mockwebserver3.junit5.internal.MockWebServerInstance; -import okhttp3.internal.Internal; -import okhttp3.internal.platform.Platform; -import okhttp3.java.net.cookiejar.JavaNetCookieJar; -import okhttp3.testing.PlatformRule; -import okhttp3.tls.HandshakeCertificates; -import okio.Buffer; -import okio.BufferedSink; -import okio.BufferedSource; -import okio.FileSystem; -import okio.ForwardingFileSystem; -import okio.GzipSink; -import okio.Okio; -import okio.Path; -import okio.fakefilesystem.FakeFileSystem; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Tag; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.RegisterExtension; -import static mockwebserver3.SocketPolicy.DisconnectAtEnd; -import static okhttp3.internal.Internal.cacheGet; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.data.Offset.offset; -import static org.junit.jupiter.api.Assertions.fail; - -@Tag("Slow") -public final class CacheTest { - private static final HostnameVerifier NULL_HOSTNAME_VERIFIER = (name, session) -> true; - - public final FakeFileSystem fileSystem = new FakeFileSystem(); - @RegisterExtension public final OkHttpClientTestRule clientTestRule = new OkHttpClientTestRule(); - @RegisterExtension public final PlatformRule platform = new PlatformRule(); - - private MockWebServer server; - private MockWebServer server2; - private final HandshakeCertificates handshakeCertificates - = platform.localhostHandshakeCertificates(); - private OkHttpClient client; - private Cache cache; - private final CookieManager cookieManager = new CookieManager(); - - @BeforeEach - public void setUp( - @MockWebServerInstance(name = "1") MockWebServer server, - @MockWebServerInstance(name = "2") MockWebServer server2 - ) throws Exception { - this.server = server; - this.server2 = server2; - - platform.assumeNotOpenJSSE(); - - server.setProtocolNegotiationEnabled(false); - fileSystem.emulateUnix(); - cache = new Cache(Path.get("/cache/"), Integer.MAX_VALUE, fileSystem); - client = clientTestRule.newClientBuilder() - .cache(cache) - .cookieJar(new JavaNetCookieJar(cookieManager)) - .build(); - } - - @AfterEach public void tearDown() throws Exception { - ResponseCache.setDefault(null); - - if (cache != null) { - cache.delete(); - } - } - - /** - * Test that response caching is consistent with the RI and the spec. - * http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html#sec13.4 - */ - @Test public void responseCachingByResponseCode() throws Exception { - // Test each documented HTTP/1.1 code, plus the first unused value in each range. - // http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html - - // We can't test 100 because it's not really a response. - // assertCached(false, 100); - assertCached(false, 101); - assertCached(true, 200); - assertCached(false, 201); - assertCached(false, 202); - assertCached(true, 203); - assertCached(true, 204); - assertCached(false, 205); - assertCached(false, 206); //Electing to not cache partial responses - assertCached(false, 207); - assertCached(true, 300); - assertCached(true, 301); - assertCached(true, 302); - assertCached(false, 303); - assertCached(false, 304); - assertCached(false, 305); - assertCached(false, 306); - assertCached(true, 307); - assertCached(true, 308); - assertCached(false, 400); - assertCached(false, 401); - assertCached(false, 402); - assertCached(false, 403); - assertCached(true, 404); - assertCached(true, 405); - assertCached(false, 406); - assertCached(false, 408); - assertCached(false, 409); - // the HTTP spec permits caching 410s, but the RI doesn't. - assertCached(true, 410); - assertCached(false, 411); - assertCached(false, 412); - assertCached(false, 413); - assertCached(true, 414); - assertCached(false, 415); - assertCached(false, 416); - assertCached(false, 417); - assertCached(false, 418); - - assertCached(false, 500); - assertCached(true, 501); - assertCached(false, 502); - assertCached(false, 503); - assertCached(false, 504); - assertCached(false, 505); - assertCached(false, 506); - } - - @Test public void responseCachingWith1xxInformationalResponse() throws Exception { - assertSubsequentResponseCached( 102, 200); - assertSubsequentResponseCached( 103, 200); - } - - private void assertCached(boolean shouldWriteToCache, int responseCode) throws Exception { - int expectedResponseCode = responseCode; - - server = new MockWebServer(); - MockResponse.Builder builder = new MockResponse.Builder() - .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) - .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) - .code(responseCode) - .body("ABCDE") - .addHeader("WWW-Authenticate: challenge"); - if (responseCode == HttpURLConnection.HTTP_PROXY_AUTH) { - builder.addHeader("Proxy-Authenticate: Basic realm=\"protected area\""); - } else if (responseCode == HttpURLConnection.HTTP_UNAUTHORIZED) { - builder.addHeader("WWW-Authenticate: Basic realm=\"protected area\""); - } else if (responseCode == HttpURLConnection.HTTP_NO_CONTENT - || responseCode == HttpURLConnection.HTTP_RESET) { - builder.body(""); // We forbid bodies for 204 and 205. - } - server.enqueue(builder.build()); - - if (responseCode == HttpURLConnection.HTTP_CLIENT_TIMEOUT) { - // 408's are a bit of an outlier because we may repeat the request if we encounter this - // response code. In this scenario, there are 2 responses: the initial 408 and then the 200 - // because of the retry. We just want to ensure the initial 408 isn't cached. - expectedResponseCode = 200; - server.enqueue(new MockResponse.Builder() - .setHeader("Cache-Control", "no-store") - .body("FGHIJ") - .build()); - } - - server.start(); - - Request request = new Request.Builder() - .url(server.url("/")) - .build(); - Response response = client.newCall(request).execute(); - assertThat(response.code()).isEqualTo(expectedResponseCode); - - // Exhaust the content stream. - response.body().string(); - - Response cached = cacheGet(cache, request); - if (shouldWriteToCache) { - assertThat(cached).isNotNull(); - cached.body().close(); - } else { - assertThat(cached).isNull(); - } - server.shutdown(); // tearDown() isn't sufficient; this test starts multiple servers - } - - private void assertSubsequentResponseCached(int initialResponseCode, int finalResponseCode) throws Exception { - server = new MockWebServer(); - MockResponse.Builder builder = new MockResponse.Builder() - .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) - .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) - .code(finalResponseCode) - .body("ABCDE") - .addInformationalResponse(new MockResponse(initialResponseCode)); - - server.enqueue(builder.build()); - server.start(); - - Request request = new Request.Builder() - .url(server.url("/")) - .build(); - Response response = client.newCall(request).execute(); - assertThat(response.code()).isEqualTo(finalResponseCode); - - // Exhaust the content stream. - response.body().string(); - - Response cached = cacheGet(cache, request); - assertThat(cached).isNotNull(); - cached.body().close(); - server.shutdown(); // tearDown() isn't sufficient; this test starts multiple servers - } - - @Test public void responseCachingAndInputStreamSkipWithFixedLength() throws IOException { - testResponseCaching(TransferKind.FIXED_LENGTH); - } - - @Test public void responseCachingAndInputStreamSkipWithChunkedEncoding() throws IOException { - testResponseCaching(TransferKind.CHUNKED); - } - - @Test public void responseCachingAndInputStreamSkipWithNoLengthHeaders() throws IOException { - testResponseCaching(TransferKind.END_OF_STREAM); - } - - /** - * Skipping bytes in the input stream caused ResponseCache corruption. - * http://code.google.com/p/android/issues/detail?id=8175 - */ - private void testResponseCaching(TransferKind transferKind) throws IOException { - MockResponse.Builder mockResponse = new MockResponse.Builder() - .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) - .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) - .status("HTTP/1.1 200 Fantastic"); - transferKind.setBody(mockResponse, "I love puppies but hate spiders", 1); - server.enqueue(mockResponse.build()); - - // Make sure that calling skip() doesn't omit bytes from the cache. - Request request = new Request.Builder().url(server.url("/")).build(); - Response response1 = client.newCall(request).execute(); - - BufferedSource in1 = response1.body().source(); - assertThat(in1.readUtf8("I love ".length())).isEqualTo("I love "); - in1.skip("puppies but hate ".length()); - assertThat(in1.readUtf8("spiders".length())).isEqualTo("spiders"); - assertThat(in1.exhausted()).isTrue(); - in1.close(); - assertThat(cache.writeSuccessCount()).isEqualTo(1); - assertThat(cache.writeAbortCount()).isEqualTo(0); - - Response response2 = client.newCall(request).execute(); - BufferedSource in2 = response2.body().source(); - assertThat(in2.readUtf8("I love puppies but hate spiders".length())).isEqualTo( - "I love puppies but hate spiders"); - assertThat(response2.code()).isEqualTo(200); - assertThat(response2.message()).isEqualTo("Fantastic"); - - assertThat(in2.exhausted()).isTrue(); - in2.close(); - assertThat(cache.writeSuccessCount()).isEqualTo(1); - assertThat(cache.writeAbortCount()).isEqualTo(0); - assertThat(cache.requestCount()).isEqualTo(2); - assertThat(cache.hitCount()).isEqualTo(1); - } - - @Test public void secureResponseCaching() throws IOException { - server.useHttps(handshakeCertificates.sslSocketFactory()); - server.enqueue(new MockResponse.Builder() - .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) - .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) - .body("ABC") - .build()); - - client = client.newBuilder() - .sslSocketFactory( - handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager()) - .hostnameVerifier(NULL_HOSTNAME_VERIFIER) - .build(); - - Request request = new Request.Builder().url(server.url("/")).build(); - Response response1 = client.newCall(request).execute(); - BufferedSource in = response1.body().source(); - assertThat(in.readUtf8()).isEqualTo("ABC"); - - // OpenJDK 6 fails on this line, complaining that the connection isn't open yet - CipherSuite cipherSuite = response1.handshake().cipherSuite(); - List localCerts = response1.handshake().localCertificates(); - List serverCerts = response1.handshake().peerCertificates(); - Principal peerPrincipal = response1.handshake().peerPrincipal(); - Principal localPrincipal = response1.handshake().localPrincipal(); - - Response response2 = client.newCall(request).execute(); // Cached! - assertThat(response2.body().string()).isEqualTo("ABC"); - - assertThat(cache.requestCount()).isEqualTo(2); - assertThat(cache.networkCount()).isEqualTo(1); - assertThat(cache.hitCount()).isEqualTo(1); - - assertThat(response2.handshake().cipherSuite()).isEqualTo(cipherSuite); - assertThat(response2.handshake().localCertificates()).isEqualTo(localCerts); - assertThat(response2.handshake().peerCertificates()).isEqualTo(serverCerts); - assertThat(response2.handshake().peerPrincipal()).isEqualTo(peerPrincipal); - assertThat(response2.handshake().localPrincipal()).isEqualTo(localPrincipal); - } - - @Test public void secureResponseCachingWithCorruption() throws IOException { - server.useHttps(handshakeCertificates.sslSocketFactory()); - server.enqueue(new MockResponse.Builder() - .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) - .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) - .body("ABC") - .build()); - server.enqueue(new MockResponse.Builder() - .addHeader("Last-Modified: " + formatDate(-5, TimeUnit.MINUTES)) - .addHeader("Expires: " + formatDate(2, TimeUnit.HOURS)) - .body("DEF") - .build()); - - client = client.newBuilder() - .sslSocketFactory( - handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager()) - .hostnameVerifier(NULL_HOSTNAME_VERIFIER) - .build(); - - Request request = new Request.Builder().url(server.url("/")).build(); - Response response1 = client.newCall(request).execute(); - assertThat(response1.body().string()).isEqualTo("ABC"); - - Path cacheEntry = fileSystem.allPaths().stream() - .filter((e) -> e.name().endsWith(".0")) - .findFirst() - .orElseThrow(NoSuchElementException::new); - corruptCertificate(cacheEntry); - - Response response2 = client.newCall(request).execute(); // Not Cached! - assertThat(response2.body().string()).isEqualTo("DEF"); - - assertThat(cache.requestCount()).isEqualTo(2); - assertThat(cache.networkCount()).isEqualTo(2); - assertThat(cache.hitCount()).isEqualTo(0); - } - - private void corruptCertificate(Path cacheEntry) throws IOException { - String content = Okio.buffer(fileSystem.source(cacheEntry)).readUtf8(); - content = content.replace("MII", "!!!"); - Okio.buffer(fileSystem.sink(cacheEntry)).writeUtf8(content).close(); - } - - @Test public void responseCachingAndRedirects() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) - .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) - .code(HttpURLConnection.HTTP_MOVED_PERM) - .addHeader("Location: /foo") - .build()); - server.enqueue(new MockResponse.Builder() - .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) - .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) - .body("ABC") - .build()); - server.enqueue(new MockResponse.Builder() - .body("DEF") - .build()); - - Request request = new Request.Builder().url(server.url("/")).build(); - Response response1 = client.newCall(request).execute(); - assertThat(response1.body().string()).isEqualTo("ABC"); - - Response response2 = client.newCall(request).execute(); // Cached! - assertThat(response2.body().string()).isEqualTo("ABC"); - - // 2 requests + 2 redirects - assertThat(cache.requestCount()).isEqualTo(4); - assertThat(cache.networkCount()).isEqualTo(2); - assertThat(cache.hitCount()).isEqualTo(2); - } - - @Test public void redirectToCachedResult() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Cache-Control: max-age=60") - .body("ABC") - .build()); - server.enqueue(new MockResponse.Builder() - .code(HttpURLConnection.HTTP_MOVED_PERM) - .addHeader("Location: /foo") - .build()); - server.enqueue(new MockResponse.Builder() - .body("DEF") - .build()); - - Request request1 = new Request.Builder().url(server.url("/foo")).build(); - Response response1 = client.newCall(request1).execute(); - assertThat(response1.body().string()).isEqualTo("ABC"); - RecordedRequest recordedRequest1 = server.takeRequest(); - assertThat(recordedRequest1.getRequestLine()).isEqualTo("GET /foo HTTP/1.1"); - assertThat(recordedRequest1.getSequenceNumber()).isEqualTo(0); - - Request request2 = new Request.Builder().url(server.url("/bar")).build(); - Response response2 = client.newCall(request2).execute(); - assertThat(response2.body().string()).isEqualTo("ABC"); - RecordedRequest recordedRequest2 = server.takeRequest(); - assertThat(recordedRequest2.getRequestLine()).isEqualTo("GET /bar HTTP/1.1"); - assertThat(recordedRequest2.getSequenceNumber()).isEqualTo(1); - - // an unrelated request should reuse the pooled connection - Request request3 = new Request.Builder().url(server.url("/baz")).build(); - Response response3 = client.newCall(request3).execute(); - assertThat(response3.body().string()).isEqualTo("DEF"); - RecordedRequest recordedRequest3 = server.takeRequest(); - assertThat(recordedRequest3.getRequestLine()).isEqualTo("GET /baz HTTP/1.1"); - assertThat(recordedRequest3.getSequenceNumber()).isEqualTo(2); - } - - @Test public void secureResponseCachingAndRedirects() throws IOException { - server.useHttps(handshakeCertificates.sslSocketFactory()); - server.enqueue(new MockResponse.Builder() - .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) - .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) - .code(HttpURLConnection.HTTP_MOVED_PERM) - .addHeader("Location: /foo") - .build()); - server.enqueue(new MockResponse.Builder() - .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) - .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) - .body("ABC") - .build()); - server.enqueue(new MockResponse.Builder() - .body("DEF") - .build()); - - client = client.newBuilder() - .sslSocketFactory( - handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager()) - .hostnameVerifier(NULL_HOSTNAME_VERIFIER) - .build(); - - Response response1 = get(server.url("/")); - assertThat(response1.body().string()).isEqualTo("ABC"); - assertThat(response1.handshake().cipherSuite()).isNotNull(); - - // Cached! - Response response2 = get(server.url("/")); - assertThat(response2.body().string()).isEqualTo("ABC"); - assertThat(response2.handshake().cipherSuite()).isNotNull(); - - // 2 direct + 2 redirect = 4 - assertThat(cache.requestCount()).isEqualTo(4); - assertThat(cache.hitCount()).isEqualTo(2); - assertThat(response2.handshake().cipherSuite()).isEqualTo( - response1.handshake().cipherSuite()); - } - - /** - * We've had bugs where caching and cross-protocol redirects yield class cast exceptions internal - * to the cache because we incorrectly assumed that HttpsURLConnection was always HTTPS and - * HttpURLConnection was always HTTP; in practice redirects mean that each can do either. - * - * https://github.com/square/okhttp/issues/214 - */ - @Test public void secureResponseCachingAndProtocolRedirects() throws IOException { - server2.useHttps(handshakeCertificates.sslSocketFactory()); - server2.enqueue(new MockResponse.Builder() - .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) - .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) - .body("ABC") - .build()); - server2.enqueue(new MockResponse.Builder() - .body("DEF") - .build()); - - server.enqueue(new MockResponse.Builder() - .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) - .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) - .code(HttpURLConnection.HTTP_MOVED_PERM) - .addHeader("Location: " + server2.url("/")) - .build()); - - client = client.newBuilder() - .sslSocketFactory( - handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager()) - .hostnameVerifier(NULL_HOSTNAME_VERIFIER) - .build(); - - Response response1 = get(server.url("/")); - assertThat(response1.body().string()).isEqualTo("ABC"); - - // Cached! - Response response2 = get(server.url("/")); - assertThat(response2.body().string()).isEqualTo("ABC"); - - // 2 direct + 2 redirect = 4 - assertThat(cache.requestCount()).isEqualTo(4); - assertThat(cache.hitCount()).isEqualTo(2); - } - - @Test public void foundCachedWithExpiresHeader() throws Exception { - temporaryRedirectCachedWithCachingHeader(302, "Expires", formatDate(1, TimeUnit.HOURS)); - } - - @Test public void foundCachedWithCacheControlHeader() throws Exception { - temporaryRedirectCachedWithCachingHeader(302, "Cache-Control", "max-age=60"); - } - - @Test public void temporaryRedirectCachedWithExpiresHeader() throws Exception { - temporaryRedirectCachedWithCachingHeader(307, "Expires", formatDate(1, TimeUnit.HOURS)); - } - - @Test public void temporaryRedirectCachedWithCacheControlHeader() throws Exception { - temporaryRedirectCachedWithCachingHeader(307, "Cache-Control", "max-age=60"); - } - - @Test public void foundNotCachedWithoutCacheHeader() throws Exception { - temporaryRedirectNotCachedWithoutCachingHeader(302); - } - - @Test public void temporaryRedirectNotCachedWithoutCacheHeader() throws Exception { - temporaryRedirectNotCachedWithoutCachingHeader(307); - } - - private void temporaryRedirectCachedWithCachingHeader( - int responseCode, String headerName, String headerValue) throws Exception { - server.enqueue(new MockResponse.Builder() - .code(responseCode) - .addHeader(headerName, headerValue) - .addHeader("Location", "/a") - .build()); - server.enqueue(new MockResponse.Builder() - .addHeader(headerName, headerValue) - .body("a") - .build()); - server.enqueue(new MockResponse.Builder() - .body("b") - .build()); - server.enqueue(new MockResponse.Builder() - .body("c") - .build()); - - HttpUrl url = server.url("/"); - assertThat(get(url).body().string()).isEqualTo("a"); - assertThat(get(url).body().string()).isEqualTo("a"); - } - - private void temporaryRedirectNotCachedWithoutCachingHeader(int responseCode) throws Exception { - server.enqueue(new MockResponse.Builder() - .code(responseCode) - .addHeader("Location", "/a") - .build()); - server.enqueue(new MockResponse.Builder() - .body("a") - .build()); - server.enqueue(new MockResponse.Builder() - .body("b") - .build()); - - HttpUrl url = server.url("/"); - assertThat(get(url).body().string()).isEqualTo("a"); - assertThat(get(url).body().string()).isEqualTo("b"); - } - - /** https://github.com/square/okhttp/issues/2198 */ - @Test public void cachedRedirect() throws IOException { - server.enqueue(new MockResponse.Builder() - .code(301) - .addHeader("Cache-Control: max-age=60") - .addHeader("Location: /bar") - .build()); - server.enqueue(new MockResponse.Builder() - .body("ABC") - .build()); - server.enqueue(new MockResponse.Builder() - .body("ABC") - .build()); - - Request request1 = new Request.Builder().url(server.url("/")).build(); - Response response1 = client.newCall(request1).execute(); - assertThat(response1.body().string()).isEqualTo("ABC"); - - Request request2 = new Request.Builder().url(server.url("/")).build(); - Response response2 = client.newCall(request2).execute(); - assertThat(response2.body().string()).isEqualTo("ABC"); - } - - @Test public void serverDisconnectsPrematurelyWithContentLengthHeader() throws IOException { - testServerPrematureDisconnect(TransferKind.FIXED_LENGTH); - } - - @Test public void serverDisconnectsPrematurelyWithChunkedEncoding() throws IOException { - testServerPrematureDisconnect(TransferKind.CHUNKED); - } - - @Test public void serverDisconnectsPrematurelyWithNoLengthHeaders() throws IOException { - // Intentionally empty. This case doesn't make sense because there's no - // such thing as a premature disconnect when the disconnect itself - // indicates the end of the data stream. - } - - private void testServerPrematureDisconnect(TransferKind transferKind) throws IOException { - MockResponse.Builder mockResponse = new MockResponse.Builder(); - transferKind.setBody(mockResponse, "ABCDE\nFGHIJKLMNOPQRSTUVWXYZ", 16); - server.enqueue(truncateViolently(mockResponse, 16).build()); - server.enqueue(new MockResponse.Builder() - .body("Request #2") - .build()); - - BufferedSource bodySource = get(server.url("/")).body().source(); - assertThat(bodySource.readUtf8Line()).isEqualTo("ABCDE"); - try { - bodySource.readUtf8(21); - fail("This implementation silently ignored a truncated HTTP body."); - } catch (IOException expected) { - } finally { - bodySource.close(); - } - - assertThat(cache.writeAbortCount()).isEqualTo(1); - assertThat(cache.writeSuccessCount()).isEqualTo(0); - Response response = get(server.url("/")); - assertThat(response.body().string()).isEqualTo("Request #2"); - assertThat(cache.writeAbortCount()).isEqualTo(1); - assertThat(cache.writeSuccessCount()).isEqualTo(1); - } - - @Test public void clientPrematureDisconnectWithContentLengthHeader() throws IOException { - testClientPrematureDisconnect(TransferKind.FIXED_LENGTH); - } - - @Test public void clientPrematureDisconnectWithChunkedEncoding() throws IOException { - testClientPrematureDisconnect(TransferKind.CHUNKED); - } - - @Test public void clientPrematureDisconnectWithNoLengthHeaders() throws IOException { - testClientPrematureDisconnect(TransferKind.END_OF_STREAM); - } - - private void testClientPrematureDisconnect(TransferKind transferKind) throws IOException { - // Setting a low transfer speed ensures that stream discarding will time out. - MockResponse.Builder builder = new MockResponse.Builder() - .throttleBody(6, 1, TimeUnit.SECONDS); - transferKind.setBody(builder, "ABCDE\nFGHIJKLMNOPQRSTUVWXYZ", 1024); - server.enqueue(builder.build()); - server.enqueue(new MockResponse.Builder() - .body("Request #2") - .build()); - - Response response1 = get(server.url("/")); - BufferedSource in = response1.body().source(); - assertThat(in.readUtf8(5)).isEqualTo("ABCDE"); - in.close(); - try { - in.readByte(); - fail("Expected an IllegalStateException because the source is closed."); - } catch (IllegalStateException expected) { - } - - assertThat(cache.writeAbortCount()).isEqualTo(1); - assertThat(cache.writeSuccessCount()).isEqualTo(0); - Response response2 = get(server.url("/")); - assertThat(response2.body().string()).isEqualTo("Request #2"); - assertThat(cache.writeAbortCount()).isEqualTo(1); - assertThat(cache.writeSuccessCount()).isEqualTo(1); - } - - @Test public void defaultExpirationDateFullyCachedForLessThan24Hours() throws Exception { - // last modified: 105 seconds ago - // served: 5 seconds ago - // default lifetime: (105 - 5) / 10 = 10 seconds - // expires: 10 seconds from served date = 5 seconds from now - server.enqueue(new MockResponse.Builder() - .addHeader("Last-Modified: " + formatDate(-105, TimeUnit.SECONDS)) - .addHeader("Date: " + formatDate(-5, TimeUnit.SECONDS)) - .body("A") - .build()); - - HttpUrl url = server.url("/"); - Response response1 = get(url); - assertThat(response1.body().string()).isEqualTo("A"); - - Response response2 = get(url); - assertThat(response2.body().string()).isEqualTo("A"); - assertThat(response2.header("Warning")).isNull(); - } - - @Test public void defaultExpirationDateConditionallyCached() throws Exception { - // last modified: 115 seconds ago - // served: 15 seconds ago - // default lifetime: (115 - 15) / 10 = 10 seconds - // expires: 10 seconds from served date = 5 seconds ago - String lastModifiedDate = formatDate(-115, TimeUnit.SECONDS); - RecordedRequest conditionalRequest = assertConditionallyCached(new MockResponse.Builder() - .addHeader("Last-Modified: " + lastModifiedDate) - .addHeader("Date: " + formatDate(-15, TimeUnit.SECONDS)) - .build()); - assertThat(conditionalRequest.getHeaders().get("If-Modified-Since")) - .isEqualTo(lastModifiedDate); - } - - @Test public void defaultExpirationDateFullyCachedForMoreThan24Hours() throws Exception { - // last modified: 105 days ago - // served: 5 days ago - // default lifetime: (105 - 5) / 10 = 10 days - // expires: 10 days from served date = 5 days from now - server.enqueue(new MockResponse.Builder() - .addHeader("Last-Modified: " + formatDate(-105, TimeUnit.DAYS)) - .addHeader("Date: " + formatDate(-5, TimeUnit.DAYS)) - .body("A") - .build()); - - assertThat(get(server.url("/")).body().string()).isEqualTo("A"); - Response response = get(server.url("/")); - assertThat(response.body().string()).isEqualTo("A"); - assertThat(response.header("Warning")).isEqualTo( - "113 HttpURLConnection \"Heuristic expiration\""); - } - - @Test public void noDefaultExpirationForUrlsWithQueryString() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Last-Modified: " + formatDate(-105, TimeUnit.SECONDS)) - .addHeader("Date: " + formatDate(-5, TimeUnit.SECONDS)) - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .body("B") - .build()); - - HttpUrl url = server.url("/").newBuilder().addQueryParameter("foo", "bar").build(); - assertThat(get(url).body().string()).isEqualTo("A"); - assertThat(get(url).body().string()).isEqualTo("B"); - } - - @Test public void expirationDateInThePastWithLastModifiedHeader() throws Exception { - String lastModifiedDate = formatDate(-2, TimeUnit.HOURS); - RecordedRequest conditionalRequest = assertConditionallyCached(new MockResponse.Builder() - .addHeader("Last-Modified: " + lastModifiedDate) - .addHeader("Expires: " + formatDate(-1, TimeUnit.HOURS)) - .build()); - assertThat(conditionalRequest.getHeaders().get("If-Modified-Since")) - .isEqualTo(lastModifiedDate); - } - - @Test public void expirationDateInThePastWithNoLastModifiedHeader() throws Exception { - assertNotCached(new MockResponse.Builder() - .addHeader("Expires: " + formatDate(-1, TimeUnit.HOURS)) - .build()); - } - - @Test public void expirationDateInTheFuture() throws Exception { - assertFullyCached(new MockResponse.Builder() - .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) - .build()); - } - - @Test public void maxAgePreferredWithMaxAgeAndExpires() throws Exception { - assertFullyCached(new MockResponse.Builder() - .addHeader("Date: " + formatDate(0, TimeUnit.HOURS)) - .addHeader("Expires: " + formatDate(-1, TimeUnit.HOURS)) - .addHeader("Cache-Control: max-age=60") - .build()); - } - - @Test public void maxAgeInThePastWithDateAndLastModifiedHeaders() throws Exception { - String lastModifiedDate = formatDate(-2, TimeUnit.HOURS); - RecordedRequest conditionalRequest = assertConditionallyCached(new MockResponse.Builder() - .addHeader("Date: " + formatDate(-120, TimeUnit.SECONDS)) - .addHeader("Last-Modified: " + lastModifiedDate) - .addHeader("Cache-Control: max-age=60") - .build()); - assertThat(conditionalRequest.getHeaders().get("If-Modified-Since")) - .isEqualTo(lastModifiedDate); - } - - @Test public void maxAgeInThePastWithDateHeaderButNoLastModifiedHeader() throws Exception { - // Chrome interprets max-age relative to the local clock. Both our cache - // and Firefox both use the earlier of the local and server's clock. - assertNotCached(new MockResponse.Builder() - .addHeader("Date: " + formatDate(-120, TimeUnit.SECONDS)) - .addHeader("Cache-Control: max-age=60") - .build()); - } - - @Test public void maxAgeInTheFutureWithDateHeader() throws Exception { - assertFullyCached(new MockResponse.Builder() - .addHeader("Date: " + formatDate(0, TimeUnit.HOURS)) - .addHeader("Cache-Control: max-age=60") - .build()); - } - - @Test public void maxAgeInTheFutureWithNoDateHeader() throws Exception { - assertFullyCached(new MockResponse.Builder() - .addHeader("Cache-Control: max-age=60") - .build()); - } - - @Test public void maxAgeWithLastModifiedButNoServedDate() throws Exception { - assertFullyCached(new MockResponse.Builder() - .addHeader("Last-Modified: " + formatDate(-120, TimeUnit.SECONDS)) - .addHeader("Cache-Control: max-age=60") - .build()); - } - - @Test public void maxAgeInTheFutureWithDateAndLastModifiedHeaders() throws Exception { - assertFullyCached(new MockResponse.Builder() - .addHeader("Last-Modified: " + formatDate(-120, TimeUnit.SECONDS)) - .addHeader("Date: " + formatDate(0, TimeUnit.SECONDS)) - .addHeader("Cache-Control: max-age=60") - .build()); - } - - @Test public void maxAgePreferredOverLowerSharedMaxAge() throws Exception { - assertFullyCached(new MockResponse.Builder() - .addHeader("Date: " + formatDate(-2, TimeUnit.MINUTES)) - .addHeader("Cache-Control: s-maxage=60") - .addHeader("Cache-Control: max-age=180") - .build()); - } - - @Test public void maxAgePreferredOverHigherMaxAge() throws Exception { - assertNotCached(new MockResponse.Builder() - .addHeader("Date: " + formatDate(-2, TimeUnit.MINUTES)) - .addHeader("Cache-Control: s-maxage=180") - .addHeader("Cache-Control: max-age=60") - .build()); - } - - @Test public void requestMethodOptionsIsNotCached() throws Exception { - testRequestMethod("OPTIONS", false); - } - - @Test public void requestMethodGetIsCached() throws Exception { - testRequestMethod("GET", true); - } - - @Test public void requestMethodHeadIsNotCached() throws Exception { - // We could support this but choose not to for implementation simplicity - testRequestMethod("HEAD", false); - } - - @Test public void requestMethodPostIsNotCached() throws Exception { - // We could support this but choose not to for implementation simplicity - testRequestMethod("POST", false); - } - - @Test public void requestMethodPutIsNotCached() throws Exception { - testRequestMethod("PUT", false); - } - - @Test public void requestMethodDeleteIsNotCached() throws Exception { - testRequestMethod("DELETE", false); - } - - @Test public void requestMethodTraceIsNotCached() throws Exception { - testRequestMethod("TRACE", false); - } - - private void testRequestMethod(String requestMethod, boolean expectCached) throws Exception { - // 1. Seed the cache (potentially). - // 2. Expect a cache hit or miss. - server.enqueue(new MockResponse.Builder() - .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) - .addHeader("X-Response-ID: 1") - .build()); - server.enqueue(new MockResponse.Builder() - .addHeader("X-Response-ID: 2") - .build()); - - HttpUrl url = server.url("/"); - - Request request = new Request.Builder() - .url(url) - .method(requestMethod, requestBodyOrNull(requestMethod)) - .build(); - Response response1 = client.newCall(request).execute(); - response1.body().close(); - assertThat(response1.header("X-Response-ID")).isEqualTo("1"); - - Response response2 = get(url); - response2.body().close(); - if (expectCached) { - assertThat(response2.header("X-Response-ID")).isEqualTo("1"); - } else { - assertThat(response2.header("X-Response-ID")).isEqualTo("2"); - } - } - - private RequestBody requestBodyOrNull(String requestMethod) { - return (requestMethod.equals("POST") || requestMethod.equals("PUT")) - ? RequestBody.create("foo", MediaType.get("text/plain")) - : null; - } - - @Test public void postInvalidatesCache() throws Exception { - testMethodInvalidates("POST"); - } - - @Test public void putInvalidatesCache() throws Exception { - testMethodInvalidates("PUT"); - } - - @Test public void deleteMethodInvalidatesCache() throws Exception { - testMethodInvalidates("DELETE"); - } - - private void testMethodInvalidates(String requestMethod) throws Exception { - // 1. Seed the cache. - // 2. Invalidate it. - // 3. Expect a cache miss. - server.enqueue(new MockResponse.Builder() - .body("A") - .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) - .build()); - server.enqueue(new MockResponse.Builder() - .body("B") - .build()); - server.enqueue(new MockResponse.Builder() - .body("C") - .build()); - - HttpUrl url = server.url("/"); - - assertThat(get(url).body().string()).isEqualTo("A"); - - Request request = new Request.Builder() - .url(url) - .method(requestMethod, requestBodyOrNull(requestMethod)) - .build(); - Response invalidate = client.newCall(request).execute(); - assertThat(invalidate.body().string()).isEqualTo("B"); - - assertThat(get(url).body().string()).isEqualTo("C"); - } - - @Test public void postInvalidatesCacheWithUncacheableResponse() throws Exception { - // 1. Seed the cache. - // 2. Invalidate it with an uncacheable response. - // 3. Expect a cache miss. - server.enqueue(new MockResponse.Builder() - .body("A") - .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) - .build()); - server.enqueue(new MockResponse.Builder() - .body("B") - .code(500) - .build()); - server.enqueue(new MockResponse.Builder() - .body("C") - .build()); - - HttpUrl url = server.url("/"); - - assertThat(get(url).body().string()).isEqualTo("A"); - - Request request = new Request.Builder() - .url(url) - .method("POST", requestBodyOrNull("POST")) - .build(); - Response invalidate = client.newCall(request).execute(); - assertThat(invalidate.body().string()).isEqualTo("B"); - - assertThat(get(url).body().string()).isEqualTo("C"); - } - - @Test public void putInvalidatesWithNoContentResponse() throws Exception { - // 1. Seed the cache. - // 2. Invalidate it. - // 3. Expect a cache miss. - server.enqueue(new MockResponse.Builder() - .body("A") - .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) - .build()); - server.enqueue(new MockResponse.Builder() - .clearHeaders() - .code(HttpURLConnection.HTTP_NO_CONTENT) - .build()); - server.enqueue(new MockResponse.Builder() - .body("C") - .build()); - - HttpUrl url = server.url("/"); - - assertThat(get(url).body().string()).isEqualTo("A"); - - Request request = new Request.Builder() - .url(url) - .put(RequestBody.create("foo", MediaType.get("text/plain"))) - .build(); - Response invalidate = client.newCall(request).execute(); - assertThat(invalidate.body().string()).isEqualTo(""); - - assertThat(get(url).body().string()).isEqualTo("C"); - } - - @Test public void etag() throws Exception { - RecordedRequest conditionalRequest = assertConditionallyCached(new MockResponse.Builder() - .addHeader("ETag: v1") - .build()); - assertThat(conditionalRequest.getHeaders().get("If-None-Match")).isEqualTo("v1"); - } - - /** If both If-Modified-Since and If-None-Match conditions apply, send only If-None-Match. */ - @Test public void etagAndExpirationDateInThePast() throws Exception { - String lastModifiedDate = formatDate(-2, TimeUnit.HOURS); - RecordedRequest conditionalRequest = assertConditionallyCached(new MockResponse.Builder() - .addHeader("ETag: v1") - .addHeader("Last-Modified: " + lastModifiedDate) - .addHeader("Expires: " + formatDate(-1, TimeUnit.HOURS)) - .build()); - assertThat(conditionalRequest.getHeaders().get("If-None-Match")).isEqualTo("v1"); - assertThat(conditionalRequest.getHeaders().get("If-Modified-Since")).isNull(); - } - - @Test public void etagAndExpirationDateInTheFuture() throws Exception { - assertFullyCached(new MockResponse.Builder() - .addHeader("ETag: v1") - .addHeader("Last-Modified: " + formatDate(-2, TimeUnit.HOURS)) - .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) - .build()); - } - - @Test public void cacheControlNoCache() throws Exception { - assertNotCached(new MockResponse.Builder() - .addHeader("Cache-Control: no-cache") - .build()); - } - - @Test public void cacheControlNoCacheAndExpirationDateInTheFuture() throws Exception { - String lastModifiedDate = formatDate(-2, TimeUnit.HOURS); - RecordedRequest conditionalRequest = assertConditionallyCached(new MockResponse.Builder() - .addHeader("Last-Modified: " + lastModifiedDate) - .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) - .addHeader("Cache-Control: no-cache") - .build()); - assertThat(conditionalRequest.getHeaders().get("If-Modified-Since")) - .isEqualTo(lastModifiedDate); - } - - @Test public void pragmaNoCache() throws Exception { - assertNotCached(new MockResponse.Builder() - .addHeader("Pragma: no-cache") - .build()); - } - - @Test public void pragmaNoCacheAndExpirationDateInTheFuture() throws Exception { - String lastModifiedDate = formatDate(-2, TimeUnit.HOURS); - RecordedRequest conditionalRequest = assertConditionallyCached(new MockResponse.Builder() - .addHeader("Last-Modified: " + lastModifiedDate) - .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) - .addHeader("Pragma: no-cache") - .build()); - assertThat(conditionalRequest.getHeaders().get("If-Modified-Since")) - .isEqualTo(lastModifiedDate); - } - - @Test public void cacheControlNoStore() throws Exception { - assertNotCached(new MockResponse.Builder() - .addHeader("Cache-Control: no-store") - .build()); - } - - @Test public void cacheControlNoStoreAndExpirationDateInTheFuture() throws Exception { - assertNotCached(new MockResponse.Builder() - .addHeader("Last-Modified: " + formatDate(-2, TimeUnit.HOURS)) - .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) - .addHeader("Cache-Control: no-store") - .build()); - } - - @Test public void partialRangeResponsesDoNotCorruptCache() throws Exception { - // 1. Request a range. - // 2. Request a full document, expecting a cache miss. - server.enqueue(new MockResponse.Builder() - .body("AA") - .code(HttpURLConnection.HTTP_PARTIAL) - .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) - .addHeader("Content-Range: bytes 1000-1001/2000") - .build()); - server.enqueue(new MockResponse.Builder() - .body("BB") - .build()); - - HttpUrl url = server.url("/"); - - Request request = new Request.Builder() - .url(url) - .header("Range", "bytes=1000-1001") - .build(); - Response range = client.newCall(request).execute(); - assertThat(range.body().string()).isEqualTo("AA"); - - assertThat(get(url).body().string()).isEqualTo("BB"); - } - - /** - * When the server returns a full response body we will store it and return it regardless of what - * its Last-Modified date is. This behavior was different prior to OkHttp 3.5 when we would prefer - * the response with the later Last-Modified date. - * - * https://github.com/square/okhttp/issues/2886 - */ - @Test public void serverReturnsDocumentOlderThanCache() throws Exception { - server.enqueue(new MockResponse.Builder() - .body("A") - .addHeader("Last-Modified: " + formatDate(-2, TimeUnit.HOURS)) - .addHeader("Expires: " + formatDate(-1, TimeUnit.HOURS)) - .build()); - server.enqueue(new MockResponse.Builder() - .body("B") - .addHeader("Last-Modified: " + formatDate(-4, TimeUnit.HOURS)) - .build()); - server.enqueue(new MockResponse.Builder() - .code(HttpURLConnection.HTTP_NOT_MODIFIED) - .build()); - - HttpUrl url = server.url("/"); - - assertThat(get(url).body().string()).isEqualTo("A"); - assertThat(get(url).body().string()).isEqualTo("B"); - assertThat(get(url).body().string()).isEqualTo("B"); - } - - @Test public void clientSideNoStore() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Cache-Control: max-age=60") - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .addHeader("Cache-Control: max-age=60") - .body("B") - .build()); - - Request request1 = new Request.Builder() - .url(server.url("/")) - .cacheControl(new CacheControl.Builder().noStore().build()) - .build(); - Response response1 = client.newCall(request1).execute(); - assertThat(response1.body().string()).isEqualTo("A"); - - Request request2 = new Request.Builder() - .url(server.url("/")) - .build(); - Response response2 = client.newCall(request2).execute(); - assertThat(response2.body().string()).isEqualTo("B"); - } - - @Test public void nonIdentityEncodingAndConditionalCache() throws Exception { - assertNonIdentityEncodingCached(new MockResponse.Builder() - .addHeader("Last-Modified: " + formatDate(-2, TimeUnit.HOURS)) - .addHeader("Expires: " + formatDate(-1, TimeUnit.HOURS)) - .build()); - } - - @Test public void nonIdentityEncodingAndFullCache() throws Exception { - assertNonIdentityEncodingCached(new MockResponse.Builder() - .addHeader("Last-Modified: " + formatDate(-2, TimeUnit.HOURS)) - .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) - .build()); - } - - private void assertNonIdentityEncodingCached(MockResponse response) throws Exception { - server.enqueue(response.newBuilder() - .body(gzip("ABCABCABC")) - .addHeader("Content-Encoding: gzip") - .build()); - server.enqueue(new MockResponse.Builder() - .code(HttpURLConnection.HTTP_NOT_MODIFIED) - .build()); - server.enqueue(new MockResponse.Builder() - .code(HttpURLConnection.HTTP_NOT_MODIFIED) - .build()); - - // At least three request/response pairs are required because after the first request is cached - // a different execution path might be taken. Thus modifications to the cache applied during - // the second request might not be visible until another request is performed. - assertThat(get(server.url("/")).body().string()).isEqualTo("ABCABCABC"); - assertThat(get(server.url("/")).body().string()).isEqualTo("ABCABCABC"); - assertThat(get(server.url("/")).body().string()).isEqualTo("ABCABCABC"); - } - - @Test public void previouslyNotGzippedContentIsNotModifiedAndSpecifiesGzipEncoding() throws Exception { - server.enqueue(new MockResponse.Builder() - .body("ABCABCABC") - .addHeader("Content-Type: text/plain") - .addHeader("Last-Modified: " + formatDate(-2, TimeUnit.HOURS)) - .addHeader("Expires: " + formatDate(-1, TimeUnit.HOURS)) - .build()); - server.enqueue(new MockResponse.Builder() - .code(HttpURLConnection.HTTP_NOT_MODIFIED) - .addHeader("Content-Type: text/plain") - .addHeader("Content-Encoding: gzip") - .build()); - server.enqueue(new MockResponse.Builder() - .body("DEFDEFDEF") - .build()); - - assertThat(get(server.url("/")).body().string()).isEqualTo("ABCABCABC"); - assertThat(get(server.url("/")).body().string()).isEqualTo("ABCABCABC"); - assertThat(get(server.url("/")).body().string()).isEqualTo("DEFDEFDEF"); - } - - @Test public void changedGzippedContentIsNotModifiedAndSpecifiesNewEncoding() throws Exception { - server.enqueue(new MockResponse.Builder() - .body(gzip("ABCABCABC")) - .addHeader("Content-Type: text/plain") - .addHeader("Last-Modified: " + formatDate(-2, TimeUnit.HOURS)) - .addHeader("Expires: " + formatDate(-1, TimeUnit.HOURS)) - .addHeader("Content-Encoding: gzip") - .build()); - server.enqueue(new MockResponse.Builder() - .code(HttpURLConnection.HTTP_NOT_MODIFIED) - .addHeader("Content-Type: text/plain") - .addHeader("Content-Encoding: identity") - .build()); - server.enqueue(new MockResponse.Builder() - .body("DEFDEFDEF") - .build()); - - assertThat(get(server.url("/")).body().string()).isEqualTo("ABCABCABC"); - assertThat(get(server.url("/")).body().string()).isEqualTo("ABCABCABC"); - assertThat(get(server.url("/")).body().string()).isEqualTo("DEFDEFDEF"); - } - - @Test public void notModifiedSpecifiesEncoding() throws Exception { - server.enqueue(new MockResponse.Builder() - .body(gzip("ABCABCABC")) - .addHeader("Content-Encoding: gzip") - .addHeader("Last-Modified: " + formatDate(-2, TimeUnit.HOURS)) - .addHeader("Expires: " + formatDate(-1, TimeUnit.HOURS)) - .build()); - server.enqueue(new MockResponse.Builder() - .code(HttpURLConnection.HTTP_NOT_MODIFIED) - .addHeader("Content-Encoding: gzip") - .build()); - server.enqueue(new MockResponse.Builder() - .body("DEFDEFDEF") - .build()); - - assertThat(get(server.url("/")).body().string()).isEqualTo("ABCABCABC"); - assertThat(get(server.url("/")).body().string()).isEqualTo("ABCABCABC"); - assertThat(get(server.url("/")).body().string()).isEqualTo("DEFDEFDEF"); - } - - /** https://github.com/square/okhttp/issues/947 */ - @Test public void gzipAndVaryOnAcceptEncoding() throws Exception { - server.enqueue(new MockResponse.Builder() - .body(gzip("ABCABCABC")) - .addHeader("Content-Encoding: gzip") - .addHeader("Vary: Accept-Encoding") - .addHeader("Cache-Control: max-age=60") - .build()); - server.enqueue(new MockResponse.Builder() - .body("FAIL") - .build()); - - assertThat(get(server.url("/")).body().string()).isEqualTo("ABCABCABC"); - assertThat(get(server.url("/")).body().string()).isEqualTo("ABCABCABC"); - } - - @Test public void conditionalCacheHitIsNotDoublePooled() throws Exception { - clientTestRule.ensureAllConnectionsReleased(); - - server.enqueue(new MockResponse.Builder() - .addHeader("ETag: v1") - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .clearHeaders() - .code(HttpURLConnection.HTTP_NOT_MODIFIED) - .build()); - - assertThat(get(server.url("/")).body().string()).isEqualTo("A"); - assertThat(get(server.url("/")).body().string()).isEqualTo("A"); - assertThat(client.connectionPool().idleConnectionCount()).isEqualTo(1); - } - - @Test public void expiresDateBeforeModifiedDate() throws Exception { - assertConditionallyCached(new MockResponse.Builder() - .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) - .addHeader("Expires: " + formatDate(-2, TimeUnit.HOURS)) - .build()); - } - - @Test public void requestMaxAge() throws IOException { - server.enqueue(new MockResponse.Builder() - .body("A") - .addHeader("Last-Modified: " + formatDate(-2, TimeUnit.HOURS)) - .addHeader("Date: " + formatDate(-1, TimeUnit.MINUTES)) - .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) - .build()); - server.enqueue(new MockResponse.Builder() - .body("B") - .build()); - - assertThat(get(server.url("/")).body().string()).isEqualTo("A"); - - Request request = new Request.Builder() - .url(server.url("/")) - .header("Cache-Control", "max-age=30") - .build(); - Response response = client.newCall(request).execute(); - assertThat(response.body().string()).isEqualTo("B"); - } - - @Test public void requestMinFresh() throws IOException { - server.enqueue(new MockResponse.Builder() - .body("A") - .addHeader("Cache-Control: max-age=60") - .addHeader("Date: " + formatDate(0, TimeUnit.MINUTES)) - .build()); - server.enqueue(new MockResponse.Builder() - .body("B") - .build()); - - assertThat(get(server.url("/")).body().string()).isEqualTo("A"); - - Request request = new Request.Builder() - .url(server.url("/")) - .header("Cache-Control", "min-fresh=120") - .build(); - Response response = client.newCall(request).execute(); - assertThat(response.body().string()).isEqualTo("B"); - } - - @Test public void requestMaxStale() throws IOException { - server.enqueue(new MockResponse.Builder() - .body("A") - .addHeader("Cache-Control: max-age=120") - .addHeader("Date: " + formatDate(-4, TimeUnit.MINUTES)) - .build()); - server.enqueue(new MockResponse.Builder() - .body("B") - .build()); - - assertThat(get(server.url("/")).body().string()).isEqualTo("A"); - - Request request = new Request.Builder() - .url(server.url("/")) - .header("Cache-Control", "max-stale=180") - .build(); - Response response = client.newCall(request).execute(); - assertThat(response.body().string()).isEqualTo("A"); - assertThat(response.header("Warning")).isEqualTo( - "110 HttpURLConnection \"Response is stale\""); - } - - @Test public void requestMaxStaleDirectiveWithNoValue() throws IOException { - // Add a stale response to the cache. - server.enqueue(new MockResponse.Builder() - .body("A") - .addHeader("Cache-Control: max-age=120") - .addHeader("Date: " + formatDate(-4, TimeUnit.MINUTES)) - .build()); - server.enqueue(new MockResponse.Builder() - .body("B") - .build()); - - assertThat(get(server.url("/")).body().string()).isEqualTo("A"); - - // With max-stale, we'll return that stale response. - Request request = new Request.Builder() - .url(server.url("/")) - .header("Cache-Control", "max-stale") - .build(); - Response response = client.newCall(request).execute(); - assertThat(response.body().string()).isEqualTo("A"); - assertThat(response.header("Warning")).isEqualTo( - "110 HttpURLConnection \"Response is stale\""); - } - - @Test public void requestMaxStaleNotHonoredWithMustRevalidate() throws IOException { - server.enqueue(new MockResponse.Builder() - .body("A") - .addHeader("Cache-Control: max-age=120, must-revalidate") - .addHeader("Date: " + formatDate(-4, TimeUnit.MINUTES)) - .build()); - server.enqueue(new MockResponse.Builder() - .body("B") - .build()); - - assertThat(get(server.url("/")).body().string()).isEqualTo("A"); - - Request request = new Request.Builder() - .url(server.url("/")) - .header("Cache-Control", "max-stale=180") - .build(); - Response response = client.newCall(request).execute(); - assertThat(response.body().string()).isEqualTo("B"); - } - - @Test public void requestOnlyIfCachedWithNoResponseCached() throws IOException { - // (no responses enqueued) - - Request request = new Request.Builder() - .url(server.url("/")) - .header("Cache-Control", "only-if-cached") - .build(); - Response response = client.newCall(request).execute(); - assertThat(response.body().source().exhausted()).isTrue(); - assertThat(response.code()).isEqualTo(504); - assertThat(cache.requestCount()).isEqualTo(1); - assertThat(cache.networkCount()).isEqualTo(0); - assertThat(cache.hitCount()).isEqualTo(0); - } - - @Test public void requestOnlyIfCachedWithFullResponseCached() throws IOException { - server.enqueue(new MockResponse.Builder() - .body("A") - .addHeader("Cache-Control: max-age=30") - .addHeader("Date: " + formatDate(0, TimeUnit.MINUTES)) - .build()); - - assertThat(get(server.url("/")).body().string()).isEqualTo("A"); - Request request = new Request.Builder() - .url(server.url("/")) - .header("Cache-Control", "only-if-cached") - .build(); - Response response = client.newCall(request).execute(); - assertThat(response.body().string()).isEqualTo("A"); - assertThat(cache.requestCount()).isEqualTo(2); - assertThat(cache.networkCount()).isEqualTo(1); - assertThat(cache.hitCount()).isEqualTo(1); - } - - @Test public void requestOnlyIfCachedWithConditionalResponseCached() throws IOException { - server.enqueue(new MockResponse.Builder() - .body("A") - .addHeader("Cache-Control: max-age=30") - .addHeader("Date: " + formatDate(-1, TimeUnit.MINUTES)) - .build()); - - assertThat(get(server.url("/")).body().string()).isEqualTo("A"); - Request request = new Request.Builder() - .url(server.url("/")) - .header("Cache-Control", "only-if-cached") - .build(); - Response response = client.newCall(request).execute(); - assertThat(response.body().source().exhausted()).isTrue(); - assertThat(response.code()).isEqualTo(504); - assertThat(cache.requestCount()).isEqualTo(2); - assertThat(cache.networkCount()).isEqualTo(1); - assertThat(cache.hitCount()).isEqualTo(0); - } - - @Test public void requestOnlyIfCachedWithUnhelpfulResponseCached() throws IOException { - server.enqueue(new MockResponse.Builder() - .body("A") - .build()); - - assertThat(get(server.url("/")).body().string()).isEqualTo("A"); - Request request = new Request.Builder() - .url(server.url("/")) - .header("Cache-Control", "only-if-cached") - .build(); - Response response = client.newCall(request).execute(); - assertThat(response.body().source().exhausted()).isTrue(); - assertThat(response.code()).isEqualTo(504); - assertThat(cache.requestCount()).isEqualTo(2); - assertThat(cache.networkCount()).isEqualTo(1); - assertThat(cache.hitCount()).isEqualTo(0); - } - - @Test public void requestCacheControlNoCache() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Last-Modified: " + formatDate(-120, TimeUnit.SECONDS)) - .addHeader("Date: " + formatDate(0, TimeUnit.SECONDS)) - .addHeader("Cache-Control: max-age=60") - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .body("B") - .build()); - - HttpUrl url = server.url("/"); - assertThat(get(url).body().string()).isEqualTo("A"); - Request request = new Request.Builder() - .url(url) - .header("Cache-Control", "no-cache") - .build(); - Response response = client.newCall(request).execute(); - assertThat(response.body().string()).isEqualTo("B"); - } - - @Test public void requestPragmaNoCache() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Last-Modified: " + formatDate(-120, TimeUnit.SECONDS)) - .addHeader("Date: " + formatDate(0, TimeUnit.SECONDS)) - .addHeader("Cache-Control: max-age=60") - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .body("B") - .build()); - - HttpUrl url = server.url("/"); - assertThat(get(url).body().string()).isEqualTo("A"); - Request request = new Request.Builder() - .url(url) - .header("Pragma", "no-cache") - .build(); - Response response = client.newCall(request).execute(); - assertThat(response.body().string()).isEqualTo("B"); - } - - @Test public void clientSuppliedIfModifiedSinceWithCachedResult() throws Exception { - MockResponse response = new MockResponse.Builder() - .addHeader("ETag: v3") - .addHeader("Cache-Control: max-age=0") - .build(); - String ifModifiedSinceDate = formatDate(-24, TimeUnit.HOURS); - RecordedRequest request = - assertClientSuppliedCondition(response, "If-Modified-Since", ifModifiedSinceDate); - assertThat(request.getHeaders().get("If-Modified-Since")).isEqualTo(ifModifiedSinceDate); - assertThat(request.getHeaders().get("If-None-Match")).isNull(); - } - - @Test public void clientSuppliedIfNoneMatchSinceWithCachedResult() throws Exception { - String lastModifiedDate = formatDate(-3, TimeUnit.MINUTES); - MockResponse response = new MockResponse.Builder() - .addHeader("Last-Modified: " + lastModifiedDate) - .addHeader("Date: " + formatDate(-2, TimeUnit.MINUTES)) - .addHeader("Cache-Control: max-age=0") - .build(); - RecordedRequest request = assertClientSuppliedCondition(response, "If-None-Match", "v1"); - assertThat(request.getHeaders().get("If-None-Match")).isEqualTo("v1"); - assertThat(request.getHeaders().get("If-Modified-Since")).isNull(); - } - - private RecordedRequest assertClientSuppliedCondition(MockResponse seed, String conditionName, - String conditionValue) throws Exception { - server.enqueue(seed.newBuilder() - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .code(HttpURLConnection.HTTP_NOT_MODIFIED) - .build()); - - HttpUrl url = server.url("/"); - assertThat(get(url).body().string()).isEqualTo("A"); - - Request request = new Request.Builder() - .url(url) - .header(conditionName, conditionValue) - .build(); - Response response = client.newCall(request).execute(); - assertThat(response.code()).isEqualTo(HttpURLConnection.HTTP_NOT_MODIFIED); - assertThat(response.body().string()).isEqualTo(""); - - server.takeRequest(); // seed - return server.takeRequest(); - } - - /** - * For Last-Modified and Date headers, we should echo the date back in the exact format we were - * served. - */ - @Test public void retainServedDateFormat() throws Exception { - // Serve a response with a non-standard date format that OkHttp supports. - Date lastModifiedDate = new Date(System.currentTimeMillis() + TimeUnit.HOURS.toMillis(-1)); - Date servedDate = new Date(System.currentTimeMillis() + TimeUnit.HOURS.toMillis(-2)); - DateFormat dateFormat = new SimpleDateFormat("EEE dd-MMM-yyyy HH:mm:ss z", Locale.US); - dateFormat.setTimeZone(TimeZone.getTimeZone("America/New_York")); - String lastModifiedString = dateFormat.format(lastModifiedDate); - String servedString = dateFormat.format(servedDate); - - // This response should be conditionally cached. - server.enqueue(new MockResponse.Builder() - .addHeader("Last-Modified: " + lastModifiedString) - .addHeader("Expires: " + servedString) - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .code(HttpURLConnection.HTTP_NOT_MODIFIED) - .build()); - - assertThat(get(server.url("/")).body().string()).isEqualTo("A"); - assertThat(get(server.url("/")).body().string()).isEqualTo("A"); - - // The first request has no conditions. - RecordedRequest request1 = server.takeRequest(); - assertThat(request1.getHeaders().get("If-Modified-Since")).isNull(); - - // The 2nd request uses the server's date format. - RecordedRequest request2 = server.takeRequest(); - assertThat(request2.getHeaders().get("If-Modified-Since")).isEqualTo(lastModifiedString); - } - - @Test public void clientSuppliedConditionWithoutCachedResult() throws Exception { - server.enqueue(new MockResponse.Builder() - .code(HttpURLConnection.HTTP_NOT_MODIFIED) - .build()); - - Request request = new Request.Builder() - .url(server.url("/")) - .header("If-Modified-Since", formatDate(-24, TimeUnit.HOURS)) - .build(); - Response response = client.newCall(request).execute(); - assertThat(response.code()).isEqualTo(HttpURLConnection.HTTP_NOT_MODIFIED); - assertThat(response.body().string()).isEqualTo(""); - } - - @Test public void authorizationRequestFullyCached() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Cache-Control: max-age=60") - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .body("B") - .build()); - - HttpUrl url = server.url("/"); - Request request = new Request.Builder() - .url(url) - .header("Authorization", "password") - .build(); - Response response = client.newCall(request).execute(); - assertThat(response.body().string()).isEqualTo("A"); - assertThat(get(url).body().string()).isEqualTo("A"); - } - - @Test public void contentLocationDoesNotPopulateCache() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Cache-Control: max-age=60") - .addHeader("Content-Location: /bar") - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .body("B") - .build()); - - assertThat(get(server.url("/foo")).body().string()).isEqualTo("A"); - assertThat(get(server.url("/bar")).body().string()).isEqualTo("B"); - } - - @Test public void connectionIsReturnedToPoolAfterConditionalSuccess() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) - .addHeader("Cache-Control: max-age=0") - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .code(HttpURLConnection.HTTP_NOT_MODIFIED) - .build()); - server.enqueue(new MockResponse.Builder() - .body("B") - .build()); - - assertThat(get(server.url("/a")).body().string()).isEqualTo("A"); - assertThat(get(server.url("/a")).body().string()).isEqualTo("A"); - assertThat(get(server.url("/b")).body().string()).isEqualTo("B"); - - assertThat(server.takeRequest().getSequenceNumber()).isEqualTo(0); - assertThat(server.takeRequest().getSequenceNumber()).isEqualTo(1); - assertThat(server.takeRequest().getSequenceNumber()).isEqualTo(2); - } - - @Test public void statisticsConditionalCacheMiss() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) - .addHeader("Cache-Control: max-age=0") - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .body("B") - .build()); - server.enqueue(new MockResponse.Builder() - .body("C") - .build()); - - assertThat(get(server.url("/")).body().string()).isEqualTo("A"); - assertThat(cache.requestCount()).isEqualTo(1); - assertThat(cache.networkCount()).isEqualTo(1); - assertThat(cache.hitCount()).isEqualTo(0); - assertThat(get(server.url("/")).body().string()).isEqualTo("B"); - assertThat(get(server.url("/")).body().string()).isEqualTo("C"); - assertThat(cache.requestCount()).isEqualTo(3); - assertThat(cache.networkCount()).isEqualTo(3); - assertThat(cache.hitCount()).isEqualTo(0); - } - - @Test public void statisticsConditionalCacheHit() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) - .addHeader("Cache-Control: max-age=0") - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .code(HttpURLConnection.HTTP_NOT_MODIFIED) - .build()); - server.enqueue(new MockResponse.Builder() - .code(HttpURLConnection.HTTP_NOT_MODIFIED) - .build()); - - assertThat(get(server.url("/")).body().string()).isEqualTo("A"); - assertThat(cache.requestCount()).isEqualTo(1); - assertThat(cache.networkCount()).isEqualTo(1); - assertThat(cache.hitCount()).isEqualTo(0); - assertThat(get(server.url("/")).body().string()).isEqualTo("A"); - assertThat(get(server.url("/")).body().string()).isEqualTo("A"); - assertThat(cache.requestCount()).isEqualTo(3); - assertThat(cache.networkCount()).isEqualTo(3); - assertThat(cache.hitCount()).isEqualTo(2); - } - - @Test public void statisticsFullCacheHit() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Cache-Control: max-age=60") - .body("A") - .build()); - - assertThat(get(server.url("/")).body().string()).isEqualTo("A"); - assertThat(cache.requestCount()).isEqualTo(1); - assertThat(cache.networkCount()).isEqualTo(1); - assertThat(cache.hitCount()).isEqualTo(0); - assertThat(get(server.url("/")).body().string()).isEqualTo("A"); - assertThat(get(server.url("/")).body().string()).isEqualTo("A"); - assertThat(cache.requestCount()).isEqualTo(3); - assertThat(cache.networkCount()).isEqualTo(1); - assertThat(cache.hitCount()).isEqualTo(2); - } - - @Test public void varyMatchesChangedRequestHeaderField() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Cache-Control: max-age=60") - .addHeader("Vary: Accept-Language") - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .body("B") - .build()); - - HttpUrl url = server.url("/"); - Request frRequest = new Request.Builder() - .url(url) - .header("Accept-Language", "fr-CA") - .build(); - Response frResponse = client.newCall(frRequest).execute(); - assertThat(frResponse.body().string()).isEqualTo("A"); - - Request enRequest = new Request.Builder() - .url(url) - .header("Accept-Language", "en-US") - .build(); - Response enResponse = client.newCall(enRequest).execute(); - assertThat(enResponse.body().string()).isEqualTo("B"); - } - - @Test public void varyMatchesUnchangedRequestHeaderField() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Cache-Control: max-age=60") - .addHeader("Vary: Accept-Language") - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .body("B") - .build()); - - HttpUrl url = server.url("/"); - Request request = new Request.Builder() - .url(url) - .header("Accept-Language", "fr-CA") - .build(); - Response response1 = client.newCall(request).execute(); - assertThat(response1.body().string()).isEqualTo("A"); - Request request1 = new Request.Builder() - .url(url) - .header("Accept-Language", "fr-CA") - .build(); - Response response2 = client.newCall(request1).execute(); - assertThat(response2.body().string()).isEqualTo("A"); - } - - @Test public void varyMatchesAbsentRequestHeaderField() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Cache-Control: max-age=60") - .addHeader("Vary: Foo") - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .body("B") - .build()); - - assertThat(get(server.url("/")).body().string()).isEqualTo("A"); - assertThat(get(server.url("/")).body().string()).isEqualTo("A"); - } - - @Test public void varyMatchesAddedRequestHeaderField() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Cache-Control: max-age=60") - .addHeader("Vary: Foo") - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .body("B") - .build()); - - assertThat(get(server.url("/")).body().string()).isEqualTo("A"); - Request request = new Request.Builder() - .url(server.url("/")).header("Foo", "bar") - .build(); - Response response = client.newCall(request).execute(); - assertThat(response.body().string()).isEqualTo("B"); - } - - @Test public void varyMatchesRemovedRequestHeaderField() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Cache-Control: max-age=60") - .addHeader("Vary: Foo") - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .body("B") - .build()); - - Request request = new Request.Builder() - .url(server.url("/")).header("Foo", "bar") - .build(); - Response fooresponse = client.newCall(request).execute(); - assertThat(fooresponse.body().string()).isEqualTo("A"); - assertThat(get(server.url("/")).body().string()).isEqualTo("B"); - } - - @Test public void varyFieldsAreCaseInsensitive() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Cache-Control: max-age=60") - .addHeader("Vary: ACCEPT-LANGUAGE") - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .body("B") - .build()); - - HttpUrl url = server.url("/"); - Request request = new Request.Builder() - .url(url) - .header("Accept-Language", "fr-CA") - .build(); - Response response1 = client.newCall(request).execute(); - assertThat(response1.body().string()).isEqualTo("A"); - Request request1 = new Request.Builder() - .url(url) - .header("accept-language", "fr-CA") - .build(); - Response response2 = client.newCall(request1).execute(); - assertThat(response2.body().string()).isEqualTo("A"); - } - - @Test public void varyMultipleFieldsWithMatch() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Cache-Control: max-age=60") - .addHeader("Vary: Accept-Language, Accept-Charset") - .addHeader("Vary: Accept-Encoding") - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .body("B") - .build()); - - HttpUrl url = server.url("/"); - Request request = new Request.Builder() - .url(url) - .header("Accept-Language", "fr-CA") - .header("Accept-Charset", "UTF-8") - .header("Accept-Encoding", "identity") - .build(); - Response response1 = client.newCall(request).execute(); - assertThat(response1.body().string()).isEqualTo("A"); - Request request1 = new Request.Builder() - .url(url) - .header("Accept-Language", "fr-CA") - .header("Accept-Charset", "UTF-8") - .header("Accept-Encoding", "identity") - .build(); - Response response2 = client.newCall(request1).execute(); - assertThat(response2.body().string()).isEqualTo("A"); - } - - @Test public void varyMultipleFieldsWithNoMatch() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Cache-Control: max-age=60") - .addHeader("Vary: Accept-Language, Accept-Charset") - .addHeader("Vary: Accept-Encoding") - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .body("B") - .build()); - - HttpUrl url = server.url("/"); - Request frRequest = new Request.Builder() - .url(url) - .header("Accept-Language", "fr-CA") - .header("Accept-Charset", "UTF-8") - .header("Accept-Encoding", "identity") - .build(); - Response frResponse = client.newCall(frRequest).execute(); - assertThat(frResponse.body().string()).isEqualTo("A"); - Request enRequest = new Request.Builder() - .url(url) - .header("Accept-Language", "en-CA") - .header("Accept-Charset", "UTF-8") - .header("Accept-Encoding", "identity") - .build(); - Response enResponse = client.newCall(enRequest).execute(); - assertThat(enResponse.body().string()).isEqualTo("B"); - } - - @Test public void varyMultipleFieldValuesWithMatch() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Cache-Control: max-age=60") - .addHeader("Vary: Accept-Language") - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .body("B") - .build()); - - HttpUrl url = server.url("/"); - Request request1 = new Request.Builder() - .url(url) - .addHeader("Accept-Language", "fr-CA, fr-FR") - .addHeader("Accept-Language", "en-US") - .build(); - Response response1 = client.newCall(request1).execute(); - assertThat(response1.body().string()).isEqualTo("A"); - - Request request2 = new Request.Builder() - .url(url) - .addHeader("Accept-Language", "fr-CA, fr-FR") - .addHeader("Accept-Language", "en-US") - .build(); - Response response2 = client.newCall(request2).execute(); - assertThat(response2.body().string()).isEqualTo("A"); - } - - @Test public void varyMultipleFieldValuesWithNoMatch() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Cache-Control: max-age=60") - .addHeader("Vary: Accept-Language") - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .body("B") - .build()); - - HttpUrl url = server.url("/"); - Request request1 = new Request.Builder() - .url(url) - .addHeader("Accept-Language", "fr-CA, fr-FR") - .addHeader("Accept-Language", "en-US") - .build(); - Response response1 = client.newCall(request1).execute(); - assertThat(response1.body().string()).isEqualTo("A"); - - Request request2 = new Request.Builder() - .url(url) - .addHeader("Accept-Language", "fr-CA") - .addHeader("Accept-Language", "en-US") - .build(); - Response response2 = client.newCall(request2).execute(); - assertThat(response2.body().string()).isEqualTo("B"); - } - - @Test public void varyAsterisk() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Cache-Control: max-age=60") - .addHeader("Vary: *") - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .body("B") - .build()); - - assertThat(get(server.url("/")).body().string()).isEqualTo("A"); - assertThat(get(server.url("/")).body().string()).isEqualTo("B"); - } - - @Test public void varyAndHttps() throws Exception { - server.useHttps(handshakeCertificates.sslSocketFactory()); - server.enqueue(new MockResponse.Builder() - .addHeader("Cache-Control: max-age=60") - .addHeader("Vary: Accept-Language") - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .body("B") - .build()); - - client = client.newBuilder() - .sslSocketFactory( - handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager()) - .hostnameVerifier(NULL_HOSTNAME_VERIFIER) - .build(); - - HttpUrl url = server.url("/"); - Request request1 = new Request.Builder() - .url(url) - .header("Accept-Language", "en-US") - .build(); - Response response1 = client.newCall(request1).execute(); - assertThat(response1.body().string()).isEqualTo("A"); - - Request request2 = new Request.Builder() - .url(url) - .header("Accept-Language", "en-US") - .build(); - Response response2 = client.newCall(request2).execute(); - assertThat(response2.body().string()).isEqualTo("A"); - } - - @Test public void cachePlusCookies() throws Exception { - RecordingCookieJar cookieJar = new RecordingCookieJar(); - client = client.newBuilder() - .cookieJar(cookieJar) - .build(); - - server.enqueue(new MockResponse.Builder() - .addHeader("Set-Cookie: a=FIRST") - .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) - .addHeader("Cache-Control: max-age=0") - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .addHeader("Set-Cookie: a=SECOND") - .code(HttpURLConnection.HTTP_NOT_MODIFIED) - .build()); - - HttpUrl url = server.url("/"); - assertThat(get(url).body().string()).isEqualTo("A"); - cookieJar.assertResponseCookies("a=FIRST; path=/"); - assertThat(get(url).body().string()).isEqualTo("A"); - cookieJar.assertResponseCookies("a=SECOND; path=/"); - } - - @Test public void getHeadersReturnsNetworkEndToEndHeaders() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Allow: GET, HEAD") - .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) - .addHeader("Cache-Control: max-age=0") - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .addHeader("Allow: GET, HEAD, PUT") - .code(HttpURLConnection.HTTP_NOT_MODIFIED) - .build()); - - Response response1 = get(server.url("/")); - assertThat(response1.body().string()).isEqualTo("A"); - assertThat(response1.header("Allow")).isEqualTo("GET, HEAD"); - - Response response2 = get(server.url("/")); - assertThat(response2.body().string()).isEqualTo("A"); - assertThat(response2.header("Allow")).isEqualTo("GET, HEAD, PUT"); - } - - @Test public void getHeadersReturnsCachedHopByHopHeaders() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Transfer-Encoding: identity") - .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) - .addHeader("Cache-Control: max-age=0") - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .addHeader("Transfer-Encoding: none") - .code(HttpURLConnection.HTTP_NOT_MODIFIED) - .build()); - - Response response1 = get(server.url("/")); - assertThat(response1.body().string()).isEqualTo("A"); - assertThat(response1.header("Transfer-Encoding")).isEqualTo("identity"); - - Response response2 = get(server.url("/")); - assertThat(response2.body().string()).isEqualTo("A"); - assertThat(response2.header("Transfer-Encoding")).isEqualTo("identity"); - } - - @Test public void getHeadersDeletesCached100LevelWarnings() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Warning: 199 test danger") - .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) - .addHeader("Cache-Control: max-age=0") - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .code(HttpURLConnection.HTTP_NOT_MODIFIED) - .build()); - - Response response1 = get(server.url("/")); - assertThat(response1.body().string()).isEqualTo("A"); - assertThat(response1.header("Warning")).isEqualTo("199 test danger"); - - Response response2 = get(server.url("/")); - assertThat(response2.body().string()).isEqualTo("A"); - assertThat(response2.header("Warning")).isNull(); - } - - @Test public void getHeadersRetainsCached200LevelWarnings() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Warning: 299 test danger") - .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) - .addHeader("Cache-Control: max-age=0") - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .code(HttpURLConnection.HTTP_NOT_MODIFIED) - .build()); - - Response response1 = get(server.url("/")); - assertThat(response1.body().string()).isEqualTo("A"); - assertThat(response1.header("Warning")).isEqualTo("299 test danger"); - - Response response2 = get(server.url("/")); - assertThat(response2.body().string()).isEqualTo("A"); - assertThat(response2.header("Warning")).isEqualTo("299 test danger"); - } - - @Test public void doNotCachePartialResponse() throws Exception { - assertNotCached(new MockResponse.Builder() - .code(HttpURLConnection.HTTP_PARTIAL) - .addHeader("Date: " + formatDate(0, TimeUnit.HOURS)) - .addHeader("Content-Range: bytes 100-100/200") - .addHeader("Cache-Control: max-age=60") - .build()); - } - - @Test public void conditionalHitUpdatesCache() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Last-Modified: " + formatDate(0, TimeUnit.SECONDS)) - .addHeader("Cache-Control: max-age=0") - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .addHeader("Cache-Control: max-age=30") - .addHeader("Allow: GET, HEAD") - .code(HttpURLConnection.HTTP_NOT_MODIFIED) - .build()); - server.enqueue(new MockResponse.Builder() - .body("B") - .build()); - - // A cache miss writes the cache. - long t0 = System.currentTimeMillis(); - Response response1 = get(server.url("/a")); - assertThat(response1.body().string()).isEqualTo("A"); - assertThat(response1.header("Allow")).isNull(); - assertThat((double) (response1.receivedResponseAtMillis() - t0)).isCloseTo(0, offset(250.0)); - - // A conditional cache hit updates the cache. - Thread.sleep(500); // Make sure t0 and t1 are distinct. - long t1 = System.currentTimeMillis(); - Response response2 = get(server.url("/a")); - assertThat(response2.code()).isEqualTo(HttpURLConnection.HTTP_OK); - assertThat(response2.body().string()).isEqualTo("A"); - assertThat(response2.header("Allow")).isEqualTo("GET, HEAD"); - Long updatedTimestamp = response2.receivedResponseAtMillis(); - assertThat((double) (updatedTimestamp - t1)).isCloseTo(0, offset(250.0)); - - // A full cache hit reads the cache. - Thread.sleep(10); - Response response3 = get(server.url("/a")); - assertThat(response3.body().string()).isEqualTo("A"); - assertThat(response3.header("Allow")).isEqualTo("GET, HEAD"); - assertThat(response3.receivedResponseAtMillis()).isEqualTo(updatedTimestamp); - - assertThat(server.getRequestCount()).isEqualTo(2); - } - - @Test public void responseSourceHeaderCached() throws IOException { - server.enqueue(new MockResponse.Builder() - .body("A") - .addHeader("Cache-Control: max-age=30") - .addHeader("Date: " + formatDate(0, TimeUnit.MINUTES)) - .build()); - - assertThat(get(server.url("/")).body().string()).isEqualTo("A"); - Request request = new Request.Builder() - .url(server.url("/")).header("Cache-Control", "only-if-cached") - .build(); - Response response = client.newCall(request).execute(); - assertThat(response.body().string()).isEqualTo("A"); - } - - @Test public void responseSourceHeaderConditionalCacheFetched() throws IOException { - server.enqueue(new MockResponse.Builder() - .body("A") - .addHeader("Cache-Control: max-age=30") - .addHeader("Date: " + formatDate(-31, TimeUnit.MINUTES)) - .build()); - server.enqueue(new MockResponse.Builder() - .body("B") - .addHeader("Cache-Control: max-age=30") - .addHeader("Date: " + formatDate(0, TimeUnit.MINUTES)) - .build()); - - assertThat(get(server.url("/")).body().string()).isEqualTo("A"); - Response response = get(server.url("/")); - assertThat(response.body().string()).isEqualTo("B"); - } - - @Test public void responseSourceHeaderConditionalCacheNotFetched() throws IOException { - server.enqueue(new MockResponse.Builder() - .body("A") - .addHeader("Cache-Control: max-age=0") - .addHeader("Date: " + formatDate(0, TimeUnit.MINUTES)) - .build()); - server.enqueue(new MockResponse.Builder() - .code(304) - .build()); - - assertThat(get(server.url("/")).body().string()).isEqualTo("A"); - Response response = get(server.url("/")); - assertThat(response.body().string()).isEqualTo("A"); - } - - @Test public void responseSourceHeaderFetched() throws IOException { - server.enqueue(new MockResponse.Builder() - .body("A") - .build()); - - Response response = get(server.url("/")); - assertThat(response.body().string()).isEqualTo("A"); - } - - @Test public void emptyResponseHeaderNameFromCacheIsLenient() throws Exception { - Headers.Builder headers = new Headers.Builder() - .add("Cache-Control: max-age=120"); - Internal.addHeaderLenient(headers, ": A"); - server.enqueue(new MockResponse.Builder() - .headers(headers.build()) - .body("body") - .build()); - - Response response = get(server.url("/")); - assertThat(response.header("")).isEqualTo("A"); - assertThat(response.body().string()).isEqualTo("body"); - } - - /** - * Old implementations of OkHttp's response cache wrote header fields like ":status: 200 OK". This - * broke our cached response parser because it split on the first colon. This regression test - * exists to help us read these old bad cache entries. - * - * https://github.com/square/okhttp/issues/227 - */ - @Test public void testGoldenCacheResponse() throws Exception { - cache.close(); - server.enqueue(new MockResponse.Builder() - .clearHeaders() - .code(HttpURLConnection.HTTP_NOT_MODIFIED) - .build()); - - HttpUrl url = server.url("/"); - String urlKey = Cache.key(url); - String entryMetadata = "" - + "" + url + "\n" - + "GET\n" - + "0\n" - + "HTTP/1.1 200 OK\n" - + "7\n" - + ":status: 200 OK\n" - + ":version: HTTP/1.1\n" - + "etag: foo\n" - + "content-length: 3\n" - + "OkHttp-Received-Millis: " + System.currentTimeMillis() + "\n" - + "X-Android-Response-Source: NETWORK 200\n" - + "OkHttp-Sent-Millis: " + System.currentTimeMillis() + "\n" - + "\n" - + "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA\n" - + "1\n" - + "MIIBpDCCAQ2gAwIBAgIBATANBgkqhkiG9w0BAQsFADAYMRYwFAYDVQQDEw1qd2lsc29uLmxvY2FsMB4XDTEzMDgy" - + "OTA1MDE1OVoXDTEzMDgzMDA1MDE1OVowGDEWMBQGA1UEAxMNandpbHNvbi5sb2NhbDCBnzANBgkqhkiG9w0BAQEF" - + "AAOBjQAwgYkCgYEAlFW+rGo/YikCcRghOyKkJanmVmJSce/p2/jH1QvNIFKizZdh8AKNwojt3ywRWaDULA/RlCUc" - + "ltF3HGNsCyjQI/+Lf40x7JpxXF8oim1E6EtDoYtGWAseelawus3IQ13nmo6nWzfyCA55KhAWf4VipelEy8DjcuFK" - + "v6L0xwXnI0ECAwEAATANBgkqhkiG9w0BAQsFAAOBgQAuluNyPo1HksU3+Mr/PyRQIQS4BI7pRXN8mcejXmqyscdP" - + "7S6J21FBFeRR8/XNjVOp4HT9uSc2hrRtTEHEZCmpyoxixbnM706ikTmC7SN/GgM+SmcoJ1ipJcNcl8N0X6zym4dm" - + "yFfXKHu2PkTo7QFdpOJFvP3lIigcSZXozfmEDg==\n" - + "-1\n"; - String entryBody = "abc"; - String journalBody = "" - + "libcore.io.DiskLruCache\n" - + "1\n" - + "201105\n" - + "2\n" - + "\n" - + "CLEAN " + urlKey + " " + entryMetadata.length() + " " + entryBody.length() + "\n"; - fileSystem.createDirectory(cache.directoryPath()); - writeFile(cache.directoryPath(), urlKey + ".0", entryMetadata); - writeFile(cache.directoryPath(), urlKey + ".1", entryBody); - writeFile(cache.directoryPath(), "journal", journalBody); - cache = new Cache(Path.get(cache.directory().getPath()), Integer.MAX_VALUE, fileSystem); - client = client.newBuilder() - .cache(cache) - .build(); - - Response response = get(url); - assertThat(response.body().string()).isEqualTo(entryBody); - assertThat(response.header("Content-Length")).isEqualTo("3"); - assertThat(response.header("etag")).isEqualTo("foo"); - } - - /** Exercise the cache format in OkHttp 2.7 and all earlier releases. */ - @Test public void testGoldenCacheHttpsResponseOkHttp27() throws Exception { - HttpUrl url = server.url("/"); - String urlKey = Cache.key(url); - String prefix = Platform.get().getPrefix(); - String entryMetadata = "" - + "" + url + "\n" - + "GET\n" - + "0\n" - + "HTTP/1.1 200 OK\n" - + "4\n" - + "Content-Length: 3\n" - + prefix + "-Received-Millis: " + System.currentTimeMillis() + "\n" - + prefix + "-Sent-Millis: " + System.currentTimeMillis() + "\n" - + "Cache-Control: max-age=60\n" - + "\n" - + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256\n" - + "1\n" - + "MIIBnDCCAQWgAwIBAgIBATANBgkqhkiG9w0BAQsFADAUMRIwEAYDVQQDEwlsb2NhbGhvc3QwHhcNMTUxMjIyMDEx" - + "MTQwWhcNMTUxMjIzMDExMTQwWjAUMRIwEAYDVQQDEwlsb2NhbGhvc3QwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJ" - + "AoGBAJTn2Dh8xYmegvpOSmsKb2Os6Cxf1L4fYbnHr/turInUD5r1P7ZAuxurY880q3GT5bUDoirS3IfucddrT1Ac" - + "AmUzEmk/FDjggiP8DlxFkY/XwXBlhRDVIp/mRuASPMGInckc0ZaixOkRFyrxADj+r1eaSmXCIvV5yTY6IaIokLj1" - + "AgMBAAEwDQYJKoZIhvcNAQELBQADgYEAFblnedqtfRqI9j2WDyPPoG0NTZf9xwjeUu+ju+Ktty8u9k7Lgrrd/DH2" - + "mQEtBD1Ctvp91MJfAClNg3faZzwClUyu5pd0QXRZEUwSwZQNen2QWDHRlVsItclBJ4t+AJLqTbwofWi4m4K8REOl" - + "593hD55E4+lY22JZiVQyjsQhe6I=\n" - + "0\n"; - String entryBody = "abc"; - String journalBody = "" - + "libcore.io.DiskLruCache\n" - + "1\n" - + "201105\n" - + "2\n" - + "\n" - + "DIRTY " + urlKey + "\n" - + "CLEAN " + urlKey + " " + entryMetadata.length() + " " + entryBody.length() + "\n"; - fileSystem.createDirectory(cache.directoryPath()); - writeFile(cache.directoryPath(), urlKey + ".0", entryMetadata); - writeFile(cache.directoryPath(), urlKey + ".1", entryBody); - writeFile(cache.directoryPath(), "journal", journalBody); - cache.close(); - cache = new Cache(Path.get(cache.directory().getPath()), Integer.MAX_VALUE, fileSystem); - client = client.newBuilder() - .cache(cache) - .build(); - - Response response = get(url); - assertThat(response.body().string()).isEqualTo(entryBody); - assertThat(response.header("Content-Length")).isEqualTo("3"); - } - - /** The TLS version is present in OkHttp 3.0 and beyond. */ - @Test public void testGoldenCacheHttpsResponseOkHttp30() throws Exception { - HttpUrl url = server.url("/"); - String urlKey = Cache.key(url); - String prefix = Platform.get().getPrefix(); - String entryMetadata = "" - + "" + url + "\n" - + "GET\n" - + "0\n" - + "HTTP/1.1 200 OK\n" - + "4\n" - + "Content-Length: 3\n" - + prefix + "-Received-Millis: " + System.currentTimeMillis() + "\n" - + prefix + "-Sent-Millis: " + System.currentTimeMillis() + "\n" - + "Cache-Control: max-age=60\n" - + "\n" - + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256\n" - + "1\n" - + "MIIBnDCCAQWgAwIBAgIBATANBgkqhkiG9w0BAQsFADAUMRIwEAYDVQQDEwlsb2NhbGhvc3QwHhcNMTUxMjIyMDEx" - + "MTQwWhcNMTUxMjIzMDExMTQwWjAUMRIwEAYDVQQDEwlsb2NhbGhvc3QwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJ" - + "AoGBAJTn2Dh8xYmegvpOSmsKb2Os6Cxf1L4fYbnHr/turInUD5r1P7ZAuxurY880q3GT5bUDoirS3IfucddrT1Ac" - + "AmUzEmk/FDjggiP8DlxFkY/XwXBlhRDVIp/mRuASPMGInckc0ZaixOkRFyrxADj+r1eaSmXCIvV5yTY6IaIokLj1" - + "AgMBAAEwDQYJKoZIhvcNAQELBQADgYEAFblnedqtfRqI9j2WDyPPoG0NTZf9xwjeUu+ju+Ktty8u9k7Lgrrd/DH2" - + "mQEtBD1Ctvp91MJfAClNg3faZzwClUyu5pd0QXRZEUwSwZQNen2QWDHRlVsItclBJ4t+AJLqTbwofWi4m4K8REOl" - + "593hD55E4+lY22JZiVQyjsQhe6I=\n" - + "0\n" - + "TLSv1.2\n"; - String entryBody = "abc"; - String journalBody = "" - + "libcore.io.DiskLruCache\n" - + "1\n" - + "201105\n" - + "2\n" - + "\n" - + "DIRTY " + urlKey + "\n" - + "CLEAN " + urlKey + " " + entryMetadata.length() + " " + entryBody.length() + "\n"; - fileSystem.createDirectory(cache.directoryPath()); - writeFile(cache.directoryPath(), urlKey + ".0", entryMetadata); - writeFile(cache.directoryPath(), urlKey + ".1", entryBody); - writeFile(cache.directoryPath(), "journal", journalBody); - cache.close(); - cache = new Cache(Path.get(cache.directory().getPath()), Integer.MAX_VALUE, fileSystem); - client = client.newBuilder() - .cache(cache) - .build(); - - Response response = get(url); - assertThat(response.body().string()).isEqualTo(entryBody); - assertThat(response.header("Content-Length")).isEqualTo("3"); - } - - @Test public void testGoldenCacheHttpResponseOkHttp30() throws Exception { - HttpUrl url = server.url("/"); - String urlKey = Cache.key(url); - String prefix = Platform.get().getPrefix(); - String entryMetadata = "" - + "" + url + "\n" - + "GET\n" - + "0\n" - + "HTTP/1.1 200 OK\n" - + "4\n" - + "Cache-Control: max-age=60\n" - + "Content-Length: 3\n" - + prefix + "-Received-Millis: " + System.currentTimeMillis() + "\n" - + prefix + "-Sent-Millis: " + System.currentTimeMillis() + "\n"; - String entryBody = "abc"; - String journalBody = "" - + "libcore.io.DiskLruCache\n" - + "1\n" - + "201105\n" - + "2\n" - + "\n" - + "DIRTY " + urlKey + "\n" - + "CLEAN " + urlKey + " " + entryMetadata.length() + " " + entryBody.length() + "\n"; - fileSystem.createDirectory(cache.directoryPath()); - writeFile(cache.directoryPath(), urlKey + ".0", entryMetadata); - writeFile(cache.directoryPath(), urlKey + ".1", entryBody); - writeFile(cache.directoryPath(), "journal", journalBody); - cache.close(); - cache = new Cache(Path.get(cache.directory().getPath()), Integer.MAX_VALUE, fileSystem); - client = client.newBuilder() - .cache(cache) - .build(); - - Response response = get(url); - assertThat(response.body().string()).isEqualTo(entryBody); - assertThat(response.header("Content-Length")).isEqualTo("3"); - } - - @Test public void evictAll() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Cache-Control: max-age=60") - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .body("B") - .build()); - - HttpUrl url = server.url("/"); - assertThat(get(url).body().string()).isEqualTo("A"); - client.cache().evictAll(); - assertThat(client.cache().size()).isEqualTo(0); - assertThat(get(url).body().string()).isEqualTo("B"); - } - - @Test public void networkInterceptorInvokedForConditionalGet() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("ETag: v1") - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .code(HttpURLConnection.HTTP_NOT_MODIFIED) - .build()); - - // Seed the cache. - HttpUrl url = server.url("/"); - assertThat(get(url).body().string()).isEqualTo("A"); - - final AtomicReference ifNoneMatch = new AtomicReference<>(); - client = client.newBuilder() - .addNetworkInterceptor(chain -> { - ifNoneMatch.compareAndSet(null, chain.request().header("If-None-Match")); - return chain.proceed(chain.request()); - }) - .build(); - - // Confirm the value is cached and intercepted. - assertThat(get(url).body().string()).isEqualTo("A"); - assertThat(ifNoneMatch.get()).isEqualTo("v1"); - } - - @Test public void networkInterceptorNotInvokedForFullyCached() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Cache-Control: max-age=60") - .body("A") - .build()); - - // Seed the cache. - HttpUrl url = server.url("/"); - assertThat(get(url).body().string()).isEqualTo("A"); - - // Confirm the interceptor isn't exercised. - client = client.newBuilder() - .addNetworkInterceptor(chain -> { throw new AssertionError(); }) - .build(); - assertThat(get(url).body().string()).isEqualTo("A"); - } - - @Test public void iterateCache() throws Exception { - // Put some responses in the cache. - server.enqueue(new MockResponse.Builder() - .body("a") - .build()); - HttpUrl urlA = server.url("/a"); - assertThat(get(urlA).body().string()).isEqualTo("a"); - - server.enqueue(new MockResponse.Builder() - .body("b") - .build()); - HttpUrl urlB = server.url("/b"); - assertThat(get(urlB).body().string()).isEqualTo("b"); - - server.enqueue(new MockResponse.Builder() - .body("c") - .build()); - HttpUrl urlC = server.url("/c"); - assertThat(get(urlC).body().string()).isEqualTo("c"); - - // Confirm the iterator returns those responses... - Iterator i = cache.urls(); - assertThat(i.hasNext()).isTrue(); - assertThat(i.next()).isEqualTo(urlA.toString()); - assertThat(i.hasNext()).isTrue(); - assertThat(i.next()).isEqualTo(urlB.toString()); - assertThat(i.hasNext()).isTrue(); - assertThat(i.next()).isEqualTo(urlC.toString()); - - // ... and nothing else. - assertThat(i.hasNext()).isFalse(); - try { - i.next(); - fail(); - } catch (NoSuchElementException expected) { - } - } - - @Test public void iteratorRemoveFromCache() throws Exception { - // Put a response in the cache. - server.enqueue(new MockResponse.Builder() - .addHeader("Cache-Control: max-age=60") - .body("a") - .build()); - HttpUrl url = server.url("/a"); - assertThat(get(url).body().string()).isEqualTo("a"); - - // Remove it with iteration. - Iterator i = cache.urls(); - assertThat(i.next()).isEqualTo(url.toString()); - i.remove(); - - // Confirm that subsequent requests suffer a cache miss. - server.enqueue(new MockResponse.Builder() - .body("b") - .build()); - assertThat(get(url).body().string()).isEqualTo("b"); - } - - @Test public void iteratorRemoveWithoutNextThrows() throws Exception { - // Put a response in the cache. - server.enqueue(new MockResponse.Builder() - .body("a") - .build()); - HttpUrl url = server.url("/a"); - assertThat(get(url).body().string()).isEqualTo("a"); - - Iterator i = cache.urls(); - assertThat(i.hasNext()).isTrue(); - try { - i.remove(); - fail(); - } catch (IllegalStateException expected) { - } - } - - @Test public void iteratorRemoveOncePerCallToNext() throws Exception { - // Put a response in the cache. - server.enqueue(new MockResponse.Builder() - .body("a") - .build()); - HttpUrl url = server.url("/a"); - assertThat(get(url).body().string()).isEqualTo("a"); - - Iterator i = cache.urls(); - assertThat(i.next()).isEqualTo(url.toString()); - i.remove(); - - // Too many calls to remove(). - try { - i.remove(); - fail(); - } catch (IllegalStateException expected) { - } - } - - @Test public void elementEvictedBetweenHasNextAndNext() throws Exception { - // Put a response in the cache. - server.enqueue(new MockResponse.Builder() - .body("a") - .build()); - HttpUrl url = server.url("/a"); - assertThat(get(url).body().string()).isEqualTo("a"); - - // The URL will remain available if hasNext() returned true... - Iterator i = cache.urls(); - assertThat(i.hasNext()).isTrue(); - - // ...so even when we evict the element, we still get something back. - cache.evictAll(); - assertThat(i.next()).isEqualTo(url.toString()); - - // Remove does nothing. But most importantly, it doesn't throw! - i.remove(); - } - - @Test public void elementEvictedBeforeHasNextIsOmitted() throws Exception { - // Put a response in the cache. - server.enqueue(new MockResponse.Builder() - .body("a") - .build()); - HttpUrl url = server.url("/a"); - assertThat(get(url).body().string()).isEqualTo("a"); - - Iterator i = cache.urls(); - cache.evictAll(); - - // The URL was evicted before hasNext() made any promises. - assertThat(i.hasNext()).isFalse(); - try { - i.next(); - fail(); - } catch (NoSuchElementException expected) { - } - } - - /** Test https://github.com/square/okhttp/issues/1712. */ - @Test public void conditionalMissUpdatesCache() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("ETag: v1") - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .code(HttpURLConnection.HTTP_NOT_MODIFIED) - .build()); - server.enqueue(new MockResponse.Builder() - .addHeader("ETag: v2") - .body("B") - .build()); - server.enqueue(new MockResponse.Builder() - .code(HttpURLConnection.HTTP_NOT_MODIFIED) - .build()); - - HttpUrl url = server.url("/"); - assertThat(get(url).body().string()).isEqualTo("A"); - assertThat(get(url).body().string()).isEqualTo("A"); - assertThat(get(url).body().string()).isEqualTo("B"); - assertThat(get(url).body().string()).isEqualTo("B"); - - assertThat(server.takeRequest().getHeaders().get("If-None-Match")).isNull(); - assertThat(server.takeRequest().getHeaders().get("If-None-Match")).isEqualTo("v1"); - assertThat(server.takeRequest().getHeaders().get("If-None-Match")).isEqualTo("v1"); - assertThat(server.takeRequest().getHeaders().get("If-None-Match")).isEqualTo("v2"); - } - - @Test public void combinedCacheHeadersCanBeNonAscii() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) - .addHeader("Cache-Control: max-age=0") - .addHeaderLenient("Alpha", "α") - .addHeaderLenient("β", "Beta") - .body("abcd") - .build()); - server.enqueue(new MockResponse.Builder() - .addHeader("Transfer-Encoding: none") - .addHeaderLenient("Gamma", "Γ") - .addHeaderLenient("Δ", "Delta") - .code(HttpURLConnection.HTTP_NOT_MODIFIED) - .build()); - - Response response1 = get(server.url("/")); - assertThat(response1.header("Alpha")).isEqualTo("α"); - assertThat(response1.header("β")).isEqualTo("Beta"); - assertThat(response1.body().string()).isEqualTo("abcd"); - - Response response2 = get(server.url("/")); - assertThat(response2.header("Alpha")).isEqualTo("α"); - assertThat(response2.header("β")).isEqualTo("Beta"); - assertThat(response2.header("Gamma")).isEqualTo("Γ"); - assertThat(response2.header("Δ")).isEqualTo("Delta"); - assertThat(response2.body().string()).isEqualTo("abcd"); - } - - @Test public void etagConditionCanBeNonAscii() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeaderLenient("Etag", "α") - .addHeader("Cache-Control: max-age=0") - .body("abcd") - .build()); - server.enqueue(new MockResponse.Builder() - .code(HttpURLConnection.HTTP_NOT_MODIFIED) - .build()); - - Response response1 = get(server.url("/")); - assertThat(response1.body().string()).isEqualTo("abcd"); - - Response response2 = get(server.url("/")); - assertThat(response2.body().string()).isEqualTo("abcd"); - - assertThat(server.takeRequest().getHeaders().get("If-None-Match")).isNull(); - assertThat(server.takeRequest().getHeaders().get("If-None-Match")).isEqualTo("α"); - } - - @Test public void conditionalHitHeadersCombined() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Etag", "a") - .addHeader("Cache-Control: max-age=0") - .addHeader("A: a1") - .addHeader("B: b2") - .addHeader("B: b3") - .body("abcd") - .build()); - server.enqueue(new MockResponse.Builder() - .code(HttpURLConnection.HTTP_NOT_MODIFIED) - .addHeader("B: b4") - .addHeader("B: b5") - .addHeader("C: c6") - .build()); - - Response response1 = get(server.url("/")); - assertThat(response1.body().string()).isEqualTo("abcd"); - assertThat(response1.headers()).isEqualTo(Headers.of("Etag", "a", "Cache-Control", "max-age=0", - "A", "a1", "B", "b2", "B", "b3", "Content-Length", "4")); - - // The original 'A' header is retained because the network response doesn't have one. - // The original 'B' headers are replaced by the network response. - // The network's 'C' header is added. - Response response2 = get(server.url("/")); - assertThat(response2.body().string()).isEqualTo("abcd"); - assertThat(response2.headers()).isEqualTo(Headers.of("Etag", "a", "Cache-Control", "max-age=0", - "A", "a1", "Content-Length", "4", "B", "b4", "B", "b5", "C", "c6")); - } - - private Response get(HttpUrl url) throws IOException { - Request request = new Request.Builder() - .url(url) - .build(); - return client.newCall(request).execute(); - } - - private void writeFile(Path directory, String file, String content) throws IOException { - BufferedSink sink = Okio.buffer(fileSystem.sink(directory.resolve(file))); - sink.writeUtf8(content); - sink.close(); - } - - /** - * @param delta the offset from the current date to use. Negative values yield dates in the past; - * positive values yield dates in the future. - */ - private String formatDate(long delta, TimeUnit timeUnit) { - return formatDate(new Date(System.currentTimeMillis() + timeUnit.toMillis(delta))); - } - - private String formatDate(Date date) { - DateFormat rfc1123 = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss zzz", Locale.US); - rfc1123.setTimeZone(TimeZone.getTimeZone("GMT")); - return rfc1123.format(date); - } - - private void assertNotCached(MockResponse response) throws Exception { - server.enqueue(response.newBuilder() - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .body("B") - .build()); - - HttpUrl url = server.url("/"); - assertThat(get(url).body().string()).isEqualTo("A"); - assertThat(get(url).body().string()).isEqualTo("B"); - } - - /** @return the request with the conditional get headers. */ - private RecordedRequest assertConditionallyCached(MockResponse response) throws Exception { - // scenario 1: condition succeeds - server.enqueue(response.newBuilder() - .body("A") - .status("HTTP/1.1 200 A-OK") - .build()); - server.enqueue(new MockResponse.Builder() - .code(HttpURLConnection.HTTP_NOT_MODIFIED) - .build()); - - // scenario 2: condition fails - server.enqueue(response.newBuilder() - .body("B") - .status("HTTP/1.1 200 B-OK") - .build()); - server.enqueue(new MockResponse.Builder() - .status("HTTP/1.1 200 C-OK") - .body("C") - .build()); - - HttpUrl valid = server.url("/valid"); - Response response1 = get(valid); - assertThat(response1.body().string()).isEqualTo("A"); - assertThat(response1.code()).isEqualTo(HttpURLConnection.HTTP_OK); - assertThat(response1.message()).isEqualTo("A-OK"); - Response response2 = get(valid); - assertThat(response2.body().string()).isEqualTo("A"); - assertThat(response2.code()).isEqualTo(HttpURLConnection.HTTP_OK); - assertThat(response2.message()).isEqualTo("A-OK"); - - HttpUrl invalid = server.url("/invalid"); - Response response3 = get(invalid); - assertThat(response3.body().string()).isEqualTo("B"); - assertThat(response3.code()).isEqualTo(HttpURLConnection.HTTP_OK); - assertThat(response3.message()).isEqualTo("B-OK"); - Response response4 = get(invalid); - assertThat(response4.body().string()).isEqualTo("C"); - assertThat(response4.code()).isEqualTo(HttpURLConnection.HTTP_OK); - assertThat(response4.message()).isEqualTo("C-OK"); - - server.takeRequest(); // regular get - return server.takeRequest(); // conditional get - } - - @Test public void immutableIsCached() throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Cache-Control", "immutable, max-age=10") - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .body("B") - .build()); - - HttpUrl url = server.url("/"); - assertThat(get(url).body().string()).isEqualTo("A"); - assertThat(get(url).body().string()).isEqualTo("A"); - } - - @Test public void immutableIsCachedAfterMultipleCalls() throws Exception { - server.enqueue(new MockResponse.Builder() - .body("A") - .build()); - server.enqueue(new MockResponse.Builder() - .addHeader("Cache-Control", "immutable, max-age=10") - .body("B") - .build()); - server.enqueue(new MockResponse.Builder() - .body("C") - .build()); - - HttpUrl url = server.url("/"); - assertThat(get(url).body().string()).isEqualTo("A"); - assertThat(get(url).body().string()).isEqualTo("B"); - assertThat(get(url).body().string()).isEqualTo("B"); - } - - @Test public void immutableIsNotCachedBeyondFreshnessLifetime() throws Exception { - // last modified: 115 seconds ago - // served: 15 seconds ago - // default lifetime: (115 - 15) / 10 = 10 seconds - // expires: 10 seconds from served date = 5 seconds ago - String lastModifiedDate = formatDate(-115, TimeUnit.SECONDS); - RecordedRequest conditionalRequest = assertConditionallyCached(new MockResponse.Builder() - .addHeader("Cache-Control: immutable") - .addHeader("Last-Modified: " + lastModifiedDate) - .addHeader("Date: " + formatDate(-15, TimeUnit.SECONDS)) - .build()); - assertThat(conditionalRequest.getHeaders().get("If-Modified-Since")) - .isEqualTo(lastModifiedDate); - } - - @Test - public void testPublicPathConstructor() throws IOException { - List events = new ArrayList<>(); - - fileSystem.createDirectories(cache.directoryPath()); - - fileSystem.createDirectories(cache.directoryPath()); - - FileSystem loggingFileSystem = new ForwardingFileSystem(fileSystem) { - @Override - public Path onPathParameter(Path path, java.lang.String functionName, java.lang.String parameterName) { - events.add(functionName + ":" + path); - return path; - } - - @Override - public Path onPathResult(Path path, java.lang.String functionName) { - events.add(functionName + ":" + path); - return path; - } - }; - Path path = Path.get("/cache"); - Cache c = new Cache(path, 100000L, loggingFileSystem); - - assertThat(c.directoryPath()).isEqualTo(path); - - c.size(); - - assertThat(events).containsExactly("metadataOrNull:/cache/journal.bkp", - "metadataOrNull:/cache", - "sink:/cache/journal.bkp", - "delete:/cache/journal.bkp", - "metadataOrNull:/cache/journal", - "metadataOrNull:/cache", - "sink:/cache/journal.tmp", - "metadataOrNull:/cache/journal", - "atomicMove:/cache/journal.tmp", - "atomicMove:/cache/journal", - "appendingSink:/cache/journal"); - - events.clear(); - - c.size(); - - assertThat(events).isEmpty(); - } - - private void assertFullyCached(MockResponse response) throws Exception { - server.enqueue(response.newBuilder().body("A").build()); - server.enqueue(response.newBuilder().body("B").build()); - - HttpUrl url = server.url("/"); - assertThat(get(url).body().string()).isEqualTo("A"); - assertThat(get(url).body().string()).isEqualTo("A"); - } - - /** - * Shortens the body of {@code response} but not the corresponding headers. Only useful to test - * how clients respond to the premature conclusion of the HTTP body. - */ - private MockResponse.Builder truncateViolently( - MockResponse.Builder builder, int numBytesToKeep) throws IOException { - MockResponse response = builder.build(); - builder.socketPolicy(DisconnectAtEnd.INSTANCE); - Headers headers = response.getHeaders(); - Buffer fullBody = new Buffer(); - response.getBody().writeTo(fullBody); - Buffer truncatedBody = new Buffer(); - truncatedBody.write(fullBody, numBytesToKeep); - builder.body(truncatedBody); - builder.headers(headers); - return builder; - } - - enum TransferKind { - CHUNKED { - @Override void setBody(MockResponse.Builder response, Buffer content, int chunkSize) { - response.chunkedBody(content, chunkSize); - } - }, - FIXED_LENGTH { - @Override void setBody(MockResponse.Builder response, Buffer content, int chunkSize) { - response.body(content); - } - }, - END_OF_STREAM { - @Override void setBody(MockResponse.Builder response, Buffer content, int chunkSize) { - response.body(content); - response.socketPolicy(DisconnectAtEnd.INSTANCE); - response.removeHeader("Content-Length"); - } - }; - - abstract void setBody(MockResponse.Builder response, Buffer content, int chunkSize) throws IOException; - - void setBody(MockResponse.Builder response, String content, int chunkSize) throws IOException { - setBody(response, new Buffer().writeUtf8(content), chunkSize); - } - } - - /** Returns a gzipped copy of {@code bytes}. */ - public Buffer gzip(String data) throws IOException { - Buffer result = new Buffer(); - BufferedSink sink = Okio.buffer(new GzipSink(result)); - sink.writeUtf8(data); - sink.close(); - return result; - } -} diff --git a/okhttp/src/test/java/okhttp3/CacheTest.kt b/okhttp/src/test/java/okhttp3/CacheTest.kt new file mode 100644 index 000000000000..38dc45247778 --- /dev/null +++ b/okhttp/src/test/java/okhttp3/CacheTest.kt @@ -0,0 +1,3432 @@ +/* + * Copyright (C) 2011 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3 + +import java.io.IOException +import java.net.CookieManager +import java.net.HttpURLConnection +import java.net.ResponseCache +import java.text.DateFormat +import java.text.SimpleDateFormat +import java.util.Date +import java.util.Locale +import java.util.TimeZone +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicReference +import javax.net.ssl.HostnameVerifier +import mockwebserver3.MockResponse +import mockwebserver3.MockWebServer +import mockwebserver3.RecordedRequest +import mockwebserver3.SocketPolicy.DisconnectAtEnd +import mockwebserver3.junit5.internal.MockWebServerInstance +import okhttp3.Cache.Companion.key +import okhttp3.Headers.Companion.headersOf +import okhttp3.MediaType.Companion.toMediaType +import okhttp3.RequestBody.Companion.toRequestBody +import okhttp3.internal.addHeaderLenient +import okhttp3.internal.cacheGet +import okhttp3.internal.platform.Platform.Companion.get +import okhttp3.java.net.cookiejar.JavaNetCookieJar +import okhttp3.testing.PlatformRule +import okio.Buffer +import okio.FileSystem +import okio.ForwardingFileSystem +import okio.GzipSink +import okio.Path +import okio.Path.Companion.toPath +import okio.buffer +import okio.fakefilesystem.FakeFileSystem +import org.assertj.core.api.Assertions.assertThat +import org.assertj.core.data.Offset +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.Assertions.fail +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Tag +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.RegisterExtension + +@Tag("Slow") +class CacheTest { + val fileSystem = FakeFileSystem() + + @RegisterExtension + val clientTestRule = OkHttpClientTestRule() + + @RegisterExtension + val platform = PlatformRule() + private lateinit var server: MockWebServer + private lateinit var server2: MockWebServer + private val handshakeCertificates = platform.localhostHandshakeCertificates() + private lateinit var client: OkHttpClient + private lateinit var cache: Cache + private val cookieManager = CookieManager() + + @BeforeEach + fun setUp( + @MockWebServerInstance(name = "1") server: MockWebServer, + @MockWebServerInstance(name = "2") server2: MockWebServer, + ) { + this.server = server + this.server2 = server2 + platform.assumeNotOpenJSSE() + server.protocolNegotiationEnabled = false + fileSystem.emulateUnix() + cache = Cache("/cache/".toPath(), Long.MAX_VALUE, fileSystem) + client = clientTestRule.newClientBuilder() + .cache(cache) + .cookieJar(JavaNetCookieJar(cookieManager)) + .build() + } + + @AfterEach + fun tearDown() { + ResponseCache.setDefault(null) + cache.delete() + } + + /** + * Test that response caching is consistent with the RI and the spec. + * http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html#sec13.4 + */ + @Test + fun responseCachingByResponseCode() { + // Test each documented HTTP/1.1 code, plus the first unused value in each range. + // http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html + + // We can't test 100 because it's not really a response. + // assertCached(false, 100); + assertCached(false, 101) + assertCached(true, 200) + assertCached(false, 201) + assertCached(false, 202) + assertCached(true, 203) + assertCached(true, 204) + assertCached(false, 205) + assertCached(false, 206) //Electing to not cache partial responses + assertCached(false, 207) + assertCached(true, 300) + assertCached(true, 301) + assertCached(true, 302) + assertCached(false, 303) + assertCached(false, 304) + assertCached(false, 305) + assertCached(false, 306) + assertCached(true, 307) + assertCached(true, 308) + assertCached(false, 400) + assertCached(false, 401) + assertCached(false, 402) + assertCached(false, 403) + assertCached(true, 404) + assertCached(true, 405) + assertCached(false, 406) + assertCached(false, 408) + assertCached(false, 409) + // the HTTP spec permits caching 410s, but the RI doesn't. + assertCached(true, 410) + assertCached(false, 411) + assertCached(false, 412) + assertCached(false, 413) + assertCached(true, 414) + assertCached(false, 415) + assertCached(false, 416) + assertCached(false, 417) + assertCached(false, 418) + assertCached(false, 500) + assertCached(true, 501) + assertCached(false, 502) + assertCached(false, 503) + assertCached(false, 504) + assertCached(false, 505) + assertCached(false, 506) + } + + @Test + fun responseCachingWith1xxInformationalResponse() { + assertSubsequentResponseCached(102, 200) + assertSubsequentResponseCached(103, 200) + } + + private fun assertCached(shouldWriteToCache: Boolean, responseCode: Int) { + var expectedResponseCode = responseCode + server = MockWebServer() + val builder = MockResponse.Builder() + .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) + .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) + .code(responseCode) + .body("ABCDE") + .addHeader("WWW-Authenticate: challenge") + when (responseCode) { + HttpURLConnection.HTTP_PROXY_AUTH -> { + builder.addHeader("Proxy-Authenticate: Basic realm=\"protected area\"") + } + + HttpURLConnection.HTTP_UNAUTHORIZED -> { + builder.addHeader("WWW-Authenticate: Basic realm=\"protected area\"") + } + + HttpURLConnection.HTTP_NO_CONTENT, HttpURLConnection.HTTP_RESET -> { + builder.body("") // We forbid bodies for 204 and 205. + } + } + server.enqueue(builder.build()) + if (responseCode == HttpURLConnection.HTTP_CLIENT_TIMEOUT) { + // 408's are a bit of an outlier because we may repeat the request if we encounter this + // response code. In this scenario, there are 2 responses: the initial 408 and then the 200 + // because of the retry. We just want to ensure the initial 408 isn't cached. + expectedResponseCode = 200 + server.enqueue( + MockResponse.Builder() + .setHeader("Cache-Control", "no-store") + .body("FGHIJ") + .build() + ) + } + server.start() + val request = Request.Builder() + .url(server.url("/")) + .build() + val response = client.newCall(request).execute() + assertThat(response.code).isEqualTo(expectedResponseCode) + + // Exhaust the content stream. + response.body.string() + val cached = cacheGet(cache, request) + if (shouldWriteToCache) { + assertThat(cached).isNotNull() + cached!!.body.close() + } else { + assertThat(cached).isNull() + } + server.shutdown() // tearDown() isn't sufficient; this test starts multiple servers + } + + private fun assertSubsequentResponseCached(initialResponseCode: Int, finalResponseCode: Int) { + server = MockWebServer() + val builder = MockResponse.Builder() + .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) + .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) + .code(finalResponseCode) + .body("ABCDE") + .addInformationalResponse(MockResponse(initialResponseCode)) + server.enqueue(builder.build()) + server.start() + val request = Request.Builder() + .url(server.url("/")) + .build() + val response = client.newCall(request).execute() + assertThat(response.code).isEqualTo(finalResponseCode) + + // Exhaust the content stream. + response.body.string() + val cached = cacheGet(cache, request) + assertThat(cached).isNotNull() + cached!!.body.close() + server.shutdown() // tearDown() isn't sufficient; this test starts multiple servers + } + + @Test + fun responseCachingAndInputStreamSkipWithFixedLength() { + testResponseCaching(TransferKind.FIXED_LENGTH) + } + + @Test + fun responseCachingAndInputStreamSkipWithChunkedEncoding() { + testResponseCaching(TransferKind.CHUNKED) + } + + @Test + fun responseCachingAndInputStreamSkipWithNoLengthHeaders() { + testResponseCaching(TransferKind.END_OF_STREAM) + } + + /** + * Skipping bytes in the input stream caused ResponseCache corruption. + * http://code.google.com/p/android/issues/detail?id=8175 + */ + private fun testResponseCaching(transferKind: TransferKind) { + val mockResponse = MockResponse.Builder() + .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) + .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) + .status("HTTP/1.1 200 Fantastic") + transferKind.setBody(mockResponse, "I love puppies but hate spiders", 1) + server.enqueue(mockResponse.build()) + + // Make sure that calling skip() doesn't omit bytes from the cache. + val request = Request.Builder().url(server.url("/")).build() + val response1 = client.newCall(request).execute() + val in1 = response1.body.source() + assertThat(in1.readUtf8("I love ".length.toLong())).isEqualTo("I love ") + in1.skip("puppies but hate ".length.toLong()) + assertThat(in1.readUtf8("spiders".length.toLong())).isEqualTo("spiders") + assertThat(in1.exhausted()).isTrue() + in1.close() + assertThat(cache.writeSuccessCount()).isEqualTo(1) + assertThat(cache.writeAbortCount()).isEqualTo(0) + val response2 = client.newCall(request).execute() + val in2 = response2.body.source() + assertThat(in2.readUtf8("I love puppies but hate spiders".length.toLong())) + .isEqualTo( + "I love puppies but hate spiders" + ) + assertThat(response2.code).isEqualTo(200) + assertThat(response2.message).isEqualTo("Fantastic") + assertThat(in2.exhausted()).isTrue() + in2.close() + assertThat(cache.writeSuccessCount()).isEqualTo(1) + assertThat(cache.writeAbortCount()).isEqualTo(0) + assertThat(cache.requestCount()).isEqualTo(2) + assertThat(cache.hitCount()).isEqualTo(1) + } + + @Test + fun secureResponseCaching() { + server.useHttps(handshakeCertificates.sslSocketFactory()) + server.enqueue( + MockResponse.Builder() + .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) + .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) + .body("ABC") + .build() + ) + client = client.newBuilder() + .sslSocketFactory( + handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager + ) + .hostnameVerifier(NULL_HOSTNAME_VERIFIER) + .build() + val request = Request.Builder().url(server.url("/")).build() + val response1 = client.newCall(request).execute() + val `in` = response1.body.source() + assertThat(`in`.readUtf8()).isEqualTo("ABC") + + // OpenJDK 6 fails on this line, complaining that the connection isn't open yet + val cipherSuite = response1.handshake!!.cipherSuite + val localCerts = response1.handshake!!.localCertificates + val serverCerts = response1.handshake!!.peerCertificates + val peerPrincipal = response1.handshake!!.peerPrincipal + val localPrincipal = response1.handshake!!.localPrincipal + val response2 = client.newCall(request).execute() // Cached! + assertThat(response2.body.string()).isEqualTo("ABC") + assertThat(cache.requestCount()).isEqualTo(2) + assertThat(cache.networkCount()).isEqualTo(1) + assertThat(cache.hitCount()).isEqualTo(1) + assertThat(response2.handshake!!.cipherSuite).isEqualTo(cipherSuite) + assertThat(response2.handshake!!.localCertificates).isEqualTo(localCerts) + assertThat(response2.handshake!!.peerCertificates).isEqualTo(serverCerts) + assertThat(response2.handshake!!.peerPrincipal).isEqualTo(peerPrincipal) + assertThat(response2.handshake!!.localPrincipal).isEqualTo(localPrincipal) + } + + @Test + fun secureResponseCachingWithCorruption() { + server.useHttps(handshakeCertificates.sslSocketFactory()) + server.enqueue( + MockResponse.Builder() + .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) + .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) + .body("ABC") + .build() + ) + server.enqueue( + MockResponse.Builder() + .addHeader("Last-Modified: " + formatDate(-5, TimeUnit.MINUTES)) + .addHeader("Expires: " + formatDate(2, TimeUnit.HOURS)) + .body("DEF") + .build() + ) + client = client.newBuilder() + .sslSocketFactory( + handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager + ) + .hostnameVerifier(NULL_HOSTNAME_VERIFIER) + .build() + val request = Request.Builder().url(server.url("/")).build() + val response1 = client.newCall(request).execute() + assertThat(response1.body.string()).isEqualTo("ABC") + val cacheEntry = fileSystem.allPaths.stream() + .filter { e: Path -> e.name.endsWith(".0") } + .findFirst() + .orElseThrow { NoSuchElementException() } + corruptCertificate(cacheEntry) + val response2 = client.newCall(request).execute() // Not Cached! + assertThat(response2.body.string()).isEqualTo("DEF") + assertThat(cache.requestCount()).isEqualTo(2) + assertThat(cache.networkCount()).isEqualTo(2) + assertThat(cache.hitCount()).isEqualTo(0) + } + + private fun corruptCertificate(cacheEntry: Path) { + var content = fileSystem.source(cacheEntry).buffer().readUtf8() + content = content.replace("MII", "!!!") + fileSystem.sink(cacheEntry).buffer().writeUtf8(content).close() + } + + @Test + fun responseCachingAndRedirects() { + server.enqueue( + MockResponse.Builder() + .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) + .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) + .code(HttpURLConnection.HTTP_MOVED_PERM) + .addHeader("Location: /foo") + .build() + ) + server.enqueue( + MockResponse.Builder() + .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) + .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) + .body("ABC") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("DEF") + .build() + ) + val request = Request.Builder().url(server.url("/")).build() + val response1 = client.newCall(request).execute() + assertThat(response1.body.string()).isEqualTo("ABC") + val response2 = client.newCall(request).execute() // Cached! + assertThat(response2.body.string()).isEqualTo("ABC") + + // 2 requests + 2 redirects + assertThat(cache.requestCount()).isEqualTo(4) + assertThat(cache.networkCount()).isEqualTo(2) + assertThat(cache.hitCount()).isEqualTo(2) + } + + @Test + fun redirectToCachedResult() { + server.enqueue( + MockResponse.Builder() + .addHeader("Cache-Control: max-age=60") + .body("ABC") + .build() + ) + server.enqueue( + MockResponse.Builder() + .code(HttpURLConnection.HTTP_MOVED_PERM) + .addHeader("Location: /foo") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("DEF") + .build() + ) + val request1 = Request.Builder().url(server.url("/foo")).build() + val response1 = client.newCall(request1).execute() + assertThat(response1.body.string()).isEqualTo("ABC") + val recordedRequest1 = server.takeRequest() + assertThat(recordedRequest1.requestLine).isEqualTo("GET /foo HTTP/1.1") + assertThat(recordedRequest1.sequenceNumber).isEqualTo(0) + val request2 = Request.Builder().url(server.url("/bar")).build() + val response2 = client.newCall(request2).execute() + assertThat(response2.body.string()).isEqualTo("ABC") + val recordedRequest2 = server.takeRequest() + assertThat(recordedRequest2.requestLine).isEqualTo("GET /bar HTTP/1.1") + assertThat(recordedRequest2.sequenceNumber).isEqualTo(1) + + // an unrelated request should reuse the pooled connection + val request3 = Request.Builder().url(server.url("/baz")).build() + val response3 = client.newCall(request3).execute() + assertThat(response3.body.string()).isEqualTo("DEF") + val recordedRequest3 = server.takeRequest() + assertThat(recordedRequest3.requestLine).isEqualTo("GET /baz HTTP/1.1") + assertThat(recordedRequest3.sequenceNumber).isEqualTo(2) + } + + @Test + fun secureResponseCachingAndRedirects() { + server.useHttps(handshakeCertificates.sslSocketFactory()) + server.enqueue( + MockResponse.Builder() + .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) + .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) + .code(HttpURLConnection.HTTP_MOVED_PERM) + .addHeader("Location: /foo") + .build() + ) + server.enqueue( + MockResponse.Builder() + .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) + .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) + .body("ABC") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("DEF") + .build() + ) + client = client.newBuilder() + .sslSocketFactory( + handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager + ) + .hostnameVerifier(NULL_HOSTNAME_VERIFIER) + .build() + val response1 = get(server.url("/")) + assertThat(response1.body.string()).isEqualTo("ABC") + assertThat(response1.handshake!!.cipherSuite).isNotNull() + + // Cached! + val response2 = get(server.url("/")) + assertThat(response2.body.string()).isEqualTo("ABC") + assertThat(response2.handshake!!.cipherSuite).isNotNull() + + // 2 direct + 2 redirect = 4 + assertThat(cache.requestCount()).isEqualTo(4) + assertThat(cache.hitCount()).isEqualTo(2) + assertThat(response2.handshake!!.cipherSuite).isEqualTo( + response1.handshake!!.cipherSuite + ) + } + + /** + * We've had bugs where caching and cross-protocol redirects yield class cast exceptions internal + * to the cache because we incorrectly assumed that HttpsURLConnection was always HTTPS and + * HttpURLConnection was always HTTP; in practice redirects mean that each can do either. + * + * https://github.com/square/okhttp/issues/214 + */ + @Test + fun secureResponseCachingAndProtocolRedirects() { + server2.useHttps(handshakeCertificates.sslSocketFactory()) + server2.enqueue( + MockResponse.Builder() + .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) + .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) + .body("ABC") + .build() + ) + server2.enqueue( + MockResponse.Builder() + .body("DEF") + .build() + ) + server.enqueue( + MockResponse.Builder() + .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) + .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) + .code(HttpURLConnection.HTTP_MOVED_PERM) + .addHeader("Location: " + server2.url("/")) + .build() + ) + client = client.newBuilder() + .sslSocketFactory( + handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager + ) + .hostnameVerifier(NULL_HOSTNAME_VERIFIER) + .build() + val response1 = get(server.url("/")) + assertThat(response1.body.string()).isEqualTo("ABC") + + // Cached! + val response2 = get(server.url("/")) + assertThat(response2.body.string()).isEqualTo("ABC") + + // 2 direct + 2 redirect = 4 + assertThat(cache.requestCount()).isEqualTo(4) + assertThat(cache.hitCount()).isEqualTo(2) + } + + @Test + fun foundCachedWithExpiresHeader() { + temporaryRedirectCachedWithCachingHeader(302, "Expires", formatDate(1, TimeUnit.HOURS)) + } + + @Test + fun foundCachedWithCacheControlHeader() { + temporaryRedirectCachedWithCachingHeader(302, "Cache-Control", "max-age=60") + } + + @Test + fun temporaryRedirectCachedWithExpiresHeader() { + temporaryRedirectCachedWithCachingHeader(307, "Expires", formatDate(1, TimeUnit.HOURS)) + } + + @Test + fun temporaryRedirectCachedWithCacheControlHeader() { + temporaryRedirectCachedWithCachingHeader(307, "Cache-Control", "max-age=60") + } + + @Test + fun foundNotCachedWithoutCacheHeader() { + temporaryRedirectNotCachedWithoutCachingHeader(302) + } + + @Test + fun temporaryRedirectNotCachedWithoutCacheHeader() { + temporaryRedirectNotCachedWithoutCachingHeader(307) + } + + private fun temporaryRedirectCachedWithCachingHeader( + responseCode: Int, + headerName: String, + headerValue: String, + ) { + server.enqueue( + MockResponse.Builder() + .code(responseCode) + .addHeader(headerName, headerValue) + .addHeader("Location", "/a") + .build() + ) + server.enqueue( + MockResponse.Builder() + .addHeader(headerName, headerValue) + .body("a") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("b") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("c") + .build() + ) + val url = server.url("/") + assertThat(get(url).body.string()).isEqualTo("a") + assertThat(get(url).body.string()).isEqualTo("a") + } + + private fun temporaryRedirectNotCachedWithoutCachingHeader(responseCode: Int) { + server.enqueue( + MockResponse.Builder() + .code(responseCode) + .addHeader("Location", "/a") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("a") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("b") + .build() + ) + val url = server.url("/") + assertThat(get(url).body.string()).isEqualTo("a") + assertThat(get(url).body.string()).isEqualTo("b") + } + + /** https://github.com/square/okhttp/issues/2198 */ + @Test + fun cachedRedirect() { + server.enqueue( + MockResponse.Builder() + .code(301) + .addHeader("Cache-Control: max-age=60") + .addHeader("Location: /bar") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("ABC") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("ABC") + .build() + ) + val request1 = Request.Builder().url(server.url("/")).build() + val response1 = client.newCall(request1).execute() + assertThat(response1.body.string()).isEqualTo("ABC") + val request2 = Request.Builder().url(server.url("/")).build() + val response2 = client.newCall(request2).execute() + assertThat(response2.body.string()).isEqualTo("ABC") + } + + @Test + fun serverDisconnectsPrematurelyWithContentLengthHeader() { + testServerPrematureDisconnect(TransferKind.FIXED_LENGTH) + } + + @Test + fun serverDisconnectsPrematurelyWithChunkedEncoding() { + testServerPrematureDisconnect(TransferKind.CHUNKED) + } + + @Test + fun serverDisconnectsPrematurelyWithNoLengthHeaders() { + // Intentionally empty. This case doesn't make sense because there's no + // such thing as a premature disconnect when the disconnect itself + // indicates the end of the data stream. + } + + private fun testServerPrematureDisconnect(transferKind: TransferKind) { + val mockResponse = MockResponse.Builder() + transferKind.setBody(mockResponse, "ABCDE\nFGHIJKLMNOPQRSTUVWXYZ", 16) + server.enqueue(truncateViolently(mockResponse, 16).build()) + server.enqueue( + MockResponse.Builder() + .body("Request #2") + .build() + ) + val bodySource = get(server.url("/")).body.source() + assertThat(bodySource.readUtf8Line()).isEqualTo("ABCDE") + try { + bodySource.readUtf8(21) + fail("This implementation silently ignored a truncated HTTP body.") + } catch (expected: IOException) { + } finally { + bodySource.close() + } + assertThat(cache.writeAbortCount()).isEqualTo(1) + assertThat(cache.writeSuccessCount()).isEqualTo(0) + val response = get(server.url("/")) + assertThat(response.body.string()).isEqualTo("Request #2") + assertThat(cache.writeAbortCount()).isEqualTo(1) + assertThat(cache.writeSuccessCount()).isEqualTo(1) + } + + @Test + fun clientPrematureDisconnectWithContentLengthHeader() { + testClientPrematureDisconnect(TransferKind.FIXED_LENGTH) + } + + @Test + fun clientPrematureDisconnectWithChunkedEncoding() { + testClientPrematureDisconnect(TransferKind.CHUNKED) + } + + @Test + fun clientPrematureDisconnectWithNoLengthHeaders() { + testClientPrematureDisconnect(TransferKind.END_OF_STREAM) + } + + private fun testClientPrematureDisconnect(transferKind: TransferKind) { + // Setting a low transfer speed ensures that stream discarding will time out. + val builder = MockResponse.Builder() + .throttleBody(6, 1, TimeUnit.SECONDS) + transferKind.setBody(builder, "ABCDE\nFGHIJKLMNOPQRSTUVWXYZ", 1024) + server.enqueue(builder.build()) + server.enqueue( + MockResponse.Builder() + .body("Request #2") + .build() + ) + val response1 = get(server.url("/")) + val `in` = response1.body.source() + assertThat(`in`.readUtf8(5)).isEqualTo("ABCDE") + `in`.close() + try { + `in`.readByte() + fail("Expected an IllegalStateException because the source is closed.") + } catch (expected: IllegalStateException) { + } + assertThat(cache.writeAbortCount()).isEqualTo(1) + assertThat(cache.writeSuccessCount()).isEqualTo(0) + val response2 = get(server.url("/")) + assertThat(response2.body.string()).isEqualTo("Request #2") + assertThat(cache.writeAbortCount()).isEqualTo(1) + assertThat(cache.writeSuccessCount()).isEqualTo(1) + } + + @Test + fun defaultExpirationDateFullyCachedForLessThan24Hours() { + // last modified: 105 seconds ago + // served: 5 seconds ago + // default lifetime: (105 - 5) / 10 = 10 seconds + // expires: 10 seconds from served date = 5 seconds from now + server.enqueue( + MockResponse.Builder() + .addHeader("Last-Modified: " + formatDate(-105, TimeUnit.SECONDS)) + .addHeader("Date: " + formatDate(-5, TimeUnit.SECONDS)) + .body("A") + .build() + ) + val url = server.url("/") + val response1 = get(url) + assertThat(response1.body.string()).isEqualTo("A") + val response2 = get(url) + assertThat(response2.body.string()).isEqualTo("A") + assertThat(response2.header("Warning")).isNull() + } + + @Test + fun defaultExpirationDateConditionallyCached() { + // last modified: 115 seconds ago + // served: 15 seconds ago + // default lifetime: (115 - 15) / 10 = 10 seconds + // expires: 10 seconds from served date = 5 seconds ago + val lastModifiedDate = formatDate(-115, TimeUnit.SECONDS) + val conditionalRequest = assertConditionallyCached( + MockResponse.Builder() + .addHeader("Last-Modified: $lastModifiedDate") + .addHeader("Date: " + formatDate(-15, TimeUnit.SECONDS)) + .build() + ) + assertThat(conditionalRequest.headers["If-Modified-Since"]) + .isEqualTo(lastModifiedDate) + } + + @Test + fun defaultExpirationDateFullyCachedForMoreThan24Hours() { + // last modified: 105 days ago + // served: 5 days ago + // default lifetime: (105 - 5) / 10 = 10 days + // expires: 10 days from served date = 5 days from now + server.enqueue( + MockResponse.Builder() + .addHeader("Last-Modified: " + formatDate(-105, TimeUnit.DAYS)) + .addHeader("Date: " + formatDate(-5, TimeUnit.DAYS)) + .body("A") + .build() + ) + assertThat(get(server.url("/")).body.string()).isEqualTo("A") + val response = get(server.url("/")) + assertThat(response.body.string()).isEqualTo("A") + assertThat(response.header("Warning")).isEqualTo( + "113 HttpURLConnection \"Heuristic expiration\"" + ) + } + + @Test + fun noDefaultExpirationForUrlsWithQueryString() { + server.enqueue( + MockResponse.Builder() + .addHeader("Last-Modified: " + formatDate(-105, TimeUnit.SECONDS)) + .addHeader("Date: " + formatDate(-5, TimeUnit.SECONDS)) + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("B") + .build() + ) + val url = server.url("/").newBuilder().addQueryParameter("foo", "bar").build() + assertThat(get(url).body.string()).isEqualTo("A") + assertThat(get(url).body.string()).isEqualTo("B") + } + + @Test + fun expirationDateInThePastWithLastModifiedHeader() { + val lastModifiedDate = formatDate(-2, TimeUnit.HOURS) + val conditionalRequest = assertConditionallyCached( + MockResponse.Builder() + .addHeader("Last-Modified: $lastModifiedDate") + .addHeader("Expires: " + formatDate(-1, TimeUnit.HOURS)) + .build() + ) + assertThat(conditionalRequest.headers["If-Modified-Since"]) + .isEqualTo(lastModifiedDate) + } + + @Test + fun expirationDateInThePastWithNoLastModifiedHeader() { + assertNotCached( + MockResponse.Builder() + .addHeader("Expires: " + formatDate(-1, TimeUnit.HOURS)) + .build() + ) + } + + @Test + fun expirationDateInTheFuture() { + assertFullyCached( + MockResponse.Builder() + .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) + .build() + ) + } + + @Test + fun maxAgePreferredWithMaxAgeAndExpires() { + assertFullyCached( + MockResponse.Builder() + .addHeader("Date: " + formatDate(0, TimeUnit.HOURS)) + .addHeader("Expires: " + formatDate(-1, TimeUnit.HOURS)) + .addHeader("Cache-Control: max-age=60") + .build() + ) + } + + @Test + fun maxAgeInThePastWithDateAndLastModifiedHeaders() { + val lastModifiedDate = formatDate(-2, TimeUnit.HOURS) + val conditionalRequest = assertConditionallyCached( + MockResponse.Builder() + .addHeader("Date: " + formatDate(-120, TimeUnit.SECONDS)) + .addHeader("Last-Modified: $lastModifiedDate") + .addHeader("Cache-Control: max-age=60") + .build() + ) + assertThat(conditionalRequest.headers["If-Modified-Since"]) + .isEqualTo(lastModifiedDate) + } + + @Test + fun maxAgeInThePastWithDateHeaderButNoLastModifiedHeader() { + // Chrome interprets max-age relative to the local clock. Both our cache + // and Firefox both use the earlier of the local and server's clock. + assertNotCached( + MockResponse.Builder() + .addHeader("Date: " + formatDate(-120, TimeUnit.SECONDS)) + .addHeader("Cache-Control: max-age=60") + .build() + ) + } + + @Test + fun maxAgeInTheFutureWithDateHeader() { + assertFullyCached( + MockResponse.Builder() + .addHeader("Date: " + formatDate(0, TimeUnit.HOURS)) + .addHeader("Cache-Control: max-age=60") + .build() + ) + } + + @Test + fun maxAgeInTheFutureWithNoDateHeader() { + assertFullyCached( + MockResponse.Builder() + .addHeader("Cache-Control: max-age=60") + .build() + ) + } + + @Test + fun maxAgeWithLastModifiedButNoServedDate() { + assertFullyCached( + MockResponse.Builder() + .addHeader("Last-Modified: " + formatDate(-120, TimeUnit.SECONDS)) + .addHeader("Cache-Control: max-age=60") + .build() + ) + } + + @Test + fun maxAgeInTheFutureWithDateAndLastModifiedHeaders() { + assertFullyCached( + MockResponse.Builder() + .addHeader("Last-Modified: " + formatDate(-120, TimeUnit.SECONDS)) + .addHeader("Date: " + formatDate(0, TimeUnit.SECONDS)) + .addHeader("Cache-Control: max-age=60") + .build() + ) + } + + @Test + fun maxAgePreferredOverLowerSharedMaxAge() { + assertFullyCached( + MockResponse.Builder() + .addHeader("Date: " + formatDate(-2, TimeUnit.MINUTES)) + .addHeader("Cache-Control: s-maxage=60") + .addHeader("Cache-Control: max-age=180") + .build() + ) + } + + @Test + fun maxAgePreferredOverHigherMaxAge() { + assertNotCached( + MockResponse.Builder() + .addHeader("Date: " + formatDate(-2, TimeUnit.MINUTES)) + .addHeader("Cache-Control: s-maxage=180") + .addHeader("Cache-Control: max-age=60") + .build() + ) + } + + @Test + fun requestMethodOptionsIsNotCached() { + testRequestMethod("OPTIONS", false) + } + + @Test + fun requestMethodGetIsCached() { + testRequestMethod("GET", true) + } + + @Test + fun requestMethodHeadIsNotCached() { + // We could support this but choose not to for implementation simplicity + testRequestMethod("HEAD", false) + } + + @Test + fun requestMethodPostIsNotCached() { + // We could support this but choose not to for implementation simplicity + testRequestMethod("POST", false) + } + + @Test + fun requestMethodPutIsNotCached() { + testRequestMethod("PUT", false) + } + + @Test + fun requestMethodDeleteIsNotCached() { + testRequestMethod("DELETE", false) + } + + @Test + fun requestMethodTraceIsNotCached() { + testRequestMethod("TRACE", false) + } + + private fun testRequestMethod(requestMethod: String, expectCached: Boolean) { + // 1. Seed the cache (potentially). + // 2. Expect a cache hit or miss. + server.enqueue( + MockResponse.Builder() + .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) + .addHeader("X-Response-ID: 1") + .build() + ) + server.enqueue( + MockResponse.Builder() + .addHeader("X-Response-ID: 2") + .build() + ) + val url = server.url("/") + val request = Request.Builder() + .url(url) + .method(requestMethod, requestBodyOrNull(requestMethod)) + .build() + val response1 = client.newCall(request).execute() + response1.body.close() + assertThat(response1.header("X-Response-ID")).isEqualTo("1") + val response2 = get(url) + response2.body.close() + if (expectCached) { + assertThat(response2.header("X-Response-ID")).isEqualTo("1") + } else { + assertThat(response2.header("X-Response-ID")).isEqualTo("2") + } + } + + private fun requestBodyOrNull(requestMethod: String): RequestBody? { + return if (requestMethod == "POST" || requestMethod == "PUT") "foo".toRequestBody("text/plain".toMediaType()) else null + } + + @Test + fun postInvalidatesCache() { + testMethodInvalidates("POST") + } + + @Test + fun putInvalidatesCache() { + testMethodInvalidates("PUT") + } + + @Test + fun deleteMethodInvalidatesCache() { + testMethodInvalidates("DELETE") + } + + private fun testMethodInvalidates(requestMethod: String) { + // 1. Seed the cache. + // 2. Invalidate it. + // 3. Expect a cache miss. + server.enqueue( + MockResponse.Builder() + .body("A") + .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("B") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("C") + .build() + ) + val url = server.url("/") + assertThat(get(url).body.string()).isEqualTo("A") + val request = Request.Builder() + .url(url) + .method(requestMethod, requestBodyOrNull(requestMethod)) + .build() + val invalidate = client.newCall(request).execute() + assertThat(invalidate.body.string()).isEqualTo("B") + assertThat(get(url).body.string()).isEqualTo("C") + } + + @Test + fun postInvalidatesCacheWithUncacheableResponse() { + // 1. Seed the cache. + // 2. Invalidate it with an uncacheable response. + // 3. Expect a cache miss. + server.enqueue( + MockResponse.Builder() + .body("A") + .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("B") + .code(500) + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("C") + .build() + ) + val url = server.url("/") + assertThat(get(url).body.string()).isEqualTo("A") + val request = Request.Builder() + .url(url) + .method("POST", requestBodyOrNull("POST")) + .build() + val invalidate = client.newCall(request).execute() + assertThat(invalidate.body.string()).isEqualTo("B") + assertThat(get(url).body.string()).isEqualTo("C") + } + + @Test + fun putInvalidatesWithNoContentResponse() { + // 1. Seed the cache. + // 2. Invalidate it. + // 3. Expect a cache miss. + server.enqueue( + MockResponse.Builder() + .body("A") + .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) + .build() + ) + server.enqueue( + MockResponse.Builder() + .clearHeaders() + .code(HttpURLConnection.HTTP_NO_CONTENT) + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("C") + .build() + ) + val url = server.url("/") + assertThat(get(url).body.string()).isEqualTo("A") + val request = Request.Builder() + .url(url) + .put("foo".toRequestBody("text/plain".toMediaType())) + .build() + val invalidate = client.newCall(request).execute() + assertThat(invalidate.body.string()).isEqualTo("") + assertThat(get(url).body.string()).isEqualTo("C") + } + + @Test + fun etag() { + val conditionalRequest = assertConditionallyCached( + MockResponse.Builder() + .addHeader("ETag: v1") + .build() + ) + assertThat(conditionalRequest.headers["If-None-Match"]).isEqualTo("v1") + } + + /** If both If-Modified-Since and If-None-Match conditions apply, send only If-None-Match. */ + @Test + fun etagAndExpirationDateInThePast() { + val lastModifiedDate = formatDate(-2, TimeUnit.HOURS) + val conditionalRequest = assertConditionallyCached( + MockResponse.Builder() + .addHeader("ETag: v1") + .addHeader("Last-Modified: $lastModifiedDate") + .addHeader("Expires: " + formatDate(-1, TimeUnit.HOURS)) + .build() + ) + assertThat(conditionalRequest.headers["If-None-Match"]).isEqualTo("v1") + assertThat(conditionalRequest.headers["If-Modified-Since"]).isNull() + } + + @Test + fun etagAndExpirationDateInTheFuture() { + assertFullyCached( + MockResponse.Builder() + .addHeader("ETag: v1") + .addHeader("Last-Modified: " + formatDate(-2, TimeUnit.HOURS)) + .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) + .build() + ) + } + + @Test + fun cacheControlNoCache() { + assertNotCached( + MockResponse.Builder() + .addHeader("Cache-Control: no-cache") + .build() + ) + } + + @Test + fun cacheControlNoCacheAndExpirationDateInTheFuture() { + val lastModifiedDate = formatDate(-2, TimeUnit.HOURS) + val conditionalRequest = assertConditionallyCached( + MockResponse.Builder() + .addHeader("Last-Modified: $lastModifiedDate") + .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) + .addHeader("Cache-Control: no-cache") + .build() + ) + assertThat(conditionalRequest.headers["If-Modified-Since"]) + .isEqualTo(lastModifiedDate) + } + + @Test + fun pragmaNoCache() { + assertNotCached( + MockResponse.Builder() + .addHeader("Pragma: no-cache") + .build() + ) + } + + @Test + fun pragmaNoCacheAndExpirationDateInTheFuture() { + val lastModifiedDate = formatDate(-2, TimeUnit.HOURS) + val conditionalRequest = assertConditionallyCached( + MockResponse.Builder() + .addHeader("Last-Modified: $lastModifiedDate") + .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) + .addHeader("Pragma: no-cache") + .build() + ) + assertThat(conditionalRequest.headers["If-Modified-Since"]) + .isEqualTo(lastModifiedDate) + } + + @Test + fun cacheControlNoStore() { + assertNotCached( + MockResponse.Builder() + .addHeader("Cache-Control: no-store") + .build() + ) + } + + @Test + fun cacheControlNoStoreAndExpirationDateInTheFuture() { + assertNotCached( + MockResponse.Builder() + .addHeader("Last-Modified: " + formatDate(-2, TimeUnit.HOURS)) + .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) + .addHeader("Cache-Control: no-store") + .build() + ) + } + + @Test + fun partialRangeResponsesDoNotCorruptCache() { + // 1. Request a range. + // 2. Request a full document, expecting a cache miss. + server.enqueue( + MockResponse.Builder() + .body("AA") + .code(HttpURLConnection.HTTP_PARTIAL) + .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) + .addHeader("Content-Range: bytes 1000-1001/2000") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("BB") + .build() + ) + val url = server.url("/") + val request = Request.Builder() + .url(url) + .header("Range", "bytes=1000-1001") + .build() + val range = client.newCall(request).execute() + assertThat(range.body.string()).isEqualTo("AA") + assertThat(get(url).body.string()).isEqualTo("BB") + } + + /** + * When the server returns a full response body we will store it and return it regardless of what + * its Last-Modified date is. This behavior was different prior to OkHttp 3.5 when we would prefer + * the response with the later Last-Modified date. + * + * https://github.com/square/okhttp/issues/2886 + */ + @Test + fun serverReturnsDocumentOlderThanCache() { + server.enqueue( + MockResponse.Builder() + .body("A") + .addHeader("Last-Modified: " + formatDate(-2, TimeUnit.HOURS)) + .addHeader("Expires: " + formatDate(-1, TimeUnit.HOURS)) + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("B") + .addHeader("Last-Modified: " + formatDate(-4, TimeUnit.HOURS)) + .build() + ) + server.enqueue( + MockResponse.Builder() + .code(HttpURLConnection.HTTP_NOT_MODIFIED) + .build() + ) + val url = server.url("/") + assertThat(get(url).body.string()).isEqualTo("A") + assertThat(get(url).body.string()).isEqualTo("B") + assertThat(get(url).body.string()).isEqualTo("B") + } + + @Test + fun clientSideNoStore() { + server.enqueue( + MockResponse.Builder() + .addHeader("Cache-Control: max-age=60") + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .addHeader("Cache-Control: max-age=60") + .body("B") + .build() + ) + val request1 = Request.Builder() + .url(server.url("/")) + .cacheControl(CacheControl.Builder().noStore().build()) + .build() + val response1 = client.newCall(request1).execute() + assertThat(response1.body.string()).isEqualTo("A") + val request2 = Request.Builder() + .url(server.url("/")) + .build() + val response2 = client.newCall(request2).execute() + assertThat(response2.body.string()).isEqualTo("B") + } + + @Test + fun nonIdentityEncodingAndConditionalCache() { + assertNonIdentityEncodingCached( + MockResponse.Builder() + .addHeader("Last-Modified: " + formatDate(-2, TimeUnit.HOURS)) + .addHeader("Expires: " + formatDate(-1, TimeUnit.HOURS)) + .build() + ) + } + + @Test + fun nonIdentityEncodingAndFullCache() { + assertNonIdentityEncodingCached( + MockResponse.Builder() + .addHeader("Last-Modified: " + formatDate(-2, TimeUnit.HOURS)) + .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) + .build() + ) + } + + private fun assertNonIdentityEncodingCached(response: MockResponse) { + server.enqueue( + response.newBuilder() + .body(gzip("ABCABCABC")) + .addHeader("Content-Encoding: gzip") + .build() + ) + server.enqueue( + MockResponse.Builder() + .code(HttpURLConnection.HTTP_NOT_MODIFIED) + .build() + ) + server.enqueue( + MockResponse.Builder() + .code(HttpURLConnection.HTTP_NOT_MODIFIED) + .build() + ) + + // At least three request/response pairs are required because after the first request is cached + // a different execution path might be taken. Thus modifications to the cache applied during + // the second request might not be visible until another request is performed. + assertThat(get(server.url("/")).body.string()).isEqualTo("ABCABCABC") + assertThat(get(server.url("/")).body.string()).isEqualTo("ABCABCABC") + assertThat(get(server.url("/")).body.string()).isEqualTo("ABCABCABC") + } + + @Test + fun previouslyNotGzippedContentIsNotModifiedAndSpecifiesGzipEncoding() { + server.enqueue( + MockResponse.Builder() + .body("ABCABCABC") + .addHeader("Content-Type: text/plain") + .addHeader("Last-Modified: " + formatDate(-2, TimeUnit.HOURS)) + .addHeader("Expires: " + formatDate(-1, TimeUnit.HOURS)) + .build() + ) + server.enqueue( + MockResponse.Builder() + .code(HttpURLConnection.HTTP_NOT_MODIFIED) + .addHeader("Content-Type: text/plain") + .addHeader("Content-Encoding: gzip") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("DEFDEFDEF") + .build() + ) + assertThat(get(server.url("/")).body.string()).isEqualTo("ABCABCABC") + assertThat(get(server.url("/")).body.string()).isEqualTo("ABCABCABC") + assertThat(get(server.url("/")).body.string()).isEqualTo("DEFDEFDEF") + } + + @Test + fun changedGzippedContentIsNotModifiedAndSpecifiesNewEncoding() { + server.enqueue( + MockResponse.Builder() + .body(gzip("ABCABCABC")) + .addHeader("Content-Type: text/plain") + .addHeader("Last-Modified: " + formatDate(-2, TimeUnit.HOURS)) + .addHeader("Expires: " + formatDate(-1, TimeUnit.HOURS)) + .addHeader("Content-Encoding: gzip") + .build() + ) + server.enqueue( + MockResponse.Builder() + .code(HttpURLConnection.HTTP_NOT_MODIFIED) + .addHeader("Content-Type: text/plain") + .addHeader("Content-Encoding: identity") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("DEFDEFDEF") + .build() + ) + assertThat(get(server.url("/")).body.string()).isEqualTo("ABCABCABC") + assertThat(get(server.url("/")).body.string()).isEqualTo("ABCABCABC") + assertThat(get(server.url("/")).body.string()).isEqualTo("DEFDEFDEF") + } + + @Test + fun notModifiedSpecifiesEncoding() { + server.enqueue( + MockResponse.Builder() + .body(gzip("ABCABCABC")) + .addHeader("Content-Encoding: gzip") + .addHeader("Last-Modified: " + formatDate(-2, TimeUnit.HOURS)) + .addHeader("Expires: " + formatDate(-1, TimeUnit.HOURS)) + .build() + ) + server.enqueue( + MockResponse.Builder() + .code(HttpURLConnection.HTTP_NOT_MODIFIED) + .addHeader("Content-Encoding: gzip") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("DEFDEFDEF") + .build() + ) + assertThat(get(server.url("/")).body.string()).isEqualTo("ABCABCABC") + assertThat(get(server.url("/")).body.string()).isEqualTo("ABCABCABC") + assertThat(get(server.url("/")).body.string()).isEqualTo("DEFDEFDEF") + } + + /** https://github.com/square/okhttp/issues/947 */ + @Test + fun gzipAndVaryOnAcceptEncoding() { + server.enqueue( + MockResponse.Builder() + .body(gzip("ABCABCABC")) + .addHeader("Content-Encoding: gzip") + .addHeader("Vary: Accept-Encoding") + .addHeader("Cache-Control: max-age=60") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("FAIL") + .build() + ) + assertThat(get(server.url("/")).body.string()).isEqualTo("ABCABCABC") + assertThat(get(server.url("/")).body.string()).isEqualTo("ABCABCABC") + } + + @Test + fun conditionalCacheHitIsNotDoublePooled() { + clientTestRule.ensureAllConnectionsReleased() + server.enqueue( + MockResponse.Builder() + .addHeader("ETag: v1") + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .clearHeaders() + .code(HttpURLConnection.HTTP_NOT_MODIFIED) + .build() + ) + assertThat(get(server.url("/")).body.string()).isEqualTo("A") + assertThat(get(server.url("/")).body.string()).isEqualTo("A") + assertThat(client.connectionPool.idleConnectionCount()).isEqualTo(1) + } + + @Test + fun expiresDateBeforeModifiedDate() { + assertConditionallyCached( + MockResponse.Builder() + .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) + .addHeader("Expires: " + formatDate(-2, TimeUnit.HOURS)) + .build() + ) + } + + @Test + fun requestMaxAge() { + server.enqueue( + MockResponse.Builder() + .body("A") + .addHeader("Last-Modified: " + formatDate(-2, TimeUnit.HOURS)) + .addHeader("Date: " + formatDate(-1, TimeUnit.MINUTES)) + .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("B") + .build() + ) + assertThat(get(server.url("/")).body.string()).isEqualTo("A") + val request = Request.Builder() + .url(server.url("/")) + .header("Cache-Control", "max-age=30") + .build() + val response = client.newCall(request).execute() + assertThat(response.body.string()).isEqualTo("B") + } + + @Test + fun requestMinFresh() { + server.enqueue( + MockResponse.Builder() + .body("A") + .addHeader("Cache-Control: max-age=60") + .addHeader("Date: " + formatDate(0, TimeUnit.MINUTES)) + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("B") + .build() + ) + assertThat(get(server.url("/")).body.string()).isEqualTo("A") + val request = Request.Builder() + .url(server.url("/")) + .header("Cache-Control", "min-fresh=120") + .build() + val response = client.newCall(request).execute() + assertThat(response.body.string()).isEqualTo("B") + } + + @Test + fun requestMaxStale() { + server.enqueue( + MockResponse.Builder() + .body("A") + .addHeader("Cache-Control: max-age=120") + .addHeader("Date: " + formatDate(-4, TimeUnit.MINUTES)) + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("B") + .build() + ) + assertThat(get(server.url("/")).body.string()).isEqualTo("A") + val request = Request.Builder() + .url(server.url("/")) + .header("Cache-Control", "max-stale=180") + .build() + val response = client.newCall(request).execute() + assertThat(response.body.string()).isEqualTo("A") + assertThat(response.header("Warning")).isEqualTo( + "110 HttpURLConnection \"Response is stale\"" + ) + } + + @Test + fun requestMaxStaleDirectiveWithNoValue() { + // Add a stale response to the cache. + server.enqueue( + MockResponse.Builder() + .body("A") + .addHeader("Cache-Control: max-age=120") + .addHeader("Date: " + formatDate(-4, TimeUnit.MINUTES)) + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("B") + .build() + ) + assertThat(get(server.url("/")).body.string()).isEqualTo("A") + + // With max-stale, we'll return that stale response. + val request = Request.Builder() + .url(server.url("/")) + .header("Cache-Control", "max-stale") + .build() + val response = client.newCall(request).execute() + assertThat(response.body.string()).isEqualTo("A") + assertThat(response.header("Warning")).isEqualTo( + "110 HttpURLConnection \"Response is stale\"" + ) + } + + @Test + fun requestMaxStaleNotHonoredWithMustRevalidate() { + server.enqueue( + MockResponse.Builder() + .body("A") + .addHeader("Cache-Control: max-age=120, must-revalidate") + .addHeader("Date: " + formatDate(-4, TimeUnit.MINUTES)) + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("B") + .build() + ) + assertThat(get(server.url("/")).body.string()).isEqualTo("A") + val request = Request.Builder() + .url(server.url("/")) + .header("Cache-Control", "max-stale=180") + .build() + val response = client.newCall(request).execute() + assertThat(response.body.string()).isEqualTo("B") + } + + @Test + fun requestOnlyIfCachedWithNoResponseCached() { + // (no responses enqueued) + val request = Request.Builder() + .url(server.url("/")) + .header("Cache-Control", "only-if-cached") + .build() + val response = client.newCall(request).execute() + assertThat(response.body.source().exhausted()).isTrue() + assertThat(response.code).isEqualTo(504) + assertThat(cache.requestCount()).isEqualTo(1) + assertThat(cache.networkCount()).isEqualTo(0) + assertThat(cache.hitCount()).isEqualTo(0) + } + + @Test + fun requestOnlyIfCachedWithFullResponseCached() { + server.enqueue( + MockResponse.Builder() + .body("A") + .addHeader("Cache-Control: max-age=30") + .addHeader("Date: " + formatDate(0, TimeUnit.MINUTES)) + .build() + ) + assertThat(get(server.url("/")).body.string()).isEqualTo("A") + val request = Request.Builder() + .url(server.url("/")) + .header("Cache-Control", "only-if-cached") + .build() + val response = client.newCall(request).execute() + assertThat(response.body.string()).isEqualTo("A") + assertThat(cache.requestCount()).isEqualTo(2) + assertThat(cache.networkCount()).isEqualTo(1) + assertThat(cache.hitCount()).isEqualTo(1) + } + + @Test + fun requestOnlyIfCachedWithConditionalResponseCached() { + server.enqueue( + MockResponse.Builder() + .body("A") + .addHeader("Cache-Control: max-age=30") + .addHeader("Date: " + formatDate(-1, TimeUnit.MINUTES)) + .build() + ) + assertThat(get(server.url("/")).body.string()).isEqualTo("A") + val request = Request.Builder() + .url(server.url("/")) + .header("Cache-Control", "only-if-cached") + .build() + val response = client.newCall(request).execute() + assertThat(response.body.source().exhausted()).isTrue() + assertThat(response.code).isEqualTo(504) + assertThat(cache.requestCount()).isEqualTo(2) + assertThat(cache.networkCount()).isEqualTo(1) + assertThat(cache.hitCount()).isEqualTo(0) + } + + @Test + fun requestOnlyIfCachedWithUnhelpfulResponseCached() { + server.enqueue( + MockResponse.Builder() + .body("A") + .build() + ) + assertThat(get(server.url("/")).body.string()).isEqualTo("A") + val request = Request.Builder() + .url(server.url("/")) + .header("Cache-Control", "only-if-cached") + .build() + val response = client.newCall(request).execute() + assertThat(response.body.source().exhausted()).isTrue() + assertThat(response.code).isEqualTo(504) + assertThat(cache.requestCount()).isEqualTo(2) + assertThat(cache.networkCount()).isEqualTo(1) + assertThat(cache.hitCount()).isEqualTo(0) + } + + @Test + fun requestCacheControlNoCache() { + server.enqueue( + MockResponse.Builder() + .addHeader("Last-Modified: " + formatDate(-120, TimeUnit.SECONDS)) + .addHeader("Date: " + formatDate(0, TimeUnit.SECONDS)) + .addHeader("Cache-Control: max-age=60") + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("B") + .build() + ) + val url = server.url("/") + assertThat(get(url).body.string()).isEqualTo("A") + val request = Request.Builder() + .url(url) + .header("Cache-Control", "no-cache") + .build() + val response = client.newCall(request).execute() + assertThat(response.body.string()).isEqualTo("B") + } + + @Test + fun requestPragmaNoCache() { + server.enqueue( + MockResponse.Builder() + .addHeader("Last-Modified: " + formatDate(-120, TimeUnit.SECONDS)) + .addHeader("Date: " + formatDate(0, TimeUnit.SECONDS)) + .addHeader("Cache-Control: max-age=60") + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("B") + .build() + ) + val url = server.url("/") + assertThat(get(url).body.string()).isEqualTo("A") + val request = Request.Builder() + .url(url) + .header("Pragma", "no-cache") + .build() + val response = client.newCall(request).execute() + assertThat(response.body.string()).isEqualTo("B") + } + + @Test + fun clientSuppliedIfModifiedSinceWithCachedResult() { + val response = MockResponse.Builder() + .addHeader("ETag: v3") + .addHeader("Cache-Control: max-age=0") + .build() + val ifModifiedSinceDate = formatDate(-24, TimeUnit.HOURS) + val request = + assertClientSuppliedCondition(response, "If-Modified-Since", ifModifiedSinceDate) + assertThat(request.headers["If-Modified-Since"]).isEqualTo(ifModifiedSinceDate) + assertThat(request.headers["If-None-Match"]).isNull() + } + + @Test + fun clientSuppliedIfNoneMatchSinceWithCachedResult() { + val lastModifiedDate = formatDate(-3, TimeUnit.MINUTES) + val response = MockResponse.Builder() + .addHeader("Last-Modified: $lastModifiedDate") + .addHeader("Date: " + formatDate(-2, TimeUnit.MINUTES)) + .addHeader("Cache-Control: max-age=0") + .build() + val request = assertClientSuppliedCondition(response, "If-None-Match", "v1") + assertThat(request.headers["If-None-Match"]).isEqualTo("v1") + assertThat(request.headers["If-Modified-Since"]).isNull() + } + + private fun assertClientSuppliedCondition( + seed: MockResponse, conditionName: String, + conditionValue: String + ): RecordedRequest { + server.enqueue( + seed.newBuilder() + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .code(HttpURLConnection.HTTP_NOT_MODIFIED) + .build() + ) + val url = server.url("/") + assertThat(get(url).body.string()).isEqualTo("A") + val request = Request.Builder() + .url(url) + .header(conditionName, conditionValue) + .build() + val response = client.newCall(request).execute() + assertThat(response.code).isEqualTo(HttpURLConnection.HTTP_NOT_MODIFIED) + assertThat(response.body.string()).isEqualTo("") + server.takeRequest() // seed + return server.takeRequest() + } + + /** + * For Last-Modified and Date headers, we should echo the date back in the exact format we were + * served. + */ + @Test + fun retainServedDateFormat() { + // Serve a response with a non-standard date format that OkHttp supports. + val lastModifiedDate = Date(System.currentTimeMillis() + TimeUnit.HOURS.toMillis(-1)) + val servedDate = Date(System.currentTimeMillis() + TimeUnit.HOURS.toMillis(-2)) + val dateFormat: DateFormat = SimpleDateFormat("EEE dd-MMM-yyyy HH:mm:ss z", Locale.US) + dateFormat.timeZone = TimeZone.getTimeZone("America/New_York") + val lastModifiedString = dateFormat.format(lastModifiedDate) + val servedString = dateFormat.format(servedDate) + + // This response should be conditionally cached. + server.enqueue( + MockResponse.Builder() + .addHeader("Last-Modified: $lastModifiedString") + .addHeader("Expires: $servedString") + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .code(HttpURLConnection.HTTP_NOT_MODIFIED) + .build() + ) + assertThat(get(server.url("/")).body.string()).isEqualTo("A") + assertThat(get(server.url("/")).body.string()).isEqualTo("A") + + // The first request has no conditions. + val request1 = server.takeRequest() + assertThat(request1.headers["If-Modified-Since"]).isNull() + + // The 2nd request uses the server's date format. + val request2 = server.takeRequest() + assertThat(request2.headers["If-Modified-Since"]).isEqualTo(lastModifiedString) + } + + @Test + fun clientSuppliedConditionWithoutCachedResult() { + server.enqueue( + MockResponse.Builder() + .code(HttpURLConnection.HTTP_NOT_MODIFIED) + .build() + ) + val request = Request.Builder() + .url(server.url("/")) + .header("If-Modified-Since", formatDate(-24, TimeUnit.HOURS)) + .build() + val response = client.newCall(request).execute() + assertThat(response.code).isEqualTo(HttpURLConnection.HTTP_NOT_MODIFIED) + assertThat(response.body.string()).isEqualTo("") + } + + @Test + fun authorizationRequestFullyCached() { + server.enqueue( + MockResponse.Builder() + .addHeader("Cache-Control: max-age=60") + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("B") + .build() + ) + val url = server.url("/") + val request = Request.Builder() + .url(url) + .header("Authorization", "password") + .build() + val response = client.newCall(request).execute() + assertThat(response.body.string()).isEqualTo("A") + assertThat(get(url).body.string()).isEqualTo("A") + } + + @Test + fun contentLocationDoesNotPopulateCache() { + server.enqueue( + MockResponse.Builder() + .addHeader("Cache-Control: max-age=60") + .addHeader("Content-Location: /bar") + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("B") + .build() + ) + assertThat(get(server.url("/foo")).body.string()).isEqualTo("A") + assertThat(get(server.url("/bar")).body.string()).isEqualTo("B") + } + + @Test + fun connectionIsReturnedToPoolAfterConditionalSuccess() { + server.enqueue( + MockResponse.Builder() + .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) + .addHeader("Cache-Control: max-age=0") + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .code(HttpURLConnection.HTTP_NOT_MODIFIED) + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("B") + .build() + ) + assertThat(get(server.url("/a")).body.string()).isEqualTo("A") + assertThat(get(server.url("/a")).body.string()).isEqualTo("A") + assertThat(get(server.url("/b")).body.string()).isEqualTo("B") + assertThat(server.takeRequest().sequenceNumber).isEqualTo(0) + assertThat(server.takeRequest().sequenceNumber).isEqualTo(1) + assertThat(server.takeRequest().sequenceNumber).isEqualTo(2) + } + + @Test + fun statisticsConditionalCacheMiss() { + server.enqueue( + MockResponse.Builder() + .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) + .addHeader("Cache-Control: max-age=0") + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("B") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("C") + .build() + ) + assertThat(get(server.url("/")).body.string()).isEqualTo("A") + assertThat(cache.requestCount()).isEqualTo(1) + assertThat(cache.networkCount()).isEqualTo(1) + assertThat(cache.hitCount()).isEqualTo(0) + assertThat(get(server.url("/")).body.string()).isEqualTo("B") + assertThat(get(server.url("/")).body.string()).isEqualTo("C") + assertThat(cache.requestCount()).isEqualTo(3) + assertThat(cache.networkCount()).isEqualTo(3) + assertThat(cache.hitCount()).isEqualTo(0) + } + + @Test + fun statisticsConditionalCacheHit() { + server.enqueue( + MockResponse.Builder() + .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) + .addHeader("Cache-Control: max-age=0") + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .code(HttpURLConnection.HTTP_NOT_MODIFIED) + .build() + ) + server.enqueue( + MockResponse.Builder() + .code(HttpURLConnection.HTTP_NOT_MODIFIED) + .build() + ) + assertThat(get(server.url("/")).body.string()).isEqualTo("A") + assertThat(cache.requestCount()).isEqualTo(1) + assertThat(cache.networkCount()).isEqualTo(1) + assertThat(cache.hitCount()).isEqualTo(0) + assertThat(get(server.url("/")).body.string()).isEqualTo("A") + assertThat(get(server.url("/")).body.string()).isEqualTo("A") + assertThat(cache.requestCount()).isEqualTo(3) + assertThat(cache.networkCount()).isEqualTo(3) + assertThat(cache.hitCount()).isEqualTo(2) + } + + @Test + fun statisticsFullCacheHit() { + server.enqueue( + MockResponse.Builder() + .addHeader("Cache-Control: max-age=60") + .body("A") + .build() + ) + assertThat(get(server.url("/")).body.string()).isEqualTo("A") + assertThat(cache.requestCount()).isEqualTo(1) + assertThat(cache.networkCount()).isEqualTo(1) + assertThat(cache.hitCount()).isEqualTo(0) + assertThat(get(server.url("/")).body.string()).isEqualTo("A") + assertThat(get(server.url("/")).body.string()).isEqualTo("A") + assertThat(cache.requestCount()).isEqualTo(3) + assertThat(cache.networkCount()).isEqualTo(1) + assertThat(cache.hitCount()).isEqualTo(2) + } + + @Test + fun varyMatchesChangedRequestHeaderField() { + server.enqueue( + MockResponse.Builder() + .addHeader("Cache-Control: max-age=60") + .addHeader("Vary: Accept-Language") + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("B") + .build() + ) + val url = server.url("/") + val frRequest = Request.Builder() + .url(url) + .header("Accept-Language", "fr-CA") + .build() + val frResponse = client.newCall(frRequest).execute() + assertThat(frResponse.body.string()).isEqualTo("A") + val enRequest = Request.Builder() + .url(url) + .header("Accept-Language", "en-US") + .build() + val enResponse = client.newCall(enRequest).execute() + assertThat(enResponse.body.string()).isEqualTo("B") + } + + @Test + fun varyMatchesUnchangedRequestHeaderField() { + server.enqueue( + MockResponse.Builder() + .addHeader("Cache-Control: max-age=60") + .addHeader("Vary: Accept-Language") + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("B") + .build() + ) + val url = server.url("/") + val request = Request.Builder() + .url(url) + .header("Accept-Language", "fr-CA") + .build() + val response1 = client.newCall(request).execute() + assertThat(response1.body.string()).isEqualTo("A") + val request1 = Request.Builder() + .url(url) + .header("Accept-Language", "fr-CA") + .build() + val response2 = client.newCall(request1).execute() + assertThat(response2.body.string()).isEqualTo("A") + } + + @Test + fun varyMatchesAbsentRequestHeaderField() { + server.enqueue( + MockResponse.Builder() + .addHeader("Cache-Control: max-age=60") + .addHeader("Vary: Foo") + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("B") + .build() + ) + assertThat(get(server.url("/")).body.string()).isEqualTo("A") + assertThat(get(server.url("/")).body.string()).isEqualTo("A") + } + + @Test + fun varyMatchesAddedRequestHeaderField() { + server.enqueue( + MockResponse.Builder() + .addHeader("Cache-Control: max-age=60") + .addHeader("Vary: Foo") + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("B") + .build() + ) + assertThat(get(server.url("/")).body.string()).isEqualTo("A") + val request = Request.Builder() + .url(server.url("/")).header("Foo", "bar") + .build() + val response = client.newCall(request).execute() + assertThat(response.body.string()).isEqualTo("B") + } + + @Test + fun varyMatchesRemovedRequestHeaderField() { + server.enqueue( + MockResponse.Builder() + .addHeader("Cache-Control: max-age=60") + .addHeader("Vary: Foo") + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("B") + .build() + ) + val request = Request.Builder() + .url(server.url("/")).header("Foo", "bar") + .build() + val fooresponse = client.newCall(request).execute() + assertThat(fooresponse.body.string()).isEqualTo("A") + assertThat(get(server.url("/")).body.string()).isEqualTo("B") + } + + @Test + fun varyFieldsAreCaseInsensitive() { + server.enqueue( + MockResponse.Builder() + .addHeader("Cache-Control: max-age=60") + .addHeader("Vary: ACCEPT-LANGUAGE") + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("B") + .build() + ) + val url = server.url("/") + val request = Request.Builder() + .url(url) + .header("Accept-Language", "fr-CA") + .build() + val response1 = client.newCall(request).execute() + assertThat(response1.body.string()).isEqualTo("A") + val request1 = Request.Builder() + .url(url) + .header("accept-language", "fr-CA") + .build() + val response2 = client.newCall(request1).execute() + assertThat(response2.body.string()).isEqualTo("A") + } + + @Test + fun varyMultipleFieldsWithMatch() { + server.enqueue( + MockResponse.Builder() + .addHeader("Cache-Control: max-age=60") + .addHeader("Vary: Accept-Language, Accept-Charset") + .addHeader("Vary: Accept-Encoding") + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("B") + .build() + ) + val url = server.url("/") + val request = Request.Builder() + .url(url) + .header("Accept-Language", "fr-CA") + .header("Accept-Charset", "UTF-8") + .header("Accept-Encoding", "identity") + .build() + val response1 = client.newCall(request).execute() + assertThat(response1.body.string()).isEqualTo("A") + val request1 = Request.Builder() + .url(url) + .header("Accept-Language", "fr-CA") + .header("Accept-Charset", "UTF-8") + .header("Accept-Encoding", "identity") + .build() + val response2 = client.newCall(request1).execute() + assertThat(response2.body.string()).isEqualTo("A") + } + + @Test + fun varyMultipleFieldsWithNoMatch() { + server.enqueue( + MockResponse.Builder() + .addHeader("Cache-Control: max-age=60") + .addHeader("Vary: Accept-Language, Accept-Charset") + .addHeader("Vary: Accept-Encoding") + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("B") + .build() + ) + val url = server.url("/") + val frRequest = Request.Builder() + .url(url) + .header("Accept-Language", "fr-CA") + .header("Accept-Charset", "UTF-8") + .header("Accept-Encoding", "identity") + .build() + val frResponse = client.newCall(frRequest).execute() + assertThat(frResponse.body.string()).isEqualTo("A") + val enRequest = Request.Builder() + .url(url) + .header("Accept-Language", "en-CA") + .header("Accept-Charset", "UTF-8") + .header("Accept-Encoding", "identity") + .build() + val enResponse = client.newCall(enRequest).execute() + assertThat(enResponse.body.string()).isEqualTo("B") + } + + @Test + fun varyMultipleFieldValuesWithMatch() { + server.enqueue( + MockResponse.Builder() + .addHeader("Cache-Control: max-age=60") + .addHeader("Vary: Accept-Language") + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("B") + .build() + ) + val url = server.url("/") + val request1 = Request.Builder() + .url(url) + .addHeader("Accept-Language", "fr-CA, fr-FR") + .addHeader("Accept-Language", "en-US") + .build() + val response1 = client.newCall(request1).execute() + assertThat(response1.body.string()).isEqualTo("A") + val request2 = Request.Builder() + .url(url) + .addHeader("Accept-Language", "fr-CA, fr-FR") + .addHeader("Accept-Language", "en-US") + .build() + val response2 = client.newCall(request2).execute() + assertThat(response2.body.string()).isEqualTo("A") + } + + @Test + fun varyMultipleFieldValuesWithNoMatch() { + server.enqueue( + MockResponse.Builder() + .addHeader("Cache-Control: max-age=60") + .addHeader("Vary: Accept-Language") + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("B") + .build() + ) + val url = server.url("/") + val request1 = Request.Builder() + .url(url) + .addHeader("Accept-Language", "fr-CA, fr-FR") + .addHeader("Accept-Language", "en-US") + .build() + val response1 = client.newCall(request1).execute() + assertThat(response1.body.string()).isEqualTo("A") + val request2 = Request.Builder() + .url(url) + .addHeader("Accept-Language", "fr-CA") + .addHeader("Accept-Language", "en-US") + .build() + val response2 = client.newCall(request2).execute() + assertThat(response2.body.string()).isEqualTo("B") + } + + @Test + fun varyAsterisk() { + server.enqueue( + MockResponse.Builder() + .addHeader("Cache-Control: max-age=60") + .addHeader("Vary: *") + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("B") + .build() + ) + assertThat(get(server.url("/")).body.string()).isEqualTo("A") + assertThat(get(server.url("/")).body.string()).isEqualTo("B") + } + + @Test + fun varyAndHttps() { + server.useHttps(handshakeCertificates.sslSocketFactory()) + server.enqueue( + MockResponse.Builder() + .addHeader("Cache-Control: max-age=60") + .addHeader("Vary: Accept-Language") + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("B") + .build() + ) + client = client.newBuilder() + .sslSocketFactory( + handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager + ) + .hostnameVerifier(NULL_HOSTNAME_VERIFIER) + .build() + val url = server.url("/") + val request1 = Request.Builder() + .url(url) + .header("Accept-Language", "en-US") + .build() + val response1 = client.newCall(request1).execute() + assertThat(response1.body.string()).isEqualTo("A") + val request2 = Request.Builder() + .url(url) + .header("Accept-Language", "en-US") + .build() + val response2 = client.newCall(request2).execute() + assertThat(response2.body.string()).isEqualTo("A") + } + + @Test + fun cachePlusCookies() { + val cookieJar = RecordingCookieJar() + client = client.newBuilder() + .cookieJar(cookieJar) + .build() + server.enqueue( + MockResponse.Builder() + .addHeader("Set-Cookie: a=FIRST") + .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) + .addHeader("Cache-Control: max-age=0") + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .addHeader("Set-Cookie: a=SECOND") + .code(HttpURLConnection.HTTP_NOT_MODIFIED) + .build() + ) + val url = server.url("/") + assertThat(get(url).body.string()).isEqualTo("A") + cookieJar.assertResponseCookies("a=FIRST; path=/") + assertThat(get(url).body.string()).isEqualTo("A") + cookieJar.assertResponseCookies("a=SECOND; path=/") + } + + @get:Throws(Exception::class) + @get:Test + val headersReturnsNetworkEndToEndHeaders: Unit + get() { + server.enqueue( + MockResponse.Builder() + .addHeader("Allow: GET, HEAD") + .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) + .addHeader("Cache-Control: max-age=0") + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .addHeader("Allow: GET, HEAD, PUT") + .code(HttpURLConnection.HTTP_NOT_MODIFIED) + .build() + ) + val response1 = get(server.url("/")) + assertThat(response1.body.string()).isEqualTo("A") + assertThat(response1.header("Allow")).isEqualTo("GET, HEAD") + val response2 = get(server.url("/")) + assertThat(response2.body.string()).isEqualTo("A") + assertThat(response2.header("Allow")).isEqualTo("GET, HEAD, PUT") + } + + @get:Throws(Exception::class) + @get:Test + val headersReturnsCachedHopByHopHeaders: Unit + get() { + server.enqueue( + MockResponse.Builder() + .addHeader("Transfer-Encoding: identity") + .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) + .addHeader("Cache-Control: max-age=0") + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .addHeader("Transfer-Encoding: none") + .code(HttpURLConnection.HTTP_NOT_MODIFIED) + .build() + ) + val response1 = get(server.url("/")) + assertThat(response1.body.string()).isEqualTo("A") + assertThat(response1.header("Transfer-Encoding")).isEqualTo("identity") + val response2 = get(server.url("/")) + assertThat(response2.body.string()).isEqualTo("A") + assertThat(response2.header("Transfer-Encoding")).isEqualTo("identity") + } + + @get:Throws(Exception::class) + @get:Test + val headersDeletesCached100LevelWarnings: Unit + get() { + server.enqueue( + MockResponse.Builder() + .addHeader("Warning: 199 test danger") + .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) + .addHeader("Cache-Control: max-age=0") + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .code(HttpURLConnection.HTTP_NOT_MODIFIED) + .build() + ) + val response1 = get(server.url("/")) + assertThat(response1.body.string()).isEqualTo("A") + assertThat(response1.header("Warning")).isEqualTo("199 test danger") + val response2 = get(server.url("/")) + assertThat(response2.body.string()).isEqualTo("A") + assertThat(response2.header("Warning")).isNull() + } + + @get:Throws(Exception::class) + @get:Test + val headersRetainsCached200LevelWarnings: Unit + get() { + server.enqueue( + MockResponse.Builder() + .addHeader("Warning: 299 test danger") + .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) + .addHeader("Cache-Control: max-age=0") + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .code(HttpURLConnection.HTTP_NOT_MODIFIED) + .build() + ) + val response1 = get(server.url("/")) + assertThat(response1.body.string()).isEqualTo("A") + assertThat(response1.header("Warning")).isEqualTo("299 test danger") + val response2 = get(server.url("/")) + assertThat(response2.body.string()).isEqualTo("A") + assertThat(response2.header("Warning")).isEqualTo("299 test danger") + } + + @Test + fun doNotCachePartialResponse() { + assertNotCached( + MockResponse.Builder() + .code(HttpURLConnection.HTTP_PARTIAL) + .addHeader("Date: " + formatDate(0, TimeUnit.HOURS)) + .addHeader("Content-Range: bytes 100-100/200") + .addHeader("Cache-Control: max-age=60") + .build() + ) + } + + @Test + fun conditionalHitUpdatesCache() { + server.enqueue( + MockResponse.Builder() + .addHeader("Last-Modified: " + formatDate(0, TimeUnit.SECONDS)) + .addHeader("Cache-Control: max-age=0") + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .addHeader("Cache-Control: max-age=30") + .addHeader("Allow: GET, HEAD") + .code(HttpURLConnection.HTTP_NOT_MODIFIED) + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("B") + .build() + ) + + // A cache miss writes the cache. + val t0 = System.currentTimeMillis() + val response1 = get(server.url("/a")) + assertThat(response1.body.string()).isEqualTo("A") + assertThat(response1.header("Allow")).isNull() + assertThat((response1.receivedResponseAtMillis - t0).toDouble()) + .isCloseTo(0.0, Offset.offset(250.0)) + + // A conditional cache hit updates the cache. + Thread.sleep(500) // Make sure t0 and t1 are distinct. + val t1 = System.currentTimeMillis() + val response2 = get(server.url("/a")) + assertThat(response2.code).isEqualTo(HttpURLConnection.HTTP_OK) + assertThat(response2.body.string()).isEqualTo("A") + assertThat(response2.header("Allow")).isEqualTo("GET, HEAD") + val updatedTimestamp = response2.receivedResponseAtMillis + assertThat((updatedTimestamp - t1).toDouble()) + .isCloseTo(0.0, Offset.offset(250.0)) + + // A full cache hit reads the cache. + Thread.sleep(10) + val response3 = get(server.url("/a")) + assertThat(response3.body.string()).isEqualTo("A") + assertThat(response3.header("Allow")).isEqualTo("GET, HEAD") + assertThat(response3.receivedResponseAtMillis).isEqualTo(updatedTimestamp) + assertThat(server.requestCount).isEqualTo(2) + } + + @Test + fun responseSourceHeaderCached() { + server.enqueue( + MockResponse.Builder() + .body("A") + .addHeader("Cache-Control: max-age=30") + .addHeader("Date: " + formatDate(0, TimeUnit.MINUTES)) + .build() + ) + assertThat(get(server.url("/")).body.string()).isEqualTo("A") + val request = Request.Builder() + .url(server.url("/")).header("Cache-Control", "only-if-cached") + .build() + val response = client.newCall(request).execute() + assertThat(response.body.string()).isEqualTo("A") + } + + @Test + fun responseSourceHeaderConditionalCacheFetched() { + server.enqueue( + MockResponse.Builder() + .body("A") + .addHeader("Cache-Control: max-age=30") + .addHeader("Date: " + formatDate(-31, TimeUnit.MINUTES)) + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("B") + .addHeader("Cache-Control: max-age=30") + .addHeader("Date: " + formatDate(0, TimeUnit.MINUTES)) + .build() + ) + assertThat(get(server.url("/")).body.string()).isEqualTo("A") + val response = get(server.url("/")) + assertThat(response.body.string()).isEqualTo("B") + } + + @Test + fun responseSourceHeaderConditionalCacheNotFetched() { + server.enqueue( + MockResponse.Builder() + .body("A") + .addHeader("Cache-Control: max-age=0") + .addHeader("Date: " + formatDate(0, TimeUnit.MINUTES)) + .build() + ) + server.enqueue( + MockResponse.Builder() + .code(304) + .build() + ) + assertThat(get(server.url("/")).body.string()).isEqualTo("A") + val response = get(server.url("/")) + assertThat(response.body.string()).isEqualTo("A") + } + + @Test + fun responseSourceHeaderFetched() { + server.enqueue( + MockResponse.Builder() + .body("A") + .build() + ) + val response = get(server.url("/")) + assertThat(response.body.string()).isEqualTo("A") + } + + @Test + fun emptyResponseHeaderNameFromCacheIsLenient() { + val headers = Headers.Builder() + .add("Cache-Control: max-age=120") + addHeaderLenient(headers, ": A") + server.enqueue( + MockResponse.Builder() + .headers(headers.build()) + .body("body") + .build() + ) + val response = get(server.url("/")) + assertThat(response.header("")).isEqualTo("A") + assertThat(response.body.string()).isEqualTo("body") + } + + /** + * Old implementations of OkHttp's response cache wrote header fields like ":status: 200 OK". This + * broke our cached response parser because it split on the first colon. This regression test + * exists to help us read these old bad cache entries. + * + * https://github.com/square/okhttp/issues/227 + */ + @Test + fun testGoldenCacheResponse() { + cache.close() + server.enqueue( + MockResponse.Builder() + .clearHeaders() + .code(HttpURLConnection.HTTP_NOT_MODIFIED) + .build() + ) + val url = server.url("/") + val urlKey = key(url) + val entryMetadata = """ + $url + GET + 0 + HTTP/1.1 200 OK + 7 + :status: 200 OK + :version: HTTP/1.1 + etag: foo + content-length: 3 + OkHttp-Received-Millis: ${System.currentTimeMillis()} + X-Android-Response-Source: NETWORK 200 + OkHttp-Sent-Millis: ${System.currentTimeMillis()} + + TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA + 1 + MIIBpDCCAQ2gAwIBAgIBATANBgkqhkiG9w0BAQsFADAYMRYwFAYDVQQDEw1qd2lsc29uLmxvY2FsMB4XDTEzMDgyOTA1MDE1OVoXDTEzMDgzMDA1MDE1OVowGDEWMBQGA1UEAxMNandpbHNvbi5sb2NhbDCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEAlFW+rGo/YikCcRghOyKkJanmVmJSce/p2/jH1QvNIFKizZdh8AKNwojt3ywRWaDULA/RlCUcltF3HGNsCyjQI/+Lf40x7JpxXF8oim1E6EtDoYtGWAseelawus3IQ13nmo6nWzfyCA55KhAWf4VipelEy8DjcuFKv6L0xwXnI0ECAwEAATANBgkqhkiG9w0BAQsFAAOBgQAuluNyPo1HksU3+Mr/PyRQIQS4BI7pRXN8mcejXmqyscdP7S6J21FBFeRR8/XNjVOp4HT9uSc2hrRtTEHEZCmpyoxixbnM706ikTmC7SN/GgM+SmcoJ1ipJcNcl8N0X6zym4dmyFfXKHu2PkTo7QFdpOJFvP3lIigcSZXozfmEDg== + -1 + + """.trimIndent() + val entryBody = "abc" + val journalBody = """libcore.io.DiskLruCache +1 +201105 +2 + +CLEAN $urlKey ${entryMetadata.length} ${entryBody.length} +""" + fileSystem.createDirectory(cache.directoryPath) + writeFile(cache.directoryPath, "$urlKey.0", entryMetadata) + writeFile(cache.directoryPath, "$urlKey.1", entryBody) + writeFile(cache.directoryPath, "journal", journalBody) + cache = Cache(cache.directory.path.toPath(), Int.MAX_VALUE.toLong(), fileSystem) + client = client.newBuilder() + .cache(cache) + .build() + val response = get(url) + assertThat(response.body.string()).isEqualTo(entryBody) + assertThat(response.header("Content-Length")).isEqualTo("3") + assertThat(response.header("etag")).isEqualTo("foo") + } + + /** Exercise the cache format in OkHttp 2.7 and all earlier releases. */ + @Test + fun testGoldenCacheHttpsResponseOkHttp27() { + val url = server.url("/") + val urlKey = key(url) + val prefix = get().getPrefix() + val entryMetadata = """ + $url + GET + 0 + HTTP/1.1 200 OK + 4 + Content-Length: 3 + $prefix-Received-Millis: ${System.currentTimeMillis()} + $prefix-Sent-Millis: ${System.currentTimeMillis()} + Cache-Control: max-age=60 + + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 + 1 + MIIBnDCCAQWgAwIBAgIBATANBgkqhkiG9w0BAQsFADAUMRIwEAYDVQQDEwlsb2NhbGhvc3QwHhcNMTUxMjIyMDExMTQwWhcNMTUxMjIzMDExMTQwWjAUMRIwEAYDVQQDEwlsb2NhbGhvc3QwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAJTn2Dh8xYmegvpOSmsKb2Os6Cxf1L4fYbnHr/turInUD5r1P7ZAuxurY880q3GT5bUDoirS3IfucddrT1AcAmUzEmk/FDjggiP8DlxFkY/XwXBlhRDVIp/mRuASPMGInckc0ZaixOkRFyrxADj+r1eaSmXCIvV5yTY6IaIokLj1AgMBAAEwDQYJKoZIhvcNAQELBQADgYEAFblnedqtfRqI9j2WDyPPoG0NTZf9xwjeUu+ju+Ktty8u9k7Lgrrd/DH2mQEtBD1Ctvp91MJfAClNg3faZzwClUyu5pd0QXRZEUwSwZQNen2QWDHRlVsItclBJ4t+AJLqTbwofWi4m4K8REOl593hD55E4+lY22JZiVQyjsQhe6I= + 0 + + """.trimIndent() + val entryBody = "abc" + val journalBody = """libcore.io.DiskLruCache +1 +201105 +2 + +DIRTY $urlKey +CLEAN $urlKey ${entryMetadata.length} ${entryBody.length} +""" + fileSystem.createDirectory(cache.directoryPath) + writeFile(cache.directoryPath, "$urlKey.0", entryMetadata) + writeFile(cache.directoryPath, "$urlKey.1", entryBody) + writeFile(cache.directoryPath, "journal", journalBody) + cache.close() + cache = Cache(cache.directory.path.toPath(), Int.MAX_VALUE.toLong(), fileSystem) + client = client.newBuilder() + .cache(cache) + .build() + val response = get(url) + assertThat(response.body.string()).isEqualTo(entryBody) + assertThat(response.header("Content-Length")).isEqualTo("3") + } + + /** The TLS version is present in OkHttp 3.0 and beyond. */ + @Test + fun testGoldenCacheHttpsResponseOkHttp30() { + val url = server.url("/") + val urlKey = key(url) + val prefix = get().getPrefix() + val entryMetadata = """ + |$url + |GET + |0 + |HTTP/1.1 200 OK + |4 + |Content-Length: 3 + |$prefix-Received-Millis: ${System.currentTimeMillis()} + |$prefix-Sent-Millis: ${System.currentTimeMillis()} + |Cache-Control: max-age=60 + | + |TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 + |1 + |MIIBnDCCAQWgAwIBAgIBATANBgkqhkiG9w0BAQsFADAUMRIwEAYDVQQDEwlsb2NhbGhvc3QwHhcNMTUxMjIyMDExMTQwWhcNMTUxMjIzMDExMTQwWjAUMRIwEAYDVQQDEwlsb2NhbGhvc3QwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAJTn2Dh8xYmegvpOSmsKb2Os6Cxf1L4fYbnHr/turInUD5r1P7ZAuxurY880q3GT5bUDoirS3IfucddrT1AcAmUzEmk/FDjggiP8DlxFkY/XwXBlhRDVIp/mRuASPMGInckc0ZaixOkRFyrxADj+r1eaSmXCIvV5yTY6IaIokLj1AgMBAAEwDQYJKoZIhvcNAQELBQADgYEAFblnedqtfRqI9j2WDyPPoG0NTZf9xwjeUu+ju+Ktty8u9k7Lgrrd/DH2mQEtBD1Ctvp91MJfAClNg3faZzwClUyu5pd0QXRZEUwSwZQNen2QWDHRlVsItclBJ4t+AJLqTbwofWi4m4K8REOl593hD55E4+lY22JZiVQyjsQhe6I= + |0 + |TLSv1.2 + | + |""".trimMargin() + val entryBody = "abc" + val journalBody = """ + |libcore.io.DiskLruCache + |1 + |201105 + |2 + | + |DIRTY $urlKey + |CLEAN $urlKey ${entryMetadata.length} ${entryBody.length} + |""".trimMargin() + fileSystem.createDirectory(cache.directoryPath) + writeFile(cache.directoryPath, "$urlKey.0", entryMetadata) + writeFile(cache.directoryPath, "$urlKey.1", entryBody) + writeFile(cache.directoryPath, "journal", journalBody) + cache.close() + cache = Cache(cache.directory.path.toPath(), Int.MAX_VALUE.toLong(), fileSystem) + client = client.newBuilder() + .cache(cache) + .build() + val response = get(url) + assertThat(response.body.string()).isEqualTo(entryBody) + assertThat(response.header("Content-Length")).isEqualTo("3") + } + + @Test + fun testGoldenCacheHttpResponseOkHttp30() { + val url = server.url("/") + val urlKey = key(url) + val prefix = get().getPrefix() + val entryMetadata = """ + |$url + |GET + |0 + |HTTP/1.1 200 OK + |4 + |Cache-Control: max-age=60 + |Content-Length: 3 + |$prefix-Received-Millis: ${System.currentTimeMillis()} + |$prefix-Sent-Millis: ${System.currentTimeMillis()} + | + """.trimMargin() + val entryBody = "abc" + val journalBody = """ + |libcore.io.DiskLruCache + |1 + |201105 + |2 + | + |DIRTY $urlKey + |CLEAN $urlKey ${entryMetadata.length} ${entryBody.length} + | + """.trimMargin() + fileSystem.createDirectory(cache.directoryPath) + writeFile(cache.directoryPath, "$urlKey.0", entryMetadata) + writeFile(cache.directoryPath, "$urlKey.1", entryBody) + writeFile(cache.directoryPath, "journal", journalBody) + cache.close() + cache = Cache(cache.directory.path.toPath(), Int.MAX_VALUE.toLong(), fileSystem) + client = client.newBuilder() + .cache(cache) + .build() + val response = get(url) + assertThat(response.body.string()).isEqualTo(entryBody) + assertThat(response.header("Content-Length")).isEqualTo("3") + } + + @Test + fun evictAll() { + server.enqueue( + MockResponse.Builder() + .addHeader("Cache-Control: max-age=60") + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("B") + .build() + ) + val url = server.url("/") + assertThat(get(url).body.string()).isEqualTo("A") + client.cache!!.evictAll() + assertThat(client.cache!!.size()).isEqualTo(0) + assertThat(get(url).body.string()).isEqualTo("B") + } + + @Test + fun networkInterceptorInvokedForConditionalGet() { + server.enqueue( + MockResponse.Builder() + .addHeader("ETag: v1") + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .code(HttpURLConnection.HTTP_NOT_MODIFIED) + .build() + ) + + // Seed the cache. + val url = server.url("/") + assertThat(get(url).body.string()).isEqualTo("A") + val ifNoneMatch = AtomicReference() + client = client.newBuilder() + .addNetworkInterceptor(Interceptor { chain: Interceptor.Chain -> + ifNoneMatch.compareAndSet(null, chain.request().header("If-None-Match")) + chain.proceed(chain.request()) + }) + .build() + + // Confirm the value is cached and intercepted. + assertThat(get(url).body.string()).isEqualTo("A") + assertThat(ifNoneMatch.get()).isEqualTo("v1") + } + + @Test + fun networkInterceptorNotInvokedForFullyCached() { + server.enqueue( + MockResponse.Builder() + .addHeader("Cache-Control: max-age=60") + .body("A") + .build() + ) + + // Seed the cache. + val url = server.url("/") + assertThat(get(url).body.string()).isEqualTo("A") + + // Confirm the interceptor isn't exercised. + client = client.newBuilder() + .addNetworkInterceptor(Interceptor { chain: Interceptor.Chain? -> throw AssertionError() }) + .build() + assertThat(get(url).body.string()).isEqualTo("A") + } + + @Test + fun iterateCache() { + // Put some responses in the cache. + server.enqueue( + MockResponse.Builder() + .body("a") + .build() + ) + val urlA = server.url("/a") + assertThat(get(urlA).body.string()).isEqualTo("a") + server.enqueue( + MockResponse.Builder() + .body("b") + .build() + ) + val urlB = server.url("/b") + assertThat(get(urlB).body.string()).isEqualTo("b") + server.enqueue( + MockResponse.Builder() + .body("c") + .build() + ) + val urlC = server.url("/c") + assertThat(get(urlC).body.string()).isEqualTo("c") + + // Confirm the iterator returns those responses... + val i: Iterator = cache.urls() + assertThat(i.hasNext()).isTrue() + assertThat(i.next()).isEqualTo(urlA.toString()) + assertThat(i.hasNext()).isTrue() + assertThat(i.next()).isEqualTo(urlB.toString()) + assertThat(i.hasNext()).isTrue() + assertThat(i.next()).isEqualTo(urlC.toString()) + + // ... and nothing else. + assertThat(i.hasNext()).isFalse() + try { + i.next() + fail() + } catch (expected: NoSuchElementException) { + } + } + + @Test + fun iteratorRemoveFromCache() { + // Put a response in the cache. + server.enqueue( + MockResponse.Builder() + .addHeader("Cache-Control: max-age=60") + .body("a") + .build() + ) + val url = server.url("/a") + assertThat(get(url).body.string()).isEqualTo("a") + + // Remove it with iteration. + val i = cache.urls() + assertThat(i.next()).isEqualTo(url.toString()) + i.remove() + + // Confirm that subsequent requests suffer a cache miss. + server.enqueue( + MockResponse.Builder() + .body("b") + .build() + ) + assertThat(get(url).body.string()).isEqualTo("b") + } + + @Test + fun iteratorRemoveWithoutNextThrows() { + // Put a response in the cache. + server.enqueue( + MockResponse.Builder() + .body("a") + .build() + ) + val url = server.url("/a") + assertThat(get(url).body.string()).isEqualTo("a") + val i = cache.urls() + assertThat(i.hasNext()).isTrue() + try { + i.remove() + fail() + } catch (expected: IllegalStateException) { + } + } + + @Test + fun iteratorRemoveOncePerCallToNext() { + // Put a response in the cache. + server.enqueue( + MockResponse.Builder() + .body("a") + .build() + ) + val url = server.url("/a") + assertThat(get(url).body.string()).isEqualTo("a") + val i = cache.urls() + assertThat(i.next()).isEqualTo(url.toString()) + i.remove() + + // Too many calls to remove(). + try { + i.remove() + fail() + } catch (expected: IllegalStateException) { + } + } + + @Test + fun elementEvictedBetweenHasNextAndNext() { + // Put a response in the cache. + server.enqueue( + MockResponse.Builder() + .body("a") + .build() + ) + val url = server.url("/a") + assertThat(get(url).body.string()).isEqualTo("a") + + // The URL will remain available if hasNext() returned true... + val i = cache.urls() + assertThat(i.hasNext()).isTrue() + + // ...so even when we evict the element, we still get something back. + cache.evictAll() + assertThat(i.next()).isEqualTo(url.toString()) + + // Remove does nothing. But most importantly, it doesn't throw! + i.remove() + } + + @Test + fun elementEvictedBeforeHasNextIsOmitted() { + // Put a response in the cache. + server.enqueue( + MockResponse.Builder() + .body("a") + .build() + ) + val url = server.url("/a") + assertThat(get(url).body.string()).isEqualTo("a") + val i: Iterator = cache.urls() + cache.evictAll() + + // The URL was evicted before hasNext() made any promises. + assertThat(i.hasNext()).isFalse() + try { + i.next() + fail() + } catch (expected: NoSuchElementException) { + } + } + + /** Test https://github.com/square/okhttp/issues/1712. */ + @Test + fun conditionalMissUpdatesCache() { + server.enqueue( + MockResponse.Builder() + .addHeader("ETag: v1") + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .code(HttpURLConnection.HTTP_NOT_MODIFIED) + .build() + ) + server.enqueue( + MockResponse.Builder() + .addHeader("ETag: v2") + .body("B") + .build() + ) + server.enqueue( + MockResponse.Builder() + .code(HttpURLConnection.HTTP_NOT_MODIFIED) + .build() + ) + val url = server.url("/") + assertThat(get(url).body.string()).isEqualTo("A") + assertThat(get(url).body.string()).isEqualTo("A") + assertThat(get(url).body.string()).isEqualTo("B") + assertThat(get(url).body.string()).isEqualTo("B") + assertThat(server.takeRequest().headers["If-None-Match"]).isNull() + assertThat(server.takeRequest().headers["If-None-Match"]).isEqualTo("v1") + assertThat(server.takeRequest().headers["If-None-Match"]).isEqualTo("v1") + assertThat(server.takeRequest().headers["If-None-Match"]).isEqualTo("v2") + } + + @Test + fun combinedCacheHeadersCanBeNonAscii() { + server.enqueue( + MockResponse.Builder() + .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) + .addHeader("Cache-Control: max-age=0") + .addHeaderLenient("Alpha", "α") + .addHeaderLenient("β", "Beta") + .body("abcd") + .build() + ) + server.enqueue( + MockResponse.Builder() + .addHeader("Transfer-Encoding: none") + .addHeaderLenient("Gamma", "Γ") + .addHeaderLenient("Δ", "Delta") + .code(HttpURLConnection.HTTP_NOT_MODIFIED) + .build() + ) + val response1 = get(server.url("/")) + assertThat(response1.header("Alpha")).isEqualTo("α") + assertThat(response1.header("β")).isEqualTo("Beta") + assertThat(response1.body.string()).isEqualTo("abcd") + val response2 = get(server.url("/")) + assertThat(response2.header("Alpha")).isEqualTo("α") + assertThat(response2.header("β")).isEqualTo("Beta") + assertThat(response2.header("Gamma")).isEqualTo("Γ") + assertThat(response2.header("Δ")).isEqualTo("Delta") + assertThat(response2.body.string()).isEqualTo("abcd") + } + + @Test + fun etagConditionCanBeNonAscii() { + server.enqueue( + MockResponse.Builder() + .addHeaderLenient("Etag", "α") + .addHeader("Cache-Control: max-age=0") + .body("abcd") + .build() + ) + server.enqueue( + MockResponse.Builder() + .code(HttpURLConnection.HTTP_NOT_MODIFIED) + .build() + ) + val response1 = get(server.url("/")) + assertThat(response1.body.string()).isEqualTo("abcd") + val response2 = get(server.url("/")) + assertThat(response2.body.string()).isEqualTo("abcd") + assertThat(server.takeRequest().headers["If-None-Match"]).isNull() + assertThat(server.takeRequest().headers["If-None-Match"]).isEqualTo("α") + } + + @Test + fun conditionalHitHeadersCombined() { + server.enqueue( + MockResponse.Builder() + .addHeader("Etag", "a") + .addHeader("Cache-Control: max-age=0") + .addHeader("A: a1") + .addHeader("B: b2") + .addHeader("B: b3") + .body("abcd") + .build() + ) + server.enqueue( + MockResponse.Builder() + .code(HttpURLConnection.HTTP_NOT_MODIFIED) + .addHeader("B: b4") + .addHeader("B: b5") + .addHeader("C: c6") + .build() + ) + val response1 = get(server.url("/")) + assertThat(response1.body.string()).isEqualTo("abcd") + assertThat(response1.headers).isEqualTo( + headersOf( + "Etag", "a", "Cache-Control", "max-age=0", + "A", "a1", "B", "b2", "B", "b3", "Content-Length", "4" + ) + ) + + // The original 'A' header is retained because the network response doesn't have one. + // The original 'B' headers are replaced by the network response. + // The network's 'C' header is added. + val response2 = get(server.url("/")) + assertThat(response2.body.string()).isEqualTo("abcd") + assertThat(response2.headers).isEqualTo( + headersOf( + "Etag", "a", "Cache-Control", "max-age=0", + "A", "a1", "Content-Length", "4", "B", "b4", "B", "b5", "C", "c6" + ) + ) + } + + private operator fun get(url: HttpUrl): Response { + val request = Request.Builder() + .url(url) + .build() + return client.newCall(request).execute() + } + + private fun writeFile(directory: Path, file: String, content: String) { + val sink = fileSystem.sink(directory.div(file)).buffer() + sink.writeUtf8(content) + sink.close() + } + + /** + * @param delta the offset from the current date to use. Negative values yield dates in the past; + * positive values yield dates in the future. + */ + private fun formatDate(delta: Long, timeUnit: TimeUnit): String { + return formatDate(Date(System.currentTimeMillis() + timeUnit.toMillis(delta))) + } + + private fun formatDate(date: Date): String { + val rfc1123: DateFormat = SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss zzz", Locale.US) + rfc1123.timeZone = TimeZone.getTimeZone("GMT") + return rfc1123.format(date) + } + + private fun assertNotCached(response: MockResponse) { + server.enqueue( + response.newBuilder() + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("B") + .build() + ) + val url = server.url("/") + assertThat(get(url).body.string()).isEqualTo("A") + assertThat(get(url).body.string()).isEqualTo("B") + } + + /** @return the request with the conditional get headers. */ + private fun assertConditionallyCached(response: MockResponse): RecordedRequest { + // scenario 1: condition succeeds + server.enqueue( + response.newBuilder() + .body("A") + .status("HTTP/1.1 200 A-OK") + .build() + ) + server.enqueue( + MockResponse.Builder() + .code(HttpURLConnection.HTTP_NOT_MODIFIED) + .build() + ) + + // scenario 2: condition fails + server.enqueue( + response.newBuilder() + .body("B") + .status("HTTP/1.1 200 B-OK") + .build() + ) + server.enqueue( + MockResponse.Builder() + .status("HTTP/1.1 200 C-OK") + .body("C") + .build() + ) + val valid = server.url("/valid") + val response1 = get(valid) + assertThat(response1.body.string()).isEqualTo("A") + assertThat(response1.code).isEqualTo(HttpURLConnection.HTTP_OK) + assertThat(response1.message).isEqualTo("A-OK") + val response2 = get(valid) + assertThat(response2.body.string()).isEqualTo("A") + assertThat(response2.code).isEqualTo(HttpURLConnection.HTTP_OK) + assertThat(response2.message).isEqualTo("A-OK") + val invalid = server.url("/invalid") + val response3 = get(invalid) + assertThat(response3.body.string()).isEqualTo("B") + assertThat(response3.code).isEqualTo(HttpURLConnection.HTTP_OK) + assertThat(response3.message).isEqualTo("B-OK") + val response4 = get(invalid) + assertThat(response4.body.string()).isEqualTo("C") + assertThat(response4.code).isEqualTo(HttpURLConnection.HTTP_OK) + assertThat(response4.message).isEqualTo("C-OK") + server.takeRequest() // regular get + return server.takeRequest() // conditional get + } + + @Test + fun immutableIsCached() { + server.enqueue( + MockResponse.Builder() + .addHeader("Cache-Control", "immutable, max-age=10") + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("B") + .build() + ) + val url = server.url("/") + assertThat(get(url).body.string()).isEqualTo("A") + assertThat(get(url).body.string()).isEqualTo("A") + } + + @Test + fun immutableIsCachedAfterMultipleCalls() { + server.enqueue( + MockResponse.Builder() + .body("A") + .build() + ) + server.enqueue( + MockResponse.Builder() + .addHeader("Cache-Control", "immutable, max-age=10") + .body("B") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("C") + .build() + ) + val url = server.url("/") + assertThat(get(url).body.string()).isEqualTo("A") + assertThat(get(url).body.string()).isEqualTo("B") + assertThat(get(url).body.string()).isEqualTo("B") + } + + @Test + fun immutableIsNotCachedBeyondFreshnessLifetime() { + // last modified: 115 seconds ago + // served: 15 seconds ago + // default lifetime: (115 - 15) / 10 = 10 seconds + // expires: 10 seconds from served date = 5 seconds ago + val lastModifiedDate = formatDate(-115, TimeUnit.SECONDS) + val conditionalRequest = assertConditionallyCached( + MockResponse.Builder() + .addHeader("Cache-Control: immutable") + .addHeader("Last-Modified: $lastModifiedDate") + .addHeader("Date: " + formatDate(-15, TimeUnit.SECONDS)) + .build() + ) + assertThat(conditionalRequest.headers["If-Modified-Since"]) + .isEqualTo(lastModifiedDate) + } + + @Test + fun testPublicPathConstructor() { + val events: MutableList = ArrayList() + fileSystem.createDirectories(cache.directoryPath) + fileSystem.createDirectories(cache.directoryPath) + val loggingFileSystem: FileSystem = object : ForwardingFileSystem(fileSystem) { + override fun onPathParameter( + path: Path, + functionName: String, + parameterName: String + ): Path { + events.add("$functionName:$path") + return path + } + + override fun onPathResult(path: Path, functionName: String): Path { + events.add("$functionName:$path") + return path + } + } + val path: Path = "/cache".toPath() + val c = Cache(path, 100000L, loggingFileSystem) + assertThat(c.directoryPath).isEqualTo(path) + c.size() + assertThat(events).containsExactly( + "metadataOrNull:/cache/journal.bkp", + "metadataOrNull:/cache", + "sink:/cache/journal.bkp", + "delete:/cache/journal.bkp", + "metadataOrNull:/cache/journal", + "metadataOrNull:/cache", + "sink:/cache/journal.tmp", + "metadataOrNull:/cache/journal", + "atomicMove:/cache/journal.tmp", + "atomicMove:/cache/journal", + "appendingSink:/cache/journal" + ) + events.clear() + c.size() + assertThat(events).isEmpty() + } + + private fun assertFullyCached(response: MockResponse) { + server.enqueue(response.newBuilder().body("A").build()) + server.enqueue(response.newBuilder().body("B").build()) + val url = server.url("/") + assertThat(get(url).body.string()).isEqualTo("A") + assertThat(get(url).body.string()).isEqualTo("A") + } + + /** + * Shortens the body of `response` but not the corresponding headers. Only useful to test + * how clients respond to the premature conclusion of the HTTP body. + */ + private fun truncateViolently( + builder: MockResponse.Builder, numBytesToKeep: Int + ): MockResponse.Builder { + val response = builder.build() + builder.socketPolicy(DisconnectAtEnd) + val headers = response.headers + val fullBody = Buffer() + response.body!!.writeTo(fullBody) + val truncatedBody = Buffer() + truncatedBody.write(fullBody, numBytesToKeep.toLong()) + builder.body(truncatedBody) + builder.headers(headers) + return builder + } + + internal enum class TransferKind { + CHUNKED { + override fun setBody(response: MockResponse.Builder, content: Buffer, chunkSize: Int) { + response.chunkedBody(content, chunkSize) + } + }, + FIXED_LENGTH { + override fun setBody(response: MockResponse.Builder, content: Buffer, chunkSize: Int) { + response.body(content) + } + }, + END_OF_STREAM { + override fun setBody(response: MockResponse.Builder, content: Buffer, chunkSize: Int) { + response.body(content) + response.socketPolicy(DisconnectAtEnd) + response.removeHeader("Content-Length") + } + }; + + abstract fun setBody(response: MockResponse.Builder, content: Buffer, chunkSize: Int) + + fun setBody(response: MockResponse.Builder, content: String, chunkSize: Int) { + setBody(response, Buffer().writeUtf8(content), chunkSize) + } + } + + /** Returns a gzipped copy of `bytes`. */ + fun gzip(data: String): Buffer { + val result = Buffer() + val sink = GzipSink(result).buffer() + sink.writeUtf8(data) + sink.close() + return result + } + + companion object { + private val NULL_HOSTNAME_VERIFIER = HostnameVerifier { hostname, session -> true } + } +} diff --git a/okhttp/src/test/java/okhttp3/CertificateChainCleanerTest.java b/okhttp/src/test/java/okhttp3/CertificateChainCleanerTest.java deleted file mode 100644 index 905f4549f54a..000000000000 --- a/okhttp/src/test/java/okhttp3/CertificateChainCleanerTest.java +++ /dev/null @@ -1,293 +0,0 @@ -/* - * Copyright (C) 2016 Square, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package okhttp3; - -import java.security.cert.Certificate; -import java.security.cert.X509Certificate; -import java.util.ArrayList; -import java.util.List; -import javax.net.ssl.SSLPeerUnverifiedException; -import javax.net.ssl.X509TrustManager; -import okhttp3.internal.tls.CertificateChainCleaner; -import okhttp3.tls.HandshakeCertificates; -import okhttp3.tls.HeldCertificate; -import org.junit.jupiter.api.Test; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.fail; - -public final class CertificateChainCleanerTest { - @Test public void equalsFromCertificate() { - HeldCertificate rootA = new HeldCertificate.Builder() - .serialNumber(1L) - .build(); - HeldCertificate rootB = new HeldCertificate.Builder() - .serialNumber(2L) - .build(); - assertThat(CertificateChainCleaner.Companion.get(rootB.certificate(), rootA.certificate())) - .isEqualTo(CertificateChainCleaner.Companion.get(rootA.certificate(), rootB.certificate())); - } - - @Test public void equalsFromTrustManager() { - HandshakeCertificates handshakeCertificates = new HandshakeCertificates.Builder().build(); - X509TrustManager x509TrustManager = handshakeCertificates.trustManager(); - assertThat(CertificateChainCleaner.Companion.get(x509TrustManager)).isEqualTo( - CertificateChainCleaner.Companion.get(x509TrustManager)); - } - - @Test public void normalizeSingleSelfSignedCertificate() throws Exception { - HeldCertificate root = new HeldCertificate.Builder() - .serialNumber(1L) - .build(); - CertificateChainCleaner cleaner = CertificateChainCleaner.Companion.get(root.certificate()); - assertThat(cleaner.clean(list(root), "hostname")).isEqualTo(list(root)); - } - - @Test public void normalizeUnknownSelfSignedCertificate() { - HeldCertificate root = new HeldCertificate.Builder() - .serialNumber(1L) - .build(); - CertificateChainCleaner cleaner = CertificateChainCleaner.Companion.get(); - - try { - cleaner.clean(list(root), "hostname"); - fail(); - } catch (SSLPeerUnverifiedException expected) { - } - } - - @Test public void orderedChainOfCertificatesWithRoot() throws Exception { - HeldCertificate root = new HeldCertificate.Builder() - .serialNumber(1L) - .certificateAuthority(1) - .build(); - HeldCertificate certA = new HeldCertificate.Builder() - .serialNumber(2L) - .certificateAuthority(0) - .signedBy(root) - .build(); - HeldCertificate certB = new HeldCertificate.Builder() - .serialNumber(3L) - .signedBy(certA) - .build(); - - CertificateChainCleaner cleaner = CertificateChainCleaner.Companion.get(root.certificate()); - assertThat(cleaner.clean(list(certB, certA, root), "hostname")).isEqualTo( - list(certB, certA, root)); - } - - @Test public void orderedChainOfCertificatesWithoutRoot() throws Exception { - HeldCertificate root = new HeldCertificate.Builder() - .serialNumber(1L) - .certificateAuthority(1) - .build(); - HeldCertificate certA = new HeldCertificate.Builder() - .serialNumber(2L) - .certificateAuthority(0) - .signedBy(root) - .build(); - HeldCertificate certB = new HeldCertificate.Builder() - .serialNumber(3L) - .signedBy(certA) - .build(); - - CertificateChainCleaner cleaner = CertificateChainCleaner.Companion.get(root.certificate()); - // Root is added! - assertThat(cleaner.clean(list(certB, certA), "hostname")).isEqualTo( - list(certB, certA, root)); - } - - @Test public void unorderedChainOfCertificatesWithRoot() throws Exception { - HeldCertificate root = new HeldCertificate.Builder() - .serialNumber(1L) - .certificateAuthority(2) - .build(); - HeldCertificate certA = new HeldCertificate.Builder() - .serialNumber(2L) - .certificateAuthority(1) - .signedBy(root) - .build(); - HeldCertificate certB = new HeldCertificate.Builder() - .serialNumber(3L) - .certificateAuthority(0) - .signedBy(certA) - .build(); - HeldCertificate certC = new HeldCertificate.Builder() - .serialNumber(4L) - .signedBy(certB) - .build(); - - CertificateChainCleaner cleaner = CertificateChainCleaner.Companion.get(root.certificate()); - assertThat(cleaner.clean(list(certC, certA, root, certB), "hostname")).isEqualTo( - list(certC, certB, certA, root)); - } - - @Test public void unorderedChainOfCertificatesWithoutRoot() throws Exception { - HeldCertificate root = new HeldCertificate.Builder() - .serialNumber(1L) - .certificateAuthority(2) - .build(); - HeldCertificate certA = new HeldCertificate.Builder() - .serialNumber(2L) - .certificateAuthority(1) - .signedBy(root) - .build(); - HeldCertificate certB = new HeldCertificate.Builder() - .serialNumber(3L) - .certificateAuthority(0) - .signedBy(certA) - .build(); - HeldCertificate certC = new HeldCertificate.Builder() - .serialNumber(4L) - .signedBy(certB) - .build(); - - CertificateChainCleaner cleaner = CertificateChainCleaner.Companion.get(root.certificate()); - assertThat(cleaner.clean(list(certC, certA, certB), "hostname")).isEqualTo( - list(certC, certB, certA, root)); - } - - @Test public void unrelatedCertificatesAreOmitted() throws Exception { - HeldCertificate root = new HeldCertificate.Builder() - .serialNumber(1L) - .certificateAuthority(1) - .build(); - HeldCertificate certA = new HeldCertificate.Builder() - .serialNumber(2L) - .certificateAuthority(0) - .signedBy(root) - .build(); - HeldCertificate certB = new HeldCertificate.Builder() - .serialNumber(3L) - .signedBy(certA) - .build(); - HeldCertificate certUnnecessary = new HeldCertificate.Builder() - .serialNumber(4L) - .build(); - - CertificateChainCleaner cleaner = CertificateChainCleaner.Companion.get(root.certificate()); - assertThat(cleaner.clean(list(certB, certUnnecessary, certA, root), "hostname")).isEqualTo( - list(certB, certA, root)); - } - - @Test public void chainGoesAllTheWayToSelfSignedRoot() throws Exception { - HeldCertificate selfSigned = new HeldCertificate.Builder() - .serialNumber(1L) - .certificateAuthority(2) - .build(); - HeldCertificate trusted = new HeldCertificate.Builder() - .serialNumber(2L) - .signedBy(selfSigned) - .certificateAuthority(1) - .build(); - HeldCertificate certA = new HeldCertificate.Builder() - .serialNumber(3L) - .certificateAuthority(0) - .signedBy(trusted) - .build(); - HeldCertificate certB = new HeldCertificate.Builder() - .serialNumber(4L) - .signedBy(certA) - .build(); - - CertificateChainCleaner cleaner = CertificateChainCleaner.Companion.get( - selfSigned.certificate(), trusted.certificate()); - assertThat(cleaner.clean(list(certB, certA), "hostname")).isEqualTo( - list(certB, certA, trusted, selfSigned)); - assertThat(cleaner.clean(list(certB, certA, trusted), "hostname")).isEqualTo( - list(certB, certA, trusted, selfSigned)); - assertThat(cleaner.clean(list(certB, certA, trusted, selfSigned), "hostname")).isEqualTo( - list(certB, certA, trusted, selfSigned)); - } - - @Test public void trustedRootNotSelfSigned() throws Exception { - HeldCertificate unknownSigner = new HeldCertificate.Builder() - .serialNumber(1L) - .certificateAuthority(2) - .build(); - HeldCertificate trusted = new HeldCertificate.Builder() - .signedBy(unknownSigner) - .certificateAuthority(1) - .serialNumber(2L) - .build(); - HeldCertificate intermediateCa = new HeldCertificate.Builder() - .signedBy(trusted) - .certificateAuthority(0) - .serialNumber(3L) - .build(); - HeldCertificate certificate = new HeldCertificate.Builder() - .signedBy(intermediateCa) - .serialNumber(4L) - .build(); - - CertificateChainCleaner cleaner = CertificateChainCleaner.Companion.get(trusted.certificate()); - assertThat(cleaner.clean(list(certificate, intermediateCa), "hostname")).isEqualTo( - list(certificate, intermediateCa, trusted)); - assertThat(cleaner.clean(list(certificate, intermediateCa, trusted), "hostname")).isEqualTo( - list(certificate, intermediateCa, trusted)); - } - - @Test public void chainMaxLength() throws Exception { - List heldCertificates = chainOfLength(10); - List certificates = new ArrayList<>(); - for (HeldCertificate heldCertificate : heldCertificates) { - certificates.add(heldCertificate.certificate()); - } - - X509Certificate root = heldCertificates.get(heldCertificates.size() - 1).certificate(); - CertificateChainCleaner cleaner = CertificateChainCleaner.Companion.get(root); - assertThat(cleaner.clean(certificates, "hostname")).isEqualTo(certificates); - assertThat(cleaner.clean(certificates.subList(0, 9), "hostname")).isEqualTo( - certificates); - } - - @Test public void chainTooLong() { - List heldCertificates = chainOfLength(11); - List certificates = new ArrayList<>(); - for (HeldCertificate heldCertificate : heldCertificates) { - certificates.add(heldCertificate.certificate()); - } - - X509Certificate root = heldCertificates.get(heldCertificates.size() - 1).certificate(); - CertificateChainCleaner cleaner = CertificateChainCleaner.Companion.get(root); - try { - cleaner.clean(certificates, "hostname"); - fail(); - } catch (SSLPeerUnverifiedException expected) { - } - } - - /** Returns a chain starting at the leaf certificate and progressing to the root. */ - private List chainOfLength(int length) { - List result = new ArrayList<>(); - for (int i = 1; i <= length; i++) { - result.add(0, new HeldCertificate.Builder() - .signedBy(!result.isEmpty() ? result.get(0) : null) - .certificateAuthority(length - i) - .serialNumber(i) - .build()); - } - return result; - } - - private List list(HeldCertificate... heldCertificates) { - List result = new ArrayList<>(); - for (HeldCertificate heldCertificate : heldCertificates) { - result.add(heldCertificate.certificate()); - } - return result; - } -} diff --git a/okhttp/src/test/java/okhttp3/CertificateChainCleanerTest.kt b/okhttp/src/test/java/okhttp3/CertificateChainCleanerTest.kt new file mode 100644 index 000000000000..01f827637fe0 --- /dev/null +++ b/okhttp/src/test/java/okhttp3/CertificateChainCleanerTest.kt @@ -0,0 +1,307 @@ +/* + * Copyright (C) 2016 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3 + +import java.security.cert.Certificate +import javax.net.ssl.SSLPeerUnverifiedException +import okhttp3.internal.tls.CertificateChainCleaner.Companion.get +import okhttp3.tls.HandshakeCertificates +import okhttp3.tls.HeldCertificate +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Assertions.fail +import org.junit.jupiter.api.Test + +class CertificateChainCleanerTest { + @Test + fun equalsFromCertificate() { + val rootA = HeldCertificate.Builder() + .serialNumber(1L) + .build() + val rootB = HeldCertificate.Builder() + .serialNumber(2L) + .build() + assertThat(get(rootB.certificate, rootA.certificate)) + .isEqualTo(get(rootA.certificate, rootB.certificate)) + } + + @Test + fun equalsFromTrustManager() { + val handshakeCertificates = HandshakeCertificates.Builder().build() + val x509TrustManager = handshakeCertificates.trustManager + assertThat(get(x509TrustManager)).isEqualTo(get(x509TrustManager)) + } + + @Test + fun normalizeSingleSelfSignedCertificate() { + val root = HeldCertificate.Builder() + .serialNumber(1L) + .build() + val cleaner = get(root.certificate) + assertThat(cleaner.clean(list(root), "hostname")).isEqualTo(list(root)) + } + + @Test + fun normalizeUnknownSelfSignedCertificate() { + val root = HeldCertificate.Builder() + .serialNumber(1L) + .build() + val cleaner = get() + try { + cleaner.clean(list(root), "hostname") + fail() + } catch (expected: SSLPeerUnverifiedException) { + } + } + + @Test + fun orderedChainOfCertificatesWithRoot() { + val root = HeldCertificate.Builder() + .serialNumber(1L) + .certificateAuthority(1) + .build() + val certA = HeldCertificate.Builder() + .serialNumber(2L) + .certificateAuthority(0) + .signedBy(root) + .build() + val certB = HeldCertificate.Builder() + .serialNumber(3L) + .signedBy(certA) + .build() + val cleaner = get(root.certificate) + assertThat(cleaner.clean(list(certB, certA, root), "hostname")) + .isEqualTo(list(certB, certA, root)) + } + + @Test + fun orderedChainOfCertificatesWithoutRoot() { + val root = HeldCertificate.Builder() + .serialNumber(1L) + .certificateAuthority(1) + .build() + val certA = HeldCertificate.Builder() + .serialNumber(2L) + .certificateAuthority(0) + .signedBy(root) + .build() + val certB = HeldCertificate.Builder() + .serialNumber(3L) + .signedBy(certA) + .build() + val cleaner = get(root.certificate) + // Root is added! + assertThat(cleaner.clean(list(certB, certA), "hostname")).isEqualTo( + list(certB, certA, root) + ) + } + + @Test + fun unorderedChainOfCertificatesWithRoot() { + val root = HeldCertificate.Builder() + .serialNumber(1L) + .certificateAuthority(2) + .build() + val certA = HeldCertificate.Builder() + .serialNumber(2L) + .certificateAuthority(1) + .signedBy(root) + .build() + val certB = HeldCertificate.Builder() + .serialNumber(3L) + .certificateAuthority(0) + .signedBy(certA) + .build() + val certC = HeldCertificate.Builder() + .serialNumber(4L) + .signedBy(certB) + .build() + val cleaner = get(root.certificate) + assertThat(cleaner.clean(list(certC, certA, root, certB), "hostname")).isEqualTo( + list(certC, certB, certA, root) + ) + } + + @Test + fun unorderedChainOfCertificatesWithoutRoot() { + val root = HeldCertificate.Builder() + .serialNumber(1L) + .certificateAuthority(2) + .build() + val certA = HeldCertificate.Builder() + .serialNumber(2L) + .certificateAuthority(1) + .signedBy(root) + .build() + val certB = HeldCertificate.Builder() + .serialNumber(3L) + .certificateAuthority(0) + .signedBy(certA) + .build() + val certC = HeldCertificate.Builder() + .serialNumber(4L) + .signedBy(certB) + .build() + val cleaner = get(root.certificate) + assertThat(cleaner.clean(list(certC, certA, certB), "hostname")).isEqualTo( + list(certC, certB, certA, root) + ) + } + + @Test + fun unrelatedCertificatesAreOmitted() { + val root = HeldCertificate.Builder() + .serialNumber(1L) + .certificateAuthority(1) + .build() + val certA = HeldCertificate.Builder() + .serialNumber(2L) + .certificateAuthority(0) + .signedBy(root) + .build() + val certB = HeldCertificate.Builder() + .serialNumber(3L) + .signedBy(certA) + .build() + val certUnnecessary = HeldCertificate.Builder() + .serialNumber(4L) + .build() + val cleaner = get(root.certificate) + assertThat(cleaner.clean(list(certB, certUnnecessary, certA, root), "hostname")) + .isEqualTo( + list(certB, certA, root) + ) + } + + @Test + fun chainGoesAllTheWayToSelfSignedRoot() { + val selfSigned = HeldCertificate.Builder() + .serialNumber(1L) + .certificateAuthority(2) + .build() + val trusted = HeldCertificate.Builder() + .serialNumber(2L) + .signedBy(selfSigned) + .certificateAuthority(1) + .build() + val certA = HeldCertificate.Builder() + .serialNumber(3L) + .certificateAuthority(0) + .signedBy(trusted) + .build() + val certB = HeldCertificate.Builder() + .serialNumber(4L) + .signedBy(certA) + .build() + val cleaner = get( + selfSigned.certificate, trusted.certificate + ) + assertThat(cleaner.clean(list(certB, certA), "hostname")).isEqualTo( + list(certB, certA, trusted, selfSigned) + ) + assertThat(cleaner.clean(list(certB, certA, trusted), "hostname")).isEqualTo( + list(certB, certA, trusted, selfSigned) + ) + assertThat(cleaner.clean(list(certB, certA, trusted, selfSigned), "hostname")) + .isEqualTo( + list(certB, certA, trusted, selfSigned) + ) + } + + @Test + fun trustedRootNotSelfSigned() { + val unknownSigner = HeldCertificate.Builder() + .serialNumber(1L) + .certificateAuthority(2) + .build() + val trusted = HeldCertificate.Builder() + .signedBy(unknownSigner) + .certificateAuthority(1) + .serialNumber(2L) + .build() + val intermediateCa = HeldCertificate.Builder() + .signedBy(trusted) + .certificateAuthority(0) + .serialNumber(3L) + .build() + val certificate = HeldCertificate.Builder() + .signedBy(intermediateCa) + .serialNumber(4L) + .build() + val cleaner = get(trusted.certificate) + assertThat(cleaner.clean(list(certificate, intermediateCa), "hostname")) + .isEqualTo( + list(certificate, intermediateCa, trusted) + ) + assertThat(cleaner.clean(list(certificate, intermediateCa, trusted), "hostname")) + .isEqualTo( + list(certificate, intermediateCa, trusted) + ) + } + + @Test + fun chainMaxLength() { + val heldCertificates = chainOfLength(10) + val certificates: MutableList = ArrayList() + for (heldCertificate in heldCertificates) { + certificates.add(heldCertificate.certificate) + } + val root = heldCertificates[heldCertificates.size - 1].certificate + val cleaner = get(root) + assertThat(cleaner.clean(certificates, "hostname")).isEqualTo(certificates) + assertThat(cleaner.clean(certificates.subList(0, 9), "hostname")).isEqualTo( + certificates + ) + } + + @Test + fun chainTooLong() { + val heldCertificates = chainOfLength(11) + val certificates: MutableList = ArrayList() + for (heldCertificate in heldCertificates) { + certificates.add(heldCertificate.certificate) + } + val root = heldCertificates[heldCertificates.size - 1].certificate + val cleaner = get(root) + try { + cleaner.clean(certificates, "hostname") + fail() + } catch (expected: SSLPeerUnverifiedException) { + } + } + + /** Returns a chain starting at the leaf certificate and progressing to the root. */ + private fun chainOfLength(length: Int): List { + val result = mutableListOf() + for (i in 1..length) { + result.add( + 0, HeldCertificate.Builder() + .signedBy(if (result.isNotEmpty()) result[0] else null) + .certificateAuthority(length - i) + .serialNumber(i.toLong()) + .build() + ) + } + return result + } + + private fun list(vararg heldCertificates: HeldCertificate): List { + val result: MutableList = ArrayList() + for (heldCertificate in heldCertificates) { + result.add(heldCertificate.certificate) + } + return result + } +} diff --git a/okhttp/src/test/java/okhttp3/CertificatePinnerTest.java b/okhttp/src/test/java/okhttp3/CertificatePinnerTest.java deleted file mode 100644 index b25c573d6cee..000000000000 --- a/okhttp/src/test/java/okhttp3/CertificatePinnerTest.java +++ /dev/null @@ -1,320 +0,0 @@ -/* - * Copyright (C) 2014 Square, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package okhttp3; - -import java.util.HashSet; -import java.util.List; -import javax.net.ssl.SSLPeerUnverifiedException; -import okhttp3.tls.HeldCertificate; -import org.junit.jupiter.api.Test; - -import static java.util.Arrays.asList; -import static java.util.Collections.singletonList; -import static okhttp3.CertificatePinner.sha1Hash; -import static okio.ByteString.decodeBase64; -import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assertions.fail; - -public final class CertificatePinnerTest { - static final HeldCertificate certA1 = new HeldCertificate.Builder() - .serialNumber(100L) - .build(); - static final String certA1Sha256Pin = CertificatePinner.pin(certA1.certificate()); - - static final HeldCertificate certB1 = new HeldCertificate.Builder() - .serialNumber(200L) - .build(); - static final String certB1Sha256Pin = CertificatePinner.pin(certB1.certificate()); - - static final HeldCertificate certC1 = new HeldCertificate.Builder() - .serialNumber(300L) - .build(); - static final String certC1Sha1Pin = "sha1/" + sha1Hash(certC1.certificate()).base64(); - - @Test public void malformedPin() throws Exception { - CertificatePinner.Builder builder = new CertificatePinner.Builder(); - try { - builder.add("example.com", "md5/DmxUShsZuNiqPQsX2Oi9uv2sCnw="); - fail(); - } catch (IllegalArgumentException expected) { - } - } - - @Test public void malformedBase64() throws Exception { - CertificatePinner.Builder builder = new CertificatePinner.Builder(); - try { - builder.add("example.com", "sha1/DmxUShsZuNiqPQsX2Oi9uv2sCnw*"); - fail(); - } catch (IllegalArgumentException expected) { - } - } - - /** Multiple certificates generated from the same keypair have the same pin. */ - @Test public void sameKeypairSamePin() throws Exception { - HeldCertificate heldCertificateA2 = new HeldCertificate.Builder() - .keyPair(certA1.keyPair()) - .serialNumber(101L) - .build(); - String keypairACertificate2Pin = CertificatePinner.pin(heldCertificateA2.certificate()); - - HeldCertificate heldCertificateB2 = new HeldCertificate.Builder() - .keyPair(certB1.keyPair()) - .serialNumber(201L) - .build(); - String keypairBCertificate2Pin = CertificatePinner.pin(heldCertificateB2.certificate()); - - assertThat(keypairACertificate2Pin).isEqualTo(certA1Sha256Pin); - assertThat(keypairBCertificate2Pin).isEqualTo(certB1Sha256Pin); - assertThat(certB1Sha256Pin).isNotEqualTo(certA1Sha256Pin); - } - - @Test public void successfulCheck() throws Exception { - CertificatePinner certificatePinner = new CertificatePinner.Builder() - .add("example.com", certA1Sha256Pin) - .build(); - - certificatePinner.check("example.com", singletonList(certA1.certificate())); - } - - @Test public void successfulMatchAcceptsAnyMatchingCertificateOld() throws Exception { - CertificatePinner certificatePinner = new CertificatePinner.Builder() - .add("example.com", certB1Sha256Pin) - .build(); - - certificatePinner.check("example.com", certA1.certificate(), certB1.certificate()); - } - - @Test public void successfulMatchAcceptsAnyMatchingCertificate() throws Exception { - CertificatePinner certificatePinner = new CertificatePinner.Builder() - .add("example.com", certB1Sha256Pin) - .build(); - - certificatePinner.check("example.com", asList(certA1.certificate(), certB1.certificate())); - } - - @Test public void unsuccessfulCheck() throws Exception { - CertificatePinner certificatePinner = new CertificatePinner.Builder() - .add("example.com", certA1Sha256Pin) - .build(); - - try { - certificatePinner.check("example.com", certB1.certificate()); - fail(); - } catch (SSLPeerUnverifiedException expected) { - } - } - - @Test public void multipleCertificatesForOneHostname() throws Exception { - CertificatePinner certificatePinner = new CertificatePinner.Builder() - .add("example.com", certA1Sha256Pin, certB1Sha256Pin) - .build(); - - certificatePinner.check("example.com", singletonList(certA1.certificate())); - certificatePinner.check("example.com", singletonList(certB1.certificate())); - } - - @Test public void multipleHostnamesForOneCertificate() throws Exception { - CertificatePinner certificatePinner = new CertificatePinner.Builder() - .add("example.com", certA1Sha256Pin) - .add("www.example.com", certA1Sha256Pin) - .build(); - - certificatePinner.check("example.com", singletonList(certA1.certificate())); - certificatePinner.check("www.example.com", singletonList(certA1.certificate())); - } - - @Test public void absentHostnameMatches() throws Exception { - CertificatePinner certificatePinner = new CertificatePinner.Builder().build(); - certificatePinner.check("example.com", singletonList(certA1.certificate())); - } - - @Test public void successfulCheckForWildcardHostname() throws Exception { - CertificatePinner certificatePinner = new CertificatePinner.Builder() - .add("*.example.com", certA1Sha256Pin) - .build(); - - certificatePinner.check("a.example.com", singletonList(certA1.certificate())); - } - - @Test public void successfulMatchAcceptsAnyMatchingCertificateForWildcardHostname() - throws Exception { - CertificatePinner certificatePinner = new CertificatePinner.Builder() - .add("*.example.com", certB1Sha256Pin) - .build(); - - certificatePinner.check("a.example.com", asList(certA1.certificate(), certB1.certificate())); - } - - @Test public void unsuccessfulCheckForWildcardHostname() throws Exception { - CertificatePinner certificatePinner = new CertificatePinner.Builder() - .add("*.example.com", certA1Sha256Pin) - .build(); - - try { - certificatePinner.check("a.example.com", singletonList(certB1.certificate())); - fail(); - } catch (SSLPeerUnverifiedException expected) { - } - } - - @Test public void multipleCertificatesForOneWildcardHostname() throws Exception { - CertificatePinner certificatePinner = new CertificatePinner.Builder() - .add("*.example.com", certA1Sha256Pin, certB1Sha256Pin) - .build(); - - certificatePinner.check("a.example.com", singletonList(certA1.certificate())); - certificatePinner.check("a.example.com", singletonList(certB1.certificate())); - } - - @Test public void successfulCheckForOneHostnameWithWildcardAndDirectCertificate() - throws Exception { - CertificatePinner certificatePinner = new CertificatePinner.Builder() - .add("*.example.com", certA1Sha256Pin) - .add("a.example.com", certB1Sha256Pin) - .build(); - - certificatePinner.check("a.example.com", singletonList(certA1.certificate())); - certificatePinner.check("a.example.com", singletonList(certB1.certificate())); - } - - @Test public void unsuccessfulCheckForOneHostnameWithWildcardAndDirectCertificate() - throws Exception { - CertificatePinner certificatePinner = new CertificatePinner.Builder() - .add("*.example.com", certA1Sha256Pin) - .add("a.example.com", certB1Sha256Pin) - .build(); - - try { - certificatePinner.check("a.example.com", singletonList(certC1.certificate())); - fail(); - } catch (SSLPeerUnverifiedException expected) { - } - } - - @Test public void checkForHostnameWithDoubleAsterisk() throws Exception { - CertificatePinner certificatePinner = new CertificatePinner.Builder() - .add("**.example.co.uk", certA1Sha256Pin) - .build(); - - // Should be pinned: - try { - certificatePinner.check("example.co.uk", singletonList(certB1.certificate())); - fail(); - } catch (SSLPeerUnverifiedException expected) { - } - try { - certificatePinner.check("foo.example.co.uk", singletonList(certB1.certificate())); - fail(); - } catch (SSLPeerUnverifiedException expected) { - } - try { - certificatePinner.check("foo.bar.example.co.uk", singletonList(certB1.certificate())); - fail(); - } catch (SSLPeerUnverifiedException expected) { - } - try { - certificatePinner.check("foo.bar.baz.example.co.uk", singletonList(certB1.certificate())); - fail(); - } catch (SSLPeerUnverifiedException expected) { - } - - // Should not be pinned: - certificatePinner.check("uk", singletonList(certB1.certificate())); - certificatePinner.check("co.uk", singletonList(certB1.certificate())); - certificatePinner.check("anotherexample.co.uk", singletonList(certB1.certificate())); - certificatePinner.check("foo.anotherexample.co.uk", singletonList(certB1.certificate())); - } - - @Test - public void testBadPin() { - try { - new CertificatePinner.Pin("example.co.uk", - "sha256/a"); - fail(); - } catch (IllegalArgumentException iae) { - // expected - } - } - - @Test - public void testBadAlgorithm() { - try { - new CertificatePinner.Pin("example.co.uk", - "sha512/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="); - fail(); - } catch (IllegalArgumentException iae) { - // expected - } - } - - @Test - public void testBadHost() { - try { - new CertificatePinner.Pin("example.*", - "sha256/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="); - fail(); - } catch (IllegalArgumentException iae) { - // expected - } - } - - @Test - public void testGoodPin() { - CertificatePinner.Pin pin = new CertificatePinner.Pin("**.example.co.uk", - "sha256/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="); - - assertEquals(decodeBase64("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="), pin.getHash()); - assertEquals("sha256", pin.getHashAlgorithm()); - assertEquals("**.example.co.uk", pin.getPattern()); - - assertTrue(pin.matchesHostname("www.example.co.uk")); - assertTrue(pin.matchesHostname("gopher.example.co.uk")); - assertFalse(pin.matchesHostname("www.example.com")); - } - - @Test - public void testMatchesSha256() { - CertificatePinner.Pin pin = new CertificatePinner.Pin("example.com", certA1Sha256Pin); - - assertTrue(pin.matchesCertificate(certA1.certificate())); - assertFalse(pin.matchesCertificate(certB1.certificate())); - } - - @Test - public void testMatchesSha1() { - CertificatePinner.Pin pin = new CertificatePinner.Pin("example.com", certC1Sha1Pin); - - assertTrue(pin.matchesCertificate(certC1.certificate())); - assertFalse(pin.matchesCertificate(certB1.certificate())); - } - - @Test public void pinList() { - CertificatePinner.Builder builder = new CertificatePinner.Builder() - .add("example.com", certA1Sha256Pin) - .add("www.example.com", certA1Sha256Pin); - CertificatePinner certificatePinner = builder.build(); - - List expectedPins = - asList(new CertificatePinner.Pin("example.com", certA1Sha256Pin), - new CertificatePinner.Pin("www.example.com", certA1Sha256Pin)); - - assertEquals(expectedPins, builder.getPins()); - assertEquals(new HashSet<>(expectedPins), certificatePinner.getPins()); - } -} diff --git a/okhttp/src/test/java/okhttp3/CertificatePinnerTest.kt b/okhttp/src/test/java/okhttp3/CertificatePinnerTest.kt new file mode 100644 index 000000000000..0250032b7103 --- /dev/null +++ b/okhttp/src/test/java/okhttp3/CertificatePinnerTest.kt @@ -0,0 +1,333 @@ +/* + * Copyright (C) 2014 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3 + +import java.util.Arrays +import javax.net.ssl.SSLPeerUnverifiedException +import okhttp3.CertificatePinner.Companion.pin +import okhttp3.CertificatePinner.Companion.sha1Hash +import okhttp3.tls.HeldCertificate +import okio.ByteString.Companion.decodeBase64 +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.fail +import org.junit.jupiter.api.Test + +class CertificatePinnerTest { + @Test + fun malformedPin() { + val builder = CertificatePinner.Builder() + try { + builder.add("example.com", "md5/DmxUShsZuNiqPQsX2Oi9uv2sCnw=") + fail() + } catch (expected: IllegalArgumentException) { + } + } + + @Test + fun malformedBase64() { + val builder = CertificatePinner.Builder() + try { + builder.add("example.com", "sha1/DmxUShsZuNiqPQsX2Oi9uv2sCnw*") + fail() + } catch (expected: IllegalArgumentException) { + } + } + + /** Multiple certificates generated from the same keypair have the same pin. */ + @Test + fun sameKeypairSamePin() { + val heldCertificateA2 = HeldCertificate.Builder() + .keyPair(certA1.keyPair) + .serialNumber(101L) + .build() + val keypairACertificate2Pin = pin(heldCertificateA2.certificate) + val heldCertificateB2 = HeldCertificate.Builder() + .keyPair(certB1.keyPair) + .serialNumber(201L) + .build() + val keypairBCertificate2Pin = pin(heldCertificateB2.certificate) + assertThat(keypairACertificate2Pin).isEqualTo( + certA1Sha256Pin + ) + assertThat(keypairBCertificate2Pin).isEqualTo( + certB1Sha256Pin + ) + assertThat(certB1Sha256Pin).isNotEqualTo(certA1Sha256Pin) + } + + @Test + fun successfulCheck() { + val certificatePinner = CertificatePinner.Builder() + .add("example.com", certA1Sha256Pin) + .build() + certificatePinner.check("example.com", listOf(certA1.certificate)) + } + + @Test + fun successfulMatchAcceptsAnyMatchingCertificateOld() { + val certificatePinner = CertificatePinner.Builder() + .add("example.com", certB1Sha256Pin) + .build() + certificatePinner.check("example.com", certA1.certificate, certB1.certificate) + } + + @Test + fun successfulMatchAcceptsAnyMatchingCertificate() { + val certificatePinner = CertificatePinner.Builder() + .add("example.com", certB1Sha256Pin) + .build() + certificatePinner.check( + "example.com", + Arrays.asList(certA1.certificate, certB1.certificate) + ) + } + + @Test + fun unsuccessfulCheck() { + val certificatePinner = CertificatePinner.Builder() + .add("example.com", certA1Sha256Pin) + .build() + try { + certificatePinner.check("example.com", certB1.certificate) + fail() + } catch (expected: SSLPeerUnverifiedException) { + } + } + + @Test + fun multipleCertificatesForOneHostname() { + val certificatePinner = CertificatePinner.Builder() + .add("example.com", certA1Sha256Pin, certB1Sha256Pin) + .build() + certificatePinner.check("example.com", listOf(certA1.certificate)) + certificatePinner.check("example.com", listOf(certB1.certificate)) + } + + @Test + fun multipleHostnamesForOneCertificate() { + val certificatePinner = CertificatePinner.Builder() + .add("example.com", certA1Sha256Pin) + .add("www.example.com", certA1Sha256Pin) + .build() + certificatePinner.check("example.com", listOf(certA1.certificate)) + certificatePinner.check("www.example.com", listOf(certA1.certificate)) + } + + @Test + fun absentHostnameMatches() { + val certificatePinner = CertificatePinner.Builder().build() + certificatePinner.check("example.com", listOf(certA1.certificate)) + } + + @Test + fun successfulCheckForWildcardHostname() { + val certificatePinner = CertificatePinner.Builder() + .add("*.example.com", certA1Sha256Pin) + .build() + certificatePinner.check("a.example.com", listOf(certA1.certificate)) + } + + @Test + fun successfulMatchAcceptsAnyMatchingCertificateForWildcardHostname() { + val certificatePinner = CertificatePinner.Builder() + .add("*.example.com", certB1Sha256Pin) + .build() + certificatePinner.check( + "a.example.com", + Arrays.asList(certA1.certificate, certB1.certificate) + ) + } + + @Test + fun unsuccessfulCheckForWildcardHostname() { + val certificatePinner = CertificatePinner.Builder() + .add("*.example.com", certA1Sha256Pin) + .build() + try { + certificatePinner.check("a.example.com", listOf(certB1.certificate)) + fail() + } catch (expected: SSLPeerUnverifiedException) { + } + } + + @Test + fun multipleCertificatesForOneWildcardHostname() { + val certificatePinner = CertificatePinner.Builder() + .add("*.example.com", certA1Sha256Pin, certB1Sha256Pin) + .build() + certificatePinner.check("a.example.com", listOf(certA1.certificate)) + certificatePinner.check("a.example.com", listOf(certB1.certificate)) + } + + @Test + fun successfulCheckForOneHostnameWithWildcardAndDirectCertificate() { + val certificatePinner = CertificatePinner.Builder() + .add("*.example.com", certA1Sha256Pin) + .add("a.example.com", certB1Sha256Pin) + .build() + certificatePinner.check("a.example.com", listOf(certA1.certificate)) + certificatePinner.check("a.example.com", listOf(certB1.certificate)) + } + + @Test + fun unsuccessfulCheckForOneHostnameWithWildcardAndDirectCertificate() { + val certificatePinner = CertificatePinner.Builder() + .add("*.example.com", certA1Sha256Pin) + .add("a.example.com", certB1Sha256Pin) + .build() + try { + certificatePinner.check("a.example.com", listOf(certC1.certificate)) + fail() + } catch (expected: SSLPeerUnverifiedException) { + } + } + + @Test + fun checkForHostnameWithDoubleAsterisk() { + val certificatePinner = CertificatePinner.Builder() + .add("**.example.co.uk", certA1Sha256Pin) + .build() + + // Should be pinned: + try { + certificatePinner.check("example.co.uk", listOf(certB1.certificate)) + fail() + } catch (expected: SSLPeerUnverifiedException) { + } + try { + certificatePinner.check("foo.example.co.uk", listOf(certB1.certificate)) + fail() + } catch (expected: SSLPeerUnverifiedException) { + } + try { + certificatePinner.check("foo.bar.example.co.uk", listOf(certB1.certificate)) + fail() + } catch (expected: SSLPeerUnverifiedException) { + } + try { + certificatePinner.check("foo.bar.baz.example.co.uk", listOf(certB1.certificate)) + fail() + } catch (expected: SSLPeerUnverifiedException) { + } + + // Should not be pinned: + certificatePinner.check("uk", listOf(certB1.certificate)) + certificatePinner.check("co.uk", listOf(certB1.certificate)) + certificatePinner.check("anotherexample.co.uk", listOf(certB1.certificate)) + certificatePinner.check("foo.anotherexample.co.uk", listOf(certB1.certificate)) + } + + @Test + fun testBadPin() { + try { + CertificatePinner.Pin( + "example.co.uk", + "sha256/a" + ) + fail() + } catch (iae: IllegalArgumentException) { + // expected + } + } + + @Test + fun testBadAlgorithm() { + try { + CertificatePinner.Pin( + "example.co.uk", + "sha512/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=" + ) + fail() + } catch (iae: IllegalArgumentException) { + // expected + } + } + + @Test + fun testBadHost() { + try { + CertificatePinner.Pin( + "example.*", + "sha256/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=" + ) + fail() + } catch (iae: IllegalArgumentException) { + // expected + } + } + + @Test + fun testGoodPin() { + val pin = CertificatePinner.Pin( + "**.example.co.uk", + "sha256/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=" + ) + assertEquals( + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=".decodeBase64(), + pin.hash + ) + assertEquals("sha256", pin.hashAlgorithm) + assertEquals("**.example.co.uk", pin.pattern) + Assertions.assertTrue(pin.matchesHostname("www.example.co.uk")) + Assertions.assertTrue(pin.matchesHostname("gopher.example.co.uk")) + Assertions.assertFalse(pin.matchesHostname("www.example.com")) + } + + @Test + fun testMatchesSha256() { + val pin = CertificatePinner.Pin("example.com", certA1Sha256Pin) + Assertions.assertTrue(pin.matchesCertificate(certA1.certificate)) + Assertions.assertFalse(pin.matchesCertificate(certB1.certificate)) + } + + @Test + fun testMatchesSha1() { + val pin = CertificatePinner.Pin("example.com", certC1Sha1Pin) + Assertions.assertTrue(pin.matchesCertificate(certC1.certificate)) + Assertions.assertFalse(pin.matchesCertificate(certB1.certificate)) + } + + @Test + fun pinList() { + val builder = CertificatePinner.Builder() + .add("example.com", certA1Sha256Pin) + .add("www.example.com", certA1Sha256Pin) + val certificatePinner = builder.build() + val expectedPins = Arrays.asList( + CertificatePinner.Pin("example.com", certA1Sha256Pin), + CertificatePinner.Pin("www.example.com", certA1Sha256Pin) + ) + assertEquals(expectedPins, builder.pins) + assertEquals(HashSet(expectedPins), certificatePinner.pins) + } + + companion object { + val certA1 = HeldCertificate.Builder() + .serialNumber(100L) + .build() + val certA1Sha256Pin = pin(certA1.certificate) + val certB1 = HeldCertificate.Builder() + .serialNumber(200L) + .build() + val certB1Sha256Pin = pin(certB1.certificate) + val certC1 = HeldCertificate.Builder() + .serialNumber(300L) + .build() + val certC1Sha1Pin = "sha1/" + certC1.certificate.sha1Hash().base64() + } +} diff --git a/okhttp/src/test/java/okhttp3/ConnectionCoalescingTest.java b/okhttp/src/test/java/okhttp3/ConnectionCoalescingTest.java deleted file mode 100644 index ecc78c624d0d..000000000000 --- a/okhttp/src/test/java/okhttp3/ConnectionCoalescingTest.java +++ /dev/null @@ -1,556 +0,0 @@ -/* - * Copyright (C) 2017 Square, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package okhttp3; - -import java.io.IOException; -import java.net.InetAddress; -import java.net.InetSocketAddress; -import java.net.Proxy; -import java.security.cert.X509Certificate; -import java.util.Collections; -import java.util.List; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; -import javax.net.ssl.HostnameVerifier; -import javax.net.ssl.SSLPeerUnverifiedException; -import javax.net.ssl.X509TrustManager; -import mockwebserver3.MockResponse; -import mockwebserver3.MockWebServer; -import okhttp3.testing.PlatformRule; -import okhttp3.tls.HandshakeCertificates; -import okhttp3.tls.HeldCertificate; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Tag; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.RegisterExtension; - -import static java.util.Arrays.asList; -import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.fail; - -@Tag("Slowish") -public final class ConnectionCoalescingTest { - @RegisterExtension public final PlatformRule platform = new PlatformRule(); - @RegisterExtension public final OkHttpClientTestRule clientTestRule = new OkHttpClientTestRule(); - - private MockWebServer server; - private OkHttpClient client; - private HeldCertificate rootCa; - private HeldCertificate certificate; - private final FakeDns dns = new FakeDns(); - private HttpUrl url; - private List serverIps; - - @BeforeEach public void setUp(MockWebServer server) throws Exception { - this.server = server; - - platform.assumeHttp2Support(); - platform.assumeNotBouncyCastle(); - - rootCa = new HeldCertificate.Builder() - .serialNumber(1L) - .certificateAuthority(0) - .commonName("root") - .build(); - certificate = new HeldCertificate.Builder() - .signedBy(rootCa) - .serialNumber(2L) - .commonName(server.getHostName()) - .addSubjectAlternativeName(server.getHostName()) - .addSubjectAlternativeName("san.com") - .addSubjectAlternativeName("*.wildcard.com") - .addSubjectAlternativeName("differentdns.com") - .build(); - - serverIps = Dns.SYSTEM.lookup(server.getHostName()); - - dns.set(server.getHostName(), serverIps); - dns.set("san.com", serverIps); - dns.set("nonsan.com", serverIps); - dns.set("www.wildcard.com", serverIps); - dns.set("differentdns.com", Collections.emptyList()); - - HandshakeCertificates handshakeCertificates = new HandshakeCertificates.Builder() - .addTrustedCertificate(rootCa.certificate()) - .build(); - - client = clientTestRule.newClientBuilder() - .fastFallback(false) // Avoid data races. - .dns(dns) - .sslSocketFactory( - handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager()) - .build(); - - HandshakeCertificates serverHandshakeCertificates = new HandshakeCertificates.Builder() - .heldCertificate(certificate) - .build(); - server.useHttps(serverHandshakeCertificates.sslSocketFactory()); - - url = server.url("/robots.txt"); - } - - /** - * Test connecting to the main host then an alternative, although only subject alternative names - * are used if present no special consideration of common name. - */ - @Test public void commonThenAlternative() throws Exception { - server.enqueue(new MockResponse()); - server.enqueue(new MockResponse()); - - assert200Http2Response(execute(url), server.getHostName()); - - HttpUrl sanUrl = url.newBuilder().host("san.com").build(); - assert200Http2Response(execute(sanUrl), "san.com"); - - assertThat(client.connectionPool().connectionCount()).isEqualTo(1); - } - - /** - * Test connecting to an alternative host then common name, although only subject alternative - * names are used if present no special consideration of common name. - */ - @Test public void alternativeThenCommon() throws Exception { - server.enqueue(new MockResponse()); - server.enqueue(new MockResponse()); - - HttpUrl sanUrl = url.newBuilder().host("san.com").build(); - assert200Http2Response(execute(sanUrl), "san.com"); - - assert200Http2Response(execute(url), server.getHostName()); - - assertThat(client.connectionPool().connectionCount()).isEqualTo(1); - } - - /** Test a previously coalesced connection that's no longer healthy. */ - @Test public void staleCoalescedConnection() throws Exception { - server.enqueue(new MockResponse()); - server.enqueue(new MockResponse()); - - AtomicReference connection = new AtomicReference<>(); - client = client.newBuilder() - .addNetworkInterceptor(chain -> { - connection.set(chain.connection()); - return chain.proceed(chain.request()); - }) - .build(); - dns.set("san.com", Dns.SYSTEM.lookup(server.getHostName()).subList(0, 1)); - - assert200Http2Response(execute(url), server.getHostName()); - - // Simulate a stale connection in the pool. - connection.get().socket().close(); - - HttpUrl sanUrl = url.newBuilder().host("san.com").build(); - assert200Http2Response(execute(sanUrl), "san.com"); - - assertThat(client.connectionPool().connectionCount()).isEqualTo(1); - } - - /** - * This is an extraordinary test case. Here's what it's trying to simulate. - * - 2 requests happen concurrently to a host that can be coalesced onto a single connection. - * - Both request discover no existing connection. They both make a connection. - * - The first request "wins the race". - * - The second request discovers it "lost the race" and closes the connection it just opened. - * - The second request uses the coalesced connection from request1. - * - The coalesced connection is violently closed after servicing the first request. - * - The second request discovers the coalesced connection is unhealthy just after acquiring it. - */ - @Test public void coalescedConnectionDestroyedAfterAcquire() throws Exception { - server.enqueue(new MockResponse()); - server.enqueue(new MockResponse()); - - dns.set("san.com", Dns.SYSTEM.lookup(server.getHostName()).subList(0, 1)); - HttpUrl sanUrl = url.newBuilder().host("san.com").build(); - - CountDownLatch latch1 = new CountDownLatch(1); - CountDownLatch latch2 = new CountDownLatch(1); - CountDownLatch latch3 = new CountDownLatch(1); - CountDownLatch latch4 = new CountDownLatch(1); - EventListener listener1 = new EventListener() { - @Override public void connectStart(Call call, InetSocketAddress inetSocketAddress, - Proxy proxy) { - try { - // Wait for request2 to guarantee we make 2 separate connections to the server. - latch1.await(); - } catch (InterruptedException e) { - throw new AssertionError(e); - } - } - - @Override public void connectionAcquired(Call call, Connection connection) { - // We have the connection and it's in the pool. Let request2 proceed to make a connection. - latch2.countDown(); - } - }; - - EventListener request2Listener = new EventListener() { - @Override public void connectStart(Call call, InetSocketAddress inetSocketAddress, - Proxy proxy) { - // Let request1 proceed to make a connection. - latch1.countDown(); - try { - // Wait until request1 makes the connection and puts it in the connection pool. - latch2.await(); - } catch (InterruptedException e) { - throw new AssertionError(e); - } - } - - @Override public void connectionAcquired(Call call, Connection connection) { - // We obtained the coalesced connection. Let request1 violently destroy it. - latch3.countDown(); - try { - latch4.await(); - } catch (InterruptedException e) { - throw new AssertionError(e); - } - } - }; - - // Get a reference to the connection so we can violently destroy it. - AtomicReference connection = new AtomicReference<>(); - OkHttpClient client1 = client.newBuilder() - .addNetworkInterceptor(chain -> { - connection.set(chain.connection()); - return chain.proceed(chain.request()); - }) - .eventListenerFactory(clientTestRule.wrap(listener1)) - .build(); - - Request request = new Request.Builder().url(sanUrl).build(); - Call call1 = client1.newCall(request); - call1.enqueue(new Callback() { - @Override public void onResponse(Call call, Response response) throws IOException { - try { - // Wait until request2 acquires the connection before we destroy it violently. - latch3.await(); - } catch (InterruptedException e) { - throw new AssertionError(e); - } - assert200Http2Response(response, "san.com"); - connection.get().socket().close(); - latch4.countDown(); - } - - @Override public void onFailure(Call call, IOException e) { - fail(); - } - }); - - OkHttpClient client2 = client.newBuilder() - .eventListenerFactory(clientTestRule.wrap(request2Listener)) - .build(); - Call call2 = client2.newCall(request); - Response response = call2.execute(); - - assert200Http2Response(response, "san.com"); - } - - /** If the existing connection matches a SAN but not a match for DNS then skip. */ - @Test public void skipsWhenDnsDontMatch() throws Exception { - server.enqueue(new MockResponse()); - - assert200Http2Response(execute(url), server.getHostName()); - - HttpUrl differentDnsUrl = url.newBuilder().host("differentdns.com").build(); - try { - execute(differentDnsUrl); - fail("expected a failed attempt to connect"); - } catch (IOException expected) { - } - } - - @Test public void skipsOnRedirectWhenDnsDontMatch() throws Exception { - server.enqueue(new MockResponse.Builder() - .code(301) - .addHeader("Location", url.newBuilder().host("differentdns.com").build()) - .build()); - server.enqueue(new MockResponse.Builder() - .body("unexpected call") - .build()); - - try { - Response response = execute(url); - response.close(); - fail("expected a failed attempt to connect"); - } catch (IOException expected) { - } - } - - /** Not in the certificate SAN. */ - @Test public void skipsWhenNotSubjectAltName() throws Exception { - server.enqueue(new MockResponse()); - server.enqueue(new MockResponse()); - - assert200Http2Response(execute(url), server.getHostName()); - - HttpUrl nonsanUrl = url.newBuilder().host("nonsan.com").build(); - - try { - execute(nonsanUrl); - fail("expected a failed attempt to connect"); - } catch (SSLPeerUnverifiedException expected) { - } - } - - @Test public void skipsOnRedirectWhenNotSubjectAltName() throws Exception { - server.enqueue(new MockResponse.Builder() - .code(301) - .addHeader("Location", url.newBuilder().host("nonsan.com").build()) - .build()); - server.enqueue(new MockResponse()); - - try { - Response response = execute(url); - response.close(); - fail("expected a failed attempt to connect"); - } catch (SSLPeerUnverifiedException expected) { - } - } - - /** Can still coalesce when pinning is used if pins match. */ - @Test public void coalescesWhenCertificatePinsMatch() throws Exception { - CertificatePinner pinner = new CertificatePinner.Builder() - .add("san.com", CertificatePinner.pin(certificate.certificate())) - .build(); - client = client.newBuilder().certificatePinner(pinner).build(); - - server.enqueue(new MockResponse()); - server.enqueue(new MockResponse()); - - assert200Http2Response(execute(url), server.getHostName()); - - HttpUrl sanUrl = url.newBuilder().host("san.com").build(); - - assert200Http2Response(execute(sanUrl), "san.com"); - - assertThat(client.connectionPool().connectionCount()).isEqualTo(1); - } - - /** Certificate pinning used and not a match will avoid coalescing and try to connect. */ - @Test public void skipsWhenCertificatePinningFails() throws Exception { - CertificatePinner pinner = new CertificatePinner.Builder() - .add("san.com", "sha1/afwiKY3RxoMmLkuRW1l7QsPZTJPwDS2pdDROQjXw8ig=") - .build(); - client = client.newBuilder().certificatePinner(pinner).build(); - - server.enqueue(new MockResponse()); - - assert200Http2Response(execute(url), server.getHostName()); - - HttpUrl sanUrl = url.newBuilder().host("san.com").build(); - - try { - execute(sanUrl); - fail("expected a failed attempt to connect"); - } catch (IOException expected) { - } - } - - @Test public void skipsOnRedirectWhenCertificatePinningFails() throws Exception { - CertificatePinner pinner = new CertificatePinner.Builder() - .add("san.com", "sha1/afwiKY3RxoMmLkuRW1l7QsPZTJPwDS2pdDROQjXw8ig=") - .build(); - client = client.newBuilder().certificatePinner(pinner).build(); - - server.enqueue(new MockResponse.Builder() - .code(301) - .addHeader("Location", url.newBuilder().host("san.com").build()) - .build()); - server.enqueue(new MockResponse()); - - try { - execute(url); - fail("expected a failed attempt to connect"); - } catch (SSLPeerUnverifiedException expected) { - } - } - - /** - * Skips coalescing when hostname verifier is overridden since the intention of the hostname - * verification is a black box. - */ - @Test public void skipsWhenHostnameVerifierUsed() throws Exception { - HostnameVerifier verifier = (name, session) -> true; - client = client.newBuilder().hostnameVerifier(verifier).build(); - - server.enqueue(new MockResponse()); - server.enqueue(new MockResponse()); - - assert200Http2Response(execute(url), server.getHostName()); - - HttpUrl sanUrl = url.newBuilder().host("san.com").build(); - - assert200Http2Response(execute(sanUrl), "san.com"); - - assertThat(client.connectionPool().connectionCount()).isEqualTo(2); - } - - @Test public void skipsOnRedirectWhenHostnameVerifierUsed() throws Exception { - HostnameVerifier verifier = (name, session) -> true; - client = client.newBuilder().hostnameVerifier(verifier).build(); - - server.enqueue(new MockResponse.Builder() - .code(301) - .addHeader("Location", url.newBuilder().host("san.com").build()) - .build()); - server.enqueue(new MockResponse()); - - assert200Http2Response(execute(url), "san.com"); - - assertThat(client.connectionPool().connectionCount()).isEqualTo(2); - assertThat(server.takeRequest().getSequenceNumber()).isEqualTo(0); // Fresh connection. - assertThat(server.takeRequest().getSequenceNumber()).isEqualTo(0); // Fresh connection. - } - - /** - * Check we would use an existing connection to a later DNS result instead of connecting to the - * first DNS result for the first time. - */ - @Test public void prefersExistingCompatible() throws Exception { - server.enqueue(new MockResponse()); - server.enqueue(new MockResponse()); - - AtomicInteger connectCount = new AtomicInteger(); - EventListener listener = new EventListener() { - @Override public void connectStart( - Call call, InetSocketAddress inetSocketAddress, Proxy proxy) { - connectCount.getAndIncrement(); - } - }; - client = client.newBuilder() - .eventListenerFactory(clientTestRule.wrap(listener)) - .build(); - - assert200Http2Response(execute(url), server.getHostName()); - - HttpUrl sanUrl = url.newBuilder().host("san.com").build(); - dns.set("san.com", - asList(InetAddress.getByAddress("san.com", new byte[] {0, 0, 0, 0}), - serverIps.get(0))); - assert200Http2Response(execute(sanUrl), "san.com"); - - assertThat(client.connectionPool().connectionCount()).isEqualTo(1); - assertThat(connectCount.get()).isEqualTo(1); - } - - /** Check that wildcard SANs are supported. */ - @Test public void commonThenWildcard() throws Exception { - server.enqueue(new MockResponse()); - server.enqueue(new MockResponse()); - - assert200Http2Response(execute(url), server.getHostName()); - - HttpUrl sanUrl = url.newBuilder().host("www.wildcard.com").build(); - assert200Http2Response(execute(sanUrl), "www.wildcard.com"); - - assertThat(client.connectionPool().connectionCount()).isEqualTo(1); - } - - /** Network interceptors check for changes to target. */ - @Test public void worksWithNetworkInterceptors() throws Exception { - client = client.newBuilder() - .addNetworkInterceptor(chain -> chain.proceed(chain.request())) - .build(); - - server.enqueue(new MockResponse()); - server.enqueue(new MockResponse()); - - assert200Http2Response(execute(url), server.getHostName()); - - HttpUrl sanUrl = url.newBuilder().host("san.com").build(); - assert200Http2Response(execute(sanUrl), "san.com"); - - assertThat(client.connectionPool().connectionCount()).isEqualTo(1); - } - - @Test public void misdirectedRequestResponseCode() throws Exception { - server.enqueue(new MockResponse.Builder() - .body("seed connection") - .build()); - server.enqueue(new MockResponse.Builder() - .code(421) - .body("misdirected!") - .build()); - server.enqueue(new MockResponse.Builder() - .body("after misdirect") - .build()); - - // Seed the connection pool. - assert200Http2Response(execute(url), server.getHostName()); - - // Use the coalesced connection which should retry on a fresh connection. - HttpUrl sanUrl = url.newBuilder() - .host("san.com") - .build(); - try (Response response = execute(sanUrl)) { - assertThat(response.code()).isEqualTo(200); - assertThat(response.priorResponse().code()).isEqualTo(421); - assertThat(response.body().string()).isEqualTo("after misdirect"); - } - - assertThat(server.takeRequest().getSequenceNumber()).isEqualTo(0); - assertThat(server.takeRequest().getSequenceNumber()).isEqualTo(1); - assertThat(server.takeRequest().getSequenceNumber()).isEqualTo(0); // Fresh connection. - - assertThat(client.connectionPool().connectionCount()).isEqualTo(2); - } - - /** - * Won't coalesce if we can't clean certs e.g. a dev setup. - */ - @Test public void redirectWithDevSetup() throws Exception { - X509TrustManager TRUST_MANAGER = new X509TrustManager() { - @Override - public void checkClientTrusted(X509Certificate[] x509Certificates, String s) { - } - - @Override - public void checkServerTrusted(X509Certificate[] x509Certificates, String s) { - } - - @Override - public X509Certificate[] getAcceptedIssuers() { - return new X509Certificate[0]; - } - }; - - client = client.newBuilder().sslSocketFactory(client.sslSocketFactory(), TRUST_MANAGER).build(); - - server.enqueue(new MockResponse()); - server.enqueue(new MockResponse()); - - assert200Http2Response(execute(url), server.getHostName()); - - HttpUrl sanUrl = url.newBuilder().host("san.com").build(); - assert200Http2Response(execute(sanUrl), "san.com"); - - assertThat(client.connectionPool().connectionCount()).isEqualTo(2); - } - - private Response execute(HttpUrl url) throws IOException { - return client.newCall(new Request.Builder().url(url).build()).execute(); - } - - private void assert200Http2Response(Response response, String expectedHost) { - assertThat(response.code()).isEqualTo(200); - assertThat(response.request().url().host()).isEqualTo(expectedHost); - assertThat(response.protocol()).isEqualTo(Protocol.HTTP_2); - response.body().close(); - } -} diff --git a/okhttp/src/test/java/okhttp3/ConnectionCoalescingTest.kt b/okhttp/src/test/java/okhttp3/ConnectionCoalescingTest.kt new file mode 100644 index 000000000000..13027d01496d --- /dev/null +++ b/okhttp/src/test/java/okhttp3/ConnectionCoalescingTest.kt @@ -0,0 +1,534 @@ +/* + * Copyright (C) 2017 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3 + +import java.io.IOException +import java.net.InetAddress +import java.net.InetSocketAddress +import java.net.Proxy +import java.security.cert.X509Certificate +import java.util.Arrays +import java.util.concurrent.CountDownLatch +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.atomic.AtomicReference +import javax.net.ssl.HostnameVerifier +import javax.net.ssl.SSLPeerUnverifiedException +import javax.net.ssl.SSLSession +import javax.net.ssl.X509TrustManager +import mockwebserver3.MockResponse +import mockwebserver3.MockWebServer +import okhttp3.CertificatePinner.Companion.pin +import okhttp3.testing.PlatformRule +import okhttp3.tls.HandshakeCertificates +import okhttp3.tls.HeldCertificate +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Assertions.fail +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Tag +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.RegisterExtension + +@Tag("Slowish") +class ConnectionCoalescingTest { + @RegisterExtension + val platform = PlatformRule() + + @RegisterExtension + val clientTestRule = OkHttpClientTestRule() + private lateinit var server: MockWebServer + private lateinit var client: OkHttpClient + private lateinit var rootCa: HeldCertificate + private lateinit var certificate: HeldCertificate + private val dns = FakeDns() + private lateinit var url: HttpUrl + private lateinit var serverIps: List + + @BeforeEach + fun setUp(server: MockWebServer) { + this.server = server + platform.assumeHttp2Support() + platform.assumeNotBouncyCastle() + rootCa = HeldCertificate.Builder() + .serialNumber(1L) + .certificateAuthority(0) + .commonName("root") + .build() + certificate = HeldCertificate.Builder() + .signedBy(rootCa) + .serialNumber(2L) + .commonName(server.hostName) + .addSubjectAlternativeName(server.hostName) + .addSubjectAlternativeName("san.com") + .addSubjectAlternativeName("*.wildcard.com") + .addSubjectAlternativeName("differentdns.com") + .build() + serverIps = Dns.SYSTEM.lookup(server.hostName) + dns[server.hostName] = serverIps + dns["san.com"] = serverIps + dns["nonsan.com"] = serverIps + dns["www.wildcard.com"] = serverIps + dns["differentdns.com"] = listOf() + val handshakeCertificates = HandshakeCertificates.Builder() + .addTrustedCertificate(rootCa.certificate) + .build() + client = clientTestRule.newClientBuilder() + .fastFallback(false) // Avoid data races. + .dns(dns) + .sslSocketFactory( + handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager + ) + .build() + val serverHandshakeCertificates = HandshakeCertificates.Builder() + .heldCertificate(certificate) + .build() + server.useHttps(serverHandshakeCertificates.sslSocketFactory()) + url = server.url("/robots.txt") + } + + /** + * Test connecting to the main host then an alternative, although only subject alternative names + * are used if present no special consideration of common name. + */ + @Test + fun commonThenAlternative() { + server.enqueue(MockResponse()) + server.enqueue(MockResponse()) + assert200Http2Response(execute(url), server.hostName) + val sanUrl = url.newBuilder().host("san.com").build() + assert200Http2Response(execute(sanUrl), "san.com") + assertThat(client.connectionPool.connectionCount()).isEqualTo(1) + } + + /** + * Test connecting to an alternative host then common name, although only subject alternative + * names are used if present no special consideration of common name. + */ + @Test + fun alternativeThenCommon() { + server.enqueue(MockResponse()) + server.enqueue(MockResponse()) + val sanUrl = url.newBuilder().host("san.com").build() + assert200Http2Response(execute(sanUrl), "san.com") + assert200Http2Response(execute(url), server.hostName) + assertThat(client.connectionPool.connectionCount()).isEqualTo(1) + } + + /** Test a previously coalesced connection that's no longer healthy. */ + @Test + fun staleCoalescedConnection() { + server.enqueue(MockResponse()) + server.enqueue(MockResponse()) + val connection = AtomicReference() + client = client.newBuilder() + .addNetworkInterceptor(Interceptor { chain: Interceptor.Chain? -> + connection.set(chain!!.connection()) + chain.proceed(chain.request()) + }) + .build() + dns["san.com"] = Dns.SYSTEM.lookup(server.hostName).subList(0, 1) + assert200Http2Response(execute(url), server.hostName) + + // Simulate a stale connection in the pool. + connection.get()!!.socket().close() + val sanUrl = url.newBuilder().host("san.com").build() + assert200Http2Response(execute(sanUrl), "san.com") + assertThat(client.connectionPool.connectionCount()).isEqualTo(1) + } + + /** + * This is an extraordinary test case. Here's what it's trying to simulate. + * - 2 requests happen concurrently to a host that can be coalesced onto a single connection. + * - Both request discover no existing connection. They both make a connection. + * - The first request "wins the race". + * - The second request discovers it "lost the race" and closes the connection it just opened. + * - The second request uses the coalesced connection from request1. + * - The coalesced connection is violently closed after servicing the first request. + * - The second request discovers the coalesced connection is unhealthy just after acquiring it. + */ + @Test + fun coalescedConnectionDestroyedAfterAcquire() { + server.enqueue(MockResponse()) + server.enqueue(MockResponse()) + dns["san.com"] = Dns.SYSTEM.lookup(server.hostName).subList(0, 1) + val sanUrl = url.newBuilder().host("san.com").build() + val latch1 = CountDownLatch(1) + val latch2 = CountDownLatch(1) + val latch3 = CountDownLatch(1) + val latch4 = CountDownLatch(1) + val listener1: EventListener = object : EventListener() { + override fun connectStart( + call: Call, inetSocketAddress: InetSocketAddress, + proxy: Proxy + ) { + try { + // Wait for request2 to guarantee we make 2 separate connections to the server. + latch1.await() + } catch (e: InterruptedException) { + throw AssertionError(e) + } + } + + override fun connectionAcquired(call: Call, connection: Connection) { + // We have the connection and it's in the pool. Let request2 proceed to make a connection. + latch2.countDown() + } + } + val request2Listener: EventListener = object : EventListener() { + override fun connectStart( + call: Call, inetSocketAddress: InetSocketAddress, proxy: Proxy, + ) { + // Let request1 proceed to make a connection. + latch1.countDown() + try { + // Wait until request1 makes the connection and puts it in the connection pool. + latch2.await() + } catch (e: InterruptedException) { + throw AssertionError(e) + } + } + + override fun connectionAcquired(call: Call, connection: Connection) { + // We obtained the coalesced connection. Let request1 violently destroy it. + latch3.countDown() + try { + latch4.await() + } catch (e: InterruptedException) { + throw AssertionError(e) + } + } + } + + // Get a reference to the connection so we can violently destroy it. + val connection = AtomicReference() + val client1 = client.newBuilder() + .addNetworkInterceptor(Interceptor { chain: Interceptor.Chain? -> + connection.set(chain!!.connection()) + chain.proceed(chain.request()) + }) + .eventListenerFactory(clientTestRule.wrap(listener1)) + .build() + val request = Request.Builder().url(sanUrl).build() + val call1 = client1.newCall(request) + call1.enqueue(object : Callback { + @Throws(IOException::class) + override fun onResponse(call: Call, response: Response) { + try { + // Wait until request2 acquires the connection before we destroy it violently. + latch3.await() + } catch (e: InterruptedException) { + throw AssertionError(e) + } + assert200Http2Response(response, "san.com") + connection.get()!!.socket().close() + latch4.countDown() + } + + override fun onFailure(call: Call, e: IOException) { + fail() + } + }) + val client2 = client.newBuilder() + .eventListenerFactory(clientTestRule.wrap(request2Listener)) + .build() + val call2 = client2.newCall(request) + val response = call2.execute() + assert200Http2Response(response, "san.com") + } + + /** If the existing connection matches a SAN but not a match for DNS then skip. */ + @Test + fun skipsWhenDnsDontMatch() { + server.enqueue(MockResponse()) + assert200Http2Response(execute(url), server.hostName) + val differentDnsUrl = url.newBuilder().host("differentdns.com").build() + try { + execute(differentDnsUrl) + fail("expected a failed attempt to connect") + } catch (expected: IOException) { + } + } + + @Test + fun skipsOnRedirectWhenDnsDontMatch() { + server.enqueue( + MockResponse.Builder() + .code(301) + .addHeader("Location", url.newBuilder().host("differentdns.com").build()) + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("unexpected call") + .build() + ) + try { + val response = execute(url) + response.close() + fail("expected a failed attempt to connect") + } catch (expected: IOException) { + } + } + + /** Not in the certificate SAN. */ + @Test + fun skipsWhenNotSubjectAltName() { + server.enqueue(MockResponse()) + server.enqueue(MockResponse()) + assert200Http2Response(execute(url), server.hostName) + val nonsanUrl = url.newBuilder().host("nonsan.com").build() + try { + execute(nonsanUrl) + fail("expected a failed attempt to connect") + } catch (expected: SSLPeerUnverifiedException) { + } + } + + @Test + fun skipsOnRedirectWhenNotSubjectAltName() { + server.enqueue( + MockResponse.Builder() + .code(301) + .addHeader("Location", url.newBuilder().host("nonsan.com").build()) + .build() + ) + server.enqueue(MockResponse()) + try { + val response = execute(url) + response.close() + fail("expected a failed attempt to connect") + } catch (expected: SSLPeerUnverifiedException) { + } + } + + /** Can still coalesce when pinning is used if pins match. */ + @Test + fun coalescesWhenCertificatePinsMatch() { + val pinner = CertificatePinner.Builder() + .add("san.com", pin(certificate.certificate)) + .build() + client = client.newBuilder().certificatePinner(pinner).build() + server.enqueue(MockResponse()) + server.enqueue(MockResponse()) + assert200Http2Response(execute(url), server.hostName) + val sanUrl = url.newBuilder().host("san.com").build() + assert200Http2Response(execute(sanUrl), "san.com") + assertThat(client.connectionPool.connectionCount()).isEqualTo(1) + } + + /** Certificate pinning used and not a match will avoid coalescing and try to connect. */ + @Test + fun skipsWhenCertificatePinningFails() { + val pinner = CertificatePinner.Builder() + .add("san.com", "sha1/afwiKY3RxoMmLkuRW1l7QsPZTJPwDS2pdDROQjXw8ig=") + .build() + client = client.newBuilder().certificatePinner(pinner).build() + server.enqueue(MockResponse()) + assert200Http2Response(execute(url), server.hostName) + val sanUrl = url.newBuilder().host("san.com").build() + try { + execute(sanUrl) + fail("expected a failed attempt to connect") + } catch (expected: IOException) { + } + } + + @Test + fun skipsOnRedirectWhenCertificatePinningFails() { + val pinner = CertificatePinner.Builder() + .add("san.com", "sha1/afwiKY3RxoMmLkuRW1l7QsPZTJPwDS2pdDROQjXw8ig=") + .build() + client = client.newBuilder().certificatePinner(pinner).build() + server.enqueue( + MockResponse.Builder() + .code(301) + .addHeader("Location", url.newBuilder().host("san.com").build()) + .build() + ) + server.enqueue(MockResponse()) + try { + execute(url) + fail("expected a failed attempt to connect") + } catch (expected: SSLPeerUnverifiedException) { + } + } + + /** + * Skips coalescing when hostname verifier is overridden since the intention of the hostname + * verification is a black box. + */ + @Test + fun skipsWhenHostnameVerifierUsed() { + val verifier = HostnameVerifier { name: String?, session: SSLSession? -> true } + client = client.newBuilder().hostnameVerifier(verifier).build() + server.enqueue(MockResponse()) + server.enqueue(MockResponse()) + assert200Http2Response(execute(url), server.hostName) + val sanUrl = url.newBuilder().host("san.com").build() + assert200Http2Response(execute(sanUrl), "san.com") + assertThat(client.connectionPool.connectionCount()).isEqualTo(2) + } + + @Test + fun skipsOnRedirectWhenHostnameVerifierUsed() { + val verifier = HostnameVerifier { name: String?, session: SSLSession? -> true } + client = client.newBuilder().hostnameVerifier(verifier).build() + server.enqueue( + MockResponse.Builder() + .code(301) + .addHeader("Location", url.newBuilder().host("san.com").build()) + .build() + ) + server.enqueue(MockResponse()) + assert200Http2Response(execute(url), "san.com") + assertThat(client.connectionPool.connectionCount()).isEqualTo(2) + assertThat(server.takeRequest().sequenceNumber) + .isEqualTo(0) // Fresh connection. + assertThat(server.takeRequest().sequenceNumber) + .isEqualTo(0) // Fresh connection. + } + + /** + * Check we would use an existing connection to a later DNS result instead of connecting to the + * first DNS result for the first time. + */ + @Test + fun prefersExistingCompatible() { + server.enqueue(MockResponse()) + server.enqueue(MockResponse()) + val connectCount = AtomicInteger() + val listener: EventListener = object : EventListener() { + override fun connectStart( + call: Call, inetSocketAddress: InetSocketAddress, proxy: Proxy + ) { + connectCount.getAndIncrement() + } + } + client = client.newBuilder() + .eventListenerFactory(clientTestRule.wrap(listener)) + .build() + assert200Http2Response(execute(url), server.hostName) + val sanUrl = url.newBuilder().host("san.com").build() + dns["san.com"] = Arrays.asList( + InetAddress.getByAddress("san.com", byteArrayOf(0, 0, 0, 0)), + serverIps[0] + ) + assert200Http2Response(execute(sanUrl), "san.com") + assertThat(client.connectionPool.connectionCount()).isEqualTo(1) + assertThat(connectCount.get()).isEqualTo(1) + } + + /** Check that wildcard SANs are supported. */ + @Test + fun commonThenWildcard() { + server.enqueue(MockResponse()) + server.enqueue(MockResponse()) + assert200Http2Response(execute(url), server.hostName) + val sanUrl = url.newBuilder().host("www.wildcard.com").build() + assert200Http2Response(execute(sanUrl), "www.wildcard.com") + assertThat(client.connectionPool.connectionCount()).isEqualTo(1) + } + + /** Network interceptors check for changes to target. */ + @Test + fun worksWithNetworkInterceptors() { + client = client.newBuilder() + .addNetworkInterceptor(Interceptor { chain: Interceptor.Chain? -> + chain!!.proceed( + chain.request() + ) + }) + .build() + server.enqueue(MockResponse()) + server.enqueue(MockResponse()) + assert200Http2Response(execute(url), server.hostName) + val sanUrl = url.newBuilder().host("san.com").build() + assert200Http2Response(execute(sanUrl), "san.com") + assertThat(client.connectionPool.connectionCount()).isEqualTo(1) + } + + @Test + fun misdirectedRequestResponseCode() { + server.enqueue( + MockResponse.Builder() + .body("seed connection") + .build() + ) + server.enqueue( + MockResponse.Builder() + .code(421) + .body("misdirected!") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("after misdirect") + .build() + ) + + // Seed the connection pool. + assert200Http2Response(execute(url), server.hostName) + + // Use the coalesced connection which should retry on a fresh connection. + val sanUrl = url.newBuilder() + .host("san.com") + .build() + execute(sanUrl).use { response -> + assertThat(response.code).isEqualTo(200) + assertThat(response.priorResponse!!.code).isEqualTo(421) + assertThat(response.body.string()).isEqualTo("after misdirect") + } + assertThat(server.takeRequest().sequenceNumber).isEqualTo(0) + assertThat(server.takeRequest().sequenceNumber).isEqualTo(1) + assertThat(server.takeRequest().sequenceNumber) + .isEqualTo(0) // Fresh connection. + assertThat(client.connectionPool.connectionCount()).isEqualTo(2) + } + + /** + * Won't coalesce if we can't clean certs e.g. a dev setup. + */ + @Test + fun redirectWithDevSetup() { + val TRUST_MANAGER: X509TrustManager = object : X509TrustManager { + override fun checkClientTrusted(x509Certificates: Array, s: String) { + } + + override fun checkServerTrusted(x509Certificates: Array, s: String) { + } + + override fun getAcceptedIssuers(): Array { + return arrayOf() + } + } + client = + client.newBuilder().sslSocketFactory(client.sslSocketFactory, TRUST_MANAGER).build() + server.enqueue(MockResponse()) + server.enqueue(MockResponse()) + assert200Http2Response(execute(url), server.hostName) + val sanUrl = url.newBuilder().host("san.com").build() + assert200Http2Response(execute(sanUrl), "san.com") + assertThat(client.connectionPool.connectionCount()).isEqualTo(2) + } + + private fun execute(url: HttpUrl) = client.newCall(Request(url = url)).execute() + + private fun assert200Http2Response(response: Response, expectedHost: String) { + assertThat(response.code).isEqualTo(200) + assertThat(response.request.url.host).isEqualTo(expectedHost) + assertThat(response.protocol).isEqualTo(Protocol.HTTP_2) + response.body.close() + } +} diff --git a/okhttp/src/test/java/okhttp3/ConnectionSpecTest.java b/okhttp/src/test/java/okhttp3/ConnectionSpecTest.java deleted file mode 100644 index 3d77cfdacdc6..000000000000 --- a/okhttp/src/test/java/okhttp3/ConnectionSpecTest.java +++ /dev/null @@ -1,379 +0,0 @@ -/* - * Copyright (C) 2015 Square, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package okhttp3; - -import okhttp3.internal.platform.Platform; -import okhttp3.testing.PlatformRule; -import okhttp3.testing.PlatformVersion; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.RegisterExtension; - -import javax.net.ssl.SSLSocket; -import javax.net.ssl.SSLSocketFactory; -import java.util.ArrayList; -import java.util.List; -import java.util.Set; -import java.util.concurrent.CopyOnWriteArraySet; - -import static java.util.Arrays.asList; -import static okhttp3.internal.Internal.applyConnectionSpec; -import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.fail; - -public final class ConnectionSpecTest { - @RegisterExtension public final PlatformRule platform = new PlatformRule(); - - @Test public void noTlsVersions() { - try { - new ConnectionSpec.Builder(ConnectionSpec.MODERN_TLS) - .tlsVersions(new TlsVersion[0]) - .build(); - fail(); - } catch (IllegalArgumentException expected) { - assertThat(expected.getMessage()).isEqualTo("At least one TLS version is required"); - } - } - - @Test public void noCipherSuites() { - try { - new ConnectionSpec.Builder(ConnectionSpec.MODERN_TLS) - .cipherSuites(new CipherSuite[0]) - .build(); - fail(); - } catch (IllegalArgumentException expected) { - assertThat(expected.getMessage()).isEqualTo("At least one cipher suite is required"); - } - } - - @Test public void cleartextBuilder() { - ConnectionSpec cleartextSpec = new ConnectionSpec.Builder(false).build(); - assertThat(cleartextSpec.isTls()).isFalse(); - } - - @Test public void tlsBuilder_explicitCiphers() throws Exception { - ConnectionSpec tlsSpec = new ConnectionSpec.Builder(true) - .cipherSuites(CipherSuite.TLS_RSA_WITH_RC4_128_MD5) - .tlsVersions(TlsVersion.TLS_1_2) - .supportsTlsExtensions(true) - .build(); - assertThat(tlsSpec.cipherSuites()).containsExactly(CipherSuite.TLS_RSA_WITH_RC4_128_MD5); - assertThat(tlsSpec.tlsVersions()).containsExactly(TlsVersion.TLS_1_2); - assertThat(tlsSpec.supportsTlsExtensions()).isTrue(); - } - - @Test public void tlsBuilder_defaultCiphers() throws Exception { - ConnectionSpec tlsSpec = new ConnectionSpec.Builder(true) - .tlsVersions(TlsVersion.TLS_1_2) - .supportsTlsExtensions(true) - .build(); - assertThat(tlsSpec.cipherSuites()).isNull(); - assertThat(tlsSpec.tlsVersions()).containsExactly(TlsVersion.TLS_1_2); - assertThat(tlsSpec.supportsTlsExtensions()).isTrue(); - } - - @Test public void tls_defaultCiphers_noFallbackIndicator() throws Exception { - platform.assumeNotConscrypt(); - platform.assumeNotBouncyCastle(); - - ConnectionSpec tlsSpec = new ConnectionSpec.Builder(true) - .tlsVersions(TlsVersion.TLS_1_2) - .supportsTlsExtensions(false) - .build(); - - SSLSocket socket = (SSLSocket) SSLSocketFactory.getDefault().createSocket(); - socket.setEnabledCipherSuites(new String[] { - CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256.javaName(), - CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA.javaName(), - }); - socket.setEnabledProtocols(new String[] { - TlsVersion.TLS_1_2.javaName(), - TlsVersion.TLS_1_1.javaName(), - }); - - assertThat(tlsSpec.isCompatible(socket)).isTrue(); - applyConnectionSpec(tlsSpec, socket, false /* isFallback */); - - assertThat(socket.getEnabledProtocols()).containsExactly(TlsVersion.TLS_1_2.javaName()); - - assertThat(socket.getEnabledCipherSuites()).containsExactlyInAnyOrder( - CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256.javaName(), - CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA.javaName()); - } - - @Test public void tls_defaultCiphers_withFallbackIndicator() throws Exception { - platform.assumeNotConscrypt(); - platform.assumeNotBouncyCastle(); - - ConnectionSpec tlsSpec = new ConnectionSpec.Builder(true) - .tlsVersions(TlsVersion.TLS_1_2) - .supportsTlsExtensions(false) - .build(); - - SSLSocket socket = (SSLSocket) SSLSocketFactory.getDefault().createSocket(); - socket.setEnabledCipherSuites(new String[] { - CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256.javaName(), - CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA.javaName(), - }); - socket.setEnabledProtocols(new String[] { - TlsVersion.TLS_1_2.javaName(), - TlsVersion.TLS_1_1.javaName(), - }); - - assertThat(tlsSpec.isCompatible(socket)).isTrue(); - applyConnectionSpec(tlsSpec, socket, true /* isFallback */); - - assertThat(socket.getEnabledProtocols()).containsExactly(TlsVersion.TLS_1_2.javaName()); - - List expectedCipherSuites = new ArrayList<>(); - expectedCipherSuites.add(CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256.javaName()); - expectedCipherSuites.add(CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA.javaName()); - if (asList(socket.getSupportedCipherSuites()).contains("TLS_FALLBACK_SCSV")) { - expectedCipherSuites.add("TLS_FALLBACK_SCSV"); - } - assertThat(socket.getEnabledCipherSuites()).containsExactlyElementsOf(expectedCipherSuites); - } - - @Test public void tls_explicitCiphers() throws Exception { - platform.assumeNotConscrypt(); - platform.assumeNotBouncyCastle(); - - ConnectionSpec tlsSpec = new ConnectionSpec.Builder(true) - .cipherSuites(CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) - .tlsVersions(TlsVersion.TLS_1_2) - .supportsTlsExtensions(false) - .build(); - - SSLSocket socket = (SSLSocket) SSLSocketFactory.getDefault().createSocket(); - socket.setEnabledCipherSuites(new String[] { - CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256.javaName(), - CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA.javaName(), - }); - socket.setEnabledProtocols(new String[] { - TlsVersion.TLS_1_2.javaName(), - TlsVersion.TLS_1_1.javaName(), - }); - - assertThat(tlsSpec.isCompatible(socket)).isTrue(); - applyConnectionSpec(tlsSpec, socket, true /* isFallback */); - - assertThat(socket.getEnabledProtocols()).containsExactly(TlsVersion.TLS_1_2.javaName()); - - List expectedCipherSuites = new ArrayList<>(); - expectedCipherSuites.add(CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256.javaName()); - if (asList(socket.getSupportedCipherSuites()).contains("TLS_FALLBACK_SCSV")) { - expectedCipherSuites.add("TLS_FALLBACK_SCSV"); - } - assertThat(socket.getEnabledCipherSuites()).containsExactlyElementsOf(expectedCipherSuites); - } - - @Test public void tls_stringCiphersAndVersions() throws Exception { - // Supporting arbitrary input strings allows users to enable suites and versions that are not - // yet known to the library, but are supported by the platform. - new ConnectionSpec.Builder(ConnectionSpec.MODERN_TLS) - .cipherSuites("MAGIC-CIPHER") - .tlsVersions("TLS9k") - .build(); - } - - @Test public void tls_missingRequiredCipher() throws Exception { - platform.assumeNotConscrypt(); - platform.assumeNotBouncyCastle(); - - ConnectionSpec tlsSpec = new ConnectionSpec.Builder(true) - .cipherSuites(CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) - .tlsVersions(TlsVersion.TLS_1_2) - .supportsTlsExtensions(false) - .build(); - - SSLSocket socket = (SSLSocket) SSLSocketFactory.getDefault().createSocket(); - socket.setEnabledProtocols(new String[] { - TlsVersion.TLS_1_2.javaName(), - TlsVersion.TLS_1_1.javaName(), - }); - - socket.setEnabledCipherSuites(new String[] { - CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256.javaName(), - CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA.javaName(), - }); - assertThat(tlsSpec.isCompatible(socket)).isTrue(); - - socket.setEnabledCipherSuites(new String[] { - CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA.javaName(), - }); - assertThat(tlsSpec.isCompatible(socket)).isFalse(); - } - - @Test public void allEnabledCipherSuites() throws Exception { - platform.assumeNotConscrypt(); - platform.assumeNotBouncyCastle(); - - ConnectionSpec tlsSpec = new ConnectionSpec.Builder(ConnectionSpec.MODERN_TLS) - .allEnabledCipherSuites() - .build(); - assertThat(tlsSpec.cipherSuites()).isNull(); - - SSLSocket sslSocket = (SSLSocket) SSLSocketFactory.getDefault().createSocket(); - sslSocket.setEnabledCipherSuites(new String[] { - CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256.javaName(), - CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA.javaName(), - }); - - applyConnectionSpec(tlsSpec, sslSocket, false); - if (platform.isAndroid()) { - // https://developer.android.com/reference/javax/net/ssl/SSLSocket - Integer sdkVersion = platform.androidSdkVersion(); - if (sdkVersion != null && sdkVersion >= 29) { - assertThat(sslSocket.getEnabledCipherSuites()).containsExactly( - CipherSuite.TLS_AES_128_GCM_SHA256.javaName(), - CipherSuite.TLS_AES_256_GCM_SHA384.javaName(), - CipherSuite.TLS_CHACHA20_POLY1305_SHA256.javaName(), - CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256.javaName(), - CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA.javaName()); - } else { - assertThat(sslSocket.getEnabledCipherSuites()).containsExactly( - CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256.javaName(), - CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA.javaName()); - } - } else { - assertThat(sslSocket.getEnabledCipherSuites()).containsExactly( - CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256.javaName(), - CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA.javaName()); - } - } - - @Test public void allEnabledTlsVersions() throws Exception { - platform.assumeNotConscrypt(); - - ConnectionSpec tlsSpec = new ConnectionSpec.Builder(ConnectionSpec.MODERN_TLS) - .allEnabledTlsVersions() - .build(); - assertThat(tlsSpec.tlsVersions()).isNull(); - - SSLSocket sslSocket = (SSLSocket) SSLSocketFactory.getDefault().createSocket(); - if (PlatformVersion.INSTANCE.getMajorVersion() > 11) { - sslSocket.setEnabledProtocols(new String[] { - TlsVersion.SSL_3_0.javaName(), - TlsVersion.TLS_1_1.javaName(), - TlsVersion.TLS_1_2.javaName(), - TlsVersion.TLS_1_3.javaName() - }); - } else { - sslSocket.setEnabledProtocols(new String[] { - TlsVersion.SSL_3_0.javaName(), - TlsVersion.TLS_1_1.javaName(), - TlsVersion.TLS_1_2.javaName() - }); - } - - applyConnectionSpec(tlsSpec, sslSocket, false); - if (Platform.Companion.isAndroid()) { - Integer sdkVersion = platform.androidSdkVersion(); - // https://developer.android.com/reference/javax/net/ssl/SSLSocket - if (sdkVersion != null && sdkVersion >= 29) { - assertThat(sslSocket.getEnabledProtocols()).containsExactly( - TlsVersion.TLS_1_1.javaName(), TlsVersion.TLS_1_2.javaName(), - TlsVersion.TLS_1_3.javaName()); - } else if (sdkVersion != null && sdkVersion >= 26) { - assertThat(sslSocket.getEnabledProtocols()).containsExactly( - TlsVersion.TLS_1_1.javaName(), TlsVersion.TLS_1_2.javaName()); - } else { - assertThat(sslSocket.getEnabledProtocols()).containsExactly( - TlsVersion.SSL_3_0.javaName(), TlsVersion.TLS_1_1.javaName(), - TlsVersion.TLS_1_2.javaName()); - } - } else { - if (PlatformVersion.INSTANCE.getMajorVersion() > 11) { - assertThat(sslSocket.getEnabledProtocols()).containsExactly( - TlsVersion.SSL_3_0.javaName(), TlsVersion.TLS_1_1.javaName(), - TlsVersion.TLS_1_2.javaName(), TlsVersion.TLS_1_3.javaName()); - } else { - assertThat(sslSocket.getEnabledProtocols()).containsExactly( - TlsVersion.SSL_3_0.javaName(), TlsVersion.TLS_1_1.javaName(), - TlsVersion.TLS_1_2.javaName()); - } - } - } - - @Test public void tls_missingTlsVersion() throws Exception { - platform.assumeNotConscrypt(); - platform.assumeNotBouncyCastle(); - - ConnectionSpec tlsSpec = new ConnectionSpec.Builder(true) - .cipherSuites(CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) - .tlsVersions(TlsVersion.TLS_1_2) - .supportsTlsExtensions(false) - .build(); - - SSLSocket socket = (SSLSocket) SSLSocketFactory.getDefault().createSocket(); - socket.setEnabledCipherSuites(new String[] { - CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256.javaName(), - }); - - socket.setEnabledProtocols( - new String[] {TlsVersion.TLS_1_2.javaName(), TlsVersion.TLS_1_1.javaName()}); - assertThat(tlsSpec.isCompatible(socket)).isTrue(); - - socket.setEnabledProtocols(new String[] {TlsVersion.TLS_1_1.javaName()}); - assertThat(tlsSpec.isCompatible(socket)).isFalse(); - } - - @Test public void equalsAndHashCode() throws Exception { - ConnectionSpec allCipherSuites = new ConnectionSpec.Builder(ConnectionSpec.MODERN_TLS) - .allEnabledCipherSuites() - .build(); - ConnectionSpec allTlsVersions = new ConnectionSpec.Builder(ConnectionSpec.MODERN_TLS) - .allEnabledTlsVersions() - .build(); - - Set set = new CopyOnWriteArraySet<>(); - assertThat(set.add(ConnectionSpec.MODERN_TLS)).isTrue(); - assertThat(set.add(ConnectionSpec.COMPATIBLE_TLS)).isTrue(); - assertThat(set.add(ConnectionSpec.CLEARTEXT)).isTrue(); - assertThat(set.add(allTlsVersions)).isTrue(); - assertThat(set.add(allCipherSuites)).isTrue(); - allCipherSuites.hashCode(); - assertThat(allCipherSuites.equals(null)).isFalse(); - - assertThat(set.remove(ConnectionSpec.MODERN_TLS)).isTrue(); - assertThat(set.remove(ConnectionSpec.COMPATIBLE_TLS)).isTrue(); - assertThat(set.remove(ConnectionSpec.CLEARTEXT)).isTrue(); - assertThat(set.remove(allTlsVersions)).isTrue(); - assertThat(set.remove(allCipherSuites)).isTrue(); - assertThat(set).isEmpty(); - allTlsVersions.hashCode(); - assertThat(allTlsVersions.equals(null)).isFalse(); - } - - @Test public void allEnabledToString() throws Exception { - ConnectionSpec connectionSpec = new ConnectionSpec.Builder(ConnectionSpec.MODERN_TLS) - .allEnabledTlsVersions() - .allEnabledCipherSuites() - .build(); - assertThat(connectionSpec.toString()).isEqualTo( - ("ConnectionSpec(cipherSuites=[all enabled], tlsVersions=[all enabled], " - + "supportsTlsExtensions=true)")); - } - - @Test public void simpleToString() throws Exception { - ConnectionSpec connectionSpec = new ConnectionSpec.Builder(ConnectionSpec.MODERN_TLS) - .tlsVersions(TlsVersion.TLS_1_2) - .cipherSuites(CipherSuite.TLS_RSA_WITH_RC4_128_MD5) - .build(); - assertThat(connectionSpec.toString()).isEqualTo( - ("ConnectionSpec(cipherSuites=[SSL_RSA_WITH_RC4_128_MD5], tlsVersions=[TLS_1_2], " - + "supportsTlsExtensions=true)")); - } -} diff --git a/okhttp/src/test/java/okhttp3/ConnectionSpecTest.kt b/okhttp/src/test/java/okhttp3/ConnectionSpecTest.kt new file mode 100644 index 000000000000..15558a396974 --- /dev/null +++ b/okhttp/src/test/java/okhttp3/ConnectionSpecTest.kt @@ -0,0 +1,396 @@ +/* + * Copyright (C) 2015 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3 + +import java.util.Arrays +import java.util.concurrent.CopyOnWriteArraySet +import javax.net.ssl.SSLSocket +import javax.net.ssl.SSLSocketFactory +import okhttp3.internal.applyConnectionSpec +import okhttp3.internal.platform.Platform.Companion.isAndroid +import okhttp3.testing.PlatformRule +import okhttp3.testing.PlatformVersion.majorVersion +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Assertions.fail +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.RegisterExtension + +class ConnectionSpecTest { + @RegisterExtension + val platform = PlatformRule() + + @Test + fun noTlsVersions() { + try { + ConnectionSpec.Builder(ConnectionSpec.MODERN_TLS) + .tlsVersions(*arrayOf()) + .build() + fail() + } catch (expected: IllegalArgumentException) { + assertThat(expected.message) + .isEqualTo("At least one TLS version is required") + } + } + + @Test + fun noCipherSuites() { + try { + ConnectionSpec.Builder(ConnectionSpec.MODERN_TLS) + .cipherSuites(*arrayOf()) + .build() + fail() + } catch (expected: IllegalArgumentException) { + assertThat(expected.message) + .isEqualTo("At least one cipher suite is required") + } + } + + @Test + fun cleartextBuilder() { + val cleartextSpec = ConnectionSpec.Builder(false).build() + assertThat(cleartextSpec.isTls).isFalse() + } + + @Test + fun tlsBuilder_explicitCiphers() { + val tlsSpec = ConnectionSpec.Builder(true) + .cipherSuites(CipherSuite.TLS_RSA_WITH_RC4_128_MD5) + .tlsVersions(TlsVersion.TLS_1_2) + .supportsTlsExtensions(true) + .build() + assertThat(tlsSpec.cipherSuites) + .containsExactly(CipherSuite.TLS_RSA_WITH_RC4_128_MD5) + assertThat(tlsSpec.tlsVersions) + .containsExactly(TlsVersion.TLS_1_2) + assertThat(tlsSpec.supportsTlsExtensions).isTrue() + } + + @Test + fun tlsBuilder_defaultCiphers() { + val tlsSpec = ConnectionSpec.Builder(true) + .tlsVersions(TlsVersion.TLS_1_2) + .supportsTlsExtensions(true) + .build() + assertThat(tlsSpec.cipherSuites).isNull() + assertThat(tlsSpec.tlsVersions) + .containsExactly(TlsVersion.TLS_1_2) + assertThat(tlsSpec.supportsTlsExtensions).isTrue() + } + + @Test + fun tls_defaultCiphers_noFallbackIndicator() { + platform.assumeNotConscrypt() + platform.assumeNotBouncyCastle() + val tlsSpec = ConnectionSpec.Builder(true) + .tlsVersions(TlsVersion.TLS_1_2) + .supportsTlsExtensions(false) + .build() + val socket = SSLSocketFactory.getDefault().createSocket() as SSLSocket + socket.enabledCipherSuites = arrayOf( + CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256.javaName, + CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA.javaName + ) + socket.enabledProtocols = arrayOf( + TlsVersion.TLS_1_2.javaName, + TlsVersion.TLS_1_1.javaName + ) + assertThat(tlsSpec.isCompatible(socket)).isTrue() + applyConnectionSpec(tlsSpec, socket, isFallback = false) + assertThat(socket.enabledProtocols).containsExactly( + TlsVersion.TLS_1_2.javaName + ) + assertThat(socket.enabledCipherSuites) + .containsExactlyInAnyOrder( + CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256.javaName, + CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA.javaName + ) + } + + @Test + fun tls_defaultCiphers_withFallbackIndicator() { + platform.assumeNotConscrypt() + platform.assumeNotBouncyCastle() + val tlsSpec = ConnectionSpec.Builder(true) + .tlsVersions(TlsVersion.TLS_1_2) + .supportsTlsExtensions(false) + .build() + val socket = SSLSocketFactory.getDefault().createSocket() as SSLSocket + socket.enabledCipherSuites = arrayOf( + CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256.javaName, + CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA.javaName + ) + socket.enabledProtocols = arrayOf( + TlsVersion.TLS_1_2.javaName, + TlsVersion.TLS_1_1.javaName + ) + assertThat(tlsSpec.isCompatible(socket)).isTrue() + applyConnectionSpec(tlsSpec, socket, isFallback = true) + assertThat(socket.enabledProtocols).containsExactly( + TlsVersion.TLS_1_2.javaName + ) + val expectedCipherSuites: MutableList = ArrayList() + expectedCipherSuites.add(CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256.javaName) + expectedCipherSuites.add(CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA.javaName) + if (listOf(*socket.supportedCipherSuites).contains("TLS_FALLBACK_SCSV")) { + expectedCipherSuites.add("TLS_FALLBACK_SCSV") + } + assertThat(socket.enabledCipherSuites) + .containsExactlyElementsOf(expectedCipherSuites) + } + + @Test + fun tls_explicitCiphers() { + platform.assumeNotConscrypt() + platform.assumeNotBouncyCastle() + val tlsSpec = ConnectionSpec.Builder(true) + .cipherSuites(CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) + .tlsVersions(TlsVersion.TLS_1_2) + .supportsTlsExtensions(false) + .build() + val socket = SSLSocketFactory.getDefault().createSocket() as SSLSocket + socket.enabledCipherSuites = arrayOf( + CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256.javaName, + CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA.javaName + ) + socket.enabledProtocols = arrayOf( + TlsVersion.TLS_1_2.javaName, + TlsVersion.TLS_1_1.javaName + ) + assertThat(tlsSpec.isCompatible(socket)).isTrue() + applyConnectionSpec(tlsSpec, socket, isFallback = true) + assertThat(socket.enabledProtocols).containsExactly( + TlsVersion.TLS_1_2.javaName + ) + val expectedCipherSuites: MutableList = ArrayList() + expectedCipherSuites.add(CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256.javaName) + if (Arrays.asList(*socket.supportedCipherSuites).contains("TLS_FALLBACK_SCSV")) { + expectedCipherSuites.add("TLS_FALLBACK_SCSV") + } + assertThat(socket.enabledCipherSuites) + .containsExactlyElementsOf(expectedCipherSuites) + } + + @Test + fun tls_stringCiphersAndVersions() { + // Supporting arbitrary input strings allows users to enable suites and versions that are not + // yet known to the library, but are supported by the platform. + ConnectionSpec.Builder(ConnectionSpec.MODERN_TLS) + .cipherSuites("MAGIC-CIPHER") + .tlsVersions("TLS9k") + .build() + } + + @Test + fun tls_missingRequiredCipher() { + platform.assumeNotConscrypt() + platform.assumeNotBouncyCastle() + val tlsSpec = ConnectionSpec.Builder(true) + .cipherSuites(CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) + .tlsVersions(TlsVersion.TLS_1_2) + .supportsTlsExtensions(false) + .build() + val socket = SSLSocketFactory.getDefault().createSocket() as SSLSocket + socket.enabledProtocols = arrayOf( + TlsVersion.TLS_1_2.javaName, + TlsVersion.TLS_1_1.javaName + ) + socket.enabledCipherSuites = arrayOf( + CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256.javaName, + CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA.javaName + ) + assertThat(tlsSpec.isCompatible(socket)).isTrue() + socket.enabledCipherSuites = arrayOf( + CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA.javaName + ) + assertThat(tlsSpec.isCompatible(socket)).isFalse() + } + + @Test + fun allEnabledCipherSuites() { + platform.assumeNotConscrypt() + platform.assumeNotBouncyCastle() + val tlsSpec = ConnectionSpec.Builder(ConnectionSpec.MODERN_TLS) + .allEnabledCipherSuites() + .build() + assertThat(tlsSpec.cipherSuites).isNull() + val sslSocket = SSLSocketFactory.getDefault().createSocket() as SSLSocket + sslSocket.enabledCipherSuites = arrayOf( + CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256.javaName, + CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA.javaName + ) + applyConnectionSpec(tlsSpec, sslSocket, false) + if (platform.isAndroid) { + // https://developer.android.com/reference/javax/net/ssl/SSLSocket + val sdkVersion = platform.androidSdkVersion() + if (sdkVersion != null && sdkVersion >= 29) { + assertThat(sslSocket.enabledCipherSuites) + .containsExactly( + CipherSuite.TLS_AES_128_GCM_SHA256.javaName, + CipherSuite.TLS_AES_256_GCM_SHA384.javaName, + CipherSuite.TLS_CHACHA20_POLY1305_SHA256.javaName, + CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256.javaName, + CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA.javaName + ) + } else { + assertThat(sslSocket.enabledCipherSuites) + .containsExactly( + CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256.javaName, + CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA.javaName + ) + } + } else { + assertThat(sslSocket.enabledCipherSuites) + .containsExactly( + CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256.javaName, + CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA.javaName + ) + } + } + + @Test + fun allEnabledTlsVersions() { + platform.assumeNotConscrypt() + val tlsSpec = ConnectionSpec.Builder(ConnectionSpec.MODERN_TLS) + .allEnabledTlsVersions() + .build() + assertThat(tlsSpec.tlsVersions).isNull() + val sslSocket = SSLSocketFactory.getDefault().createSocket() as SSLSocket + if (majorVersion > 11) { + sslSocket.enabledProtocols = arrayOf( + TlsVersion.SSL_3_0.javaName, + TlsVersion.TLS_1_1.javaName, + TlsVersion.TLS_1_2.javaName, + TlsVersion.TLS_1_3.javaName + ) + } else { + sslSocket.enabledProtocols = arrayOf( + TlsVersion.SSL_3_0.javaName, + TlsVersion.TLS_1_1.javaName, + TlsVersion.TLS_1_2.javaName + ) + } + applyConnectionSpec(tlsSpec, sslSocket, false) + if (isAndroid) { + val sdkVersion = platform.androidSdkVersion() + // https://developer.android.com/reference/javax/net/ssl/SSLSocket + if (sdkVersion != null && sdkVersion >= 29) { + assertThat(sslSocket.enabledProtocols) + .containsExactly( + TlsVersion.TLS_1_1.javaName, TlsVersion.TLS_1_2.javaName, + TlsVersion.TLS_1_3.javaName + ) + } else if (sdkVersion != null && sdkVersion >= 26) { + assertThat(sslSocket.enabledProtocols) + .containsExactly( + TlsVersion.TLS_1_1.javaName, TlsVersion.TLS_1_2.javaName + ) + } else { + assertThat(sslSocket.enabledProtocols) + .containsExactly( + TlsVersion.SSL_3_0.javaName, TlsVersion.TLS_1_1.javaName, + TlsVersion.TLS_1_2.javaName + ) + } + } else { + if (majorVersion > 11) { + assertThat(sslSocket.enabledProtocols) + .containsExactly( + TlsVersion.SSL_3_0.javaName, TlsVersion.TLS_1_1.javaName, + TlsVersion.TLS_1_2.javaName, TlsVersion.TLS_1_3.javaName + ) + } else { + assertThat(sslSocket.enabledProtocols) + .containsExactly( + TlsVersion.SSL_3_0.javaName, TlsVersion.TLS_1_1.javaName, + TlsVersion.TLS_1_2.javaName + ) + } + } + } + + @Test + fun tls_missingTlsVersion() { + platform.assumeNotConscrypt() + platform.assumeNotBouncyCastle() + val tlsSpec = ConnectionSpec.Builder(true) + .cipherSuites(CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) + .tlsVersions(TlsVersion.TLS_1_2) + .supportsTlsExtensions(false) + .build() + val socket = SSLSocketFactory.getDefault().createSocket() as SSLSocket + socket.enabledCipherSuites = arrayOf( + CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256.javaName + ) + socket.enabledProtocols = arrayOf( + TlsVersion.TLS_1_2.javaName, + TlsVersion.TLS_1_1.javaName + ) + assertThat(tlsSpec.isCompatible(socket)).isTrue() + socket.enabledProtocols = arrayOf(TlsVersion.TLS_1_1.javaName) + assertThat(tlsSpec.isCompatible(socket)).isFalse() + } + + @Test + fun equalsAndHashCode() { + val allCipherSuites = ConnectionSpec.Builder(ConnectionSpec.MODERN_TLS) + .allEnabledCipherSuites() + .build() + val allTlsVersions = ConnectionSpec.Builder(ConnectionSpec.MODERN_TLS) + .allEnabledTlsVersions() + .build() + val set: MutableSet = CopyOnWriteArraySet() + assertThat(set.add(ConnectionSpec.MODERN_TLS)).isTrue() + assertThat(set.add(ConnectionSpec.COMPATIBLE_TLS)).isTrue() + assertThat(set.add(ConnectionSpec.CLEARTEXT)).isTrue() + assertThat(set.add(allTlsVersions)).isTrue() + assertThat(set.add(allCipherSuites)).isTrue() + allCipherSuites.hashCode() + assertThat(allCipherSuites.equals(null)).isFalse() + assertThat(set.remove(ConnectionSpec.MODERN_TLS)).isTrue() + assertThat(set.remove(ConnectionSpec.COMPATIBLE_TLS)) + .isTrue() + assertThat(set.remove(ConnectionSpec.CLEARTEXT)).isTrue() + assertThat(set.remove(allTlsVersions)).isTrue() + assertThat(set.remove(allCipherSuites)).isTrue() + assertThat(set).isEmpty() + allTlsVersions.hashCode() + assertThat(allTlsVersions.equals(null)).isFalse() + } + + @Test + fun allEnabledToString() { + val connectionSpec = ConnectionSpec.Builder(ConnectionSpec.MODERN_TLS) + .allEnabledTlsVersions() + .allEnabledCipherSuites() + .build() + assertThat(connectionSpec.toString()).isEqualTo( + "ConnectionSpec(cipherSuites=[all enabled], tlsVersions=[all enabled], " + + "supportsTlsExtensions=true)" + ) + } + + @Test + fun simpleToString() { + val connectionSpec = ConnectionSpec.Builder(ConnectionSpec.MODERN_TLS) + .tlsVersions(TlsVersion.TLS_1_2) + .cipherSuites(CipherSuite.TLS_RSA_WITH_RC4_128_MD5) + .build() + assertThat(connectionSpec.toString()).isEqualTo( + "ConnectionSpec(cipherSuites=[SSL_RSA_WITH_RC4_128_MD5], tlsVersions=[TLS_1_2], " + + "supportsTlsExtensions=true)" + ) + } +} diff --git a/okhttp/src/test/java/okhttp3/DispatcherTest.java b/okhttp/src/test/java/okhttp3/DispatcherTest.java deleted file mode 100644 index c43adc686908..000000000000 --- a/okhttp/src/test/java/okhttp3/DispatcherTest.java +++ /dev/null @@ -1,333 +0,0 @@ -package okhttp3; - -import java.io.IOException; -import java.io.InterruptedIOException; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.atomic.AtomicBoolean; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Tag; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.RegisterExtension; - -import static java.util.concurrent.TimeUnit.SECONDS; -import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.fail; - -@Tag("Slowish") -public final class DispatcherTest { - @RegisterExtension public final OkHttpClientTestRule clientTestRule = new OkHttpClientTestRule(); - - final RecordingExecutor executor = new RecordingExecutor(this); - final RecordingCallback callback = new RecordingCallback(); - final RecordingWebSocketListener webSocketListener = new RecordingWebSocketListener(); - final Dispatcher dispatcher = new Dispatcher(executor); - final RecordingEventListener listener = new RecordingEventListener(); - OkHttpClient client = clientTestRule.newClientBuilder() - .dispatcher(dispatcher) - .eventListenerFactory(clientTestRule.wrap(listener)) - .build(); - - @BeforeEach public void setUp() throws Exception { - dispatcher.setMaxRequests(20); - dispatcher.setMaxRequestsPerHost(10); - listener.forbidLock(dispatcher); - } - - @Test public void maxRequestsZero() throws Exception { - try { - dispatcher.setMaxRequests(0); - fail(); - } catch (IllegalArgumentException expected) { - } - } - - @Test public void maxPerHostZero() throws Exception { - try { - dispatcher.setMaxRequestsPerHost(0); - fail(); - } catch (IllegalArgumentException expected) { - } - } - - @Test public void enqueuedJobsRunImmediately() throws Exception { - client.newCall(newRequest("http://a/1")).enqueue(callback); - executor.assertJobs("http://a/1"); - } - - @Test public void maxRequestsEnforced() throws Exception { - dispatcher.setMaxRequests(3); - client.newCall(newRequest("http://a/1")).enqueue(callback); - client.newCall(newRequest("http://a/2")).enqueue(callback); - client.newCall(newRequest("http://b/1")).enqueue(callback); - client.newCall(newRequest("http://b/2")).enqueue(callback); - executor.assertJobs("http://a/1", "http://a/2", "http://b/1"); - } - - @Test public void maxPerHostEnforced() throws Exception { - dispatcher.setMaxRequestsPerHost(2); - client.newCall(newRequest("http://a/1")).enqueue(callback); - client.newCall(newRequest("http://a/2")).enqueue(callback); - client.newCall(newRequest("http://a/3")).enqueue(callback); - executor.assertJobs("http://a/1", "http://a/2"); - } - - @Test public void maxPerHostNotEnforcedForWebSockets() { - dispatcher.setMaxRequestsPerHost(2); - client.newWebSocket(newRequest("http://a/1"), webSocketListener); - client.newWebSocket(newRequest("http://a/2"), webSocketListener); - client.newWebSocket(newRequest("http://a/3"), webSocketListener); - executor.assertJobs("http://a/1", "http://a/2", "http://a/3"); - } - - @Test public void increasingMaxRequestsPromotesJobsImmediately() throws Exception { - dispatcher.setMaxRequests(2); - client.newCall(newRequest("http://a/1")).enqueue(callback); - client.newCall(newRequest("http://b/1")).enqueue(callback); - client.newCall(newRequest("http://c/1")).enqueue(callback); - client.newCall(newRequest("http://a/2")).enqueue(callback); - client.newCall(newRequest("http://b/2")).enqueue(callback); - dispatcher.setMaxRequests(4); - executor.assertJobs("http://a/1", "http://b/1", "http://c/1", "http://a/2"); - } - - @Test public void increasingMaxPerHostPromotesJobsImmediately() throws Exception { - dispatcher.setMaxRequestsPerHost(2); - client.newCall(newRequest("http://a/1")).enqueue(callback); - client.newCall(newRequest("http://a/2")).enqueue(callback); - client.newCall(newRequest("http://a/3")).enqueue(callback); - client.newCall(newRequest("http://a/4")).enqueue(callback); - client.newCall(newRequest("http://a/5")).enqueue(callback); - dispatcher.setMaxRequestsPerHost(4); - executor.assertJobs("http://a/1", "http://a/2", "http://a/3", "http://a/4"); - } - - @Test public void oldJobFinishesNewJobCanRunDifferentHost() throws Exception { - dispatcher.setMaxRequests(1); - client.newCall(newRequest("http://a/1")).enqueue(callback); - client.newCall(newRequest("http://b/1")).enqueue(callback); - executor.finishJob("http://a/1"); - executor.assertJobs("http://b/1"); - } - - @Test public void oldJobFinishesNewJobWithSameHostStarts() throws Exception { - dispatcher.setMaxRequests(2); - dispatcher.setMaxRequestsPerHost(1); - client.newCall(newRequest("http://a/1")).enqueue(callback); - client.newCall(newRequest("http://b/1")).enqueue(callback); - client.newCall(newRequest("http://b/2")).enqueue(callback); - client.newCall(newRequest("http://a/2")).enqueue(callback); - executor.finishJob("http://a/1"); - executor.assertJobs("http://b/1", "http://a/2"); - } - - @Test public void oldJobFinishesNewJobCantRunDueToHostLimit() throws Exception { - dispatcher.setMaxRequestsPerHost(1); - client.newCall(newRequest("http://a/1")).enqueue(callback); - client.newCall(newRequest("http://b/1")).enqueue(callback); - client.newCall(newRequest("http://a/2")).enqueue(callback); - executor.finishJob("http://b/1"); - executor.assertJobs("http://a/1"); - } - - @Test public void enqueuedCallsStillRespectMaxCallsPerHost() throws Exception { - dispatcher.setMaxRequests(1); - dispatcher.setMaxRequestsPerHost(1); - client.newCall(newRequest("http://a/1")).enqueue(callback); - client.newCall(newRequest("http://b/1")).enqueue(callback); - client.newCall(newRequest("http://b/2")).enqueue(callback); - client.newCall(newRequest("http://b/3")).enqueue(callback); - dispatcher.setMaxRequests(3); - executor.finishJob("http://a/1"); - executor.assertJobs("http://b/1"); - } - - @Test public void cancelingRunningJobTakesNoEffectUntilJobFinishes() throws Exception { - dispatcher.setMaxRequests(1); - Call c1 = client.newCall(newRequest("http://a/1", "tag1")); - Call c2 = client.newCall(newRequest("http://a/2")); - c1.enqueue(callback); - c2.enqueue(callback); - c1.cancel(); - executor.assertJobs("http://a/1"); - executor.finishJob("http://a/1"); - executor.assertJobs("http://a/2"); - } - - @Test public void asyncCallAccessors() throws Exception { - dispatcher.setMaxRequests(3); - Call a1 = client.newCall(newRequest("http://a/1")); - Call a2 = client.newCall(newRequest("http://a/2")); - Call a3 = client.newCall(newRequest("http://a/3")); - Call a4 = client.newCall(newRequest("http://a/4")); - Call a5 = client.newCall(newRequest("http://a/5")); - a1.enqueue(callback); - a2.enqueue(callback); - a3.enqueue(callback); - a4.enqueue(callback); - a5.enqueue(callback); - assertThat(dispatcher.runningCallsCount()).isEqualTo(3); - assertThat(dispatcher.queuedCallsCount()).isEqualTo(2); - assertThat(dispatcher.runningCalls()).containsExactlyInAnyOrder(a1, a2, a3); - assertThat(dispatcher.queuedCalls()).containsExactlyInAnyOrder(a4, a5); - } - - @Test public void synchronousCallAccessors() throws Exception { - CountDownLatch ready = new CountDownLatch(2); - CountDownLatch waiting = new CountDownLatch(1); - client = client.newBuilder() - .addInterceptor(chain -> { - try { - ready.countDown(); - waiting.await(); - } catch (InterruptedException e) { - throw new AssertionError(); - } - throw new IOException(); - }) - .build(); - - Call a1 = client.newCall(newRequest("http://a/1")); - Call a2 = client.newCall(newRequest("http://a/2")); - Call a3 = client.newCall(newRequest("http://a/3")); - Call a4 = client.newCall(newRequest("http://a/4")); - Thread t1 = makeSynchronousCall(a1); - Thread t2 = makeSynchronousCall(a2); - - // We created 4 calls and started 2 of them. That's 2 running calls and 0 queued. - ready.await(); - assertThat(dispatcher.runningCallsCount()).isEqualTo(2); - assertThat(dispatcher.queuedCallsCount()).isEqualTo(0); - assertThat(dispatcher.runningCalls()).containsExactlyInAnyOrder(a1, a2); - assertThat(dispatcher.queuedCalls()).isEmpty(); - - // Cancel some calls. That doesn't impact running or queued. - a2.cancel(); - a3.cancel(); - assertThat(dispatcher.runningCalls()).containsExactlyInAnyOrder(a1, a2); - assertThat(dispatcher.queuedCalls()).isEmpty(); - - // Let the calls finish. - waiting.countDown(); - t1.join(); - t2.join(); - - // Now we should have 0 running calls and 0 queued calls. - assertThat(dispatcher.runningCallsCount()).isEqualTo(0); - assertThat(dispatcher.queuedCallsCount()).isEqualTo(0); - assertThat(dispatcher.runningCalls()).isEmpty(); - assertThat(dispatcher.queuedCalls()).isEmpty(); - - assertThat(a1.isExecuted()).isTrue(); - assertThat(a1.isCanceled()).isFalse(); - - assertThat(a2.isExecuted()).isTrue(); - assertThat(a2.isCanceled()).isTrue(); - - assertThat(a3.isExecuted()).isFalse(); - assertThat(a3.isCanceled()).isTrue(); - - assertThat(a4.isExecuted()).isFalse(); - assertThat(a4.isCanceled()).isFalse(); - } - - @Test public void idleCallbackInvokedWhenIdle() throws Exception { - AtomicBoolean idle = new AtomicBoolean(); - dispatcher.setIdleCallback(() -> idle.set(true)); - - client.newCall(newRequest("http://a/1")).enqueue(callback); - client.newCall(newRequest("http://a/2")).enqueue(callback); - executor.finishJob("http://a/1"); - assertThat(idle.get()).isFalse(); - - CountDownLatch ready = new CountDownLatch(1); - CountDownLatch proceed = new CountDownLatch(1); - client = client.newBuilder() - .addInterceptor(chain -> { - ready.countDown(); - try { - proceed.await(5, SECONDS); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - return chain.proceed(chain.request()); - }) - .build(); - - Thread t1 = makeSynchronousCall(client.newCall(newRequest("http://a/3"))); - ready.await(5, SECONDS); - executor.finishJob("http://a/2"); - assertThat(idle.get()).isFalse(); - - proceed.countDown(); - t1.join(); - assertThat(idle.get()).isTrue(); - } - - @Test public void executionRejectedImmediately() throws Exception { - Request request = newRequest("http://a/1"); - executor.shutdown(); - client.newCall(request).enqueue(callback); - callback.await(request.url()).assertFailure(InterruptedIOException.class); - assertThat(listener.recordedEventTypes()).containsExactly("CallStart", "CallFailed"); - } - - @Test public void executionRejectedAfterMaxRequestsChange() throws Exception { - Request request1 = newRequest("http://a/1"); - Request request2 = newRequest("http://a/2"); - dispatcher.setMaxRequests(1); - client.newCall(request1).enqueue(callback); - executor.shutdown(); - client.newCall(request2).enqueue(callback); - dispatcher.setMaxRequests(2); // Trigger promotion. - callback.await(request2.url()).assertFailure(InterruptedIOException.class); - - assertThat(listener.recordedEventTypes()) - .containsExactly("CallStart", "CallStart", "CallFailed"); - } - - @Test public void executionRejectedAfterMaxRequestsPerHostChange() throws Exception { - Request request1 = newRequest("http://a/1"); - Request request2 = newRequest("http://a/2"); - dispatcher.setMaxRequestsPerHost(1); - client.newCall(request1).enqueue(callback); - executor.shutdown(); - client.newCall(request2).enqueue(callback); - dispatcher.setMaxRequestsPerHost(2); // Trigger promotion. - callback.await(request2.url()).assertFailure(InterruptedIOException.class); - assertThat(listener.recordedEventTypes()) - .containsExactly("CallStart", "CallStart", "CallFailed"); - } - - @Test public void executionRejectedAfterPrecedingCallFinishes() throws Exception { - Request request1 = newRequest("http://a/1"); - Request request2 = newRequest("http://a/2"); - dispatcher.setMaxRequests(1); - client.newCall(request1).enqueue(callback); - executor.shutdown(); - client.newCall(request2).enqueue(callback); - executor.finishJob("http://a/1"); // Trigger promotion. - callback.await(request2.url()).assertFailure(InterruptedIOException.class); - assertThat(listener.recordedEventTypes()) - .containsExactly("CallStart", "CallStart", "CallFailed"); - } - - private Thread makeSynchronousCall(Call call) { - Thread thread = new Thread(() -> { - try { - call.execute(); - throw new AssertionError(); - } catch (IOException expected) { - } - }); - thread.start(); - return thread; - } - - private Request newRequest(String url) { - return new Request.Builder().url(url).build(); - } - - private Request newRequest(String url, String tag) { - return new Request.Builder().url(url).tag(tag).build(); - } -} diff --git a/okhttp/src/test/java/okhttp3/DispatcherTest.kt b/okhttp/src/test/java/okhttp3/DispatcherTest.kt new file mode 100644 index 000000000000..3144f874f0b8 --- /dev/null +++ b/okhttp/src/test/java/okhttp3/DispatcherTest.kt @@ -0,0 +1,350 @@ +package okhttp3 + +import java.io.IOException +import java.io.InterruptedIOException +import java.net.UnknownHostException +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicBoolean +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Assertions.fail +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Tag +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.RegisterExtension + +@Tag("Slowish") +class DispatcherTest { + @RegisterExtension + val clientTestRule = OkHttpClientTestRule() + private val executor = RecordingExecutor(this) + val callback = RecordingCallback() + val webSocketListener = RecordingWebSocketListener() + val dispatcher = Dispatcher(executor) + val listener = RecordingEventListener() + var client = clientTestRule.newClientBuilder() + .dns { throw UnknownHostException() } + .dispatcher(dispatcher) + .eventListenerFactory(clientTestRule.wrap(listener)) + .build() + + @BeforeEach + fun setUp() { + dispatcher.maxRequests = 20 + dispatcher.maxRequestsPerHost = 10 + listener.forbidLock(dispatcher) + } + + @Test + fun maxRequestsZero() { + try { + dispatcher.maxRequests = 0 + fail() + } catch (expected: IllegalArgumentException) { + } + } + + @Test + fun maxPerHostZero() { + try { + dispatcher.maxRequestsPerHost = 0 + fail() + } catch (expected: IllegalArgumentException) { + } + } + + @Test + fun enqueuedJobsRunImmediately() { + client.newCall(newRequest("http://a/1")).enqueue(callback) + executor.assertJobs("http://a/1") + } + + @Test + fun maxRequestsEnforced() { + dispatcher.maxRequests = 3 + client.newCall(newRequest("http://a/1")).enqueue(callback) + client.newCall(newRequest("http://a/2")).enqueue(callback) + client.newCall(newRequest("http://b/1")).enqueue(callback) + client.newCall(newRequest("http://b/2")).enqueue(callback) + executor.assertJobs("http://a/1", "http://a/2", "http://b/1") + } + + @Test + fun maxPerHostEnforced() { + dispatcher.maxRequestsPerHost = 2 + client.newCall(newRequest("http://a/1")).enqueue(callback) + client.newCall(newRequest("http://a/2")).enqueue(callback) + client.newCall(newRequest("http://a/3")).enqueue(callback) + executor.assertJobs("http://a/1", "http://a/2") + } + + @Test + fun maxPerHostNotEnforcedForWebSockets() { + dispatcher.maxRequestsPerHost = 2 + client.newWebSocket(newRequest("http://a/1"), webSocketListener) + client.newWebSocket(newRequest("http://a/2"), webSocketListener) + client.newWebSocket(newRequest("http://a/3"), webSocketListener) + executor.assertJobs("http://a/1", "http://a/2", "http://a/3") + } + + @Test + fun increasingMaxRequestsPromotesJobsImmediately() { + dispatcher.maxRequests = 2 + client.newCall(newRequest("http://a/1")).enqueue(callback) + client.newCall(newRequest("http://b/1")).enqueue(callback) + client.newCall(newRequest("http://c/1")).enqueue(callback) + client.newCall(newRequest("http://a/2")).enqueue(callback) + client.newCall(newRequest("http://b/2")).enqueue(callback) + dispatcher.maxRequests = 4 + executor.assertJobs("http://a/1", "http://b/1", "http://c/1", "http://a/2") + } + + @Test + fun increasingMaxPerHostPromotesJobsImmediately() { + dispatcher.maxRequestsPerHost = 2 + client.newCall(newRequest("http://a/1")).enqueue(callback) + client.newCall(newRequest("http://a/2")).enqueue(callback) + client.newCall(newRequest("http://a/3")).enqueue(callback) + client.newCall(newRequest("http://a/4")).enqueue(callback) + client.newCall(newRequest("http://a/5")).enqueue(callback) + dispatcher.maxRequestsPerHost = 4 + executor.assertJobs("http://a/1", "http://a/2", "http://a/3", "http://a/4") + } + + @Test + fun oldJobFinishesNewJobCanRunDifferentHost() { + dispatcher.maxRequests = 1 + client.newCall(newRequest("http://a/1")).enqueue(callback) + client.newCall(newRequest("http://b/1")).enqueue(callback) + executor.finishJob("http://a/1") + executor.assertJobs("http://b/1") + } + + @Test + fun oldJobFinishesNewJobWithSameHostStarts() { + dispatcher.maxRequests = 2 + dispatcher.maxRequestsPerHost = 1 + client.newCall(newRequest("http://a/1")).enqueue(callback) + client.newCall(newRequest("http://b/1")).enqueue(callback) + client.newCall(newRequest("http://b/2")).enqueue(callback) + client.newCall(newRequest("http://a/2")).enqueue(callback) + executor.finishJob("http://a/1") + executor.assertJobs("http://b/1", "http://a/2") + } + + @Test + fun oldJobFinishesNewJobCantRunDueToHostLimit() { + dispatcher.maxRequestsPerHost = 1 + client.newCall(newRequest("http://a/1")).enqueue(callback) + client.newCall(newRequest("http://b/1")).enqueue(callback) + client.newCall(newRequest("http://a/2")).enqueue(callback) + executor.finishJob("http://b/1") + executor.assertJobs("http://a/1") + } + + @Test + fun enqueuedCallsStillRespectMaxCallsPerHost() { + dispatcher.maxRequests = 1 + dispatcher.maxRequestsPerHost = 1 + client.newCall(newRequest("http://a/1")).enqueue(callback) + client.newCall(newRequest("http://b/1")).enqueue(callback) + client.newCall(newRequest("http://b/2")).enqueue(callback) + client.newCall(newRequest("http://b/3")).enqueue(callback) + dispatcher.maxRequests = 3 + executor.finishJob("http://a/1") + executor.assertJobs("http://b/1") + } + + @Test + fun cancelingRunningJobTakesNoEffectUntilJobFinishes() { + dispatcher.maxRequests = 1 + val c1 = client.newCall(newRequest("http://a/1", "tag1")) + val c2 = client.newCall(newRequest("http://a/2")) + c1.enqueue(callback) + c2.enqueue(callback) + c1.cancel() + executor.assertJobs("http://a/1") + executor.finishJob("http://a/1") + executor.assertJobs("http://a/2") + } + + @Test + fun asyncCallAccessors() { + dispatcher.maxRequests = 3 + val a1 = client.newCall(newRequest("http://a/1")) + val a2 = client.newCall(newRequest("http://a/2")) + val a3 = client.newCall(newRequest("http://a/3")) + val a4 = client.newCall(newRequest("http://a/4")) + val a5 = client.newCall(newRequest("http://a/5")) + a1.enqueue(callback) + a2.enqueue(callback) + a3.enqueue(callback) + a4.enqueue(callback) + a5.enqueue(callback) + assertThat(dispatcher.runningCallsCount()).isEqualTo(3) + assertThat(dispatcher.queuedCallsCount()).isEqualTo(2) + assertThat(dispatcher.runningCalls()) + .containsExactlyInAnyOrder(a1, a2, a3) + assertThat(dispatcher.queuedCalls()) + .containsExactlyInAnyOrder(a4, a5) + } + + @Test + fun synchronousCallAccessors() { + val ready = CountDownLatch(2) + val waiting = CountDownLatch(1) + client = client.newBuilder() + .addInterceptor(Interceptor { chain: Interceptor.Chain? -> + try { + ready.countDown() + waiting.await() + } catch (e: InterruptedException) { + throw AssertionError() + } + throw IOException() + }) + .build() + val a1 = client.newCall(newRequest("http://a/1")) + val a2 = client.newCall(newRequest("http://a/2")) + val a3 = client.newCall(newRequest("http://a/3")) + val a4 = client.newCall(newRequest("http://a/4")) + val t1 = makeSynchronousCall(a1) + val t2 = makeSynchronousCall(a2) + + // We created 4 calls and started 2 of them. That's 2 running calls and 0 queued. + ready.await() + assertThat(dispatcher.runningCallsCount()).isEqualTo(2) + assertThat(dispatcher.queuedCallsCount()).isEqualTo(0) + assertThat(dispatcher.runningCalls()) + .containsExactlyInAnyOrder(a1, a2) + assertThat(dispatcher.queuedCalls()).isEmpty() + + // Cancel some calls. That doesn't impact running or queued. + a2.cancel() + a3.cancel() + assertThat(dispatcher.runningCalls()) + .containsExactlyInAnyOrder(a1, a2) + assertThat(dispatcher.queuedCalls()).isEmpty() + + // Let the calls finish. + waiting.countDown() + t1.join() + t2.join() + + // Now we should have 0 running calls and 0 queued calls. + assertThat(dispatcher.runningCallsCount()).isEqualTo(0) + assertThat(dispatcher.queuedCallsCount()).isEqualTo(0) + assertThat(dispatcher.runningCalls()).isEmpty() + assertThat(dispatcher.queuedCalls()).isEmpty() + assertThat(a1.isExecuted()).isTrue() + assertThat(a1.isCanceled()).isFalse() + assertThat(a2.isExecuted()).isTrue() + assertThat(a2.isCanceled()).isTrue() + assertThat(a3.isExecuted()).isFalse() + assertThat(a3.isCanceled()).isTrue() + assertThat(a4.isExecuted()).isFalse() + assertThat(a4.isCanceled()).isFalse() + } + + @Test + fun idleCallbackInvokedWhenIdle() { + val idle = AtomicBoolean() + dispatcher.idleCallback = Runnable { idle.set(true) } + client.newCall(newRequest("http://a/1")).enqueue(callback) + client.newCall(newRequest("http://a/2")).enqueue(callback) + executor.finishJob("http://a/1") + assertThat(idle.get()).isFalse() + val ready = CountDownLatch(1) + val proceed = CountDownLatch(1) + client = client.newBuilder() + .addInterceptor(Interceptor { chain: Interceptor.Chain -> + ready.countDown() + try { + proceed.await(5, TimeUnit.SECONDS) + } catch (e: InterruptedException) { + throw RuntimeException(e) + } + chain.proceed(chain.request()) + }) + .build() + val t1 = makeSynchronousCall(client.newCall(newRequest("http://a/3"))) + ready.await(5, TimeUnit.SECONDS) + executor.finishJob("http://a/2") + assertThat(idle.get()).isFalse() + proceed.countDown() + t1.join() + assertThat(idle.get()).isTrue() + } + + @Test + fun executionRejectedImmediately() { + val request = newRequest("http://a/1") + executor.shutdown() + client.newCall(request).enqueue(callback) + callback.await(request.url).assertFailure(InterruptedIOException::class.java) + assertThat(listener.recordedEventTypes()) + .containsExactly("CallStart", "CallFailed") + } + + @Test + fun executionRejectedAfterMaxRequestsChange() { + val request1 = newRequest("http://a/1") + val request2 = newRequest("http://a/2") + dispatcher.maxRequests = 1 + client.newCall(request1).enqueue(callback) + executor.shutdown() + client.newCall(request2).enqueue(callback) + dispatcher.maxRequests = 2 // Trigger promotion. + callback.await(request2.url).assertFailure(InterruptedIOException::class.java) + assertThat(listener.recordedEventTypes()) + .containsExactly("CallStart", "CallStart", "CallFailed") + } + + @Test + fun executionRejectedAfterMaxRequestsPerHostChange() { + val request1 = newRequest("http://a/1") + val request2 = newRequest("http://a/2") + dispatcher.maxRequestsPerHost = 1 + client.newCall(request1).enqueue(callback) + executor.shutdown() + client.newCall(request2).enqueue(callback) + dispatcher.maxRequestsPerHost = 2 // Trigger promotion. + callback.await(request2.url).assertFailure(InterruptedIOException::class.java) + assertThat(listener.recordedEventTypes()) + .containsExactly("CallStart", "CallStart", "CallFailed") + } + + @Test + fun executionRejectedAfterPrecedingCallFinishes() { + val request1 = newRequest("http://a/1") + val request2 = newRequest("http://a/2") + dispatcher.maxRequests = 1 + client.newCall(request1).enqueue(callback) + executor.shutdown() + client.newCall(request2).enqueue(callback) + executor.finishJob("http://a/1") // Trigger promotion. + callback.await(request2.url).assertFailure(InterruptedIOException::class.java) + assertThat(listener.recordedEventTypes()) + .containsExactly("CallStart", "CallStart", "CallFailed") + } + + private fun makeSynchronousCall(call: Call): Thread { + val thread = Thread { + try { + call.execute() + throw AssertionError() + } catch (expected: IOException) { + } + } + thread.start() + return thread + } + + private fun newRequest(url: String): Request { + return Request.Builder().url(url).build() + } + + private fun newRequest(url: String, tag: String): Request { + return Request.Builder().url(url).tag(tag).build() + } +} diff --git a/okhttp/src/test/java/okhttp3/DuplexTest.java b/okhttp/src/test/java/okhttp3/DuplexTest.java deleted file mode 100644 index 29ad34f89b37..000000000000 --- a/okhttp/src/test/java/okhttp3/DuplexTest.java +++ /dev/null @@ -1,751 +0,0 @@ -/* - * Copyright (C) 2018 Square, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package okhttp3; - -import java.io.IOException; -import java.net.HttpURLConnection; -import java.net.ProtocolException; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.Executors; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.TimeUnit; -import mockwebserver3.MockResponse; -import mockwebserver3.MockWebServer; -import mockwebserver3.internal.duplex.MockStreamHandler; -import okhttp3.internal.RecordingOkAuthenticator; -import okhttp3.internal.duplex.AsyncRequestBody; -import okhttp3.testing.PlatformRule; -import okhttp3.tls.HandshakeCertificates; -import okio.BufferedSink; -import okio.BufferedSource; -import org.jetbrains.annotations.Nullable; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Tag; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.Timeout; -import org.junit.jupiter.api.extension.RegisterExtension; -import static java.util.Arrays.asList; -import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assertions.fail; - -@Timeout(30) -@Tag("Slowish") -public final class DuplexTest { - @RegisterExtension public final PlatformRule platform = new PlatformRule(); - @RegisterExtension public OkHttpClientTestRule clientTestRule = new OkHttpClientTestRule(); - - private MockWebServer server; - private RecordingEventListener listener = new RecordingEventListener(); - private final HandshakeCertificates handshakeCertificates - = platform.localhostHandshakeCertificates(); - private OkHttpClient client = clientTestRule.newClientBuilder() - .eventListenerFactory(clientTestRule.wrap(listener)) - .build(); - - private final ScheduledExecutorService executorService = Executors.newScheduledThreadPool(1); - - @BeforeEach public void setUp(MockWebServer server) { - this.server = server; - platform.assumeNotOpenJSSE(); - platform.assumeHttp2Support(); - } - - @AfterEach public void tearDown() { - executorService.shutdown(); - } - - @Test public void http1DoesntSupportDuplex() throws IOException { - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .post(new AsyncRequestBody()) - .build()); - try { - call.execute(); - fail(); - } catch (ProtocolException expected) { - } - } - - @Test public void trueDuplexClientWritesFirst() throws Exception { - enableProtocol(Protocol.HTTP_2); - MockStreamHandler body = new MockStreamHandler() - .receiveRequest("request A\n") - .sendResponse("response B\n") - .receiveRequest("request C\n") - .sendResponse("response D\n") - .receiveRequest("request E\n") - .sendResponse("response F\n") - .exhaustRequest() - .exhaustResponse(); - server.enqueue(new MockResponse.Builder() - .clearHeaders() - .streamHandler(body) - .build()); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .post(new AsyncRequestBody()) - .build()); - - try (Response response = call.execute()) { - BufferedSink requestBody = ((AsyncRequestBody) call.request().body()).takeSink(); - requestBody.writeUtf8("request A\n"); - requestBody.flush(); - - BufferedSource responseBody = response.body().source(); - assertThat(responseBody.readUtf8Line()).isEqualTo("response B"); - - requestBody.writeUtf8("request C\n"); - requestBody.flush(); - assertThat(responseBody.readUtf8Line()).isEqualTo("response D"); - - requestBody.writeUtf8("request E\n"); - requestBody.flush(); - assertThat(responseBody.readUtf8Line()).isEqualTo("response F"); - - requestBody.close(); - assertThat(responseBody.readUtf8Line()).isNull(); - } - - body.awaitSuccess(); - } - - @Test public void trueDuplexServerWritesFirst() throws Exception { - enableProtocol(Protocol.HTTP_2); - MockStreamHandler body = new MockStreamHandler() - .sendResponse("response A\n") - .receiveRequest("request B\n") - .sendResponse("response C\n") - .receiveRequest("request D\n") - .sendResponse("response E\n") - .receiveRequest("request F\n") - .exhaustResponse() - .exhaustRequest(); - server.enqueue(new MockResponse.Builder() - .clearHeaders() - .streamHandler(body) - .build()); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .post(new AsyncRequestBody()) - .build()); - - try (Response response = call.execute()) { - BufferedSink requestBody = ((AsyncRequestBody) call.request().body()).takeSink(); - BufferedSource responseBody = response.body().source(); - - assertThat(responseBody.readUtf8Line()).isEqualTo("response A"); - requestBody.writeUtf8("request B\n"); - requestBody.flush(); - - assertThat(responseBody.readUtf8Line()).isEqualTo("response C"); - requestBody.writeUtf8("request D\n"); - requestBody.flush(); - - assertThat(responseBody.readUtf8Line()).isEqualTo("response E"); - requestBody.writeUtf8("request F\n"); - requestBody.flush(); - - assertThat(responseBody.readUtf8Line()).isNull(); - requestBody.close(); - } - - body.awaitSuccess(); - } - - @Test public void clientReadsHeadersDataTrailers() throws Exception { - enableProtocol(Protocol.HTTP_2); - MockStreamHandler body = new MockStreamHandler() - .sendResponse("ok") - .exhaustResponse(); - server.enqueue(new MockResponse.Builder() - .clearHeaders() - .addHeader("h1", "v1") - .addHeader("h2", "v2") - .trailers(Headers.of("trailers", "boom")) - .streamHandler(body) - .build()); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - - try (Response response = call.execute()) { - assertThat(response.headers()).isEqualTo(Headers.of("h1", "v1", "h2", "v2")); - - BufferedSource responseBody = response.body().source(); - assertThat(responseBody.readUtf8(2)).isEqualTo("ok"); - assertTrue(responseBody.exhausted()); - assertThat(response.trailers()).isEqualTo(Headers.of("trailers", "boom")); - } - - body.awaitSuccess(); - } - - @Test public void serverReadsHeadersData() throws Exception { - TestUtil.assumeNotWindows(); - - enableProtocol(Protocol.HTTP_2); - MockStreamHandler body = new MockStreamHandler() - .exhaustResponse() - .receiveRequest("hey\n") - .receiveRequest("whats going on\n") - .exhaustRequest(); - server.enqueue(new MockResponse.Builder() - .clearHeaders() - .addHeader("h1", "v1") - .addHeader("h2", "v2") - .streamHandler(body) - .build()); - - Request request = new Request.Builder() - .url(server.url("/")) - .method("POST", new AsyncRequestBody()) - .build(); - Call call = client.newCall(request); - - try (Response response = call.execute()) { - BufferedSink sink = ((AsyncRequestBody) request.body()).takeSink(); - sink.writeUtf8("hey\n"); - sink.writeUtf8("whats going on\n"); - sink.close(); - } - - body.awaitSuccess(); - } - - @Test public void requestBodyEndsAfterResponseBody() throws Exception { - enableProtocol(Protocol.HTTP_2); - MockStreamHandler body = new MockStreamHandler() - .exhaustResponse() - .receiveRequest("request A\n") - .exhaustRequest(); - server.enqueue(new MockResponse.Builder() - .clearHeaders() - .streamHandler(body) - .build()); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .post(new AsyncRequestBody()) - .build()); - - try (Response response = call.execute()) { - BufferedSource responseBody = response.body().source(); - assertTrue(responseBody.exhausted()); - - BufferedSink requestBody = ((AsyncRequestBody) call.request().body()).takeSink(); - requestBody.writeUtf8("request A\n"); - requestBody.close(); - } - - body.awaitSuccess(); - - assertThat(listener.recordedEventTypes()).containsExactly( - "CallStart", "ProxySelectStart", "ProxySelectEnd", "DnsStart", "DnsEnd", "ConnectStart", - "SecureConnectStart", "SecureConnectEnd", "ConnectEnd", "ConnectionAcquired", - "RequestHeadersStart", "RequestHeadersEnd", "RequestBodyStart", "ResponseHeadersStart", - "ResponseHeadersEnd", "ResponseBodyStart", "ResponseBodyEnd", "RequestBodyEnd", - "ConnectionReleased", "CallEnd"); - } - - @Test public void duplexWith100Continue() throws Exception { - enableProtocol(Protocol.HTTP_2); - - MockStreamHandler body = new MockStreamHandler() - .receiveRequest("request body\n") - .sendResponse("response body\n") - .exhaustRequest(); - server.enqueue(new MockResponse.Builder() - .clearHeaders() - .add100Continue() - .streamHandler(body) - .build()); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .header("Expect", "100-continue") - .post(new AsyncRequestBody()) - .build()); - - try (Response response = call.execute()) { - BufferedSink requestBody = ((AsyncRequestBody) call.request().body()).takeSink(); - requestBody.writeUtf8("request body\n"); - requestBody.flush(); - - BufferedSource responseBody = response.body().source(); - assertThat(responseBody.readUtf8Line()).isEqualTo("response body"); - - requestBody.close(); - assertThat(responseBody.readUtf8Line()).isNull(); - } - - body.awaitSuccess(); - } - - /** - * Duplex calls that have follow-ups are weird. By the time we know there's a follow-up we've - * already split off another thread to stream the request body. Because we permit at most one - * exchange at a time we break the request stream out from under that writer. - */ - @Test public void duplexWithRedirect() throws Exception { - enableProtocol(Protocol.HTTP_2); - - CountDownLatch duplexResponseSent = new CountDownLatch(1); - listener = new RecordingEventListener() { - @Override public void responseHeadersEnd(Call call, Response response) { - try { - // Wait for the server to send the duplex response before acting on the 301 response - // and resetting the stream. - duplexResponseSent.await(); - } catch (InterruptedException e) { - throw new AssertionError(); - } - super.responseHeadersEnd(call, response); - } - }; - - client = client.newBuilder() - .eventListener(listener) - .build(); - - MockStreamHandler body = new MockStreamHandler() - .sendResponse("/a has moved!\n", duplexResponseSent) - .requestIOException() - .exhaustResponse(); - server.enqueue(new MockResponse.Builder() - .clearHeaders() - .code(HttpURLConnection.HTTP_MOVED_PERM) - .addHeader("Location: /b") - .streamHandler(body) - .build()); - server.enqueue(new MockResponse.Builder() - .body("this is /b") - .build()); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .post(new AsyncRequestBody()) - .build()); - - try (Response response = call.execute()) { - BufferedSource responseBody = response.body().source(); - assertThat(responseBody.readUtf8Line()).isEqualTo("this is /b"); - } - - BufferedSink requestBody = ((AsyncRequestBody) call.request().body()).takeSink(); - try { - requestBody.writeUtf8("request body\n"); - requestBody.flush(); - fail(); - } catch (IOException expected) { - assertThat(expected.getMessage()).isEqualTo("stream was reset: CANCEL"); - } - - body.awaitSuccess(); - - assertThat(listener.recordedEventTypes()).containsExactly( - "CallStart", "ProxySelectStart", "ProxySelectEnd", "DnsStart", "DnsEnd", "ConnectStart", - "SecureConnectStart", "SecureConnectEnd", "ConnectEnd", "ConnectionAcquired", - "RequestHeadersStart", "RequestHeadersEnd", "RequestBodyStart", "ResponseHeadersStart", - "ResponseHeadersEnd", "ResponseBodyStart", "ResponseBodyEnd", "RequestHeadersStart", - "RequestHeadersEnd", "ResponseHeadersStart", "ResponseHeadersEnd", "ResponseBodyStart", - "ResponseBodyEnd", "ConnectionReleased", "CallEnd", "RequestFailed"); - } - - /** - * Auth requires follow-ups. Unlike redirects, the auth follow-up also has a request body. This - * test makes a single call with two duplex requests! - */ - @Test public void duplexWithAuthChallenge() throws Exception { - enableProtocol(Protocol.HTTP_2); - - String credential = Credentials.basic("jesse", "secret"); - client = client.newBuilder() - .authenticator(new RecordingOkAuthenticator(credential, null)) - .build(); - - MockStreamHandler body1 = new MockStreamHandler() - .sendResponse("please authenticate!\n") - .requestIOException() - .exhaustResponse(); - server.enqueue(new MockResponse.Builder() - .clearHeaders() - .code(HttpURLConnection.HTTP_UNAUTHORIZED) - .streamHandler(body1) - .build()); - MockStreamHandler body = new MockStreamHandler() - .sendResponse("response body\n") - .exhaustResponse() - .receiveRequest("request body\n") - .exhaustRequest(); - server.enqueue(new MockResponse.Builder() - .clearHeaders() - .streamHandler(body) - .build()); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .post(new AsyncRequestBody()) - .build()); - - Response response2 = call.execute(); - - // First duplex request is detached with violence. - BufferedSink requestBody1 = ((AsyncRequestBody) call.request().body()).takeSink(); - try { - requestBody1.writeUtf8("not authenticated\n"); - requestBody1.flush(); - fail(); - } catch (IOException expected) { - assertThat(expected.getMessage()).isEqualTo("stream was reset: CANCEL"); - } - body1.awaitSuccess(); - - // Second duplex request proceeds normally. - BufferedSink requestBody2 = ((AsyncRequestBody) call.request().body()).takeSink(); - requestBody2.writeUtf8("request body\n"); - requestBody2.close(); - BufferedSource responseBody2 = response2.body().source(); - assertThat(responseBody2.readUtf8Line()).isEqualTo("response body"); - assertTrue(responseBody2.exhausted()); - body.awaitSuccess(); - - // No more requests attempted! - ((AsyncRequestBody) call.request().body()).assertNoMoreSinks(); - } - - @Test public void fullCallTimeoutAppliesToSetup() throws Exception { - enableProtocol(Protocol.HTTP_2); - - server.enqueue(new MockResponse.Builder() - .headersDelay(500, TimeUnit.MILLISECONDS) - .build()); - - Request request = new Request.Builder() - .url(server.url("/")) - .post(new AsyncRequestBody()) - .build(); - - Call call = client.newCall(request); - call.timeout().timeout(250, TimeUnit.MILLISECONDS); - try { - call.execute(); - fail(); - } catch (IOException e) { - assertThat(e.getMessage()).isEqualTo("timeout"); - assertTrue(call.isCanceled()); - } - } - - @Test public void fullCallTimeoutDoesNotApplyOnceConnected() throws Exception { - enableProtocol(Protocol.HTTP_2); - - MockStreamHandler body = new MockStreamHandler() - .sendResponse("response A\n") - .sleep(750, TimeUnit.MILLISECONDS) - .sendResponse("response B\n") - .receiveRequest("request C\n") - .exhaustResponse() - .exhaustRequest(); - server.enqueue(new MockResponse.Builder() - .clearHeaders() - .streamHandler(body) - .build()); - - Request request = new Request.Builder() - .url(server.url("/")) - .post(new AsyncRequestBody()) - .build(); - - Call call = client.newCall(request); - call.timeout().timeout(500, TimeUnit.MILLISECONDS); // Long enough for the first TLS handshake. - - try (Response response = call.execute()) { - BufferedSink requestBody = ((AsyncRequestBody) call.request().body()).takeSink(); - - BufferedSource responseBody = response.body().source(); - assertThat(responseBody.readUtf8Line()).isEqualTo("response A"); - assertThat(responseBody.readUtf8Line()).isEqualTo("response B"); - - requestBody.writeUtf8("request C\n"); - requestBody.close(); - assertThat(responseBody.readUtf8Line()).isNull(); - } - - body.awaitSuccess(); - } - - @Test public void duplexWithRewriteInterceptors() throws Exception { - enableProtocol(Protocol.HTTP_2); - MockStreamHandler body = new MockStreamHandler() - .receiveRequest("REQUEST A\n") - .sendResponse("response B\n") - .exhaustRequest() - .exhaustResponse(); - server.enqueue(new MockResponse.Builder() - .clearHeaders() - .streamHandler(body) - .build()); - - client = client.newBuilder() - .addInterceptor(new UppercaseRequestInterceptor()) - .addInterceptor(new UppercaseResponseInterceptor()) - .build(); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .post(new AsyncRequestBody()) - .build()); - - try (Response response = call.execute()) { - BufferedSink requestBody = ((AsyncRequestBody) call.request().body()).takeSink(); - requestBody.writeUtf8("request A\n"); - requestBody.flush(); - - BufferedSource responseBody = response.body().source(); - assertThat(responseBody.readUtf8Line()).isEqualTo("RESPONSE B"); - - requestBody.close(); - assertThat(responseBody.readUtf8Line()).isNull(); - } - - body.awaitSuccess(); - } - - /** - * OkHttp currently doesn't implement failing the request body stream independently of failing the - * corresponding response body stream. This is necessary if we want servers to be able to stop - * inbound data and send an early 400 before the request body completes. - * - * This test sends a slow request that is canceled by the server. It expects the response to still - * be readable after the request stream is canceled. - */ - @Disabled - @Test public void serverCancelsRequestBodyAndSendsResponseBody() throws Exception { - client = client.newBuilder() - .retryOnConnectionFailure(false) - .build(); - - BlockingQueue log = new LinkedBlockingQueue<>(); - - enableProtocol(Protocol.HTTP_2); - MockStreamHandler body = new MockStreamHandler() - .sendResponse("success!") - .exhaustResponse() - .cancelStream(); - server.enqueue(new MockResponse.Builder() - .clearHeaders() - .streamHandler(body) - .build()); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .post(new RequestBody() { - @Override public @Nullable MediaType contentType() { - return null; - } - - @Override public void writeTo(BufferedSink sink) throws IOException { - try { - for (int i = 0; i < 10; i++) { - sink.writeUtf8("."); - sink.flush(); - Thread.sleep(100); - } - } catch (IOException e) { - log.add(e.toString()); - throw e; - } catch (Exception e) { - log.add(e.toString()); - } - } - }) - .build()); - - try (Response response = call.execute()) { - assertThat(response.body().string()).isEqualTo("success!"); - } - - body.awaitSuccess(); - - assertThat(log.take()).contains("StreamResetException: stream was reset: CANCEL"); - } - - /** - * We delay sending the last byte of the request body 1500 ms. The 1000 ms read timeout should - * only elapse 1000 ms after the request body is sent. - */ - @Test public void headersReadTimeoutDoesNotStartUntilLastRequestBodyByteFire() { - enableProtocol(Protocol.HTTP_2); - - server.enqueue(new MockResponse.Builder() - .headersDelay(1500, TimeUnit.MILLISECONDS) - .build()); - - Request request = new Request.Builder() - .url(server.url("/")) - .post(new DelayedRequestBody(RequestBody.create("hello", null), 1500, TimeUnit.MILLISECONDS)) - .build(); - - client = client.newBuilder() - .readTimeout(1000, TimeUnit.MILLISECONDS) - .build(); - - Call call = client.newCall(request); - try { - call.execute(); - fail(); - } catch (IOException e) { - assertThat(e.getMessage()).isEqualTo("timeout"); - } - } - - /** Same as the previous test, but the server stalls sending the response body. */ - @Test public void bodyReadTimeoutDoesNotStartUntilLastRequestBodyByteFire() throws Exception { - enableProtocol(Protocol.HTTP_2); - - server.enqueue(new MockResponse.Builder() - .bodyDelay(1500, TimeUnit.MILLISECONDS) - .body("this should never be received") - .build()); - - Request request = new Request.Builder() - .url(server.url("/")) - .post(new DelayedRequestBody(RequestBody.create("hello", null), 1500, TimeUnit.MILLISECONDS)) - .build(); - - client = client.newBuilder() - .readTimeout(1000, TimeUnit.MILLISECONDS) - .build(); - - Call call = client.newCall(request); - Response response = call.execute(); - try { - response.body().string(); - fail(); - } catch (IOException e) { - assertThat(e.getMessage()).isEqualTo("timeout"); - } - } - - /** - * We delay sending the last byte of the request body 1500 ms. The 1000 ms read timeout shouldn't - * elapse because it shouldn't start until the request body is sent. - */ - @Test public void headersReadTimeoutDoesNotStartUntilLastRequestBodyByteNoFire() throws Exception { - enableProtocol(Protocol.HTTP_2); - - server.enqueue(new MockResponse.Builder() - .headersDelay(500, TimeUnit.MILLISECONDS) - .build()); - - Request request = new Request.Builder() - .url(server.url("/")) - .post(new DelayedRequestBody(RequestBody.create("hello", null), 1500, TimeUnit.MILLISECONDS)) - .build(); - - client = client.newBuilder() - .readTimeout(1000, TimeUnit.MILLISECONDS) - .build(); - - Call call = client.newCall(request); - Response response = call.execute(); - assertThat(response.isSuccessful()).isTrue(); - } - - /** - * We delay sending the last byte of the request body 1500 ms. The 1000 ms read timeout shouldn't - * elapse because it shouldn't start until the request body is sent. - */ - @Test public void bodyReadTimeoutDoesNotStartUntilLastRequestBodyByteNoFire() throws Exception { - enableProtocol(Protocol.HTTP_2); - - server.enqueue(new MockResponse.Builder() - .bodyDelay(500, TimeUnit.MILLISECONDS) - .body("success") - .build()); - - Request request = new Request.Builder() - .url(server.url("/")) - .post(new DelayedRequestBody(RequestBody.create("hello", null), 1500, TimeUnit.MILLISECONDS)) - .build(); - - client = client.newBuilder() - .readTimeout(1000, TimeUnit.MILLISECONDS) - .build(); - - Call call = client.newCall(request); - Response response = call.execute(); - assertThat(response.body().string()).isEqualTo("success"); - } - - /** - * Tests that use this will fail unless boot classpath is set. Ex. {@code - * -Xbootclasspath/p:/tmp/alpn-boot-8.0.0.v20140317} - */ - private void enableProtocol(Protocol protocol) { - enableTls(); - client = client.newBuilder() - .protocols(asList(protocol, Protocol.HTTP_1_1)) - .build(); - server.setProtocols(client.protocols()); - } - - private void enableTls() { - client = client.newBuilder() - .sslSocketFactory( - handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager()) - .hostnameVerifier(new RecordingHostnameVerifier()) - .build(); - server.useHttps(handshakeCertificates.sslSocketFactory()); - } - - private class DelayedRequestBody extends RequestBody { - private final RequestBody delegate; - private final long delayMillis; - - public DelayedRequestBody(RequestBody delegate, long delay, TimeUnit timeUnit) { - this.delegate = delegate; - this.delayMillis = timeUnit.toMillis(delay); - } - - @Override public MediaType contentType() { - return delegate.contentType(); - } - - @Override public boolean isDuplex() { - return true; - } - - @Override public void writeTo(BufferedSink sink) throws IOException { - executorService.schedule(() -> { - try { - delegate.writeTo(sink); - sink.close(); - } catch (IOException e) { - throw new RuntimeException(e); - } - }, delayMillis, TimeUnit.MILLISECONDS); - } - } -} diff --git a/okhttp/src/test/java/okhttp3/DuplexTest.kt b/okhttp/src/test/java/okhttp3/DuplexTest.kt new file mode 100644 index 000000000000..924b07c57670 --- /dev/null +++ b/okhttp/src/test/java/okhttp3/DuplexTest.kt @@ -0,0 +1,761 @@ +/* + * Copyright (C) 2018 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3 + +import java.io.IOException +import java.net.HttpURLConnection +import java.net.ProtocolException +import java.util.concurrent.BlockingQueue +import java.util.concurrent.CountDownLatch +import java.util.concurrent.Executors +import java.util.concurrent.LinkedBlockingQueue +import java.util.concurrent.TimeUnit +import mockwebserver3.MockResponse +import mockwebserver3.MockWebServer +import mockwebserver3.internal.duplex.MockStreamHandler +import okhttp3.Credentials.basic +import okhttp3.Headers.Companion.headersOf +import okhttp3.RequestBody.Companion.toRequestBody +import okhttp3.TestUtil.assumeNotWindows +import okhttp3.internal.RecordingOkAuthenticator +import okhttp3.internal.duplex.AsyncRequestBody +import okhttp3.testing.PlatformRule +import okio.BufferedSink +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Assertions.fail +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Disabled +import org.junit.jupiter.api.Tag +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Timeout +import org.junit.jupiter.api.extension.RegisterExtension + +@Timeout(30) +@Tag("Slowish") +class DuplexTest { + @RegisterExtension + val platform = PlatformRule() + + @RegisterExtension + var clientTestRule = OkHttpClientTestRule() + private lateinit var server: MockWebServer + private var listener = RecordingEventListener() + private val handshakeCertificates = platform.localhostHandshakeCertificates() + private var client = clientTestRule.newClientBuilder() + .eventListenerFactory(clientTestRule.wrap(listener)) + .build() + private val executorService = Executors.newScheduledThreadPool(1) + + @BeforeEach + fun setUp(server: MockWebServer) { + this.server = server + platform.assumeNotOpenJSSE() + platform.assumeHttp2Support() + } + + @AfterEach + fun tearDown() { + executorService.shutdown() + } + + @Test + @Throws(IOException::class) + fun http1DoesntSupportDuplex() { + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .post(AsyncRequestBody()) + .build() + ) + try { + call.execute() + fail() + } catch (expected: ProtocolException) { + } + } + + @Test + fun trueDuplexClientWritesFirst() { + enableProtocol(Protocol.HTTP_2) + val body = MockStreamHandler() + .receiveRequest("request A\n") + .sendResponse("response B\n") + .receiveRequest("request C\n") + .sendResponse("response D\n") + .receiveRequest("request E\n") + .sendResponse("response F\n") + .exhaustRequest() + .exhaustResponse() + server.enqueue( + MockResponse.Builder() + .clearHeaders() + .streamHandler(body) + .build() + ) + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .post(AsyncRequestBody()) + .build() + ) + call.execute().use { response -> + val requestBody = (call.request().body as AsyncRequestBody?)!!.takeSink() + requestBody.writeUtf8("request A\n") + requestBody.flush() + val responseBody = response.body.source() + assertThat(responseBody.readUtf8Line()) + .isEqualTo("response B") + requestBody.writeUtf8("request C\n") + requestBody.flush() + assertThat(responseBody.readUtf8Line()) + .isEqualTo("response D") + requestBody.writeUtf8("request E\n") + requestBody.flush() + assertThat(responseBody.readUtf8Line()) + .isEqualTo("response F") + requestBody.close() + assertThat(responseBody.readUtf8Line()).isNull() + } + body.awaitSuccess() + } + + @Test + fun trueDuplexServerWritesFirst() { + enableProtocol(Protocol.HTTP_2) + val body = MockStreamHandler() + .sendResponse("response A\n") + .receiveRequest("request B\n") + .sendResponse("response C\n") + .receiveRequest("request D\n") + .sendResponse("response E\n") + .receiveRequest("request F\n") + .exhaustResponse() + .exhaustRequest() + server.enqueue( + MockResponse.Builder() + .clearHeaders() + .streamHandler(body) + .build() + ) + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .post(AsyncRequestBody()) + .build() + ) + call.execute().use { response -> + val requestBody = (call.request().body as AsyncRequestBody?)!!.takeSink() + val responseBody = response.body.source() + assertThat(responseBody.readUtf8Line()) + .isEqualTo("response A") + requestBody.writeUtf8("request B\n") + requestBody.flush() + assertThat(responseBody.readUtf8Line()) + .isEqualTo("response C") + requestBody.writeUtf8("request D\n") + requestBody.flush() + assertThat(responseBody.readUtf8Line()) + .isEqualTo("response E") + requestBody.writeUtf8("request F\n") + requestBody.flush() + assertThat(responseBody.readUtf8Line()).isNull() + requestBody.close() + } + body.awaitSuccess() + } + + @Test + fun clientReadsHeadersDataTrailers() { + enableProtocol(Protocol.HTTP_2) + val body = MockStreamHandler() + .sendResponse("ok") + .exhaustResponse() + server.enqueue( + MockResponse.Builder() + .clearHeaders() + .addHeader("h1", "v1") + .addHeader("h2", "v2") + .trailers(headersOf("trailers", "boom")) + .streamHandler(body) + .build() + ) + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + call.execute().use { response -> + assertThat?>(response.headers) + .isEqualTo( + headersOf("h1", "v1", "h2", "v2") + ) + val responseBody = response.body.source() + assertThat(responseBody.readUtf8(2)).isEqualTo("ok") + Assertions.assertTrue(responseBody.exhausted()) + assertThat?>(response.trailers()) + .isEqualTo( + headersOf("trailers", "boom") + ) + } + body.awaitSuccess() + } + + @Test + fun serverReadsHeadersData() { + assumeNotWindows() + enableProtocol(Protocol.HTTP_2) + val body = MockStreamHandler() + .exhaustResponse() + .receiveRequest("hey\n") + .receiveRequest("whats going on\n") + .exhaustRequest() + server.enqueue( + MockResponse.Builder() + .clearHeaders() + .addHeader("h1", "v1") + .addHeader("h2", "v2") + .streamHandler(body) + .build() + ) + val request = Request.Builder() + .url(server.url("/")) + .method("POST", AsyncRequestBody()) + .build() + val call = client.newCall(request) + call.execute().use { response -> + val sink = (request.body as AsyncRequestBody?)!!.takeSink() + sink.writeUtf8("hey\n") + sink.writeUtf8("whats going on\n") + sink.close() + } + body.awaitSuccess() + } + + @Test + fun requestBodyEndsAfterResponseBody() { + enableProtocol(Protocol.HTTP_2) + val body = MockStreamHandler() + .exhaustResponse() + .receiveRequest("request A\n") + .exhaustRequest() + server.enqueue( + MockResponse.Builder() + .clearHeaders() + .streamHandler(body) + .build() + ) + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .post(AsyncRequestBody()) + .build() + ) + call.execute().use { response -> + val responseBody = response.body.source() + Assertions.assertTrue(responseBody.exhausted()) + val requestBody = (call.request().body as AsyncRequestBody?)!!.takeSink() + requestBody.writeUtf8("request A\n") + requestBody.close() + } + body.awaitSuccess() + assertThat(listener.recordedEventTypes()).containsExactly( + "CallStart", "ProxySelectStart", "ProxySelectEnd", "DnsStart", "DnsEnd", "ConnectStart", + "SecureConnectStart", "SecureConnectEnd", "ConnectEnd", "ConnectionAcquired", + "RequestHeadersStart", "RequestHeadersEnd", "RequestBodyStart", "ResponseHeadersStart", + "ResponseHeadersEnd", "ResponseBodyStart", "ResponseBodyEnd", "RequestBodyEnd", + "ConnectionReleased", "CallEnd" + ) + } + + @Test + fun duplexWith100Continue() { + enableProtocol(Protocol.HTTP_2) + val body = MockStreamHandler() + .receiveRequest("request body\n") + .sendResponse("response body\n") + .exhaustRequest() + server.enqueue( + MockResponse.Builder() + .clearHeaders() + .add100Continue() + .streamHandler(body) + .build() + ) + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .header("Expect", "100-continue") + .post(AsyncRequestBody()) + .build() + ) + call.execute().use { response -> + val requestBody = (call.request().body as AsyncRequestBody?)!!.takeSink() + requestBody.writeUtf8("request body\n") + requestBody.flush() + val responseBody = response.body.source() + assertThat(responseBody.readUtf8Line()) + .isEqualTo("response body") + requestBody.close() + assertThat(responseBody.readUtf8Line()).isNull() + } + body.awaitSuccess() + } + + /** + * Duplex calls that have follow-ups are weird. By the time we know there's a follow-up we've + * already split off another thread to stream the request body. Because we permit at most one + * exchange at a time we break the request stream out from under that writer. + */ + @Test + fun duplexWithRedirect() { + enableProtocol(Protocol.HTTP_2) + val duplexResponseSent = CountDownLatch(1) + listener = object : RecordingEventListener() { + override fun responseHeadersEnd(call: Call, response: Response) { + try { + // Wait for the server to send the duplex response before acting on the 301 response + // and resetting the stream. + duplexResponseSent.await() + } catch (e: InterruptedException) { + throw AssertionError() + } + super.responseHeadersEnd(call, response) + } + } + client = client.newBuilder() + .eventListener(listener) + .build() + val body = MockStreamHandler() + .sendResponse("/a has moved!\n", duplexResponseSent) + .requestIOException() + .exhaustResponse() + server.enqueue( + MockResponse.Builder() + .clearHeaders() + .code(HttpURLConnection.HTTP_MOVED_PERM) + .addHeader("Location: /b") + .streamHandler(body) + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("this is /b") + .build() + ) + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .post(AsyncRequestBody()) + .build() + ) + call.execute().use { response -> + val responseBody = response.body.source() + assertThat(responseBody.readUtf8Line()) + .isEqualTo("this is /b") + } + val requestBody = (call.request().body as AsyncRequestBody?)!!.takeSink() + try { + requestBody.writeUtf8("request body\n") + requestBody.flush() + fail() + } catch (expected: IOException) { + assertThat(expected.message) + .isEqualTo("stream was reset: CANCEL") + } + body.awaitSuccess() + assertThat(listener.recordedEventTypes()).containsExactly( + "CallStart", "ProxySelectStart", "ProxySelectEnd", "DnsStart", "DnsEnd", "ConnectStart", + "SecureConnectStart", "SecureConnectEnd", "ConnectEnd", "ConnectionAcquired", + "RequestHeadersStart", "RequestHeadersEnd", "RequestBodyStart", "ResponseHeadersStart", + "ResponseHeadersEnd", "ResponseBodyStart", "ResponseBodyEnd", "RequestHeadersStart", + "RequestHeadersEnd", "ResponseHeadersStart", "ResponseHeadersEnd", "ResponseBodyStart", + "ResponseBodyEnd", "ConnectionReleased", "CallEnd", "RequestFailed" + ) + } + + /** + * Auth requires follow-ups. Unlike redirects, the auth follow-up also has a request body. This + * test makes a single call with two duplex requests! + */ + @Test + fun duplexWithAuthChallenge() { + enableProtocol(Protocol.HTTP_2) + val credential = basic("jesse", "secret") + client = client.newBuilder() + .authenticator(RecordingOkAuthenticator(credential, null)) + .build() + val body1 = MockStreamHandler() + .sendResponse("please authenticate!\n") + .requestIOException() + .exhaustResponse() + server.enqueue( + MockResponse.Builder() + .clearHeaders() + .code(HttpURLConnection.HTTP_UNAUTHORIZED) + .streamHandler(body1) + .build() + ) + val body = MockStreamHandler() + .sendResponse("response body\n") + .exhaustResponse() + .receiveRequest("request body\n") + .exhaustRequest() + server.enqueue( + MockResponse.Builder() + .clearHeaders() + .streamHandler(body) + .build() + ) + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .post(AsyncRequestBody()) + .build() + ) + val response2 = call.execute() + + // First duplex request is detached with violence. + val requestBody1 = (call.request().body as AsyncRequestBody?)!!.takeSink() + try { + requestBody1.writeUtf8("not authenticated\n") + requestBody1.flush() + fail() + } catch (expected: IOException) { + assertThat(expected.message) + .isEqualTo("stream was reset: CANCEL") + } + body1.awaitSuccess() + + // Second duplex request proceeds normally. + val requestBody2 = (call.request().body as AsyncRequestBody?)!!.takeSink() + requestBody2.writeUtf8("request body\n") + requestBody2.close() + val responseBody2 = response2.body.source() + assertThat(responseBody2.readUtf8Line()) + .isEqualTo("response body") + Assertions.assertTrue(responseBody2.exhausted()) + body.awaitSuccess() + + // No more requests attempted! + (call.request().body as AsyncRequestBody?)!!.assertNoMoreSinks() + } + + @Test + fun fullCallTimeoutAppliesToSetup() { + enableProtocol(Protocol.HTTP_2) + server.enqueue( + MockResponse.Builder() + .headersDelay(500, TimeUnit.MILLISECONDS) + .build() + ) + val request = Request.Builder() + .url(server.url("/")) + .post(AsyncRequestBody()) + .build() + val call = client.newCall(request) + call.timeout().timeout(250, TimeUnit.MILLISECONDS) + try { + call.execute() + fail() + } catch (e: IOException) { + assertThat(e.message).isEqualTo("timeout") + Assertions.assertTrue(call.isCanceled()) + } + } + + @Test + fun fullCallTimeoutDoesNotApplyOnceConnected() { + enableProtocol(Protocol.HTTP_2) + val body = MockStreamHandler() + .sendResponse("response A\n") + .sleep(750, TimeUnit.MILLISECONDS) + .sendResponse("response B\n") + .receiveRequest("request C\n") + .exhaustResponse() + .exhaustRequest() + server.enqueue( + MockResponse.Builder() + .clearHeaders() + .streamHandler(body) + .build() + ) + val request = Request.Builder() + .url(server.url("/")) + .post(AsyncRequestBody()) + .build() + val call = client.newCall(request) + call.timeout() + .timeout(500, TimeUnit.MILLISECONDS) // Long enough for the first TLS handshake. + call.execute().use { response -> + val requestBody = (call.request().body as AsyncRequestBody?)!!.takeSink() + val responseBody = response.body.source() + assertThat(responseBody.readUtf8Line()) + .isEqualTo("response A") + assertThat(responseBody.readUtf8Line()) + .isEqualTo("response B") + requestBody.writeUtf8("request C\n") + requestBody.close() + assertThat(responseBody.readUtf8Line()).isNull() + } + body.awaitSuccess() + } + + @Test + fun duplexWithRewriteInterceptors() { + enableProtocol(Protocol.HTTP_2) + val body = MockStreamHandler() + .receiveRequest("REQUEST A\n") + .sendResponse("response B\n") + .exhaustRequest() + .exhaustResponse() + server.enqueue( + MockResponse.Builder() + .clearHeaders() + .streamHandler(body) + .build() + ) + client = client.newBuilder() + .addInterceptor(UppercaseRequestInterceptor()) + .addInterceptor(UppercaseResponseInterceptor()) + .build() + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .post(AsyncRequestBody()) + .build() + ) + call.execute().use { response -> + val requestBody = (call.request().body as AsyncRequestBody?)!!.takeSink() + requestBody.writeUtf8("request A\n") + requestBody.flush() + val responseBody = response.body.source() + assertThat(responseBody.readUtf8Line()) + .isEqualTo("RESPONSE B") + requestBody.close() + assertThat(responseBody.readUtf8Line()).isNull() + } + body.awaitSuccess() + } + + /** + * OkHttp currently doesn't implement failing the request body stream independently of failing the + * corresponding response body stream. This is necessary if we want servers to be able to stop + * inbound data and send an early 400 before the request body completes. + * + * This test sends a slow request that is canceled by the server. It expects the response to still + * be readable after the request stream is canceled. + */ + @Disabled + @Test + fun serverCancelsRequestBodyAndSendsResponseBody() { + client = client.newBuilder() + .retryOnConnectionFailure(false) + .build() + val log: BlockingQueue = LinkedBlockingQueue() + enableProtocol(Protocol.HTTP_2) + val body = MockStreamHandler() + .sendResponse("success!") + .exhaustResponse() + .cancelStream() + server.enqueue( + MockResponse.Builder() + .clearHeaders() + .streamHandler(body) + .build() + ) + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .post(object : RequestBody() { + override fun contentType(): MediaType? { + return null + } + + override fun writeTo(sink: BufferedSink) { + try { + for (i in 0..9) { + sink.writeUtf8(".") + sink.flush() + Thread.sleep(100) + } + } catch (e: IOException) { + log.add(e.toString()) + throw e + } catch (e: Exception) { + log.add(e.toString()) + } + } + }) + .build() + ) + call.execute().use { response -> + assertThat(response.body.string()).isEqualTo("success!") + } + body.awaitSuccess() + assertThat(log.take()) + .contains("StreamResetException: stream was reset: CANCEL") + } + + /** + * We delay sending the last byte of the request body 1500 ms. The 1000 ms read timeout should + * only elapse 1000 ms after the request body is sent. + */ + @Test + fun headersReadTimeoutDoesNotStartUntilLastRequestBodyByteFire() { + enableProtocol(Protocol.HTTP_2) + server.enqueue( + MockResponse.Builder() + .headersDelay(1500, TimeUnit.MILLISECONDS) + .build() + ) + val request = Request.Builder() + .url(server.url("/")) + .post(DelayedRequestBody("hello".toRequestBody(null), 1500, TimeUnit.MILLISECONDS)) + .build() + client = client.newBuilder() + .readTimeout(1000, TimeUnit.MILLISECONDS) + .build() + val call = client.newCall(request) + try { + call.execute() + fail() + } catch (e: IOException) { + assertThat(e.message).isEqualTo("timeout") + } + } + + /** Same as the previous test, but the server stalls sending the response body. */ + @Test + fun bodyReadTimeoutDoesNotStartUntilLastRequestBodyByteFire() { + enableProtocol(Protocol.HTTP_2) + server.enqueue( + MockResponse.Builder() + .bodyDelay(1500, TimeUnit.MILLISECONDS) + .body("this should never be received") + .build() + ) + val request = Request.Builder() + .url(server.url("/")) + .post(DelayedRequestBody("hello".toRequestBody(null), 1500, TimeUnit.MILLISECONDS)) + .build() + client = client.newBuilder() + .readTimeout(1000, TimeUnit.MILLISECONDS) + .build() + val call = client.newCall(request) + val response = call.execute() + try { + response.body.string() + fail() + } catch (e: IOException) { + assertThat(e.message).isEqualTo("timeout") + } + } + + /** + * We delay sending the last byte of the request body 1500 ms. The 1000 ms read timeout shouldn't + * elapse because it shouldn't start until the request body is sent. + */ + @Test + fun headersReadTimeoutDoesNotStartUntilLastRequestBodyByteNoFire() { + enableProtocol(Protocol.HTTP_2) + server.enqueue( + MockResponse.Builder() + .headersDelay(500, TimeUnit.MILLISECONDS) + .build() + ) + val request = Request.Builder() + .url(server.url("/")) + .post(DelayedRequestBody("hello".toRequestBody(null), 1500, TimeUnit.MILLISECONDS)) + .build() + client = client.newBuilder() + .readTimeout(1000, TimeUnit.MILLISECONDS) + .build() + val call = client.newCall(request) + val response = call.execute() + assertThat(response.isSuccessful).isTrue() + } + + /** + * We delay sending the last byte of the request body 1500 ms. The 1000 ms read timeout shouldn't + * elapse because it shouldn't start until the request body is sent. + */ + @Test + fun bodyReadTimeoutDoesNotStartUntilLastRequestBodyByteNoFire() { + enableProtocol(Protocol.HTTP_2) + server.enqueue( + MockResponse.Builder() + .bodyDelay(500, TimeUnit.MILLISECONDS) + .body("success") + .build() + ) + val request = Request.Builder() + .url(server.url("/")) + .post(DelayedRequestBody("hello".toRequestBody(null), 1500, TimeUnit.MILLISECONDS)) + .build() + client = client.newBuilder() + .readTimeout(1000, TimeUnit.MILLISECONDS) + .build() + val call = client.newCall(request) + val response = call.execute() + assertThat(response.body.string()).isEqualTo("success") + } + + /** + * Tests that use this will fail unless boot classpath is set. Ex. `-Xbootclasspath/p:/tmp/alpn-boot-8.0.0.v20140317` + */ + private fun enableProtocol(protocol: Protocol) { + enableTls() + client = client.newBuilder() + .protocols(listOf(protocol, Protocol.HTTP_1_1)) + .build() + server.protocols = client.protocols + } + + private fun enableTls() { + client = client.newBuilder() + .sslSocketFactory( + handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager + ) + .hostnameVerifier(RecordingHostnameVerifier()) + .build() + server.useHttps(handshakeCertificates.sslSocketFactory()) + } + + private inner class DelayedRequestBody( + private val delegate: RequestBody, + delay: Long, + timeUnit: TimeUnit + ) : RequestBody() { + private val delayMillis = timeUnit.toMillis(delay) + + override fun contentType() = delegate.contentType() + + override fun isDuplex() = true + + override fun writeTo(sink: BufferedSink) { + executorService.schedule({ + try { + delegate.writeTo(sink) + sink.close() + } catch (e: IOException) { + throw RuntimeException(e) + } + }, delayMillis, TimeUnit.MILLISECONDS) + } + } +} diff --git a/okhttp/src/test/java/okhttp3/EventListenerTest.java b/okhttp/src/test/java/okhttp3/EventListenerTest.java deleted file mode 100644 index 51780554ce68..000000000000 --- a/okhttp/src/test/java/okhttp3/EventListenerTest.java +++ /dev/null @@ -1,1657 +0,0 @@ -/* - * Copyright (C) 2017 Square, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package okhttp3; - -import java.io.File; -import java.io.IOException; -import java.io.InterruptedIOException; -import java.net.HttpURLConnection; -import java.net.InetAddress; -import java.net.InetSocketAddress; -import java.net.Proxy; -import java.net.UnknownHostException; -import java.time.Duration; -import java.util.Collections; -import java.util.List; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import javax.annotation.Nullable; -import mockwebserver3.MockResponse; -import mockwebserver3.MockWebServer; -import mockwebserver3.SocketPolicy; -import mockwebserver3.SocketPolicy.DisconnectDuringRequestBody; -import mockwebserver3.SocketPolicy.DisconnectDuringResponseBody; -import mockwebserver3.SocketPolicy.FailHandshake; -import okhttp3.CallEvent.CallEnd; -import okhttp3.CallEvent.CallFailed; -import okhttp3.CallEvent.CallStart; -import okhttp3.CallEvent.ConnectEnd; -import okhttp3.CallEvent.ConnectFailed; -import okhttp3.CallEvent.ConnectStart; -import okhttp3.CallEvent.ConnectionAcquired; -import okhttp3.CallEvent.ConnectionReleased; -import okhttp3.CallEvent.DnsEnd; -import okhttp3.CallEvent.DnsStart; -import okhttp3.CallEvent.RequestBodyEnd; -import okhttp3.CallEvent.RequestBodyStart; -import okhttp3.CallEvent.RequestHeadersEnd; -import okhttp3.CallEvent.RequestHeadersStart; -import okhttp3.CallEvent.ResponseBodyEnd; -import okhttp3.CallEvent.ResponseBodyStart; -import okhttp3.CallEvent.ResponseFailed; -import okhttp3.CallEvent.ResponseHeadersEnd; -import okhttp3.CallEvent.ResponseHeadersStart; -import okhttp3.CallEvent.SecureConnectEnd; -import okhttp3.CallEvent.SecureConnectStart; -import okhttp3.internal.DoubleInetAddressDns; -import okhttp3.internal.RecordingOkAuthenticator; -import okhttp3.internal.connection.RealConnectionPool; -import okhttp3.logging.HttpLoggingInterceptor; -import okhttp3.testing.Flaky; -import okhttp3.testing.PlatformRule; -import okhttp3.tls.HandshakeCertificates; -import okio.Buffer; -import okio.BufferedSink; -import org.hamcrest.BaseMatcher; -import org.hamcrest.CoreMatchers; -import org.hamcrest.Description; -import org.hamcrest.Matcher; -import org.hamcrest.MatcherAssert; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Tag; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.Timeout; -import org.junit.jupiter.api.extension.RegisterExtension; -import static java.util.Arrays.asList; -import static org.assertj.core.api.Assertions.assertThat; -import static org.hamcrest.CoreMatchers.any; -import static org.hamcrest.CoreMatchers.equalTo; -import static org.junit.Assume.assumeThat; -import static org.junit.jupiter.api.Assertions.fail; - -@Flaky // STDOUT logging enabled for test -@Timeout(30) -@Tag("Slow") -public final class EventListenerTest { - public static final Matcher anyResponse = CoreMatchers.any(Response.class); - - @RegisterExtension public final PlatformRule platform = new PlatformRule(); - @RegisterExtension public final OkHttpClientTestRule clientTestRule = new OkHttpClientTestRule(); - - private MockWebServer server; - private final RecordingEventListener listener = new RecordingEventListener(); - private final HandshakeCertificates handshakeCertificates - = platform.localhostHandshakeCertificates(); - - private OkHttpClient client = clientTestRule.newClientBuilder() - .eventListenerFactory(clientTestRule.wrap(listener)) - .build(); - private SocksProxy socksProxy; - private Cache cache = null; - - @BeforeEach public void setUp(MockWebServer server) { - this.server = server; - - platform.assumeNotOpenJSSE(); - - listener.forbidLock(RealConnectionPool.Companion.get(client.connectionPool())); - listener.forbidLock(client.dispatcher()); - } - - @AfterEach public void tearDown() throws Exception { - if (socksProxy != null) { - socksProxy.shutdown(); - } - if (cache != null) { - cache.delete(); - } - } - - @Test public void successfulCallEventSequence() throws IOException { - server.enqueue(new MockResponse.Builder() - .body("abc") - .build()); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - Response response = call.execute(); - assertThat(response.code()).isEqualTo(200); - assertThat(response.body().string()).isEqualTo("abc"); - response.body().close(); - - assertThat(listener.recordedEventTypes()).containsExactly("CallStart", - "ProxySelectStart", "ProxySelectEnd", "DnsStart", "DnsEnd", - "ConnectStart", "ConnectEnd", "ConnectionAcquired", "RequestHeadersStart", - "RequestHeadersEnd", "ResponseHeadersStart", "ResponseHeadersEnd", "ResponseBodyStart", - "ResponseBodyEnd", "ConnectionReleased", "CallEnd"); - } - - @Test public void successfulCallEventSequenceForIpAddress() throws IOException { - server.enqueue(new MockResponse.Builder() - .body("abc") - .build()); - - String ipAddress = InetAddress.getLoopbackAddress().getHostAddress(); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/").newBuilder().host(ipAddress).build()) - .build()); - Response response = call.execute(); - assertThat(response.code()).isEqualTo(200); - assertThat(response.body().string()).isEqualTo("abc"); - response.body().close(); - - assertThat(listener.recordedEventTypes()).containsExactly("CallStart", - "ProxySelectStart", "ProxySelectEnd", - "ConnectStart", "ConnectEnd", "ConnectionAcquired", "RequestHeadersStart", - "RequestHeadersEnd", "ResponseHeadersStart", "ResponseHeadersEnd", "ResponseBodyStart", - "ResponseBodyEnd", "ConnectionReleased", "CallEnd"); - } - - @Test public void successfulCallEventSequenceForEnqueue() throws Exception { - server.enqueue(new MockResponse.Builder() - .body("abc") - .build()); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - - final CountDownLatch completionLatch = new CountDownLatch(1); - Callback callback = new Callback() { - @Override public void onFailure(Call call, IOException e) { - completionLatch.countDown(); - } - - @Override public void onResponse(Call call, Response response) { - response.close(); - completionLatch.countDown(); - } - }; - - call.enqueue(callback); - - completionLatch.await(); - - assertThat(listener.recordedEventTypes()).containsExactly("CallStart", - "ProxySelectStart", "ProxySelectEnd", "DnsStart", "DnsEnd", - "ConnectStart", "ConnectEnd", "ConnectionAcquired", "RequestHeadersStart", - "RequestHeadersEnd", "ResponseHeadersStart", "ResponseHeadersEnd", "ResponseBodyStart", - "ResponseBodyEnd", "ConnectionReleased", "CallEnd"); - } - - @Test public void failedCallEventSequence() { - server.enqueue(new MockResponse.Builder() - .headersDelay(2, TimeUnit.SECONDS) - .build()); - - client = client.newBuilder() - .readTimeout(Duration.ofMillis(250)) - .build(); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - try { - call.execute(); - fail(); - } catch (IOException expected) { - assertThat(expected.getMessage()).isIn("timeout", "Read timed out"); - } - - assertThat(listener.recordedEventTypes()).containsExactly("CallStart", - "ProxySelectStart", "ProxySelectEnd", "DnsStart", "DnsEnd", - "ConnectStart", "ConnectEnd", "ConnectionAcquired", "RequestHeadersStart", - "RequestHeadersEnd", "ResponseFailed", "ConnectionReleased", "CallFailed"); - } - - @Test public void failedDribbledCallEventSequence() throws IOException { - server.enqueue(new MockResponse.Builder() - .body("0123456789") - .throttleBody(2, 100, TimeUnit.MILLISECONDS) - .socketPolicy(DisconnectDuringResponseBody.INSTANCE) - .build()); - - client = client.newBuilder() - .protocols(Collections.singletonList(Protocol.HTTP_1_1)) - .readTimeout(Duration.ofMillis(250)) - .build(); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - - Response response = call.execute(); - try { - response.body().string(); - fail(); - } catch (IOException expected) { - assertThat(expected.getMessage()).isEqualTo("unexpected end of stream"); - } - - assertThat(listener.recordedEventTypes()).containsExactly("CallStart", - "ProxySelectStart", "ProxySelectEnd", "DnsStart", "DnsEnd", - "ConnectStart", "ConnectEnd", "ConnectionAcquired", "RequestHeadersStart", - "RequestHeadersEnd", "ResponseHeadersStart", "ResponseHeadersEnd", "ResponseBodyStart", - "ResponseFailed", "ConnectionReleased", "CallFailed"); - ResponseFailed responseFailed = listener.removeUpToEvent(ResponseFailed.class); - assertThat(responseFailed.getIoe().getMessage()).isEqualTo("unexpected end of stream"); - } - - @Test public void canceledCallEventSequence() { - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - call.cancel(); - try { - call.execute(); - fail(); - } catch (IOException expected) { - assertThat(expected.getMessage()).isEqualTo("Canceled"); - } - - assertThat(listener.recordedEventTypes()).containsExactly( - "Canceled", "CallStart", "CallFailed"); - } - - @Test public void cancelAsyncCall() throws IOException { - server.enqueue(new MockResponse.Builder() - .body("abc") - .build()); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - call.enqueue(new Callback() { - @Override public void onFailure(Call call, IOException e) { - } - - @Override public void onResponse(Call call, Response response) throws IOException { - response.close(); - } - }); - call.cancel(); - - assertThat(listener.recordedEventTypes()).contains("Canceled"); - } - - @Test public void multipleCancelsEmitsOnlyOneEvent() throws IOException { - server.enqueue(new MockResponse.Builder() - .body("abc") - .build()); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - call.cancel(); - call.cancel(); - - assertThat(listener.recordedEventTypes()).containsExactly("Canceled"); - } - - private void assertSuccessfulEventOrder(Matcher responseMatcher) throws IOException { - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - Response response = call.execute(); - assertThat(response.code()).isEqualTo(200); - response.body().string(); - response.body().close(); - - assumeThat(response, responseMatcher); - - assertThat(listener.recordedEventTypes()).containsExactly( - "CallStart", "ProxySelectStart", "ProxySelectEnd", - "DnsStart", "DnsEnd", "ConnectStart", - "SecureConnectStart", "SecureConnectEnd", "ConnectEnd", "ConnectionAcquired", - "RequestHeadersStart", "RequestHeadersEnd", "ResponseHeadersStart", "ResponseHeadersEnd", - "ResponseBodyStart", "ResponseBodyEnd", "ConnectionReleased", "CallEnd"); - } - - @Test public void secondCallEventSequence() throws IOException { - enableTlsWithTunnel(); - server.setProtocols(asList(Protocol.HTTP_2, Protocol.HTTP_1_1)); - server.enqueue(new MockResponse()); - server.enqueue(new MockResponse()); - - client.newCall(new Request.Builder() - .url(server.url("/")) - .build()).execute().close(); - - listener.removeUpToEvent(CallEnd.class); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - Response response = call.execute(); - response.close(); - - assertThat(listener.recordedEventTypes()).containsExactly("CallStart", "ConnectionAcquired", - "RequestHeadersStart", "RequestHeadersEnd", "ResponseHeadersStart", "ResponseHeadersEnd", - "ResponseBodyStart", "ResponseBodyEnd", "ConnectionReleased", "CallEnd"); - } - - private void assertBytesReadWritten(RecordingEventListener listener, - @Nullable Matcher requestHeaderLength, @Nullable Matcher requestBodyBytes, - @Nullable Matcher responseHeaderLength, @Nullable Matcher responseBodyBytes) { - - if (requestHeaderLength != null) { - RequestHeadersEnd responseHeadersEnd = listener.removeUpToEvent(RequestHeadersEnd.class); - MatcherAssert.assertThat("request header length", responseHeadersEnd.getHeaderLength(), - requestHeaderLength); - } else { - assertThat(listener.recordedEventTypes()).doesNotContain("RequestHeadersEnd"); - } - - if (requestBodyBytes != null) { - RequestBodyEnd responseBodyEnd = listener.removeUpToEvent(RequestBodyEnd.class); - MatcherAssert.assertThat("request body bytes", responseBodyEnd.getBytesWritten(), requestBodyBytes); - } else { - assertThat(listener.recordedEventTypes()).doesNotContain("RequestBodyEnd"); - } - - if (responseHeaderLength != null) { - ResponseHeadersEnd responseHeadersEnd = listener.removeUpToEvent(ResponseHeadersEnd.class); - MatcherAssert.assertThat("response header length", responseHeadersEnd.getHeaderLength(), - responseHeaderLength); - } else { - assertThat(listener.recordedEventTypes()).doesNotContain("ResponseHeadersEnd"); - } - - if (responseBodyBytes != null) { - ResponseBodyEnd responseBodyEnd = listener.removeUpToEvent(ResponseBodyEnd.class); - MatcherAssert.assertThat("response body bytes", responseBodyEnd.getBytesRead(), responseBodyBytes); - } else { - assertThat(listener.recordedEventTypes()).doesNotContain("ResponseBodyEnd"); - } - } - - private Matcher greaterThan(final long value) { - return new BaseMatcher() { - @Override public void describeTo(Description description) { - description.appendText("> " + value); - } - - @Override public boolean matches(Object o) { - return ((Long) o) > value; - } - }; - } - - private Matcher matchesProtocol(final Protocol protocol) { - return new BaseMatcher() { - @Override public void describeTo(Description description) { - description.appendText("is HTTP/2"); - } - - @Override public boolean matches(Object o) { - return ((Response) o).protocol() == protocol; - } - }; - } - - @Test public void successfulEmptyH2CallEventSequence() throws IOException { - enableTlsWithTunnel(); - server.setProtocols(asList(Protocol.HTTP_2, Protocol.HTTP_1_1)); - server.enqueue(new MockResponse()); - - assertSuccessfulEventOrder(matchesProtocol(Protocol.HTTP_2)); - - assertBytesReadWritten(listener, any(Long.class), null, greaterThan(0L), - equalTo(0L)); - } - - @Test public void successfulEmptyHttpsCallEventSequence() throws IOException { - enableTlsWithTunnel(); - server.setProtocols(asList(Protocol.HTTP_1_1)); - server.enqueue(new MockResponse.Builder() - .body("abc") - .build()); - - assertSuccessfulEventOrder(anyResponse); - - assertBytesReadWritten(listener, any(Long.class), null, greaterThan(0L), - equalTo(3L)); - } - - @Test public void successfulChunkedHttpsCallEventSequence() throws IOException { - enableTlsWithTunnel(); - server.setProtocols(asList(Protocol.HTTP_1_1)); - server.enqueue(new MockResponse.Builder() - .bodyDelay(100, TimeUnit.MILLISECONDS) - .chunkedBody("Hello!", 2) - .build()); - - assertSuccessfulEventOrder(anyResponse); - - assertBytesReadWritten(listener, any(Long.class), null, greaterThan(0L), - equalTo(6L)); - } - - @Test public void successfulChunkedH2CallEventSequence() throws IOException { - enableTlsWithTunnel(); - server.setProtocols(asList(Protocol.HTTP_2, Protocol.HTTP_1_1)); - server.enqueue(new MockResponse.Builder() - .bodyDelay(100, TimeUnit.MILLISECONDS) - .chunkedBody("Hello!", 2) - .build()); - - assertSuccessfulEventOrder(matchesProtocol(Protocol.HTTP_2)); - - assertBytesReadWritten(listener, any(Long.class), null, equalTo(0L), - greaterThan(6L)); - } - - @Test public void successfulDnsLookup() throws IOException { - server.enqueue(new MockResponse()); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - Response response = call.execute(); - assertThat(response.code()).isEqualTo(200); - response.body().close(); - - DnsStart dnsStart = listener.removeUpToEvent(DnsStart.class); - assertThat(dnsStart.getCall()).isSameAs(call); - assertThat(dnsStart.getDomainName()).isEqualTo(server.getHostName()); - - DnsEnd dnsEnd = listener.removeUpToEvent(DnsEnd.class); - assertThat(dnsEnd.getCall()).isSameAs(call); - assertThat(dnsEnd.getDomainName()).isEqualTo(server.getHostName()); - assertThat(dnsEnd.getInetAddressList().size()).isEqualTo(1); - } - - @Test public void noDnsLookupOnPooledConnection() throws IOException { - server.enqueue(new MockResponse()); - server.enqueue(new MockResponse()); - - // Seed the pool. - Call call1 = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - Response response1 = call1.execute(); - assertThat(response1.code()).isEqualTo(200); - response1.body().close(); - - listener.clearAllEvents(); - - Call call2 = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - Response response2 = call2.execute(); - assertThat(response2.code()).isEqualTo(200); - response2.body().close(); - - List recordedEvents = listener.recordedEventTypes(); - assertThat(recordedEvents).doesNotContain("DnsStart"); - assertThat(recordedEvents).doesNotContain("DnsEnd"); - } - - @Test public void multipleDnsLookupsForSingleCall() throws IOException { - server.enqueue(new MockResponse.Builder() - .code(301) - .setHeader("Location", "http://www.fakeurl:" + server.getPort()) - .build()); - server.enqueue(new MockResponse()); - - FakeDns dns = new FakeDns(); - dns.set("fakeurl", client.dns().lookup(server.getHostName())); - dns.set("www.fakeurl", client.dns().lookup(server.getHostName())); - - client = client.newBuilder() - .dns(dns) - .build(); - - Call call = client.newCall(new Request.Builder() - .url("http://fakeurl:" + server.getPort()) - .build()); - Response response = call.execute(); - assertThat(response.code()).isEqualTo(200); - response.body().close(); - - listener.removeUpToEvent(DnsStart.class); - listener.removeUpToEvent(DnsEnd.class); - listener.removeUpToEvent(DnsStart.class); - listener.removeUpToEvent(DnsEnd.class); - } - - @Test public void failedDnsLookup() { - client = client.newBuilder() - .dns(new FakeDns()) - .build(); - Call call = client.newCall(new Request.Builder() - .url("http://fakeurl/") - .build()); - try { - call.execute(); - fail(); - } catch (IOException expected) { - } - - listener.removeUpToEvent(DnsStart.class); - - CallFailed callFailed = listener.removeUpToEvent(CallFailed.class); - assertThat(callFailed.getCall()).isSameAs(call); - assertThat(callFailed.getIoe()).isInstanceOf(UnknownHostException.class); - } - - @Test public void emptyDnsLookup() { - Dns emptyDns = hostname -> Collections.emptyList(); - - client = client.newBuilder() - .dns(emptyDns) - .build(); - Call call = client.newCall(new Request.Builder() - .url("http://fakeurl/") - .build()); - try { - call.execute(); - fail(); - } catch (IOException expected) { - } - - listener.removeUpToEvent(DnsStart.class); - - CallFailed callFailed = listener.removeUpToEvent(CallFailed.class); - assertThat(callFailed.getCall()).isSameAs(call); - assertThat(callFailed.getIoe()).isInstanceOf(UnknownHostException.class); - } - - @Test public void successfulConnect() throws IOException { - server.enqueue(new MockResponse()); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - Response response = call.execute(); - assertThat(response.code()).isEqualTo(200); - response.body().close(); - - InetAddress address = client.dns().lookup(server.getHostName()).get(0); - InetSocketAddress expectedAddress = new InetSocketAddress(address, server.getPort()); - - ConnectStart connectStart = listener.removeUpToEvent(ConnectStart.class); - assertThat(connectStart.getCall()).isSameAs(call); - assertThat(connectStart.getInetSocketAddress()).isEqualTo(expectedAddress); - assertThat(connectStart.getProxy()).isEqualTo(Proxy.NO_PROXY); - - ConnectEnd connectEnd = listener.removeUpToEvent(ConnectEnd.class); - assertThat(connectEnd.getCall()).isSameAs(call); - assertThat(connectEnd.getInetSocketAddress()).isEqualTo(expectedAddress); - assertThat(connectEnd.getProtocol()).isEqualTo(Protocol.HTTP_1_1); - } - - @Test public void failedConnect() throws UnknownHostException { - enableTlsWithTunnel(); - server.enqueue(new MockResponse.Builder() - .socketPolicy(FailHandshake.INSTANCE) - .build()); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - try { - call.execute(); - fail(); - } catch (IOException expected) { - } - - InetAddress address = client.dns().lookup(server.getHostName()).get(0); - InetSocketAddress expectedAddress = new InetSocketAddress(address, server.getPort()); - - ConnectStart connectStart = listener.removeUpToEvent(ConnectStart.class); - assertThat(connectStart.getCall()).isSameAs(call); - assertThat(connectStart.getInetSocketAddress()).isEqualTo(expectedAddress); - assertThat(connectStart.getProxy()).isEqualTo(Proxy.NO_PROXY); - - ConnectFailed connectFailed = listener.removeUpToEvent(ConnectFailed.class); - assertThat(connectFailed.getCall()).isSameAs(call); - assertThat(connectFailed.getInetSocketAddress()).isEqualTo(expectedAddress); - assertThat(connectFailed.getProtocol()).isNull(); - assertThat(connectFailed.getIoe()).isNotNull(); - } - - @Test public void multipleConnectsForSingleCall() throws IOException { - enableTlsWithTunnel(); - server.enqueue(new MockResponse.Builder() - .socketPolicy(FailHandshake.INSTANCE) - .build()); - server.enqueue(new MockResponse()); - - client = client.newBuilder() - .dns(new DoubleInetAddressDns()) - .build(); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - Response response = call.execute(); - assertThat(response.code()).isEqualTo(200); - response.body().close(); - - listener.removeUpToEvent(ConnectStart.class); - listener.removeUpToEvent(ConnectFailed.class); - listener.removeUpToEvent(ConnectStart.class); - listener.removeUpToEvent(ConnectEnd.class); - } - - @Test public void successfulHttpProxyConnect() throws IOException { - server.enqueue(new MockResponse()); - - client = client.newBuilder() - .proxy(server.toProxyAddress()) - .build(); - - Call call = client.newCall(new Request.Builder() - .url("http://www.fakeurl") - .build()); - Response response = call.execute(); - assertThat(response.code()).isEqualTo(200); - response.body().close(); - - InetAddress address = client.dns().lookup(server.getHostName()).get(0); - InetSocketAddress expectedAddress = new InetSocketAddress(address, server.getPort()); - - ConnectStart connectStart = listener.removeUpToEvent(ConnectStart.class); - assertThat(connectStart.getCall()).isSameAs(call); - assertThat(connectStart.getInetSocketAddress()).isEqualTo(expectedAddress); - assertThat(connectStart.getProxy()).isEqualTo(server.toProxyAddress()); - - ConnectEnd connectEnd = listener.removeUpToEvent(ConnectEnd.class); - assertThat(connectEnd.getCall()).isSameAs(call); - assertThat(connectEnd.getInetSocketAddress()).isEqualTo(expectedAddress); - assertThat(connectEnd.getProtocol()).isEqualTo(Protocol.HTTP_1_1); - } - - @Test public void successfulSocksProxyConnect() throws Exception { - server.enqueue(new MockResponse()); - - socksProxy = new SocksProxy(); - socksProxy.play(); - Proxy proxy = socksProxy.proxy(); - - client = client.newBuilder() - .proxy(proxy) - .build(); - - Call call = client.newCall(new Request.Builder() - .url("http://" + SocksProxy.HOSTNAME_THAT_ONLY_THE_PROXY_KNOWS + ":" + server.getPort()) - .build()); - Response response = call.execute(); - assertThat(response.code()).isEqualTo(200); - response.body().close(); - - InetSocketAddress expectedAddress = InetSocketAddress.createUnresolved( - SocksProxy.HOSTNAME_THAT_ONLY_THE_PROXY_KNOWS, server.getPort()); - - ConnectStart connectStart = listener.removeUpToEvent(ConnectStart.class); - assertThat(connectStart.getCall()).isSameAs(call); - assertThat(connectStart.getInetSocketAddress()).isEqualTo(expectedAddress); - assertThat(connectStart.getProxy()).isEqualTo(proxy); - - ConnectEnd connectEnd = listener.removeUpToEvent(ConnectEnd.class); - assertThat(connectEnd.getCall()).isSameAs(call); - assertThat(connectEnd.getInetSocketAddress()).isEqualTo(expectedAddress); - assertThat(connectEnd.getProtocol()).isEqualTo(Protocol.HTTP_1_1); - } - - @Test public void authenticatingTunnelProxyConnect() throws IOException { - enableTlsWithTunnel(); - server.enqueue(new MockResponse.Builder() - .inTunnel() - .code(407) - .addHeader("Proxy-Authenticate: Basic realm=\"localhost\"") - .addHeader("Connection: close") - .build()); - server.enqueue(new MockResponse.Builder() - .inTunnel() - .build()); - server.enqueue(new MockResponse()); - - client = client.newBuilder() - .proxy(server.toProxyAddress()) - .proxyAuthenticator(new RecordingOkAuthenticator("password", "Basic")) - .build(); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - Response response = call.execute(); - assertThat(response.code()).isEqualTo(200); - response.body().close(); - - listener.removeUpToEvent(ConnectStart.class); - - ConnectEnd connectEnd = listener.removeUpToEvent(ConnectEnd.class); - assertThat(connectEnd.getProtocol()).isNull(); - - listener.removeUpToEvent(ConnectStart.class); - listener.removeUpToEvent(ConnectEnd.class); - } - - @Test public void successfulSecureConnect() throws IOException { - enableTlsWithTunnel(); - server.enqueue(new MockResponse()); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - Response response = call.execute(); - assertThat(response.code()).isEqualTo(200); - response.body().close(); - - SecureConnectStart secureStart = listener.removeUpToEvent(SecureConnectStart.class); - assertThat(secureStart.getCall()).isSameAs(call); - - SecureConnectEnd secureEnd = listener.removeUpToEvent(SecureConnectEnd.class); - assertThat(secureEnd.getCall()).isSameAs(call); - assertThat(secureEnd.getHandshake()).isNotNull(); - } - - @Test public void failedSecureConnect() { - enableTlsWithTunnel(); - server.enqueue(new MockResponse.Builder() - .socketPolicy(FailHandshake.INSTANCE) - .build()); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - try { - call.execute(); - fail(); - } catch (IOException expected) { - } - - SecureConnectStart secureStart = listener.removeUpToEvent(SecureConnectStart.class); - assertThat(secureStart.getCall()).isSameAs(call); - - CallFailed callFailed = listener.removeUpToEvent(CallFailed.class); - assertThat(callFailed.getCall()).isSameAs(call); - assertThat(callFailed.getIoe()).isNotNull(); - } - - @Test public void secureConnectWithTunnel() throws IOException { - enableTlsWithTunnel(); - server.enqueue(new MockResponse.Builder() - .inTunnel() - .build()); - server.enqueue(new MockResponse()); - - client = client.newBuilder() - .proxy(server.toProxyAddress()) - .build(); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - Response response = call.execute(); - assertThat(response.code()).isEqualTo(200); - response.body().close(); - - SecureConnectStart secureStart = listener.removeUpToEvent(SecureConnectStart.class); - assertThat(secureStart.getCall()).isSameAs(call); - - SecureConnectEnd secureEnd = listener.removeUpToEvent(SecureConnectEnd.class); - assertThat(secureEnd.getCall()).isSameAs(call); - assertThat(secureEnd.getHandshake()).isNotNull(); - } - - @Test public void multipleSecureConnectsForSingleCall() throws IOException { - enableTlsWithTunnel(); - server.enqueue(new MockResponse.Builder() - .socketPolicy(FailHandshake.INSTANCE) - .build()); - server.enqueue(new MockResponse()); - - client = client.newBuilder() - .dns(new DoubleInetAddressDns()) - .build(); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - Response response = call.execute(); - assertThat(response.code()).isEqualTo(200); - response.body().close(); - - listener.removeUpToEvent(SecureConnectStart.class); - listener.removeUpToEvent(ConnectFailed.class); - - listener.removeUpToEvent(SecureConnectStart.class); - listener.removeUpToEvent(SecureConnectEnd.class); - } - - @Test public void noSecureConnectsOnPooledConnection() throws IOException { - enableTlsWithTunnel(); - server.enqueue(new MockResponse()); - server.enqueue(new MockResponse()); - - client = client.newBuilder() - .dns(new DoubleInetAddressDns()) - .build(); - - // Seed the pool. - Call call1 = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - Response response1 = call1.execute(); - assertThat(response1.code()).isEqualTo(200); - response1.body().close(); - - listener.clearAllEvents(); - - Call call2 = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - Response response2 = call2.execute(); - assertThat(response2.code()).isEqualTo(200); - response2.body().close(); - - List recordedEvents = listener.recordedEventTypes(); - assertThat(recordedEvents).doesNotContain("SecureConnectStart"); - assertThat(recordedEvents).doesNotContain("SecureConnectEnd"); - } - - @Test public void successfulConnectionFound() throws IOException { - server.enqueue(new MockResponse()); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - Response response = call.execute(); - assertThat(response.code()).isEqualTo(200); - response.body().close(); - - ConnectionAcquired connectionAcquired = listener.removeUpToEvent(ConnectionAcquired.class); - assertThat(connectionAcquired.getCall()).isSameAs(call); - assertThat(connectionAcquired.getConnection()).isNotNull(); - } - - @Test public void noConnectionFoundOnFollowUp() throws IOException { - server.enqueue(new MockResponse.Builder() - .code(301) - .addHeader("Location", "/foo") - .build()); - server.enqueue(new MockResponse.Builder() - .body("ABC") - .build()); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - Response response = call.execute(); - assertThat(response.body().string()).isEqualTo("ABC"); - - listener.removeUpToEvent(ConnectionAcquired.class); - - List remainingEvents = listener.recordedEventTypes(); - assertThat(remainingEvents).doesNotContain("ConnectionAcquired"); - } - - @Test public void pooledConnectionFound() throws IOException { - server.enqueue(new MockResponse()); - server.enqueue(new MockResponse()); - - // Seed the pool. - Call call1 = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - Response response1 = call1.execute(); - assertThat(response1.code()).isEqualTo(200); - response1.body().close(); - - ConnectionAcquired connectionAcquired1 = listener.removeUpToEvent(ConnectionAcquired.class); - listener.clearAllEvents(); - - Call call2 = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - Response response2 = call2.execute(); - assertThat(response2.code()).isEqualTo(200); - response2.body().close(); - - ConnectionAcquired connectionAcquired2 = listener.removeUpToEvent(ConnectionAcquired.class); - assertThat(connectionAcquired2.getConnection()).isSameAs( - connectionAcquired1.getConnection()); - } - - @Test public void multipleConnectionsFoundForSingleCall() throws IOException { - server.enqueue(new MockResponse.Builder() - .code(301) - .addHeader("Location", "/foo") - .addHeader("Connection", "Close") - .build()); - server.enqueue(new MockResponse.Builder() - .body("ABC") - .build()); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - Response response = call.execute(); - assertThat(response.body().string()).isEqualTo("ABC"); - - listener.removeUpToEvent(ConnectionAcquired.class); - listener.removeUpToEvent(ConnectionAcquired.class); - } - - @Test public void responseBodyFailHttp1OverHttps() throws IOException { - enableTlsWithTunnel(); - server.setProtocols(asList(Protocol.HTTP_1_1)); - responseBodyFail(Protocol.HTTP_1_1); - } - - @Test public void responseBodyFailHttp2OverHttps() throws IOException { - platform.assumeHttp2Support(); - - enableTlsWithTunnel(); - server.setProtocols(asList(Protocol.HTTP_2, Protocol.HTTP_1_1)); - responseBodyFail(Protocol.HTTP_2); - } - - @Test public void responseBodyFailHttp() throws IOException { - responseBodyFail(Protocol.HTTP_1_1); - } - - private void responseBodyFail(Protocol expectedProtocol) throws IOException { - // Use a 2 MiB body so the disconnect won't happen until the client has read some data. - int responseBodySize = 2 * 1024 * 1024; // 2 MiB - server.enqueue(new MockResponse.Builder() - .body(new Buffer().write(new byte[responseBodySize])) - .socketPolicy(DisconnectDuringResponseBody.INSTANCE) - .build()); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - Response response = call.execute(); - if (expectedProtocol == Protocol.HTTP_2) { - // soft failure since client may not support depending on Platform - assumeThat(response, matchesProtocol(Protocol.HTTP_2)); - } - assertThat(response.protocol()).isEqualTo(expectedProtocol); - try { - response.body().string(); - fail(); - } catch (IOException expected) { - } - - CallFailed callFailed = listener.removeUpToEvent(CallFailed.class); - assertThat(callFailed.getIoe()).isNotNull(); - } - - @Test public void emptyResponseBody() throws IOException { - server.enqueue(new MockResponse.Builder() - .body("") - .bodyDelay(1, TimeUnit.SECONDS) - .socketPolicy(DisconnectDuringResponseBody.INSTANCE) - .build()); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - Response response = call.execute(); - response.body().close(); - - assertThat(listener.recordedEventTypes()).containsExactly("CallStart", - "ProxySelectStart", "ProxySelectEnd", "DnsStart", "DnsEnd", - "ConnectStart", "ConnectEnd", "ConnectionAcquired", "RequestHeadersStart", - "RequestHeadersEnd", "ResponseHeadersStart", "ResponseHeadersEnd", "ResponseBodyStart", - "ResponseBodyEnd", "ConnectionReleased", "CallEnd"); - } - - @Test public void emptyResponseBodyConnectionClose() throws IOException { - server.enqueue(new MockResponse.Builder() - .addHeader("Connection", "close") - .body("") - .build()); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - Response response = call.execute(); - response.body().close(); - - assertThat(listener.recordedEventTypes()).containsExactly("CallStart", - "ProxySelectStart", "ProxySelectEnd", "DnsStart", "DnsEnd", - "ConnectStart", "ConnectEnd", "ConnectionAcquired", "RequestHeadersStart", - "RequestHeadersEnd", "ResponseHeadersStart", "ResponseHeadersEnd", "ResponseBodyStart", - "ResponseBodyEnd", "ConnectionReleased", "CallEnd"); - } - - @Test public void responseBodyClosedClosedWithoutReadingAllData() throws IOException { - server.enqueue(new MockResponse.Builder() - .body("abc") - .bodyDelay(1, TimeUnit.SECONDS) - .socketPolicy(DisconnectDuringResponseBody.INSTANCE) - .build()); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - Response response = call.execute(); - response.body().close(); - - assertThat(listener.recordedEventTypes()).containsExactly("CallStart", - "ProxySelectStart", "ProxySelectEnd", "DnsStart", "DnsEnd", - "ConnectStart", "ConnectEnd", "ConnectionAcquired", "RequestHeadersStart", - "RequestHeadersEnd", "ResponseHeadersStart", "ResponseHeadersEnd", "ResponseBodyStart", - "ResponseBodyEnd", "ConnectionReleased", "CallEnd"); - } - - @Test public void requestBodyFailHttp1OverHttps() { - enableTlsWithTunnel(); - server.setProtocols(asList(Protocol.HTTP_1_1)); - - requestBodyFail(Protocol.HTTP_1_1); - } - - @Test public void requestBodyFailHttp2OverHttps() { - platform.assumeHttp2Support(); - - enableTlsWithTunnel(); - server.setProtocols(asList(Protocol.HTTP_2, Protocol.HTTP_1_1)); - - requestBodyFail(Protocol.HTTP_2); - } - - @Test public void requestBodyFailHttp() { - requestBodyFail(null); - } - - private void requestBodyFail(@Nullable Protocol expectedProtocol) { - server.enqueue(new MockResponse.Builder() - .socketPolicy(DisconnectDuringRequestBody.INSTANCE) - .build()); - - NonCompletingRequestBody request = new NonCompletingRequestBody(); - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .post(request) - .build()); - try { - call.execute(); - fail(); - } catch (IOException expected) { - } - - if (expectedProtocol != null) { - ConnectionAcquired connectionAcquired = listener.removeUpToEvent(ConnectionAcquired.class); - assertThat(connectionAcquired.getConnection().protocol()).isEqualTo(expectedProtocol); - } - - CallFailed callFailed = listener.removeUpToEvent(CallFailed.class); - assertThat(callFailed.getIoe()).isNotNull(); - - assertThat(request.ioe).isNotNull(); - } - - private class NonCompletingRequestBody extends RequestBody { - private final byte[] chunk = new byte[1024 * 1024]; - IOException ioe; - - @Override public MediaType contentType() { - return MediaType.get("text/plain"); - } - - @Override public long contentLength() { - return chunk.length * 8L; - } - - @Override public void writeTo(BufferedSink sink) throws IOException { - try { - for (int i = 0; i < contentLength(); i += chunk.length) { - sink.write(chunk); - sink.flush(); - Thread.sleep(100); - } - } catch (IOException e) { - ioe = e; - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - } - } - - @Test public void requestBodyMultipleFailuresReportedOnlyOnce() { - RequestBody requestBody = new RequestBody() { - @Override public MediaType contentType() { - return MediaType.get("text/plain"); - } - - @Override public long contentLength() { - return 1024 * 1024 * 256; - } - - @Override public void writeTo(BufferedSink sink) throws IOException { - int failureCount = 0; - for (int i = 0; i < 1024; i++) { - try { - sink.write(new byte[1024 * 256]); - sink.flush(); - } catch (IOException e) { - failureCount++; - if (failureCount == 3) throw e; - } - } - } - }; - - server.enqueue(new MockResponse.Builder() - .socketPolicy(DisconnectDuringRequestBody.INSTANCE) - .build()); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .post(requestBody) - .build()); - try { - call.execute(); - fail(); - } catch (IOException expected) { - } - - assertThat(listener.recordedEventTypes()).containsExactly( - "CallStart", "ProxySelectStart", "ProxySelectEnd", "DnsStart", "DnsEnd", "ConnectStart", - "ConnectEnd", "ConnectionAcquired", "RequestHeadersStart", "RequestHeadersEnd", - "RequestBodyStart", "RequestFailed", "ResponseFailed", "ConnectionReleased", "CallFailed"); - } - - @Test public void requestBodySuccessHttp1OverHttps() throws IOException { - enableTlsWithTunnel(); - server.setProtocols(asList(Protocol.HTTP_1_1)); - requestBodySuccess(RequestBody.create("Hello", MediaType.get("text/plain")), equalTo(5L), - equalTo(19L)); - } - - @Test public void requestBodySuccessHttp2OverHttps() throws IOException { - platform.assumeHttp2Support(); - - enableTlsWithTunnel(); - server.setProtocols(asList(Protocol.HTTP_2, Protocol.HTTP_1_1)); - requestBodySuccess(RequestBody.create("Hello", MediaType.get("text/plain")), equalTo(5L), - equalTo(19L)); - } - - @Test public void requestBodySuccessHttp() throws IOException { - requestBodySuccess(RequestBody.create("Hello", MediaType.get("text/plain")), equalTo(5L), - equalTo(19L)); - } - - @Test public void requestBodySuccessStreaming() throws IOException { - RequestBody requestBody = new RequestBody() { - @Override public MediaType contentType() { - return MediaType.get("text/plain"); - } - - @Override public void writeTo(BufferedSink sink) throws IOException { - sink.write(new byte[8192]); - sink.flush(); - } - }; - - requestBodySuccess(requestBody, equalTo(8192L), equalTo(19L)); - } - - @Test public void requestBodySuccessEmpty() throws IOException { - requestBodySuccess(RequestBody.create("", MediaType.get("text/plain")), equalTo(0L), - equalTo(19L)); - } - - @Test public void successfulCallEventSequenceWithListener() throws IOException { - server.enqueue(new MockResponse.Builder() - .body("abc") - .build()); - - client = client.newBuilder() - .addNetworkInterceptor(new HttpLoggingInterceptor() - .setLevel(HttpLoggingInterceptor.Level.BODY)) - .build(); - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - Response response = call.execute(); - assertThat(response.code()).isEqualTo(200); - assertThat(response.body().string()).isEqualTo("abc"); - response.body().close(); - - assertThat(listener.recordedEventTypes()).containsExactly("CallStart", - "ProxySelectStart", "ProxySelectEnd", "DnsStart", "DnsEnd", - "ConnectStart", "ConnectEnd", "ConnectionAcquired", "RequestHeadersStart", - "RequestHeadersEnd", "ResponseHeadersStart", "ResponseHeadersEnd", "ResponseBodyStart", - "ResponseBodyEnd", "ConnectionReleased", "CallEnd"); - } - - private void requestBodySuccess(RequestBody body, Matcher requestBodyBytes, - Matcher responseHeaderLength) throws IOException { - server.enqueue(new MockResponse.Builder() - .code(200) - .body("World!") - .build()); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .post(body) - .build()); - Response response = call.execute(); - assertThat(response.body().string()).isEqualTo("World!"); - - assertBytesReadWritten(listener, any(Long.class), requestBodyBytes, responseHeaderLength, - equalTo(6L)); - } - - @Test public void timeToFirstByteHttp1OverHttps() throws IOException { - enableTlsWithTunnel(); - server.setProtocols(asList(Protocol.HTTP_1_1)); - - timeToFirstByte(); - } - - @Test public void timeToFirstByteHttp2OverHttps() throws IOException { - platform.assumeHttp2Support(); - enableTlsWithTunnel(); - server.setProtocols(asList(Protocol.HTTP_2, Protocol.HTTP_1_1)); - - timeToFirstByte(); - } - - /** - * Test to confirm that events are reported at the time they occur and no earlier and no later. - * This inserts a bunch of synthetic 250 ms delays into both client and server and confirms that - * the same delays make it back into the events. - * - * We've had bugs where we report an event when we request data rather than when the data actually - * arrives. https://github.com/square/okhttp/issues/5578 - */ - private void timeToFirstByte() throws IOException { - long applicationInterceptorDelay = 250L; - long networkInterceptorDelay = 250L; - long requestBodyDelay = 250L; - long responseHeadersStartDelay = 250L; - long responseBodyStartDelay = 250L; - long responseBodyEndDelay = 250L; - - // Warm up the client so the timing part of the test gets a pooled connection. - server.enqueue(new MockResponse()); - Call warmUpCall = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - try (Response warmUpResponse = warmUpCall.execute()) { - warmUpResponse.body().string(); - } - listener.clearAllEvents(); - - // Create a client with artificial delays. - client = client.newBuilder() - .addInterceptor(chain -> { - try { - Thread.sleep(applicationInterceptorDelay); - return chain.proceed(chain.request()); - } catch (InterruptedException e) { - throw new InterruptedIOException(); - } - }) - .addNetworkInterceptor(chain -> { - try { - Thread.sleep(networkInterceptorDelay); - return chain.proceed(chain.request()); - } catch (InterruptedException e) { - throw new InterruptedIOException(); - } - }) - .build(); - - // Create a request body with artificial delays. - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .post(new RequestBody() { - @Override public @Nullable MediaType contentType() { - return null; - } - - @Override public void writeTo(BufferedSink sink) throws IOException { - try { - Thread.sleep(requestBodyDelay); - sink.writeUtf8("abc"); - } catch (InterruptedException e) { - throw new InterruptedIOException(); - } - } - }) - .build()); - - // Create a response with artificial delays. - server.enqueue(new MockResponse.Builder() - .headersDelay(responseHeadersStartDelay, TimeUnit.MILLISECONDS) - .bodyDelay(responseBodyStartDelay, TimeUnit.MILLISECONDS) - .throttleBody(5, responseBodyEndDelay, TimeUnit.MILLISECONDS) - .body("fghijk") - .build()); - - // Make the call. - try (Response response = call.execute()) { - assertThat(response.body().string()).isEqualTo("fghijk"); - } - - // Confirm the events occur when expected. - listener.takeEvent(CallStart.class, 0L); - listener.takeEvent(ConnectionAcquired.class, applicationInterceptorDelay); - listener.takeEvent(RequestHeadersStart.class, networkInterceptorDelay); - listener.takeEvent(RequestHeadersEnd.class, 0L); - listener.takeEvent(RequestBodyStart.class, 0L); - listener.takeEvent(RequestBodyEnd.class, requestBodyDelay); - listener.takeEvent(ResponseHeadersStart.class, responseHeadersStartDelay); - listener.takeEvent(ResponseHeadersEnd.class, 0L); - listener.takeEvent(ResponseBodyStart.class, responseBodyStartDelay); - listener.takeEvent(ResponseBodyEnd.class, responseBodyEndDelay); - listener.takeEvent(ConnectionReleased.class, 0L); - listener.takeEvent(CallEnd.class, 0L); - } - - private void enableTlsWithTunnel() { - client = client.newBuilder() - .sslSocketFactory( - handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager()) - .hostnameVerifier(new RecordingHostnameVerifier()) - .build(); - server.useHttps(handshakeCertificates.sslSocketFactory()); - } - - @Test public void redirectUsingSameConnectionEventSequence() throws IOException { - server.enqueue(new MockResponse.Builder() - .code(HttpURLConnection.HTTP_MOVED_TEMP) - .addHeader("Location: /foo") - .build()); - server.enqueue(new MockResponse()); - - Call call = client.newCall(new Request.Builder().url(server.url("/")).build()); - call.execute(); - - assertThat(listener.recordedEventTypes()).containsExactly("CallStart", - "ProxySelectStart", "ProxySelectEnd", "DnsStart", "DnsEnd", - "ConnectStart", "ConnectEnd", "ConnectionAcquired", "RequestHeadersStart", - "RequestHeadersEnd", "ResponseHeadersStart", "ResponseHeadersEnd", "ResponseBodyStart", - "ResponseBodyEnd", "RequestHeadersStart", "RequestHeadersEnd", "ResponseHeadersStart", - "ResponseHeadersEnd", "ResponseBodyStart", "ResponseBodyEnd", "ConnectionReleased", - "CallEnd"); - } - - @Test - public void redirectUsingNewConnectionEventSequence() throws IOException { - MockWebServer otherServer = new MockWebServer(); - server.enqueue( - new MockResponse.Builder() - .code(HttpURLConnection.HTTP_MOVED_TEMP) - .addHeader("Location: " + otherServer.url("/foo")) - .build()); - otherServer.enqueue(new MockResponse()); - - Call call = client.newCall(new Request.Builder().url(server.url("/")).build()); - call.execute(); - - assertThat(listener.recordedEventTypes()).containsExactly("CallStart", - "ProxySelectStart", "ProxySelectEnd", "DnsStart", "DnsEnd", - "ConnectStart", "ConnectEnd", "ConnectionAcquired", "RequestHeadersStart", - "RequestHeadersEnd", "ResponseHeadersStart", "ResponseHeadersEnd", "ResponseBodyStart", - "ResponseBodyEnd", "ConnectionReleased", "ProxySelectStart", "ProxySelectEnd", - "DnsStart", "DnsEnd", "ConnectStart", "ConnectEnd", - "ConnectionAcquired", "RequestHeadersStart", "RequestHeadersEnd", "ResponseHeadersStart", - "ResponseHeadersEnd", "ResponseBodyStart", "ResponseBodyEnd", "ConnectionReleased", - "CallEnd"); - } - - @Test public void applicationInterceptorProceedsMultipleTimes() throws Exception { - server.enqueue(new MockResponse.Builder().body("a").build()); - server.enqueue(new MockResponse.Builder().body("b").build()); - - client = client.newBuilder() - .addInterceptor(chain -> { - try (Response a = chain.proceed(chain.request())) { - assertThat(a.body().string()).isEqualTo("a"); - } - return chain.proceed(chain.request()); - }) - .build(); - - Call call = client.newCall(new Request.Builder().url(server.url("/")).build()); - Response response = call.execute(); - assertThat(response.body().string()).isEqualTo("b"); - - assertThat(listener.recordedEventTypes()).containsExactly("CallStart", - "ProxySelectStart", "ProxySelectEnd", "DnsStart", "DnsEnd", - "ConnectStart", "ConnectEnd", "ConnectionAcquired", "RequestHeadersStart", - "RequestHeadersEnd", "ResponseHeadersStart", "ResponseHeadersEnd", "ResponseBodyStart", - "ResponseBodyEnd", "RequestHeadersStart", "RequestHeadersEnd", "ResponseHeadersStart", - "ResponseHeadersEnd", "ResponseBodyStart", "ResponseBodyEnd", "ConnectionReleased", - "CallEnd"); - - assertThat(server.takeRequest().getSequenceNumber()).isEqualTo(0); - assertThat(server.takeRequest().getSequenceNumber()).isEqualTo(1); - } - - @Test public void applicationInterceptorShortCircuit() throws Exception { - client = client.newBuilder() - .addInterceptor(chain -> new Response.Builder() - .request(chain.request()) - .protocol(Protocol.HTTP_1_1) - .code(200) - .message("OK") - .body(ResponseBody.create("a", null)) - .build()) - .build(); - - Call call = client.newCall(new Request.Builder().url(server.url("/")).build()); - Response response = call.execute(); - assertThat(response.body().string()).isEqualTo("a"); - - assertThat(listener.recordedEventTypes()).containsExactly("CallStart", "CallEnd"); - } - - /** Response headers start, then the entire request body, then response headers end. */ - @Test public void expectContinueStartsResponseHeadersEarly() throws Exception { - server.enqueue(new MockResponse.Builder() - .add100Continue() - .build()); - - Request request = new Request.Builder() - .url(server.url("/")) - .header("Expect", "100-continue") - .post(RequestBody.create("abc", MediaType.get("text/plain"))) - .build(); - - Call call = client.newCall(request); - call.execute(); - - assertThat(listener.recordedEventTypes()).containsExactly( - "CallStart", "ProxySelectStart", "ProxySelectEnd", "DnsStart", "DnsEnd", "ConnectStart", - "ConnectEnd", "ConnectionAcquired", "RequestHeadersStart", "RequestHeadersEnd", - "ResponseHeadersStart", "RequestBodyStart", "RequestBodyEnd", "ResponseHeadersEnd", - "ResponseBodyStart", "ResponseBodyEnd", "ConnectionReleased", "CallEnd"); - } - - @Test public void timeToFirstByteGapBetweenResponseHeaderStartAndEnd() throws IOException { - long responseHeadersStartDelay = 250L; - server.enqueue(new MockResponse.Builder() - .add100Continue() - .headersDelay(responseHeadersStartDelay, TimeUnit.MILLISECONDS) - .build()); - - Request request = new Request.Builder() - .url(server.url("/")) - .header("Expect", "100-continue") - .post(RequestBody.create("abc", MediaType.get("text/plain"))) - .build(); - - Call call = client.newCall(request); - try (Response response = call.execute()) { - assertThat(response.body().string()).isEqualTo(""); - } - - listener.removeUpToEvent(ResponseHeadersStart.class); - listener.takeEvent(RequestBodyStart.class, 0L); - listener.takeEvent(RequestBodyEnd.class, 0L); - listener.takeEvent(ResponseHeadersEnd.class, responseHeadersStartDelay); - } - - @Test public void cacheMiss() throws IOException { - enableCache(); - - server.enqueue(new MockResponse.Builder() - .body("abc") - .build()); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - Response response = call.execute(); - assertThat(response.code()).isEqualTo(200); - assertThat(response.body().string()).isEqualTo("abc"); - response.close(); - - assertThat(listener.recordedEventTypes()).containsExactly("CallStart", "CacheMiss", - "ProxySelectStart", "ProxySelectEnd", "DnsStart", "DnsEnd", - "ConnectStart", "ConnectEnd", "ConnectionAcquired", "RequestHeadersStart", - "RequestHeadersEnd", "ResponseHeadersStart", "ResponseHeadersEnd", - "ResponseBodyStart", "ResponseBodyEnd", "ConnectionReleased", "CallEnd"); - } - - @Test public void conditionalCache() throws IOException { - enableCache(); - - server.enqueue(new MockResponse.Builder() - .addHeader("ETag", "v1") - .body("abc") - .build()); - server.enqueue(new MockResponse.Builder() - .code(HttpURLConnection.HTTP_NOT_MODIFIED) - .build()); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - - Response response = call.execute(); - assertThat(response.code()).isEqualTo(200); - response.close(); - - listener.clearAllEvents(); - - call = call.clone(); - - response = call.execute(); - assertThat(response.code()).isEqualTo(200); - assertThat(response.body().string()).isEqualTo("abc"); - response.close(); - - assertThat(listener.recordedEventTypes()).containsExactly("CallStart", "CacheConditionalHit", - "ConnectionAcquired", "RequestHeadersStart", - "RequestHeadersEnd", "ResponseHeadersStart", "ResponseHeadersEnd", - "ResponseBodyStart", "ResponseBodyEnd", "CacheHit", "ConnectionReleased", "CallEnd"); - } - - @Test public void conditionalCacheMiss() throws IOException { - enableCache(); - - server.enqueue(new MockResponse.Builder() - .addHeader("ETag: v1") - .body("abc") - .build()); - server.enqueue(new MockResponse.Builder() - .code(HttpURLConnection.HTTP_OK) - .addHeader("ETag: v2") - .body("abd") - .build()); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - - Response response = call.execute(); - assertThat(response.code()).isEqualTo(200); - response.close(); - - listener.clearAllEvents(); - - call = call.clone(); - - response = call.execute(); - assertThat(response.code()).isEqualTo(200); - assertThat(response.body().string()).isEqualTo("abd"); - response.close(); - - assertThat(listener.recordedEventTypes()).containsExactly("CallStart", "CacheConditionalHit", - "ConnectionAcquired", "RequestHeadersStart", - "RequestHeadersEnd", "ResponseHeadersStart", "ResponseHeadersEnd", "CacheMiss", - "ResponseBodyStart", "ResponseBodyEnd", "ConnectionReleased", "CallEnd"); - } - - @Test public void satisfactionFailure() throws IOException { - enableCache(); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .cacheControl(CacheControl.FORCE_CACHE) - .build()); - Response response = call.execute(); - assertThat(response.code()).isEqualTo(504); - response.close(); - - assertThat(listener.recordedEventTypes()).containsExactly("CallStart", "SatisfactionFailure", "CallEnd"); - } - - @Test public void cacheHit() throws IOException { - enableCache(); - - server.enqueue(new MockResponse.Builder() - .body("abc") - .addHeader("cache-control: public, max-age=300") - .build()); - - Call call = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - Response response = call.execute(); - assertThat(response.code()).isEqualTo(200); - assertThat(response.body().string()).isEqualTo("abc"); - response.close(); - - listener.clearAllEvents(); - - call = call.clone(); - response = call.execute(); - assertThat(response.code()).isEqualTo(200); - assertThat(response.body().string()).isEqualTo("abc"); - response.close(); - - assertThat(listener.recordedEventTypes()).containsExactly("CallStart", "CacheHit", "CallEnd"); - } - - private Cache enableCache() throws IOException { - cache = makeCache(); - client = client.newBuilder().cache(cache).build(); - return cache; - } - - private Cache makeCache() throws IOException { - File cacheDir = File.createTempFile("cache-", ".dir"); - cacheDir.delete(); - return new Cache(cacheDir, 1024 * 1024); - } -} diff --git a/okhttp/src/test/java/okhttp3/EventListenerTest.kt b/okhttp/src/test/java/okhttp3/EventListenerTest.kt new file mode 100644 index 000000000000..93063b982d50 --- /dev/null +++ b/okhttp/src/test/java/okhttp3/EventListenerTest.kt @@ -0,0 +1,1850 @@ +/* + * Copyright (C) 2017 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3 + +import java.io.File +import java.io.IOException +import java.io.InterruptedIOException +import java.net.HttpURLConnection +import java.net.InetAddress +import java.net.InetSocketAddress +import java.net.Proxy +import java.net.UnknownHostException +import java.time.Duration +import java.util.Arrays +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit +import mockwebserver3.MockResponse +import mockwebserver3.MockWebServer +import mockwebserver3.SocketPolicy.DisconnectDuringRequestBody +import mockwebserver3.SocketPolicy.DisconnectDuringResponseBody +import mockwebserver3.SocketPolicy.FailHandshake +import okhttp3.CallEvent.CallEnd +import okhttp3.CallEvent.CallFailed +import okhttp3.CallEvent.CallStart +import okhttp3.CallEvent.ConnectStart +import okhttp3.CallEvent.ConnectionAcquired +import okhttp3.CallEvent.DnsEnd +import okhttp3.CallEvent.DnsStart +import okhttp3.CallEvent.RequestBodyEnd +import okhttp3.CallEvent.RequestBodyStart +import okhttp3.CallEvent.RequestHeadersEnd +import okhttp3.CallEvent.RequestHeadersStart +import okhttp3.CallEvent.ResponseBodyEnd +import okhttp3.CallEvent.ResponseBodyStart +import okhttp3.CallEvent.ResponseFailed +import okhttp3.CallEvent.ResponseHeadersEnd +import okhttp3.CallEvent.ResponseHeadersStart +import okhttp3.CallEvent.SecureConnectEnd +import okhttp3.CallEvent.SecureConnectStart +import okhttp3.MediaType.Companion.toMediaType +import okhttp3.RequestBody.Companion.toRequestBody +import okhttp3.ResponseBody.Companion.toResponseBody +import okhttp3.internal.DoubleInetAddressDns +import okhttp3.internal.RecordingOkAuthenticator +import okhttp3.internal.connection.RealConnectionPool.Companion.get +import okhttp3.logging.HttpLoggingInterceptor +import okhttp3.testing.Flaky +import okhttp3.testing.PlatformRule +import okio.Buffer +import okio.BufferedSink +import org.assertj.core.api.Assertions.assertThat +import org.hamcrest.BaseMatcher +import org.hamcrest.CoreMatchers +import org.hamcrest.Description +import org.hamcrest.Matcher +import org.hamcrest.MatcherAssert +import org.junit.Assume +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.Assertions.fail +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Tag +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Timeout +import org.junit.jupiter.api.extension.RegisterExtension + +@Flaky // STDOUT logging enabled for test +@Timeout(30) +@Tag("Slow") +class EventListenerTest { + @RegisterExtension + val platform = PlatformRule() + + @RegisterExtension + val clientTestRule = OkHttpClientTestRule() + private lateinit var server: MockWebServer + private val listener: RecordingEventListener = RecordingEventListener() + private val handshakeCertificates = platform.localhostHandshakeCertificates() + private var client = clientTestRule.newClientBuilder() + .eventListenerFactory(clientTestRule.wrap(listener)) + .build() + private var socksProxy: SocksProxy? = null + private var cache: Cache? = null + + @BeforeEach + fun setUp(server: MockWebServer) { + this.server = server + platform.assumeNotOpenJSSE() + listener.forbidLock(get(client.connectionPool)) + listener.forbidLock(client.dispatcher) + } + + @AfterEach + fun tearDown() { + if (socksProxy != null) { + socksProxy!!.shutdown() + } + if (cache != null) { + cache!!.delete() + } + } + + @Test + fun successfulCallEventSequence() { + server.enqueue( + MockResponse.Builder() + .body("abc") + .build() + ) + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + val response = call.execute() + assertThat(response.code).isEqualTo(200) + assertThat(response.body.string()).isEqualTo("abc") + response.body.close() + assertThat(listener.recordedEventTypes()).containsExactly( + "CallStart", + "ProxySelectStart", "ProxySelectEnd", "DnsStart", "DnsEnd", + "ConnectStart", "ConnectEnd", "ConnectionAcquired", "RequestHeadersStart", + "RequestHeadersEnd", "ResponseHeadersStart", "ResponseHeadersEnd", "ResponseBodyStart", + "ResponseBodyEnd", "ConnectionReleased", "CallEnd" + ) + } + + @Test + fun successfulCallEventSequenceForIpAddress() { + server.enqueue( + MockResponse.Builder() + .body("abc") + .build() + ) + val ipAddress = InetAddress.getLoopbackAddress().hostAddress + val call = client.newCall( + Request.Builder() + .url(server.url("/").newBuilder().host(ipAddress!!).build()) + .build() + ) + val response = call.execute() + assertThat(response.code).isEqualTo(200) + assertThat(response.body.string()).isEqualTo("abc") + response.body.close() + assertThat(listener.recordedEventTypes()).containsExactly( + "CallStart", + "ProxySelectStart", "ProxySelectEnd", + "ConnectStart", "ConnectEnd", "ConnectionAcquired", "RequestHeadersStart", + "RequestHeadersEnd", "ResponseHeadersStart", "ResponseHeadersEnd", "ResponseBodyStart", + "ResponseBodyEnd", "ConnectionReleased", "CallEnd" + ) + } + + @Test + fun successfulCallEventSequenceForEnqueue() { + server.enqueue( + MockResponse.Builder() + .body("abc") + .build() + ) + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + val completionLatch = CountDownLatch(1) + val callback: Callback = object : Callback { + override fun onFailure(call: Call, e: IOException) { + completionLatch.countDown() + } + + override fun onResponse(call: Call, response: Response) { + response.close() + completionLatch.countDown() + } + } + call.enqueue(callback) + completionLatch.await() + assertThat(listener.recordedEventTypes()).containsExactly( + "CallStart", + "ProxySelectStart", "ProxySelectEnd", "DnsStart", "DnsEnd", + "ConnectStart", "ConnectEnd", "ConnectionAcquired", "RequestHeadersStart", + "RequestHeadersEnd", "ResponseHeadersStart", "ResponseHeadersEnd", "ResponseBodyStart", + "ResponseBodyEnd", "ConnectionReleased", "CallEnd" + ) + } + + @Test + fun failedCallEventSequence() { + server.enqueue( + MockResponse.Builder() + .headersDelay(2, TimeUnit.SECONDS) + .build() + ) + client = client.newBuilder() + .readTimeout(Duration.ofMillis(250)) + .build() + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + try { + call.execute() + fail() + } catch (expected: IOException) { + assertThat(expected.message).isIn("timeout", "Read timed out") + } + assertThat(listener.recordedEventTypes()).containsExactly( + "CallStart", + "ProxySelectStart", "ProxySelectEnd", "DnsStart", "DnsEnd", + "ConnectStart", "ConnectEnd", "ConnectionAcquired", "RequestHeadersStart", + "RequestHeadersEnd", "ResponseFailed", "ConnectionReleased", "CallFailed" + ) + } + + @Test + fun failedDribbledCallEventSequence() { + server.enqueue( + MockResponse.Builder() + .body("0123456789") + .throttleBody(2, 100, TimeUnit.MILLISECONDS) + .socketPolicy(DisconnectDuringResponseBody) + .build() + ) + client = client.newBuilder() + .protocols(listOf(Protocol.HTTP_1_1)) + .readTimeout(Duration.ofMillis(250)) + .build() + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + val response = call.execute() + try { + response.body.string() + fail() + } catch (expected: IOException) { + assertThat(expected.message).isEqualTo("unexpected end of stream") + } + assertThat(listener.recordedEventTypes()).containsExactly( + "CallStart", + "ProxySelectStart", "ProxySelectEnd", "DnsStart", "DnsEnd", + "ConnectStart", "ConnectEnd", "ConnectionAcquired", "RequestHeadersStart", + "RequestHeadersEnd", "ResponseHeadersStart", "ResponseHeadersEnd", "ResponseBodyStart", + "ResponseFailed", "ConnectionReleased", "CallFailed" + ) + val responseFailed = listener.removeUpToEvent() + assertThat(responseFailed.ioe.message).isEqualTo("unexpected end of stream") + } + + @Test + fun canceledCallEventSequence() { + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + call.cancel() + try { + call.execute() + fail() + } catch (expected: IOException) { + assertThat(expected.message).isEqualTo("Canceled") + } + assertThat(listener.recordedEventTypes()).containsExactly( + "Canceled", "CallStart", "CallFailed" + ) + } + + @Test + fun cancelAsyncCall() { + server.enqueue( + MockResponse.Builder() + .body("abc") + .build() + ) + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + call.enqueue(object : Callback { + override fun onFailure(call: Call, e: IOException) { + } + + override fun onResponse(call: Call, response: Response) { + response.close() + } + }) + call.cancel() + assertThat(listener.recordedEventTypes()).contains("Canceled") + } + + @Test + fun multipleCancelsEmitsOnlyOneEvent() { + server.enqueue( + MockResponse.Builder() + .body("abc") + .build() + ) + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + call.cancel() + call.cancel() + assertThat(listener.recordedEventTypes()).containsExactly("Canceled") + } + + private fun assertSuccessfulEventOrder(responseMatcher: Matcher?) { + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + val response = call.execute() + assertThat(response.code).isEqualTo(200) + response.body.string() + response.body.close() + Assume.assumeThat(response, responseMatcher) + assertThat(listener.recordedEventTypes()).containsExactly( + "CallStart", + "ProxySelectStart", + "ProxySelectEnd", + "DnsStart", + "DnsEnd", + "ConnectStart", + "SecureConnectStart", + "SecureConnectEnd", + "ConnectEnd", + "ConnectionAcquired", + "RequestHeadersStart", + "RequestHeadersEnd", + "ResponseHeadersStart", + "ResponseHeadersEnd", + "ResponseBodyStart", + "ResponseBodyEnd", + "ConnectionReleased", + "CallEnd" + ) + } + + @Test + fun secondCallEventSequence() { + enableTlsWithTunnel() + server.protocols = Arrays.asList(Protocol.HTTP_2, Protocol.HTTP_1_1) + server.enqueue(MockResponse()) + server.enqueue(MockResponse()) + client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ).execute().close() + listener.removeUpToEvent() + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + val response = call.execute() + response.close() + assertThat(listener.recordedEventTypes()).containsExactly( + "CallStart", + "ConnectionAcquired", + "RequestHeadersStart", + "RequestHeadersEnd", + "ResponseHeadersStart", + "ResponseHeadersEnd", + "ResponseBodyStart", + "ResponseBodyEnd", + "ConnectionReleased", + "CallEnd" + ) + } + + private fun assertBytesReadWritten( + listener: RecordingEventListener, + requestHeaderLength: Matcher?, requestBodyBytes: Matcher?, + responseHeaderLength: Matcher?, responseBodyBytes: Matcher? + ) { + if (requestHeaderLength != null) { + val responseHeadersEnd = listener.removeUpToEvent() + MatcherAssert.assertThat( + "request header length", responseHeadersEnd.headerLength, + requestHeaderLength + ) + } else { + assertThat(listener.recordedEventTypes()) + .doesNotContain("RequestHeadersEnd") + } + if (requestBodyBytes != null) { + val responseBodyEnd: RequestBodyEnd = listener.removeUpToEvent() + MatcherAssert.assertThat( + "request body bytes", + responseBodyEnd.bytesWritten, + requestBodyBytes + ) + } else { + assertThat(listener.recordedEventTypes()).doesNotContain("RequestBodyEnd") + } + if (responseHeaderLength != null) { + val responseHeadersEnd: ResponseHeadersEnd = + listener.removeUpToEvent() + MatcherAssert.assertThat( + "response header length", responseHeadersEnd.headerLength, + responseHeaderLength + ) + } else { + assertThat(listener.recordedEventTypes()) + .doesNotContain("ResponseHeadersEnd") + } + if (responseBodyBytes != null) { + val responseBodyEnd: ResponseBodyEnd = listener.removeUpToEvent() + MatcherAssert.assertThat( + "response body bytes", + responseBodyEnd.bytesRead, + responseBodyBytes + ) + } else { + assertThat(listener.recordedEventTypes()).doesNotContain("ResponseBodyEnd") + } + } + + private fun greaterThan(value: Long): Matcher { + return object : BaseMatcher() { + override fun describeTo(description: Description?) { + description!!.appendText("> $value") + } + + override fun matches(o: Any?): Boolean { + return (o as Long?)!! > value + } + } + } + + private fun matchesProtocol(protocol: Protocol?): Matcher { + return object : BaseMatcher() { + override fun describeTo(description: Description?) { + description!!.appendText("is HTTP/2") + } + + override fun matches(o: Any?): Boolean { + return (o as Response?)!!.protocol == protocol + } + } + } + + @Test + fun successfulEmptyH2CallEventSequence() { + enableTlsWithTunnel() + server.protocols = Arrays.asList(Protocol.HTTP_2, Protocol.HTTP_1_1) + server.enqueue(MockResponse()) + assertSuccessfulEventOrder(matchesProtocol(Protocol.HTTP_2)) + assertBytesReadWritten( + listener, CoreMatchers.any(Long::class.java), null, greaterThan(0L), + CoreMatchers.equalTo(0L) + ) + } + + @Test + fun successfulEmptyHttpsCallEventSequence() { + enableTlsWithTunnel() + server.protocols = Arrays.asList(Protocol.HTTP_1_1) + server.enqueue( + MockResponse.Builder() + .body("abc") + .build() + ) + assertSuccessfulEventOrder(anyResponse) + assertBytesReadWritten( + listener, CoreMatchers.any(Long::class.java), null, greaterThan(0L), + CoreMatchers.equalTo(3L) + ) + } + + @Test + fun successfulChunkedHttpsCallEventSequence() { + enableTlsWithTunnel() + server.protocols = Arrays.asList(Protocol.HTTP_1_1) + server.enqueue( + MockResponse.Builder() + .bodyDelay(100, TimeUnit.MILLISECONDS) + .chunkedBody("Hello!", 2) + .build() + ) + assertSuccessfulEventOrder(anyResponse) + assertBytesReadWritten( + listener, CoreMatchers.any(Long::class.java), null, greaterThan(0L), + CoreMatchers.equalTo(6L) + ) + } + + @Test + fun successfulChunkedH2CallEventSequence() { + enableTlsWithTunnel() + server.protocols = Arrays.asList(Protocol.HTTP_2, Protocol.HTTP_1_1) + server.enqueue( + MockResponse.Builder() + .bodyDelay(100, TimeUnit.MILLISECONDS) + .chunkedBody("Hello!", 2) + .build() + ) + assertSuccessfulEventOrder(matchesProtocol(Protocol.HTTP_2)) + assertBytesReadWritten( + listener, CoreMatchers.any(Long::class.java), null, CoreMatchers.equalTo(0L), + greaterThan(6L) + ) + } + + @Test + fun successfulDnsLookup() { + server.enqueue(MockResponse()) + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + val response = call.execute() + assertThat(response.code).isEqualTo(200) + response.body.close() + val dnsStart: DnsStart = listener.removeUpToEvent() + assertThat(dnsStart.call).isSameAs(call) + assertThat(dnsStart.domainName).isEqualTo(server.hostName) + val dnsEnd: DnsEnd = listener.removeUpToEvent() + assertThat(dnsEnd.call).isSameAs(call) + assertThat(dnsEnd.domainName).isEqualTo(server.hostName) + assertThat(dnsEnd.inetAddressList.size).isEqualTo(1) + } + + @Test + fun noDnsLookupOnPooledConnection() { + server.enqueue(MockResponse()) + server.enqueue(MockResponse()) + + // Seed the pool. + val call1 = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + val response1 = call1.execute() + assertThat(response1.code).isEqualTo(200) + response1.body.close() + listener.clearAllEvents() + val call2 = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + val response2 = call2.execute() + assertThat(response2.code).isEqualTo(200) + response2.body.close() + val recordedEvents: List = listener.recordedEventTypes() + assertThat(recordedEvents).doesNotContain("DnsStart") + assertThat(recordedEvents).doesNotContain("DnsEnd") + } + + @Test + fun multipleDnsLookupsForSingleCall() { + server.enqueue( + MockResponse.Builder() + .code(301) + .setHeader("Location", "http://www.fakeurl:" + server.port) + .build() + ) + server.enqueue(MockResponse()) + val dns = FakeDns() + dns["fakeurl"] = client.dns.lookup(server.hostName) + dns["www.fakeurl"] = client.dns.lookup(server.hostName) + client = client.newBuilder() + .dns(dns) + .build() + val call = client.newCall( + Request.Builder() + .url("http://fakeurl:" + server.port) + .build() + ) + val response = call.execute() + assertThat(response.code).isEqualTo(200) + response.body.close() + listener.removeUpToEvent() + listener.removeUpToEvent() + listener.removeUpToEvent() + listener.removeUpToEvent() + } + + @Test + fun failedDnsLookup() { + client = client.newBuilder() + .dns(FakeDns()) + .build() + val call = client.newCall( + Request.Builder() + .url("http://fakeurl/") + .build() + ) + try { + call.execute() + fail() + } catch (expected: IOException) { + } + listener.removeUpToEvent() + val callFailed: CallFailed = listener.removeUpToEvent() + assertThat(callFailed.call).isSameAs(call) + assertThat(callFailed.ioe).isInstanceOf( + UnknownHostException::class.java + ) + } + + @Test + fun emptyDnsLookup() { + val emptyDns = Dns { listOf() } + client = client.newBuilder() + .dns(emptyDns) + .build() + val call = client.newCall( + Request.Builder() + .url("http://fakeurl/") + .build() + ) + try { + call.execute() + fail() + } catch (expected: IOException) { + } + listener.removeUpToEvent() + val callFailed: CallFailed = listener.removeUpToEvent() + assertThat(callFailed.call).isSameAs(call) + assertThat(callFailed.ioe).isInstanceOf( + UnknownHostException::class.java + ) + } + + @Test + fun successfulConnect() { + server.enqueue(MockResponse()) + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + val response = call.execute() + assertThat(response.code).isEqualTo(200) + response.body.close() + val address = client.dns.lookup(server.hostName)[0] + val expectedAddress = InetSocketAddress(address, server.port) + val connectStart = listener.removeUpToEvent() + assertThat(connectStart.call).isSameAs(call) + assertThat(connectStart.inetSocketAddress).isEqualTo(expectedAddress) + assertThat(connectStart.proxy).isEqualTo(Proxy.NO_PROXY) + val connectEnd = listener.removeUpToEvent() + assertThat(connectEnd.call).isSameAs(call) + assertThat(connectEnd.inetSocketAddress).isEqualTo(expectedAddress) + assertThat(connectEnd.protocol).isEqualTo(Protocol.HTTP_1_1) + } + + @Test + fun failedConnect() { + enableTlsWithTunnel() + server.enqueue( + MockResponse.Builder() + .socketPolicy(FailHandshake) + .build() + ) + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + try { + call.execute() + fail() + } catch (expected: IOException) { + } + val address = client.dns.lookup(server.hostName)[0] + val expectedAddress = InetSocketAddress(address, server.port) + val connectStart = listener.removeUpToEvent() + assertThat(connectStart.call).isSameAs(call) + assertThat(connectStart.inetSocketAddress).isEqualTo(expectedAddress) + assertThat(connectStart.proxy).isEqualTo(Proxy.NO_PROXY) + val connectFailed = listener.removeUpToEvent() + assertThat(connectFailed.call).isSameAs(call) + assertThat(connectFailed.inetSocketAddress).isEqualTo(expectedAddress) + assertThat(connectFailed.protocol).isNull() + assertThat(connectFailed.ioe).isNotNull() + } + + @Test + fun multipleConnectsForSingleCall() { + enableTlsWithTunnel() + server.enqueue( + MockResponse.Builder() + .socketPolicy(FailHandshake) + .build() + ) + server.enqueue(MockResponse()) + client = client.newBuilder() + .dns(DoubleInetAddressDns()) + .build() + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + val response = call.execute() + assertThat(response.code).isEqualTo(200) + response.body.close() + listener.removeUpToEvent() + listener.removeUpToEvent() + listener.removeUpToEvent() + listener.removeUpToEvent() + } + + @Test + fun successfulHttpProxyConnect() { + server.enqueue(MockResponse()) + client = client.newBuilder() + .proxy(server.toProxyAddress()) + .build() + val call = client.newCall( + Request.Builder() + .url("http://www.fakeurl") + .build() + ) + val response = call.execute() + assertThat(response.code).isEqualTo(200) + response.body.close() + val address = client.dns.lookup(server.hostName)[0] + val expectedAddress = InetSocketAddress(address, server.port) + val connectStart: ConnectStart = listener.removeUpToEvent( + ) + assertThat(connectStart.call).isSameAs(call) + assertThat(connectStart.inetSocketAddress).isEqualTo(expectedAddress) + assertThat(connectStart.proxy).isEqualTo( + server.toProxyAddress() + ) + val connectEnd = listener.removeUpToEvent() + assertThat(connectEnd.call).isSameAs(call) + assertThat(connectEnd.inetSocketAddress).isEqualTo(expectedAddress) + assertThat(connectEnd.protocol).isEqualTo(Protocol.HTTP_1_1) + } + + @Test + fun successfulSocksProxyConnect() { + server.enqueue(MockResponse()) + socksProxy = SocksProxy() + socksProxy!!.play() + val proxy = socksProxy!!.proxy() + client = client.newBuilder() + .proxy(proxy) + .build() + val call = client.newCall( + Request.Builder() + .url("http://" + SocksProxy.HOSTNAME_THAT_ONLY_THE_PROXY_KNOWS + ":" + server.port) + .build() + ) + val response = call.execute() + assertThat(response.code).isEqualTo(200) + response.body.close() + val expectedAddress = InetSocketAddress.createUnresolved( + SocksProxy.HOSTNAME_THAT_ONLY_THE_PROXY_KNOWS, server.port + ) + val connectStart = listener.removeUpToEvent() + assertThat(connectStart.call).isSameAs(call) + assertThat(connectStart.inetSocketAddress).isEqualTo(expectedAddress) + assertThat(connectStart.proxy).isEqualTo(proxy) + val connectEnd = listener.removeUpToEvent() + assertThat(connectEnd.call).isSameAs(call) + assertThat(connectEnd.inetSocketAddress).isEqualTo(expectedAddress) + assertThat(connectEnd.protocol).isEqualTo(Protocol.HTTP_1_1) + } + + @Test + fun authenticatingTunnelProxyConnect() { + enableTlsWithTunnel() + server.enqueue( + MockResponse.Builder() + .inTunnel() + .code(407) + .addHeader("Proxy-Authenticate: Basic realm=\"localhost\"") + .addHeader("Connection: close") + .build() + ) + server.enqueue( + MockResponse.Builder() + .inTunnel() + .build() + ) + server.enqueue(MockResponse()) + client = client.newBuilder() + .proxy(server.toProxyAddress()) + .proxyAuthenticator(RecordingOkAuthenticator("password", "Basic")) + .build() + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + val response = call.execute() + assertThat(response.code).isEqualTo(200) + response.body.close() + listener.removeUpToEvent() + val connectEnd = listener.removeUpToEvent() + assertThat(connectEnd.protocol).isNull() + listener.removeUpToEvent() + listener.removeUpToEvent() + } + + @Test + fun successfulSecureConnect() { + enableTlsWithTunnel() + server.enqueue(MockResponse()) + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + val response = call.execute() + assertThat(response.code).isEqualTo(200) + response.body.close() + val secureStart = listener.removeUpToEvent() + assertThat(secureStart.call).isSameAs(call) + val secureEnd = listener.removeUpToEvent() + assertThat(secureEnd.call).isSameAs(call) + assertThat(secureEnd.handshake).isNotNull() + } + + @Test + fun failedSecureConnect() { + enableTlsWithTunnel() + server.enqueue( + MockResponse.Builder() + .socketPolicy(FailHandshake) + .build() + ) + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + try { + call.execute() + fail() + } catch (expected: IOException) { + } + val secureStart = listener.removeUpToEvent() + assertThat(secureStart.call).isSameAs(call) + val callFailed = listener.removeUpToEvent() + assertThat(callFailed.call).isSameAs(call) + assertThat(callFailed.ioe).isNotNull() + } + + @Test + fun secureConnectWithTunnel() { + enableTlsWithTunnel() + server.enqueue( + MockResponse.Builder() + .inTunnel() + .build() + ) + server.enqueue(MockResponse()) + client = client.newBuilder() + .proxy(server.toProxyAddress()) + .build() + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + val response = call.execute() + assertThat(response.code).isEqualTo(200) + response.body.close() + val secureStart = listener.removeUpToEvent() + assertThat(secureStart.call).isSameAs(call) + val secureEnd = listener.removeUpToEvent() + assertThat(secureEnd.call).isSameAs(call) + assertThat(secureEnd.handshake).isNotNull() + } + + @Test + fun multipleSecureConnectsForSingleCall() { + enableTlsWithTunnel() + server.enqueue( + MockResponse.Builder() + .socketPolicy(FailHandshake) + .build() + ) + server.enqueue(MockResponse()) + client = client.newBuilder() + .dns(DoubleInetAddressDns()) + .build() + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + val response = call.execute() + assertThat(response.code).isEqualTo(200) + response.body.close() + listener.removeUpToEvent() + listener.removeUpToEvent() + listener.removeUpToEvent() + listener.removeUpToEvent() + } + + @Test + fun noSecureConnectsOnPooledConnection() { + enableTlsWithTunnel() + server.enqueue(MockResponse()) + server.enqueue(MockResponse()) + client = client.newBuilder() + .dns(DoubleInetAddressDns()) + .build() + + // Seed the pool. + val call1 = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + val response1 = call1.execute() + assertThat(response1.code).isEqualTo(200) + response1.body.close() + listener.clearAllEvents() + val call2 = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + val response2 = call2.execute() + assertThat(response2.code).isEqualTo(200) + response2.body.close() + val recordedEvents: List = listener.recordedEventTypes() + assertThat(recordedEvents).doesNotContain("SecureConnectStart") + assertThat(recordedEvents).doesNotContain("SecureConnectEnd") + } + + @Test + fun successfulConnectionFound() { + server.enqueue(MockResponse()) + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + val response = call.execute() + assertThat(response.code).isEqualTo(200) + response.body.close() + val connectionAcquired = listener.removeUpToEvent() + assertThat(connectionAcquired.call).isSameAs(call) + assertThat(connectionAcquired.connection).isNotNull() + } + + @Test + fun noConnectionFoundOnFollowUp() { + server.enqueue( + MockResponse.Builder() + .code(301) + .addHeader("Location", "/foo") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("ABC") + .build() + ) + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + val response = call.execute() + assertThat(response.body.string()).isEqualTo("ABC") + listener.removeUpToEvent() + val remainingEvents = listener.recordedEventTypes() + assertThat(remainingEvents).doesNotContain("ConnectionAcquired") + } + + @Test + fun pooledConnectionFound() { + server.enqueue(MockResponse()) + server.enqueue(MockResponse()) + + // Seed the pool. + val call1 = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + val response1 = call1.execute() + assertThat(response1.code).isEqualTo(200) + response1.body.close() + val connectionAcquired1 = listener.removeUpToEvent() + listener.clearAllEvents() + val call2 = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + val response2 = call2.execute() + assertThat(response2.code).isEqualTo(200) + response2.body.close() + val connectionAcquired2 = listener.removeUpToEvent() + assertThat(connectionAcquired2.connection).isSameAs( + connectionAcquired1.connection + ) + } + + @Test + fun multipleConnectionsFoundForSingleCall() { + server.enqueue( + MockResponse.Builder() + .code(301) + .addHeader("Location", "/foo") + .addHeader("Connection", "Close") + .build() + ) + server.enqueue( + MockResponse.Builder() + .body("ABC") + .build() + ) + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + val response = call.execute() + assertThat(response.body.string()).isEqualTo("ABC") + listener.removeUpToEvent() + listener.removeUpToEvent() + } + + @Test + fun responseBodyFailHttp1OverHttps() { + enableTlsWithTunnel() + server.protocols = Arrays.asList(Protocol.HTTP_1_1) + responseBodyFail(Protocol.HTTP_1_1) + } + + @Test + fun responseBodyFailHttp2OverHttps() { + platform.assumeHttp2Support() + enableTlsWithTunnel() + server.protocols = Arrays.asList(Protocol.HTTP_2, Protocol.HTTP_1_1) + responseBodyFail(Protocol.HTTP_2) + } + + @Test + fun responseBodyFailHttp() { + responseBodyFail(Protocol.HTTP_1_1) + } + + private fun responseBodyFail(expectedProtocol: Protocol?) { + // Use a 2 MiB body so the disconnect won't happen until the client has read some data. + val responseBodySize = 2 * 1024 * 1024 // 2 MiB + server.enqueue( + MockResponse.Builder() + .body(Buffer().write(ByteArray(responseBodySize))) + .socketPolicy(DisconnectDuringResponseBody) + .build() + ) + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + val response = call.execute() + if (expectedProtocol == Protocol.HTTP_2) { + // soft failure since client may not support depending on Platform + Assume.assumeThat(response, matchesProtocol(Protocol.HTTP_2)) + } + assertThat(response.protocol).isEqualTo(expectedProtocol) + try { + response.body.string() + fail() + } catch (expected: IOException) { + } + val callFailed = listener.removeUpToEvent() + assertThat(callFailed.ioe).isNotNull() + } + + @Test + fun emptyResponseBody() { + server.enqueue( + MockResponse.Builder() + .body("") + .bodyDelay(1, TimeUnit.SECONDS) + .socketPolicy(DisconnectDuringResponseBody) + .build() + ) + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + val response = call.execute() + response.body.close() + assertThat(listener.recordedEventTypes()).containsExactly( + "CallStart", + "ProxySelectStart", "ProxySelectEnd", "DnsStart", "DnsEnd", + "ConnectStart", "ConnectEnd", "ConnectionAcquired", "RequestHeadersStart", + "RequestHeadersEnd", "ResponseHeadersStart", "ResponseHeadersEnd", "ResponseBodyStart", + "ResponseBodyEnd", "ConnectionReleased", "CallEnd" + ) + } + + @Test + fun emptyResponseBodyConnectionClose() { + server.enqueue( + MockResponse.Builder() + .addHeader("Connection", "close") + .body("") + .build() + ) + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + val response = call.execute() + response.body.close() + assertThat(listener.recordedEventTypes()).containsExactly( + "CallStart", + "ProxySelectStart", "ProxySelectEnd", "DnsStart", "DnsEnd", + "ConnectStart", "ConnectEnd", "ConnectionAcquired", "RequestHeadersStart", + "RequestHeadersEnd", "ResponseHeadersStart", "ResponseHeadersEnd", "ResponseBodyStart", + "ResponseBodyEnd", "ConnectionReleased", "CallEnd" + ) + } + + @Test + fun responseBodyClosedClosedWithoutReadingAllData() { + server.enqueue( + MockResponse.Builder() + .body("abc") + .bodyDelay(1, TimeUnit.SECONDS) + .socketPolicy(DisconnectDuringResponseBody) + .build() + ) + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + val response = call.execute() + response.body.close() + assertThat(listener.recordedEventTypes()).containsExactly( + "CallStart", + "ProxySelectStart", "ProxySelectEnd", "DnsStart", "DnsEnd", + "ConnectStart", "ConnectEnd", "ConnectionAcquired", "RequestHeadersStart", + "RequestHeadersEnd", "ResponseHeadersStart", "ResponseHeadersEnd", "ResponseBodyStart", + "ResponseBodyEnd", "ConnectionReleased", "CallEnd" + ) + } + + @Test + fun requestBodyFailHttp1OverHttps() { + enableTlsWithTunnel() + server.protocols = Arrays.asList(Protocol.HTTP_1_1) + requestBodyFail(Protocol.HTTP_1_1) + } + + @Test + fun requestBodyFailHttp2OverHttps() { + platform.assumeHttp2Support() + enableTlsWithTunnel() + server.protocols = Arrays.asList(Protocol.HTTP_2, Protocol.HTTP_1_1) + requestBodyFail(Protocol.HTTP_2) + } + + @Test + fun requestBodyFailHttp() { + requestBodyFail(null) + } + + private fun requestBodyFail(expectedProtocol: Protocol?) { + server.enqueue( + MockResponse.Builder() + .socketPolicy(DisconnectDuringRequestBody) + .build() + ) + val request = NonCompletingRequestBody() + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .post(request) + .build() + ) + try { + call.execute() + fail() + } catch (expected: IOException) { + } + if (expectedProtocol != null) { + val connectionAcquired = listener.removeUpToEvent() + assertThat(connectionAcquired.connection.protocol()) + .isEqualTo(expectedProtocol) + } + val callFailed = listener.removeUpToEvent() + assertThat(callFailed.ioe).isNotNull() + assertThat(request.ioe).isNotNull() + } + + private inner class NonCompletingRequestBody : RequestBody() { + private val chunk: ByteArray? = ByteArray(1024 * 1024) + var ioe: IOException? = null + override fun contentType(): MediaType? { + return "text/plain".toMediaType() + } + + override fun contentLength(): Long { + return chunk!!.size * 8L + } + + override fun writeTo(sink: BufferedSink) { + try { + var i = 0 + while (i < contentLength()) { + sink.write(chunk!!) + sink.flush() + Thread.sleep(100) + i += chunk.size + } + } catch (e: IOException) { + ioe = e + } catch (e: InterruptedException) { + throw RuntimeException(e) + } + } + } + + @Test + fun requestBodyMultipleFailuresReportedOnlyOnce() { + val requestBody: RequestBody = object : RequestBody() { + override fun contentType() = "text/plain".toMediaType() + + override fun contentLength(): Long { + return 1024 * 1024 * 256 + } + + override fun writeTo(sink: BufferedSink) { + var failureCount = 0 + for (i in 0..1023) { + try { + sink.write(ByteArray(1024 * 256)) + sink.flush() + } catch (e: IOException) { + failureCount++ + if (failureCount == 3) throw e + } + } + } + } + server.enqueue( + MockResponse.Builder() + .socketPolicy(DisconnectDuringRequestBody) + .build() + ) + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .post(requestBody) + .build() + ) + try { + call.execute() + fail() + } catch (expected: IOException) { + } + assertThat(listener.recordedEventTypes()).containsExactly( + "CallStart", + "ProxySelectStart", + "ProxySelectEnd", + "DnsStart", + "DnsEnd", + "ConnectStart", + "ConnectEnd", + "ConnectionAcquired", + "RequestHeadersStart", + "RequestHeadersEnd", + "RequestBodyStart", + "RequestFailed", + "ResponseFailed", + "ConnectionReleased", + "CallFailed" + ) + } + + @Test + fun requestBodySuccessHttp1OverHttps() { + enableTlsWithTunnel() + server.protocols = Arrays.asList(Protocol.HTTP_1_1) + requestBodySuccess( + "Hello".toRequestBody("text/plain".toMediaType()), CoreMatchers.equalTo(5L), + CoreMatchers.equalTo(19L) + ) + } + + @Test + fun requestBodySuccessHttp2OverHttps() { + platform.assumeHttp2Support() + enableTlsWithTunnel() + server.protocols = Arrays.asList(Protocol.HTTP_2, Protocol.HTTP_1_1) + requestBodySuccess( + "Hello".toRequestBody("text/plain".toMediaType()), CoreMatchers.equalTo(5L), + CoreMatchers.equalTo(19L) + ) + } + + @Test + fun requestBodySuccessHttp() { + requestBodySuccess( + "Hello".toRequestBody("text/plain".toMediaType()), CoreMatchers.equalTo(5L), + CoreMatchers.equalTo(19L) + ) + } + + @Test + fun requestBodySuccessStreaming() { + val requestBody: RequestBody = object : RequestBody() { + override fun contentType() = "text/plain".toMediaType() + + override fun writeTo(sink: BufferedSink) { + sink.write(ByteArray(8192)) + sink.flush() + } + } + requestBodySuccess(requestBody, CoreMatchers.equalTo(8192L), CoreMatchers.equalTo(19L)) + } + + @Test + fun requestBodySuccessEmpty() { + requestBodySuccess( + "".toRequestBody("text/plain".toMediaType()), CoreMatchers.equalTo(0L), + CoreMatchers.equalTo(19L) + ) + } + + @Test + fun successfulCallEventSequenceWithListener() { + server.enqueue( + MockResponse.Builder() + .body("abc") + .build() + ) + client = client.newBuilder() + .addNetworkInterceptor( + HttpLoggingInterceptor() + .setLevel(HttpLoggingInterceptor.Level.BODY) + ) + .build() + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + val response = call.execute() + assertThat(response.code).isEqualTo(200) + assertThat(response.body.string()).isEqualTo("abc") + response.body.close() + assertThat(listener.recordedEventTypes()).containsExactly( + "CallStart", + "ProxySelectStart", "ProxySelectEnd", "DnsStart", "DnsEnd", + "ConnectStart", "ConnectEnd", "ConnectionAcquired", "RequestHeadersStart", + "RequestHeadersEnd", "ResponseHeadersStart", "ResponseHeadersEnd", "ResponseBodyStart", + "ResponseBodyEnd", "ConnectionReleased", "CallEnd" + ) + } + + private fun requestBodySuccess( + body: RequestBody?, requestBodyBytes: Matcher?, + responseHeaderLength: Matcher? + ) { + server.enqueue( + MockResponse.Builder() + .code(200) + .body("World!") + .build() + ) + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .post(body!!) + .build() + ) + val response = call.execute() + assertThat(response.body.string()).isEqualTo("World!") + assertBytesReadWritten( + listener, CoreMatchers.any(Long::class.java), requestBodyBytes, responseHeaderLength, + CoreMatchers.equalTo(6L) + ) + } + + @Test + fun timeToFirstByteHttp1OverHttps() { + enableTlsWithTunnel() + server.protocols = Arrays.asList(Protocol.HTTP_1_1) + timeToFirstByte() + } + + @Test + fun timeToFirstByteHttp2OverHttps() { + platform.assumeHttp2Support() + enableTlsWithTunnel() + server.protocols = Arrays.asList(Protocol.HTTP_2, Protocol.HTTP_1_1) + timeToFirstByte() + } + + /** + * Test to confirm that events are reported at the time they occur and no earlier and no later. + * This inserts a bunch of synthetic 250 ms delays into both client and server and confirms that + * the same delays make it back into the events. + * + * We've had bugs where we report an event when we request data rather than when the data actually + * arrives. https://github.com/square/okhttp/issues/5578 + */ + private fun timeToFirstByte() { + val applicationInterceptorDelay = 250L + val networkInterceptorDelay = 250L + val requestBodyDelay = 250L + val responseHeadersStartDelay = 250L + val responseBodyStartDelay = 250L + val responseBodyEndDelay = 250L + + // Warm up the client so the timing part of the test gets a pooled connection. + server.enqueue(MockResponse()) + val warmUpCall = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + warmUpCall.execute().use { warmUpResponse -> warmUpResponse.body.string() } + listener.clearAllEvents() + + // Create a client with artificial delays. + client = client.newBuilder() + .addInterceptor(Interceptor { chain: Interceptor.Chain -> + try { + Thread.sleep(applicationInterceptorDelay) + return@Interceptor chain.proceed(chain.request()) + } catch (e: InterruptedException) { + throw InterruptedIOException() + } + }) + .addNetworkInterceptor(Interceptor { chain: Interceptor.Chain -> + try { + Thread.sleep(networkInterceptorDelay) + return@Interceptor chain.proceed(chain.request()) + } catch (e: InterruptedException) { + throw InterruptedIOException() + } + }) + .build() + + // Create a request body with artificial delays. + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .post(object : RequestBody() { + override fun contentType(): MediaType? { + return null + } + + override fun writeTo(sink: BufferedSink) { + try { + Thread.sleep(requestBodyDelay) + sink.writeUtf8("abc") + } catch (e: InterruptedException) { + throw InterruptedIOException() + } + } + }) + .build() + ) + + // Create a response with artificial delays. + server.enqueue( + MockResponse.Builder() + .headersDelay(responseHeadersStartDelay, TimeUnit.MILLISECONDS) + .bodyDelay(responseBodyStartDelay, TimeUnit.MILLISECONDS) + .throttleBody(5, responseBodyEndDelay, TimeUnit.MILLISECONDS) + .body("fghijk") + .build() + ) + call.execute().use { response -> + assertThat(response.body.string()).isEqualTo("fghijk") + } + + // Confirm the events occur when expected. + listener.takeEvent(CallStart::class.java, 0L) + listener.takeEvent(ConnectionAcquired::class.java, applicationInterceptorDelay) + listener.takeEvent(RequestHeadersStart::class.java, networkInterceptorDelay) + listener.takeEvent(RequestHeadersEnd::class.java, 0L) + listener.takeEvent(RequestBodyStart::class.java, 0L) + listener.takeEvent(RequestBodyEnd::class.java, requestBodyDelay) + listener.takeEvent(ResponseHeadersStart::class.java, responseHeadersStartDelay) + listener.takeEvent(ResponseHeadersEnd::class.java, 0L) + listener.takeEvent(ResponseBodyStart::class.java, responseBodyStartDelay) + listener.takeEvent(ResponseBodyEnd::class.java, responseBodyEndDelay) + listener.takeEvent(CallEvent.ConnectionReleased::class.java, 0L) + listener.takeEvent(CallEnd::class.java, 0L) + } + + private fun enableTlsWithTunnel() { + client = client.newBuilder() + .sslSocketFactory( + handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager + ) + .hostnameVerifier(RecordingHostnameVerifier()) + .build() + server.useHttps(handshakeCertificates.sslSocketFactory()) + } + + @Test + fun redirectUsingSameConnectionEventSequence() { + server.enqueue( + MockResponse.Builder() + .code(HttpURLConnection.HTTP_MOVED_TEMP) + .addHeader("Location: /foo") + .build() + ) + server.enqueue(MockResponse()) + val call = client.newCall(Request.Builder().url(server.url("/")).build()) + call.execute() + assertThat(listener.recordedEventTypes()).containsExactly( + "CallStart", + "ProxySelectStart", "ProxySelectEnd", "DnsStart", "DnsEnd", + "ConnectStart", "ConnectEnd", "ConnectionAcquired", "RequestHeadersStart", + "RequestHeadersEnd", "ResponseHeadersStart", "ResponseHeadersEnd", "ResponseBodyStart", + "ResponseBodyEnd", "RequestHeadersStart", "RequestHeadersEnd", "ResponseHeadersStart", + "ResponseHeadersEnd", "ResponseBodyStart", "ResponseBodyEnd", "ConnectionReleased", + "CallEnd" + ) + } + + @Test + fun redirectUsingNewConnectionEventSequence() { + val otherServer = MockWebServer() + server.enqueue( + MockResponse.Builder() + .code(HttpURLConnection.HTTP_MOVED_TEMP) + .addHeader("Location: " + otherServer.url("/foo")) + .build() + ) + otherServer.enqueue(MockResponse()) + val call = client.newCall(Request.Builder().url(server.url("/")).build()) + call.execute() + assertThat(listener.recordedEventTypes()).containsExactly( + "CallStart", + "ProxySelectStart", + "ProxySelectEnd", + "DnsStart", + "DnsEnd", + "ConnectStart", + "ConnectEnd", + "ConnectionAcquired", + "RequestHeadersStart", + "RequestHeadersEnd", + "ResponseHeadersStart", + "ResponseHeadersEnd", + "ResponseBodyStart", + "ResponseBodyEnd", + "ConnectionReleased", + "ProxySelectStart", + "ProxySelectEnd", + "DnsStart", + "DnsEnd", + "ConnectStart", + "ConnectEnd", + "ConnectionAcquired", + "RequestHeadersStart", + "RequestHeadersEnd", + "ResponseHeadersStart", + "ResponseHeadersEnd", + "ResponseBodyStart", + "ResponseBodyEnd", + "ConnectionReleased", + "CallEnd" + ) + } + + @Test + fun applicationInterceptorProceedsMultipleTimes() { + server.enqueue(MockResponse.Builder().body("a").build()) + server.enqueue(MockResponse.Builder().body("b").build()) + client = client.newBuilder() + .addInterceptor(Interceptor { chain: Interceptor.Chain? -> + chain!!.proceed(chain.request()) + .use { a -> assertThat(a.body.string()).isEqualTo("a") } + chain.proceed(chain.request()) + }) + .build() + val call = client.newCall(Request.Builder().url(server.url("/")).build()) + val response = call.execute() + assertThat(response.body.string()).isEqualTo("b") + assertThat(listener.recordedEventTypes()).containsExactly( + "CallStart", + "ProxySelectStart", "ProxySelectEnd", "DnsStart", "DnsEnd", + "ConnectStart", "ConnectEnd", "ConnectionAcquired", "RequestHeadersStart", + "RequestHeadersEnd", "ResponseHeadersStart", "ResponseHeadersEnd", "ResponseBodyStart", + "ResponseBodyEnd", "RequestHeadersStart", "RequestHeadersEnd", "ResponseHeadersStart", + "ResponseHeadersEnd", "ResponseBodyStart", "ResponseBodyEnd", "ConnectionReleased", + "CallEnd" + ) + assertThat(server.takeRequest().sequenceNumber).isEqualTo(0) + assertThat(server.takeRequest().sequenceNumber).isEqualTo(1) + } + + @Test + fun applicationInterceptorShortCircuit() { + client = client.newBuilder() + .addInterceptor(Interceptor { chain: Interceptor.Chain? -> + Response.Builder() + .request(chain!!.request()) + .protocol(Protocol.HTTP_1_1) + .code(200) + .message("OK") + .body("a".toResponseBody(null)) + .build() + }) + .build() + val call = client.newCall(Request.Builder().url(server.url("/")).build()) + val response = call.execute() + assertThat(response.body.string()).isEqualTo("a") + assertThat(listener.recordedEventTypes()) + .containsExactly("CallStart", "CallEnd") + } + + /** Response headers start, then the entire request body, then response headers end. */ + @Test + fun expectContinueStartsResponseHeadersEarly() { + server.enqueue( + MockResponse.Builder() + .add100Continue() + .build() + ) + val request = Request.Builder() + .url(server.url("/")) + .header("Expect", "100-continue") + .post("abc".toRequestBody("text/plain".toMediaType())) + .build() + val call = client.newCall(request) + call.execute() + assertThat(listener.recordedEventTypes()).containsExactly( + "CallStart", "ProxySelectStart", "ProxySelectEnd", "DnsStart", "DnsEnd", "ConnectStart", + "ConnectEnd", "ConnectionAcquired", "RequestHeadersStart", "RequestHeadersEnd", + "ResponseHeadersStart", "RequestBodyStart", "RequestBodyEnd", "ResponseHeadersEnd", + "ResponseBodyStart", "ResponseBodyEnd", "ConnectionReleased", "CallEnd" + ) + } + + @Test + fun timeToFirstByteGapBetweenResponseHeaderStartAndEnd() { + val responseHeadersStartDelay = 250L + server.enqueue( + MockResponse.Builder() + .add100Continue() + .headersDelay(responseHeadersStartDelay, TimeUnit.MILLISECONDS) + .build() + ) + val request = Request.Builder() + .url(server.url("/")) + .header("Expect", "100-continue") + .post("abc".toRequestBody("text/plain".toMediaType())) + .build() + val call = client.newCall(request) + call.execute() + .use { response -> assertThat(response.body.string()).isEqualTo("") } + listener.removeUpToEvent() + listener.takeEvent(RequestBodyStart::class.java, 0L) + listener.takeEvent(RequestBodyEnd::class.java, 0L) + listener.takeEvent(ResponseHeadersEnd::class.java, responseHeadersStartDelay) + } + + @Test + fun cacheMiss() { + enableCache() + server.enqueue( + MockResponse.Builder() + .body("abc") + .build() + ) + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + val response = call.execute() + assertThat(response.code).isEqualTo(200) + assertThat(response.body.string()).isEqualTo("abc") + response.close() + assertThat(listener.recordedEventTypes()).containsExactly( + "CallStart", "CacheMiss", + "ProxySelectStart", "ProxySelectEnd", "DnsStart", "DnsEnd", + "ConnectStart", "ConnectEnd", "ConnectionAcquired", "RequestHeadersStart", + "RequestHeadersEnd", "ResponseHeadersStart", "ResponseHeadersEnd", + "ResponseBodyStart", "ResponseBodyEnd", "ConnectionReleased", "CallEnd" + ) + } + + @Test + fun conditionalCache() { + enableCache() + server.enqueue( + MockResponse.Builder() + .addHeader("ETag", "v1") + .body("abc") + .build() + ) + server.enqueue( + MockResponse.Builder() + .code(HttpURLConnection.HTTP_NOT_MODIFIED) + .build() + ) + var call = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + var response = call.execute() + assertThat(response.code).isEqualTo(200) + response.close() + listener.clearAllEvents() + call = call.clone() + response = call.execute() + assertThat(response.code).isEqualTo(200) + assertThat(response.body.string()).isEqualTo("abc") + response.close() + assertThat(listener.recordedEventTypes()).containsExactly( + "CallStart", "CacheConditionalHit", + "ConnectionAcquired", "RequestHeadersStart", + "RequestHeadersEnd", "ResponseHeadersStart", "ResponseHeadersEnd", + "ResponseBodyStart", "ResponseBodyEnd", "CacheHit", "ConnectionReleased", "CallEnd" + ) + } + + @Test + fun conditionalCacheMiss() { + enableCache() + server.enqueue( + MockResponse.Builder() + .addHeader("ETag: v1") + .body("abc") + .build() + ) + server.enqueue( + MockResponse.Builder() + .code(HttpURLConnection.HTTP_OK) + .addHeader("ETag: v2") + .body("abd") + .build() + ) + var call = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + var response = call.execute() + assertThat(response.code).isEqualTo(200) + response.close() + listener.clearAllEvents() + call = call.clone() + response = call.execute() + assertThat(response.code).isEqualTo(200) + assertThat(response.body.string()).isEqualTo("abd") + response.close() + assertThat(listener.recordedEventTypes()).containsExactly( + "CallStart", "CacheConditionalHit", + "ConnectionAcquired", "RequestHeadersStart", + "RequestHeadersEnd", "ResponseHeadersStart", "ResponseHeadersEnd", "CacheMiss", + "ResponseBodyStart", "ResponseBodyEnd", "ConnectionReleased", "CallEnd" + ) + } + + @Test + fun satisfactionFailure() { + enableCache() + val call = client.newCall( + Request.Builder() + .url(server.url("/")) + .cacheControl(CacheControl.FORCE_CACHE) + .build() + ) + val response = call.execute() + assertThat(response.code).isEqualTo(504) + response.close() + assertThat(listener.recordedEventTypes()) + .containsExactly("CallStart", "SatisfactionFailure", "CallEnd") + } + + @Test + fun cacheHit() { + enableCache() + server.enqueue( + MockResponse.Builder() + .body("abc") + .addHeader("cache-control: public, max-age=300") + .build() + ) + var call = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + var response = call.execute() + assertThat(response.code).isEqualTo(200) + assertThat(response.body.string()).isEqualTo("abc") + response.close() + listener.clearAllEvents() + call = call.clone() + response = call.execute() + assertThat(response.code).isEqualTo(200) + assertThat(response.body.string()).isEqualTo("abc") + response.close() + assertThat(listener.recordedEventTypes()) + .containsExactly("CallStart", "CacheHit", "CallEnd") + } + + private fun enableCache(): Cache? { + cache = makeCache() + client = client.newBuilder().cache(cache).build() + return cache + } + + private fun makeCache(): Cache { + val cacheDir = File.createTempFile("cache-", ".dir") + cacheDir.delete() + return Cache(cacheDir, (1024 * 1024).toLong()) + } + + companion object { + val anyResponse = CoreMatchers.any(Response::class.java) + } +} diff --git a/okhttp/src/test/java/okhttp3/InterceptorTest.java b/okhttp/src/test/java/okhttp3/InterceptorTest.java deleted file mode 100644 index 1b0a2fee1ce2..000000000000 --- a/okhttp/src/test/java/okhttp3/InterceptorTest.java +++ /dev/null @@ -1,918 +0,0 @@ -/* - * Copyright (C) 2014 Square, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package okhttp3; - -import java.io.IOException; -import java.net.SocketTimeoutException; -import java.time.Duration; -import java.util.Locale; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.SynchronousQueue; -import java.util.concurrent.ThreadPoolExecutor; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; -import kotlin.Unit; -import mockwebserver3.MockResponse; -import mockwebserver3.MockWebServer; -import mockwebserver3.RecordedRequest; -import mockwebserver3.SocketPolicy; -import mockwebserver3.SocketPolicy.DisconnectAtEnd; -import okio.Buffer; -import okio.BufferedSink; -import okio.ForwardingSink; -import okio.ForwardingSource; -import okio.GzipSink; -import okio.Okio; -import okio.Sink; -import okio.Source; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Tag; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.RegisterExtension; -import static okhttp3.TestUtil.assertSuppressed; -import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assertions.fail; - -@Tag("Slow") -public final class InterceptorTest { - @RegisterExtension public final OkHttpClientTestRule clientTestRule = new OkHttpClientTestRule(); - - private MockWebServer server; - private OkHttpClient client = clientTestRule.newClient(); - private final RecordingCallback callback = new RecordingCallback(); - - @BeforeEach - public void setUp(MockWebServer server) throws Exception { - this.server = server; - } - - @Test public void applicationInterceptorsCanShortCircuitResponses() throws Exception { - server.shutdown(); // Accept no connections. - - Request request = new Request.Builder() - .url("https://localhost:1/") - .build(); - - Response interceptorResponse = new Response.Builder() - .request(request) - .protocol(Protocol.HTTP_1_1) - .code(200) - .message("Intercepted!") - .body(ResponseBody.create("abc", MediaType.get("text/plain; charset=utf-8"))) - .build(); - - client = client.newBuilder() - .addInterceptor(chain -> interceptorResponse) - .build(); - - Response response = client.newCall(request).execute(); - assertThat(response).isSameAs(interceptorResponse); - } - - @Test public void networkInterceptorsCannotShortCircuitResponses() throws Exception { - server.enqueue(new MockResponse.Builder() - .code(500) - .build()); - - Interceptor interceptor = chain -> new Response.Builder() - .request(chain.request()) - .protocol(Protocol.HTTP_1_1) - .code(200) - .message("Intercepted!") - .body(ResponseBody.create("abc", MediaType.get("text/plain; charset=utf-8"))) - .build(); - client = client.newBuilder() - .addNetworkInterceptor(interceptor) - .build(); - - Request request = new Request.Builder() - .url(server.url("/")) - .build(); - - try { - client.newCall(request).execute(); - fail(); - } catch (IllegalStateException expected) { - assertThat(expected.getMessage()).isEqualTo( - ("network interceptor " + interceptor + " must call proceed() exactly once")); - } - } - - @Test public void networkInterceptorsCannotCallProceedMultipleTimes() throws Exception { - server.enqueue(new MockResponse()); - server.enqueue(new MockResponse()); - - Interceptor interceptor = chain -> { - chain.proceed(chain.request()); - return chain.proceed(chain.request()); - }; - client = client.newBuilder() - .addNetworkInterceptor(interceptor) - .build(); - - Request request = new Request.Builder() - .url(server.url("/")) - .build(); - - try { - client.newCall(request).execute(); - fail(); - } catch (IllegalStateException expected) { - assertThat(expected.getMessage()).isEqualTo( - ("network interceptor " + interceptor + " must call proceed() exactly once")); - } - } - - @Test public void networkInterceptorsCannotChangeServerAddress() throws Exception { - server.enqueue(new MockResponse.Builder() - .code(500) - .build()); - - Interceptor interceptor = chain -> { - Address address = chain.connection().route().address(); - String sameHost = address.url().host(); - int differentPort = address.url().port() + 1; - return chain.proceed(chain.request().newBuilder() - .url("http://" + sameHost + ":" + differentPort + "/") - .build()); - }; - client = client.newBuilder() - .addNetworkInterceptor(interceptor) - .build(); - - Request request = new Request.Builder() - .url(server.url("/")) - .build(); - - try { - client.newCall(request).execute(); - fail(); - } catch (IllegalStateException expected) { - assertThat(expected.getMessage()).isEqualTo( - ("network interceptor " + interceptor + " must retain the same host and port")); - } - } - - @Test public void networkInterceptorsHaveConnectionAccess() throws Exception { - server.enqueue(new MockResponse()); - - Interceptor interceptor = chain -> { - Connection connection = chain.connection(); - assertThat(connection).isNotNull(); - return chain.proceed(chain.request()); - }; - client = client.newBuilder() - .addNetworkInterceptor(interceptor) - .build(); - - Request request = new Request.Builder() - .url(server.url("/")) - .build(); - client.newCall(request).execute(); - } - - @Test public void networkInterceptorsObserveNetworkHeaders() throws Exception { - server.enqueue(new MockResponse.Builder() - .body(gzip("abcabcabc")) - .addHeader("Content-Encoding: gzip") - .build()); - - Interceptor interceptor = chain -> { - // The network request has everything: User-Agent, Host, Accept-Encoding. - Request networkRequest = chain.request(); - assertThat(networkRequest.header("User-Agent")).isNotNull(); - assertThat(networkRequest.header("Host")).isEqualTo( - (server.getHostName() + ":" + server.getPort())); - assertThat(networkRequest.header("Accept-Encoding")).isNotNull(); - - // The network response also has everything, including the raw gzipped content. - Response networkResponse = chain.proceed(networkRequest); - assertThat(networkResponse.header("Content-Encoding")).isEqualTo("gzip"); - return networkResponse; - }; - client = client.newBuilder() - .addNetworkInterceptor(interceptor) - .build(); - - Request request = new Request.Builder() - .url(server.url("/")) - .build(); - - // No extra headers in the application's request. - assertThat(request.header("User-Agent")).isNull(); - assertThat(request.header("Host")).isNull(); - assertThat(request.header("Accept-Encoding")).isNull(); - - // No extra headers in the application's response. - Response response = client.newCall(request).execute(); - assertThat(request.header("Content-Encoding")).isNull(); - assertThat(response.body().string()).isEqualTo("abcabcabc"); - } - - @Test public void networkInterceptorsCanChangeRequestMethodFromGetToPost() throws Exception { - server.enqueue(new MockResponse()); - - Interceptor interceptor = chain -> { - Request originalRequest = chain.request(); - MediaType mediaType = MediaType.get("text/plain"); - RequestBody body = RequestBody.create("abc", mediaType); - return chain.proceed(originalRequest.newBuilder() - .method("POST", body) - .header("Content-Type", mediaType.toString()) - .header("Content-Length", Long.toString(body.contentLength())) - .build()); - }; - client = client.newBuilder() - .addNetworkInterceptor(interceptor) - .build(); - - Request request = new Request.Builder() - .url(server.url("/")) - .get() - .build(); - - client.newCall(request).execute(); - - RecordedRequest recordedRequest = server.takeRequest(); - assertThat(recordedRequest.getMethod()).isEqualTo("POST"); - assertThat(recordedRequest.getBody().readUtf8()).isEqualTo("abc"); - } - - @Test public void applicationInterceptorsRewriteRequestToServer() throws Exception { - rewriteRequestToServer(false); - } - - @Test public void networkInterceptorsRewriteRequestToServer() throws Exception { - rewriteRequestToServer(true); - } - - private void rewriteRequestToServer(boolean network) throws Exception { - server.enqueue(new MockResponse()); - - addInterceptor(network, chain -> { - Request originalRequest = chain.request(); - return chain.proceed(originalRequest.newBuilder() - .method("POST", uppercase(originalRequest.body())) - .addHeader("OkHttp-Intercepted", "yep") - .build()); - }); - - Request request = new Request.Builder() - .url(server.url("/")) - .addHeader("Original-Header", "foo") - .method("PUT", RequestBody.create("abc", MediaType.get("text/plain"))) - .build(); - - client.newCall(request).execute(); - - RecordedRequest recordedRequest = server.takeRequest(); - assertThat(recordedRequest.getBody().readUtf8()).isEqualTo("ABC"); - assertThat(recordedRequest.getHeaders().get("Original-Header")).isEqualTo("foo"); - assertThat(recordedRequest.getHeaders().get("OkHttp-Intercepted")).isEqualTo("yep"); - assertThat(recordedRequest.getMethod()).isEqualTo("POST"); - } - - @Test public void applicationInterceptorsRewriteResponseFromServer() throws Exception { - rewriteResponseFromServer(false); - } - - @Test public void networkInterceptorsRewriteResponseFromServer() throws Exception { - rewriteResponseFromServer(true); - } - - private void rewriteResponseFromServer(boolean network) throws Exception { - server.enqueue(new MockResponse.Builder() - .addHeader("Original-Header: foo") - .body("abc") - .build()); - - addInterceptor(network, chain -> { - Response originalResponse = chain.proceed(chain.request()); - return originalResponse.newBuilder() - .body(uppercase(originalResponse.body())) - .addHeader("OkHttp-Intercepted", "yep") - .build(); - }); - - Request request = new Request.Builder() - .url(server.url("/")) - .build(); - - Response response = client.newCall(request).execute(); - assertThat(response.body().string()).isEqualTo("ABC"); - assertThat(response.header("OkHttp-Intercepted")).isEqualTo("yep"); - assertThat(response.header("Original-Header")).isEqualTo("foo"); - } - - @Test public void multipleApplicationInterceptors() throws Exception { - multipleInterceptors(false); - } - - @Test public void multipleNetworkInterceptors() throws Exception { - multipleInterceptors(true); - } - - private void multipleInterceptors(boolean network) throws Exception { - server.enqueue(new MockResponse()); - - addInterceptor(network, chain -> { - Request originalRequest = chain.request(); - Response originalResponse = chain.proceed(originalRequest.newBuilder() - .addHeader("Request-Interceptor", "Android") // 1. Added first. - .build()); - return originalResponse.newBuilder() - .addHeader("Response-Interceptor", "Donut") // 4. Added last. - .build(); - }); - addInterceptor(network, chain -> { - Request originalRequest = chain.request(); - Response originalResponse = chain.proceed(originalRequest.newBuilder() - .addHeader("Request-Interceptor", "Bob") // 2. Added second. - .build()); - return originalResponse.newBuilder() - .addHeader("Response-Interceptor", "Cupcake") // 3. Added third. - .build(); - }); - - Request request = new Request.Builder() - .url(server.url("/")) - .build(); - - Response response = client.newCall(request).execute(); - assertThat(response.headers("Response-Interceptor")).containsExactly("Cupcake", "Donut"); - - RecordedRequest recordedRequest = server.takeRequest(); - assertThat(recordedRequest.getHeaders().values("Request-Interceptor")) - .containsExactly("Android", "Bob"); - } - - @Test public void asyncApplicationInterceptors() throws Exception { - asyncInterceptors(false); - } - - @Test public void asyncNetworkInterceptors() throws Exception { - asyncInterceptors(true); - } - - private void asyncInterceptors(boolean network) throws Exception { - server.enqueue(new MockResponse()); - - addInterceptor(network, chain -> { - Response originalResponse = chain.proceed(chain.request()); - return originalResponse.newBuilder() - .addHeader("OkHttp-Intercepted", "yep") - .build(); - }); - - Request request = new Request.Builder() - .url(server.url("/")) - .build(); - client.newCall(request).enqueue(callback); - - callback.await(request.url()) - .assertCode(200) - .assertHeader("OkHttp-Intercepted", "yep"); - } - - @Test public void applicationInterceptorsCanMakeMultipleRequestsToServer() throws Exception { - server.enqueue(new MockResponse.Builder().body("a").build()); - server.enqueue(new MockResponse.Builder().body("b").build()); - - client = client.newBuilder() - .addInterceptor(chain -> { - Response response1 = chain.proceed(chain.request()); - response1.body().close(); - return chain.proceed(chain.request()); - }) - .build(); - - Request request = new Request.Builder() - .url(server.url("/")) - .build(); - - Response response = client.newCall(request).execute(); - assertThat("b").isEqualTo(response.body().string()); - } - - /** Make sure interceptors can interact with the OkHttp client. */ - @Test public void interceptorMakesAnUnrelatedRequest() throws Exception { - server.enqueue(new MockResponse.Builder().body("a").build()); // Fetched by interceptor. - server.enqueue(new MockResponse.Builder().body("b").build()); // Fetched directly. - - client = client.newBuilder() - .addInterceptor(chain -> { - if (chain.request().url().encodedPath().equals("/b")) { - Request requestA = new Request.Builder() - .url(server.url("/a")) - .build(); - Response responseA = client.newCall(requestA).execute(); - assertThat(responseA.body().string()).isEqualTo("a"); - } - - return chain.proceed(chain.request()); - }) - .build(); - - Request requestB = new Request.Builder() - .url(server.url("/b")) - .build(); - Response responseB = client.newCall(requestB).execute(); - assertThat(responseB.body().string()).isEqualTo("b"); - } - - /** Make sure interceptors can interact with the OkHttp client asynchronously. */ - @Test public void interceptorMakesAnUnrelatedAsyncRequest() throws Exception { - server.enqueue(new MockResponse.Builder().body("a").build()); // Fetched by interceptor. - server.enqueue(new MockResponse.Builder().body("b").build()); // Fetched directly. - - client = client.newBuilder() - .addInterceptor(chain -> { - if (chain.request().url().encodedPath().equals("/b")) { - Request requestA = new Request.Builder() - .url(server.url("/a")) - .build(); - - try { - RecordingCallback callbackA = new RecordingCallback(); - client.newCall(requestA).enqueue(callbackA); - callbackA.await(requestA.url()).assertBody("a"); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - return chain.proceed(chain.request()); - }) - .build(); - - Request requestB = new Request.Builder() - .url(server.url("/b")) - .build(); - RecordingCallback callbackB = new RecordingCallback(); - client.newCall(requestB).enqueue(callbackB); - callbackB.await(requestB.url()).assertBody("b"); - } - - @Test public void applicationInterceptorThrowsRuntimeExceptionSynchronous() throws Exception { - interceptorThrowsRuntimeExceptionSynchronous(false); - } - - @Test public void networkInterceptorThrowsRuntimeExceptionSynchronous() throws Exception { - interceptorThrowsRuntimeExceptionSynchronous(true); - } - - /** - * When an interceptor throws an unexpected exception, synchronous callers can catch it and deal - * with it. - */ - private void interceptorThrowsRuntimeExceptionSynchronous(boolean network) throws Exception { - addInterceptor(network, chain -> { throw new RuntimeException("boom!"); }); - - Request request = new Request.Builder() - .url(server.url("/")) - .build(); - - try { - client.newCall(request).execute(); - fail(); - } catch (RuntimeException expected) { - assertThat(expected.getMessage()).isEqualTo("boom!"); - } - } - - @Test public void networkInterceptorModifiedRequestIsReturned() throws IOException { - server.enqueue(new MockResponse()); - - Interceptor modifyHeaderInterceptor = chain -> { - Request modifiedRequest = chain.request() - .newBuilder() - .header("User-Agent", "intercepted request") - .build(); - return chain.proceed(modifiedRequest); - }; - - client = client.newBuilder() - .addNetworkInterceptor(modifyHeaderInterceptor) - .build(); - - Request request = new Request.Builder() - .url(server.url("/")) - .header("User-Agent", "user request") - .build(); - - Response response = client.newCall(request).execute(); - assertThat(response.request().header("User-Agent")).isNotNull(); - assertThat(response.request().header("User-Agent")).isEqualTo("user request"); - assertThat(response.networkResponse().request().header("User-Agent")).isEqualTo( - "intercepted request"); - } - - @Test public void applicationInterceptorThrowsRuntimeExceptionAsynchronous() throws Exception { - interceptorThrowsRuntimeExceptionAsynchronous(false); - } - - @Test public void networkInterceptorThrowsRuntimeExceptionAsynchronous() throws Exception { - interceptorThrowsRuntimeExceptionAsynchronous(true); - } - - /** - * When an interceptor throws an unexpected exception, asynchronous calls are canceled. The - * exception goes to the uncaught exception handler. - */ - private void interceptorThrowsRuntimeExceptionAsynchronous(boolean network) throws Exception { - RuntimeException boom = new RuntimeException("boom!"); - addInterceptor(network, chain -> { throw boom; }); - - ExceptionCatchingExecutor executor = new ExceptionCatchingExecutor(); - client = client.newBuilder() - .dispatcher(new Dispatcher(executor)) - .build(); - - Request request = new Request.Builder() - .url(server.url("/")) - .build(); - Call call = client.newCall(request); - call.enqueue(callback); - RecordedResponse recordedResponse = callback.await(server.url("/")); - assertThat(recordedResponse.failure) - .hasMessage("canceled due to java.lang.RuntimeException: boom!"); - assertSuppressed(recordedResponse.failure, throwables -> { - assertThat(throwables).contains(boom); - return Unit.INSTANCE; - }); - assertThat(call.isCanceled()).isTrue(); - - assertThat(executor.takeException()).isEqualTo(boom); - } - - @Test public void applicationInterceptorReturnsNull() throws Exception { - server.enqueue(new MockResponse()); - - Interceptor interceptor = chain -> { - chain.proceed(chain.request()); - return null; - }; - client = client.newBuilder() - .addInterceptor(interceptor) - .build(); - - ExceptionCatchingExecutor executor = new ExceptionCatchingExecutor(); - client = client.newBuilder() - .dispatcher(new Dispatcher(executor)) - .build(); - - Request request = new Request.Builder() - .url(server.url("/")) - .build(); - try { - client.newCall(request).execute(); - fail(); - } catch (NullPointerException expected) { - assertThat(expected.getMessage()).isEqualTo( - ("interceptor " + interceptor + " returned null")); - } - } - - @Test public void networkInterceptorReturnsNull() throws Exception { - server.enqueue(new MockResponse()); - - Interceptor interceptor = chain -> { - chain.proceed(chain.request()); - return null; - }; - client = client.newBuilder() - .addNetworkInterceptor(interceptor) - .build(); - - ExceptionCatchingExecutor executor = new ExceptionCatchingExecutor(); - client = client.newBuilder() - .dispatcher(new Dispatcher(executor)) - .build(); - - Request request = new Request.Builder() - .url(server.url("/")) - .build(); - try { - client.newCall(request).execute(); - fail(); - } catch (NullPointerException expected) { - assertThat(expected.getMessage()).isEqualTo( - ("interceptor " + interceptor + " returned null")); - } - } - - @Test public void networkInterceptorReturnsConnectionOnEmptyBody() throws Exception { - server.enqueue(new MockResponse.Builder() - .socketPolicy(DisconnectAtEnd.INSTANCE) - .addHeader("Connection", "Close") - .build()); - - Interceptor interceptor = chain -> { - Response response = chain.proceed(chain.request()); - assertThat(chain.connection()).isNotNull(); - return response; - }; - - client = client.newBuilder() - .addNetworkInterceptor(interceptor) - .build(); - - Request request = new Request.Builder() - .url(server.url("/")) - .build(); - - Response response = client.newCall(request).execute(); - response.body().close(); - } - - @Test public void connectTimeout() throws Exception { - Interceptor interceptor1 = chainA -> { - assertThat(chainA.connectTimeoutMillis()).isEqualTo(5000); - - Interceptor.Chain chainB = chainA.withConnectTimeout(100, TimeUnit.MILLISECONDS); - assertThat(chainB.connectTimeoutMillis()).isEqualTo(100); - - return chainB.proceed(chainA.request()); - }; - - Interceptor interceptor2 = chain -> { - assertThat(chain.connectTimeoutMillis()).isEqualTo(100); - return chain.proceed(chain.request()); - }; - - client = client.newBuilder() - .connectTimeout(Duration.ofSeconds(5)) - .addInterceptor(interceptor1) - .addInterceptor(interceptor2) - .build(); - - Request request1 = - new Request.Builder() - .url("http://" + TestUtil.UNREACHABLE_ADDRESS_IPV4) - .build(); - Call call = client.newCall(request1); - - long startNanos = System.nanoTime(); - try { - call.execute(); - fail(); - } catch (SocketTimeoutException expected) { - } - long elapsedNanos = System.nanoTime() - startNanos; - - assertTrue(elapsedNanos < TimeUnit.SECONDS.toNanos(5), - "Timeout should have taken ~100ms but was " + (elapsedNanos / 1e6) + " ms"); - } - - @Test public void chainWithReadTimeout() throws Exception { - Interceptor interceptor1 = chainA -> { - assertThat(chainA.readTimeoutMillis()).isEqualTo(5000); - - Interceptor.Chain chainB = chainA.withReadTimeout(100, TimeUnit.MILLISECONDS); - assertThat(chainB.readTimeoutMillis()).isEqualTo(100); - - return chainB.proceed(chainA.request()); - }; - - Interceptor interceptor2 = chain -> { - assertThat(chain.readTimeoutMillis()).isEqualTo(100); - return chain.proceed(chain.request()); - }; - - client = client.newBuilder() - .readTimeout(Duration.ofSeconds(5)) - .addInterceptor(interceptor1) - .addInterceptor(interceptor2) - .build(); - - server.enqueue(new MockResponse.Builder() - .body("abc") - .throttleBody(1, 1, TimeUnit.SECONDS) - .build()); - - Request request1 = new Request.Builder() - .url(server.url("/")) - .build(); - Call call = client.newCall(request1); - Response response = call.execute(); - ResponseBody body = response.body(); - try { - body.string(); - fail(); - } catch (SocketTimeoutException expected) { - } - } - - @Test public void networkInterceptorCannotChangeReadTimeout() throws Exception { - addInterceptor(true, chain -> - chain.withReadTimeout(100, TimeUnit.MILLISECONDS).proceed(chain.request())); - - Request request1 = new Request.Builder().url(server.url("/")).build(); - Call call = client.newCall(request1); - try { - call.execute(); - fail(); - } catch (IllegalStateException expected) { - assertThat(expected.getMessage()).isEqualTo("Timeouts can't be adjusted in a network interceptor"); - } - } - - @Test public void networkInterceptorCannotChangeWriteTimeout() throws Exception { - addInterceptor(true, chain -> - chain.withWriteTimeout(100, TimeUnit.MILLISECONDS).proceed(chain.request())); - - Request request1 = new Request.Builder().url(server.url("/")).build(); - Call call = client.newCall(request1); - try { - call.execute(); - fail(); - } catch (IllegalStateException expected) { - assertThat(expected.getMessage()).isEqualTo("Timeouts can't be adjusted in a network interceptor"); - } - } - - @Test public void networkInterceptorCannotChangeConnectTimeout() throws Exception { - addInterceptor(true, chain -> - chain.withConnectTimeout(100, TimeUnit.MILLISECONDS).proceed(chain.request())); - - Request request1 = new Request.Builder().url(server.url("/")).build(); - Call call = client.newCall(request1); - try { - call.execute(); - fail(); - } catch (IllegalStateException expected) { - assertThat(expected.getMessage()).isEqualTo("Timeouts can't be adjusted in a network interceptor"); - } - } - - @Test public void chainWithWriteTimeout() throws Exception { - Interceptor interceptor1 = chainA -> { - assertThat(chainA.writeTimeoutMillis()).isEqualTo(5000); - - Interceptor.Chain chainB = chainA.withWriteTimeout(100, TimeUnit.MILLISECONDS); - assertThat(chainB.writeTimeoutMillis()).isEqualTo(100); - - return chainB.proceed(chainA.request()); - }; - - Interceptor interceptor2 = chain -> { - assertThat(chain.writeTimeoutMillis()).isEqualTo(100); - return chain.proceed(chain.request()); - }; - - client = client.newBuilder() - .writeTimeout(Duration.ofSeconds(5)) - .addInterceptor(interceptor1) - .addInterceptor(interceptor2) - .build(); - - server.enqueue(new MockResponse.Builder() - .body("abc") - .throttleBody(1, 1, TimeUnit.SECONDS) - .build()); - - byte[] data = new byte[2 * 1024 * 1024]; // 2 MiB. - Request request1 = new Request.Builder() - .url(server.url("/")) - .post(RequestBody.create(data, MediaType.get("text/plain"))) - .build(); - Call call = client.newCall(request1); - - try { - call.execute(); // we want this call to throw a SocketTimeoutException - fail(); - } catch (SocketTimeoutException expected) { - } - } - - @Test public void chainCanCancelCall() throws Exception { - AtomicReference callRef = new AtomicReference<>(); - - Interceptor interceptor = chain -> { - Call call = chain.call(); - callRef.set(call); - - assertThat(call.isCanceled()).isFalse(); - call.cancel(); - assertThat(call.isCanceled()).isTrue(); - - return chain.proceed(chain.request()); - }; - - client = client.newBuilder() - .addInterceptor(interceptor) - .build(); - - Request request = new Request.Builder() - .url(server.url("/")) - .build(); - Call call = client.newCall(request); - - try { - call.execute(); - fail(); - } catch (IOException expected) { - } - - assertThat(callRef.get()).isSameAs(call); - } - - private RequestBody uppercase(RequestBody original) { - return new RequestBody() { - @Override public MediaType contentType() { - return original.contentType(); - } - - @Override public long contentLength() throws IOException { - return original.contentLength(); - } - - @Override public void writeTo(BufferedSink sink) throws IOException { - Sink uppercase = uppercase(sink); - BufferedSink bufferedSink = Okio.buffer(uppercase); - original.writeTo(bufferedSink); - bufferedSink.emit(); - } - }; - } - - private Sink uppercase(BufferedSink original) { - return new ForwardingSink(original) { - @Override public void write(Buffer source, long byteCount) throws IOException { - original.writeUtf8(source.readUtf8(byteCount).toUpperCase(Locale.US)); - } - }; - } - - static ResponseBody uppercase(ResponseBody original) throws IOException { - return ResponseBody.create(Okio.buffer(uppercase(original.source())), - original.contentType(), original.contentLength()); - } - - private static Source uppercase(Source original) { - return new ForwardingSource(original) { - @Override public long read(Buffer sink, long byteCount) throws IOException { - Buffer mixedCase = new Buffer(); - long count = original.read(mixedCase, byteCount); - sink.writeUtf8(mixedCase.readUtf8().toUpperCase(Locale.US)); - return count; - } - }; - } - - private Buffer gzip(String data) throws IOException { - Buffer result = new Buffer(); - BufferedSink sink = Okio.buffer(new GzipSink(result)); - sink.writeUtf8(data); - sink.close(); - return result; - } - - private void addInterceptor(boolean network, Interceptor interceptor) { - OkHttpClient.Builder builder = client.newBuilder(); - if (network) { - builder.addNetworkInterceptor(interceptor); - } else { - builder.addInterceptor(interceptor); - } - client = builder.build(); - } - - /** Catches exceptions that are otherwise headed for the uncaught exception handler. */ - private static class ExceptionCatchingExecutor extends ThreadPoolExecutor { - private final BlockingQueue exceptions = new LinkedBlockingQueue<>(); - - public ExceptionCatchingExecutor() { - super(1, 1, 0, TimeUnit.SECONDS, new SynchronousQueue<>()); - } - - @Override public void execute(Runnable runnable) { - super.execute(() -> { - try { - runnable.run(); - } catch (Exception e) { - exceptions.add(e); - } - }); - } - - public Exception takeException() throws Exception { - return exceptions.take(); - } - } -} diff --git a/okhttp/src/test/java/okhttp3/InterceptorTest.kt b/okhttp/src/test/java/okhttp3/InterceptorTest.kt new file mode 100644 index 000000000000..ca4260b43aed --- /dev/null +++ b/okhttp/src/test/java/okhttp3/InterceptorTest.kt @@ -0,0 +1,854 @@ +/* + * Copyright (C) 2014 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3 + +import java.io.IOException +import java.net.SocketTimeoutException +import java.time.Duration +import java.util.concurrent.BlockingQueue +import java.util.concurrent.LinkedBlockingQueue +import java.util.concurrent.SynchronousQueue +import java.util.concurrent.ThreadPoolExecutor +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicReference +import mockwebserver3.MockResponse +import mockwebserver3.MockWebServer +import mockwebserver3.SocketPolicy.DisconnectAtEnd +import okhttp3.MediaType.Companion.toMediaType +import okhttp3.RequestBody.Companion.toRequestBody +import okhttp3.ResponseBody.Companion.toResponseBody +import okhttp3.TestUtil.assertSuppressed +import okio.Buffer +import okio.BufferedSink +import okio.ForwardingSink +import okio.ForwardingSource +import okio.GzipSink +import okio.Sink +import okio.Source +import okio.buffer +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Assertions.fail +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Tag +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.RegisterExtension + +@Tag("Slow") +class InterceptorTest { + @RegisterExtension + val clientTestRule = OkHttpClientTestRule() + private lateinit var server: MockWebServer + private var client = clientTestRule.newClient() + private val callback = RecordingCallback() + + @BeforeEach + fun setUp(server: MockWebServer) { + this.server = server + } + + @Test + fun applicationInterceptorsCanShortCircuitResponses() { + server.shutdown() // Accept no connections. + val request = Request.Builder() + .url("https://localhost:1/") + .build() + val interceptorResponse = Response.Builder() + .request(request) + .protocol(Protocol.HTTP_1_1) + .code(200) + .message("Intercepted!") + .body("abc".toResponseBody("text/plain; charset=utf-8".toMediaType())) + .build() + client = client.newBuilder() + .addInterceptor(Interceptor { chain: Interceptor.Chain? -> interceptorResponse }) + .build() + val response = client.newCall(request).execute() + assertThat(response).isSameAs(interceptorResponse) + } + + @Test + fun networkInterceptorsCannotShortCircuitResponses() { + server.enqueue( + MockResponse.Builder() + .code(500) + .build() + ) + val interceptor = Interceptor { chain: Interceptor.Chain -> + Response.Builder() + .request(chain.request()) + .protocol(Protocol.HTTP_1_1) + .code(200) + .message("Intercepted!") + .body("abc".toResponseBody("text/plain; charset=utf-8".toMediaType())) + .build() + } + client = client.newBuilder() + .addNetworkInterceptor(interceptor) + .build() + val request = Request.Builder() + .url(server.url("/")) + .build() + try { + client.newCall(request).execute() + fail() + } catch (expected: IllegalStateException) { + assertThat(expected.message).isEqualTo( + "network interceptor $interceptor must call proceed() exactly once" + ) + } + } + + @Test + fun networkInterceptorsCannotCallProceedMultipleTimes() { + server.enqueue(MockResponse()) + server.enqueue(MockResponse()) + val interceptor = Interceptor { chain: Interceptor.Chain -> + chain.proceed(chain.request()) + chain.proceed(chain.request()) + } + client = client.newBuilder() + .addNetworkInterceptor(interceptor) + .build() + val request = Request.Builder() + .url(server.url("/")) + .build() + try { + client.newCall(request).execute() + fail() + } catch (expected: IllegalStateException) { + assertThat(expected.message).isEqualTo( + "network interceptor $interceptor must call proceed() exactly once" + ) + } + } + + @Test + fun networkInterceptorsCannotChangeServerAddress() { + server.enqueue( + MockResponse.Builder() + .code(500) + .build() + ) + val interceptor = Interceptor { chain: Interceptor.Chain -> + val address = chain.connection()!!.route().address + val sameHost = address.url.host + val differentPort = address.url.port + 1 + chain.proceed( + chain.request().newBuilder() + .url("http://$sameHost:$differentPort/") + .build() + ) + } + client = client.newBuilder() + .addNetworkInterceptor(interceptor) + .build() + val request = Request.Builder() + .url(server.url("/")) + .build() + try { + client.newCall(request).execute() + fail() + } catch (expected: IllegalStateException) { + assertThat(expected.message).isEqualTo( + "network interceptor $interceptor must retain the same host and port" + ) + } + } + + @Test + fun networkInterceptorsHaveConnectionAccess() { + server.enqueue(MockResponse()) + val interceptor = Interceptor { chain: Interceptor.Chain -> + val connection = chain.connection() + assertThat(connection).isNotNull() + chain.proceed(chain.request()) + } + client = client.newBuilder() + .addNetworkInterceptor(interceptor) + .build() + val request = Request.Builder() + .url(server.url("/")) + .build() + client.newCall(request).execute() + } + + @Test + fun networkInterceptorsObserveNetworkHeaders() { + server.enqueue( + MockResponse.Builder() + .body(gzip("abcabcabc")) + .addHeader("Content-Encoding: gzip") + .build() + ) + val interceptor = Interceptor { chain: Interceptor.Chain -> + // The network request has everything: User-Agent, Host, Accept-Encoding. + val networkRequest = chain.request() + assertThat(networkRequest.header("User-Agent")).isNotNull() + assertThat(networkRequest.header("Host")).isEqualTo( + server.hostName + ":" + server.port + ) + assertThat(networkRequest.header("Accept-Encoding")).isNotNull() + + // The network response also has everything, including the raw gzipped content. + val networkResponse = chain.proceed(networkRequest) + assertThat(networkResponse.header("Content-Encoding")).isEqualTo("gzip") + networkResponse + } + client = client.newBuilder() + .addNetworkInterceptor(interceptor) + .build() + val request = Request.Builder() + .url(server.url("/")) + .build() + + // No extra headers in the application's request. + assertThat(request.header("User-Agent")).isNull() + assertThat(request.header("Host")).isNull() + assertThat(request.header("Accept-Encoding")).isNull() + + // No extra headers in the application's response. + val response = client.newCall(request).execute() + assertThat(request.header("Content-Encoding")).isNull() + assertThat(response.body.string()).isEqualTo("abcabcabc") + } + + @Test + fun networkInterceptorsCanChangeRequestMethodFromGetToPost() { + server.enqueue(MockResponse()) + val interceptor = Interceptor { chain: Interceptor.Chain -> + val originalRequest = chain.request() + val mediaType = "text/plain".toMediaType() + val body = "abc".toRequestBody(mediaType) + chain.proceed( + originalRequest.newBuilder() + .method("POST", body) + .header("Content-Type", mediaType.toString()) + .header("Content-Length", body.contentLength().toString()) + .build() + ) + } + client = client.newBuilder() + .addNetworkInterceptor(interceptor) + .build() + val request = Request.Builder() + .url(server.url("/")) + .get() + .build() + client.newCall(request).execute() + val recordedRequest = server.takeRequest() + assertThat(recordedRequest.method).isEqualTo("POST") + assertThat(recordedRequest.body.readUtf8()).isEqualTo("abc") + } + + @Test + fun applicationInterceptorsRewriteRequestToServer() { + rewriteRequestToServer(false) + } + + @Test + fun networkInterceptorsRewriteRequestToServer() { + rewriteRequestToServer(true) + } + + private fun rewriteRequestToServer(network: Boolean) { + server.enqueue(MockResponse()) + addInterceptor(network) { chain: Interceptor.Chain -> + val originalRequest = chain.request() + chain.proceed( + originalRequest.newBuilder() + .method("POST", uppercase(originalRequest.body)) + .addHeader("OkHttp-Intercepted", "yep") + .build() + ) + } + val request = Request.Builder() + .url(server.url("/")) + .addHeader("Original-Header", "foo") + .method("PUT", "abc".toRequestBody("text/plain".toMediaType())) + .build() + client.newCall(request).execute() + val recordedRequest = server.takeRequest() + assertThat(recordedRequest.body.readUtf8()).isEqualTo("ABC") + assertThat(recordedRequest.headers["Original-Header"]).isEqualTo("foo") + assertThat(recordedRequest.headers["OkHttp-Intercepted"]).isEqualTo("yep") + assertThat(recordedRequest.method).isEqualTo("POST") + } + + @Test + fun applicationInterceptorsRewriteResponseFromServer() { + rewriteResponseFromServer(false) + } + + @Test + fun networkInterceptorsRewriteResponseFromServer() { + rewriteResponseFromServer(true) + } + + private fun rewriteResponseFromServer(network: Boolean) { + server.enqueue( + MockResponse.Builder() + .addHeader("Original-Header: foo") + .body("abc") + .build() + ) + addInterceptor(network) { chain: Interceptor.Chain -> + val originalResponse = chain.proceed(chain.request()) + originalResponse.newBuilder() + .body(uppercase(originalResponse.body)) + .addHeader("OkHttp-Intercepted", "yep") + .build() + } + val request = Request.Builder() + .url(server.url("/")) + .build() + val response = client.newCall(request).execute() + assertThat(response.body.string()).isEqualTo("ABC") + assertThat(response.header("OkHttp-Intercepted")).isEqualTo("yep") + assertThat(response.header("Original-Header")).isEqualTo("foo") + } + + @Test + fun multipleApplicationInterceptors() { + multipleInterceptors(false) + } + + @Test + fun multipleNetworkInterceptors() { + multipleInterceptors(true) + } + + private fun multipleInterceptors(network: Boolean) { + server.enqueue(MockResponse()) + addInterceptor(network) { chain: Interceptor.Chain -> + val originalRequest = chain.request() + val originalResponse = chain.proceed( + originalRequest.newBuilder() + .addHeader("Request-Interceptor", "Android") // 1. Added first. + .build() + ) + originalResponse.newBuilder() + .addHeader("Response-Interceptor", "Donut") // 4. Added last. + .build() + } + addInterceptor(network) { chain: Interceptor.Chain -> + val originalRequest = chain.request() + val originalResponse = chain.proceed( + originalRequest.newBuilder() + .addHeader("Request-Interceptor", "Bob") // 2. Added second. + .build() + ) + originalResponse.newBuilder() + .addHeader("Response-Interceptor", "Cupcake") // 3. Added third. + .build() + } + val request = Request.Builder() + .url(server.url("/")) + .build() + val response = client.newCall(request).execute() + assertThat(response.headers("Response-Interceptor")) + .containsExactly("Cupcake", "Donut") + val recordedRequest = server.takeRequest() + assertThat(recordedRequest.headers.values("Request-Interceptor")) + .containsExactly("Android", "Bob") + } + + @Test + fun asyncApplicationInterceptors() { + asyncInterceptors(false) + } + + @Test + fun asyncNetworkInterceptors() { + asyncInterceptors(true) + } + + private fun asyncInterceptors(network: Boolean) { + server.enqueue(MockResponse()) + addInterceptor(network) { chain: Interceptor.Chain -> + val originalResponse = chain.proceed(chain.request()) + originalResponse.newBuilder() + .addHeader("OkHttp-Intercepted", "yep") + .build() + } + val request = Request.Builder() + .url(server.url("/")) + .build() + client.newCall(request).enqueue(callback) + callback.await(request.url) + .assertCode(200) + .assertHeader("OkHttp-Intercepted", "yep") + } + + @Test + fun applicationInterceptorsCanMakeMultipleRequestsToServer() { + server.enqueue(MockResponse.Builder().body("a").build()) + server.enqueue(MockResponse.Builder().body("b").build()) + client = client.newBuilder() + .addInterceptor(Interceptor { chain: Interceptor.Chain -> + val response1 = chain.proceed(chain.request()) + response1.body.close() + chain.proceed(chain.request()) + }) + .build() + val request = Request.Builder() + .url(server.url("/")) + .build() + val response = client.newCall(request).execute() + assertThat("b").isEqualTo(response.body.string()) + } + + /** Make sure interceptors can interact with the OkHttp client. */ + @Test + fun interceptorMakesAnUnrelatedRequest() { + server.enqueue(MockResponse.Builder().body("a").build()) // Fetched by interceptor. + server.enqueue(MockResponse.Builder().body("b").build()) // Fetched directly. + client = client.newBuilder() + .addInterceptor(Interceptor { chain: Interceptor.Chain -> + if (chain.request().url.encodedPath == "/b") { + val requestA = Request.Builder() + .url(server.url("/a")) + .build() + val responseA = client.newCall(requestA).execute() + assertThat(responseA.body.string()).isEqualTo("a") + } + chain.proceed(chain.request()) + }) + .build() + val requestB = Request.Builder() + .url(server.url("/b")) + .build() + val responseB = client.newCall(requestB).execute() + assertThat(responseB.body.string()).isEqualTo("b") + } + + /** Make sure interceptors can interact with the OkHttp client asynchronously. */ + @Test + fun interceptorMakesAnUnrelatedAsyncRequest() { + server.enqueue(MockResponse.Builder().body("a").build()) // Fetched by interceptor. + server.enqueue(MockResponse.Builder().body("b").build()) // Fetched directly. + client = client.newBuilder() + .addInterceptor(Interceptor { chain: Interceptor.Chain -> + if (chain.request().url.encodedPath == "/b") { + val requestA = Request.Builder() + .url(server.url("/a")) + .build() + try { + val callbackA = RecordingCallback() + client.newCall(requestA).enqueue(callbackA) + callbackA.await(requestA.url).assertBody("a") + } catch (e: Exception) { + throw RuntimeException(e) + } + } + chain.proceed(chain.request()) + }) + .build() + val requestB = Request.Builder() + .url(server.url("/b")) + .build() + val callbackB = RecordingCallback() + client.newCall(requestB).enqueue(callbackB) + callbackB.await(requestB.url).assertBody("b") + } + + @Test + fun applicationInterceptorThrowsRuntimeExceptionSynchronous() { + interceptorThrowsRuntimeExceptionSynchronous(false) + } + + @Test + fun networkInterceptorThrowsRuntimeExceptionSynchronous() { + interceptorThrowsRuntimeExceptionSynchronous(true) + } + + /** + * When an interceptor throws an unexpected exception, synchronous callers can catch it and deal + * with it. + */ + private fun interceptorThrowsRuntimeExceptionSynchronous(network: Boolean) { + addInterceptor(network) { chain: Interceptor.Chain? -> throw RuntimeException("boom!") } + val request = Request.Builder() + .url(server.url("/")) + .build() + try { + client.newCall(request).execute() + fail() + } catch (expected: RuntimeException) { + assertThat(expected.message).isEqualTo("boom!") + } + } + + @Test + fun networkInterceptorModifiedRequestIsReturned() { + server.enqueue(MockResponse()) + val modifyHeaderInterceptor = Interceptor { chain: Interceptor.Chain -> + val modifiedRequest = chain.request() + .newBuilder() + .header("User-Agent", "intercepted request") + .build() + chain.proceed(modifiedRequest) + } + client = client.newBuilder() + .addNetworkInterceptor(modifyHeaderInterceptor) + .build() + val request = Request.Builder() + .url(server.url("/")) + .header("User-Agent", "user request") + .build() + val response = client.newCall(request).execute() + assertThat(response.request.header("User-Agent")).isNotNull() + assertThat(response.request.header("User-Agent")).isEqualTo("user request") + assertThat(response.networkResponse!!.request.header("User-Agent")).isEqualTo( + "intercepted request" + ) + } + + @Test + fun applicationInterceptorThrowsRuntimeExceptionAsynchronous() { + interceptorThrowsRuntimeExceptionAsynchronous(false) + } + + @Test + fun networkInterceptorThrowsRuntimeExceptionAsynchronous() { + interceptorThrowsRuntimeExceptionAsynchronous(true) + } + + /** + * When an interceptor throws an unexpected exception, asynchronous calls are canceled. The + * exception goes to the uncaught exception handler. + */ + private fun interceptorThrowsRuntimeExceptionAsynchronous(network: Boolean) { + val boom = RuntimeException("boom!") + addInterceptor(network) { chain: Interceptor.Chain? -> throw boom } + val executor = ExceptionCatchingExecutor() + client = client.newBuilder() + .dispatcher(Dispatcher(executor)) + .build() + val request = Request.Builder() + .url(server.url("/")) + .build() + val call = client.newCall(request) + call.enqueue(callback) + val recordedResponse = callback.await(server.url("/")) + assertThat(recordedResponse.failure) + .hasMessage("canceled due to java.lang.RuntimeException: boom!") + recordedResponse.failure!!.assertSuppressed { throwables: List? -> + assertThat(throwables).contains(boom) + Unit + } + assertThat(call.isCanceled()).isTrue() + assertThat(executor.takeException()).isEqualTo(boom) + } + + @Test + fun networkInterceptorReturnsConnectionOnEmptyBody() { + server.enqueue( + MockResponse.Builder() + .socketPolicy(DisconnectAtEnd) + .addHeader("Connection", "Close") + .build() + ) + val interceptor = Interceptor { chain: Interceptor.Chain -> + val response = chain.proceed(chain.request()) + assertThat(chain.connection()).isNotNull() + response + } + client = client.newBuilder() + .addNetworkInterceptor(interceptor) + .build() + val request = Request.Builder() + .url(server.url("/")) + .build() + val response = client.newCall(request).execute() + response.body.close() + } + + @Test + fun connectTimeout() { + val interceptor1 = Interceptor { chainA: Interceptor.Chain -> + assertThat(chainA.connectTimeoutMillis()).isEqualTo(5000) + val chainB = chainA.withConnectTimeout(100, TimeUnit.MILLISECONDS) + assertThat(chainB.connectTimeoutMillis()).isEqualTo(100) + chainB.proceed(chainA.request()) + } + val interceptor2 = Interceptor { chain: Interceptor.Chain -> + assertThat(chain.connectTimeoutMillis()).isEqualTo(100) + chain.proceed(chain.request()) + } + client = client.newBuilder() + .connectTimeout(Duration.ofSeconds(5)) + .addInterceptor(interceptor1) + .addInterceptor(interceptor2) + .build() + val request1 = Request.Builder() + .url("http://" + TestUtil.UNREACHABLE_ADDRESS_IPV4) + .build() + val call = client.newCall(request1) + val startNanos = System.nanoTime() + try { + call.execute() + fail() + } catch (expected: SocketTimeoutException) { + } + val elapsedNanos = System.nanoTime() - startNanos + org.junit.jupiter.api.Assertions.assertTrue( + elapsedNanos < TimeUnit.SECONDS.toNanos(5), + "Timeout should have taken ~100ms but was " + elapsedNanos / 1e6 + " ms" + ) + } + + @Test + fun chainWithReadTimeout() { + val interceptor1 = Interceptor { chainA: Interceptor.Chain -> + assertThat(chainA.readTimeoutMillis()).isEqualTo(5000) + val chainB = chainA.withReadTimeout(100, TimeUnit.MILLISECONDS) + assertThat(chainB.readTimeoutMillis()).isEqualTo(100) + chainB.proceed(chainA.request()) + } + val interceptor2 = Interceptor { chain: Interceptor.Chain -> + assertThat(chain.readTimeoutMillis()).isEqualTo(100) + chain.proceed(chain.request()) + } + client = client.newBuilder() + .readTimeout(Duration.ofSeconds(5)) + .addInterceptor(interceptor1) + .addInterceptor(interceptor2) + .build() + server.enqueue( + MockResponse.Builder() + .body("abc") + .throttleBody(1, 1, TimeUnit.SECONDS) + .build() + ) + val request1 = Request.Builder() + .url(server.url("/")) + .build() + val call = client.newCall(request1) + val response = call.execute() + val body = response.body + try { + body.string() + fail() + } catch (expected: SocketTimeoutException) { + } + } + + @Test + fun networkInterceptorCannotChangeReadTimeout() { + addInterceptor(true) { chain: Interceptor.Chain -> + chain.withReadTimeout( + 100, + TimeUnit.MILLISECONDS + ).proceed(chain.request()) + } + val request1 = Request.Builder().url(server.url("/")).build() + val call = client.newCall(request1) + try { + call.execute() + fail() + } catch (expected: IllegalStateException) { + assertThat(expected.message) + .isEqualTo("Timeouts can't be adjusted in a network interceptor") + } + } + + @Test + fun networkInterceptorCannotChangeWriteTimeout() { + addInterceptor(true) { chain: Interceptor.Chain -> + chain.withWriteTimeout( + 100, + TimeUnit.MILLISECONDS + ).proceed(chain.request()) + } + val request1 = Request.Builder().url(server.url("/")).build() + val call = client.newCall(request1) + try { + call.execute() + fail() + } catch (expected: IllegalStateException) { + assertThat(expected.message) + .isEqualTo("Timeouts can't be adjusted in a network interceptor") + } + } + + @Test + fun networkInterceptorCannotChangeConnectTimeout() { + addInterceptor(true) { chain: Interceptor.Chain -> + chain.withConnectTimeout( + 100, + TimeUnit.MILLISECONDS + ).proceed(chain.request()) + } + val request1 = Request.Builder().url(server.url("/")).build() + val call = client.newCall(request1) + try { + call.execute() + fail() + } catch (expected: IllegalStateException) { + assertThat(expected.message) + .isEqualTo("Timeouts can't be adjusted in a network interceptor") + } + } + + @Test + fun chainWithWriteTimeout() { + val interceptor1 = Interceptor { chainA: Interceptor.Chain -> + assertThat(chainA.writeTimeoutMillis()).isEqualTo(5000) + val chainB = chainA.withWriteTimeout(100, TimeUnit.MILLISECONDS) + assertThat(chainB.writeTimeoutMillis()).isEqualTo(100) + chainB.proceed(chainA.request()) + } + val interceptor2 = Interceptor { chain: Interceptor.Chain -> + assertThat(chain.writeTimeoutMillis()).isEqualTo(100) + chain.proceed(chain.request()) + } + client = client.newBuilder() + .writeTimeout(Duration.ofSeconds(5)) + .addInterceptor(interceptor1) + .addInterceptor(interceptor2) + .build() + server.enqueue( + MockResponse.Builder() + .body("abc") + .throttleBody(1, 1, TimeUnit.SECONDS) + .build() + ) + val data = ByteArray(2 * 1024 * 1024) // 2 MiB. + val request1 = Request.Builder() + .url(server.url("/")) + .post(data.toRequestBody("text/plain".toMediaType())) + .build() + val call = client.newCall(request1) + try { + call.execute() // we want this call to throw a SocketTimeoutException + fail() + } catch (expected: SocketTimeoutException) { + } + } + + @Test + fun chainCanCancelCall() { + val callRef = AtomicReference() + val interceptor = Interceptor { chain: Interceptor.Chain -> + val call = chain.call() + callRef.set(call) + assertThat(call.isCanceled()).isFalse() + call.cancel() + assertThat(call.isCanceled()).isTrue() + chain.proceed(chain.request()) + } + client = client.newBuilder() + .addInterceptor(interceptor) + .build() + val request = Request.Builder() + .url(server.url("/")) + .build() + val call = client.newCall(request) + try { + call.execute() + fail() + } catch (expected: IOException) { + } + assertThat(callRef.get()).isSameAs(call) + } + + private fun uppercase(original: RequestBody?): RequestBody { + return object : RequestBody() { + override fun contentType(): MediaType? { + return original!!.contentType() + } + + override fun contentLength(): Long { + return original!!.contentLength() + } + + override fun writeTo(sink: BufferedSink) { + val uppercase = uppercase(sink) + val bufferedSink = uppercase.buffer() + original!!.writeTo(bufferedSink) + bufferedSink.emit() + } + } + } + + private fun uppercase(original: BufferedSink): Sink { + return object : ForwardingSink(original) { + override fun write(source: Buffer, byteCount: Long) { + original.writeUtf8(source.readUtf8(byteCount).uppercase()) + } + } + } + + private fun gzip(data: String): Buffer { + val result = Buffer() + val sink = GzipSink(result).buffer() + sink.writeUtf8(data) + sink.close() + return result + } + + private fun addInterceptor(network: Boolean, interceptor: Interceptor) { + val builder = client.newBuilder() + if (network) { + builder.addNetworkInterceptor(interceptor) + } else { + builder.addInterceptor(interceptor) + } + client = builder.build() + } + + /** Catches exceptions that are otherwise headed for the uncaught exception handler. */ + private class ExceptionCatchingExecutor : + ThreadPoolExecutor(1, 1, 0, TimeUnit.SECONDS, SynchronousQueue()) { + private val exceptions: BlockingQueue = LinkedBlockingQueue() + override fun execute(runnable: Runnable) { + super.execute { + try { + runnable.run() + } catch (e: Exception) { + exceptions.add(e) + } + } + } + + fun takeException(): Exception { + return exceptions.take() + } + } + + companion object { + fun uppercase(original: ResponseBody): ResponseBody { + return object : ResponseBody() { + override fun contentType()= original.contentType() + + override fun contentLength() = original.contentLength() + + override fun source() = uppercase(original.source()).buffer() + } + } + + private fun uppercase(original: Source): Source { + return object : ForwardingSource(original) { + override fun read(sink: Buffer, byteCount: Long): Long { + val mixedCase = Buffer() + val count = original.read(mixedCase, byteCount) + sink.writeUtf8(mixedCase.readUtf8().uppercase()) + return count + } + } + } + } +} diff --git a/okhttp/src/test/java/okhttp3/MultipartBodyTest.java b/okhttp/src/test/java/okhttp3/MultipartBodyTest.java deleted file mode 100644 index 2555692b3d6a..000000000000 --- a/okhttp/src/test/java/okhttp3/MultipartBodyTest.java +++ /dev/null @@ -1,281 +0,0 @@ -/* - * Copyright (C) 2014 Square, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package okhttp3; - -import java.io.IOException; -import okio.Buffer; -import okio.BufferedSink; -import org.junit.jupiter.api.Test; - -import static java.nio.charset.StandardCharsets.UTF_8; -import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.fail; - -public final class MultipartBodyTest { - @Test public void onePartRequired() throws Exception { - try { - new MultipartBody.Builder().build(); - fail(); - } catch (IllegalStateException e) { - assertThat(e.getMessage()).isEqualTo("Multipart body must have at least one part."); - } - } - - @Test public void singlePart() throws Exception { - String expected = "" - + "--123\r\n" - + "\r\n" - + "Hello, World!\r\n" - + "--123--\r\n"; - - MultipartBody body = new MultipartBody.Builder("123") - .addPart(RequestBody.create("Hello, World!", null)) - .build(); - - assertThat(body.boundary()).isEqualTo("123"); - assertThat(body.type()).isEqualTo(MultipartBody.MIXED); - assertThat(body.contentType().toString()).isEqualTo("multipart/mixed; boundary=123"); - assertThat(body.parts().size()).isEqualTo(1); - assertThat(body.contentLength()).isEqualTo(33L); - - Buffer buffer = new Buffer(); - body.writeTo(buffer); - assertThat(body.contentLength()).isEqualTo(buffer.size()); - assertThat(buffer.readUtf8()).isEqualTo(expected); - } - - @Test public void threeParts() throws Exception { - String expected = "" - + "--123\r\n" - + "\r\n" - + "Quick\r\n" - + "--123\r\n" - + "\r\n" - + "Brown\r\n" - + "--123\r\n" - + "\r\n" - + "Fox\r\n" - + "--123--\r\n"; - - MultipartBody body = new MultipartBody.Builder("123") - .addPart(RequestBody.create("Quick", null)) - .addPart(RequestBody.create("Brown", null)) - .addPart(RequestBody.create("Fox", null)) - .build(); - - assertThat(body.boundary()).isEqualTo("123"); - assertThat(body.type()).isEqualTo(MultipartBody.MIXED); - assertThat(body.contentType().toString()).isEqualTo("multipart/mixed; boundary=123"); - assertThat(body.parts().size()).isEqualTo(3); - assertThat(body.contentLength()).isEqualTo(55L); - - Buffer buffer = new Buffer(); - body.writeTo(buffer); - assertThat(body.contentLength()).isEqualTo(buffer.size()); - assertThat(buffer.readUtf8()).isEqualTo(expected); - } - - @Test public void fieldAndTwoFiles() throws Exception { - String expected = "" - + "--AaB03x\r\n" - + "Content-Disposition: form-data; name=\"submit-name\"\r\n" - + "\r\n" - + "Larry\r\n" - + "--AaB03x\r\n" - + "Content-Disposition: form-data; name=\"files\"\r\n" - + "Content-Type: multipart/mixed; boundary=BbC04y\r\n" - + "\r\n" - + "--BbC04y\r\n" - + "Content-Disposition: file; filename=\"file1.txt\"\r\n" - + "Content-Type: text/plain; charset=utf-8\r\n" - + "\r\n" - + "... contents of file1.txt ...\r\n" - + "--BbC04y\r\n" - + "Content-Disposition: file; filename=\"file2.gif\"\r\n" - + "Content-Transfer-Encoding: binary\r\n" - + "Content-Type: image/gif\r\n" - + "\r\n" - + "... contents of file2.gif ...\r\n" - + "--BbC04y--\r\n" - + "\r\n" - + "--AaB03x--\r\n"; - - MultipartBody body = new MultipartBody.Builder("AaB03x") - .setType(MultipartBody.FORM) - .addFormDataPart("submit-name", "Larry") - .addFormDataPart("files", null, - new MultipartBody.Builder("BbC04y") - .addPart( - Headers.of("Content-Disposition", "file; filename=\"file1.txt\""), - RequestBody.create( - "... contents of file1.txt ...", MediaType.get("text/plain"))) - .addPart( - Headers.of( - "Content-Disposition", "file; filename=\"file2.gif\"", - "Content-Transfer-Encoding", "binary"), - RequestBody.create( - "... contents of file2.gif ...".getBytes(UTF_8), - MediaType.get("image/gif"))) - .build()) - .build(); - - assertThat(body.boundary()).isEqualTo("AaB03x"); - assertThat(body.type()).isEqualTo(MultipartBody.FORM); - assertThat(body.contentType().toString()).isEqualTo( - "multipart/form-data; boundary=AaB03x"); - assertThat(body.parts().size()).isEqualTo(2); - assertThat(body.contentLength()).isEqualTo(488L); - - Buffer buffer = new Buffer(); - body.writeTo(buffer); - assertThat(body.contentLength()).isEqualTo(buffer.size()); - assertThat(buffer.readUtf8()).isEqualTo(expected); - } - - @Test public void stringEscapingIsWeird() throws Exception { - String expected = "" - + "--AaB03x\r\n" - + "Content-Disposition: form-data; name=\"field with spaces\"; filename=\"filename with spaces.txt\"\r\n" - + "Content-Type: text/plain; charset=utf-8\r\n" - + "\r\n" - + "okay\r\n" - + "--AaB03x\r\n" - + "Content-Disposition: form-data; name=\"field with %22\"\r\n" - + "\r\n" - + "\"\r\n" - + "--AaB03x\r\n" - + "Content-Disposition: form-data; name=\"field with %22\"\r\n" - + "\r\n" - + "%22\r\n" - + "--AaB03x\r\n" - + "Content-Disposition: form-data; name=\"field with \u007e\"\r\n" - + "\r\n" - + "Alpha\r\n" - + "--AaB03x--\r\n"; - - MultipartBody body = new MultipartBody.Builder("AaB03x") - .setType(MultipartBody.FORM) - .addFormDataPart("field with spaces", "filename with spaces.txt", - RequestBody.create("okay", MediaType.get("text/plain; charset=utf-8"))) - .addFormDataPart("field with \"", "\"") - .addFormDataPart("field with %22", "%22") - .addFormDataPart("field with \u007e", "Alpha") - .build(); - - Buffer buffer = new Buffer(); - body.writeTo(buffer); - assertThat(buffer.readUtf8()).isEqualTo(expected); - } - - @Test public void streamingPartHasNoLength() throws Exception { - class StreamingBody extends RequestBody { - private final String body; - - StreamingBody(String body) { - this.body = body; - } - - @Override public MediaType contentType() { - return null; - } - - @Override public void writeTo(BufferedSink sink) throws IOException { - sink.writeUtf8(body); - } - } - - String expected = "" - + "--123\r\n" - + "\r\n" - + "Quick\r\n" - + "--123\r\n" - + "\r\n" - + "Brown\r\n" - + "--123\r\n" - + "\r\n" - + "Fox\r\n" - + "--123--\r\n"; - - MultipartBody body = new MultipartBody.Builder("123") - .addPart(RequestBody.create("Quick", null)) - .addPart(new StreamingBody("Brown")) - .addPart(RequestBody.create("Fox", null)) - .build(); - - assertThat(body.boundary()).isEqualTo("123"); - assertThat(body.type()).isEqualTo(MultipartBody.MIXED); - assertThat(body.contentType().toString()).isEqualTo("multipart/mixed; boundary=123"); - assertThat(body.parts().size()).isEqualTo(3); - assertThat(body.contentLength()).isEqualTo(-1); - - Buffer buffer = new Buffer(); - body.writeTo(buffer); - assertThat(buffer.readUtf8()).isEqualTo(expected); - } - - @Test public void contentTypeHeaderIsForbidden() throws Exception { - MultipartBody.Builder multipart = new MultipartBody.Builder(); - try { - multipart.addPart(Headers.of("Content-Type", "text/plain"), - RequestBody.create("Hello, World!", null)); - fail(); - } catch (IllegalArgumentException expected) { - } - } - - @Test public void contentLengthHeaderIsForbidden() throws Exception { - MultipartBody.Builder multipart = new MultipartBody.Builder(); - try { - multipart.addPart(Headers.of("Content-Length", "13"), - RequestBody.create("Hello, World!", null)); - fail(); - } catch (IllegalArgumentException expected) { - } - } - - @Test public void partAccessors() throws IOException { - MultipartBody body = new MultipartBody.Builder() - .addPart(Headers.of("Foo", "Bar"), RequestBody.create("Baz", null)) - .build(); - assertThat(body.parts().size()).isEqualTo(1); - - Buffer part1Buffer = new Buffer(); - MultipartBody.Part part1 = body.part(0); - part1.body().writeTo(part1Buffer); - assertThat(part1.headers()).isEqualTo(Headers.of("Foo", "Bar")); - assertThat(part1Buffer.readUtf8()).isEqualTo("Baz"); - } - - @Test public void nonAsciiFilename() throws Exception { - String expected = "" - + "--AaB03x\r\n" - + "Content-Disposition: form-data; name=\"attachment\"; filename=\"resumé.pdf\"\r\n" - + "Content-Type: application/pdf; charset=utf-8\r\n" - + "\r\n" - + "Jesse’s Resumé\r\n" - + "--AaB03x--\r\n"; - - MultipartBody body = new MultipartBody.Builder("AaB03x") - .setType(MultipartBody.FORM) - .addFormDataPart("attachment", "resumé.pdf", - RequestBody.create("Jesse’s Resumé", MediaType.parse("application/pdf"))) - .build(); - - Buffer buffer = new Buffer(); - body.writeTo(buffer); - assertThat(buffer.readUtf8()).isEqualTo(expected); - } -} diff --git a/okhttp/src/test/java/okhttp3/MultipartBodyTest.kt b/okhttp/src/test/java/okhttp3/MultipartBodyTest.kt new file mode 100644 index 000000000000..dd8350fb6c13 --- /dev/null +++ b/okhttp/src/test/java/okhttp3/MultipartBodyTest.kt @@ -0,0 +1,301 @@ +/* + * Copyright (C) 2014 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3 + +import java.io.IOException +import java.nio.charset.StandardCharsets +import okhttp3.Headers.Companion.headersOf +import okhttp3.MediaType.Companion.toMediaType +import okhttp3.MediaType.Companion.toMediaTypeOrNull +import okhttp3.RequestBody.Companion.toRequestBody +import okio.Buffer +import okio.BufferedSink +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Assertions.fail +import org.junit.jupiter.api.Test + +class MultipartBodyTest { + @Test + fun onePartRequired() { + try { + MultipartBody.Builder().build() + fail() + } catch (e: IllegalStateException) { + assertThat(e.message) + .isEqualTo("Multipart body must have at least one part.") + } + } + + @Test + fun singlePart() { + val expected = """ + |--123 + | + |Hello, World! + |--123-- + | + """.trimMargin().replace("\n", "\r\n") + val body = MultipartBody.Builder("123") + .addPart("Hello, World!".toRequestBody(null)) + .build() + assertThat(body.boundary).isEqualTo("123") + assertThat(body.type).isEqualTo(MultipartBody.MIXED) + assertThat(body.contentType().toString()) + .isEqualTo("multipart/mixed; boundary=123") + assertThat(body.parts.size).isEqualTo(1) + assertThat(body.contentLength()).isEqualTo(33L) + val buffer = Buffer() + body.writeTo(buffer) + assertThat(body.contentLength()).isEqualTo(buffer.size) + assertThat(buffer.readUtf8()).isEqualTo(expected) + } + + @Test + fun threeParts() { + val expected = """ + |--123 + | + |Quick + |--123 + | + |Brown + |--123 + | + |Fox + |--123-- + | + """.trimMargin().replace("\n", "\r\n") + val body = MultipartBody.Builder("123") + .addPart("Quick".toRequestBody(null)) + .addPart("Brown".toRequestBody(null)) + .addPart("Fox".toRequestBody(null)) + .build() + assertThat(body.boundary).isEqualTo("123") + assertThat(body.type).isEqualTo(MultipartBody.MIXED) + assertThat(body.contentType().toString()) + .isEqualTo("multipart/mixed; boundary=123") + assertThat(body.parts.size).isEqualTo(3) + assertThat(body.contentLength()).isEqualTo(55L) + val buffer = Buffer() + body.writeTo(buffer) + assertThat(body.contentLength()).isEqualTo(buffer.size) + assertThat(buffer.readUtf8()).isEqualTo(expected) + } + + @Test + fun fieldAndTwoFiles() { + val expected = """ + |--AaB03x + |Content-Disposition: form-data; name="submit-name" + | + |Larry + |--AaB03x + |Content-Disposition: form-data; name="files" + |Content-Type: multipart/mixed; boundary=BbC04y + | + |--BbC04y + |Content-Disposition: file; filename="file1.txt" + |Content-Type: text/plain; charset=utf-8 + | + |... contents of file1.txt ... + |--BbC04y + |Content-Disposition: file; filename="file2.gif" + |Content-Transfer-Encoding: binary + |Content-Type: image/gif + | + |... contents of file2.gif ... + |--BbC04y-- + | + |--AaB03x-- + | + """.trimMargin().replace("\n", "\r\n") + val body = MultipartBody.Builder("AaB03x") + .setType(MultipartBody.FORM) + .addFormDataPart("submit-name", "Larry") + .addFormDataPart( + "files", null, + MultipartBody.Builder("BbC04y") + .addPart( + headersOf("Content-Disposition", "file; filename=\"file1.txt\""), + "... contents of file1.txt ...".toRequestBody("text/plain".toMediaType()) + ) + .addPart( + headersOf( + "Content-Disposition", "file; filename=\"file2.gif\"", + "Content-Transfer-Encoding", "binary" + ), + "... contents of file2.gif ...".toByteArray(StandardCharsets.UTF_8) + .toRequestBody("image/gif".toMediaType()) + ) + .build() + ) + .build() + assertThat(body.boundary).isEqualTo("AaB03x") + assertThat(body.type).isEqualTo(MultipartBody.FORM) + assertThat(body.contentType().toString()).isEqualTo( + "multipart/form-data; boundary=AaB03x" + ) + assertThat(body.parts.size).isEqualTo(2) + assertThat(body.contentLength()).isEqualTo(488L) + val buffer = Buffer() + body.writeTo(buffer) + assertThat(body.contentLength()).isEqualTo(buffer.size) + assertThat(buffer.readUtf8()).isEqualTo(expected) + } + + @Test + fun stringEscapingIsWeird() { + val expected = """ + |--AaB03x + |Content-Disposition: form-data; name="field with spaces"; filename="filename with spaces.txt" + |Content-Type: text/plain; charset=utf-8 + | + |okay + |--AaB03x + |Content-Disposition: form-data; name="field with %22" + | + |" + |--AaB03x + |Content-Disposition: form-data; name="field with %22" + | + |%22 + |--AaB03x + |Content-Disposition: form-data; name="field with ~" + | + |Alpha + |--AaB03x-- + | + """.trimMargin().replace("\n", "\r\n") + val body = MultipartBody.Builder("AaB03x") + .setType(MultipartBody.FORM) + .addFormDataPart( + "field with spaces", "filename with spaces.txt", + "okay".toRequestBody("text/plain; charset=utf-8".toMediaType()) + ) + .addFormDataPart("field with \"", "\"") + .addFormDataPart("field with %22", "%22") + .addFormDataPart("field with \u007e", "Alpha") + .build() + val buffer = Buffer() + body.writeTo(buffer) + assertThat(buffer.readUtf8()).isEqualTo(expected) + } + + @Test + fun streamingPartHasNoLength() { + class StreamingBody(private val body: String) : RequestBody() { + override fun contentType(): MediaType? { + return null + } + + @Throws(IOException::class) + override fun writeTo(sink: BufferedSink) { + sink.writeUtf8(body) + } + } + + val expected = """ + |--123 + | + |Quick + |--123 + | + |Brown + |--123 + | + |Fox + |--123-- + | + """.trimMargin().replace("\n", "\r\n") + val body = MultipartBody.Builder("123") + .addPart("Quick".toRequestBody(null)) + .addPart(StreamingBody("Brown")) + .addPart("Fox".toRequestBody(null)) + .build() + assertThat(body.boundary).isEqualTo("123") + assertThat(body.type).isEqualTo(MultipartBody.MIXED) + assertThat(body.contentType().toString()) + .isEqualTo("multipart/mixed; boundary=123") + assertThat(body.parts.size).isEqualTo(3) + assertThat(body.contentLength()).isEqualTo(-1) + val buffer = Buffer() + body.writeTo(buffer) + assertThat(buffer.readUtf8()).isEqualTo(expected) + } + + @Test + fun contentTypeHeaderIsForbidden() { + val multipart = MultipartBody.Builder() + try { + multipart.addPart( + headersOf("Content-Type", "text/plain"), + "Hello, World!".toRequestBody(null) + ) + fail() + } catch (expected: IllegalArgumentException) { + } + } + + @Test + fun contentLengthHeaderIsForbidden() { + val multipart = MultipartBody.Builder() + try { + multipart.addPart( + headersOf("Content-Length", "13"), + "Hello, World!".toRequestBody(null) + ) + fail() + } catch (expected: IllegalArgumentException) { + } + } + + @Test + @Throws(IOException::class) + fun partAccessors() { + val body = MultipartBody.Builder() + .addPart(headersOf("Foo", "Bar"), "Baz".toRequestBody(null)) + .build() + assertThat(body.parts.size).isEqualTo(1) + val part1Buffer = Buffer() + val part1 = body.part(0) + part1.body.writeTo(part1Buffer) + assertThat(part1.headers).isEqualTo(headersOf("Foo", "Bar")) + assertThat(part1Buffer.readUtf8()).isEqualTo("Baz") + } + + @Test + fun nonAsciiFilename() { + val expected = """ + |--AaB03x + |Content-Disposition: form-data; name="attachment"; filename="resumé.pdf" + |Content-Type: application/pdf; charset=utf-8 + | + |Jesse’s Resumé + |--AaB03x-- + | + """.trimMargin().replace("\n", "\r\n") + val body = MultipartBody.Builder("AaB03x") + .setType(MultipartBody.FORM) + .addFormDataPart( + "attachment", "resumé.pdf", + "Jesse’s Resumé".toRequestBody("application/pdf".toMediaTypeOrNull()) + ) + .build() + val buffer = Buffer() + body.writeTo(buffer) + assertThat(buffer.readUtf8()).isEqualTo(expected) + } +} diff --git a/okhttp/src/test/java/okhttp3/WholeOperationTimeoutTest.java b/okhttp/src/test/java/okhttp3/WholeOperationTimeoutTest.java deleted file mode 100644 index 4c07133b0236..000000000000 --- a/okhttp/src/test/java/okhttp3/WholeOperationTimeoutTest.java +++ /dev/null @@ -1,348 +0,0 @@ -/* - * Copyright (C) 2018 Square, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package okhttp3; - -import java.io.IOException; -import java.io.InterruptedIOException; -import java.net.HttpURLConnection; -import java.time.Duration; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; -import mockwebserver3.MockResponse; -import mockwebserver3.MockWebServer; -import okhttp3.testing.Flaky; -import okio.BufferedSink; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Tag; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.Timeout; -import org.junit.jupiter.api.extension.RegisterExtension; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.fail; - -@Timeout(30) -@Tag("Slow") -public final class WholeOperationTimeoutTest { - /** A large response body. Smaller bodies might successfully read after the socket is closed! */ - private static final String BIG_ENOUGH_BODY = TestUtil.repeat('a', 64 * 1024); - - @RegisterExtension public final OkHttpClientTestRule clientTestRule = new OkHttpClientTestRule(); - - private MockWebServer server; - private final OkHttpClient client = clientTestRule.newClient(); - - @BeforeEach - public void setUp(MockWebServer server) throws Exception { - this.server = server; - } - - @Test public void defaultConfigIsNoTimeout() throws Exception { - Request request = new Request.Builder() - .url(server.url("/")) - .build(); - Call call = client.newCall(request); - assertThat(call.timeout().timeoutNanos()).isEqualTo(0); - } - - @Test public void configureClientDefault() throws Exception { - Request request = new Request.Builder() - .url(server.url("/")) - .build(); - - OkHttpClient timeoutClient = client.newBuilder() - .callTimeout(Duration.ofMillis(456)) - .build(); - - Call call = timeoutClient.newCall(request); - assertThat(call.timeout().timeoutNanos()).isEqualTo(TimeUnit.MILLISECONDS.toNanos(456)); - } - - @Test public void timeoutWritingRequest() throws Exception { - server.enqueue(new MockResponse()); - - Request request = new Request.Builder() - .url(server.url("/")) - .post(sleepingRequestBody(500)) - .build(); - - Call call = client.newCall(request); - call.timeout().timeout(250, TimeUnit.MILLISECONDS); - try { - call.execute(); - fail(); - } catch (IOException e) { - assertThat(e.getMessage()).isEqualTo("timeout"); - assertThat(call.isCanceled()).isTrue(); - } - } - - @Test public void timeoutWritingRequestWithEnqueue() throws Exception { - server.enqueue(new MockResponse()); - - Request request = new Request.Builder() - .url(server.url("/")) - .post(sleepingRequestBody(500)) - .build(); - - final CountDownLatch latch = new CountDownLatch(1); - final AtomicReference exceptionRef = new AtomicReference<>(); - - Call call = client.newCall(request); - call.timeout().timeout(250, TimeUnit.MILLISECONDS); - call.enqueue(new Callback() { - @Override public void onFailure(Call call, IOException e) { - exceptionRef.set(e); - latch.countDown(); - } - - @Override public void onResponse(Call call, Response response) throws IOException { - response.close(); - latch.countDown(); - } - }); - - latch.await(); - assertThat(call.isCanceled()).isTrue(); - assertThat(exceptionRef.get()).isNotNull(); - } - - @Test public void timeoutProcessing() throws Exception { - server.enqueue(new MockResponse.Builder() - .headersDelay(500, TimeUnit.MILLISECONDS) - .build()); - - Request request = new Request.Builder() - .url(server.url("/")) - .build(); - - Call call = client.newCall(request); - call.timeout().timeout(250, TimeUnit.MILLISECONDS); - try { - call.execute(); - fail(); - } catch (IOException e) { - assertThat(e.getMessage()).isEqualTo("timeout"); - assertThat(call.isCanceled()).isTrue(); - } - } - - @Test public void timeoutProcessingWithEnqueue() throws Exception { - server.enqueue(new MockResponse.Builder() - .headersDelay(500, TimeUnit.MILLISECONDS) - .build()); - - Request request = new Request.Builder() - .url(server.url("/")) - .build(); - - final CountDownLatch latch = new CountDownLatch(1); - final AtomicReference exceptionRef = new AtomicReference<>(); - - Call call = client.newCall(request); - call.timeout().timeout(250, TimeUnit.MILLISECONDS); - call.enqueue(new Callback() { - @Override public void onFailure(Call call, IOException e) { - exceptionRef.set(e); - latch.countDown(); - } - - @Override public void onResponse(Call call, Response response) throws IOException { - response.close(); - latch.countDown(); - } - }); - - latch.await(); - assertThat(call.isCanceled()).isTrue(); - assertThat(exceptionRef.get()).isNotNull(); - } - - @Test public void timeoutReadingResponse() throws Exception { - server.enqueue(new MockResponse.Builder() - .body(BIG_ENOUGH_BODY) - .build()); - - Request request = new Request.Builder() - .url(server.url("/")) - .build(); - - Call call = client.newCall(request); - call.timeout().timeout(250, TimeUnit.MILLISECONDS); - Response response = call.execute(); - Thread.sleep(500); - try { - response.body().source().readUtf8(); - fail(); - } catch (IOException e) { - assertThat(e.getMessage()).isEqualTo("timeout"); - assertThat(call.isCanceled()).isTrue(); - } - } - - @Test public void timeoutReadingResponseWithEnqueue() throws Exception { - server.enqueue(new MockResponse.Builder() - .body(BIG_ENOUGH_BODY) - .build()); - - Request request = new Request.Builder() - .url(server.url("/")) - .build(); - - final CountDownLatch latch = new CountDownLatch(1); - final AtomicReference exceptionRef = new AtomicReference<>(); - - Call call = client.newCall(request); - call.timeout().timeout(250, TimeUnit.MILLISECONDS); - call.enqueue(new Callback() { - @Override public void onFailure(Call call, IOException e) { - latch.countDown(); - } - - @Override public void onResponse(Call call, Response response) throws IOException { - try { - Thread.sleep(500); - } catch (InterruptedException e) { - throw new AssertionError(); - } - try { - response.body().source().readUtf8(); - fail(); - } catch (IOException e) { - exceptionRef.set(e); - } finally { - latch.countDown(); - } - } - }); - - latch.await(); - assertThat(call.isCanceled()).isTrue(); - assertThat(exceptionRef.get()).isNotNull(); - } - - @Test public void singleTimeoutForAllFollowUpRequests() throws Exception { - server.enqueue(new MockResponse.Builder() - .code(HttpURLConnection.HTTP_MOVED_TEMP) - .setHeader("Location", "/b") - .headersDelay(100, TimeUnit.MILLISECONDS) - .build()); - server.enqueue(new MockResponse.Builder() - .code(HttpURLConnection.HTTP_MOVED_TEMP) - .setHeader("Location", "/c") - .headersDelay(100, TimeUnit.MILLISECONDS) - .build()); - server.enqueue(new MockResponse.Builder() - .code(HttpURLConnection.HTTP_MOVED_TEMP) - .setHeader("Location", "/d") - .headersDelay(100, TimeUnit.MILLISECONDS) - .build()); - server.enqueue(new MockResponse.Builder() - .code(HttpURLConnection.HTTP_MOVED_TEMP) - .setHeader("Location", "/e") - .headersDelay(100, TimeUnit.MILLISECONDS) - .build()); - server.enqueue(new MockResponse.Builder() - .code(HttpURLConnection.HTTP_MOVED_TEMP) - .setHeader("Location", "/f") - .headersDelay(100, TimeUnit.MILLISECONDS) - .build()); - server.enqueue(new MockResponse()); - - Request request = new Request.Builder() - .url(server.url("/a")) - .build(); - - Call call = client.newCall(request); - call.timeout().timeout(250, TimeUnit.MILLISECONDS); - try { - call.execute(); - fail(); - } catch (IOException e) { - assertThat(e.getMessage()).isEqualTo("timeout"); - assertThat(call.isCanceled()).isTrue(); - } - } - - @Test - public void timeoutFollowingRedirectOnNewConnection() throws Exception { - MockWebServer otherServer = new MockWebServer(); - - server.enqueue(new MockResponse.Builder() - .code(HttpURLConnection.HTTP_MOVED_TEMP) - .setHeader("Location", otherServer.url("/")) - .build()); - - otherServer.enqueue(new MockResponse.Builder() - .headersDelay(500, TimeUnit.MILLISECONDS) - .build()); - - Request request = new Request.Builder().url(server.url("/")).build(); - - Call call = client.newCall(request); - call.timeout().timeout(250, TimeUnit.MILLISECONDS); - try { - call.execute(); - fail(); - } catch (IOException e) { - assertThat(e.getMessage()).isEqualTo("timeout"); - assertThat(call.isCanceled()).isTrue(); - } - } - - @Flaky - @Test public void noTimeout() throws Exception { - // Flaky https://github.com/square/okhttp/issues/5304 - - server.enqueue(new MockResponse.Builder() - .headersDelay(250, TimeUnit.MILLISECONDS) - .body(BIG_ENOUGH_BODY) - .build()); - - Request request = new Request.Builder() - .url(server.url("/")) - .post(sleepingRequestBody(250)) - .build(); - - Call call = client.newCall(request); - call.timeout().timeout(2000, TimeUnit.MILLISECONDS); - Response response = call.execute(); - Thread.sleep(250); - response.body().source().readUtf8(); - response.close(); - assertThat(call.isCanceled()).isFalse(); - } - - private RequestBody sleepingRequestBody(final int sleepMillis) { - return new RequestBody() { - @Override public MediaType contentType() { - return MediaType.parse("text/plain"); - } - - @Override public void writeTo(BufferedSink sink) throws IOException { - try { - sink.writeUtf8("abc"); - sink.flush(); - Thread.sleep(sleepMillis); - sink.writeUtf8("def"); - } catch (InterruptedException e) { - throw new InterruptedIOException(); - } - } - }; - } -} diff --git a/okhttp/src/test/java/okhttp3/WholeOperationTimeoutTest.kt b/okhttp/src/test/java/okhttp3/WholeOperationTimeoutTest.kt new file mode 100644 index 000000000000..c88eff6c6a55 --- /dev/null +++ b/okhttp/src/test/java/okhttp3/WholeOperationTimeoutTest.kt @@ -0,0 +1,362 @@ +/* + * Copyright (C) 2018 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3 + +import java.io.IOException +import java.io.InterruptedIOException +import java.net.HttpURLConnection +import java.time.Duration +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicReference +import mockwebserver3.MockResponse +import mockwebserver3.MockWebServer +import okhttp3.MediaType.Companion.toMediaTypeOrNull +import okhttp3.TestUtil.repeat +import okhttp3.testing.Flaky +import okio.BufferedSink +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Assertions.fail +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Tag +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Timeout +import org.junit.jupiter.api.extension.RegisterExtension + +@Timeout(30) +@Tag("Slow") +class WholeOperationTimeoutTest { + @RegisterExtension + val clientTestRule = OkHttpClientTestRule() + private val client = clientTestRule.newClient() + + private lateinit var server: MockWebServer + + @BeforeEach + fun setUp(server: MockWebServer) { + this.server = server + } + + @Test + fun defaultConfigIsNoTimeout() { + val request = Request.Builder() + .url(server.url("/")) + .build() + val call = client.newCall(request) + assertThat(call.timeout().timeoutNanos()).isEqualTo(0) + } + + @Test + fun configureClientDefault() { + val request = Request.Builder() + .url(server.url("/")) + .build() + val timeoutClient = client.newBuilder() + .callTimeout(Duration.ofMillis(456)) + .build() + val call = timeoutClient.newCall(request) + assertThat(call.timeout().timeoutNanos()) + .isEqualTo(TimeUnit.MILLISECONDS.toNanos(456)) + } + + @Test + fun timeoutWritingRequest() { + server.enqueue(MockResponse()) + val request = Request.Builder() + .url(server.url("/")) + .post(sleepingRequestBody(500)) + .build() + val call = client.newCall(request) + call.timeout().timeout(250, TimeUnit.MILLISECONDS) + try { + call.execute() + fail() + } catch (e: IOException) { + assertThat(e.message).isEqualTo("timeout") + assertThat(call.isCanceled()).isTrue() + } + } + + @Test + fun timeoutWritingRequestWithEnqueue() { + server.enqueue(MockResponse()) + val request = Request.Builder() + .url(server.url("/")) + .post(sleepingRequestBody(500)) + .build() + val latch = CountDownLatch(1) + val exceptionRef = AtomicReference() + val call = client.newCall(request) + call.timeout().timeout(250, TimeUnit.MILLISECONDS) + call.enqueue(object : Callback { + override fun onFailure(call: Call, e: IOException) { + exceptionRef.set(e) + latch.countDown() + } + + @Throws(IOException::class) + override fun onResponse(call: Call, response: Response) { + response.close() + latch.countDown() + } + }) + latch.await() + assertThat(call.isCanceled()).isTrue() + assertThat(exceptionRef.get()).isNotNull() + } + + @Test + fun timeoutProcessing() { + server.enqueue( + MockResponse.Builder() + .headersDelay(500, TimeUnit.MILLISECONDS) + .build() + ) + val request = Request.Builder() + .url(server.url("/")) + .build() + val call = client.newCall(request) + call.timeout().timeout(250, TimeUnit.MILLISECONDS) + try { + call.execute() + fail() + } catch (e: IOException) { + assertThat(e.message).isEqualTo("timeout") + assertThat(call.isCanceled()).isTrue() + } + } + + @Test + fun timeoutProcessingWithEnqueue() { + server.enqueue( + MockResponse.Builder() + .headersDelay(500, TimeUnit.MILLISECONDS) + .build() + ) + val request = Request.Builder() + .url(server.url("/")) + .build() + val latch = CountDownLatch(1) + val exceptionRef = AtomicReference() + val call = client.newCall(request) + call.timeout().timeout(250, TimeUnit.MILLISECONDS) + call.enqueue(object : Callback { + override fun onFailure(call: Call, e: IOException) { + exceptionRef.set(e) + latch.countDown() + } + + @Throws(IOException::class) + override fun onResponse(call: Call, response: Response) { + response.close() + latch.countDown() + } + }) + latch.await() + assertThat(call.isCanceled()).isTrue() + assertThat(exceptionRef.get()).isNotNull() + } + + @Test + fun timeoutReadingResponse() { + server.enqueue( + MockResponse.Builder() + .body(BIG_ENOUGH_BODY) + .build() + ) + val request = Request.Builder() + .url(server.url("/")) + .build() + val call = client.newCall(request) + call.timeout().timeout(250, TimeUnit.MILLISECONDS) + val response = call.execute() + Thread.sleep(500) + try { + response.body.source().readUtf8() + fail() + } catch (e: IOException) { + assertThat(e.message).isEqualTo("timeout") + assertThat(call.isCanceled()).isTrue() + } + } + + @Test + fun timeoutReadingResponseWithEnqueue() { + server.enqueue( + MockResponse.Builder() + .body(BIG_ENOUGH_BODY) + .build() + ) + val request = Request.Builder() + .url(server.url("/")) + .build() + val latch = CountDownLatch(1) + val exceptionRef = AtomicReference() + val call = client.newCall(request) + call.timeout().timeout(250, TimeUnit.MILLISECONDS) + call.enqueue(object : Callback { + override fun onFailure(call: Call, e: IOException) { + latch.countDown() + } + + @Throws(IOException::class) + override fun onResponse(call: Call, response: Response) { + try { + Thread.sleep(500) + } catch (e: InterruptedException) { + throw AssertionError() + } + try { + response.body.source().readUtf8() + fail() + } catch (e: IOException) { + exceptionRef.set(e) + } finally { + latch.countDown() + } + } + }) + latch.await() + assertThat(call.isCanceled()).isTrue() + assertThat(exceptionRef.get()).isNotNull() + } + + @Test + fun singleTimeoutForAllFollowUpRequests() { + server.enqueue( + MockResponse.Builder() + .code(HttpURLConnection.HTTP_MOVED_TEMP) + .setHeader("Location", "/b") + .headersDelay(100, TimeUnit.MILLISECONDS) + .build() + ) + server.enqueue( + MockResponse.Builder() + .code(HttpURLConnection.HTTP_MOVED_TEMP) + .setHeader("Location", "/c") + .headersDelay(100, TimeUnit.MILLISECONDS) + .build() + ) + server.enqueue( + MockResponse.Builder() + .code(HttpURLConnection.HTTP_MOVED_TEMP) + .setHeader("Location", "/d") + .headersDelay(100, TimeUnit.MILLISECONDS) + .build() + ) + server.enqueue( + MockResponse.Builder() + .code(HttpURLConnection.HTTP_MOVED_TEMP) + .setHeader("Location", "/e") + .headersDelay(100, TimeUnit.MILLISECONDS) + .build() + ) + server.enqueue( + MockResponse.Builder() + .code(HttpURLConnection.HTTP_MOVED_TEMP) + .setHeader("Location", "/f") + .headersDelay(100, TimeUnit.MILLISECONDS) + .build() + ) + server.enqueue(MockResponse()) + val request = Request.Builder() + .url(server.url("/a")) + .build() + val call = client.newCall(request) + call.timeout().timeout(250, TimeUnit.MILLISECONDS) + try { + call.execute() + fail() + } catch (e: IOException) { + assertThat(e.message).isEqualTo("timeout") + assertThat(call.isCanceled()).isTrue() + } + } + + @Test + fun timeoutFollowingRedirectOnNewConnection() { + val otherServer = MockWebServer() + server.enqueue( + MockResponse.Builder() + .code(HttpURLConnection.HTTP_MOVED_TEMP) + .setHeader("Location", otherServer.url("/")) + .build() + ) + otherServer.enqueue( + MockResponse.Builder() + .headersDelay(500, TimeUnit.MILLISECONDS) + .build() + ) + val request = Request.Builder().url(server.url("/")).build() + val call = client.newCall(request) + call.timeout().timeout(250, TimeUnit.MILLISECONDS) + try { + call.execute() + fail() + } catch (e: IOException) { + assertThat(e.message).isEqualTo("timeout") + assertThat(call.isCanceled()).isTrue() + } + } + + @Flaky + @Test + fun noTimeout() { + // Flaky https://github.com/square/okhttp/issues/5304 + server.enqueue( + MockResponse.Builder() + .headersDelay(250, TimeUnit.MILLISECONDS) + .body(BIG_ENOUGH_BODY) + .build() + ) + val request = Request.Builder() + .url(server.url("/")) + .post(sleepingRequestBody(250)) + .build() + val call = client.newCall(request) + call.timeout().timeout(2000, TimeUnit.MILLISECONDS) + val response = call.execute() + Thread.sleep(250) + response.body.source().readUtf8() + response.close() + assertThat(call.isCanceled()).isFalse() + } + + private fun sleepingRequestBody(sleepMillis: Int): RequestBody { + return object : RequestBody() { + override fun contentType(): MediaType? { + return "text/plain".toMediaTypeOrNull() + } + + @Throws(IOException::class) + override fun writeTo(sink: BufferedSink) { + try { + sink.writeUtf8("abc") + sink.flush() + Thread.sleep(sleepMillis.toLong()) + sink.writeUtf8("def") + } catch (e: InterruptedException) { + throw InterruptedIOException() + } + } + } + } + + companion object { + /** A large response body. Smaller bodies might successfully read after the socket is closed! */ + private val BIG_ENOUGH_BODY = repeat('a', 64 * 1024) + } +} diff --git a/okhttp/src/test/java/okhttp3/internal/cache2/FileOperatorTest.java b/okhttp/src/test/java/okhttp3/internal/cache2/FileOperatorTest.java deleted file mode 100644 index 067a2bac1bb2..000000000000 --- a/okhttp/src/test/java/okhttp3/internal/cache2/FileOperatorTest.java +++ /dev/null @@ -1,198 +0,0 @@ -/* - * Copyright (C) 2016 Square, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package okhttp3.internal.cache2; - -import java.io.File; -import java.io.IOException; -import java.io.RandomAccessFile; -import java.util.Random; -import okio.Buffer; -import okio.BufferedSink; -import okio.BufferedSource; -import okio.ByteString; -import okio.Okio; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.io.TempDir; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.fail; - -public final class FileOperatorTest { - @TempDir public File tempDir; - private File file; - private RandomAccessFile randomAccessFile; - - @BeforeEach public void setUp() throws Exception { - file = new File(tempDir, "test"); - randomAccessFile = new RandomAccessFile(file, "rw"); - } - - @AfterEach public void tearDown() throws Exception { - randomAccessFile.close(); - } - - @Test public void read() throws Exception { - write(ByteString.encodeUtf8("Hello, World")); - - FileOperator operator = new FileOperator(randomAccessFile.getChannel()); - - Buffer buffer = new Buffer(); - operator.read(0, buffer, 5); - assertThat(buffer.readUtf8()).isEqualTo("Hello"); - - operator.read(4, buffer, 5); - assertThat(buffer.readUtf8()).isEqualTo("o, Wo"); - } - - @Test public void write() throws Exception { - FileOperator operator = new FileOperator(randomAccessFile.getChannel()); - - Buffer buffer1 = new Buffer().writeUtf8("Hello, World"); - operator.write(0, buffer1, 5); - assertThat(buffer1.readUtf8()).isEqualTo(", World"); - - Buffer buffer2 = new Buffer().writeUtf8("icopter!"); - operator.write(3, buffer2, 7); - assertThat(buffer2.readUtf8()).isEqualTo("!"); - - assertThat(snapshot()).isEqualTo(ByteString.encodeUtf8("Helicopter")); - } - - @Test public void readAndWrite() throws Exception { - FileOperator operator = new FileOperator(randomAccessFile.getChannel()); - - write(ByteString.encodeUtf8("woman god creates dinosaurs destroys. ")); - Buffer buffer = new Buffer(); - operator.read(6, buffer, 21); - operator.read(36, buffer, 1); - operator.read(5, buffer, 5); - operator.read(28, buffer, 8); - operator.read(17, buffer, 10); - operator.read(36, buffer, 2); - operator.read(2, buffer, 4); - operator.write(0, buffer, buffer.size()); - operator.read(0, buffer, 12); - operator.read(47, buffer, 3); - operator.read(45, buffer, 2); - operator.read(47, buffer, 3); - operator.read(26, buffer, 10); - operator.read(23, buffer, 3); - operator.write(47, buffer, buffer.size()); - operator.read(62, buffer, 6); - operator.read(4, buffer, 19); - operator.write(80, buffer, buffer.size()); - - assertThat(ByteString.encodeUtf8("" - + "god creates dinosaurs. " - + "god destroys dinosaurs. " - + "god creates man. " - + "man destroys god. " - + "man creates dinosaurs. ")).isEqualTo(snapshot()); - } - - @Test public void multipleOperatorsShareOneFile() throws Exception { - FileOperator operatorA = new FileOperator(randomAccessFile.getChannel()); - FileOperator operatorB = new FileOperator(randomAccessFile.getChannel()); - - Buffer bufferA = new Buffer(); - Buffer bufferB = new Buffer(); - - bufferA.writeUtf8("Dodgson!\n"); - operatorA.write(0, bufferA, 9); - - bufferB.writeUtf8("You shouldn't use my name.\n"); - operatorB.write(9, bufferB, 27); - - bufferA.writeUtf8("Dodgson, we've got Dodgson here!\n"); - operatorA.write(36, bufferA, 33); - - operatorB.read(0, bufferB, 9); - assertThat(bufferB.readUtf8()).isEqualTo("Dodgson!\n"); - - operatorA.read(9, bufferA, 27); - assertThat(bufferA.readUtf8()).isEqualTo("You shouldn't use my name.\n"); - - operatorB.read(36, bufferB, 33); - assertThat(bufferB.readUtf8()).isEqualTo("Dodgson, we've got Dodgson here!\n"); - } - - @Test public void largeRead() throws Exception { - ByteString data = randomByteString(1000000); - write(data); - - FileOperator operator = new FileOperator(randomAccessFile.getChannel()); - - Buffer buffer = new Buffer(); - operator.read(0, buffer, data.size()); - assertThat(buffer.readByteString()).isEqualTo(data); - } - - @Test public void largeWrite() throws Exception { - ByteString data = randomByteString(1000000); - - FileOperator operator = new FileOperator(randomAccessFile.getChannel()); - - Buffer buffer = new Buffer().write(data); - operator.write(0, buffer, data.size()); - - assertThat(snapshot()).isEqualTo(data); - } - - @Test public void readBounds() throws Exception { - FileOperator operator = new FileOperator(randomAccessFile.getChannel()); - Buffer buffer = new Buffer(); - try { - operator.read(0, buffer, -1L); - fail(); - } catch (IndexOutOfBoundsException expected) { - } - } - - @Test public void writeBounds() throws Exception { - FileOperator operator = new FileOperator(randomAccessFile.getChannel()); - Buffer buffer = new Buffer().writeUtf8("abc"); - try { - operator.write(0, buffer, -1L); - fail(); - } catch (IndexOutOfBoundsException expected) { - } - try { - operator.write(0, buffer, 4L); - fail(); - } catch (IndexOutOfBoundsException expected) { - } - } - - private ByteString randomByteString(int byteCount) { - byte[] bytes = new byte[byteCount]; - new Random(0).nextBytes(bytes); - return ByteString.of(bytes); - } - - private ByteString snapshot() throws IOException { - randomAccessFile.getChannel().force(false); - BufferedSource source = Okio.buffer(Okio.source(file)); - return source.readByteString(); - } - - private void write(ByteString data) throws IOException { - BufferedSink sink = Okio.buffer(Okio.sink(file)); - sink.write(data); - sink.close(); - } -} diff --git a/okhttp/src/test/java/okhttp3/internal/cache2/FileOperatorTest.kt b/okhttp/src/test/java/okhttp3/internal/cache2/FileOperatorTest.kt new file mode 100644 index 000000000000..946c599d1f62 --- /dev/null +++ b/okhttp/src/test/java/okhttp3/internal/cache2/FileOperatorTest.kt @@ -0,0 +1,209 @@ +/* + * Copyright (C) 2016 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3.internal.cache2 + +import java.io.File +import java.io.RandomAccessFile +import java.util.Random +import okio.Buffer +import okio.ByteString +import okio.ByteString.Companion.encodeUtf8 +import okio.buffer +import okio.sink +import okio.source +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.Assertions.fail +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.io.TempDir + +class FileOperatorTest { + @TempDir + var tempDir: File? = null + private var file: File? = null + private var randomAccessFile: RandomAccessFile? = null + + @BeforeEach + fun setUp() { + file = File(tempDir, "test") + randomAccessFile = RandomAccessFile(file, "rw") + } + + @AfterEach + fun tearDown() { + randomAccessFile!!.close() + } + + @Test + fun read() { + write("Hello, World".encodeUtf8()) + val operator = FileOperator( + randomAccessFile!!.getChannel() + ) + val buffer = Buffer() + operator.read(0, buffer, 5) + assertThat(buffer.readUtf8()).isEqualTo("Hello") + operator.read(4, buffer, 5) + assertThat(buffer.readUtf8()).isEqualTo("o, Wo") + } + + @Test + fun write() { + val operator = FileOperator( + randomAccessFile!!.getChannel() + ) + val buffer1 = Buffer().writeUtf8("Hello, World") + operator.write(0, buffer1, 5) + assertThat(buffer1.readUtf8()).isEqualTo(", World") + val buffer2 = Buffer().writeUtf8("icopter!") + operator.write(3, buffer2, 7) + assertThat(buffer2.readUtf8()).isEqualTo("!") + assertThat(snapshot()).isEqualTo("Helicopter".encodeUtf8()) + } + + @Test + fun readAndWrite() { + val operator = FileOperator( + randomAccessFile!!.getChannel() + ) + write("woman god creates dinosaurs destroys. ".encodeUtf8()) + val buffer = Buffer() + operator.read(6, buffer, 21) + operator.read(36, buffer, 1) + operator.read(5, buffer, 5) + operator.read(28, buffer, 8) + operator.read(17, buffer, 10) + operator.read(36, buffer, 2) + operator.read(2, buffer, 4) + operator.write(0, buffer, buffer.size) + operator.read(0, buffer, 12) + operator.read(47, buffer, 3) + operator.read(45, buffer, 2) + operator.read(47, buffer, 3) + operator.read(26, buffer, 10) + operator.read(23, buffer, 3) + operator.write(47, buffer, buffer.size) + operator.read(62, buffer, 6) + operator.read(4, buffer, 19) + operator.write(80, buffer, buffer.size) + assertThat(snapshot()).isEqualTo( + ("" + + "god creates dinosaurs. " + + "god destroys dinosaurs. " + + "god creates man. " + + "man destroys god. " + + "man creates dinosaurs. " + ).encodeUtf8() + ) + } + + @Test + fun multipleOperatorsShareOneFile() { + val operatorA = FileOperator( + randomAccessFile!!.getChannel() + ) + val operatorB = FileOperator( + randomAccessFile!!.getChannel() + ) + val bufferA = Buffer() + val bufferB = Buffer() + bufferA.writeUtf8("Dodgson!\n") + operatorA.write(0, bufferA, 9) + bufferB.writeUtf8("You shouldn't use my name.\n") + operatorB.write(9, bufferB, 27) + bufferA.writeUtf8("Dodgson, we've got Dodgson here!\n") + operatorA.write(36, bufferA, 33) + operatorB.read(0, bufferB, 9) + assertThat(bufferB.readUtf8()).isEqualTo("Dodgson!\n") + operatorA.read(9, bufferA, 27) + assertThat(bufferA.readUtf8()).isEqualTo("You shouldn't use my name.\n") + operatorB.read(36, bufferB, 33) + assertThat(bufferB.readUtf8()).isEqualTo("Dodgson, we've got Dodgson here!\n") + } + + @Test + fun largeRead() { + val data = randomByteString(1000000) + write(data) + val operator = FileOperator( + randomAccessFile!!.getChannel() + ) + val buffer = Buffer() + operator.read(0, buffer, data.size.toLong()) + assertThat(buffer.readByteString()).isEqualTo(data) + } + + @Test + fun largeWrite() { + val data = randomByteString(1000000) + val operator = FileOperator( + randomAccessFile!!.getChannel() + ) + val buffer = Buffer().write(data) + operator.write(0, buffer, data.size.toLong()) + assertThat(snapshot()).isEqualTo(data) + } + + @Test + fun readBounds() { + val operator = FileOperator( + randomAccessFile!!.getChannel() + ) + val buffer = Buffer() + try { + operator.read(0, buffer, -1L) + fail() + } catch (expected: IndexOutOfBoundsException) { + } + } + + @Test + fun writeBounds() { + val operator = FileOperator( + randomAccessFile!!.getChannel() + ) + val buffer = Buffer().writeUtf8("abc") + try { + operator.write(0, buffer, -1L) + fail() + } catch (expected: IndexOutOfBoundsException) { + } + try { + operator.write(0, buffer, 4L) + fail() + } catch (expected: IndexOutOfBoundsException) { + } + } + + private fun randomByteString(byteCount: Int): ByteString { + val bytes = ByteArray(byteCount) + Random(0).nextBytes(bytes) + return ByteString.of(*bytes) + } + + private fun snapshot(): ByteString { + randomAccessFile!!.getChannel().force(false) + val source = file!!.source().buffer() + return source.readByteString() + } + + private fun write(data: ByteString) { + val sink = file!!.sink().buffer() + sink.write(data) + sink.close() + } +} diff --git a/okhttp/src/test/java/okhttp3/internal/cache2/RelayTest.java b/okhttp/src/test/java/okhttp3/internal/cache2/RelayTest.java deleted file mode 100644 index 5c8caf1893fc..000000000000 --- a/okhttp/src/test/java/okhttp3/internal/cache2/RelayTest.java +++ /dev/null @@ -1,251 +0,0 @@ -/* - * Copyright (C) 2016 Square, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package okhttp3.internal.cache2; - -import java.io.File; -import java.io.IOException; -import java.util.concurrent.Callable; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.Future; -import okio.Buffer; -import okio.BufferedSink; -import okio.BufferedSource; -import okio.ByteString; -import okio.Okio; -import okio.Pipe; -import okio.Source; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Tag; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.io.TempDir; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.fail; - -@Tag("Slowish") -public final class RelayTest { - @TempDir File tempDir; - private final ExecutorService executor = Executors.newCachedThreadPool(); - private final ByteString metadata = ByteString.encodeUtf8("great metadata!"); - private File file; - - @BeforeEach - void setUp() { - file = new File(tempDir, "test"); - } - - @AfterEach public void tearDown() throws Exception { - executor.shutdown(); - } - - @Test public void singleSource() throws Exception { - Buffer upstream = new Buffer(); - upstream.writeUtf8("abcdefghijklm"); - - Relay relay = Relay.Companion.edit(file, upstream, metadata, 1024); - Source source = relay.newSource(); - Buffer sourceBuffer = new Buffer(); - - assertThat(source.read(sourceBuffer, 5)).isEqualTo(5); - assertThat(sourceBuffer.readUtf8()).isEqualTo("abcde"); - - assertThat(source.read(sourceBuffer, 1024)).isEqualTo(8); - assertThat(sourceBuffer.readUtf8()).isEqualTo("fghijklm"); - - assertThat(source.read(sourceBuffer, 1024)).isEqualTo(-1); - assertThat(sourceBuffer.size()).isEqualTo(0); - - source.close(); - assertThat(relay.isClosed()).isTrue(); - assertFile(Relay.PREFIX_CLEAN, 13L, metadata.size(), "abcdefghijklm", metadata); - } - - @Test public void multipleSources() throws Exception { - Buffer upstream = new Buffer(); - upstream.writeUtf8("abcdefghijklm"); - - Relay relay = Relay.Companion.edit(file, upstream, metadata, 1024); - BufferedSource source1 = Okio.buffer(relay.newSource()); - BufferedSource source2 = Okio.buffer(relay.newSource()); - - assertThat(source1.readUtf8()).isEqualTo("abcdefghijklm"); - assertThat(source2.readUtf8()).isEqualTo("abcdefghijklm"); - source1.close(); - source2.close(); - assertThat(relay.isClosed()).isTrue(); - - assertFile(Relay.PREFIX_CLEAN, 13L, metadata.size(), "abcdefghijklm", metadata); - } - - @Test public void readFromBuffer() throws Exception { - Buffer upstream = new Buffer(); - upstream.writeUtf8("abcdefghij"); - - Relay relay = Relay.Companion.edit(file, upstream, metadata, 5); - BufferedSource source1 = Okio.buffer(relay.newSource()); - BufferedSource source2 = Okio.buffer(relay.newSource()); - - assertThat(source1.readUtf8(5)).isEqualTo("abcde"); - assertThat(source2.readUtf8(5)).isEqualTo("abcde"); - assertThat(source2.readUtf8(5)).isEqualTo("fghij"); - assertThat(source1.readUtf8(5)).isEqualTo("fghij"); - assertThat(source1.exhausted()).isTrue(); - assertThat(source2.exhausted()).isTrue(); - source1.close(); - source2.close(); - assertThat(relay.isClosed()).isTrue(); - - assertFile(Relay.PREFIX_CLEAN, 10L, metadata.size(), "abcdefghij", metadata); - } - - @Test public void readFromFile() throws Exception { - Buffer upstream = new Buffer(); - upstream.writeUtf8("abcdefghijklmnopqrst"); - - Relay relay = Relay.Companion.edit(file, upstream, metadata, 5); - BufferedSource source1 = Okio.buffer(relay.newSource()); - BufferedSource source2 = Okio.buffer(relay.newSource()); - - assertThat(source1.readUtf8(10)).isEqualTo("abcdefghij"); - assertThat(source2.readUtf8(10)).isEqualTo("abcdefghij"); - assertThat(source2.readUtf8(10)).isEqualTo("klmnopqrst"); - assertThat(source1.readUtf8(10)).isEqualTo("klmnopqrst"); - assertThat(source1.exhausted()).isTrue(); - assertThat(source2.exhausted()).isTrue(); - source1.close(); - source2.close(); - assertThat(relay.isClosed()).isTrue(); - - assertFile(Relay.PREFIX_CLEAN, 20L, metadata.size(), "abcdefghijklmnopqrst", metadata); - } - - @Test public void readAfterEdit() throws Exception { - Buffer upstream = new Buffer(); - upstream.writeUtf8("abcdefghij"); - - Relay relay1 = Relay.Companion.edit(file, upstream, metadata, 5); - BufferedSource source1 = Okio.buffer(relay1.newSource()); - assertThat(source1.readUtf8(10)).isEqualTo("abcdefghij"); - assertThat(source1.exhausted()).isTrue(); - source1.close(); - assertThat(relay1.isClosed()).isTrue(); - - // Since relay1 is closed, new sources cannot be created. - assertThat(relay1.newSource()).isNull(); - - Relay relay2 = Relay.Companion.read(file); - assertThat(relay2.metadata()).isEqualTo(metadata); - BufferedSource source2 = Okio.buffer(relay2.newSource()); - assertThat(source2.readUtf8(10)).isEqualTo("abcdefghij"); - assertThat(source2.exhausted()).isTrue(); - source2.close(); - assertThat(relay2.isClosed()).isTrue(); - - // Since relay2 is closed, new sources cannot be created. - assertThat(relay2.newSource()).isNull(); - - assertFile(Relay.PREFIX_CLEAN, 10L, metadata.size(), "abcdefghij", metadata); - } - - @Test public void closeBeforeExhaustLeavesDirtyFile() throws Exception { - Buffer upstream = new Buffer(); - upstream.writeUtf8("abcdefghij"); - - Relay relay1 = Relay.Companion.edit(file, upstream, metadata, 5); - BufferedSource source1 = Okio.buffer(relay1.newSource()); - assertThat(source1.readUtf8(10)).isEqualTo("abcdefghij"); - source1.close(); // Not exhausted! - assertThat(relay1.isClosed()).isTrue(); - - try { - Relay.Companion.read(file); - fail(); - } catch (IOException expected) { - assertThat(expected.getMessage()).isEqualTo("unreadable cache file"); - } - - assertFile(Relay.PREFIX_DIRTY, -1L, -1, null, null); - } - - @Test public void redundantCallsToCloseAreIgnored() throws Exception { - Buffer upstream = new Buffer(); - upstream.writeUtf8("abcde"); - - Relay relay = Relay.Companion.edit(file, upstream, metadata, 1024); - Source source1 = relay.newSource(); - Source source2 = relay.newSource(); - - source1.close(); - source1.close(); // Unnecessary. Shouldn't decrement the reference count. - assertThat(relay.isClosed()).isFalse(); - - source2.close(); - assertThat(relay.isClosed()).isTrue(); - assertFile(Relay.PREFIX_DIRTY, -1L, -1, null, null); - } - - @Test public void racingReaders() throws Exception { - Pipe pipe = new Pipe(1024); - BufferedSink sink = Okio.buffer(pipe.sink()); - - Relay relay = Relay.Companion.edit(file, pipe.source(), metadata, 5); - - Future future1 = executor.submit(sourceReader(relay.newSource())); - Future future2 = executor.submit(sourceReader(relay.newSource())); - - Thread.sleep(500); - sink.writeUtf8("abcdefghij"); - - Thread.sleep(500); - sink.writeUtf8("klmnopqrst"); - sink.close(); - - assertThat(future1.get()).isEqualTo(ByteString.encodeUtf8("abcdefghijklmnopqrst")); - assertThat(future2.get()).isEqualTo(ByteString.encodeUtf8("abcdefghijklmnopqrst")); - - assertThat(relay.isClosed()).isTrue(); - - assertFile(Relay.PREFIX_CLEAN, 20L, metadata.size(), "abcdefghijklmnopqrst", metadata); - } - - /** Returns a callable that reads all of source, closes it, and returns the bytes. */ - private Callable sourceReader(final Source source) { - return () -> { - Buffer buffer = new Buffer(); - while (source.read(buffer, 16384) != -1) { - } - source.close(); - return buffer.readByteString(); - }; - } - - private void assertFile(ByteString prefix, long upstreamSize, int metadataSize, String upstream, - ByteString metadata) throws IOException { - BufferedSource source = Okio.buffer(Okio.source(file)); - assertThat(source.readByteString(prefix.size())).isEqualTo(prefix); - assertThat(source.readLong()).isEqualTo(upstreamSize); - assertThat(source.readLong()).isEqualTo(metadataSize); - if (upstream != null) { - assertThat(source.readUtf8(upstreamSize)).isEqualTo(upstream); - } - if (metadata != null) { - assertThat(source.readByteString(metadataSize)).isEqualTo(metadata); - } - source.close(); - } -} diff --git a/okhttp/src/test/java/okhttp3/internal/cache2/RelayTest.kt b/okhttp/src/test/java/okhttp3/internal/cache2/RelayTest.kt new file mode 100644 index 000000000000..eb1cb33bf208 --- /dev/null +++ b/okhttp/src/test/java/okhttp3/internal/cache2/RelayTest.kt @@ -0,0 +1,233 @@ +/* + * Copyright (C) 2016 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3.internal.cache2 + +import java.io.File +import java.io.IOException +import java.util.concurrent.Callable +import java.util.concurrent.Executors +import okhttp3.internal.cache2.Relay.Companion.edit +import okhttp3.internal.cache2.Relay.Companion.read +import okio.Buffer +import okio.ByteString +import okio.ByteString.Companion.encodeUtf8 +import okio.Pipe +import okio.Source +import okio.buffer +import okio.source +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Tag +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.io.TempDir + +@Tag("Slowish") +class RelayTest { + @TempDir + var tempDir: File? = null + private val executor = Executors.newCachedThreadPool() + private val metadata: ByteString = "great metadata!".encodeUtf8() + private lateinit var file: File + + @BeforeEach + fun setUp() { + file = File(tempDir, "test") + } + + @AfterEach + fun tearDown() { + executor.shutdown() + } + + @Test + fun singleSource() { + val upstream = Buffer() + upstream.writeUtf8("abcdefghijklm") + val relay = edit(file, upstream, metadata, 1024) + val source = relay.newSource() + val sourceBuffer = Buffer() + assertThat(source!!.read(sourceBuffer, 5)).isEqualTo(5) + assertThat(sourceBuffer.readUtf8()).isEqualTo("abcde") + assertThat(source.read(sourceBuffer, 1024)).isEqualTo(8) + assertThat(sourceBuffer.readUtf8()).isEqualTo("fghijklm") + assertThat(source.read(sourceBuffer, 1024)).isEqualTo(-1) + assertThat(sourceBuffer.size).isEqualTo(0) + source.close() + assertThat(relay.isClosed).isTrue() + assertFile(Relay.PREFIX_CLEAN, 13L, metadata.size, "abcdefghijklm", metadata) + } + + @Test + fun multipleSources() { + val upstream = Buffer() + upstream.writeUtf8("abcdefghijklm") + val relay = edit(file, upstream, metadata, 1024) + val source1 = relay.newSource()!!.buffer() + val source2 = relay.newSource()!!.buffer() + assertThat(source1.readUtf8()).isEqualTo("abcdefghijklm") + assertThat(source2.readUtf8()).isEqualTo("abcdefghijklm") + source1.close() + source2.close() + assertThat(relay.isClosed).isTrue() + assertFile(Relay.PREFIX_CLEAN, 13L, metadata.size, "abcdefghijklm", metadata) + } + + @Test + fun readFromBuffer() { + val upstream = Buffer() + upstream.writeUtf8("abcdefghij") + val relay = edit(file, upstream, metadata, 5) + val source1 = relay.newSource()!!.buffer() + val source2 = relay.newSource()!!.buffer() + assertThat(source1.readUtf8(5)).isEqualTo("abcde") + assertThat(source2.readUtf8(5)).isEqualTo("abcde") + assertThat(source2.readUtf8(5)).isEqualTo("fghij") + assertThat(source1.readUtf8(5)).isEqualTo("fghij") + assertThat(source1.exhausted()).isTrue() + assertThat(source2.exhausted()).isTrue() + source1.close() + source2.close() + assertThat(relay.isClosed).isTrue() + assertFile(Relay.PREFIX_CLEAN, 10L, metadata.size, "abcdefghij", metadata) + } + + @Test + fun readFromFile() { + val upstream = Buffer() + upstream.writeUtf8("abcdefghijklmnopqrst") + val relay = edit(file, upstream, metadata, 5) + val source1 = relay.newSource()!!.buffer() + val source2 = relay.newSource()!!.buffer() + assertThat(source1.readUtf8(10)).isEqualTo("abcdefghij") + assertThat(source2.readUtf8(10)).isEqualTo("abcdefghij") + assertThat(source2.readUtf8(10)).isEqualTo("klmnopqrst") + assertThat(source1.readUtf8(10)).isEqualTo("klmnopqrst") + assertThat(source1.exhausted()).isTrue() + assertThat(source2.exhausted()).isTrue() + source1.close() + source2.close() + assertThat(relay.isClosed).isTrue() + assertFile(Relay.PREFIX_CLEAN, 20L, metadata.size, "abcdefghijklmnopqrst", metadata) + } + + @Test + fun readAfterEdit() { + val upstream = Buffer() + upstream.writeUtf8("abcdefghij") + val relay1 = edit(file, upstream, metadata, 5) + val source1 = relay1.newSource()!!.buffer() + assertThat(source1.readUtf8(10)).isEqualTo("abcdefghij") + assertThat(source1.exhausted()).isTrue() + source1.close() + assertThat(relay1.isClosed).isTrue() + + // Since relay1 is closed, new sources cannot be created. + assertThat(relay1.newSource()).isNull() + val relay2 = read(file) + assertThat(relay2.metadata()).isEqualTo(metadata) + val source2 = relay2.newSource()!!.buffer() + assertThat(source2.readUtf8(10)).isEqualTo("abcdefghij") + assertThat(source2.exhausted()).isTrue() + source2.close() + assertThat(relay2.isClosed).isTrue() + + // Since relay2 is closed, new sources cannot be created. + assertThat(relay2.newSource()).isNull() + assertFile(Relay.PREFIX_CLEAN, 10L, metadata.size, "abcdefghij", metadata) + } + + @Test + fun closeBeforeExhaustLeavesDirtyFile() { + val upstream = Buffer() + upstream.writeUtf8("abcdefghij") + val relay1 = edit(file, upstream, metadata, 5) + val source1 = relay1.newSource()!!.buffer() + assertThat(source1.readUtf8(10)).isEqualTo("abcdefghij") + source1.close() // Not exhausted! + assertThat(relay1.isClosed).isTrue() + try { + read(file) + org.junit.jupiter.api.Assertions.fail() + } catch (expected: IOException) { + assertThat(expected.message).isEqualTo("unreadable cache file") + } + assertFile(Relay.PREFIX_DIRTY, -1L, -1, null, null) + } + + @Test + fun redundantCallsToCloseAreIgnored() { + val upstream = Buffer() + upstream.writeUtf8("abcde") + val relay = edit(file, upstream, metadata, 1024) + val source1 = relay.newSource() + val source2 = relay.newSource() + source1!!.close() + source1.close() // Unnecessary. Shouldn't decrement the reference count. + assertThat(relay.isClosed).isFalse() + source2!!.close() + assertThat(relay.isClosed).isTrue() + assertFile(Relay.PREFIX_DIRTY, -1L, -1, null, null) + } + + @Test + fun racingReaders() { + val pipe = Pipe(1024) + val sink = pipe.sink.buffer() + val relay = edit(file, pipe.source, metadata, 5) + val future1 = executor.submit(sourceReader(relay.newSource())) + val future2 = executor.submit(sourceReader(relay.newSource())) + Thread.sleep(500) + sink.writeUtf8("abcdefghij") + Thread.sleep(500) + sink.writeUtf8("klmnopqrst") + sink.close() + assertThat(future1.get()) + .isEqualTo("abcdefghijklmnopqrst".encodeUtf8()) + assertThat(future2.get()) + .isEqualTo("abcdefghijklmnopqrst".encodeUtf8()) + assertThat(relay.isClosed).isTrue() + assertFile(Relay.PREFIX_CLEAN, 20L, metadata.size, "abcdefghijklmnopqrst", metadata) + } + + /** Returns a callable that reads all of source, closes it, and returns the bytes. */ + private fun sourceReader(source: Source?): Callable { + return Callable { + val buffer = Buffer() + while (source!!.read(buffer, 16384) != -1L) { + } + source.close() + buffer.readByteString() + } + } + + private fun assertFile( + prefix: ByteString, upstreamSize: Long, metadataSize: Int, upstream: String?, + metadata: ByteString? + ) { + val source = file.source().buffer() + assertThat(source.readByteString(prefix.size.toLong())).isEqualTo(prefix) + assertThat(source.readLong()).isEqualTo(upstreamSize) + assertThat(source.readLong()).isEqualTo(metadataSize.toLong()) + if (upstream != null) { + assertThat(source.readUtf8(upstreamSize)).isEqualTo(upstream) + } + if (metadata != null) { + assertThat(source.readByteString(metadataSize.toLong())).isEqualTo(metadata) + } + source.close() + } +} diff --git a/okhttp/src/test/java/okhttp3/internal/http2/BaseTestHandler.java b/okhttp/src/test/java/okhttp3/internal/http2/BaseTestHandler.java deleted file mode 100644 index 88679b79eb10..000000000000 --- a/okhttp/src/test/java/okhttp3/internal/http2/BaseTestHandler.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright (C) 2013 Square, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package okhttp3.internal.http2; - -import java.io.IOException; -import java.util.List; -import okio.BufferedSource; -import okio.ByteString; - -import static org.junit.jupiter.api.Assertions.fail; - -class BaseTestHandler implements Http2Reader.Handler { - @Override public void data(boolean inFinished, int streamId, BufferedSource source, int length) - throws IOException { - fail(); - } - - @Override public void headers(boolean inFinished, int streamId, int associatedStreamId, - List
headerBlock) { - fail(); - } - - @Override public void rstStream(int streamId, ErrorCode errorCode) { - fail(); - } - - @Override public void settings(boolean clearPrevious, Settings settings) { - fail(); - } - - @Override public void ackSettings() { - fail(); - } - - @Override public void ping(boolean ack, int payload1, int payload2) { - fail(); - } - - @Override public void goAway(int lastGoodStreamId, ErrorCode errorCode, ByteString debugData) { - fail(); - } - - @Override public void windowUpdate(int streamId, long windowSizeIncrement) { - fail(); - } - - @Override public void priority(int streamId, int streamDependency, int weight, - boolean exclusive) { - fail(); - } - - @Override - public void pushPromise(int streamId, int associatedStreamId, List
headerBlock) { - fail(); - } - - @Override public void alternateService(int streamId, String origin, ByteString protocol, - String host, int port, long maxAge) { - fail(); - } -} diff --git a/okhttp/src/test/java/okhttp3/internal/http2/BaseTestHandler.kt b/okhttp/src/test/java/okhttp3/internal/http2/BaseTestHandler.kt new file mode 100644 index 000000000000..2352083db671 --- /dev/null +++ b/okhttp/src/test/java/okhttp3/internal/http2/BaseTestHandler.kt @@ -0,0 +1,109 @@ +/* + * Copyright (C) 2013 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3.internal.http2 + +import okio.BufferedSource +import okio.ByteString +import org.junit.jupiter.api.Assertions.fail + +internal open class BaseTestHandler : Http2Reader.Handler { + override fun data( + inFinished: Boolean, + streamId: Int, + source: BufferedSource, + length: Int, + ) { + fail() + } + + override fun headers( + inFinished: Boolean, + streamId: Int, + associatedStreamId: Int, + headerBlock: List
, + ) { + fail() + } + + override fun rstStream( + streamId: Int, + errorCode: ErrorCode, + ) { + fail() + } + + override fun settings( + clearPrevious: Boolean, + settings: Settings, + ) { + fail() + } + + override fun ackSettings() { + fail() + } + + override fun ping( + ack: Boolean, + payload1: Int, + payload2: Int, + ) { + fail() + } + + override fun goAway( + lastGoodStreamId: Int, + errorCode: ErrorCode, + debugData: ByteString, + ) { + fail() + } + + override fun windowUpdate( + streamId: Int, + windowSizeIncrement: Long, + ) { + fail() + } + + override fun priority( + streamId: Int, + streamDependency: Int, + weight: Int, + exclusive: Boolean, + ) { + fail() + } + + override fun pushPromise( + streamId: Int, + associatedStreamId: Int, + headerBlock: List
, + ) { + fail() + } + + override fun alternateService( + streamId: Int, + origin: String, + protocol: ByteString, + host: String, + port: Int, + maxAge: Long, + ) { + fail() + } +} diff --git a/okhttp/src/test/java/okhttp3/internal/http2/HpackTest.java b/okhttp/src/test/java/okhttp3/internal/http2/HpackTest.java deleted file mode 100644 index 59d88a991944..000000000000 --- a/okhttp/src/test/java/okhttp3/internal/http2/HpackTest.java +++ /dev/null @@ -1,1107 +0,0 @@ -/* - * Copyright (C) 2013 Square, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package okhttp3.internal.http2; - -import java.io.IOException; -import java.util.Arrays; -import java.util.List; -import okio.Buffer; -import okio.ByteString; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -import static java.util.Arrays.asList; -import static okhttp3.TestUtil.headerEntries; -import static okio.ByteString.decodeHex; -import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.fail; - -public final class HpackTest { - private final Buffer bytesIn = new Buffer(); - private Hpack.Reader hpackReader; - private final Buffer bytesOut = new Buffer(); - private Hpack.Writer hpackWriter; - - @BeforeEach public void reset() { - hpackReader = newReader(bytesIn); - hpackWriter = new Hpack.Writer(4096, false, bytesOut); - } - - /** - * Variable-length quantity special cases strings which are longer than 127 bytes. Values such as - * cookies can be 4KiB, and should be possible to send. - * - *

http://tools.ietf.org/html/draft-ietf-httpbis-header-compression-12#section-5.2 - */ - @Test public void largeHeaderValue() throws IOException { - char[] value = new char[4096]; - Arrays.fill(value, '!'); - List

headerBlock = headerEntries("cookie", new String(value)); - - hpackWriter.writeHeaders(headerBlock); - bytesIn.writeAll(bytesOut); - hpackReader.readHeaders(); - - assertThat(hpackReader.headerCount).isEqualTo(0); - - assertThat(hpackReader.getAndResetHeaderList()).isEqualTo(headerBlock); - } - - /** - * HPACK has a max header table size, which can be smaller than the max header message. Ensure the - * larger header content is not lost. - */ - @Test public void tooLargeToHPackIsStillEmitted() throws IOException { - bytesIn.writeByte(0x21); // Dynamic table size update (size = 1). - bytesIn.writeByte(0x00); // Literal indexed - bytesIn.writeByte(0x0a); // Literal name (len = 10) - bytesIn.writeUtf8("custom-key"); - - bytesIn.writeByte(0x0d); // Literal value (len = 13) - bytesIn.writeUtf8("custom-header"); - - hpackReader.readHeaders(); - - assertThat(hpackReader.headerCount).isEqualTo(0); - - assertThat(hpackReader.getAndResetHeaderList()).isEqualTo( - headerEntries("custom-key", "custom-header")); - } - - /** Oldest entries are evicted to support newer ones. */ - @Test public void writerEviction() throws IOException { - List
headerBlock = - headerEntries( - "custom-foo", "custom-header", - "custom-bar", "custom-header", - "custom-baz", "custom-header"); - - bytesIn.writeByte(0x40); // Literal indexed - bytesIn.writeByte(0x0a); // Literal name (len = 10) - bytesIn.writeUtf8("custom-foo"); - - bytesIn.writeByte(0x0d); // Literal value (len = 13) - bytesIn.writeUtf8("custom-header"); - - bytesIn.writeByte(0x40); // Literal indexed - bytesIn.writeByte(0x0a); // Literal name (len = 10) - bytesIn.writeUtf8("custom-bar"); - - bytesIn.writeByte(0x0d); // Literal value (len = 13) - bytesIn.writeUtf8("custom-header"); - - bytesIn.writeByte(0x40); // Literal indexed - bytesIn.writeByte(0x0a); // Literal name (len = 10) - bytesIn.writeUtf8("custom-baz"); - - bytesIn.writeByte(0x0d); // Literal value (len = 13) - bytesIn.writeUtf8("custom-header"); - - // Set to only support 110 bytes (enough for 2 headers). - // Use a new Writer because we don't support change the dynamic table - // size after Writer constructed. - Hpack.Writer writer = new Hpack.Writer(110, false, bytesOut); - writer.writeHeaders(headerBlock); - - assertThat(bytesOut).isEqualTo(bytesIn); - assertThat(writer.headerCount).isEqualTo(2); - - int tableLength = writer.dynamicTable.length; - Header entry = writer.dynamicTable[tableLength - 1]; - checkEntry(entry, "custom-bar", "custom-header", 55); - - entry = writer.dynamicTable[tableLength - 2]; - checkEntry(entry, "custom-baz", "custom-header", 55); - } - - @Test public void readerEviction() throws IOException { - List
headerBlock = - headerEntries( - "custom-foo", "custom-header", - "custom-bar", "custom-header", - "custom-baz", "custom-header"); - - // Set to only support 110 bytes (enough for 2 headers). - bytesIn.writeByte(0x3F); // Dynamic table size update (size = 110). - bytesIn.writeByte(0x4F); - - bytesIn.writeByte(0x40); // Literal indexed - bytesIn.writeByte(0x0a); // Literal name (len = 10) - bytesIn.writeUtf8("custom-foo"); - - bytesIn.writeByte(0x0d); // Literal value (len = 13) - bytesIn.writeUtf8("custom-header"); - - bytesIn.writeByte(0x40); // Literal indexed - bytesIn.writeByte(0x0a); // Literal name (len = 10) - bytesIn.writeUtf8("custom-bar"); - - bytesIn.writeByte(0x0d); // Literal value (len = 13) - bytesIn.writeUtf8("custom-header"); - - bytesIn.writeByte(0x40); // Literal indexed - bytesIn.writeByte(0x0a); // Literal name (len = 10) - bytesIn.writeUtf8("custom-baz"); - - bytesIn.writeByte(0x0d); // Literal value (len = 13) - bytesIn.writeUtf8("custom-header"); - - hpackReader.readHeaders(); - - assertThat(hpackReader.headerCount).isEqualTo(2); - - Header entry1 = hpackReader.dynamicTable[readerHeaderTableLength() - 1]; - checkEntry(entry1, "custom-bar", "custom-header", 55); - - Header entry2 = hpackReader.dynamicTable[readerHeaderTableLength() - 2]; - checkEntry(entry2, "custom-baz", "custom-header", 55); - - // Once a header field is decoded and added to the reconstructed header - // list, it cannot be removed from it. Hence, foo is here. - assertThat(hpackReader.getAndResetHeaderList()).isEqualTo(headerBlock); - - // Simulate receiving a small dynamic table size update, that implies eviction. - bytesIn.writeByte(0x3F); // Dynamic table size update (size = 55). - bytesIn.writeByte(0x18); - hpackReader.readHeaders(); - assertThat(hpackReader.headerCount).isEqualTo(1); - } - - /** Header table backing array is initially 8 long, let's ensure it grows. */ - @Test public void dynamicallyGrowsBeyond64Entries() throws IOException { - // Lots of headers need more room! - hpackReader = new Hpack.Reader(bytesIn, 16384, 4096); - bytesIn.writeByte(0x3F); // Dynamic table size update (size = 16384). - bytesIn.writeByte(0xE1); - bytesIn.writeByte(0x7F); - - for (int i = 0; i < 256; i++) { - bytesIn.writeByte(0x40); // Literal indexed - bytesIn.writeByte(0x0a); // Literal name (len = 10) - bytesIn.writeUtf8("custom-foo"); - - bytesIn.writeByte(0x0d); // Literal value (len = 13) - bytesIn.writeUtf8("custom-header"); - } - - hpackReader.readHeaders(); - - assertThat(hpackReader.headerCount).isEqualTo(256); - } - - @Test public void huffmanDecodingSupported() throws IOException { - bytesIn.writeByte(0x44); // == Literal indexed == - // Indexed name (idx = 4) -> :path - bytesIn.writeByte(0x8c); // Literal value Huffman encoded 12 bytes - // decodes to www.example.com which is length 15 - bytesIn.write(decodeHex("f1e3c2e5f23a6ba0ab90f4ff")); - - hpackReader.readHeaders(); - - assertThat(hpackReader.headerCount).isEqualTo(1); - assertThat(hpackReader.dynamicTableByteCount).isEqualTo(52); - - Header entry = hpackReader.dynamicTable[readerHeaderTableLength() - 1]; - checkEntry(entry, ":path", "www.example.com", 52); - } - - /** - * http://tools.ietf.org/html/draft-ietf-httpbis-header-compression-12#appendix-C.2.1 - */ - @Test public void readLiteralHeaderFieldWithIndexing() throws IOException { - bytesIn.writeByte(0x40); // Literal indexed - bytesIn.writeByte(0x0a); // Literal name (len = 10) - bytesIn.writeUtf8("custom-key"); - - bytesIn.writeByte(0x0d); // Literal value (len = 13) - bytesIn.writeUtf8("custom-header"); - - hpackReader.readHeaders(); - - assertThat(hpackReader.headerCount).isEqualTo(1); - assertThat(hpackReader.dynamicTableByteCount).isEqualTo(55); - - Header entry = hpackReader.dynamicTable[readerHeaderTableLength() - 1]; - checkEntry(entry, "custom-key", "custom-header", 55); - - assertThat(hpackReader.getAndResetHeaderList()).isEqualTo( - headerEntries("custom-key", "custom-header")); - } - - /** - * http://tools.ietf.org/html/draft-ietf-httpbis-header-compression-12#appendix-C.2.2 - */ - @Test public void literalHeaderFieldWithoutIndexingIndexedName() throws IOException { - List
headerBlock = headerEntries(":path", "/sample/path"); - - bytesIn.writeByte(0x04); // == Literal not indexed == - // Indexed name (idx = 4) -> :path - bytesIn.writeByte(0x0c); // Literal value (len = 12) - bytesIn.writeUtf8("/sample/path"); - - hpackWriter.writeHeaders(headerBlock); - assertThat(bytesOut).isEqualTo(bytesIn); - - hpackReader.readHeaders(); - - assertThat(hpackReader.headerCount).isEqualTo(0); - - assertThat(hpackReader.getAndResetHeaderList()).isEqualTo(headerBlock); - } - - @Test public void literalHeaderFieldWithoutIndexingNewName() throws IOException { - List
headerBlock = headerEntries("custom-key", "custom-header"); - - bytesIn.writeByte(0x00); // Not indexed - bytesIn.writeByte(0x0a); // Literal name (len = 10) - bytesIn.writeUtf8("custom-key"); - - bytesIn.writeByte(0x0d); // Literal value (len = 13) - bytesIn.writeUtf8("custom-header"); - - hpackReader.readHeaders(); - - assertThat(hpackReader.headerCount).isEqualTo(0); - - assertThat(hpackReader.getAndResetHeaderList()).isEqualTo(headerBlock); - } - - @Test public void literalHeaderFieldNeverIndexedIndexedName() throws IOException { - bytesIn.writeByte(0x14); // == Literal never indexed == - // Indexed name (idx = 4) -> :path - bytesIn.writeByte(0x0c); // Literal value (len = 12) - bytesIn.writeUtf8("/sample/path"); - - hpackReader.readHeaders(); - - assertThat(hpackReader.headerCount).isEqualTo(0); - - assertThat(hpackReader.getAndResetHeaderList()).isEqualTo( - headerEntries(":path", "/sample/path")); - } - - @Test public void literalHeaderFieldNeverIndexedNewName() throws IOException { - List
headerBlock = headerEntries("custom-key", "custom-header"); - - bytesIn.writeByte(0x10); // Never indexed - bytesIn.writeByte(0x0a); // Literal name (len = 10) - bytesIn.writeUtf8("custom-key"); - - bytesIn.writeByte(0x0d); // Literal value (len = 13) - bytesIn.writeUtf8("custom-header"); - - hpackReader.readHeaders(); - - assertThat(hpackReader.headerCount).isEqualTo(0); - - assertThat(hpackReader.getAndResetHeaderList()).isEqualTo(headerBlock); - } - - @Test public void literalHeaderFieldWithIncrementalIndexingIndexedName() throws IOException { - List
headerBlock = headerEntries(":path", "/sample/path"); - - bytesIn.writeByte(0x44); // Indexed name (idx = 4) -> :path - bytesIn.writeByte(0x0c); // Literal value (len = 12) - bytesIn.writeUtf8("/sample/path"); - - hpackReader.readHeaders(); - - assertThat(hpackReader.headerCount).isEqualTo(1); - - assertThat(hpackReader.getAndResetHeaderList()).isEqualTo(headerBlock); - } - - @Test public void literalHeaderFieldWithIncrementalIndexingNewName() throws IOException { - List
headerBlock = headerEntries("custom-key", "custom-header"); - - bytesIn.writeByte(0x40); // Never indexed - bytesIn.writeByte(0x0a); // Literal name (len = 10) - bytesIn.writeUtf8("custom-key"); - - bytesIn.writeByte(0x0d); // Literal value (len = 13) - bytesIn.writeUtf8("custom-header"); - - hpackWriter.writeHeaders(headerBlock); - assertThat(bytesOut).isEqualTo(bytesIn); - - assertThat(hpackWriter.headerCount).isEqualTo(1); - - Header entry = hpackWriter.dynamicTable[hpackWriter.dynamicTable.length - 1]; - checkEntry(entry, "custom-key", "custom-header", 55); - - hpackReader.readHeaders(); - - assertThat(hpackReader.headerCount).isEqualTo(1); - - assertThat(hpackReader.getAndResetHeaderList()).isEqualTo(headerBlock); - } - - @Test public void theSameHeaderAfterOneIncrementalIndexed() throws IOException { - List
headerBlock = - headerEntries( - "custom-key", "custom-header", - "custom-key", "custom-header"); - - bytesIn.writeByte(0x40); // Never indexed - bytesIn.writeByte(0x0a); // Literal name (len = 10) - bytesIn.writeUtf8("custom-key"); - - bytesIn.writeByte(0x0d); // Literal value (len = 13) - bytesIn.writeUtf8("custom-header"); - - bytesIn.writeByte(0xbe); // Indexed name and value (idx = 63) - - hpackWriter.writeHeaders(headerBlock); - assertThat(bytesOut).isEqualTo(bytesIn); - - assertThat(hpackWriter.headerCount).isEqualTo(1); - - Header entry = hpackWriter.dynamicTable[hpackWriter.dynamicTable.length - 1]; - checkEntry(entry, "custom-key", "custom-header", 55); - - hpackReader.readHeaders(); - - assertThat(hpackReader.headerCount).isEqualTo(1); - - assertThat(hpackReader.getAndResetHeaderList()).isEqualTo(headerBlock); - } - - @Test public void staticHeaderIsNotCopiedIntoTheIndexedTable() throws IOException { - bytesIn.writeByte(0x82); // == Indexed - Add == - // idx = 2 -> :method: GET - - hpackReader.readHeaders(); - - assertThat(hpackReader.headerCount).isEqualTo(0); - assertThat(hpackReader.dynamicTableByteCount).isEqualTo(0); - - assertThat(hpackReader.dynamicTable[readerHeaderTableLength() - 1]).isNull(); - - assertThat(hpackReader.getAndResetHeaderList()).isEqualTo( - headerEntries(":method", "GET")); - } - - // Example taken from twitter/hpack DecoderTest.testUnusedIndex - @Test public void readIndexedHeaderFieldIndex0() throws IOException { - bytesIn.writeByte(0x80); // == Indexed - Add idx = 0 - - try { - hpackReader.readHeaders(); - fail(""); - } catch (IOException e) { - assertThat(e.getMessage()).isEqualTo("index == 0"); - } - } - - // Example taken from twitter/hpack DecoderTest.testIllegalIndex - @Test public void readIndexedHeaderFieldTooLargeIndex() throws IOException { - bytesIn.writeShort(0xff00); // == Indexed - Add idx = 127 - - try { - hpackReader.readHeaders(); - fail(); - } catch (IOException e) { - assertThat(e.getMessage()).isEqualTo("Header index too large 127"); - } - } - - // Example taken from twitter/hpack DecoderTest.testInsidiousIndex - @Test public void readIndexedHeaderFieldInsidiousIndex() throws IOException { - bytesIn.writeByte(0xff); // == Indexed - Add == - bytesIn.write(decodeHex("8080808008")); // idx = -2147483521 - - try { - hpackReader.readHeaders(); - fail(); - } catch (IOException e) { - assertThat(e.getMessage()).isEqualTo("Header index too large -2147483521"); - } - } - - // Example taken from twitter/hpack DecoderTest.testHeaderTableSizeUpdate - @Test public void minMaxHeaderTableSize() throws IOException { - bytesIn.writeByte(0x20); - hpackReader.readHeaders(); - - assertThat(hpackReader.maxDynamicTableByteCount()).isEqualTo(0); - - bytesIn.writeByte(0x3f); // encode size 4096 - bytesIn.writeByte(0xe1); - bytesIn.writeByte(0x1f); - hpackReader.readHeaders(); - - assertThat(hpackReader.maxDynamicTableByteCount()).isEqualTo(4096); - } - - // Example taken from twitter/hpack DecoderTest.testIllegalHeaderTableSizeUpdate - @Test public void cannotSetTableSizeLargerThanSettingsValue() throws IOException { - bytesIn.writeByte(0x3f); // encode size 4097 - bytesIn.writeByte(0xe2); - bytesIn.writeByte(0x1f); - - try { - hpackReader.readHeaders(); - fail(); - } catch (IOException e) { - assertThat(e.getMessage()).isEqualTo("Invalid dynamic table size update 4097"); - } - } - - // Example taken from twitter/hpack DecoderTest.testInsidiousMaxHeaderSize - @Test public void readHeaderTableStateChangeInsidiousMaxHeaderByteCount() throws IOException { - bytesIn.writeByte(0x3f); - bytesIn.write(decodeHex("e1ffffff07")); // count = -2147483648 - - try { - hpackReader.readHeaders(); - fail(); - } catch (IOException e) { - assertThat(e.getMessage()).isEqualTo("Invalid dynamic table size update -2147483648"); - } - } - - /** - * http://tools.ietf.org/html/draft-ietf-httpbis-header-compression-12#appendix-C.2.4 - */ - @Test public void readIndexedHeaderFieldFromStaticTableWithoutBuffering() throws IOException { - bytesIn.writeByte(0x20); // Dynamic table size update (size = 0). - bytesIn.writeByte(0x82); // == Indexed - Add == - // idx = 2 -> :method: GET - - hpackReader.readHeaders(); - - // Not buffered in header table. - assertThat(hpackReader.headerCount).isEqualTo(0); - - assertThat(hpackReader.getAndResetHeaderList()).isEqualTo( - headerEntries(":method", "GET")); - } - - @Test public void readLiteralHeaderWithIncrementalIndexingStaticName() throws IOException { - bytesIn.writeByte(0x7d); // == Literal indexed == - // Indexed name (idx = 60) -> "www-authenticate" - bytesIn.writeByte(0x05); // Literal value (len = 5) - bytesIn.writeUtf8("Basic"); - - hpackReader.readHeaders(); - - assertThat(hpackReader.getAndResetHeaderList()) - .containsExactly(new Header("www-authenticate", "Basic")); - } - - @Test public void readLiteralHeaderWithIncrementalIndexingDynamicName() throws IOException { - bytesIn.writeByte(0x40); - bytesIn.writeByte(0x0a); // Literal name (len = 10) - bytesIn.writeUtf8("custom-foo"); - bytesIn.writeByte(0x05); // Literal value (len = 5) - bytesIn.writeUtf8("Basic"); - - bytesIn.writeByte(0x7e); - bytesIn.writeByte(0x06); // Literal value (len = 6) - bytesIn.writeUtf8("Basic2"); - - hpackReader.readHeaders(); - - assertThat(hpackReader.getAndResetHeaderList()).containsExactly( - new Header("custom-foo", "Basic"), new Header("custom-foo", "Basic2")); - } - - /** - * http://tools.ietf.org/html/draft-ietf-httpbis-header-compression-12#appendix-C.2 - */ - @Test public void readRequestExamplesWithoutHuffman() throws IOException { - firstRequestWithoutHuffman(); - hpackReader.readHeaders(); - checkReadFirstRequestWithoutHuffman(); - - secondRequestWithoutHuffman(); - hpackReader.readHeaders(); - checkReadSecondRequestWithoutHuffman(); - - thirdRequestWithoutHuffman(); - hpackReader.readHeaders(); - checkReadThirdRequestWithoutHuffman(); - } - - @Test public void readFailingRequestExample() throws IOException { - bytesIn.writeByte(0x82); // == Indexed - Add == - // idx = 2 -> :method: GET - bytesIn.writeByte(0x86); // == Indexed - Add == - // idx = 7 -> :scheme: http - bytesIn.writeByte(0x84); // == Indexed - Add == - - bytesIn.writeByte(0x7f); // == Bad index! == - - // Indexed name (idx = 4) -> :authority - bytesIn.writeByte(0x0f); // Literal value (len = 15) - bytesIn.writeUtf8("www.example.com"); - - try { - hpackReader.readHeaders(); - fail(); - } catch (IOException e) { - assertThat(e.getMessage()).isEqualTo("Header index too large 78"); - } - } - - private void firstRequestWithoutHuffman() { - bytesIn.writeByte(0x82); // == Indexed - Add == - // idx = 2 -> :method: GET - bytesIn.writeByte(0x86); // == Indexed - Add == - // idx = 7 -> :scheme: http - bytesIn.writeByte(0x84); // == Indexed - Add == - // idx = 6 -> :path: / - bytesIn.writeByte(0x41); // == Literal indexed == - // Indexed name (idx = 4) -> :authority - bytesIn.writeByte(0x0f); // Literal value (len = 15) - bytesIn.writeUtf8("www.example.com"); - } - - private void checkReadFirstRequestWithoutHuffman() { - assertThat(hpackReader.headerCount).isEqualTo(1); - - // [ 1] (s = 57) :authority: www.example.com - Header entry = hpackReader.dynamicTable[readerHeaderTableLength() - 1]; - checkEntry(entry, ":authority", "www.example.com", 57); - - // Table size: 57 - assertThat(hpackReader.dynamicTableByteCount).isEqualTo(57); - - // Decoded header list: - assertThat(hpackReader.getAndResetHeaderList()).isEqualTo(headerEntries( - ":method", "GET", - ":scheme", "http", - ":path", "/", - ":authority", "www.example.com")); - } - - private void secondRequestWithoutHuffman() { - bytesIn.writeByte(0x82); // == Indexed - Add == - // idx = 2 -> :method: GET - bytesIn.writeByte(0x86); // == Indexed - Add == - // idx = 7 -> :scheme: http - bytesIn.writeByte(0x84); // == Indexed - Add == - // idx = 6 -> :path: / - bytesIn.writeByte(0xbe); // == Indexed - Add == - // Indexed name (idx = 62) -> :authority: www.example.com - bytesIn.writeByte(0x58); // == Literal indexed == - // Indexed name (idx = 24) -> cache-control - bytesIn.writeByte(0x08); // Literal value (len = 8) - bytesIn.writeUtf8("no-cache"); - } - - private void checkReadSecondRequestWithoutHuffman() { - assertThat(hpackReader.headerCount).isEqualTo(2); - - // [ 1] (s = 53) cache-control: no-cache - Header entry = hpackReader.dynamicTable[readerHeaderTableLength() - 2]; - checkEntry(entry, "cache-control", "no-cache", 53); - - // [ 2] (s = 57) :authority: www.example.com - entry = hpackReader.dynamicTable[readerHeaderTableLength() - 1]; - checkEntry(entry, ":authority", "www.example.com", 57); - - // Table size: 110 - assertThat(hpackReader.dynamicTableByteCount).isEqualTo(110); - - // Decoded header list: - assertThat(hpackReader.getAndResetHeaderList()).isEqualTo(headerEntries( - ":method", "GET", - ":scheme", "http", - ":path", "/", - ":authority", "www.example.com", - "cache-control", "no-cache")); - } - - private void thirdRequestWithoutHuffman() { - bytesIn.writeByte(0x82); // == Indexed - Add == - // idx = 2 -> :method: GET - bytesIn.writeByte(0x87); // == Indexed - Add == - // idx = 7 -> :scheme: http - bytesIn.writeByte(0x85); // == Indexed - Add == - // idx = 5 -> :path: /index.html - bytesIn.writeByte(0xbf); // == Indexed - Add == - // Indexed name (idx = 63) -> :authority: www.example.com - bytesIn.writeByte(0x40); // Literal indexed - bytesIn.writeByte(0x0a); // Literal name (len = 10) - bytesIn.writeUtf8("custom-key"); - bytesIn.writeByte(0x0c); // Literal value (len = 12) - bytesIn.writeUtf8("custom-value"); - } - - private void checkReadThirdRequestWithoutHuffman() { - assertThat(hpackReader.headerCount).isEqualTo(3); - - // [ 1] (s = 54) custom-key: custom-value - Header entry = hpackReader.dynamicTable[readerHeaderTableLength() - 3]; - checkEntry(entry, "custom-key", "custom-value", 54); - - // [ 2] (s = 53) cache-control: no-cache - entry = hpackReader.dynamicTable[readerHeaderTableLength() - 2]; - checkEntry(entry, "cache-control", "no-cache", 53); - - // [ 3] (s = 57) :authority: www.example.com - entry = hpackReader.dynamicTable[readerHeaderTableLength() - 1]; - checkEntry(entry, ":authority", "www.example.com", 57); - - // Table size: 164 - assertThat(hpackReader.dynamicTableByteCount).isEqualTo(164); - - // Decoded header list: - assertThat(hpackReader.getAndResetHeaderList()).isEqualTo(headerEntries( - ":method", "GET", - ":scheme", "https", - ":path", "/index.html", - ":authority", "www.example.com", - "custom-key", "custom-value")); - } - - /** - * http://tools.ietf.org/html/draft-ietf-httpbis-header-compression-12#appendix-C.4 - */ - @Test public void readRequestExamplesWithHuffman() throws IOException { - firstRequestWithHuffman(); - hpackReader.readHeaders(); - checkReadFirstRequestWithHuffman(); - - secondRequestWithHuffman(); - hpackReader.readHeaders(); - checkReadSecondRequestWithHuffman(); - - thirdRequestWithHuffman(); - hpackReader.readHeaders(); - checkReadThirdRequestWithHuffman(); - } - - private void firstRequestWithHuffman() { - bytesIn.writeByte(0x82); // == Indexed - Add == - // idx = 2 -> :method: GET - bytesIn.writeByte(0x86); // == Indexed - Add == - // idx = 6 -> :scheme: http - bytesIn.writeByte(0x84); // == Indexed - Add == - // idx = 4 -> :path: / - bytesIn.writeByte(0x41); // == Literal indexed == - // Indexed name (idx = 1) -> :authority - bytesIn.writeByte(0x8c); // Literal value Huffman encoded 12 bytes - // decodes to www.example.com which is length 15 - bytesIn.write(decodeHex("f1e3c2e5f23a6ba0ab90f4ff")); - } - - private void checkReadFirstRequestWithHuffman() { - assertThat(hpackReader.headerCount).isEqualTo(1); - - // [ 1] (s = 57) :authority: www.example.com - Header entry = hpackReader.dynamicTable[readerHeaderTableLength() - 1]; - checkEntry(entry, ":authority", "www.example.com", 57); - - // Table size: 57 - assertThat(hpackReader.dynamicTableByteCount).isEqualTo(57); - - // Decoded header list: - assertThat(hpackReader.getAndResetHeaderList()).isEqualTo(headerEntries( - ":method", "GET", - ":scheme", "http", - ":path", "/", - ":authority", "www.example.com")); - } - - private void secondRequestWithHuffman() { - bytesIn.writeByte(0x82); // == Indexed - Add == - // idx = 2 -> :method: GET - bytesIn.writeByte(0x86); // == Indexed - Add == - // idx = 6 -> :scheme: http - bytesIn.writeByte(0x84); // == Indexed - Add == - // idx = 4 -> :path: / - bytesIn.writeByte(0xbe); // == Indexed - Add == - // idx = 62 -> :authority: www.example.com - bytesIn.writeByte(0x58); // == Literal indexed == - // Indexed name (idx = 24) -> cache-control - bytesIn.writeByte(0x86); // Literal value Huffman encoded 6 bytes - // decodes to no-cache which is length 8 - bytesIn.write(decodeHex("a8eb10649cbf")); - } - - private void checkReadSecondRequestWithHuffman() { - assertThat(hpackReader.headerCount).isEqualTo(2); - - // [ 1] (s = 53) cache-control: no-cache - Header entry = hpackReader.dynamicTable[readerHeaderTableLength() - 2]; - checkEntry(entry, "cache-control", "no-cache", 53); - - // [ 2] (s = 57) :authority: www.example.com - entry = hpackReader.dynamicTable[readerHeaderTableLength() - 1]; - checkEntry(entry, ":authority", "www.example.com", 57); - - // Table size: 110 - assertThat(hpackReader.dynamicTableByteCount).isEqualTo(110); - - // Decoded header list: - assertThat(hpackReader.getAndResetHeaderList()).isEqualTo(headerEntries( - ":method", "GET", - ":scheme", "http", - ":path", "/", - ":authority", "www.example.com", - "cache-control", "no-cache")); - } - - private void thirdRequestWithHuffman() { - bytesIn.writeByte(0x82); // == Indexed - Add == - // idx = 2 -> :method: GET - bytesIn.writeByte(0x87); // == Indexed - Add == - // idx = 7 -> :scheme: https - bytesIn.writeByte(0x85); // == Indexed - Add == - // idx = 5 -> :path: /index.html - bytesIn.writeByte(0xbf); // == Indexed - Add == - // idx = 63 -> :authority: www.example.com - bytesIn.writeByte(0x40); // Literal indexed - bytesIn.writeByte(0x88); // Literal name Huffman encoded 8 bytes - // decodes to custom-key which is length 10 - bytesIn.write(decodeHex("25a849e95ba97d7f")); - bytesIn.writeByte(0x89); // Literal value Huffman encoded 9 bytes - // decodes to custom-value which is length 12 - bytesIn.write(decodeHex("25a849e95bb8e8b4bf")); - } - - private void checkReadThirdRequestWithHuffman() { - assertThat(hpackReader.headerCount).isEqualTo(3); - - // [ 1] (s = 54) custom-key: custom-value - Header entry = hpackReader.dynamicTable[readerHeaderTableLength() - 3]; - checkEntry(entry, "custom-key", "custom-value", 54); - - // [ 2] (s = 53) cache-control: no-cache - entry = hpackReader.dynamicTable[readerHeaderTableLength() - 2]; - checkEntry(entry, "cache-control", "no-cache", 53); - - // [ 3] (s = 57) :authority: www.example.com - entry = hpackReader.dynamicTable[readerHeaderTableLength() - 1]; - checkEntry(entry, ":authority", "www.example.com", 57); - - // Table size: 164 - assertThat(hpackReader.dynamicTableByteCount).isEqualTo(164); - - // Decoded header list: - assertThat(hpackReader.getAndResetHeaderList()).isEqualTo(headerEntries( - ":method", "GET", - ":scheme", "https", - ":path", "/index.html", - ":authority", "www.example.com", - "custom-key", "custom-value")); - } - - @Test public void readSingleByteInt() throws IOException { - assertThat(newReader(byteStream()).readInt(10, 31)).isEqualTo(10); - assertThat(newReader(byteStream()).readInt(0xe0 | 10, 31)).isEqualTo(10); - } - - @Test public void readMultibyteInt() throws IOException { - assertThat(newReader(byteStream(154, 10)).readInt(31, 31)).isEqualTo(1337); - } - - @Test public void writeSingleByteInt() throws IOException { - hpackWriter.writeInt(10, 31, 0); - assertBytes(10); - hpackWriter.writeInt(10, 31, 0xe0); - assertBytes(0xe0 | 10); - } - - @Test public void writeMultibyteInt() throws IOException { - hpackWriter.writeInt(1337, 31, 0); - assertBytes(31, 154, 10); - hpackWriter.writeInt(1337, 31, 0xe0); - assertBytes(0xe0 | 31, 154, 10); - } - - @Test public void max31BitValue() throws IOException { - hpackWriter.writeInt(0x7fffffff, 31, 0); - assertBytes(31, 224, 255, 255, 255, 7); - assertThat(newReader(byteStream(224, 255, 255, 255, 7)).readInt(31, 31)).isEqualTo( - (long) 0x7fffffff); - } - - @Test public void prefixMask() throws IOException { - hpackWriter.writeInt(31, 31, 0); - assertBytes(31, 0); - assertThat(newReader(byteStream(0)).readInt(31, 31)).isEqualTo(31); - } - - @Test public void prefixMaskMinusOne() throws IOException { - hpackWriter.writeInt(30, 31, 0); - assertBytes(30); - assertThat(newReader(byteStream(0)).readInt(31, 31)).isEqualTo(31); - } - - @Test public void zero() throws IOException { - hpackWriter.writeInt(0, 31, 0); - assertBytes(0); - assertThat(newReader(byteStream()).readInt(0, 31)).isEqualTo(0); - } - - @Test public void lowercaseHeaderNameBeforeEmit() throws IOException { - hpackWriter.writeHeaders(asList(new Header("FoO", "BaR"))); - assertBytes(0x40, 3, 'f', 'o', 'o', 3, 'B', 'a', 'R'); - } - - @Test public void mixedCaseHeaderNameIsMalformed() throws IOException { - try { - newReader(byteStream(0, 3, 'F', 'o', 'o', 3, 'B', 'a', 'R')).readHeaders(); - fail(); - } catch (IOException e) { - assertThat(e.getMessage()).isEqualTo( - "PROTOCOL_ERROR response malformed: mixed case name: Foo"); - } - } - - @Test public void emptyHeaderName() throws IOException { - hpackWriter.writeByteString(ByteString.encodeUtf8("")); - assertBytes(0); - assertThat(newReader(byteStream(0)).readByteString()).isEqualTo(ByteString.EMPTY); - } - - @Test public void emitsDynamicTableSizeUpdate() throws IOException { - hpackWriter.resizeHeaderTable(2048); - hpackWriter.writeHeaders(asList(new Header("foo", "bar"))); - assertBytes( - 0x3F, 0xE1, 0xF, // Dynamic table size update (size = 2048). - 0x40, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'); - - hpackWriter.resizeHeaderTable(8192); - hpackWriter.writeHeaders(asList(new Header("bar", "foo"))); - assertBytes( - 0x3F, 0xE1, 0x3F, // Dynamic table size update (size = 8192). - 0x40, 3, 'b', 'a', 'r', 3, 'f', 'o', 'o'); - - // No more dynamic table updates should be emitted. - hpackWriter.writeHeaders(asList(new Header("far", "boo"))); - assertBytes(0x40, 3, 'f', 'a', 'r', 3, 'b', 'o', 'o'); - } - - @Test public void noDynamicTableSizeUpdateWhenSizeIsEqual() throws IOException { - int currentSize = hpackWriter.headerTableSizeSetting; - hpackWriter.resizeHeaderTable(currentSize); - hpackWriter.writeHeaders(asList(new Header("foo", "bar"))); - - assertBytes(0x40, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'); - } - - @Test public void growDynamicTableSize() throws IOException { - hpackWriter.resizeHeaderTable(8192); - hpackWriter.resizeHeaderTable(16384); - hpackWriter.writeHeaders(asList(new Header("foo", "bar"))); - - assertBytes( - 0x3F, 0xE1, 0x7F, // Dynamic table size update (size = 16384). - 0x40, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'); - } - - @Test public void shrinkDynamicTableSize() throws IOException { - hpackWriter.resizeHeaderTable(2048); - hpackWriter.resizeHeaderTable(0); - hpackWriter.writeHeaders(asList(new Header("foo", "bar"))); - - assertBytes( - 0x20, // Dynamic size update (size = 0). - 0x40, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'); - } - - @Test public void manyDynamicTableSizeChanges() throws IOException { - hpackWriter.resizeHeaderTable(16384); - hpackWriter.resizeHeaderTable(8096); - hpackWriter.resizeHeaderTable(0); - hpackWriter.resizeHeaderTable(4096); - hpackWriter.resizeHeaderTable(2048); - hpackWriter.writeHeaders(asList(new Header("foo", "bar"))); - - assertBytes( - 0x20, // Dynamic size update (size = 0). - 0x3F, 0xE1, 0xF, // Dynamic size update (size = 2048). - 0x40, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'); - } - - @Test public void dynamicTableEvictionWhenSizeLowered() throws IOException { - List
headerBlock = - headerEntries( - "custom-key1", "custom-header", - "custom-key2", "custom-header"); - hpackWriter.writeHeaders(headerBlock); - assertThat(hpackWriter.headerCount).isEqualTo(2); - - hpackWriter.resizeHeaderTable(56); - assertThat(hpackWriter.headerCount).isEqualTo(1); - - hpackWriter.resizeHeaderTable(0); - assertThat(hpackWriter.headerCount).isEqualTo(0); - } - - @Test public void noEvictionOnDynamicTableSizeIncrease() throws IOException { - List
headerBlock = - headerEntries( - "custom-key1", "custom-header", - "custom-key2", "custom-header"); - hpackWriter.writeHeaders(headerBlock); - assertThat(hpackWriter.headerCount).isEqualTo(2); - - hpackWriter.resizeHeaderTable(8192); - assertThat(hpackWriter.headerCount).isEqualTo(2); - } - - @Test public void dynamicTableSizeHasAnUpperBound() { - hpackWriter.resizeHeaderTable(1048576); - assertThat(hpackWriter.maxDynamicTableByteCount).isEqualTo(16384); - } - - @Test public void huffmanEncode() throws IOException { - hpackWriter = new Hpack.Writer(4096, true, bytesOut); - hpackWriter.writeHeaders(headerEntries("foo", "bar")); - - ByteString expected = new Buffer() - .writeByte(0x40) // Literal header, new name. - .writeByte(0x82) // String literal is Huffman encoded (len = 2). - .writeByte(0x94) // 'foo' Huffman encoded. - .writeByte(0xE7) - .writeByte(3) // String literal not Huffman encoded (len = 3). - .writeByte('b') - .writeByte('a') - .writeByte('r') - .readByteString(); - - ByteString actual = bytesOut.readByteString(); - assertThat(actual).isEqualTo(expected); - } - - @Test public void staticTableIndexedHeaders() throws IOException { - hpackWriter.writeHeaders(headerEntries(":method", "GET")); - assertBytes(0x82); - assertThat(hpackWriter.headerCount).isEqualTo(0); - - hpackWriter.writeHeaders(headerEntries(":method", "POST")); - assertBytes(0x83); - assertThat(hpackWriter.headerCount).isEqualTo(0); - - hpackWriter.writeHeaders(headerEntries(":path", "/")); - assertBytes(0x84); - assertThat(hpackWriter.headerCount).isEqualTo(0); - - hpackWriter.writeHeaders(headerEntries(":path", "/index.html")); - assertBytes(0x85); - assertThat(hpackWriter.headerCount).isEqualTo(0); - - hpackWriter.writeHeaders(headerEntries(":scheme", "http")); - assertBytes(0x86); - assertThat(hpackWriter.headerCount).isEqualTo(0); - - hpackWriter.writeHeaders(headerEntries(":scheme", "https")); - assertBytes(0x87); - assertThat(hpackWriter.headerCount).isEqualTo(0); - } - - @Test public void dynamicTableIndexedHeader() throws IOException { - hpackWriter.writeHeaders(headerEntries("custom-key", "custom-header")); - assertBytes(0x40, - 10, 'c', 'u', 's', 't', 'o', 'm', '-', 'k', 'e', 'y', - 13, 'c', 'u', 's', 't', 'o', 'm', '-', 'h', 'e', 'a', 'd', 'e', 'r'); - assertThat(hpackWriter.headerCount).isEqualTo(1); - - hpackWriter.writeHeaders(headerEntries("custom-key", "custom-header")); - assertBytes(0xbe); - assertThat(hpackWriter.headerCount).isEqualTo(1); - } - - @Test public void doNotIndexPseudoHeaders() throws IOException { - hpackWriter.writeHeaders(headerEntries(":method", "PUT")); - assertBytes(0x02, 3, 'P', 'U', 'T'); - assertThat(hpackWriter.headerCount).isEqualTo(0); - - hpackWriter.writeHeaders(headerEntries(":path", "/okhttp")); - assertBytes(0x04, 7, '/', 'o', 'k', 'h', 't', 't', 'p'); - assertThat(hpackWriter.headerCount).isEqualTo(0); - } - - @Test public void incrementalIndexingWithAuthorityPseudoHeader() throws IOException { - hpackWriter.writeHeaders(headerEntries(":authority", "foo.com")); - assertBytes(0x41, 7, 'f', 'o', 'o', '.', 'c', 'o', 'm'); - assertThat(hpackWriter.headerCount).isEqualTo(1); - - hpackWriter.writeHeaders(headerEntries(":authority", "foo.com")); - assertBytes(0xbe); - assertThat(hpackWriter.headerCount).isEqualTo(1); - - // If the :authority header somehow changes, it should be re-added to the dynamic table. - hpackWriter.writeHeaders(headerEntries(":authority", "bar.com")); - assertBytes(0x41, 7, 'b', 'a', 'r', '.', 'c', 'o', 'm'); - assertThat(hpackWriter.headerCount).isEqualTo(2); - - hpackWriter.writeHeaders(headerEntries(":authority", "bar.com")); - assertBytes(0xbe); - assertThat(hpackWriter.headerCount).isEqualTo(2); - } - - @Test public void incrementalIndexingWithStaticTableIndexedName() throws IOException { - hpackWriter.writeHeaders(headerEntries("accept-encoding", "gzip")); - assertBytes(0x50, 4, 'g', 'z', 'i', 'p'); - assertThat(hpackWriter.headerCount).isEqualTo(1); - - hpackWriter.writeHeaders(headerEntries("accept-encoding", "gzip")); - assertBytes(0xbe); - assertThat(hpackWriter.headerCount).isEqualTo(1); - } - - @Test public void incrementalIndexingWithDynamcTableIndexedName() throws IOException { - hpackWriter.writeHeaders(headerEntries("foo", "bar")); - assertBytes(0x40, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'); - assertThat(hpackWriter.headerCount).isEqualTo(1); - - hpackWriter.writeHeaders(headerEntries("foo", "bar1")); - assertBytes(0x7e, 4, 'b', 'a', 'r', '1'); - assertThat(hpackWriter.headerCount).isEqualTo(2); - - hpackWriter.writeHeaders(headerEntries("foo", "bar1")); - assertBytes(0xbe); - assertThat(hpackWriter.headerCount).isEqualTo(2); - } - - private Hpack.Reader newReader(Buffer source) { - return new Hpack.Reader(source, 4096); - } - - private Buffer byteStream(int... bytes) { - return new Buffer().write(intArrayToByteArray(bytes)); - } - - private void checkEntry(Header entry, String name, String value, int size) { - assertThat(entry.name.utf8()).isEqualTo(name); - assertThat(entry.value.utf8()).isEqualTo(value); - assertThat(entry.hpackSize).isEqualTo(size); - } - - private void assertBytes(int... bytes) throws IOException { - ByteString expected = intArrayToByteArray(bytes); - ByteString actual = bytesOut.readByteString(); - assertThat(actual).isEqualTo(expected); - } - - private ByteString intArrayToByteArray(int[] bytes) { - byte[] data = new byte[bytes.length]; - for (int i = 0; i < bytes.length; i++) { - data[i] = (byte) bytes[i]; - } - return ByteString.of(data); - } - - private int readerHeaderTableLength() { - return hpackReader.dynamicTable.length; - } -} diff --git a/okhttp/src/test/java/okhttp3/internal/http2/HpackTest.kt b/okhttp/src/test/java/okhttp3/internal/http2/HpackTest.kt new file mode 100644 index 000000000000..00ff1ed13a69 --- /dev/null +++ b/okhttp/src/test/java/okhttp3/internal/http2/HpackTest.kt @@ -0,0 +1,1103 @@ +/* + * Copyright (C) 2013 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3.internal.http2 + +import java.io.IOException +import java.util.Arrays +import okhttp3.TestUtil.headerEntries +import okio.Buffer +import okio.ByteString +import okio.ByteString.Companion.decodeHex +import okio.ByteString.Companion.encodeUtf8 +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test + +class HpackTest { + private val bytesIn = Buffer() + private var hpackReader: Hpack.Reader? = null + private val bytesOut = Buffer() + private var hpackWriter: Hpack.Writer? = null + + @BeforeEach + fun reset() { + hpackReader = newReader(bytesIn) + hpackWriter = Hpack.Writer(4096, false, bytesOut) + } + + /** + * Variable-length quantity special cases strings which are longer than 127 bytes. Values such as + * cookies can be 4KiB, and should be possible to send. + * + * http://tools.ietf.org/html/draft-ietf-httpbis-header-compression-12#section-5.2 + */ + @Test + fun largeHeaderValue() { + val value = CharArray(4096) + Arrays.fill(value, '!') + val headerBlock = headerEntries("cookie", String(value)) + hpackWriter!!.writeHeaders(headerBlock) + bytesIn.writeAll(bytesOut) + hpackReader!!.readHeaders() + assertThat(hpackReader!!.headerCount).isEqualTo(0) + assertThat(hpackReader!!.getAndResetHeaderList()).isEqualTo(headerBlock) + } + + /** + * HPACK has a max header table size, which can be smaller than the max header message. Ensure the + * larger header content is not lost. + */ + @Test + fun tooLargeToHPackIsStillEmitted() { + bytesIn.writeByte(0x21) // Dynamic table size update (size = 1). + bytesIn.writeByte(0x00) // Literal indexed + bytesIn.writeByte(0x0a) // Literal name (len = 10) + bytesIn.writeUtf8("custom-key") + bytesIn.writeByte(0x0d) // Literal value (len = 13) + bytesIn.writeUtf8("custom-header") + hpackReader!!.readHeaders() + assertThat(hpackReader!!.headerCount).isEqualTo(0) + assertThat(hpackReader!!.getAndResetHeaderList()).isEqualTo( + headerEntries("custom-key", "custom-header") + ) + } + + /** Oldest entries are evicted to support newer ones. */ + @Test + fun writerEviction() { + val headerBlock = headerEntries( + "custom-foo", "custom-header", + "custom-bar", "custom-header", + "custom-baz", "custom-header" + ) + bytesIn.writeByte(0x40) // Literal indexed + bytesIn.writeByte(0x0a) // Literal name (len = 10) + bytesIn.writeUtf8("custom-foo") + bytesIn.writeByte(0x0d) // Literal value (len = 13) + bytesIn.writeUtf8("custom-header") + bytesIn.writeByte(0x40) // Literal indexed + bytesIn.writeByte(0x0a) // Literal name (len = 10) + bytesIn.writeUtf8("custom-bar") + bytesIn.writeByte(0x0d) // Literal value (len = 13) + bytesIn.writeUtf8("custom-header") + bytesIn.writeByte(0x40) // Literal indexed + bytesIn.writeByte(0x0a) // Literal name (len = 10) + bytesIn.writeUtf8("custom-baz") + bytesIn.writeByte(0x0d) // Literal value (len = 13) + bytesIn.writeUtf8("custom-header") + + // Set to only support 110 bytes (enough for 2 headers). + // Use a new Writer because we don't support change the dynamic table + // size after Writer constructed. + val writer = Hpack.Writer(110, false, bytesOut) + writer.writeHeaders(headerBlock) + assertThat(bytesOut).isEqualTo(bytesIn) + assertThat(writer.headerCount).isEqualTo(2) + val tableLength = writer.dynamicTable.size + var entry = writer.dynamicTable[tableLength - 1]!! + checkEntry(entry, "custom-bar", "custom-header", 55) + entry = writer.dynamicTable[tableLength - 2]!! + checkEntry(entry, "custom-baz", "custom-header", 55) + } + + @Test + fun readerEviction() { + val headerBlock = headerEntries( + "custom-foo", "custom-header", + "custom-bar", "custom-header", + "custom-baz", "custom-header" + ) + + // Set to only support 110 bytes (enough for 2 headers). + bytesIn.writeByte(0x3F) // Dynamic table size update (size = 110). + bytesIn.writeByte(0x4F) + bytesIn.writeByte(0x40) // Literal indexed + bytesIn.writeByte(0x0a) // Literal name (len = 10) + bytesIn.writeUtf8("custom-foo") + bytesIn.writeByte(0x0d) // Literal value (len = 13) + bytesIn.writeUtf8("custom-header") + bytesIn.writeByte(0x40) // Literal indexed + bytesIn.writeByte(0x0a) // Literal name (len = 10) + bytesIn.writeUtf8("custom-bar") + bytesIn.writeByte(0x0d) // Literal value (len = 13) + bytesIn.writeUtf8("custom-header") + bytesIn.writeByte(0x40) // Literal indexed + bytesIn.writeByte(0x0a) // Literal name (len = 10) + bytesIn.writeUtf8("custom-baz") + bytesIn.writeByte(0x0d) // Literal value (len = 13) + bytesIn.writeUtf8("custom-header") + hpackReader!!.readHeaders() + assertThat(hpackReader!!.headerCount).isEqualTo(2) + val entry1 = hpackReader!!.dynamicTable[readerHeaderTableLength() - 1]!! + checkEntry(entry1, "custom-bar", "custom-header", 55) + val entry2 = hpackReader!!.dynamicTable[readerHeaderTableLength() - 2]!! + checkEntry(entry2, "custom-baz", "custom-header", 55) + + // Once a header field is decoded and added to the reconstructed header + // list, it cannot be removed from it. Hence, foo is here. + assertThat(hpackReader!!.getAndResetHeaderList()).isEqualTo(headerBlock) + + // Simulate receiving a small dynamic table size update, that implies eviction. + bytesIn.writeByte(0x3F) // Dynamic table size update (size = 55). + bytesIn.writeByte(0x18) + hpackReader!!.readHeaders() + assertThat(hpackReader!!.headerCount).isEqualTo(1) + } + + /** Header table backing array is initially 8 long, let's ensure it grows. */ + @Test + fun dynamicallyGrowsBeyond64Entries() { + // Lots of headers need more room! + hpackReader = Hpack.Reader(bytesIn, 16384, 4096) + bytesIn.writeByte(0x3F) // Dynamic table size update (size = 16384). + bytesIn.writeByte(0xE1) + bytesIn.writeByte(0x7F) + for (i in 0..255) { + bytesIn.writeByte(0x40) // Literal indexed + bytesIn.writeByte(0x0a) // Literal name (len = 10) + bytesIn.writeUtf8("custom-foo") + bytesIn.writeByte(0x0d) // Literal value (len = 13) + bytesIn.writeUtf8("custom-header") + } + hpackReader!!.readHeaders() + assertThat(hpackReader!!.headerCount).isEqualTo(256) + } + + @Test + fun huffmanDecodingSupported() { + bytesIn.writeByte(0x44) // == Literal indexed == + // Indexed name (idx = 4) -> :path + bytesIn.writeByte(0x8c) // Literal value Huffman encoded 12 bytes + // decodes to www.example.com which is length 15 + bytesIn.write("f1e3c2e5f23a6ba0ab90f4ff".decodeHex()) + hpackReader!!.readHeaders() + assertThat(hpackReader!!.headerCount).isEqualTo(1) + assertThat(hpackReader!!.dynamicTableByteCount).isEqualTo(52) + val entry = hpackReader!!.dynamicTable[readerHeaderTableLength() - 1]!! + checkEntry(entry, ":path", "www.example.com", 52) + } + + /** + * http://tools.ietf.org/html/draft-ietf-httpbis-header-compression-12#appendix-C.2.1 + */ + @Test + fun readLiteralHeaderFieldWithIndexing() { + bytesIn.writeByte(0x40) // Literal indexed + bytesIn.writeByte(0x0a) // Literal name (len = 10) + bytesIn.writeUtf8("custom-key") + bytesIn.writeByte(0x0d) // Literal value (len = 13) + bytesIn.writeUtf8("custom-header") + hpackReader!!.readHeaders() + assertThat(hpackReader!!.headerCount).isEqualTo(1) + assertThat(hpackReader!!.dynamicTableByteCount).isEqualTo(55) + val entry = hpackReader!!.dynamicTable[readerHeaderTableLength() - 1]!! + checkEntry(entry, "custom-key", "custom-header", 55) + assertThat(hpackReader!!.getAndResetHeaderList()).isEqualTo( + headerEntries("custom-key", "custom-header") + ) + } + + /** + * http://tools.ietf.org/html/draft-ietf-httpbis-header-compression-12#appendix-C.2.2 + */ + @Test + fun literalHeaderFieldWithoutIndexingIndexedName() { + val headerBlock = headerEntries(":path", "/sample/path") + bytesIn.writeByte(0x04) // == Literal not indexed == + // Indexed name (idx = 4) -> :path + bytesIn.writeByte(0x0c) // Literal value (len = 12) + bytesIn.writeUtf8("/sample/path") + hpackWriter!!.writeHeaders(headerBlock) + assertThat(bytesOut).isEqualTo(bytesIn) + hpackReader!!.readHeaders() + assertThat(hpackReader!!.headerCount).isEqualTo(0) + assertThat(hpackReader!!.getAndResetHeaderList()).isEqualTo(headerBlock) + } + + @Test + fun literalHeaderFieldWithoutIndexingNewName() { + val headerBlock = headerEntries("custom-key", "custom-header") + bytesIn.writeByte(0x00) // Not indexed + bytesIn.writeByte(0x0a) // Literal name (len = 10) + bytesIn.writeUtf8("custom-key") + bytesIn.writeByte(0x0d) // Literal value (len = 13) + bytesIn.writeUtf8("custom-header") + hpackReader!!.readHeaders() + assertThat(hpackReader!!.headerCount).isEqualTo(0) + assertThat(hpackReader!!.getAndResetHeaderList()).isEqualTo(headerBlock) + } + + @Test + fun literalHeaderFieldNeverIndexedIndexedName() { + bytesIn.writeByte(0x14) // == Literal never indexed == + // Indexed name (idx = 4) -> :path + bytesIn.writeByte(0x0c) // Literal value (len = 12) + bytesIn.writeUtf8("/sample/path") + hpackReader!!.readHeaders() + assertThat(hpackReader!!.headerCount).isEqualTo(0) + assertThat(hpackReader!!.getAndResetHeaderList()).isEqualTo( + headerEntries(":path", "/sample/path") + ) + } + + @Test + fun literalHeaderFieldNeverIndexedNewName() { + val headerBlock = headerEntries("custom-key", "custom-header") + bytesIn.writeByte(0x10) // Never indexed + bytesIn.writeByte(0x0a) // Literal name (len = 10) + bytesIn.writeUtf8("custom-key") + bytesIn.writeByte(0x0d) // Literal value (len = 13) + bytesIn.writeUtf8("custom-header") + hpackReader!!.readHeaders() + assertThat(hpackReader!!.headerCount).isEqualTo(0) + assertThat(hpackReader!!.getAndResetHeaderList()).isEqualTo(headerBlock) + } + + @Test + fun literalHeaderFieldWithIncrementalIndexingIndexedName() { + val headerBlock = headerEntries(":path", "/sample/path") + bytesIn.writeByte(0x44) // Indexed name (idx = 4) -> :path + bytesIn.writeByte(0x0c) // Literal value (len = 12) + bytesIn.writeUtf8("/sample/path") + hpackReader!!.readHeaders() + assertThat(hpackReader!!.headerCount).isEqualTo(1) + assertThat(hpackReader!!.getAndResetHeaderList()).isEqualTo(headerBlock) + } + + @Test + fun literalHeaderFieldWithIncrementalIndexingNewName() { + val headerBlock = headerEntries("custom-key", "custom-header") + bytesIn.writeByte(0x40) // Never indexed + bytesIn.writeByte(0x0a) // Literal name (len = 10) + bytesIn.writeUtf8("custom-key") + bytesIn.writeByte(0x0d) // Literal value (len = 13) + bytesIn.writeUtf8("custom-header") + hpackWriter!!.writeHeaders(headerBlock) + assertThat(bytesOut).isEqualTo(bytesIn) + assertThat(hpackWriter!!.headerCount).isEqualTo(1) + val entry = hpackWriter!!.dynamicTable[hpackWriter!!.dynamicTable.size - 1]!! + checkEntry(entry, "custom-key", "custom-header", 55) + hpackReader!!.readHeaders() + assertThat(hpackReader!!.headerCount).isEqualTo(1) + assertThat(hpackReader!!.getAndResetHeaderList()).isEqualTo(headerBlock) + } + + @Test + fun theSameHeaderAfterOneIncrementalIndexed() { + val headerBlock = headerEntries( + "custom-key", "custom-header", + "custom-key", "custom-header" + ) + bytesIn.writeByte(0x40) // Never indexed + bytesIn.writeByte(0x0a) // Literal name (len = 10) + bytesIn.writeUtf8("custom-key") + bytesIn.writeByte(0x0d) // Literal value (len = 13) + bytesIn.writeUtf8("custom-header") + bytesIn.writeByte(0xbe) // Indexed name and value (idx = 63) + hpackWriter!!.writeHeaders(headerBlock) + assertThat(bytesOut).isEqualTo(bytesIn) + assertThat(hpackWriter!!.headerCount).isEqualTo(1) + val entry = hpackWriter!!.dynamicTable[hpackWriter!!.dynamicTable.size - 1]!! + checkEntry(entry, "custom-key", "custom-header", 55) + hpackReader!!.readHeaders() + assertThat(hpackReader!!.headerCount).isEqualTo(1) + assertThat(hpackReader!!.getAndResetHeaderList()).isEqualTo(headerBlock) + } + + @Test + fun staticHeaderIsNotCopiedIntoTheIndexedTable() { + bytesIn.writeByte(0x82) // == Indexed - Add == + // idx = 2 -> :method: GET + hpackReader!!.readHeaders() + assertThat(hpackReader!!.headerCount).isEqualTo(0) + assertThat(hpackReader!!.dynamicTableByteCount).isEqualTo(0) + assertThat(hpackReader!!.dynamicTable[readerHeaderTableLength() - 1]).isNull() + assertThat(hpackReader!!.getAndResetHeaderList()).isEqualTo( + headerEntries(":method", "GET") + ) + } + + // Example taken from twitter/hpack DecoderTest.testUnusedIndex + @Test + fun readIndexedHeaderFieldIndex0() { + bytesIn.writeByte(0x80) // == Indexed - Add idx = 0 + try { + hpackReader!!.readHeaders() + org.junit.jupiter.api.Assertions.fail("") + } catch (e: IOException) { + assertThat(e.message).isEqualTo("index == 0") + } + } + + // Example taken from twitter/hpack DecoderTest.testIllegalIndex + @Test + fun readIndexedHeaderFieldTooLargeIndex() { + bytesIn.writeShort(0xff00) // == Indexed - Add idx = 127 + try { + hpackReader!!.readHeaders() + org.junit.jupiter.api.Assertions.fail() + } catch (e: IOException) { + assertThat(e.message).isEqualTo("Header index too large 127") + } + } + + // Example taken from twitter/hpack DecoderTest.testInsidiousIndex + @Test + fun readIndexedHeaderFieldInsidiousIndex() { + bytesIn.writeByte(0xff) // == Indexed - Add == + bytesIn.write("8080808008".decodeHex()) // idx = -2147483521 + try { + hpackReader!!.readHeaders() + org.junit.jupiter.api.Assertions.fail() + } catch (e: IOException) { + assertThat(e.message).isEqualTo("Header index too large -2147483521") + } + } + + // Example taken from twitter/hpack DecoderTest.testHeaderTableSizeUpdate + @Test + fun minMaxHeaderTableSize() { + bytesIn.writeByte(0x20) + hpackReader!!.readHeaders() + assertThat(hpackReader!!.maxDynamicTableByteCount()).isEqualTo(0) + bytesIn.writeByte(0x3f) // encode size 4096 + bytesIn.writeByte(0xe1) + bytesIn.writeByte(0x1f) + hpackReader!!.readHeaders() + assertThat(hpackReader!!.maxDynamicTableByteCount()).isEqualTo(4096) + } + + // Example taken from twitter/hpack DecoderTest.testIllegalHeaderTableSizeUpdate + @Test + fun cannotSetTableSizeLargerThanSettingsValue() { + bytesIn.writeByte(0x3f) // encode size 4097 + bytesIn.writeByte(0xe2) + bytesIn.writeByte(0x1f) + try { + hpackReader!!.readHeaders() + org.junit.jupiter.api.Assertions.fail() + } catch (e: IOException) { + assertThat(e.message).isEqualTo("Invalid dynamic table size update 4097") + } + } + + // Example taken from twitter/hpack DecoderTest.testInsidiousMaxHeaderSize + @Test + fun readHeaderTableStateChangeInsidiousMaxHeaderByteCount() { + bytesIn.writeByte(0x3f) + bytesIn.write("e1ffffff07".decodeHex()) // count = -2147483648 + try { + hpackReader!!.readHeaders() + org.junit.jupiter.api.Assertions.fail() + } catch (e: IOException) { + assertThat(e.message) + .isEqualTo("Invalid dynamic table size update -2147483648") + } + } + + /** + * http://tools.ietf.org/html/draft-ietf-httpbis-header-compression-12#appendix-C.2.4 + */ + @Test + fun readIndexedHeaderFieldFromStaticTableWithoutBuffering() { + bytesIn.writeByte(0x20) // Dynamic table size update (size = 0). + bytesIn.writeByte(0x82) // == Indexed - Add == + // idx = 2 -> :method: GET + hpackReader!!.readHeaders() + + // Not buffered in header table. + assertThat(hpackReader!!.headerCount).isEqualTo(0) + assertThat(hpackReader!!.getAndResetHeaderList()).isEqualTo( + headerEntries(":method", "GET") + ) + } + + @Test + fun readLiteralHeaderWithIncrementalIndexingStaticName() { + bytesIn.writeByte(0x7d) // == Literal indexed == + // Indexed name (idx = 60) -> "www-authenticate" + bytesIn.writeByte(0x05) // Literal value (len = 5) + bytesIn.writeUtf8("Basic") + hpackReader!!.readHeaders() + assertThat(hpackReader!!.getAndResetHeaderList()) + .containsExactly(Header("www-authenticate", "Basic")) + } + + @Test + fun readLiteralHeaderWithIncrementalIndexingDynamicName() { + bytesIn.writeByte(0x40) + bytesIn.writeByte(0x0a) // Literal name (len = 10) + bytesIn.writeUtf8("custom-foo") + bytesIn.writeByte(0x05) // Literal value (len = 5) + bytesIn.writeUtf8("Basic") + bytesIn.writeByte(0x7e) + bytesIn.writeByte(0x06) // Literal value (len = 6) + bytesIn.writeUtf8("Basic2") + hpackReader!!.readHeaders() + assertThat(hpackReader!!.getAndResetHeaderList()).containsExactly( + Header("custom-foo", "Basic"), Header("custom-foo", "Basic2") + ) + } + + /** + * http://tools.ietf.org/html/draft-ietf-httpbis-header-compression-12#appendix-C.2 + */ + @Test + fun readRequestExamplesWithoutHuffman() { + firstRequestWithoutHuffman() + hpackReader!!.readHeaders() + checkReadFirstRequestWithoutHuffman() + secondRequestWithoutHuffman() + hpackReader!!.readHeaders() + checkReadSecondRequestWithoutHuffman() + thirdRequestWithoutHuffman() + hpackReader!!.readHeaders() + checkReadThirdRequestWithoutHuffman() + } + + @Test + fun readFailingRequestExample() { + bytesIn.writeByte(0x82) // == Indexed - Add == + // idx = 2 -> :method: GET + bytesIn.writeByte(0x86) // == Indexed - Add == + // idx = 7 -> :scheme: http + bytesIn.writeByte(0x84) // == Indexed - Add == + bytesIn.writeByte(0x7f) // == Bad index! == + + // Indexed name (idx = 4) -> :authority + bytesIn.writeByte(0x0f) // Literal value (len = 15) + bytesIn.writeUtf8("www.example.com") + try { + hpackReader!!.readHeaders() + org.junit.jupiter.api.Assertions.fail() + } catch (e: IOException) { + assertThat(e.message).isEqualTo("Header index too large 78") + } + } + + private fun firstRequestWithoutHuffman() { + bytesIn.writeByte(0x82) // == Indexed - Add == + // idx = 2 -> :method: GET + bytesIn.writeByte(0x86) // == Indexed - Add == + // idx = 7 -> :scheme: http + bytesIn.writeByte(0x84) // == Indexed - Add == + // idx = 6 -> :path: / + bytesIn.writeByte(0x41) // == Literal indexed == + // Indexed name (idx = 4) -> :authority + bytesIn.writeByte(0x0f) // Literal value (len = 15) + bytesIn.writeUtf8("www.example.com") + } + + private fun checkReadFirstRequestWithoutHuffman() { + assertThat(hpackReader!!.headerCount).isEqualTo(1) + + // [ 1] (s = 57) :authority: www.example.com + val entry = hpackReader!!.dynamicTable[readerHeaderTableLength() - 1]!! + checkEntry(entry, ":authority", "www.example.com", 57) + + // Table size: 57 + assertThat(hpackReader!!.dynamicTableByteCount).isEqualTo(57) + + // Decoded header list: + assertThat(hpackReader!!.getAndResetHeaderList()).isEqualTo( + headerEntries( + ":method", "GET", + ":scheme", "http", + ":path", "/", + ":authority", "www.example.com" + ) + ) + } + + private fun secondRequestWithoutHuffman() { + bytesIn.writeByte(0x82) // == Indexed - Add == + // idx = 2 -> :method: GET + bytesIn.writeByte(0x86) // == Indexed - Add == + // idx = 7 -> :scheme: http + bytesIn.writeByte(0x84) // == Indexed - Add == + // idx = 6 -> :path: / + bytesIn.writeByte(0xbe) // == Indexed - Add == + // Indexed name (idx = 62) -> :authority: www.example.com + bytesIn.writeByte(0x58) // == Literal indexed == + // Indexed name (idx = 24) -> cache-control + bytesIn.writeByte(0x08) // Literal value (len = 8) + bytesIn.writeUtf8("no-cache") + } + + private fun checkReadSecondRequestWithoutHuffman() { + assertThat(hpackReader!!.headerCount).isEqualTo(2) + + // [ 1] (s = 53) cache-control: no-cache + var entry = hpackReader!!.dynamicTable[readerHeaderTableLength() - 2]!! + checkEntry(entry, "cache-control", "no-cache", 53) + + // [ 2] (s = 57) :authority: www.example.com + entry = hpackReader!!.dynamicTable[readerHeaderTableLength() - 1]!! + checkEntry(entry, ":authority", "www.example.com", 57) + + // Table size: 110 + assertThat(hpackReader!!.dynamicTableByteCount).isEqualTo(110) + + // Decoded header list: + assertThat(hpackReader!!.getAndResetHeaderList()).isEqualTo( + headerEntries( + ":method", "GET", + ":scheme", "http", + ":path", "/", + ":authority", "www.example.com", + "cache-control", "no-cache" + ) + ) + } + + private fun thirdRequestWithoutHuffman() { + bytesIn.writeByte(0x82) // == Indexed - Add == + // idx = 2 -> :method: GET + bytesIn.writeByte(0x87) // == Indexed - Add == + // idx = 7 -> :scheme: http + bytesIn.writeByte(0x85) // == Indexed - Add == + // idx = 5 -> :path: /index.html + bytesIn.writeByte(0xbf) // == Indexed - Add == + // Indexed name (idx = 63) -> :authority: www.example.com + bytesIn.writeByte(0x40) // Literal indexed + bytesIn.writeByte(0x0a) // Literal name (len = 10) + bytesIn.writeUtf8("custom-key") + bytesIn.writeByte(0x0c) // Literal value (len = 12) + bytesIn.writeUtf8("custom-value") + } + + private fun checkReadThirdRequestWithoutHuffman() { + assertThat(hpackReader!!.headerCount).isEqualTo(3) + + // [ 1] (s = 54) custom-key: custom-value + var entry = hpackReader!!.dynamicTable[readerHeaderTableLength() - 3]!! + checkEntry(entry, "custom-key", "custom-value", 54) + + // [ 2] (s = 53) cache-control: no-cache + entry = hpackReader!!.dynamicTable[readerHeaderTableLength() - 2]!! + checkEntry(entry, "cache-control", "no-cache", 53) + + // [ 3] (s = 57) :authority: www.example.com + entry = hpackReader!!.dynamicTable[readerHeaderTableLength() - 1]!! + checkEntry(entry, ":authority", "www.example.com", 57) + + // Table size: 164 + assertThat(hpackReader!!.dynamicTableByteCount).isEqualTo(164) + + // Decoded header list: + assertThat(hpackReader!!.getAndResetHeaderList()).isEqualTo( + headerEntries( + ":method", "GET", + ":scheme", "https", + ":path", "/index.html", + ":authority", "www.example.com", + "custom-key", "custom-value" + ) + ) + } + + /** + * http://tools.ietf.org/html/draft-ietf-httpbis-header-compression-12#appendix-C.4 + */ + @Test + fun readRequestExamplesWithHuffman() { + firstRequestWithHuffman() + hpackReader!!.readHeaders() + checkReadFirstRequestWithHuffman() + secondRequestWithHuffman() + hpackReader!!.readHeaders() + checkReadSecondRequestWithHuffman() + thirdRequestWithHuffman() + hpackReader!!.readHeaders() + checkReadThirdRequestWithHuffman() + } + + private fun firstRequestWithHuffman() { + bytesIn.writeByte(0x82) // == Indexed - Add == + // idx = 2 -> :method: GET + bytesIn.writeByte(0x86) // == Indexed - Add == + // idx = 6 -> :scheme: http + bytesIn.writeByte(0x84) // == Indexed - Add == + // idx = 4 -> :path: / + bytesIn.writeByte(0x41) // == Literal indexed == + // Indexed name (idx = 1) -> :authority + bytesIn.writeByte(0x8c) // Literal value Huffman encoded 12 bytes + // decodes to www.example.com which is length 15 + bytesIn.write("f1e3c2e5f23a6ba0ab90f4ff".decodeHex()) + } + + private fun checkReadFirstRequestWithHuffman() { + assertThat(hpackReader!!.headerCount).isEqualTo(1) + + // [ 1] (s = 57) :authority: www.example.com + val entry = hpackReader!!.dynamicTable[readerHeaderTableLength() - 1]!! + checkEntry(entry, ":authority", "www.example.com", 57) + + // Table size: 57 + assertThat(hpackReader!!.dynamicTableByteCount).isEqualTo(57) + + // Decoded header list: + assertThat(hpackReader!!.getAndResetHeaderList()).isEqualTo( + headerEntries( + ":method", "GET", + ":scheme", "http", + ":path", "/", + ":authority", "www.example.com" + ) + ) + } + + private fun secondRequestWithHuffman() { + bytesIn.writeByte(0x82) // == Indexed - Add == + // idx = 2 -> :method: GET + bytesIn.writeByte(0x86) // == Indexed - Add == + // idx = 6 -> :scheme: http + bytesIn.writeByte(0x84) // == Indexed - Add == + // idx = 4 -> :path: / + bytesIn.writeByte(0xbe) // == Indexed - Add == + // idx = 62 -> :authority: www.example.com + bytesIn.writeByte(0x58) // == Literal indexed == + // Indexed name (idx = 24) -> cache-control + bytesIn.writeByte(0x86) // Literal value Huffman encoded 6 bytes + // decodes to no-cache which is length 8 + bytesIn.write("a8eb10649cbf".decodeHex()) + } + + private fun checkReadSecondRequestWithHuffman() { + assertThat(hpackReader!!.headerCount).isEqualTo(2) + + // [ 1] (s = 53) cache-control: no-cache + var entry = hpackReader!!.dynamicTable[readerHeaderTableLength() - 2]!! + checkEntry(entry, "cache-control", "no-cache", 53) + + // [ 2] (s = 57) :authority: www.example.com + entry = hpackReader!!.dynamicTable[readerHeaderTableLength() - 1]!! + checkEntry(entry, ":authority", "www.example.com", 57) + + // Table size: 110 + assertThat(hpackReader!!.dynamicTableByteCount).isEqualTo(110) + + // Decoded header list: + assertThat(hpackReader!!.getAndResetHeaderList()).isEqualTo( + headerEntries( + ":method", "GET", + ":scheme", "http", + ":path", "/", + ":authority", "www.example.com", + "cache-control", "no-cache" + ) + ) + } + + private fun thirdRequestWithHuffman() { + bytesIn.writeByte(0x82) // == Indexed - Add == + // idx = 2 -> :method: GET + bytesIn.writeByte(0x87) // == Indexed - Add == + // idx = 7 -> :scheme: https + bytesIn.writeByte(0x85) // == Indexed - Add == + // idx = 5 -> :path: /index.html + bytesIn.writeByte(0xbf) // == Indexed - Add == + // idx = 63 -> :authority: www.example.com + bytesIn.writeByte(0x40) // Literal indexed + bytesIn.writeByte(0x88) // Literal name Huffman encoded 8 bytes + // decodes to custom-key which is length 10 + bytesIn.write("25a849e95ba97d7f".decodeHex()) + bytesIn.writeByte(0x89) // Literal value Huffman encoded 9 bytes + // decodes to custom-value which is length 12 + bytesIn.write("25a849e95bb8e8b4bf".decodeHex()) + } + + private fun checkReadThirdRequestWithHuffman() { + assertThat(hpackReader!!.headerCount).isEqualTo(3) + + // [ 1] (s = 54) custom-key: custom-value + var entry = hpackReader!!.dynamicTable[readerHeaderTableLength() - 3]!! + checkEntry(entry, "custom-key", "custom-value", 54) + + // [ 2] (s = 53) cache-control: no-cache + entry = hpackReader!!.dynamicTable[readerHeaderTableLength() - 2]!! + checkEntry(entry, "cache-control", "no-cache", 53) + + // [ 3] (s = 57) :authority: www.example.com + entry = hpackReader!!.dynamicTable[readerHeaderTableLength() - 1]!! + checkEntry(entry, ":authority", "www.example.com", 57) + + // Table size: 164 + assertThat(hpackReader!!.dynamicTableByteCount).isEqualTo(164) + + // Decoded header list: + assertThat(hpackReader!!.getAndResetHeaderList()).isEqualTo( + headerEntries( + ":method", "GET", + ":scheme", "https", + ":path", "/index.html", + ":authority", "www.example.com", + "custom-key", "custom-value" + ) + ) + } + + @Test + fun readSingleByteInt() { + assertThat(newReader(byteStream()).readInt(10, 31)).isEqualTo(10) + assertThat(newReader(byteStream()).readInt(0xe0 or 10, 31)).isEqualTo(10) + } + + @Test + fun readMultibyteInt() { + assertThat(newReader(byteStream(154, 10)).readInt(31, 31)).isEqualTo(1337) + } + + @Test + fun writeSingleByteInt() { + hpackWriter!!.writeInt(10, 31, 0) + assertBytes(10) + hpackWriter!!.writeInt(10, 31, 0xe0) + assertBytes(0xe0 or 10) + } + + @Test + fun writeMultibyteInt() { + hpackWriter!!.writeInt(1337, 31, 0) + assertBytes(31, 154, 10) + hpackWriter!!.writeInt(1337, 31, 0xe0) + assertBytes(0xe0 or 31, 154, 10) + } + + @Test + fun max31BitValue() { + hpackWriter!!.writeInt(0x7fffffff, 31, 0) + assertBytes(31, 224, 255, 255, 255, 7) + assertThat(newReader(byteStream(224, 255, 255, 255, 7)).readInt(31, 31)) + .isEqualTo(0x7fffffffL) + } + + @Test + fun prefixMask() { + hpackWriter!!.writeInt(31, 31, 0) + assertBytes(31, 0) + assertThat(newReader(byteStream(0)).readInt(31, 31)).isEqualTo(31) + } + + @Test + fun prefixMaskMinusOne() { + hpackWriter!!.writeInt(30, 31, 0) + assertBytes(30) + assertThat(newReader(byteStream(0)).readInt(31, 31)).isEqualTo(31) + } + + @Test + fun zero() { + hpackWriter!!.writeInt(0, 31, 0) + assertBytes(0) + assertThat(newReader(byteStream()).readInt(0, 31)).isEqualTo(0) + } + + @Test + fun lowercaseHeaderNameBeforeEmit() { + hpackWriter!!.writeHeaders(listOf(Header("FoO", "BaR"))) + assertBytes(0x40, 3, 'f'.code, 'o'.code, 'o'.code, 3, 'B'.code, 'a'.code, 'R'.code) + } + + @Test + fun mixedCaseHeaderNameIsMalformed() { + try { + newReader( + byteStream( + 0, + 3, + 'F'.code, + 'o'.code, + 'o'.code, + 3, + 'B'.code, + 'a'.code, + 'R'.code + ) + ).readHeaders() + org.junit.jupiter.api.Assertions.fail() + } catch (e: IOException) { + assertThat(e.message).isEqualTo( + "PROTOCOL_ERROR response malformed: mixed case name: Foo" + ) + } + } + + @Test + fun emptyHeaderName() { + hpackWriter!!.writeByteString("".encodeUtf8()) + assertBytes(0) + assertThat(newReader(byteStream(0)).readByteString()) + .isEqualTo(ByteString.EMPTY) + } + + @Test + fun emitsDynamicTableSizeUpdate() { + hpackWriter!!.resizeHeaderTable(2048) + hpackWriter!!.writeHeaders(listOf(Header("foo", "bar"))) + assertBytes( + 0x3F, 0xE1, 0xF, // Dynamic table size update (size = 2048). + 0x40, 3, 'f'.code, 'o'.code, 'o'.code, 3, 'b'.code, 'a'.code, 'r'.code + ) + hpackWriter!!.resizeHeaderTable(8192) + hpackWriter!!.writeHeaders(listOf(Header("bar", "foo"))) + assertBytes( + 0x3F, 0xE1, 0x3F, // Dynamic table size update (size = 8192). + 0x40, 3, 'b'.code, 'a'.code, 'r'.code, 3, 'f'.code, 'o'.code, 'o'.code + ) + + // No more dynamic table updates should be emitted. + hpackWriter!!.writeHeaders(listOf(Header("far", "boo"))) + assertBytes(0x40, 3, 'f'.code, 'a'.code, 'r'.code, 3, 'b'.code, 'o'.code, 'o'.code) + } + + @Test + fun noDynamicTableSizeUpdateWhenSizeIsEqual() { + val currentSize = hpackWriter!!.headerTableSizeSetting + hpackWriter!!.resizeHeaderTable(currentSize) + hpackWriter!!.writeHeaders(listOf(Header("foo", "bar"))) + assertBytes(0x40, 3, 'f'.code, 'o'.code, 'o'.code, 3, 'b'.code, 'a'.code, 'r'.code) + } + + @Test + fun growDynamicTableSize() { + hpackWriter!!.resizeHeaderTable(8192) + hpackWriter!!.resizeHeaderTable(16384) + hpackWriter!!.writeHeaders(listOf(Header("foo", "bar"))) + assertBytes( + 0x3F, 0xE1, 0x7F, // Dynamic table size update (size = 16384). + 0x40, 3, 'f'.code, 'o'.code, 'o'.code, 3, 'b'.code, 'a'.code, 'r'.code + ) + } + + @Test + fun shrinkDynamicTableSize() { + hpackWriter!!.resizeHeaderTable(2048) + hpackWriter!!.resizeHeaderTable(0) + hpackWriter!!.writeHeaders(listOf(Header("foo", "bar"))) + assertBytes( + 0x20, // Dynamic size update (size = 0). + 0x40, 3, 'f'.code, 'o'.code, 'o'.code, 3, 'b'.code, 'a'.code, 'r'.code + ) + } + + @Test + fun manyDynamicTableSizeChanges() { + hpackWriter!!.resizeHeaderTable(16384) + hpackWriter!!.resizeHeaderTable(8096) + hpackWriter!!.resizeHeaderTable(0) + hpackWriter!!.resizeHeaderTable(4096) + hpackWriter!!.resizeHeaderTable(2048) + hpackWriter!!.writeHeaders(listOf(Header("foo", "bar"))) + assertBytes( + 0x20, // Dynamic size update (size = 0). + 0x3F, 0xE1, 0xF, // Dynamic size update (size = 2048). + 0x40, 3, 'f'.code, 'o'.code, 'o'.code, 3, 'b'.code, 'a'.code, 'r'.code + ) + } + + @Test + fun dynamicTableEvictionWhenSizeLowered() { + val headerBlock = headerEntries( + "custom-key1", "custom-header", + "custom-key2", "custom-header" + ) + hpackWriter!!.writeHeaders(headerBlock) + assertThat(hpackWriter!!.headerCount).isEqualTo(2) + hpackWriter!!.resizeHeaderTable(56) + assertThat(hpackWriter!!.headerCount).isEqualTo(1) + hpackWriter!!.resizeHeaderTable(0) + assertThat(hpackWriter!!.headerCount).isEqualTo(0) + } + + @Test + fun noEvictionOnDynamicTableSizeIncrease() { + val headerBlock = headerEntries( + "custom-key1", "custom-header", + "custom-key2", "custom-header" + ) + hpackWriter!!.writeHeaders(headerBlock) + assertThat(hpackWriter!!.headerCount).isEqualTo(2) + hpackWriter!!.resizeHeaderTable(8192) + assertThat(hpackWriter!!.headerCount).isEqualTo(2) + } + + @Test + fun dynamicTableSizeHasAnUpperBound() { + hpackWriter!!.resizeHeaderTable(1048576) + assertThat(hpackWriter!!.maxDynamicTableByteCount).isEqualTo(16384) + } + + @Test + fun huffmanEncode() { + hpackWriter = Hpack.Writer(4096, true, bytesOut) + hpackWriter!!.writeHeaders(headerEntries("foo", "bar")) + val expected = Buffer() + .writeByte(0x40) // Literal header, new name. + .writeByte(0x82) // String literal is Huffman encoded (len = 2). + .writeByte(0x94) // 'foo' Huffman encoded. + .writeByte(0xE7) + .writeByte(3) // String literal not Huffman encoded (len = 3). + .writeByte('b'.code) + .writeByte('a'.code) + .writeByte('r'.code) + .readByteString() + val actual = bytesOut.readByteString() + assertThat(actual).isEqualTo(expected) + } + + @Test + fun staticTableIndexedHeaders() { + hpackWriter!!.writeHeaders(headerEntries(":method", "GET")) + assertBytes(0x82) + assertThat(hpackWriter!!.headerCount).isEqualTo(0) + hpackWriter!!.writeHeaders(headerEntries(":method", "POST")) + assertBytes(0x83) + assertThat(hpackWriter!!.headerCount).isEqualTo(0) + hpackWriter!!.writeHeaders(headerEntries(":path", "/")) + assertBytes(0x84) + assertThat(hpackWriter!!.headerCount).isEqualTo(0) + hpackWriter!!.writeHeaders(headerEntries(":path", "/index.html")) + assertBytes(0x85) + assertThat(hpackWriter!!.headerCount).isEqualTo(0) + hpackWriter!!.writeHeaders(headerEntries(":scheme", "http")) + assertBytes(0x86) + assertThat(hpackWriter!!.headerCount).isEqualTo(0) + hpackWriter!!.writeHeaders(headerEntries(":scheme", "https")) + assertBytes(0x87) + assertThat(hpackWriter!!.headerCount).isEqualTo(0) + } + + @Test + fun dynamicTableIndexedHeader() { + hpackWriter!!.writeHeaders(headerEntries("custom-key", "custom-header")) + assertBytes( + 0x40, + 10, + 'c'.code, + 'u'.code, + 's'.code, + 't'.code, + 'o'.code, + 'm'.code, + '-'.code, + 'k'.code, + 'e'.code, + 'y'.code, + 13, + 'c'.code, + 'u'.code, + 's'.code, + 't'.code, + 'o'.code, + 'm'.code, + '-'.code, + 'h'.code, + 'e'.code, + 'a'.code, + 'd'.code, + 'e'.code, + 'r'.code + ) + assertThat(hpackWriter!!.headerCount).isEqualTo(1) + hpackWriter!!.writeHeaders(headerEntries("custom-key", "custom-header")) + assertBytes(0xbe) + assertThat(hpackWriter!!.headerCount).isEqualTo(1) + } + + @Test + fun doNotIndexPseudoHeaders() { + hpackWriter!!.writeHeaders(headerEntries(":method", "PUT")) + assertBytes(0x02, 3, 'P'.code, 'U'.code, 'T'.code) + assertThat(hpackWriter!!.headerCount).isEqualTo(0) + hpackWriter!!.writeHeaders(headerEntries(":path", "/okhttp")) + assertBytes(0x04, 7, '/'.code, 'o'.code, 'k'.code, 'h'.code, 't'.code, 't'.code, 'p'.code) + assertThat(hpackWriter!!.headerCount).isEqualTo(0) + } + + @Test + fun incrementalIndexingWithAuthorityPseudoHeader() { + hpackWriter!!.writeHeaders(headerEntries(":authority", "foo.com")) + assertBytes(0x41, 7, 'f'.code, 'o'.code, 'o'.code, '.'.code, 'c'.code, 'o'.code, 'm'.code) + assertThat(hpackWriter!!.headerCount).isEqualTo(1) + hpackWriter!!.writeHeaders(headerEntries(":authority", "foo.com")) + assertBytes(0xbe) + assertThat(hpackWriter!!.headerCount).isEqualTo(1) + + // If the :authority header somehow changes, it should be re-added to the dynamic table. + hpackWriter!!.writeHeaders(headerEntries(":authority", "bar.com")) + assertBytes(0x41, 7, 'b'.code, 'a'.code, 'r'.code, '.'.code, 'c'.code, 'o'.code, 'm'.code) + assertThat(hpackWriter!!.headerCount).isEqualTo(2) + hpackWriter!!.writeHeaders(headerEntries(":authority", "bar.com")) + assertBytes(0xbe) + assertThat(hpackWriter!!.headerCount).isEqualTo(2) + } + + @Test + fun incrementalIndexingWithStaticTableIndexedName() { + hpackWriter!!.writeHeaders(headerEntries("accept-encoding", "gzip")) + assertBytes(0x50, 4, 'g'.code, 'z'.code, 'i'.code, 'p'.code) + assertThat(hpackWriter!!.headerCount).isEqualTo(1) + hpackWriter!!.writeHeaders(headerEntries("accept-encoding", "gzip")) + assertBytes(0xbe) + assertThat(hpackWriter!!.headerCount).isEqualTo(1) + } + + @Test + fun incrementalIndexingWithDynamcTableIndexedName() { + hpackWriter!!.writeHeaders(headerEntries("foo", "bar")) + assertBytes(0x40, 3, 'f'.code, 'o'.code, 'o'.code, 3, 'b'.code, 'a'.code, 'r'.code) + assertThat(hpackWriter!!.headerCount).isEqualTo(1) + hpackWriter!!.writeHeaders(headerEntries("foo", "bar1")) + assertBytes(0x7e, 4, 'b'.code, 'a'.code, 'r'.code, '1'.code) + assertThat(hpackWriter!!.headerCount).isEqualTo(2) + hpackWriter!!.writeHeaders(headerEntries("foo", "bar1")) + assertBytes(0xbe) + assertThat(hpackWriter!!.headerCount).isEqualTo(2) + } + + private fun newReader(source: Buffer): Hpack.Reader { + return Hpack.Reader(source, 4096) + } + + private fun byteStream(vararg bytes: Int): Buffer { + return Buffer().write(intArrayToByteArray(bytes)) + } + + private fun checkEntry(entry: Header, name: String, value: String, size: Int) { + assertThat(entry.name.utf8()).isEqualTo(name) + assertThat(entry.value.utf8()).isEqualTo(value) + assertThat(entry.hpackSize).isEqualTo(size) + } + + private fun assertBytes(vararg bytes: Int) { + val expected = intArrayToByteArray(bytes) + val actual = bytesOut.readByteString() + assertThat(actual).isEqualTo(expected) + } + + private fun intArrayToByteArray(bytes: IntArray): ByteString { + val data = ByteArray(bytes.size) + for (i in bytes.indices) { + data[i] = bytes[i].toByte() + } + return ByteString.of(*data) + } + + private fun readerHeaderTableLength(): Int { + return hpackReader!!.dynamicTable.size + } +} diff --git a/okhttp/src/test/java/okhttp3/internal/tls/CertificatePinnerChainValidationTest.java b/okhttp/src/test/java/okhttp3/internal/tls/CertificatePinnerChainValidationTest.java deleted file mode 100644 index bd9db61157c9..000000000000 --- a/okhttp/src/test/java/okhttp3/internal/tls/CertificatePinnerChainValidationTest.java +++ /dev/null @@ -1,632 +0,0 @@ -/* - * Copyright (C) 2016 Square, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package okhttp3.internal.tls; - -import mockwebserver3.MockResponse; -import mockwebserver3.MockWebServer; -import mockwebserver3.SocketPolicy.DisconnectAtEnd; -import okhttp3.Call; -import okhttp3.CertificatePinner; -import okhttp3.OkHttpClient; -import okhttp3.OkHttpClientTestRule; -import okhttp3.RecordingHostnameVerifier; -import okhttp3.Request; -import okhttp3.Response; -import okhttp3.internal.platform.Platform; -import okhttp3.testing.PlatformRule; -import okhttp3.tls.HandshakeCertificates; -import okhttp3.tls.HeldCertificate; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.RegisterExtension; - -import javax.net.ssl.KeyManager; -import javax.net.ssl.SSLContext; -import javax.net.ssl.SSLHandshakeException; -import javax.net.ssl.SSLPeerUnverifiedException; -import javax.net.ssl.SSLSocketFactory; -import javax.net.ssl.TrustManager; -import javax.net.ssl.X509KeyManager; -import javax.net.ssl.X509TrustManager; -import java.security.GeneralSecurityException; -import java.security.SecureRandom; -import java.security.cert.X509Certificate; -import java.util.Collections; - -import static okhttp3.tls.internal.TlsUtil.newKeyManager; -import static okhttp3.tls.internal.TlsUtil.newTrustManager; -import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.fail; - -public final class CertificatePinnerChainValidationTest { - @RegisterExtension PlatformRule platform = new PlatformRule(); - @RegisterExtension OkHttpClientTestRule clientTestRule = new OkHttpClientTestRule(); - - private MockWebServer server; - - @BeforeEach - public void setup(MockWebServer server) { - this.server = server; - platform.assumeNotBouncyCastle(); - } - - /** The pinner should pull the root certificate from the trust manager. */ - @Test public void pinRootNotPresentInChain() throws Exception { - // Fails on 11.0.1 https://github.com/square/okhttp/issues/4703 - - HeldCertificate rootCa = new HeldCertificate.Builder() - .serialNumber(1L) - .certificateAuthority(1) - .commonName("root") - .build(); - HeldCertificate intermediateCa = new HeldCertificate.Builder() - .signedBy(rootCa) - .certificateAuthority(0) - .serialNumber(2L) - .commonName("intermediate_ca") - .build(); - HeldCertificate certificate = new HeldCertificate.Builder() - .signedBy(intermediateCa) - .serialNumber(3L) - .commonName(server.getHostName()) - .build(); - CertificatePinner certificatePinner = new CertificatePinner.Builder() - .add(server.getHostName(), CertificatePinner.pin(rootCa.certificate())) - .build(); - HandshakeCertificates handshakeCertificates = new HandshakeCertificates.Builder() - .addTrustedCertificate(rootCa.certificate()) - .build(); - OkHttpClient client = clientTestRule.newClientBuilder() - .sslSocketFactory( - handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager()) - .hostnameVerifier(new RecordingHostnameVerifier()) - .certificatePinner(certificatePinner) - .build(); - - HandshakeCertificates serverHandshakeCertificates = new HandshakeCertificates.Builder() - .heldCertificate(certificate, intermediateCa.certificate()) - .build(); - server.useHttps(serverHandshakeCertificates.sslSocketFactory()); - - // The request should complete successfully. - server.enqueue(new MockResponse.Builder() - .body("abc") - .build()); - Call call1 = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - Response response1 = call1.execute(); - assertThat(response1.body().string()).isEqualTo("abc"); - } - - /** The pinner should accept an intermediate from the server's chain. */ - @Test public void pinIntermediatePresentInChain() throws Exception { - // Fails on 11.0.1 https://github.com/square/okhttp/issues/4703 - - HeldCertificate rootCa = new HeldCertificate.Builder() - .serialNumber(1L) - .certificateAuthority(1) - .commonName("root") - .build(); - HeldCertificate intermediateCa = new HeldCertificate.Builder() - .signedBy(rootCa) - .certificateAuthority(0) - .serialNumber(2L) - .commonName("intermediate_ca") - .build(); - HeldCertificate certificate = new HeldCertificate.Builder() - .signedBy(intermediateCa) - .serialNumber(3L) - .commonName(server.getHostName()) - .build(); - CertificatePinner certificatePinner = new CertificatePinner.Builder() - .add(server.getHostName(), CertificatePinner.pin(intermediateCa.certificate())) - .build(); - HandshakeCertificates handshakeCertificates = new HandshakeCertificates.Builder() - .addTrustedCertificate(rootCa.certificate()) - .build(); - OkHttpClient client = clientTestRule.newClientBuilder() - .sslSocketFactory( - handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager()) - .hostnameVerifier(new RecordingHostnameVerifier()) - .certificatePinner(certificatePinner) - .build(); - - HandshakeCertificates serverHandshakeCertificates = new HandshakeCertificates.Builder() - .heldCertificate(certificate, intermediateCa.certificate()) - .build(); - server.useHttps(serverHandshakeCertificates.sslSocketFactory()); - - // The request should complete successfully. - server.enqueue(new MockResponse.Builder() - .body("abc") - .socketPolicy(DisconnectAtEnd.INSTANCE) - .build()); - Call call1 = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - Response response1 = call1.execute(); - assertThat(response1.body().string()).isEqualTo("abc"); - response1.close(); - - // Force a fresh connection for the next request. - client.connectionPool().evictAll(); - - // Confirm that a second request also succeeds. This should detect caching problems. - server.enqueue(new MockResponse.Builder() - .body("def") - .socketPolicy(DisconnectAtEnd.INSTANCE) - .build()); - Call call2 = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - Response response2 = call2.execute(); - assertThat(response2.body().string()).isEqualTo("def"); - response2.close(); - } - - @Test public void unrelatedPinnedLeafCertificateInChain() throws Exception { - // https://github.com/square/okhttp/issues/4729 - platform.expectFailureOnConscryptPlatform(); - platform.expectFailureOnCorrettoPlatform(); - platform.expectFailureOnLoomPlatform(); - - // Start with a trusted root CA certificate. - HeldCertificate rootCa = new HeldCertificate.Builder() - .serialNumber(1L) - .certificateAuthority(1) - .commonName("root") - .build(); - - // Add a good intermediate CA, and have that issue a good certificate to localhost. Prepare an - // SSL context for an HTTP client under attack. It includes the trusted CA and a pinned - // certificate. - HeldCertificate goodIntermediateCa = new HeldCertificate.Builder() - .signedBy(rootCa) - .certificateAuthority(0) - .serialNumber(2L) - .commonName("good_intermediate_ca") - .build(); - HeldCertificate goodCertificate = new HeldCertificate.Builder() - .signedBy(goodIntermediateCa) - .serialNumber(3L) - .commonName(server.getHostName()) - .build(); - CertificatePinner certificatePinner = new CertificatePinner.Builder() - .add(server.getHostName(), CertificatePinner.pin(goodCertificate.certificate())) - .build(); - HandshakeCertificates handshakeCertificates = new HandshakeCertificates.Builder() - .addTrustedCertificate(rootCa.certificate()) - .build(); - OkHttpClient client = clientTestRule.newClientBuilder() - .sslSocketFactory( - handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager()) - .hostnameVerifier(new RecordingHostnameVerifier()) - .certificatePinner(certificatePinner) - .build(); - - // Add a bad intermediate CA and have that issue a rogue certificate for localhost. Prepare - // an SSL context for an attacking webserver. It includes both these rogue certificates plus the - // trusted good certificate above. The attack is that by including the good certificate in the - // chain, we may trick the certificate pinner into accepting the rouge certificate. - HeldCertificate compromisedIntermediateCa = new HeldCertificate.Builder() - .signedBy(rootCa) - .certificateAuthority(0) - .serialNumber(4L) - .commonName("bad_intermediate_ca") - .build(); - HeldCertificate rogueCertificate = new HeldCertificate.Builder() - .serialNumber(5L) - .signedBy(compromisedIntermediateCa) - .commonName(server.getHostName()) - .build(); - - SSLSocketFactory socketFactory = newServerSocketFactory(rogueCertificate, - compromisedIntermediateCa.certificate(), goodCertificate.certificate()); - - server.useHttps(socketFactory); - server.enqueue(new MockResponse.Builder() - .body("abc") - .addHeader("Content-Type: text/plain") - .build()); - - // Make a request from client to server. It should succeed certificate checks (unfortunately the - // rogue CA is trusted) but it should fail certificate pinning. - Request request = new Request.Builder() - .url(server.url("/")) - .build(); - Call call = client.newCall(request); - try { - call.execute(); - fail(); - } catch (SSLPeerUnverifiedException expected) { - // Certificate pinning fails! - String message = expected.getMessage(); - assertThat(message).startsWith("Certificate pinning failure!"); - } - } - - @Test public void unrelatedPinnedIntermediateCertificateInChain() throws Exception { - // https://github.com/square/okhttp/issues/4729 - platform.expectFailureOnConscryptPlatform(); - platform.expectFailureOnCorrettoPlatform(); - platform.expectFailureOnLoomPlatform(); - - // Start with two root CA certificates, one is good and the other is compromised. - HeldCertificate rootCa = new HeldCertificate.Builder() - .serialNumber(1L) - .certificateAuthority(1) - .commonName("root") - .build(); - HeldCertificate compromisedRootCa = new HeldCertificate.Builder() - .serialNumber(2L) - .certificateAuthority(1) - .commonName("compromised_root") - .build(); - - // Add a good intermediate CA, and have that issue a good certificate to localhost. Prepare an - // SSL context for an HTTP client under attack. It includes the trusted CA and a pinned - // certificate. - HeldCertificate goodIntermediateCa = new HeldCertificate.Builder() - .signedBy(rootCa) - .certificateAuthority(0) - .serialNumber(3L) - .commonName("intermediate_ca") - .build(); - CertificatePinner certificatePinner = new CertificatePinner.Builder() - .add(server.getHostName(), CertificatePinner.pin(goodIntermediateCa.certificate())) - .build(); - HandshakeCertificates handshakeCertificates = new HandshakeCertificates.Builder() - .addTrustedCertificate(rootCa.certificate()) - .addTrustedCertificate(compromisedRootCa.certificate()) - .build(); - OkHttpClient client = clientTestRule.newClientBuilder() - .sslSocketFactory( - handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager()) - .hostnameVerifier(new RecordingHostnameVerifier()) - .certificatePinner(certificatePinner) - .build(); - - // The attacker compromises the root CA, issues an intermediate with the same common name - // "intermediate_ca" as the good CA. This signs a rogue certificate for localhost. The server - // serves the good CAs certificate in the chain, which means the certificate pinner sees a - // different set of certificates than the SSL verifier. - HeldCertificate compromisedIntermediateCa = new HeldCertificate.Builder() - .signedBy(compromisedRootCa) - .certificateAuthority(0) - .serialNumber(4L) - .commonName("intermediate_ca") - .build(); - HeldCertificate rogueCertificate = new HeldCertificate.Builder() - .serialNumber(5L) - .signedBy(compromisedIntermediateCa) - .commonName(server.getHostName()) - .build(); - - SSLSocketFactory socketFactory = newServerSocketFactory(rogueCertificate, - goodIntermediateCa.certificate(), compromisedIntermediateCa.certificate()); - server.useHttps(socketFactory); - server.enqueue(new MockResponse.Builder() - .body("abc") - .addHeader("Content-Type: text/plain") - .build()); - - // Make a request from client to server. It should succeed certificate checks (unfortunately the - // rogue CA is trusted) but it should fail certificate pinning. - Request request = new Request.Builder() - .url(server.url("/")) - .build(); - Call call = client.newCall(request); - try { - call.execute(); - fail(); - } catch (SSLHandshakeException expected) { - // On Android, the handshake fails before the certificate pinner runs. - String message = expected.getMessage(); - assertThat(message).contains("Could not validate certificate"); - } catch (SSLPeerUnverifiedException expected) { - // On OpenJDK, the handshake succeeds but the certificate pinner fails. - String message = expected.getMessage(); - assertThat(message).startsWith("Certificate pinning failure!"); - } - } - - /** - * Not checking the CA bit created a vulnerability in old OkHttp releases. It is exploited by - * triggering different chains to be discovered by the TLS engine and our chain cleaner. In this - * attack there's several different chains. - * - *

The victim's gets a non-CA certificate signed by a CA, and pins the CA root and/or - * intermediate. This is business as usual. - * - *

{@code
-   *
-   *   pinnedRoot (trusted by CertificatePinner)
-   *     -> pinnedIntermediate (trusted by CertificatePinner)
-   *       -> realVictim
-   *
-   * }
- * - *

The attacker compromises a CA. They take the public key from an intermediate certificate - * signed by the compromised CA's certificate and uses it in a non-CA certificate. They ask the - * pinned CA above to sign it for non-certificate-authority uses: - * - *

{@code
-   *
-   *   pinnedRoot (trusted by CertificatePinner)
-   *     -> pinnedIntermediate (trusted by CertificatePinner)
-   *         -> attackerSwitch
-   *
-   * }
- * - *

The attacker serves a set of certificates that yields a too-long chain in our certificate - * pinner. The served certificates (incorrectly) formed a single chain to the pinner: - * - *

{@code
-   *
-   *   attackerCa
-   *     -> attackerIntermediate
-   *         -> pinnedRoot (trusted by CertificatePinner)
-   *             -> pinnedIntermediate (trusted by CertificatePinner)
-   *                 -> attackerSwitch (not a CA certificate!)
-   *                     -> phonyVictim
-   *
-   * }
- * - * But this chain is wrong because the attackerSwitch certificate is being used in a CA role even - * though it is not a CA certificate. There are pinned certificates in the chain! The correct - * chain is much shorter because it skips the non-CA certificate. - * - *
{@code
-   *
-   *   attackerCa
-   *     -> attackerIntermediate
-   *         -> phonyVictim
-   *
-   * }
- * - * Some implementations fail the TLS handshake when they see the long chain, and don't give - * CertificatePinner the opportunity to produce a different chain from their own. This includes - * the OpenJDK 11 TLS implementation, which itself fails the handshake when it encounters a non-CA - * certificate. - */ - @Test public void signersMustHaveCaBitSet() throws Exception { - HeldCertificate attackerCa = new HeldCertificate.Builder() - .serialNumber(1L) - .certificateAuthority(4) - .commonName("attacker ca") - .build(); - HeldCertificate attackerIntermediate = new HeldCertificate.Builder() - .serialNumber(2L) - .certificateAuthority(3) - .commonName("attacker") - .signedBy(attackerCa) - .build(); - HeldCertificate pinnedRoot = new HeldCertificate.Builder() - .serialNumber(3L) - .certificateAuthority(2) - .commonName("pinned root") - .signedBy(attackerIntermediate) - .build(); - HeldCertificate pinnedIntermediate = new HeldCertificate.Builder() - .serialNumber(4L) - .certificateAuthority(1) - .commonName("pinned intermediate") - .signedBy(pinnedRoot) - .build(); - HeldCertificate attackerSwitch = new HeldCertificate.Builder() - .serialNumber(5L) - .keyPair(attackerIntermediate.keyPair()) // share keys between compromised CA and leaf! - .commonName("attacker") - .addSubjectAlternativeName("attacker.com") - .signedBy(pinnedIntermediate) - .build(); - HeldCertificate phonyVictim = new HeldCertificate.Builder() - .serialNumber(6L) - .signedBy(attackerSwitch) - .addSubjectAlternativeName("victim.com") - .commonName("victim") - .build(); - - CertificatePinner certificatePinner = new CertificatePinner.Builder() - .add(server.getHostName(), CertificatePinner.pin(pinnedRoot.certificate())) - .build(); - HandshakeCertificates handshakeCertificates = new HandshakeCertificates.Builder() - .addTrustedCertificate(pinnedRoot.certificate()) - .addTrustedCertificate(attackerCa.certificate()) - .build(); - OkHttpClient client = clientTestRule.newClientBuilder() - .sslSocketFactory( - handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager()) - .hostnameVerifier(new RecordingHostnameVerifier()) - .certificatePinner(certificatePinner) - .build(); - - HandshakeCertificates serverHandshakeCertificates = new HandshakeCertificates.Builder() - .heldCertificate( - phonyVictim, - attackerSwitch.certificate(), - pinnedIntermediate.certificate(), - pinnedRoot.certificate(), - attackerIntermediate.certificate() - ) - .build(); - server.useHttps(serverHandshakeCertificates.sslSocketFactory()); - - server.enqueue(new MockResponse()); - - // Make a request from client to server. It should succeed certificate checks (unfortunately the - // rogue CA is trusted) but it should fail certificate pinning. - Request request = new Request.Builder() - .url(server.url("/")) - .build(); - Call call = client.newCall(request); - try (Response response = call.execute()) { - fail("expected connection failure but got " + response); - } catch (SSLPeerUnverifiedException expected) { - // Certificate pinning fails! - String message = expected.getMessage(); - assertThat(message).startsWith("Certificate pinning failure!"); - } catch (SSLHandshakeException expected) { - // We didn't have the opportunity to do certificate pinning because the handshake failed. - assertThat(expected).hasMessageContaining("this is not a CA certificate"); - } - } - - /** - * Attack the CA intermediates check by presenting unrelated chains to the handshake vs. - * certificate pinner. - * - * This chain is valid but not pinned: - * - *
{@code
-   *
-   *   attackerCa
-   *    -> phonyVictim
-   *
-   * }
- * - * This chain is pinned but not valid: - * - *
{@code
-   *
-   *   attackerCa
-   *     -> pinnedRoot (trusted by CertificatePinner)
-   *         -> compromisedIntermediate (max intermediates: 0)
-   *             -> attackerIntermediate (max intermediates: 0)
-   *                 -> phonyVictim
-   * }
- */ - @Test public void intermediateMustNotHaveMoreIntermediatesThanSigner() throws Exception { - HeldCertificate attackerCa = new HeldCertificate.Builder() - .serialNumber(1L) - .certificateAuthority(2) - .commonName("attacker ca") - .build(); - HeldCertificate pinnedRoot = new HeldCertificate.Builder() - .serialNumber(2L) - .certificateAuthority(1) - .commonName("pinned root") - .signedBy(attackerCa) - .build(); - HeldCertificate compromisedIntermediate = new HeldCertificate.Builder() - .serialNumber(3L) - .certificateAuthority(0) - .commonName("compromised intermediate") - .signedBy(pinnedRoot) - .build(); - HeldCertificate attackerIntermediate = new HeldCertificate.Builder() - .keyPair(attackerCa.keyPair()) // Share keys between compromised CA and intermediate! - .serialNumber(4L) - .certificateAuthority(0) // More intermediates than permitted by signer! - .commonName("attacker intermediate") - .signedBy(compromisedIntermediate) - .build(); - HeldCertificate phonyVictim = new HeldCertificate.Builder() - .serialNumber(5L) - .signedBy(attackerIntermediate) - .addSubjectAlternativeName("victim.com") - .commonName("victim") - .build(); - - CertificatePinner certificatePinner = new CertificatePinner.Builder() - .add(server.getHostName(), CertificatePinner.pin(pinnedRoot.certificate())) - .build(); - HandshakeCertificates handshakeCertificates = new HandshakeCertificates.Builder() - .addTrustedCertificate(pinnedRoot.certificate()) - .addTrustedCertificate(attackerCa.certificate()) - .build(); - OkHttpClient client = clientTestRule.newClientBuilder() - .sslSocketFactory( - handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager()) - .hostnameVerifier(new RecordingHostnameVerifier()) - .certificatePinner(certificatePinner) - .build(); - - HandshakeCertificates serverHandshakeCertificates = new HandshakeCertificates.Builder() - .heldCertificate( - phonyVictim, - attackerIntermediate.certificate(), - compromisedIntermediate.certificate(), - pinnedRoot.certificate() - ) - .build(); - server.useHttps(serverHandshakeCertificates.sslSocketFactory()); - - server.enqueue(new MockResponse()); - - // Make a request from client to server. It should not succeed certificate checks. - Request request = new Request.Builder() - .url(server.url("/")) - .build(); - Call call = client.newCall(request); - try (Response response = call.execute()) { - fail("expected connection failure but got " + response); - } catch (SSLHandshakeException expected) { - } - } - - @Test public void lonePinnedCertificate() throws Exception { - HeldCertificate onlyCertificate = new HeldCertificate.Builder() - .serialNumber(1L) - .commonName("root") - .build(); - CertificatePinner certificatePinner = new CertificatePinner.Builder() - .add(server.getHostName(), CertificatePinner.pin(onlyCertificate.certificate())) - .build(); - HandshakeCertificates handshakeCertificates = new HandshakeCertificates.Builder() - .addTrustedCertificate(onlyCertificate.certificate()) - .build(); - OkHttpClient client = clientTestRule.newClientBuilder() - .sslSocketFactory( - handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager()) - .hostnameVerifier(new RecordingHostnameVerifier()) - .certificatePinner(certificatePinner) - .build(); - - HandshakeCertificates serverHandshakeCertificates = new HandshakeCertificates.Builder() - .heldCertificate(onlyCertificate) - .build(); - server.useHttps(serverHandshakeCertificates.sslSocketFactory()); - - // The request should complete successfully. - server.enqueue(new MockResponse.Builder() - .body("abc") - .build()); - Call call1 = client.newCall(new Request.Builder() - .url(server.url("/")) - .build()); - Response response1 = call1.execute(); - assertThat(response1.body().string()).isEqualTo("abc"); - } - - private SSLSocketFactory newServerSocketFactory(HeldCertificate heldCertificate, - X509Certificate... intermediates) throws GeneralSecurityException { - // Test setup fails on JDK9 - // java.security.KeyStoreException: Certificate chain is not valid - // at sun.security.pkcs12.PKCS12KeyStore.setKeyEntry - // http://openjdk.java.net/jeps/229 - // http://hg.openjdk.java.net/jdk9/jdk9/jdk/file/2c1c21d11e58/src/share/classes/sun/security/pkcs12/PKCS12KeyStore.java#l596 - String keystoreType = platform.isJdk9() ? "JKS" : null; - X509KeyManager x509KeyManager = newKeyManager(keystoreType, heldCertificate, intermediates); - X509TrustManager trustManager = newTrustManager( - keystoreType, Collections.emptyList(), Collections.emptyList()); - SSLContext sslContext = Platform.get().newSSLContext(); - sslContext.init(new KeyManager[] {x509KeyManager}, new TrustManager[] {trustManager}, - new SecureRandom()); - return sslContext.getSocketFactory(); - } -} diff --git a/okhttp/src/test/java/okhttp3/internal/tls/CertificatePinnerChainValidationTest.kt b/okhttp/src/test/java/okhttp3/internal/tls/CertificatePinnerChainValidationTest.kt new file mode 100644 index 000000000000..e38ccb6f6738 --- /dev/null +++ b/okhttp/src/test/java/okhttp3/internal/tls/CertificatePinnerChainValidationTest.kt @@ -0,0 +1,653 @@ +/* + * Copyright (C) 2016 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3.internal.tls + +import java.security.SecureRandom +import java.security.cert.X509Certificate +import javax.net.ssl.KeyManager +import javax.net.ssl.SSLHandshakeException +import javax.net.ssl.SSLPeerUnverifiedException +import javax.net.ssl.SSLSocketFactory +import javax.net.ssl.TrustManager +import mockwebserver3.MockResponse +import mockwebserver3.MockWebServer +import mockwebserver3.SocketPolicy.DisconnectAtEnd +import okhttp3.CertificatePinner +import okhttp3.CertificatePinner.Companion.pin +import okhttp3.OkHttpClientTestRule +import okhttp3.RecordingHostnameVerifier +import okhttp3.Request +import okhttp3.internal.platform.Platform.Companion.get +import okhttp3.testing.PlatformRule +import okhttp3.tls.HandshakeCertificates +import okhttp3.tls.HeldCertificate +import okhttp3.tls.internal.TlsUtil.newKeyManager +import okhttp3.tls.internal.TlsUtil.newTrustManager +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Assertions.fail +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.RegisterExtension + +class CertificatePinnerChainValidationTest { + @RegisterExtension + var platform = PlatformRule() + + @RegisterExtension + var clientTestRule = OkHttpClientTestRule() + + private lateinit var server: MockWebServer + + @BeforeEach + fun setup(server: MockWebServer) { + this.server = server + platform.assumeNotBouncyCastle() + } + + /** + * The pinner should pull the root certificate from the trust manager. + */ + @Test + fun pinRootNotPresentInChain() { + // Fails on 11.0.1 https://github.com/square/okhttp/issues/4703 + val rootCa = HeldCertificate.Builder() + .serialNumber(1L) + .certificateAuthority(1) + .commonName("root") + .build() + val intermediateCa = HeldCertificate.Builder() + .signedBy(rootCa) + .certificateAuthority(0) + .serialNumber(2L) + .commonName("intermediate_ca") + .build() + val certificate = HeldCertificate.Builder() + .signedBy(intermediateCa) + .serialNumber(3L) + .commonName(server.hostName) + .build() + val certificatePinner = CertificatePinner.Builder() + .add(server.hostName, pin(rootCa.certificate)) + .build() + val handshakeCertificates = HandshakeCertificates.Builder() + .addTrustedCertificate(rootCa.certificate) + .build() + val client = clientTestRule.newClientBuilder() + .sslSocketFactory( + handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager + ) + .hostnameVerifier(RecordingHostnameVerifier()) + .certificatePinner(certificatePinner) + .build() + val serverHandshakeCertificates = HandshakeCertificates.Builder() + .heldCertificate(certificate, intermediateCa.certificate) + .build() + server.useHttps(serverHandshakeCertificates.sslSocketFactory()) + + // The request should complete successfully. + server.enqueue( + MockResponse.Builder() + .body("abc") + .build() + ) + val call1 = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + val response1 = call1.execute() + assertThat(response1.body.string()).isEqualTo("abc") + } + + /** + * The pinner should accept an intermediate from the server's chain. + */ + @Test + fun pinIntermediatePresentInChain() { + // Fails on 11.0.1 https://github.com/square/okhttp/issues/4703 + val rootCa = HeldCertificate.Builder() + .serialNumber(1L) + .certificateAuthority(1) + .commonName("root") + .build() + val intermediateCa = HeldCertificate.Builder() + .signedBy(rootCa) + .certificateAuthority(0) + .serialNumber(2L) + .commonName("intermediate_ca") + .build() + val certificate = HeldCertificate.Builder() + .signedBy(intermediateCa) + .serialNumber(3L) + .commonName(server.hostName) + .build() + val certificatePinner = CertificatePinner.Builder() + .add(server.hostName, pin(intermediateCa.certificate)) + .build() + val handshakeCertificates = HandshakeCertificates.Builder() + .addTrustedCertificate(rootCa.certificate) + .build() + val client = clientTestRule.newClientBuilder() + .sslSocketFactory( + handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager + ) + .hostnameVerifier(RecordingHostnameVerifier()) + .certificatePinner(certificatePinner) + .build() + val serverHandshakeCertificates = HandshakeCertificates.Builder() + .heldCertificate(certificate, intermediateCa.certificate) + .build() + server.useHttps(serverHandshakeCertificates.sslSocketFactory()) + + // The request should complete successfully. + server.enqueue( + MockResponse.Builder() + .body("abc") + .socketPolicy(DisconnectAtEnd) + .build() + ) + val call1 = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + val response1 = call1.execute() + assertThat(response1.body.string()).isEqualTo("abc") + response1.close() + + // Force a fresh connection for the next request. + client.connectionPool.evictAll() + + // Confirm that a second request also succeeds. This should detect caching problems. + server.enqueue( + MockResponse.Builder() + .body("def") + .socketPolicy(DisconnectAtEnd) + .build() + ) + val call2 = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + val response2 = call2.execute() + assertThat(response2.body.string()).isEqualTo("def") + response2.close() + } + + @Test + fun unrelatedPinnedLeafCertificateInChain() { + // https://github.com/square/okhttp/issues/4729 + platform.expectFailureOnConscryptPlatform() + platform.expectFailureOnCorrettoPlatform() + platform.expectFailureOnLoomPlatform() + + // Start with a trusted root CA certificate. + val rootCa = HeldCertificate.Builder() + .serialNumber(1L) + .certificateAuthority(1) + .commonName("root") + .build() + + // Add a good intermediate CA, and have that issue a good certificate to localhost. Prepare an + // SSL context for an HTTP client under attack. It includes the trusted CA and a pinned + // certificate. + val goodIntermediateCa = HeldCertificate.Builder() + .signedBy(rootCa) + .certificateAuthority(0) + .serialNumber(2L) + .commonName("good_intermediate_ca") + .build() + val goodCertificate = HeldCertificate.Builder() + .signedBy(goodIntermediateCa) + .serialNumber(3L) + .commonName(server.hostName) + .build() + val certificatePinner = CertificatePinner.Builder() + .add(server.hostName, pin(goodCertificate.certificate)) + .build() + val handshakeCertificates = HandshakeCertificates.Builder() + .addTrustedCertificate(rootCa.certificate) + .build() + val client = clientTestRule.newClientBuilder() + .sslSocketFactory( + handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager + ) + .hostnameVerifier(RecordingHostnameVerifier()) + .certificatePinner(certificatePinner) + .build() + + // Add a bad intermediate CA and have that issue a rogue certificate for localhost. Prepare + // an SSL context for an attacking webserver. It includes both these rogue certificates plus the + // trusted good certificate above. The attack is that by including the good certificate in the + // chain, we may trick the certificate pinner into accepting the rouge certificate. + val compromisedIntermediateCa = HeldCertificate.Builder() + .signedBy(rootCa) + .certificateAuthority(0) + .serialNumber(4L) + .commonName("bad_intermediate_ca") + .build() + val rogueCertificate = HeldCertificate.Builder() + .serialNumber(5L) + .signedBy(compromisedIntermediateCa) + .commonName(server.hostName) + .build() + val socketFactory = newServerSocketFactory( + rogueCertificate, + compromisedIntermediateCa.certificate, goodCertificate.certificate + ) + server.useHttps(socketFactory) + server.enqueue( + MockResponse.Builder() + .body("abc") + .addHeader("Content-Type: text/plain") + .build() + ) + + // Make a request from client to server. It should succeed certificate checks (unfortunately the + // rogue CA is trusted) but it should fail certificate pinning. + val request = Request.Builder() + .url(server.url("/")) + .build() + val call = client.newCall(request) + try { + call.execute() + fail() + } catch (expected: SSLPeerUnverifiedException) { + // Certificate pinning fails! + val message = expected.message + assertThat(message).startsWith("Certificate pinning failure!") + } + } + + @Test + fun unrelatedPinnedIntermediateCertificateInChain() { + // https://github.com/square/okhttp/issues/4729 + platform.expectFailureOnConscryptPlatform() + platform.expectFailureOnCorrettoPlatform() + platform.expectFailureOnLoomPlatform() + + // Start with two root CA certificates, one is good and the other is compromised. + val rootCa = HeldCertificate.Builder() + .serialNumber(1L) + .certificateAuthority(1) + .commonName("root") + .build() + val compromisedRootCa = HeldCertificate.Builder() + .serialNumber(2L) + .certificateAuthority(1) + .commonName("compromised_root") + .build() + + // Add a good intermediate CA, and have that issue a good certificate to localhost. Prepare an + // SSL context for an HTTP client under attack. It includes the trusted CA and a pinned + // certificate. + val goodIntermediateCa = HeldCertificate.Builder() + .signedBy(rootCa) + .certificateAuthority(0) + .serialNumber(3L) + .commonName("intermediate_ca") + .build() + val certificatePinner = CertificatePinner.Builder() + .add(server.hostName, pin(goodIntermediateCa.certificate)) + .build() + val handshakeCertificates = HandshakeCertificates.Builder() + .addTrustedCertificate(rootCa.certificate) + .addTrustedCertificate(compromisedRootCa.certificate) + .build() + val client = clientTestRule.newClientBuilder() + .sslSocketFactory( + handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager + ) + .hostnameVerifier(RecordingHostnameVerifier()) + .certificatePinner(certificatePinner) + .build() + + // The attacker compromises the root CA, issues an intermediate with the same common name + // "intermediate_ca" as the good CA. This signs a rogue certificate for localhost. The server + // serves the good CAs certificate in the chain, which means the certificate pinner sees a + // different set of certificates than the SSL verifier. + val compromisedIntermediateCa = HeldCertificate.Builder() + .signedBy(compromisedRootCa) + .certificateAuthority(0) + .serialNumber(4L) + .commonName("intermediate_ca") + .build() + val rogueCertificate = HeldCertificate.Builder() + .serialNumber(5L) + .signedBy(compromisedIntermediateCa) + .commonName(server.hostName) + .build() + val socketFactory = newServerSocketFactory( + rogueCertificate, + goodIntermediateCa.certificate, compromisedIntermediateCa.certificate + ) + server.useHttps(socketFactory) + server.enqueue( + MockResponse.Builder() + .body("abc") + .addHeader("Content-Type: text/plain") + .build() + ) + + // Make a request from client to server. It should succeed certificate checks (unfortunately the + // rogue CA is trusted) but it should fail certificate pinning. + val request = Request.Builder() + .url(server.url("/")) + .build() + val call = client.newCall(request) + try { + call.execute() + fail() + } catch (expected: SSLHandshakeException) { + // On Android, the handshake fails before the certificate pinner runs. + val message = expected.message + assertThat(message).contains("Could not validate certificate") + } catch (expected: SSLPeerUnverifiedException) { + // On OpenJDK, the handshake succeeds but the certificate pinner fails. + val message = expected.message + assertThat(message).startsWith("Certificate pinning failure!") + } + } + + /** + * Not checking the CA bit created a vulnerability in old OkHttp releases. It is exploited by + * triggering different chains to be discovered by the TLS engine and our chain cleaner. In this + * attack there's several different chains. + * + * + * The victim's gets a non-CA certificate signed by a CA, and pins the CA root and/or + * intermediate. This is business as usual. + * + * ``` + * pinnedRoot (trusted by CertificatePinner) + * -> pinnedIntermediate (trusted by CertificatePinner) + * -> realVictim + * ``` + * + * The attacker compromises a CA. They take the public key from an intermediate certificate + * signed by the compromised CA's certificate and uses it in a non-CA certificate. They ask the + * pinned CA above to sign it for non-certificate-authority uses: + * + * ``` + * pinnedRoot (trusted by CertificatePinner) + * -> pinnedIntermediate (trusted by CertificatePinner) + * -> attackerSwitch + * ``` + * + * The attacker serves a set of certificates that yields a too-long chain in our certificate + * pinner. The served certificates (incorrectly) formed a single chain to the pinner: + * + * ``` + * attackerCa + * -> attackerIntermediate + * -> pinnedRoot (trusted by CertificatePinner) + * -> pinnedIntermediate (trusted by CertificatePinner) + * -> attackerSwitch (not a CA certificate!) + * -> phonyVictim + * ``` + * + * But this chain is wrong because the attackerSwitch certificate is being used in a CA role even + * though it is not a CA certificate. There are pinned certificates in the chain! The correct + * chain is much shorter because it skips the non-CA certificate. + * + * ``` + * attackerCa + * -> attackerIntermediate + * -> phonyVictim + * ``` + * + * Some implementations fail the TLS handshake when they see the long chain, and don't give + * CertificatePinner the opportunity to produce a different chain from their own. This includes + * the OpenJDK 11 TLS implementation, which itself fails the handshake when it encounters a non-CA + * certificate. + */ + @Test + fun signersMustHaveCaBitSet() { + val attackerCa = HeldCertificate.Builder() + .serialNumber(1L) + .certificateAuthority(4) + .commonName("attacker ca") + .build() + val attackerIntermediate = HeldCertificate.Builder() + .serialNumber(2L) + .certificateAuthority(3) + .commonName("attacker") + .signedBy(attackerCa) + .build() + val pinnedRoot = HeldCertificate.Builder() + .serialNumber(3L) + .certificateAuthority(2) + .commonName("pinned root") + .signedBy(attackerIntermediate) + .build() + val pinnedIntermediate = HeldCertificate.Builder() + .serialNumber(4L) + .certificateAuthority(1) + .commonName("pinned intermediate") + .signedBy(pinnedRoot) + .build() + val attackerSwitch = HeldCertificate.Builder() + .serialNumber(5L) + .keyPair(attackerIntermediate.keyPair) // share keys between compromised CA and leaf! + .commonName("attacker") + .addSubjectAlternativeName("attacker.com") + .signedBy(pinnedIntermediate) + .build() + val phonyVictim = HeldCertificate.Builder() + .serialNumber(6L) + .signedBy(attackerSwitch) + .addSubjectAlternativeName("victim.com") + .commonName("victim") + .build() + val certificatePinner = CertificatePinner.Builder() + .add(server.hostName, pin(pinnedRoot.certificate)) + .build() + val handshakeCertificates = HandshakeCertificates.Builder() + .addTrustedCertificate(pinnedRoot.certificate) + .addTrustedCertificate(attackerCa.certificate) + .build() + val client = clientTestRule.newClientBuilder() + .sslSocketFactory( + handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager + ) + .hostnameVerifier(RecordingHostnameVerifier()) + .certificatePinner(certificatePinner) + .build() + val serverHandshakeCertificates = HandshakeCertificates.Builder() + .heldCertificate( + phonyVictim, + attackerSwitch.certificate, + pinnedIntermediate.certificate, + pinnedRoot.certificate, + attackerIntermediate.certificate + ) + .build() + server.useHttps(serverHandshakeCertificates.sslSocketFactory()) + server.enqueue(MockResponse()) + + // Make a request from client to server. It should succeed certificate checks (unfortunately the + // rogue CA is trusted) but it should fail certificate pinning. + val request = Request.Builder() + .url(server.url("/")) + .build() + val call = client.newCall(request) + try { + call.execute() + .use { response -> fail("expected connection failure but got $response") } + } catch (expected: SSLPeerUnverifiedException) { + // Certificate pinning fails! + val message = expected.message + assertThat(message).startsWith("Certificate pinning failure!") + } catch (expected: SSLHandshakeException) { + // We didn't have the opportunity to do certificate pinning because the handshake failed. + assertThat(expected).hasMessageContaining("this is not a CA certificate") + } + } + + /** + * Attack the CA intermediates check by presenting unrelated chains to the handshake vs. + * certificate pinner. + * + * This chain is valid but not pinned: + * + * ``` + * attackerCa + * -> phonyVictim + * ``` + * + * + * This chain is pinned but not valid: + * + * ``` + * attackerCa + * -> pinnedRoot (trusted by CertificatePinner) + * -> compromisedIntermediate (max intermediates: 0) + * -> attackerIntermediate (max intermediates: 0) + * -> phonyVictim + * ``` + */ + @Test + fun intermediateMustNotHaveMoreIntermediatesThanSigner() { + val attackerCa = HeldCertificate.Builder() + .serialNumber(1L) + .certificateAuthority(2) + .commonName("attacker ca") + .build() + val pinnedRoot = HeldCertificate.Builder() + .serialNumber(2L) + .certificateAuthority(1) + .commonName("pinned root") + .signedBy(attackerCa) + .build() + val compromisedIntermediate = HeldCertificate.Builder() + .serialNumber(3L) + .certificateAuthority(0) + .commonName("compromised intermediate") + .signedBy(pinnedRoot) + .build() + val attackerIntermediate = HeldCertificate.Builder() + .keyPair(attackerCa.keyPair) // Share keys between compromised CA and intermediate! + .serialNumber(4L) + .certificateAuthority(0) // More intermediates than permitted by signer! + .commonName("attacker intermediate") + .signedBy(compromisedIntermediate) + .build() + val phonyVictim = HeldCertificate.Builder() + .serialNumber(5L) + .signedBy(attackerIntermediate) + .addSubjectAlternativeName("victim.com") + .commonName("victim") + .build() + val certificatePinner = CertificatePinner.Builder() + .add(server.hostName, pin(pinnedRoot.certificate)) + .build() + val handshakeCertificates = HandshakeCertificates.Builder() + .addTrustedCertificate(pinnedRoot.certificate) + .addTrustedCertificate(attackerCa.certificate) + .build() + val client = clientTestRule.newClientBuilder() + .sslSocketFactory( + handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager + ) + .hostnameVerifier(RecordingHostnameVerifier()) + .certificatePinner(certificatePinner) + .build() + val serverHandshakeCertificates = HandshakeCertificates.Builder() + .heldCertificate( + phonyVictim, + attackerIntermediate.certificate, + compromisedIntermediate.certificate, + pinnedRoot.certificate + ) + .build() + server.useHttps(serverHandshakeCertificates.sslSocketFactory()) + server.enqueue(MockResponse()) + + // Make a request from client to server. It should not succeed certificate checks. + val request = Request.Builder() + .url(server.url("/")) + .build() + val call = client.newCall(request) + try { + call.execute().use { response -> + fail("expected connection failure but got $response") + } + } catch (expected: SSLHandshakeException) { + } + } + + @Test + fun lonePinnedCertificate() { + val onlyCertificate = HeldCertificate.Builder() + .serialNumber(1L) + .commonName("root") + .build() + val certificatePinner = CertificatePinner.Builder() + .add(server.hostName, pin(onlyCertificate.certificate)) + .build() + val handshakeCertificates = HandshakeCertificates.Builder() + .addTrustedCertificate(onlyCertificate.certificate) + .build() + val client = clientTestRule.newClientBuilder() + .sslSocketFactory( + handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager + ) + .hostnameVerifier(RecordingHostnameVerifier()) + .certificatePinner(certificatePinner) + .build() + val serverHandshakeCertificates = HandshakeCertificates.Builder() + .heldCertificate(onlyCertificate) + .build() + server.useHttps(serverHandshakeCertificates.sslSocketFactory()) + + // The request should complete successfully. + server.enqueue( + MockResponse.Builder() + .body("abc") + .build() + ) + val call1 = client.newCall( + Request.Builder() + .url(server.url("/")) + .build() + ) + val response1 = call1.execute() + assertThat(response1.body.string()).isEqualTo("abc") + } + + private fun newServerSocketFactory( + heldCertificate: HeldCertificate, + vararg intermediates: X509Certificate + ): SSLSocketFactory { + // Test setup fails on JDK9 + // java.security.KeyStoreException: Certificate chain is not valid + // at sun.security.pkcs12.PKCS12KeyStore.setKeyEntry + // http://openjdk.java.net/jeps/229 + // http://hg.openjdk.java.net/jdk9/jdk9/jdk/file/2c1c21d11e58/src/share/classes/sun/security/pkcs12/PKCS12KeyStore.java#l596 + val keystoreType = if (platform.isJdk9()) "JKS" else null + val x509KeyManager = newKeyManager(keystoreType, heldCertificate, *intermediates) + val trustManager = newTrustManager( + keystoreType, emptyList(), emptyList() + ) + val sslContext = get().newSSLContext() + sslContext.init( + arrayOf(x509KeyManager), arrayOf(trustManager), + SecureRandom() + ) + return sslContext.socketFactory + } +} diff --git a/okhttp/src/test/java/okhttp3/internal/tls/ClientAuthTest.java b/okhttp/src/test/java/okhttp3/internal/tls/ClientAuthTest.java deleted file mode 100644 index 1ae1e8763905..000000000000 --- a/okhttp/src/test/java/okhttp3/internal/tls/ClientAuthTest.java +++ /dev/null @@ -1,374 +0,0 @@ -/* - * Copyright (C) 2016 Square, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package okhttp3.internal.tls; - -import java.io.IOException; -import java.net.SocketException; -import java.security.GeneralSecurityException; -import java.security.SecureRandom; -import java.security.cert.X509Certificate; -import java.util.Collections; -import java.util.List; -import javax.net.ssl.KeyManager; -import javax.net.ssl.SSLContext; -import javax.net.ssl.SSLException; -import javax.net.ssl.SSLHandshakeException; -import javax.net.ssl.SSLPeerUnverifiedException; -import javax.net.ssl.SSLSocketFactory; -import javax.net.ssl.TrustManager; -import javax.net.ssl.X509KeyManager; -import javax.net.ssl.X509TrustManager; -import javax.security.auth.x500.X500Principal; -import mockwebserver3.MockResponse; -import mockwebserver3.MockWebServer; -import mockwebserver3.junit5.internal.MockWebServerExtension; -import okhttp3.Call; -import okhttp3.OkHttpClient; -import okhttp3.OkHttpClientTestRule; -import okhttp3.RecordingEventListener; -import okhttp3.Request; -import okhttp3.Response; -import okhttp3.internal.http2.ConnectionShutdownException; -import okhttp3.testing.Flaky; -import okhttp3.testing.PlatformRule; -import okhttp3.tls.HandshakeCertificates; -import okhttp3.tls.HeldCertificate; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Tag; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.junit.jupiter.api.extension.RegisterExtension; -import org.junitpioneer.jupiter.RetryingTest; - -import static java.util.Arrays.asList; -import static okhttp3.tls.internal.TlsUtil.newKeyManager; -import static okhttp3.tls.internal.TlsUtil.newTrustManager; -import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.fail; - -@Tag("Slowish") -@ExtendWith(MockWebServerExtension.class) -public final class ClientAuthTest { - @RegisterExtension public final PlatformRule platform = new PlatformRule(); - @RegisterExtension public final OkHttpClientTestRule clientTestRule = new OkHttpClientTestRule(); - - private MockWebServer server; - private HeldCertificate serverRootCa; - private HeldCertificate serverIntermediateCa; - private HeldCertificate serverCert; - private HeldCertificate clientRootCa; - private HeldCertificate clientIntermediateCa; - private HeldCertificate clientCert; - - @BeforeEach - public void setUp(MockWebServer server) { - this.server = server; - - platform.assumeNotOpenJSSE(); - platform.assumeNotBouncyCastle(); - - serverRootCa = new HeldCertificate.Builder() - .serialNumber(1L) - .certificateAuthority(1) - .commonName("root") - .addSubjectAlternativeName("root_ca.com") - .build(); - serverIntermediateCa = new HeldCertificate.Builder() - .signedBy(serverRootCa) - .certificateAuthority(0) - .serialNumber(2L) - .commonName("intermediate_ca") - .addSubjectAlternativeName("intermediate_ca.com") - .build(); - - serverCert = new HeldCertificate.Builder() - .signedBy(serverIntermediateCa) - .serialNumber(3L) - .commonName("Local Host") - .addSubjectAlternativeName(server.getHostName()) - .build(); - - clientRootCa = new HeldCertificate.Builder() - .serialNumber(1L) - .certificateAuthority(1) - .commonName("root") - .addSubjectAlternativeName("root_ca.com") - .build(); - clientIntermediateCa = new HeldCertificate.Builder() - .signedBy(serverRootCa) - .certificateAuthority(0) - .serialNumber(2L) - .commonName("intermediate_ca") - .addSubjectAlternativeName("intermediate_ca.com") - .build(); - - clientCert = new HeldCertificate.Builder() - .signedBy(clientIntermediateCa) - .serialNumber(4L) - .commonName("Jethro Willis") - .addSubjectAlternativeName("jethrowillis.com") - .build(); - } - - @Test public void clientAuthForWants() throws Exception { - OkHttpClient client = buildClient(clientCert, clientIntermediateCa.certificate()); - - SSLSocketFactory socketFactory = buildServerSslSocketFactory(); - - server.useHttps(socketFactory); - server.requestClientAuth(); - server.enqueue(new MockResponse.Builder() - .body("abc") - .build()); - - Call call = client.newCall(new Request.Builder().url(server.url("/")).build()); - Response response = call.execute(); - assertThat(response.handshake().peerPrincipal()).isEqualTo( - new X500Principal("CN=Local Host")); - assertThat(response.handshake().localPrincipal()).isEqualTo( - new X500Principal("CN=Jethro Willis")); - assertThat(response.body().string()).isEqualTo("abc"); - } - - @Test public void clientAuthForNeeds() throws Exception { - OkHttpClient client = buildClient(clientCert, clientIntermediateCa.certificate()); - - SSLSocketFactory socketFactory = buildServerSslSocketFactory(); - - server.useHttps(socketFactory); - server.requireClientAuth(); - server.enqueue(new MockResponse.Builder() - .body("abc") - .build()); - - Call call = client.newCall(new Request.Builder().url(server.url("/")).build()); - Response response = call.execute(); - assertThat(response.handshake().peerPrincipal()).isEqualTo( - new X500Principal("CN=Local Host")); - assertThat(response.handshake().localPrincipal()).isEqualTo( - new X500Principal("CN=Jethro Willis")); - assertThat(response.body().string()).isEqualTo("abc"); - } - - @Test public void clientAuthSkippedForNone() throws Exception { - OkHttpClient client = buildClient(clientCert, clientIntermediateCa.certificate()); - - SSLSocketFactory socketFactory = buildServerSslSocketFactory(); - - server.useHttps(socketFactory); - server.noClientAuth(); - server.enqueue(new MockResponse.Builder() - .body("abc") - .build()); - - Call call = client.newCall(new Request.Builder().url(server.url("/")).build()); - Response response = call.execute(); - assertThat(response.handshake().peerPrincipal()).isEqualTo( - new X500Principal("CN=Local Host")); - assertThat(response.handshake().localPrincipal()).isNull(); - assertThat(response.body().string()).isEqualTo("abc"); - } - - @Test public void missingClientAuthSkippedForWantsOnly() throws Exception { - OkHttpClient client = buildClient(null, clientIntermediateCa.certificate()); - - SSLSocketFactory socketFactory = buildServerSslSocketFactory(); - - server.useHttps(socketFactory); - server.requestClientAuth(); - server.enqueue(new MockResponse.Builder() - .body("abc") - .build()); - - Call call = client.newCall(new Request.Builder().url(server.url("/")).build()); - Response response = call.execute(); - assertThat(response.handshake().peerPrincipal()).isEqualTo( - new X500Principal("CN=Local Host")); - assertThat(response.handshake().localPrincipal()).isNull(); - assertThat(response.body().string()).isEqualTo("abc"); - } - - @Flaky @RetryingTest(5) - public void missingClientAuthFailsForNeeds() throws Exception { - // Fails with 11.0.1 https://github.com/square/okhttp/issues/4598 - // StreamReset stream was reset: PROT... - - OkHttpClient client = buildClient(null, clientIntermediateCa.certificate()); - - SSLSocketFactory socketFactory = buildServerSslSocketFactory(); - - server.useHttps(socketFactory); - server.requireClientAuth(); - - Call call = client.newCall(new Request.Builder().url(server.url("/")).build()); - - try { - call.execute(); - fail(); - } catch (SSLHandshakeException expected) { - // JDK 11+ - } catch (SSLException expected) { - // javax.net.ssl.SSLException: readRecord - } catch (SocketException expected) { - // Conscrypt, JDK 8 (>= 292), JDK 9 - } catch (IOException expected) { - assertThat(expected.getMessage()).isEqualTo("exhausted all routes"); - } - } - - @Test public void commonNameIsNotTrusted() throws Exception { - serverCert = new HeldCertificate.Builder() - .signedBy(serverIntermediateCa) - .serialNumber(3L) - .commonName(server.getHostName()) - .addSubjectAlternativeName("different-host.com") - .build(); - - OkHttpClient client = buildClient(clientCert, clientIntermediateCa.certificate()); - - SSLSocketFactory socketFactory = buildServerSslSocketFactory(); - - server.useHttps(socketFactory); - server.requireClientAuth(); - - Call call = client.newCall(new Request.Builder().url(server.url("/")).build()); - - try { - call.execute(); - fail(); - } catch (SSLPeerUnverifiedException expected) { - } - } - - @Test public void invalidClientAuthFails() throws Throwable { - // Fails with https://github.com/square/okhttp/issues/4598 - // StreamReset stream was reset: PROT... - - HeldCertificate clientCert2 = new HeldCertificate.Builder() - .serialNumber(4L) - .commonName("Jethro Willis") - .build(); - - OkHttpClient client = buildClient(clientCert2); - - SSLSocketFactory socketFactory = buildServerSslSocketFactory(); - - server.useHttps(socketFactory); - server.requireClientAuth(); - - Call call = client.newCall(new Request.Builder().url(server.url("/")).build()); - - try { - call.execute(); - fail(); - } catch (SSLHandshakeException expected) { - // JDK 11+ - } catch (SSLException expected) { - // javax.net.ssl.SSLException: readRecord - } catch (SocketException expected) { - // Conscrypt, JDK 8 (>= 292), JDK 9 - } catch (ConnectionShutdownException expected) { - // It didn't fail until it reached the application layer. - } catch (IOException expected) { - assertThat(expected.getMessage()).isEqualTo("exhausted all routes"); - } - } - - @Test public void invalidClientAuthEvents() throws Throwable { - server.enqueue(new MockResponse.Builder() - .body("abc") - .build()); - - clientCert = new HeldCertificate.Builder() - .signedBy(clientIntermediateCa) - .serialNumber(4L) - .commonName("Jethro Willis") - .addSubjectAlternativeName("jethrowillis.com") - .validityInterval(1, 2) - .build(); - - OkHttpClient client = buildClient(clientCert, clientIntermediateCa.certificate()); - - RecordingEventListener listener = new RecordingEventListener(); - - client = client.newBuilder() - .eventListener(listener) - .build(); - - SSLSocketFactory socketFactory = buildServerSslSocketFactory(); - - server.useHttps(socketFactory); - server.requireClientAuth(); - - Call call = client.newCall(new Request.Builder().url(server.url("/")).build()); - - try { - call.execute(); - fail(); - } catch (IOException expected) { - } - - // Observed Events are variable - // JDK 14 - // CallStart, ProxySelectStart, ProxySelectEnd, DnsStart, DnsEnd, ConnectStart, SecureConnectStart, - // SecureConnectEnd, ConnectEnd, ConnectionAcquired, RequestHeadersStart, RequestHeadersEnd, - // ResponseFailed, ConnectionReleased, CallFailed - // JDK 8 - // CallStart, ProxySelectStart, ProxySelectEnd, DnsStart, DnsEnd, ConnectStart, SecureConnectStart, - // ConnectFailed, CallFailed - // Gradle - JDK 11 - // CallStart, ProxySelectStart, ProxySelectEnd, DnsStart, DnsEnd, ConnectStart, SecureConnectStart, - // SecureConnectEnd, ConnectFailed, CallFailed - - List recordedEventTypes = listener.recordedEventTypes(); - assertThat(recordedEventTypes).startsWith( - "CallStart", "ProxySelectStart", "ProxySelectEnd", "DnsStart", "DnsEnd", "ConnectStart", "SecureConnectStart"); - assertThat(recordedEventTypes).endsWith("CallFailed"); - } - - private OkHttpClient buildClient( - HeldCertificate heldCertificate, X509Certificate... intermediates) { - HandshakeCertificates.Builder builder = new HandshakeCertificates.Builder() - .addTrustedCertificate(serverRootCa.certificate()); - - if (heldCertificate != null) { - builder.heldCertificate(heldCertificate, intermediates); - } - - HandshakeCertificates handshakeCertificates = builder.build(); - return clientTestRule.newClientBuilder() - .sslSocketFactory( - handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager()) - .build(); - } - - private SSLSocketFactory buildServerSslSocketFactory() { - // The test uses JDK default SSL Context instead of the Platform provided one - // as Conscrypt seems to have some differences, we only want to test client side here. - try { - X509KeyManager keyManager = newKeyManager( - null, serverCert, serverIntermediateCa.certificate()); - X509TrustManager trustManager = newTrustManager(null, - asList(serverRootCa.certificate(), clientRootCa.certificate()), Collections.emptyList()); - SSLContext sslContext = SSLContext.getInstance("TLS"); - sslContext.init(new KeyManager[] {keyManager}, new TrustManager[] {trustManager}, - new SecureRandom()); - return sslContext.getSocketFactory(); - } catch (GeneralSecurityException e) { - throw new AssertionError(e); - } - } -} diff --git a/okhttp/src/test/java/okhttp3/internal/tls/ClientAuthTest.kt b/okhttp/src/test/java/okhttp3/internal/tls/ClientAuthTest.kt new file mode 100644 index 000000000000..8481190d8d82 --- /dev/null +++ b/okhttp/src/test/java/okhttp3/internal/tls/ClientAuthTest.kt @@ -0,0 +1,364 @@ +/* + * Copyright (C) 2016 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3.internal.tls + +import java.io.IOException +import java.net.SocketException +import java.security.GeneralSecurityException +import java.security.SecureRandom +import java.security.cert.X509Certificate +import java.util.Arrays +import javax.net.ssl.KeyManager +import javax.net.ssl.SSLContext +import javax.net.ssl.SSLException +import javax.net.ssl.SSLHandshakeException +import javax.net.ssl.SSLPeerUnverifiedException +import javax.net.ssl.SSLSocketFactory +import javax.net.ssl.TrustManager +import javax.security.auth.x500.X500Principal +import mockwebserver3.MockResponse +import mockwebserver3.MockWebServer +import mockwebserver3.junit5.internal.MockWebServerExtension +import okhttp3.OkHttpClient +import okhttp3.OkHttpClientTestRule +import okhttp3.RecordingEventListener +import okhttp3.Request +import okhttp3.internal.http2.ConnectionShutdownException +import okhttp3.testing.Flaky +import okhttp3.testing.PlatformRule +import okhttp3.tls.HandshakeCertificates +import okhttp3.tls.HeldCertificate +import okhttp3.tls.internal.TlsUtil.newKeyManager +import okhttp3.tls.internal.TlsUtil.newTrustManager +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Assertions.fail +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Tag +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.ExtendWith +import org.junit.jupiter.api.extension.RegisterExtension +import org.junitpioneer.jupiter.RetryingTest + +@Tag("Slowish") +@ExtendWith(MockWebServerExtension::class) +class ClientAuthTest { + @RegisterExtension + val platform = PlatformRule() + + @RegisterExtension + val clientTestRule = OkHttpClientTestRule() + + private lateinit var server: MockWebServer + private lateinit var serverRootCa: HeldCertificate + private lateinit var serverIntermediateCa: HeldCertificate + private lateinit var serverCert: HeldCertificate + private lateinit var clientRootCa: HeldCertificate + private lateinit var clientIntermediateCa: HeldCertificate + private lateinit var clientCert: HeldCertificate + + @BeforeEach + fun setUp(server: MockWebServer) { + this.server = server + platform.assumeNotOpenJSSE() + platform.assumeNotBouncyCastle() + serverRootCa = HeldCertificate.Builder() + .serialNumber(1L) + .certificateAuthority(1) + .commonName("root") + .addSubjectAlternativeName("root_ca.com") + .build() + serverIntermediateCa = HeldCertificate.Builder() + .signedBy(serverRootCa) + .certificateAuthority(0) + .serialNumber(2L) + .commonName("intermediate_ca") + .addSubjectAlternativeName("intermediate_ca.com") + .build() + serverCert = HeldCertificate.Builder() + .signedBy(serverIntermediateCa) + .serialNumber(3L) + .commonName("Local Host") + .addSubjectAlternativeName(server.hostName) + .build() + clientRootCa = HeldCertificate.Builder() + .serialNumber(1L) + .certificateAuthority(1) + .commonName("root") + .addSubjectAlternativeName("root_ca.com") + .build() + clientIntermediateCa = HeldCertificate.Builder() + .signedBy(serverRootCa) + .certificateAuthority(0) + .serialNumber(2L) + .commonName("intermediate_ca") + .addSubjectAlternativeName("intermediate_ca.com") + .build() + clientCert = HeldCertificate.Builder() + .signedBy(clientIntermediateCa) + .serialNumber(4L) + .commonName("Jethro Willis") + .addSubjectAlternativeName("jethrowillis.com") + .build() + } + + @Test + fun clientAuthForWants() { + val client = buildClient(clientCert, clientIntermediateCa.certificate) + val socketFactory = buildServerSslSocketFactory() + server.useHttps(socketFactory) + server.requestClientAuth() + server.enqueue( + MockResponse.Builder() + .body("abc") + .build() + ) + val call = client.newCall(Request.Builder().url(server.url("/")).build()) + val response = call.execute() + assertThat(response.handshake!!.peerPrincipal).isEqualTo( + X500Principal("CN=Local Host") + ) + assertThat(response.handshake!!.localPrincipal).isEqualTo( + X500Principal("CN=Jethro Willis") + ) + assertThat(response.body.string()).isEqualTo("abc") + } + + @Test + fun clientAuthForNeeds() { + val client = buildClient(clientCert, clientIntermediateCa.certificate) + val socketFactory = buildServerSslSocketFactory() + server.useHttps(socketFactory) + server.requireClientAuth() + server.enqueue( + MockResponse.Builder() + .body("abc") + .build() + ) + val call = client.newCall(Request.Builder().url(server.url("/")).build()) + val response = call.execute() + assertThat(response.handshake!!.peerPrincipal).isEqualTo( + X500Principal("CN=Local Host") + ) + assertThat(response.handshake!!.localPrincipal).isEqualTo( + X500Principal("CN=Jethro Willis") + ) + assertThat(response.body.string()).isEqualTo("abc") + } + + @Test + fun clientAuthSkippedForNone() { + val client = buildClient(clientCert, clientIntermediateCa.certificate) + val socketFactory = buildServerSslSocketFactory() + server.useHttps(socketFactory) + server.noClientAuth() + server.enqueue( + MockResponse.Builder() + .body("abc") + .build() + ) + val call = client.newCall(Request.Builder().url(server.url("/")).build()) + val response = call.execute() + assertThat(response.handshake!!.peerPrincipal).isEqualTo( + X500Principal("CN=Local Host") + ) + assertThat(response.handshake!!.localPrincipal).isNull() + assertThat(response.body.string()).isEqualTo("abc") + } + + @Test + fun missingClientAuthSkippedForWantsOnly() { + val client = buildClient(null, clientIntermediateCa.certificate) + val socketFactory = buildServerSslSocketFactory() + server.useHttps(socketFactory) + server.requestClientAuth() + server.enqueue( + MockResponse.Builder() + .body("abc") + .build() + ) + val call = client.newCall(Request.Builder().url(server.url("/")).build()) + val response = call.execute() + assertThat(response.handshake!!.peerPrincipal).isEqualTo( + X500Principal("CN=Local Host") + ) + assertThat(response.handshake!!.localPrincipal).isNull() + assertThat(response.body.string()).isEqualTo("abc") + } + + @Flaky + @RetryingTest(5) + fun missingClientAuthFailsForNeeds() { + // Fails with 11.0.1 https://github.com/square/okhttp/issues/4598 + // StreamReset stream was reset: PROT... + val client = buildClient(null, clientIntermediateCa.certificate) + val socketFactory = buildServerSslSocketFactory() + server.useHttps(socketFactory) + server.requireClientAuth() + val call = client.newCall(Request.Builder().url(server.url("/")).build()) + try { + call.execute() + fail() + } catch (expected: SSLHandshakeException) { + // JDK 11+ + } catch (expected: SSLException) { + // javax.net.ssl.SSLException: readRecord + } catch (expected: SocketException) { + // Conscrypt, JDK 8 (>= 292), JDK 9 + } catch (expected: IOException) { + assertThat(expected.message).isEqualTo("exhausted all routes") + } + } + + @Test + fun commonNameIsNotTrusted() { + serverCert = HeldCertificate.Builder() + .signedBy(serverIntermediateCa) + .serialNumber(3L) + .commonName(server.hostName) + .addSubjectAlternativeName("different-host.com") + .build() + val client = buildClient(clientCert, clientIntermediateCa.certificate) + val socketFactory = buildServerSslSocketFactory() + server.useHttps(socketFactory) + server.requireClientAuth() + val call = client.newCall(Request.Builder().url(server.url("/")).build()) + try { + call.execute() + fail() + } catch (expected: SSLPeerUnverifiedException) { + } + } + + @Test + fun invalidClientAuthFails() { + // Fails with https://github.com/square/okhttp/issues/4598 + // StreamReset stream was reset: PROT... + val clientCert2 = HeldCertificate.Builder() + .serialNumber(4L) + .commonName("Jethro Willis") + .build() + val client = buildClient(clientCert2) + val socketFactory = buildServerSslSocketFactory() + server.useHttps(socketFactory) + server.requireClientAuth() + val call = client.newCall(Request.Builder().url(server.url("/")).build()) + try { + call.execute() + fail() + } catch (expected: SSLHandshakeException) { + // JDK 11+ + } catch (expected: SSLException) { + // javax.net.ssl.SSLException: readRecord + } catch (expected: SocketException) { + // Conscrypt, JDK 8 (>= 292), JDK 9 + } catch (expected: ConnectionShutdownException) { + // It didn't fail until it reached the application layer. + } catch (expected: IOException) { + assertThat(expected.message).isEqualTo("exhausted all routes") + } + } + + @Test + fun invalidClientAuthEvents() { + server.enqueue( + MockResponse.Builder() + .body("abc") + .build() + ) + clientCert = HeldCertificate.Builder() + .signedBy(clientIntermediateCa) + .serialNumber(4L) + .commonName("Jethro Willis") + .addSubjectAlternativeName("jethrowillis.com") + .validityInterval(1, 2) + .build() + var client = buildClient(clientCert, clientIntermediateCa.certificate) + val listener = RecordingEventListener() + client = client.newBuilder() + .eventListener(listener) + .build() + val socketFactory = buildServerSslSocketFactory() + server.useHttps(socketFactory) + server.requireClientAuth() + val call = client.newCall(Request.Builder().url(server.url("/")).build()) + try { + call.execute() + fail() + } catch (expected: IOException) { + } + + // Observed Events are variable + // JDK 14 + // CallStart, ProxySelectStart, ProxySelectEnd, DnsStart, DnsEnd, ConnectStart, SecureConnectStart, + // SecureConnectEnd, ConnectEnd, ConnectionAcquired, RequestHeadersStart, RequestHeadersEnd, + // ResponseFailed, ConnectionReleased, CallFailed + // JDK 8 + // CallStart, ProxySelectStart, ProxySelectEnd, DnsStart, DnsEnd, ConnectStart, SecureConnectStart, + // ConnectFailed, CallFailed + // Gradle - JDK 11 + // CallStart, ProxySelectStart, ProxySelectEnd, DnsStart, DnsEnd, ConnectStart, SecureConnectStart, + // SecureConnectEnd, ConnectFailed, CallFailed + val recordedEventTypes = listener.recordedEventTypes() + assertThat(recordedEventTypes).startsWith( + "CallStart", + "ProxySelectStart", + "ProxySelectEnd", + "DnsStart", + "DnsEnd", + "ConnectStart", + "SecureConnectStart" + ) + assertThat(recordedEventTypes).endsWith("CallFailed") + } + + private fun buildClient( + heldCertificate: HeldCertificate?, vararg intermediates: X509Certificate + ): OkHttpClient { + val builder = HandshakeCertificates.Builder() + .addTrustedCertificate(serverRootCa.certificate) + if (heldCertificate != null) { + builder.heldCertificate(heldCertificate, *intermediates) + } + val handshakeCertificates = builder.build() + return clientTestRule.newClientBuilder() + .sslSocketFactory( + handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager + ) + .build() + } + + private fun buildServerSslSocketFactory(): SSLSocketFactory { + // The test uses JDK default SSL Context instead of the Platform provided one + // as Conscrypt seems to have some differences, we only want to test client side here. + return try { + val keyManager = newKeyManager( + null, serverCert, serverIntermediateCa.certificate + ) + val trustManager = newTrustManager( + null, + Arrays.asList(serverRootCa.certificate, clientRootCa.certificate), emptyList() + ) + val sslContext = SSLContext.getInstance("TLS") + sslContext.init( + arrayOf(keyManager), arrayOf(trustManager), + SecureRandom() + ) + sslContext.socketFactory + } catch (e: GeneralSecurityException) { + throw AssertionError(e) + } + } +} diff --git a/okhttp/src/test/java/okhttp3/internal/ws/RealWebSocketTest.java b/okhttp/src/test/java/okhttp3/internal/ws/RealWebSocketTest.java deleted file mode 100644 index cccc27c980af..000000000000 --- a/okhttp/src/test/java/okhttp3/internal/ws/RealWebSocketTest.java +++ /dev/null @@ -1,490 +0,0 @@ -/* - * Copyright (C) 2014 Square, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package okhttp3.internal.ws; - -import java.io.EOFException; -import java.io.IOException; -import java.net.ProtocolException; -import java.net.SocketTimeoutException; -import java.util.Random; -import okhttp3.Headers; -import okhttp3.Protocol; -import okhttp3.Request; -import okhttp3.Response; -import okhttp3.TestUtil; -import okhttp3.internal.concurrent.TaskFaker; -import okio.ByteString; -import okio.Okio; -import okio.Pipe; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Tag; -import org.junit.jupiter.api.Test; -import static okhttp3.internal.ws.RealWebSocket.DEFAULT_MINIMUM_DEFLATE_SIZE; -import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.fail; - -@Tag("Slow") -public final class RealWebSocketTest { - // NOTE: Fields are named 'client' and 'server' for cognitive simplicity. This differentiation has - // zero effect on the behavior of the WebSocket API which is why tests are only written once - // from the perspective of a single peer. - - private final Random random = new Random(0); - private final Pipe client2Server = new Pipe(8192L); - private final Pipe server2client = new Pipe(8192L); - - private final TaskFaker taskFaker = new TaskFaker(); - private final TestStreams client = new TestStreams( - true, taskFaker, server2client, client2Server); - private final TestStreams server = new TestStreams( - false, taskFaker, client2Server, server2client); - - @BeforeEach public void setUp() throws IOException { - client.initWebSocket(random, 0); - server.initWebSocket(random, 0); - } - - @AfterEach public void tearDown() throws Exception { - client.listener.assertExhausted(); - server.listener.assertExhausted(); - server.getSource().close(); - client.getSource().close(); - server.webSocket.tearDown(); - client.webSocket.tearDown(); - taskFaker.close(); - } - - @Test public void close() throws IOException { - client.webSocket.close(1000, "Hello!"); - taskFaker.runTasks(); - // This will trigger a close response. - assertThat(server.processNextFrame()).isFalse(); - server.listener.assertClosing(1000, "Hello!"); - server.webSocket.finishReader(); - server.webSocket.close(1000, "Goodbye!"); - assertThat(client.processNextFrame()).isFalse(); - client.listener.assertClosing(1000, "Goodbye!"); - client.webSocket.finishReader(); - server.listener.assertClosed(1000, "Hello!"); - client.listener.assertClosed(1000, "Goodbye!"); - } - - @Test public void clientCloseThenMethodsReturnFalse() throws IOException { - client.webSocket.close(1000, "Hello!"); - - assertThat(client.webSocket.close(1000, "Hello!")).isFalse(); - assertThat(client.webSocket.send("Hello!")).isFalse(); - } - - @Test public void clientCloseWith0Fails() throws IOException { - try { - client.webSocket.close(0, null); - fail(); - } catch (IllegalArgumentException expected) { - assertThat("Code must be in range [1000,5000): 0").isEqualTo(expected.getMessage()); - } - } - - @Test public void afterSocketClosedPingFailsWebSocket() throws IOException { - client2Server.source().close(); - client.webSocket.pong(ByteString.encodeUtf8("Ping!")); - taskFaker.runTasks(); - client.listener.assertFailure(IOException.class, "source is closed"); - - assertThat(client.webSocket.send("Hello!")).isFalse(); - } - - @Test public void socketClosedDuringMessageKillsWebSocket() throws IOException { - client2Server.source().close(); - - assertThat(client.webSocket.send("Hello!")).isTrue(); - taskFaker.runTasks(); - client.listener.assertFailure(IOException.class, "source is closed"); - - // A failed write prevents further use of the WebSocket instance. - assertThat(client.webSocket.send("Hello!")).isFalse(); - assertThat(client.webSocket.pong(ByteString.encodeUtf8("Ping!"))).isFalse(); - } - - @Test public void serverCloseThenWritingPingSucceeds() throws IOException { - server.webSocket.close(1000, "Hello!"); - taskFaker.runTasks(); - client.processNextFrame(); - client.listener.assertClosing(1000, "Hello!"); - - assertThat(client.webSocket.pong(ByteString.encodeUtf8("Pong?"))).isTrue(); - } - - @Test public void clientCanWriteMessagesAfterServerClose() throws IOException { - server.webSocket.close(1000, "Hello!"); - taskFaker.runTasks(); - client.processNextFrame(); - client.listener.assertClosing(1000, "Hello!"); - - assertThat(client.webSocket.send("Hi!")).isTrue(); - server.processNextFrame(); - server.listener.assertTextMessage("Hi!"); - } - - @Test public void serverCloseThenClientClose() throws IOException { - server.webSocket.close(1000, "Hello!"); - taskFaker.runTasks(); - - client.processNextFrame(); - client.listener.assertClosing(1000, "Hello!"); - assertThat(client.webSocket.close(1000, "Bye!")).isTrue(); - client.webSocket.finishReader(); - client.listener.assertClosed(1000, "Hello!"); - - server.processNextFrame(); - server.listener.assertClosing(1000, "Bye!"); - server.webSocket.finishReader(); - server.listener.assertClosed(1000, "Bye!"); - } - - @Test public void emptyCloseInitiatesShutdown() throws IOException { - server.getSink().write(ByteString.decodeHex("8800")).emit(); // Close without code. - client.processNextFrame(); - client.listener.assertClosing(1005, ""); - client.webSocket.finishReader(); - - assertThat(client.webSocket.close(1000, "Bye!")).isTrue(); - taskFaker.runTasks(); - server.processNextFrame(); - server.listener.assertClosing(1000, "Bye!"); - server.webSocket.finishReader(); - - client.listener.assertClosed(1005, ""); - } - - @Test public void clientCloseClosesConnection() throws IOException { - client.webSocket.close(1000, "Hello!"); - taskFaker.runTasks(); - assertThat(client.closed).isFalse(); - server.processNextFrame(); // Read client closing, send server close. - server.listener.assertClosing(1000, "Hello!"); - server.webSocket.finishReader(); - - server.webSocket.close(1000, "Goodbye!"); - client.processNextFrame(); // Read server closing, close connection. - taskFaker.runTasks(); - client.listener.assertClosing(1000, "Goodbye!"); - client.webSocket.finishReader(); - assertThat(client.closed).isTrue(); - - // Server and client both finished closing, connection is closed. - server.listener.assertClosed(1000, "Hello!"); - client.listener.assertClosed(1000, "Goodbye!"); - } - - @Test public void serverCloseClosesConnection() throws IOException { - server.webSocket.close(1000, "Hello!"); - taskFaker.runTasks(); - - client.processNextFrame(); // Read server close, send client close, close connection. - assertThat(client.closed).isFalse(); - client.listener.assertClosing(1000, "Hello!"); - client.webSocket.finishReader(); - - client.webSocket.close(1000, "Hello!"); - server.processNextFrame(); - server.listener.assertClosing(1000, "Hello!"); - server.webSocket.finishReader(); - - client.listener.assertClosed(1000, "Hello!"); - server.listener.assertClosed(1000, "Hello!"); - } - - @Test public void clientAndServerCloseClosesConnection() throws Exception { - // Send close from both sides at the same time. - server.webSocket.close(1000, "Hello!"); - taskFaker.runTasks(); - client.processNextFrame(); // Read close, close connection close. - - assertThat(client.closed).isFalse(); - client.webSocket.close(1000, "Hi!"); - server.processNextFrame(); - - client.listener.assertClosing(1000, "Hello!"); - server.listener.assertClosing(1000, "Hi!"); - client.webSocket.finishReader(); - server.webSocket.finishReader(); - client.listener.assertClosed(1000, "Hello!"); - server.listener.assertClosed(1000, "Hi!"); - taskFaker.runTasks(); - assertThat(client.closed).isTrue(); - - server.listener.assertExhausted(); // Client should not have sent second close. - client.listener.assertExhausted(); // Server should not have sent second close. - } - - @Test public void serverCloseBreaksReadMessageLoop() throws IOException { - server.webSocket.send("Hello!"); - server.webSocket.close(1000, "Bye!"); - taskFaker.runTasks(); - assertThat(client.processNextFrame()).isTrue(); - client.listener.assertTextMessage("Hello!"); - assertThat(client.processNextFrame()).isFalse(); - client.listener.assertClosing(1000, "Bye!"); - } - - @Test public void protocolErrorBeforeCloseSendsFailure() throws IOException { - server.getSink().write(ByteString.decodeHex("0a00")).emit(); // Invalid non-final ping frame. - - client.processNextFrame(); // Detects error, send close, close connection. - taskFaker.runTasks(); - client.webSocket.finishReader(); - assertThat(client.closed).isTrue(); - client.listener.assertFailure(ProtocolException.class, "Control frames must be final."); - - server.processNextFrame(); - taskFaker.runTasks(); - server.listener.assertFailure(); - } - - @Test public void protocolErrorInCloseResponseClosesConnection() throws IOException { - client.webSocket.close(1000, "Hello"); - taskFaker.runTasks(); - server.processNextFrame(); - // Not closed until close reply is received. - assertThat(client.closed).isFalse(); - - // Manually write an invalid masked close frame. - server.getSink().write(ByteString.decodeHex("888760b420bb635c68de0cd84f")).emit(); - - client.processNextFrame();// Detects error, disconnects immediately since close already sent. - client.webSocket.finishReader(); - assertThat(client.closed).isTrue(); - client.listener.assertFailure( - ProtocolException.class, "Server-sent frames must not be masked."); - - server.listener.assertClosing(1000, "Hello"); - server.listener.assertExhausted(); // Client should not have sent second close. - } - - @Test public void protocolErrorAfterCloseDoesNotSendClose() throws IOException { - client.webSocket.close(1000, "Hello!"); - taskFaker.runTasks(); - server.processNextFrame(); - - // Not closed until close reply is received. - assertThat(client.closed).isFalse(); - server.getSink().write(ByteString.decodeHex("0a00")).emit(); // Invalid non-final ping frame. - - client.processNextFrame(); // Detects error, disconnects immediately since close already sent. - client.webSocket.finishReader(); - taskFaker.runTasks(); - assertThat(client.closed).isTrue(); - client.listener.assertFailure(ProtocolException.class, "Control frames must be final."); - - server.listener.assertClosing(1000, "Hello!"); - - server.listener.assertExhausted(); // Client should not have sent second close. - } - - @Test public void networkErrorReportedAsFailure() throws IOException { - server.getSink().close(); - client.processNextFrame(); - taskFaker.runTasks(); - client.listener.assertFailure(EOFException.class); - } - - @Test public void closeThrowingFailsConnection() throws IOException { - client2Server.source().close(); - client.webSocket.close(1000, null); - taskFaker.runTasks(); - client.listener.assertFailure(IOException.class, "source is closed"); - } - - @Test public void closeMessageAndConnectionCloseThrowingDoesNotMaskOriginal() throws IOException { - // So when the client sends close it throws an IOException. - server.getSource().close(); - - client.webSocket.close(1000, "Bye!"); - taskFaker.runTasks(); - client.webSocket.finishReader(); - client.listener.assertFailure(IOException.class, "source is closed"); - assertThat(client.closed).isTrue(); - } - - @Test public void pingOnInterval() throws IOException { - client.initWebSocket(random, 500); - taskFaker.advanceUntil(ns(500L)); - - server.processNextFrame(); // Ping. - client.processNextFrame(); // Pong. - - taskFaker.advanceUntil(ns(1_000L)); - server.processNextFrame(); // Ping. - client.processNextFrame(); // Pong. - - taskFaker.advanceUntil(ns(1_500L)); - server.processNextFrame(); // Ping. - client.processNextFrame(); // Pong. - } - - @Test public void unacknowledgedPingFailsConnection() throws IOException { - client.initWebSocket(random, 500); - - // Don't process the ping and pong frames! - taskFaker.advanceUntil(ns(500L)); - taskFaker.advanceUntil(ns(1_000L)); - client.listener.assertFailure(SocketTimeoutException.class, - "sent ping but didn't receive pong within 500ms (after 0 successful ping/pongs)"); - } - - @Test public void unexpectedPongsDoNotInterfereWithFailureDetection() throws IOException { - client.initWebSocket(random, 500); - - // At 0ms the server sends 3 unexpected pongs. The client accepts 'em and ignores em. - server.webSocket.pong(ByteString.encodeUtf8("pong 1")); - taskFaker.runTasks(); - client.processNextFrame(); - server.webSocket.pong(ByteString.encodeUtf8("pong 2")); - client.processNextFrame(); - taskFaker.runTasks(); - server.webSocket.pong(ByteString.encodeUtf8("pong 3")); - client.processNextFrame(); - - // After 500ms the client automatically pings and the server pongs back. - taskFaker.advanceUntil(ns(500L)); - server.processNextFrame(); // Ping. - client.processNextFrame(); // Pong. - - // After 1000ms the client will attempt a ping 2, but we don't process it. That'll cause the - // client to fail at 1500ms when it's time to send ping 3 because pong 2 hasn't been received. - taskFaker.advanceUntil(ns(1_000L)); - taskFaker.advanceUntil(ns(1_500L)); - client.listener.assertFailure(SocketTimeoutException.class, - "sent ping but didn't receive pong within 500ms (after 1 successful ping/pongs)"); - } - - @Test public void messagesNotCompressedWhenNotConfigured() throws IOException { - String message = TestUtil.repeat('a', (int) DEFAULT_MINIMUM_DEFLATE_SIZE); - server.webSocket.send(message); - taskFaker.runTasks(); - - assertThat(client.clientSourceBufferSize()).isGreaterThan(message.length()); // Not compressed. - assertThat(client.processNextFrame()).isTrue(); - client.listener.assertTextMessage(message); - } - - @Test public void messagesCompressedWhenConfigured() throws IOException { - Headers headers = Headers.of("Sec-WebSocket-Extensions", "permessage-deflate"); - client.initWebSocket(random, 0, headers); - server.initWebSocket(random, 0, headers); - - String message = TestUtil.repeat('a', (int) DEFAULT_MINIMUM_DEFLATE_SIZE); - server.webSocket.send(message); - - taskFaker.runTasks(); - assertThat(client.clientSourceBufferSize()).isLessThan(message.length()); // Compressed! - assertThat(client.processNextFrame()).isTrue(); - client.listener.assertTextMessage(message); - } - - @Test public void smallMessagesNotCompressed() throws IOException { - Headers headers = Headers.of("Sec-WebSocket-Extensions", "permessage-deflate"); - client.initWebSocket(random, 0, headers); - server.initWebSocket(random, 0, headers); - - String message = TestUtil.repeat('a', (int) DEFAULT_MINIMUM_DEFLATE_SIZE - 1); - server.webSocket.send(message); - taskFaker.runTasks(); - - assertThat(client.clientSourceBufferSize()).isGreaterThan(message.length()); // Not compressed. - assertThat(client.processNextFrame()).isTrue(); - client.listener.assertTextMessage(message); - } - - private static long ns(long millis) { - return millis * 1_000_000L; - } - - /** One peer's streams, listener, and web socket in the test. */ - private static class TestStreams extends RealWebSocket.Streams { - private final String name; - private final WebSocketRecorder listener; - private final TaskFaker taskFaker; - final Pipe sourcePipe; - final Pipe sinkPipe; - private RealWebSocket webSocket; - boolean closed; - - public TestStreams(boolean client, TaskFaker taskFaker, Pipe source, Pipe sink) { - super(client, Okio.buffer(source.source()), Okio.buffer(sink.sink())); - this.name = client ? "client" : "server"; - this.listener = new WebSocketRecorder(name); - this.taskFaker = taskFaker; - this.sourcePipe = source; - this.sinkPipe = sink; - } - - public void initWebSocket(Random random, int pingIntervalMillis) throws IOException { - initWebSocket(random, pingIntervalMillis, Headers.of()); - } - - public void initWebSocket( - Random random, int pingIntervalMillis, Headers responseHeaders) throws IOException { - String url = "http://example.com/websocket"; - Response response = new Response.Builder() - .code(101) - .message("OK") - .request(new Request.Builder().url(url).build()) - .headers(responseHeaders) - .protocol(Protocol.HTTP_1_1) - .build(); - webSocket = new RealWebSocket(taskFaker.getTaskRunner(), response.request(), listener, random, - pingIntervalMillis, WebSocketExtensions.Companion.parse(responseHeaders), - DEFAULT_MINIMUM_DEFLATE_SIZE); - webSocket.initReaderAndWriter(name, this); - } - - /** - * Peeks the number of bytes available for the client to read immediately. This doesn't block so - * it requires that bytes have already been flushed by the server. - */ - public long clientSourceBufferSize() throws IOException { - getSource().request(1L); - return getSource().getBuffer().size(); - } - - public boolean processNextFrame() throws IOException { - return webSocket.processNextFrame(); - } - - @Override public void close() { - if (closed) { - throw new AssertionError("Already closed"); - } - try { - getSource().close(); - } catch (IOException ignored) { - } - try { - getSink().close(); - } catch (IOException ignored) { - } - closed = true; - } - - @Override public void cancel() { - sourcePipe.cancel(); - sinkPipe.cancel(); - } - } -} diff --git a/okhttp/src/test/java/okhttp3/internal/ws/RealWebSocketTest.kt b/okhttp/src/test/java/okhttp3/internal/ws/RealWebSocketTest.kt new file mode 100644 index 000000000000..ee86d660d56d --- /dev/null +++ b/okhttp/src/test/java/okhttp3/internal/ws/RealWebSocketTest.kt @@ -0,0 +1,497 @@ +/* + * Copyright (C) 2014 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3.internal.ws + +import java.io.EOFException +import java.io.IOException +import java.net.ProtocolException +import java.net.SocketTimeoutException +import java.util.Random +import okhttp3.Headers +import okhttp3.Headers.Companion.headersOf +import okhttp3.Protocol +import okhttp3.Request +import okhttp3.Response +import okhttp3.TestUtil.repeat +import okhttp3.internal.concurrent.TaskFaker +import okhttp3.internal.ws.WebSocketExtensions.Companion.parse +import okio.ByteString.Companion.decodeHex +import okio.ByteString.Companion.encodeUtf8 +import okio.Pipe +import okio.buffer +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Tag +import org.junit.jupiter.api.Test + +@Tag("Slow") +class RealWebSocketTest { + // NOTE: Fields are named 'client' and 'server' for cognitive simplicity. This differentiation has + // zero effect on the behavior of the WebSocket API which is why tests are only written once + // from the perspective of a single peer. + private val random = Random(0) + private val client2Server = Pipe(8192L) + private val server2client = Pipe(8192L) + private val taskFaker = TaskFaker() + private val client = TestStreams(true, taskFaker, server2client, client2Server) + private val server = TestStreams(false, taskFaker, client2Server, server2client) + + @BeforeEach + fun setUp() { + client.initWebSocket(random, 0) + server.initWebSocket(random, 0) + } + + @AfterEach + @Throws(Exception::class) + fun tearDown() { + client.listener.assertExhausted() + server.listener.assertExhausted() + server.source.close() + client.source.close() + server.webSocket!!.tearDown() + client.webSocket!!.tearDown() + taskFaker.close() + } + + @Test + fun close() { + client.webSocket!!.close(1000, "Hello!") + taskFaker.runTasks() + // This will trigger a close response. + assertThat(server.processNextFrame()).isFalse() + server.listener.assertClosing(1000, "Hello!") + server.webSocket!!.finishReader() + server.webSocket!!.close(1000, "Goodbye!") + assertThat(client.processNextFrame()).isFalse() + client.listener.assertClosing(1000, "Goodbye!") + client.webSocket!!.finishReader() + server.listener.assertClosed(1000, "Hello!") + client.listener.assertClosed(1000, "Goodbye!") + } + + @Test + fun clientCloseThenMethodsReturnFalse() { + client.webSocket!!.close(1000, "Hello!") + assertThat(client.webSocket!!.close(1000, "Hello!")).isFalse() + assertThat(client.webSocket!!.send("Hello!")).isFalse() + } + + @Test + fun clientCloseWith0Fails() { + try { + client.webSocket!!.close(0, null) + org.junit.jupiter.api.Assertions.fail() + } catch (expected: IllegalArgumentException) { + assertThat("Code must be in range [1000,5000): 0") + .isEqualTo(expected.message) + } + } + + @Test + fun afterSocketClosedPingFailsWebSocket() { + client2Server.source.close() + client.webSocket!!.pong("Ping!".encodeUtf8()) + taskFaker.runTasks() + client.listener.assertFailure(IOException::class.java, "source is closed") + assertThat(client.webSocket!!.send("Hello!")).isFalse() + } + + @Test + fun socketClosedDuringMessageKillsWebSocket() { + client2Server.source.close() + assertThat(client.webSocket!!.send("Hello!")).isTrue() + taskFaker.runTasks() + client.listener.assertFailure(IOException::class.java, "source is closed") + + // A failed write prevents further use of the WebSocket instance. + assertThat(client.webSocket!!.send("Hello!")).isFalse() + assertThat(client.webSocket!!.pong("Ping!".encodeUtf8())).isFalse() + } + + @Test + fun serverCloseThenWritingPingSucceeds() { + server.webSocket!!.close(1000, "Hello!") + taskFaker.runTasks() + client.processNextFrame() + client.listener.assertClosing(1000, "Hello!") + assertThat(client.webSocket!!.pong("Pong?".encodeUtf8())).isTrue() + } + + @Test + fun clientCanWriteMessagesAfterServerClose() { + server.webSocket!!.close(1000, "Hello!") + taskFaker.runTasks() + client.processNextFrame() + client.listener.assertClosing(1000, "Hello!") + assertThat(client.webSocket!!.send("Hi!")).isTrue() + server.processNextFrame() + server.listener.assertTextMessage("Hi!") + } + + @Test + fun serverCloseThenClientClose() { + server.webSocket!!.close(1000, "Hello!") + taskFaker.runTasks() + client.processNextFrame() + client.listener.assertClosing(1000, "Hello!") + assertThat(client.webSocket!!.close(1000, "Bye!")).isTrue() + client.webSocket!!.finishReader() + client.listener.assertClosed(1000, "Hello!") + server.processNextFrame() + server.listener.assertClosing(1000, "Bye!") + server.webSocket!!.finishReader() + server.listener.assertClosed(1000, "Bye!") + } + + @Test + fun emptyCloseInitiatesShutdown() { + server.sink.write("8800".decodeHex()).emit() // Close without code. + client.processNextFrame() + client.listener.assertClosing(1005, "") + client.webSocket!!.finishReader() + assertThat(client.webSocket!!.close(1000, "Bye!")).isTrue() + taskFaker.runTasks() + server.processNextFrame() + server.listener.assertClosing(1000, "Bye!") + server.webSocket!!.finishReader() + client.listener.assertClosed(1005, "") + } + + @Test + fun clientCloseClosesConnection() { + client.webSocket!!.close(1000, "Hello!") + taskFaker.runTasks() + assertThat(client.closed).isFalse() + server.processNextFrame() // Read client closing, send server close. + server.listener.assertClosing(1000, "Hello!") + server.webSocket!!.finishReader() + server.webSocket!!.close(1000, "Goodbye!") + client.processNextFrame() // Read server closing, close connection. + taskFaker.runTasks() + client.listener.assertClosing(1000, "Goodbye!") + client.webSocket!!.finishReader() + assertThat(client.closed).isTrue() + + // Server and client both finished closing, connection is closed. + server.listener.assertClosed(1000, "Hello!") + client.listener.assertClosed(1000, "Goodbye!") + } + + @Test + fun serverCloseClosesConnection() { + server.webSocket!!.close(1000, "Hello!") + taskFaker.runTasks() + client.processNextFrame() // Read server close, send client close, close connection. + assertThat(client.closed).isFalse() + client.listener.assertClosing(1000, "Hello!") + client.webSocket!!.finishReader() + client.webSocket!!.close(1000, "Hello!") + server.processNextFrame() + server.listener.assertClosing(1000, "Hello!") + server.webSocket!!.finishReader() + client.listener.assertClosed(1000, "Hello!") + server.listener.assertClosed(1000, "Hello!") + } + + @Test + @Throws(Exception::class) + fun clientAndServerCloseClosesConnection() { + // Send close from both sides at the same time. + server.webSocket!!.close(1000, "Hello!") + taskFaker.runTasks() + client.processNextFrame() // Read close, close connection close. + assertThat(client.closed).isFalse() + client.webSocket!!.close(1000, "Hi!") + server.processNextFrame() + client.listener.assertClosing(1000, "Hello!") + server.listener.assertClosing(1000, "Hi!") + client.webSocket!!.finishReader() + server.webSocket!!.finishReader() + client.listener.assertClosed(1000, "Hello!") + server.listener.assertClosed(1000, "Hi!") + taskFaker.runTasks() + assertThat(client.closed).isTrue() + server.listener.assertExhausted() // Client should not have sent second close. + client.listener.assertExhausted() // Server should not have sent second close. + } + + @Test + fun serverCloseBreaksReadMessageLoop() { + server.webSocket!!.send("Hello!") + server.webSocket!!.close(1000, "Bye!") + taskFaker.runTasks() + assertThat(client.processNextFrame()).isTrue() + client.listener.assertTextMessage("Hello!") + assertThat(client.processNextFrame()).isFalse() + client.listener.assertClosing(1000, "Bye!") + } + + @Test + fun protocolErrorBeforeCloseSendsFailure() { + server.sink.write("0a00".decodeHex()).emit() // Invalid non-final ping frame. + client.processNextFrame() // Detects error, send close, close connection. + taskFaker.runTasks() + client.webSocket!!.finishReader() + assertThat(client.closed).isTrue() + client.listener.assertFailure( + ProtocolException::class.java, + "Control frames must be final." + ) + server.processNextFrame() + taskFaker.runTasks() + server.listener.assertFailure() + } + + @Test + fun protocolErrorInCloseResponseClosesConnection() { + client.webSocket!!.close(1000, "Hello") + taskFaker.runTasks() + server.processNextFrame() + // Not closed until close reply is received. + assertThat(client.closed).isFalse() + + // Manually write an invalid masked close frame. + server.sink.write("888760b420bb635c68de0cd84f".decodeHex()).emit() + client.processNextFrame() // Detects error, disconnects immediately since close already sent. + client.webSocket!!.finishReader() + assertThat(client.closed).isTrue() + client.listener.assertFailure( + ProtocolException::class.java, "Server-sent frames must not be masked." + ) + server.listener.assertClosing(1000, "Hello") + server.listener.assertExhausted() // Client should not have sent second close. + } + + @Test + fun protocolErrorAfterCloseDoesNotSendClose() { + client.webSocket!!.close(1000, "Hello!") + taskFaker.runTasks() + server.processNextFrame() + + // Not closed until close reply is received. + assertThat(client.closed).isFalse() + server.sink.write("0a00".decodeHex()).emit() // Invalid non-final ping frame. + client.processNextFrame() // Detects error, disconnects immediately since close already sent. + client.webSocket!!.finishReader() + taskFaker.runTasks() + assertThat(client.closed).isTrue() + client.listener.assertFailure( + ProtocolException::class.java, + "Control frames must be final." + ) + server.listener.assertClosing(1000, "Hello!") + server.listener.assertExhausted() // Client should not have sent second close. + } + + @Test + fun networkErrorReportedAsFailure() { + server.sink.close() + client.processNextFrame() + taskFaker.runTasks() + client.listener.assertFailure(EOFException::class.java) + } + + @Test + fun closeThrowingFailsConnection() { + client2Server.source.close() + client.webSocket!!.close(1000, null) + taskFaker.runTasks() + client.listener.assertFailure(IOException::class.java, "source is closed") + } + + @Test + fun closeMessageAndConnectionCloseThrowingDoesNotMaskOriginal() { + // So when the client sends close it throws an IOException. + server.source.close() + client.webSocket!!.close(1000, "Bye!") + taskFaker.runTasks() + client.webSocket!!.finishReader() + client.listener.assertFailure(IOException::class.java, "source is closed") + assertThat(client.closed).isTrue() + } + + @Test + fun pingOnInterval() { + client.initWebSocket(random, 500) + taskFaker.advanceUntil(ns(500L)) + server.processNextFrame() // Ping. + client.processNextFrame() // Pong. + taskFaker.advanceUntil(ns(1000L)) + server.processNextFrame() // Ping. + client.processNextFrame() // Pong. + taskFaker.advanceUntil(ns(1500L)) + server.processNextFrame() // Ping. + client.processNextFrame() // Pong. + } + + @Test + fun unacknowledgedPingFailsConnection() { + client.initWebSocket(random, 500) + + // Don't process the ping and pong frames! + taskFaker.advanceUntil(ns(500L)) + taskFaker.advanceUntil(ns(1000L)) + client.listener.assertFailure( + SocketTimeoutException::class.java, + "sent ping but didn't receive pong within 500ms (after 0 successful ping/pongs)" + ) + } + + @Test + fun unexpectedPongsDoNotInterfereWithFailureDetection() { + client.initWebSocket(random, 500) + + // At 0ms the server sends 3 unexpected pongs. The client accepts 'em and ignores em. + server.webSocket!!.pong("pong 1".encodeUtf8()) + taskFaker.runTasks() + client.processNextFrame() + server.webSocket!!.pong("pong 2".encodeUtf8()) + client.processNextFrame() + taskFaker.runTasks() + server.webSocket!!.pong("pong 3".encodeUtf8()) + client.processNextFrame() + + // After 500ms the client automatically pings and the server pongs back. + taskFaker.advanceUntil(ns(500L)) + server.processNextFrame() // Ping. + client.processNextFrame() // Pong. + + // After 1000ms the client will attempt a ping 2, but we don't process it. That'll cause the + // client to fail at 1500ms when it's time to send ping 3 because pong 2 hasn't been received. + taskFaker.advanceUntil(ns(1000L)) + taskFaker.advanceUntil(ns(1500L)) + client.listener.assertFailure( + SocketTimeoutException::class.java, + "sent ping but didn't receive pong within 500ms (after 1 successful ping/pongs)" + ) + } + + @Test + fun messagesNotCompressedWhenNotConfigured() { + val message = repeat('a', RealWebSocket.DEFAULT_MINIMUM_DEFLATE_SIZE.toInt()) + server.webSocket!!.send(message) + taskFaker.runTasks() + assertThat(client.clientSourceBufferSize()) + .isGreaterThan(message.length.toLong()) // Not compressed. + assertThat(client.processNextFrame()).isTrue() + client.listener.assertTextMessage(message) + } + + @Test + fun messagesCompressedWhenConfigured() { + val headers = headersOf("Sec-WebSocket-Extensions", "permessage-deflate") + client.initWebSocket(random, 0, headers) + server.initWebSocket(random, 0, headers) + val message = repeat('a', RealWebSocket.DEFAULT_MINIMUM_DEFLATE_SIZE.toInt()) + server.webSocket!!.send(message) + taskFaker.runTasks() + assertThat(client.clientSourceBufferSize()) + .isLessThan(message.length.toLong()) // Compressed! + assertThat(client.processNextFrame()).isTrue() + client.listener.assertTextMessage(message) + } + + @Test + fun smallMessagesNotCompressed() { + val headers = headersOf("Sec-WebSocket-Extensions", "permessage-deflate") + client.initWebSocket(random, 0, headers) + server.initWebSocket(random, 0, headers) + val message = repeat('a', RealWebSocket.DEFAULT_MINIMUM_DEFLATE_SIZE.toInt() - 1) + server.webSocket!!.send(message) + taskFaker.runTasks() + assertThat(client.clientSourceBufferSize()) + .isGreaterThan(message.length.toLong()) // Not compressed. + assertThat(client.processNextFrame()).isTrue() + client.listener.assertTextMessage(message) + } + + /** One peer's streams, listener, and web socket in the test. */ + private class TestStreams( + client: Boolean, + private val taskFaker: TaskFaker, + private val sourcePipe: Pipe, + private val sinkPipe: Pipe, + ) : RealWebSocket.Streams(client, sourcePipe.source.buffer(), sinkPipe.sink.buffer()) { + private val name = if (client) "client" else "server" + val listener = WebSocketRecorder(name) + var webSocket: RealWebSocket? = null + var closed = false + + fun initWebSocket( + random: Random?, + pingIntervalMillis: Int, + responseHeaders: Headers? = headersOf(), + ) { + val url = "http://example.com/websocket" + val response = Response.Builder() + .code(101) + .message("OK") + .request(Request.Builder().url(url).build()) + .headers(responseHeaders!!) + .protocol(Protocol.HTTP_1_1) + .build() + webSocket = RealWebSocket( + taskFaker.taskRunner, response.request, listener, random!!, + pingIntervalMillis.toLong(), parse( + responseHeaders + ), + RealWebSocket.DEFAULT_MINIMUM_DEFLATE_SIZE + ) + webSocket!!.initReaderAndWriter(name, this) + } + + /** + * Peeks the number of bytes available for the client to read immediately. This doesn't block so + * it requires that bytes have already been flushed by the server. + */ + fun clientSourceBufferSize(): Long { + source.request(1L) + return source.buffer.size + } + + fun processNextFrame(): Boolean { + return webSocket!!.processNextFrame() + } + + override fun close() { + if (closed) { + throw AssertionError("Already closed") + } + try { + source.close() + } catch (ignored: IOException) { + } + try { + sink.close() + } catch (ignored: IOException) { + } + closed = true + } + + override fun cancel() { + sourcePipe.cancel() + sinkPipe.cancel() + } + } + + companion object { + private fun ns(millis: Long): Long { + return millis * 1000000L + } + } +} diff --git a/okhttp/src/test/java/okhttp3/internal/ws/WebSocketHttpTest.java b/okhttp/src/test/java/okhttp3/internal/ws/WebSocketHttpTest.java deleted file mode 100644 index d1736b372b2e..000000000000 --- a/okhttp/src/test/java/okhttp3/internal/ws/WebSocketHttpTest.java +++ /dev/null @@ -1,1100 +0,0 @@ -/* - * Copyright (C) 2014 Square, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package okhttp3.internal.ws; - -import java.io.EOFException; -import java.io.IOException; -import java.io.InterruptedIOException; -import java.net.HttpURLConnection; -import java.net.ProtocolException; -import java.net.SocketTimeoutException; -import java.time.Duration; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Random; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; -import mockwebserver3.Dispatcher; -import mockwebserver3.MockResponse; -import mockwebserver3.MockWebServer; -import mockwebserver3.RecordedRequest; -import mockwebserver3.SocketPolicy; -import mockwebserver3.SocketPolicy.KeepOpen; -import mockwebserver3.SocketPolicy.NoResponse; -import okhttp3.OkHttpClient; -import okhttp3.OkHttpClientTestRule; -import okhttp3.Protocol; -import okhttp3.RecordingEventListener; -import okhttp3.RecordingHostnameVerifier; -import okhttp3.Request; -import okhttp3.Response; -import okhttp3.TestLogHandler; -import okhttp3.TestUtil; -import okhttp3.WebSocket; -import okhttp3.WebSocketListener; -import okhttp3.internal.UnreadableResponseBody; -import okhttp3.internal.concurrent.TaskRunner; -import okhttp3.testing.Flaky; -import okhttp3.testing.PlatformRule; -import okhttp3.tls.HandshakeCertificates; -import okio.Buffer; -import okio.ByteString; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Tag; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.RegisterExtension; -import static java.util.Arrays.asList; -import static okhttp3.TestUtil.repeat; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.data.Offset.offset; -import static org.junit.jupiter.api.Assertions.fail; - -@Flaky -@Tag("Slow") -public final class WebSocketHttpTest { - // Flaky https://github.com/square/okhttp/issues/4515 - // Flaky https://github.com/square/okhttp/issues/4953 - - @RegisterExtension OkHttpClientTestRule clientTestRule = configureClientTestRule(); - @RegisterExtension PlatformRule platform = new PlatformRule(); - @RegisterExtension TestLogHandler testLogHandler = new TestLogHandler(OkHttpClient.class); - - private MockWebServer webServer; - private final HandshakeCertificates handshakeCertificates - = platform.localhostHandshakeCertificates(); - private final WebSocketRecorder clientListener = new WebSocketRecorder("client"); - private final WebSocketRecorder serverListener = new WebSocketRecorder("server"); - private final Random random = new Random(0); - private OkHttpClient client = clientTestRule.newClientBuilder() - .writeTimeout(Duration.ofMillis(500)) - .readTimeout(Duration.ofMillis(500)) - .addInterceptor(chain -> { - Response response = chain.proceed(chain.request()); - // Ensure application interceptors never see a null body. - assertThat(response.body()).isNotNull(); - return response; - }) - .build(); - - private OkHttpClientTestRule configureClientTestRule() { - OkHttpClientTestRule clientTestRule = new OkHttpClientTestRule(); - clientTestRule.setRecordTaskRunner(true); - return clientTestRule; - } - - @BeforeEach public void setUp(MockWebServer webServer) { - this.webServer = webServer; - - platform.assumeNotOpenJSSE(); - } - - @AfterEach public void tearDown() throws InterruptedException { - clientListener.assertExhausted(); - } - - @Test public void textMessage() { - webServer.enqueue(new MockResponse.Builder() - .webSocketUpgrade(serverListener) - .build()); - WebSocket webSocket = newWebSocket(); - - clientListener.assertOpen(); - WebSocket server = serverListener.assertOpen(); - - webSocket.send("Hello, WebSockets!"); - serverListener.assertTextMessage("Hello, WebSockets!"); - - closeWebSockets(webSocket, server); - } - - @Test public void binaryMessage() { - webServer.enqueue(new MockResponse.Builder() - .webSocketUpgrade(serverListener) - .build()); - WebSocket webSocket = newWebSocket(); - - clientListener.assertOpen(); - WebSocket server = serverListener.assertOpen(); - - webSocket.send(ByteString.encodeUtf8("Hello!")); - serverListener.assertBinaryMessage(ByteString.of(new byte[] {'H', 'e', 'l', 'l', 'o', '!'})); - - closeWebSockets(webSocket, server); - } - - @Test public void nullStringThrows() { - webServer.enqueue(new MockResponse.Builder() - .webSocketUpgrade(serverListener) - .build()); - WebSocket webSocket = newWebSocket(); - - clientListener.assertOpen(); - WebSocket server = serverListener.assertOpen(); - try { - webSocket.send((String) null); - fail(); - } catch (NullPointerException expected) { - } - - closeWebSockets(webSocket, server); - } - - @Test public void nullByteStringThrows() { - TestUtil.assumeNotWindows(); - - webServer.enqueue(new MockResponse.Builder() - .webSocketUpgrade(serverListener) - .build()); - WebSocket webSocket = newWebSocket(); - - clientListener.assertOpen(); - WebSocket server = serverListener.assertOpen(); - try { - webSocket.send((ByteString) null); - fail(); - } catch (NullPointerException expected) { - } - - closeWebSockets(webSocket, server); - } - - @Test public void serverMessage() { - webServer.enqueue(new MockResponse.Builder() - .webSocketUpgrade(serverListener) - .build()); - WebSocket webSocket = newWebSocket(); - - clientListener.assertOpen(); - WebSocket server = serverListener.assertOpen(); - - server.send("Hello, WebSockets!"); - clientListener.assertTextMessage("Hello, WebSockets!"); - - closeWebSockets(webSocket, server); - } - - @Test public void throwingOnOpenFailsImmediately() { - webServer.enqueue(new MockResponse.Builder() - .webSocketUpgrade(serverListener) - .build()); - - final RuntimeException e = new RuntimeException(); - clientListener.setNextEventDelegate(new WebSocketListener() { - @Override public void onOpen(WebSocket webSocket, Response response) { - throw e; - } - }); - newWebSocket(); - - serverListener.assertOpen(); - serverListener.assertFailure(EOFException.class); - serverListener.assertExhausted(); - clientListener.assertFailure(e); - } - - @Disabled("AsyncCall currently lets runtime exceptions propagate.") - @Test public void throwingOnFailLogs() throws Exception { - webServer.enqueue(new MockResponse.Builder() - .code(200) - .body("Body") - .build()); - - final RuntimeException e = new RuntimeException("boom"); - clientListener.setNextEventDelegate(new WebSocketListener() { - @Override public void onFailure(WebSocket webSocket, Throwable t, Response response) { - throw e; - } - }); - - newWebSocket(); - - assertThat(testLogHandler.take()).isEqualTo("INFO: [WS client] onFailure"); - } - - @Test public void throwingOnMessageClosesImmediatelyAndFails() { - webServer.enqueue(new MockResponse.Builder() - .webSocketUpgrade(serverListener) - .build()); - newWebSocket(); - - clientListener.assertOpen(); - WebSocket server = serverListener.assertOpen(); - - final RuntimeException e = new RuntimeException(); - clientListener.setNextEventDelegate(new WebSocketListener() { - @Override public void onMessage(WebSocket webSocket, String text) { - throw e; - } - }); - - server.send("Hello, WebSockets!"); - clientListener.assertFailure(e); - serverListener.assertFailure(EOFException.class); - serverListener.assertExhausted(); - } - - @Test public void throwingOnClosingClosesImmediatelyAndFails() { - webServer.enqueue(new MockResponse.Builder() - .webSocketUpgrade(serverListener) - .build()); - newWebSocket(); - - clientListener.assertOpen(); - WebSocket server = serverListener.assertOpen(); - - final RuntimeException e = new RuntimeException(); - clientListener.setNextEventDelegate(new WebSocketListener() { - @Override public void onClosing(WebSocket webSocket, int code, String reason) { - throw e; - } - }); - - server.close(1000, "bye"); - clientListener.assertFailure(e); - serverListener.assertFailure(); - serverListener.assertExhausted(); - } - - @Test public void unplannedCloseHandledByCloseWithoutFailure() { - webServer.enqueue(new MockResponse.Builder() - .webSocketUpgrade(serverListener) - .build()); - newWebSocket(); - - clientListener.assertOpen(); - WebSocket server = serverListener.assertOpen(); - clientListener.setNextEventDelegate(new WebSocketListener() { - @Override public void onClosing(WebSocket webSocket, int code, String reason) { - webSocket.close(1000, null); - } - }); - - server.close(1001, "bye"); - clientListener.assertClosed(1001, "bye"); - clientListener.assertExhausted(); - serverListener.assertClosing(1000, ""); - serverListener.assertClosed(1000, ""); - serverListener.assertExhausted(); - } - - @Test public void unplannedCloseHandledWithoutFailure() { - webServer.enqueue(new MockResponse.Builder() - .webSocketUpgrade(serverListener) - .build()); - newWebSocket(); - - WebSocket webSocket = clientListener.assertOpen(); - WebSocket server = serverListener.assertOpen(); - - closeWebSockets(webSocket, server); - } - - @Test public void non101RetainsBody() throws IOException { - webServer.enqueue(new MockResponse.Builder() - .code(200) - .body("Body") - .build()); - newWebSocket(); - - clientListener.assertFailure(200, "Body", ProtocolException.class, - "Expected HTTP 101 response but was '200 OK'"); - } - - @Test public void notFound() throws IOException { - webServer.enqueue(new MockResponse.Builder() - .status("HTTP/1.1 404 Not Found") - .build()); - newWebSocket(); - - clientListener.assertFailure(404, null, ProtocolException.class, - "Expected HTTP 101 response but was '404 Not Found'"); - } - - @Test public void clientTimeoutClosesBody() { - webServer.enqueue(new MockResponse.Builder() - .code(408) - .build()); - webServer.enqueue(new MockResponse.Builder() - .webSocketUpgrade(serverListener) - .build()); - WebSocket webSocket = newWebSocket(); - - clientListener.assertOpen(); - WebSocket server = serverListener.assertOpen(); - - webSocket.send("abc"); - serverListener.assertTextMessage("abc"); - - server.send("def"); - clientListener.assertTextMessage("def"); - - closeWebSockets(webSocket, server); - } - - @Test public void missingConnectionHeader() throws IOException { - webServer.enqueue(new MockResponse.Builder() - .code(101) - .setHeader("Upgrade", "websocket") - .setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk=") - .build()); - webServer.enqueue(new MockResponse.Builder() - .socketPolicy(SocketPolicy.DisconnectAtStart.INSTANCE) - .build()); - - RealWebSocket webSocket = newWebSocket(); - - clientListener.assertFailure(101, null, ProtocolException.class, - "Expected 'Connection' header value 'Upgrade' but was 'null'"); - - webSocket.cancel(); - } - - @Test public void wrongConnectionHeader() throws IOException { - webServer.enqueue(new MockResponse.Builder() - .code(101) - .setHeader("Upgrade", "websocket") - .setHeader("Connection", "Downgrade") - .setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk=") - .build()); - webServer.enqueue(new MockResponse.Builder() - .socketPolicy(SocketPolicy.DisconnectAtStart.INSTANCE) - .build()); - - RealWebSocket webSocket = newWebSocket(); - - clientListener.assertFailure(101, null, ProtocolException.class, - "Expected 'Connection' header value 'Upgrade' but was 'Downgrade'"); - - webSocket.cancel(); - } - - @Test public void missingUpgradeHeader() throws IOException { - webServer.enqueue(new MockResponse.Builder() - .code(101) - .setHeader("Connection", "Upgrade") - .setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk=") - .build()); - webServer.enqueue(new MockResponse.Builder() - .socketPolicy(SocketPolicy.DisconnectAtStart.INSTANCE) - .build()); - - RealWebSocket webSocket = newWebSocket(); - - clientListener.assertFailure(101, null, ProtocolException.class, - "Expected 'Upgrade' header value 'websocket' but was 'null'"); - - webSocket.cancel(); - } - - @Test public void wrongUpgradeHeader() throws IOException { - webServer.enqueue(new MockResponse.Builder() - .code(101) - .setHeader("Connection", "Upgrade") - .setHeader("Upgrade", "Pepsi") - .setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk=") - .build()); - webServer.enqueue(new MockResponse.Builder() - .socketPolicy(SocketPolicy.DisconnectAtStart.INSTANCE) - .build()); - - RealWebSocket webSocket = newWebSocket(); - - clientListener.assertFailure(101, null, ProtocolException.class, - "Expected 'Upgrade' header value 'websocket' but was 'Pepsi'"); - - webSocket.cancel(); - } - - @Test public void missingMagicHeader() throws IOException { - webServer.enqueue(new MockResponse.Builder() - .code(101) - .setHeader("Connection", "Upgrade") - .setHeader("Upgrade", "websocket") - .build()); - webServer.enqueue(new MockResponse.Builder() - .socketPolicy(SocketPolicy.DisconnectAtStart.INSTANCE) - .build()); - - RealWebSocket webSocket = newWebSocket(); - - clientListener.assertFailure(101, null, ProtocolException.class, - "Expected 'Sec-WebSocket-Accept' header value 'ujmZX4KXZqjwy6vi1aQFH5p4Ygk=' but was 'null'"); - - webSocket.cancel(); - } - - @Test public void wrongMagicHeader() throws IOException { - webServer.enqueue(new MockResponse.Builder() - .code(101) - .setHeader("Connection", "Upgrade") - .setHeader("Upgrade", "websocket") - .setHeader("Sec-WebSocket-Accept", "magic") - .build()); - webServer.enqueue(new MockResponse.Builder() - .socketPolicy(SocketPolicy.DisconnectAtStart.INSTANCE) - .build()); - - RealWebSocket webSocket = newWebSocket(); - - clientListener.assertFailure(101, null, ProtocolException.class, - "Expected 'Sec-WebSocket-Accept' header value 'ujmZX4KXZqjwy6vi1aQFH5p4Ygk=' but was 'magic'"); - - webSocket.cancel(); - } - - @Test public void clientIncludesForbiddenHeader() throws IOException { - newWebSocket(new Request.Builder() - .url(webServer.url("/")) - .header("Sec-WebSocket-Extensions", "permessage-deflate") - .build()); - - clientListener.assertFailure(ProtocolException.class, - "Request header not permitted: 'Sec-WebSocket-Extensions'"); - } - - @SuppressWarnings("KotlinInternalInJava") - @Test public void webSocketAndApplicationInterceptors() { - final AtomicInteger interceptedCount = new AtomicInteger(); - - client = client.newBuilder() - .addInterceptor(chain -> { - assertThat(chain.request().body()).isNull(); - Response response = chain.proceed(chain.request()); - assertThat(response.header("Connection")).isEqualTo("Upgrade"); - assertThat(response.body()).isInstanceOf(UnreadableResponseBody.class); - interceptedCount.incrementAndGet(); - return response; - }) - .build(); - - webServer.enqueue(new MockResponse.Builder() - .webSocketUpgrade(serverListener) - .build()); - - WebSocket webSocket = newWebSocket(); - clientListener.assertOpen(); - assertThat(interceptedCount.get()).isEqualTo(1); - - closeWebSockets(webSocket, serverListener.assertOpen()); - } - - @Test public void webSocketAndNetworkInterceptors() { - client = client.newBuilder() - .addNetworkInterceptor(chain -> { - throw new AssertionError(); // Network interceptors don't execute. - }) - .build(); - - webServer.enqueue(new MockResponse.Builder() - .webSocketUpgrade(serverListener) - .build()); - - WebSocket webSocket = newWebSocket(); - clientListener.assertOpen(); - WebSocket server = serverListener.assertOpen(); - - closeWebSockets(webSocket, server); - } - - @Test public void overflowOutgoingQueue() { - webServer.enqueue(new MockResponse.Builder() - .webSocketUpgrade(serverListener) - .build()); - - WebSocket webSocket = newWebSocket(); - clientListener.assertOpen(); - - // Send messages until the client's outgoing buffer overflows! - ByteString message = ByteString.of(new byte[1024 * 1024]); - long messageCount = 0; - while (true) { - boolean success = webSocket.send(message); - if (!success) break; - - messageCount++; - long queueSize = webSocket.queueSize(); - assertThat(queueSize).isBetween(0L, messageCount * message.size()); - // Expect to fail before enqueueing 32 MiB. - assertThat(messageCount).isLessThan(32L); - } - - // Confirm all sent messages were received, followed by a client-initiated close. - WebSocket server = serverListener.assertOpen(); - for (int i = 0; i < messageCount; i++) { - serverListener.assertBinaryMessage(message); - } - serverListener.assertClosing(1001, ""); - - // When the server acknowledges the close the connection shuts down gracefully. - server.close(1000, null); - clientListener.assertClosing(1000, ""); - clientListener.assertClosed(1000, ""); - serverListener.assertClosed(1001, ""); - } - - @Test public void closeReasonMaximumLength() { - webServer.enqueue(new MockResponse.Builder() - .webSocketUpgrade(serverListener) - .build()); - - String clientReason = repeat('C', 123); - String serverReason = repeat('S', 123); - - WebSocket webSocket = newWebSocket(); - WebSocket server = serverListener.assertOpen(); - - clientListener.assertOpen(); - webSocket.close(1000, clientReason); - serverListener.assertClosing(1000, clientReason); - - server.close(1000, serverReason); - clientListener.assertClosing(1000, serverReason); - clientListener.assertClosed(1000, serverReason); - - serverListener.assertClosed(1000, clientReason); - } - - @Test public void closeReasonTooLong() { - webServer.enqueue(new MockResponse.Builder() - .webSocketUpgrade(serverListener) - .build()); - - WebSocket webSocket = newWebSocket(); - WebSocket server = serverListener.assertOpen(); - - clientListener.assertOpen(); - String reason = repeat('X', 124); - try { - webSocket.close(1000, reason); - fail(); - } catch (IllegalArgumentException expected) { - assertThat(expected.getMessage()).isEqualTo(("reason.size() > 123: " + reason)); - } - - webSocket.close(1000, null); - serverListener.assertClosing(1000, ""); - - server.close(1000, null); - clientListener.assertClosing(1000, ""); - clientListener.assertClosed(1000, ""); - - serverListener.assertClosed(1000, ""); - } - - @Test public void wsScheme() { - TestUtil.assumeNotWindows(); - - websocketScheme("ws"); - } - - @Test public void wsUppercaseScheme() { - websocketScheme("WS"); - } - - @Test public void wssScheme() { - webServer.useHttps(handshakeCertificates.sslSocketFactory()); - client = client.newBuilder() - .sslSocketFactory( - handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager()) - .hostnameVerifier(new RecordingHostnameVerifier()) - .build(); - - websocketScheme("wss"); - } - - @Test public void httpsScheme() { - webServer.useHttps(handshakeCertificates.sslSocketFactory()); - client = client.newBuilder() - .sslSocketFactory( - handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager()) - .hostnameVerifier(new RecordingHostnameVerifier()) - .build(); - - websocketScheme("https"); - } - - @Test public void readTimeoutAppliesToHttpRequest() { - webServer.enqueue(new MockResponse.Builder() - .socketPolicy(NoResponse.INSTANCE) - .build()); - - WebSocket webSocket = newWebSocket(); - - clientListener.assertFailure(SocketTimeoutException.class, "timeout", "Read timed out"); - assertThat(webSocket.close(1000, null)).isFalse(); - } - - /** - * There's no read timeout when reading the first byte of a new frame. But as soon as we start - * reading a frame we enable the read timeout. In this test we have the server returning the first - * byte of a frame but no more frames. - */ - @Test public void readTimeoutAppliesWithinFrames() { - webServer.setDispatcher(new Dispatcher() { - @Override public MockResponse dispatch(RecordedRequest request) { - return upgradeResponse(request) - .body(new Buffer().write(ByteString.decodeHex("81"))) // Truncated frame. - .removeHeader("Content-Length") - .socketPolicy(KeepOpen.INSTANCE) - .build(); - } - }); - - WebSocket webSocket = newWebSocket(); - clientListener.assertOpen(); - - clientListener.assertFailure(SocketTimeoutException.class, "timeout", "Read timed out"); - assertThat(webSocket.close(1000, null)).isFalse(); - } - - @Test public void readTimeoutDoesNotApplyAcrossFrames() throws Exception { - webServer.enqueue(new MockResponse.Builder() - .webSocketUpgrade(serverListener) - .build()); - WebSocket webSocket = newWebSocket(); - - clientListener.assertOpen(); - WebSocket server = serverListener.assertOpen(); - - // Sleep longer than the HTTP client's read timeout. - Thread.sleep(client.readTimeoutMillis() + 500); - - server.send("abc"); - clientListener.assertTextMessage("abc"); - - closeWebSockets(webSocket, server); - } - - @Test public void clientPingsServerOnInterval() throws Exception { - client = client.newBuilder() - .pingInterval(Duration.ofMillis(500)) - .build(); - - webServer.enqueue(new MockResponse.Builder() - .webSocketUpgrade(serverListener) - .build()); - RealWebSocket webSocket = newWebSocket(); - - clientListener.assertOpen(); - RealWebSocket server = (RealWebSocket) serverListener.assertOpen(); - - long startNanos = System.nanoTime(); - while (webSocket.receivedPongCount() < 3) { - Thread.sleep(50); - } - - long elapsedUntilPong3 = System.nanoTime() - startNanos; - assertThat(TimeUnit.NANOSECONDS.toMillis(elapsedUntilPong3)) - .isCloseTo(1500L, offset(250L)); - - // The client pinged the server 3 times, and it has ponged back 3 times. - assertThat(webSocket.sentPingCount()).isEqualTo(3); - assertThat(server.receivedPingCount()).isEqualTo(3); - assertThat(webSocket.receivedPongCount()).isEqualTo(3); - - // The server has never pinged the client. - assertThat(server.receivedPongCount()).isEqualTo(0); - assertThat(webSocket.receivedPingCount()).isEqualTo(0); - - closeWebSockets(webSocket, server); - } - - @Test public void clientDoesNotPingServerByDefault() throws Exception { - webServer.enqueue(new MockResponse.Builder() - .webSocketUpgrade(serverListener) - .build()); - RealWebSocket webSocket = newWebSocket(); - - clientListener.assertOpen(); - RealWebSocket server = (RealWebSocket) serverListener.assertOpen(); - - Thread.sleep(1000); - - // No pings and no pongs. - assertThat(webSocket.sentPingCount()).isEqualTo(0); - assertThat(webSocket.receivedPingCount()).isEqualTo(0); - assertThat(webSocket.receivedPongCount()).isEqualTo(0); - assertThat(server.sentPingCount()).isEqualTo(0); - assertThat(server.receivedPingCount()).isEqualTo(0); - assertThat(server.receivedPongCount()).isEqualTo(0); - - closeWebSockets(webSocket, server); - } - - /** - * Configure the websocket to send pings every 500 ms. Artificially prevent the server from - * responding to pings. The client should give up when attempting to send its 2nd ping, at about - * 1000 ms. - */ - @Test public void unacknowledgedPingFailsConnection() { - TestUtil.assumeNotWindows(); - - client = client.newBuilder() - .pingInterval(Duration.ofMillis(500)) - .build(); - - // Stall in onOpen to prevent pongs from being sent. - final CountDownLatch latch = new CountDownLatch(1); - webServer.enqueue(new MockResponse.Builder() - .webSocketUpgrade(new WebSocketListener() { - @Override public void onOpen(WebSocket webSocket, Response response) { - try { - latch.await(); // The server can't respond to pings! - } catch (InterruptedException e) { - throw new AssertionError(e); - } - } - }) - .build()); - - long openAtNanos = System.nanoTime(); - newWebSocket(); - clientListener.assertOpen(); - clientListener.assertFailure(SocketTimeoutException.class, - "sent ping but didn't receive pong within 500ms (after 0 successful ping/pongs)"); - latch.countDown(); - - long elapsedUntilFailure = System.nanoTime() - openAtNanos; - assertThat(TimeUnit.NANOSECONDS.toMillis(elapsedUntilFailure)) - .isCloseTo(1000L, offset(250L)); - } - - /** https://github.com/square/okhttp/issues/2788 */ - @Test public void clientCancelsIfCloseIsNotAcknowledged() { - webServer.enqueue(new MockResponse.Builder() - .webSocketUpgrade(serverListener) - .build()); - RealWebSocket webSocket = newWebSocket(); - - clientListener.assertOpen(); - WebSocket server = serverListener.assertOpen(); - - // Initiate a close on the client, which will schedule a hard cancel in 500 ms. - long closeAtNanos = System.nanoTime(); - webSocket.close(1000, "goodbye", 500L); - serverListener.assertClosing(1000, "goodbye"); - - // Confirm that the hard cancel occurred after 500 ms. - clientListener.assertFailure(); - long elapsedUntilFailure = System.nanoTime() - closeAtNanos; - assertThat(TimeUnit.NANOSECONDS.toMillis(elapsedUntilFailure)) - .isCloseTo(500L, offset(250L)); - - // Close the server and confirm it saw what we expected. - server.close(1000, null); - serverListener.assertClosed(1000, "goodbye"); - } - - @Test public void webSocketsDontTriggerEventListener() { - RecordingEventListener listener = new RecordingEventListener(); - - client = client.newBuilder() - .eventListenerFactory(clientTestRule.wrap(listener)) - .build(); - - webServer.enqueue(new MockResponse.Builder() - .webSocketUpgrade(serverListener) - .build()); - WebSocket webSocket = newWebSocket(); - - clientListener.assertOpen(); - WebSocket server = serverListener.assertOpen(); - - webSocket.send("Web Sockets and Events?!"); - serverListener.assertTextMessage("Web Sockets and Events?!"); - - webSocket.close(1000, ""); - serverListener.assertClosing(1000, ""); - - server.close(1000, ""); - clientListener.assertClosing(1000, ""); - clientListener.assertClosed(1000, ""); - serverListener.assertClosed(1000, ""); - - assertThat(listener.recordedEventTypes()).isEmpty(); - } - - @Test public void callTimeoutAppliesToSetup() throws Exception { - webServer.enqueue(new MockResponse.Builder() - .headersDelay(500, TimeUnit.MILLISECONDS) - .build()); - - client = client.newBuilder() - .readTimeout(Duration.ZERO) - .writeTimeout(Duration.ZERO) - .callTimeout(Duration.ofMillis(100)) - .build(); - - newWebSocket(); - clientListener.assertFailure(InterruptedIOException.class, "timeout"); - } - - @Test public void callTimeoutDoesNotApplyOnceConnected() throws Exception { - client = client.newBuilder() - .callTimeout(Duration.ofMillis(100)) - .build(); - - webServer.enqueue(new MockResponse.Builder() - .webSocketUpgrade(serverListener) - .build()); - WebSocket webSocket = newWebSocket(); - - clientListener.assertOpen(); - WebSocket server = serverListener.assertOpen(); - - Thread.sleep(500); - - server.send("Hello, WebSockets!"); - clientListener.assertTextMessage("Hello, WebSockets!"); - - closeWebSockets(webSocket, server); - } - - /** - * We had a bug where web socket connections were leaked if the HTTP connection upgrade was not - * successful. This test confirms that connections are released back to the connection pool! - * https://github.com/square/okhttp/issues/4258 - */ - @Test public void webSocketConnectionIsReleased() throws Exception { - // This test assumes HTTP/1.1 pooling semantics. - client = client.newBuilder() - .protocols(asList(Protocol.HTTP_1_1)) - .build(); - - webServer.enqueue(new MockResponse.Builder() - .code(HttpURLConnection.HTTP_NOT_FOUND) - .body("not found!") - .build()); - webServer.enqueue(new MockResponse()); - - newWebSocket(); - clientListener.assertFailure(); - - Request regularRequest = new Request.Builder() - .url(webServer.url("/")) - .build(); - Response response = client.newCall(regularRequest).execute(); - response.close(); - - assertThat(webServer.takeRequest().getSequenceNumber()).isEqualTo(0); - assertThat(webServer.takeRequest().getSequenceNumber()).isEqualTo(1); - } - - /** https://github.com/square/okhttp/issues/5705 */ - @Test public void closeWithoutSuccessfulConnect() { - Request request = new Request.Builder() - .url(webServer.url("/")) - .build(); - WebSocket webSocket = client.newWebSocket(request, clientListener); - webSocket.send("hello"); - webSocket.close(1000, null); - } - - /** https://github.com/square/okhttp/issues/7768 */ - @Test public void reconnectingToNonWebSocket() throws InterruptedException { - for (int i = 0; i < 30; i++) { - webServer.enqueue(new MockResponse.Builder() - .bodyDelay(100, TimeUnit.MILLISECONDS) - .body("Wrong endpoint") - .code(401) - .build()); - } - - Request request = new Request.Builder() - .url(webServer.url("/")) - .build(); - - CountDownLatch attempts = new CountDownLatch(20); - - List webSockets = Collections.synchronizedList(new ArrayList<>()); - - WebSocketListener reconnectOnFailure = new WebSocketListener() { - @Override - public void onFailure(WebSocket webSocket, Throwable t, Response response) { - if (attempts.getCount() > 0) { - clientListener.setNextEventDelegate(this); - webSockets.add(client.newWebSocket(request, clientListener)); - attempts.countDown(); - } - } - }; - - clientListener.setNextEventDelegate(reconnectOnFailure); - - webSockets.add(client.newWebSocket(request, clientListener)); - - attempts.await(); - - synchronized (webSockets) { - for (WebSocket webSocket : webSockets) { - webSocket.cancel(); - } - } - } - - @Test public void compressedMessages() throws Exception { - successfulExtensions("permessage-deflate"); - } - - @Test public void compressedMessagesNoClientContextTakeover() throws Exception { - successfulExtensions("permessage-deflate; client_no_context_takeover"); - } - - @Test public void compressedMessagesNoServerContextTakeover() throws Exception { - successfulExtensions("permessage-deflate; server_no_context_takeover"); - } - - @Test public void unexpectedExtensionParameter() throws Exception { - extensionNegotiationFailure("permessage-deflate; unknown_parameter=15"); - } - - @Test public void clientMaxWindowBitsIncluded() throws Exception { - extensionNegotiationFailure("permessage-deflate; client_max_window_bits=15"); - } - - @Test public void serverMaxWindowBitsTooLow() throws Exception { - extensionNegotiationFailure("permessage-deflate; server_max_window_bits=7"); - } - - @Test public void serverMaxWindowBitsTooHigh() throws Exception { - extensionNegotiationFailure("permessage-deflate; server_max_window_bits=16"); - } - - @Test public void serverMaxWindowBitsJustRight() throws Exception { - successfulExtensions("permessage-deflate; server_max_window_bits=15"); - } - - private void successfulExtensions(String extensionsHeader) throws Exception { - webServer.enqueue(new MockResponse.Builder() - .addHeader("Sec-WebSocket-Extensions", extensionsHeader) - .webSocketUpgrade(serverListener) - .build()); - - WebSocket client = newWebSocket(); - clientListener.assertOpen(); - WebSocket server = serverListener.assertOpen(); - - // Server to client message big enough to be compressed. - String message1 = TestUtil.repeat('a', (int) RealWebSocket.DEFAULT_MINIMUM_DEFLATE_SIZE); - server.send(message1); - clientListener.assertTextMessage(message1); - - // Client to server message big enough to be compressed. - String message2 = TestUtil.repeat('b', (int) RealWebSocket.DEFAULT_MINIMUM_DEFLATE_SIZE); - client.send(message2); - serverListener.assertTextMessage(message2); - - // Empty server to client message. - String message3 = ""; - server.send(message3); - clientListener.assertTextMessage(message3); - - // Empty client to server message. - String message4 = ""; - client.send(message4); - serverListener.assertTextMessage(message4); - - // Server to client message that shares context with message1. - String message5 = message1 + message1; - server.send(message5); - clientListener.assertTextMessage(message5); - - // Client to server message that shares context with message2. - String message6 = message2 + message2; - client.send(message6); - serverListener.assertTextMessage(message6); - - closeWebSockets(client, server); - - RecordedRequest upgradeRequest = webServer.takeRequest(); - assertThat(upgradeRequest.getHeaders().get("Sec-WebSocket-Extensions")) - .isEqualTo("permessage-deflate"); - } - - private void extensionNegotiationFailure(String extensionsHeader) throws Exception { - webServer.enqueue(new MockResponse.Builder() - .addHeader("Sec-WebSocket-Extensions", extensionsHeader) - .webSocketUpgrade(serverListener) - .build()); - - newWebSocket(); - clientListener.assertOpen(); - WebSocket server = serverListener.assertOpen(); - - String clientReason = "unexpected Sec-WebSocket-Extensions in response header"; - serverListener.assertClosing(1010, clientReason); - server.close(1010, ""); - clientListener.assertClosing(1010, ""); - clientListener.assertClosed(1010, ""); - serverListener.assertClosed(1010, clientReason); - clientListener.assertExhausted(); - serverListener.assertExhausted(); - } - - private MockResponse.Builder upgradeResponse(RecordedRequest request) { - String key = request.getHeaders().get("Sec-WebSocket-Key"); - return new MockResponse.Builder() - .status("HTTP/1.1 101 Switching Protocols") - .setHeader("Connection", "Upgrade") - .setHeader("Upgrade", "websocket") - .setHeader("Sec-WebSocket-Accept", WebSocketProtocol.INSTANCE.acceptHeader(key)); - } - - private void websocketScheme(String scheme) { - webServer.enqueue(new MockResponse.Builder() - .webSocketUpgrade(serverListener) - .build()); - - Request request = new Request.Builder() - .url(scheme + "://" + webServer.getHostName() + ":" + webServer.getPort() + "/") - .build(); - - RealWebSocket webSocket = newWebSocket(request); - clientListener.assertOpen(); - WebSocket server = serverListener.assertOpen(); - - webSocket.send("abc"); - serverListener.assertTextMessage("abc"); - - closeWebSockets(webSocket, server); - } - - private RealWebSocket newWebSocket() { - return newWebSocket(new Request.Builder().get().url(webServer.url("/")).build()); - } - - private RealWebSocket newWebSocket(Request request) { - RealWebSocket webSocket = new RealWebSocket(TaskRunner.INSTANCE, request, clientListener, - random, client.pingIntervalMillis(), null, 0L); - webSocket.connect(client); - return webSocket; - } - - private void closeWebSockets(WebSocket client, WebSocket server) { - server.close(1001, ""); - clientListener.assertClosing(1001, ""); - client.close(1000, ""); - serverListener.assertClosing(1000, ""); - clientListener.assertClosed(1001, ""); - serverListener.assertClosed(1000, ""); - clientListener.assertExhausted(); - serverListener.assertExhausted(); - } -} diff --git a/okhttp/src/test/java/okhttp3/internal/ws/WebSocketHttpTest.kt b/okhttp/src/test/java/okhttp3/internal/ws/WebSocketHttpTest.kt new file mode 100644 index 000000000000..2c28e9c44359 --- /dev/null +++ b/okhttp/src/test/java/okhttp3/internal/ws/WebSocketHttpTest.kt @@ -0,0 +1,1143 @@ +/* + * Copyright (C) 2014 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3.internal.ws + +import java.io.EOFException +import java.io.IOException +import java.io.InterruptedIOException +import java.net.HttpURLConnection +import java.net.ProtocolException +import java.net.SocketTimeoutException +import java.time.Duration +import java.util.Arrays +import java.util.Collections +import java.util.Random +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicInteger +import mockwebserver3.Dispatcher +import mockwebserver3.MockResponse +import mockwebserver3.MockWebServer +import mockwebserver3.RecordedRequest +import mockwebserver3.SocketPolicy +import mockwebserver3.SocketPolicy.KeepOpen +import mockwebserver3.SocketPolicy.NoResponse +import okhttp3.Interceptor +import okhttp3.OkHttpClient +import okhttp3.OkHttpClientTestRule +import okhttp3.Protocol +import okhttp3.RecordingEventListener +import okhttp3.RecordingHostnameVerifier +import okhttp3.Request +import okhttp3.Response +import okhttp3.TestLogHandler +import okhttp3.TestUtil.assumeNotWindows +import okhttp3.TestUtil.repeat +import okhttp3.WebSocket +import okhttp3.WebSocketListener +import okhttp3.internal.UnreadableResponseBody +import okhttp3.internal.concurrent.TaskRunner +import okhttp3.internal.ws.WebSocketProtocol.acceptHeader +import okhttp3.testing.Flaky +import okhttp3.testing.PlatformRule +import okio.Buffer +import okio.ByteString +import okio.ByteString.Companion.decodeHex +import okio.ByteString.Companion.encodeUtf8 +import org.assertj.core.api.Assertions.assertThat +import org.assertj.core.data.Offset +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Disabled +import org.junit.jupiter.api.Tag +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.RegisterExtension + +@Flaky +@Tag("Slow") +class WebSocketHttpTest { + // Flaky https://github.com/square/okhttp/issues/4515 + // Flaky https://github.com/square/okhttp/issues/4953 + @RegisterExtension + var clientTestRule = configureClientTestRule() + + @RegisterExtension + var platform = PlatformRule() + + @RegisterExtension + var testLogHandler = TestLogHandler(OkHttpClient::class.java) + private lateinit var webServer: MockWebServer + private val handshakeCertificates = platform.localhostHandshakeCertificates() + private val clientListener = WebSocketRecorder("client") + private val serverListener = WebSocketRecorder("server") + private val random = Random(0) + private var client = clientTestRule.newClientBuilder() + .writeTimeout(Duration.ofMillis(500)) + .readTimeout(Duration.ofMillis(500)) + .addInterceptor(Interceptor { chain: Interceptor.Chain -> + val response = chain.proceed(chain.request()) + // Ensure application interceptors never see a null body. + assertThat(response.body).isNotNull() + response + }) + .build() + + private fun configureClientTestRule(): OkHttpClientTestRule { + val clientTestRule = OkHttpClientTestRule() + clientTestRule.recordTaskRunner = true + return clientTestRule + } + + @BeforeEach + fun setUp(webServer: MockWebServer) { + this.webServer = webServer + platform.assumeNotOpenJSSE() + } + + @AfterEach + @Throws(InterruptedException::class) + fun tearDown() { + clientListener.assertExhausted() + } + + @Test + fun textMessage() { + webServer.enqueue( + MockResponse.Builder() + .webSocketUpgrade(serverListener) + .build() + ) + val webSocket: WebSocket = newWebSocket() + clientListener.assertOpen() + val server = serverListener.assertOpen() + webSocket.send("Hello, WebSockets!") + serverListener.assertTextMessage("Hello, WebSockets!") + closeWebSockets(webSocket, server) + } + + @Test + fun binaryMessage() { + webServer.enqueue( + MockResponse.Builder() + .webSocketUpgrade(serverListener) + .build() + ) + val webSocket: WebSocket = newWebSocket() + clientListener.assertOpen() + val server = serverListener.assertOpen() + webSocket.send("Hello!".encodeUtf8()) + serverListener.assertBinaryMessage("Hello!".encodeUtf8()) + closeWebSockets(webSocket, server) + } + + @Test + fun serverMessage() { + webServer.enqueue( + MockResponse.Builder() + .webSocketUpgrade(serverListener) + .build() + ) + val webSocket: WebSocket = newWebSocket() + clientListener.assertOpen() + val server = serverListener.assertOpen() + server.send("Hello, WebSockets!") + clientListener.assertTextMessage("Hello, WebSockets!") + closeWebSockets(webSocket, server) + } + + @Test + fun throwingOnOpenFailsImmediately() { + webServer.enqueue( + MockResponse.Builder() + .webSocketUpgrade(serverListener) + .build() + ) + val e = RuntimeException() + clientListener.setNextEventDelegate(object : WebSocketListener() { + override fun onOpen(webSocket: WebSocket, response: Response) { + throw e + } + }) + newWebSocket() + serverListener.assertOpen() + serverListener.assertFailure(EOFException::class.java) + serverListener.assertExhausted() + clientListener.assertFailure(e) + } + + @Disabled("AsyncCall currently lets runtime exceptions propagate.") + @Test + @Throws( + Exception::class + ) + fun throwingOnFailLogs() { + webServer.enqueue( + MockResponse.Builder() + .code(200) + .body("Body") + .build() + ) + val e = RuntimeException("boom") + clientListener.setNextEventDelegate(object : WebSocketListener() { + override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) { + throw e + } + }) + newWebSocket() + assertThat(testLogHandler.take()).isEqualTo("INFO: [WS client] onFailure") + } + + @Test + fun throwingOnMessageClosesImmediatelyAndFails() { + webServer.enqueue( + MockResponse.Builder() + .webSocketUpgrade(serverListener) + .build() + ) + newWebSocket() + clientListener.assertOpen() + val server = serverListener.assertOpen() + val e = RuntimeException() + clientListener.setNextEventDelegate(object : WebSocketListener() { + override fun onMessage(webSocket: WebSocket, text: String) { + throw e + } + }) + server.send("Hello, WebSockets!") + clientListener.assertFailure(e) + serverListener.assertFailure(EOFException::class.java) + serverListener.assertExhausted() + } + + @Test + fun throwingOnClosingClosesImmediatelyAndFails() { + webServer.enqueue( + MockResponse.Builder() + .webSocketUpgrade(serverListener) + .build() + ) + newWebSocket() + clientListener.assertOpen() + val server = serverListener.assertOpen() + val e = RuntimeException() + clientListener.setNextEventDelegate(object : WebSocketListener() { + override fun onClosing(webSocket: WebSocket, code: Int, reason: String) { + throw e + } + }) + server.close(1000, "bye") + clientListener.assertFailure(e) + serverListener.assertFailure() + serverListener.assertExhausted() + } + + @Test + fun unplannedCloseHandledByCloseWithoutFailure() { + webServer.enqueue( + MockResponse.Builder() + .webSocketUpgrade(serverListener) + .build() + ) + newWebSocket() + clientListener.assertOpen() + val server = serverListener.assertOpen() + clientListener.setNextEventDelegate(object : WebSocketListener() { + override fun onClosing(webSocket: WebSocket, code: Int, reason: String) { + webSocket.close(1000, null) + } + }) + server.close(1001, "bye") + clientListener.assertClosed(1001, "bye") + clientListener.assertExhausted() + serverListener.assertClosing(1000, "") + serverListener.assertClosed(1000, "") + serverListener.assertExhausted() + } + + @Test + fun unplannedCloseHandledWithoutFailure() { + webServer.enqueue( + MockResponse.Builder() + .webSocketUpgrade(serverListener) + .build() + ) + newWebSocket() + val webSocket = clientListener.assertOpen() + val server = serverListener.assertOpen() + closeWebSockets(webSocket, server) + } + + @Test + @Throws(IOException::class) + fun non101RetainsBody() { + webServer.enqueue( + MockResponse.Builder() + .code(200) + .body("Body") + .build() + ) + newWebSocket() + clientListener.assertFailure( + 200, "Body", ProtocolException::class.java, + "Expected HTTP 101 response but was '200 OK'" + ) + } + + @Test + @Throws(IOException::class) + fun notFound() { + webServer.enqueue( + MockResponse.Builder() + .status("HTTP/1.1 404 Not Found") + .build() + ) + newWebSocket() + clientListener.assertFailure( + 404, null, ProtocolException::class.java, + "Expected HTTP 101 response but was '404 Not Found'" + ) + } + + @Test + fun clientTimeoutClosesBody() { + webServer.enqueue( + MockResponse.Builder() + .code(408) + .build() + ) + webServer.enqueue( + MockResponse.Builder() + .webSocketUpgrade(serverListener) + .build() + ) + val webSocket: WebSocket = newWebSocket() + clientListener.assertOpen() + val server = serverListener.assertOpen() + webSocket.send("abc") + serverListener.assertTextMessage("abc") + server.send("def") + clientListener.assertTextMessage("def") + closeWebSockets(webSocket, server) + } + + @Test + @Throws(IOException::class) + fun missingConnectionHeader() { + webServer.enqueue( + MockResponse.Builder() + .code(101) + .setHeader("Upgrade", "websocket") + .setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk=") + .build() + ) + webServer.enqueue( + MockResponse.Builder() + .socketPolicy(SocketPolicy.DisconnectAtStart) + .build() + ) + val webSocket = newWebSocket() + clientListener.assertFailure( + 101, null, ProtocolException::class.java, + "Expected 'Connection' header value 'Upgrade' but was 'null'" + ) + webSocket.cancel() + } + + @Test + @Throws(IOException::class) + fun wrongConnectionHeader() { + webServer.enqueue( + MockResponse.Builder() + .code(101) + .setHeader("Upgrade", "websocket") + .setHeader("Connection", "Downgrade") + .setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk=") + .build() + ) + webServer.enqueue( + MockResponse.Builder() + .socketPolicy(SocketPolicy.DisconnectAtStart) + .build() + ) + val webSocket = newWebSocket() + clientListener.assertFailure( + 101, null, ProtocolException::class.java, + "Expected 'Connection' header value 'Upgrade' but was 'Downgrade'" + ) + webSocket.cancel() + } + + @Test + @Throws(IOException::class) + fun missingUpgradeHeader() { + webServer.enqueue( + MockResponse.Builder() + .code(101) + .setHeader("Connection", "Upgrade") + .setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk=") + .build() + ) + webServer.enqueue( + MockResponse.Builder() + .socketPolicy(SocketPolicy.DisconnectAtStart) + .build() + ) + val webSocket = newWebSocket() + clientListener.assertFailure( + 101, null, ProtocolException::class.java, + "Expected 'Upgrade' header value 'websocket' but was 'null'" + ) + webSocket.cancel() + } + + @Test + @Throws(IOException::class) + fun wrongUpgradeHeader() { + webServer.enqueue( + MockResponse.Builder() + .code(101) + .setHeader("Connection", "Upgrade") + .setHeader("Upgrade", "Pepsi") + .setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk=") + .build() + ) + webServer.enqueue( + MockResponse.Builder() + .socketPolicy(SocketPolicy.DisconnectAtStart) + .build() + ) + val webSocket = newWebSocket() + clientListener.assertFailure( + 101, null, ProtocolException::class.java, + "Expected 'Upgrade' header value 'websocket' but was 'Pepsi'" + ) + webSocket.cancel() + } + + @Test + @Throws(IOException::class) + fun missingMagicHeader() { + webServer.enqueue( + MockResponse.Builder() + .code(101) + .setHeader("Connection", "Upgrade") + .setHeader("Upgrade", "websocket") + .build() + ) + webServer.enqueue( + MockResponse.Builder() + .socketPolicy(SocketPolicy.DisconnectAtStart) + .build() + ) + val webSocket = newWebSocket() + clientListener.assertFailure( + 101, null, ProtocolException::class.java, + "Expected 'Sec-WebSocket-Accept' header value 'ujmZX4KXZqjwy6vi1aQFH5p4Ygk=' but was 'null'" + ) + webSocket.cancel() + } + + @Test + @Throws(IOException::class) + fun wrongMagicHeader() { + webServer.enqueue( + MockResponse.Builder() + .code(101) + .setHeader("Connection", "Upgrade") + .setHeader("Upgrade", "websocket") + .setHeader("Sec-WebSocket-Accept", "magic") + .build() + ) + webServer.enqueue( + MockResponse.Builder() + .socketPolicy(SocketPolicy.DisconnectAtStart) + .build() + ) + val webSocket = newWebSocket() + clientListener.assertFailure( + 101, null, ProtocolException::class.java, + "Expected 'Sec-WebSocket-Accept' header value 'ujmZX4KXZqjwy6vi1aQFH5p4Ygk=' but was 'magic'" + ) + webSocket.cancel() + } + + @Test + @Throws(IOException::class) + fun clientIncludesForbiddenHeader() { + newWebSocket( + Request.Builder() + .url(webServer.url("/")) + .header("Sec-WebSocket-Extensions", "permessage-deflate") + .build() + ) + clientListener.assertFailure( + ProtocolException::class.java, + "Request header not permitted: 'Sec-WebSocket-Extensions'" + ) + } + + @Test + fun webSocketAndApplicationInterceptors() { + val interceptedCount = AtomicInteger() + client = client.newBuilder() + .addInterceptor(Interceptor { chain: Interceptor.Chain -> + assertThat(chain.request().body).isNull() + val response = chain.proceed(chain.request()) + assertThat(response.header("Connection")).isEqualTo("Upgrade") + assertThat(response.body).isInstanceOf( + UnreadableResponseBody::class.java + ) + interceptedCount.incrementAndGet() + response + }) + .build() + webServer.enqueue( + MockResponse.Builder() + .webSocketUpgrade(serverListener) + .build() + ) + val webSocket: WebSocket = newWebSocket() + clientListener.assertOpen() + assertThat(interceptedCount.get()).isEqualTo(1) + closeWebSockets(webSocket, serverListener.assertOpen()) + } + + @Test + fun webSocketAndNetworkInterceptors() { + client = client.newBuilder() + .addNetworkInterceptor(Interceptor { chain: Interceptor.Chain? -> + throw AssertionError() // Network interceptors don't execute. + }) + .build() + webServer.enqueue( + MockResponse.Builder() + .webSocketUpgrade(serverListener) + .build() + ) + val webSocket: WebSocket = newWebSocket() + clientListener.assertOpen() + val server = serverListener.assertOpen() + closeWebSockets(webSocket, server) + } + + @Test + fun overflowOutgoingQueue() { + webServer.enqueue( + MockResponse.Builder() + .webSocketUpgrade(serverListener) + .build() + ) + val webSocket: WebSocket = newWebSocket() + clientListener.assertOpen() + + // Send messages until the client's outgoing buffer overflows! + val message: ByteString = ByteString.of(*ByteArray(1024 * 1024)) + var messageCount: Long = 0 + while (true) { + val success = webSocket.send(message) + if (!success) break + messageCount++ + val queueSize = webSocket.queueSize() + assertThat(queueSize).isBetween(0L, messageCount * message.size) + // Expect to fail before enqueueing 32 MiB. + assertThat(messageCount).isLessThan(32L) + } + + // Confirm all sent messages were received, followed by a client-initiated close. + val server = serverListener.assertOpen() + for (i in 0 until messageCount) { + serverListener.assertBinaryMessage(message) + } + serverListener.assertClosing(1001, "") + + // When the server acknowledges the close the connection shuts down gracefully. + server.close(1000, null) + clientListener.assertClosing(1000, "") + clientListener.assertClosed(1000, "") + serverListener.assertClosed(1001, "") + } + + @Test + fun closeReasonMaximumLength() { + webServer.enqueue( + MockResponse.Builder() + .webSocketUpgrade(serverListener) + .build() + ) + val clientReason = repeat('C', 123) + val serverReason = repeat('S', 123) + val webSocket: WebSocket = newWebSocket() + val server = serverListener.assertOpen() + clientListener.assertOpen() + webSocket.close(1000, clientReason) + serverListener.assertClosing(1000, clientReason) + server.close(1000, serverReason) + clientListener.assertClosing(1000, serverReason) + clientListener.assertClosed(1000, serverReason) + serverListener.assertClosed(1000, clientReason) + } + + @Test + fun closeReasonTooLong() { + webServer.enqueue( + MockResponse.Builder() + .webSocketUpgrade(serverListener) + .build() + ) + val webSocket: WebSocket = newWebSocket() + val server = serverListener.assertOpen() + clientListener.assertOpen() + val reason = repeat('X', 124) + try { + webSocket.close(1000, reason) + org.junit.jupiter.api.Assertions.fail() + } catch (expected: IllegalArgumentException) { + assertThat(expected.message).isEqualTo("reason.size() > 123: $reason") + } + webSocket.close(1000, null) + serverListener.assertClosing(1000, "") + server.close(1000, null) + clientListener.assertClosing(1000, "") + clientListener.assertClosed(1000, "") + serverListener.assertClosed(1000, "") + } + + @Test + fun wsScheme() { + assumeNotWindows() + websocketScheme("ws") + } + + @Test + fun wsUppercaseScheme() { + websocketScheme("WS") + } + + @Test + fun wssScheme() { + webServer.useHttps(handshakeCertificates.sslSocketFactory()) + client = client.newBuilder() + .sslSocketFactory( + handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager + ) + .hostnameVerifier(RecordingHostnameVerifier()) + .build() + websocketScheme("wss") + } + + @Test + fun httpsScheme() { + webServer.useHttps(handshakeCertificates.sslSocketFactory()) + client = client.newBuilder() + .sslSocketFactory( + handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager + ) + .hostnameVerifier(RecordingHostnameVerifier()) + .build() + websocketScheme("https") + } + + @Test + fun readTimeoutAppliesToHttpRequest() { + webServer.enqueue( + MockResponse.Builder() + .socketPolicy(NoResponse) + .build() + ) + val webSocket: WebSocket = newWebSocket() + clientListener.assertFailure( + SocketTimeoutException::class.java, + "timeout", + "Read timed out" + ) + assertThat(webSocket.close(1000, null)).isFalse() + } + + /** + * There's no read timeout when reading the first byte of a new frame. But as soon as we start + * reading a frame we enable the read timeout. In this test we have the server returning the first + * byte of a frame but no more frames. + */ + @Test + fun readTimeoutAppliesWithinFrames() { + webServer.dispatcher = object : Dispatcher() { + override fun dispatch(request: RecordedRequest): MockResponse { + return upgradeResponse(request) + .body(Buffer().write("81".decodeHex())) // Truncated frame. + .removeHeader("Content-Length") + .socketPolicy(KeepOpen) + .build() + } + } + val webSocket: WebSocket = newWebSocket() + clientListener.assertOpen() + clientListener.assertFailure( + SocketTimeoutException::class.java, + "timeout", + "Read timed out" + ) + assertThat(webSocket.close(1000, null)).isFalse() + } + + @Test + @Throws(Exception::class) + fun readTimeoutDoesNotApplyAcrossFrames() { + webServer.enqueue( + MockResponse.Builder() + .webSocketUpgrade(serverListener) + .build() + ) + val webSocket: WebSocket = newWebSocket() + clientListener.assertOpen() + val server = serverListener.assertOpen() + + // Sleep longer than the HTTP client's read timeout. + Thread.sleep((client.readTimeoutMillis + 500).toLong()) + server.send("abc") + clientListener.assertTextMessage("abc") + closeWebSockets(webSocket, server) + } + + @Test + @Throws(Exception::class) + fun clientPingsServerOnInterval() { + client = client.newBuilder() + .pingInterval(Duration.ofMillis(500)) + .build() + webServer.enqueue( + MockResponse.Builder() + .webSocketUpgrade(serverListener) + .build() + ) + val webSocket = newWebSocket() + clientListener.assertOpen() + val server = serverListener.assertOpen() as RealWebSocket + val startNanos = System.nanoTime() + while (webSocket.receivedPongCount() < 3) { + Thread.sleep(50) + } + val elapsedUntilPong3 = System.nanoTime() - startNanos + assertThat(TimeUnit.NANOSECONDS.toMillis(elapsedUntilPong3)) + .isCloseTo(1500L, Offset.offset(250L)) + + // The client pinged the server 3 times, and it has ponged back 3 times. + assertThat(webSocket.sentPingCount()).isEqualTo(3) + assertThat(server.receivedPingCount()).isEqualTo(3) + assertThat(webSocket.receivedPongCount()).isEqualTo(3) + + // The server has never pinged the client. + assertThat(server.receivedPongCount()).isEqualTo(0) + assertThat(webSocket.receivedPingCount()).isEqualTo(0) + closeWebSockets(webSocket, server) + } + + @Test + @Throws(Exception::class) + fun clientDoesNotPingServerByDefault() { + webServer.enqueue( + MockResponse.Builder() + .webSocketUpgrade(serverListener) + .build() + ) + val webSocket = newWebSocket() + clientListener.assertOpen() + val server = serverListener.assertOpen() as RealWebSocket + Thread.sleep(1000) + + // No pings and no pongs. + assertThat(webSocket.sentPingCount()).isEqualTo(0) + assertThat(webSocket.receivedPingCount()).isEqualTo(0) + assertThat(webSocket.receivedPongCount()).isEqualTo(0) + assertThat(server.sentPingCount()).isEqualTo(0) + assertThat(server.receivedPingCount()).isEqualTo(0) + assertThat(server.receivedPongCount()).isEqualTo(0) + closeWebSockets(webSocket, server) + } + + /** + * Configure the websocket to send pings every 500 ms. Artificially prevent the server from + * responding to pings. The client should give up when attempting to send its 2nd ping, at about + * 1000 ms. + */ + @Test + fun unacknowledgedPingFailsConnection() { + assumeNotWindows() + client = client.newBuilder() + .pingInterval(Duration.ofMillis(500)) + .build() + + // Stall in onOpen to prevent pongs from being sent. + val latch = CountDownLatch(1) + webServer.enqueue( + MockResponse.Builder() + .webSocketUpgrade(object : WebSocketListener() { + override fun onOpen(webSocket: WebSocket, response: Response) { + try { + latch.await() // The server can't respond to pings! + } catch (e: InterruptedException) { + throw AssertionError(e) + } + } + }) + .build() + ) + val openAtNanos = System.nanoTime() + newWebSocket() + clientListener.assertOpen() + clientListener.assertFailure( + SocketTimeoutException::class.java, + "sent ping but didn't receive pong within 500ms (after 0 successful ping/pongs)" + ) + latch.countDown() + val elapsedUntilFailure = System.nanoTime() - openAtNanos + assertThat(TimeUnit.NANOSECONDS.toMillis(elapsedUntilFailure)) + .isCloseTo(1000L, Offset.offset(250L)) + } + + /** https://github.com/square/okhttp/issues/2788 */ + @Test + fun clientCancelsIfCloseIsNotAcknowledged() { + webServer.enqueue( + MockResponse.Builder() + .webSocketUpgrade(serverListener) + .build() + ) + val webSocket = newWebSocket() + clientListener.assertOpen() + val server = serverListener.assertOpen() + + // Initiate a close on the client, which will schedule a hard cancel in 500 ms. + val closeAtNanos = System.nanoTime() + webSocket.close(1000, "goodbye", 500L) + serverListener.assertClosing(1000, "goodbye") + + // Confirm that the hard cancel occurred after 500 ms. + clientListener.assertFailure() + val elapsedUntilFailure = System.nanoTime() - closeAtNanos + assertThat(TimeUnit.NANOSECONDS.toMillis(elapsedUntilFailure)) + .isCloseTo(500L, Offset.offset(250L)) + + // Close the server and confirm it saw what we expected. + server.close(1000, null) + serverListener.assertClosed(1000, "goodbye") + } + + @Test + fun webSocketsDontTriggerEventListener() { + val listener = RecordingEventListener() + client = client.newBuilder() + .eventListenerFactory(clientTestRule.wrap(listener)) + .build() + webServer.enqueue( + MockResponse.Builder() + .webSocketUpgrade(serverListener) + .build() + ) + val webSocket: WebSocket = newWebSocket() + clientListener.assertOpen() + val server = serverListener.assertOpen() + webSocket.send("Web Sockets and Events?!") + serverListener.assertTextMessage("Web Sockets and Events?!") + webSocket.close(1000, "") + serverListener.assertClosing(1000, "") + server.close(1000, "") + clientListener.assertClosing(1000, "") + clientListener.assertClosed(1000, "") + serverListener.assertClosed(1000, "") + assertThat(listener.recordedEventTypes()).isEmpty() + } + + @Test + @Throws(Exception::class) + fun callTimeoutAppliesToSetup() { + webServer.enqueue( + MockResponse.Builder() + .headersDelay(500, TimeUnit.MILLISECONDS) + .build() + ) + client = client.newBuilder() + .readTimeout(Duration.ZERO) + .writeTimeout(Duration.ZERO) + .callTimeout(Duration.ofMillis(100)) + .build() + newWebSocket() + clientListener.assertFailure(InterruptedIOException::class.java, "timeout") + } + + @Test + @Throws(Exception::class) + fun callTimeoutDoesNotApplyOnceConnected() { + client = client.newBuilder() + .callTimeout(Duration.ofMillis(100)) + .build() + webServer.enqueue( + MockResponse.Builder() + .webSocketUpgrade(serverListener) + .build() + ) + val webSocket: WebSocket = newWebSocket() + clientListener.assertOpen() + val server = serverListener.assertOpen() + Thread.sleep(500) + server.send("Hello, WebSockets!") + clientListener.assertTextMessage("Hello, WebSockets!") + closeWebSockets(webSocket, server) + } + + /** + * We had a bug where web socket connections were leaked if the HTTP connection upgrade was not + * successful. This test confirms that connections are released back to the connection pool! + * https://github.com/square/okhttp/issues/4258 + */ + @Test + @Throws(Exception::class) + fun webSocketConnectionIsReleased() { + // This test assumes HTTP/1.1 pooling semantics. + client = client.newBuilder() + .protocols(Arrays.asList(Protocol.HTTP_1_1)) + .build() + webServer.enqueue( + MockResponse.Builder() + .code(HttpURLConnection.HTTP_NOT_FOUND) + .body("not found!") + .build() + ) + webServer.enqueue(MockResponse()) + newWebSocket() + clientListener.assertFailure() + val regularRequest = Request.Builder() + .url(webServer.url("/")) + .build() + val response = client.newCall(regularRequest).execute() + response.close() + assertThat(webServer.takeRequest().sequenceNumber).isEqualTo(0) + assertThat(webServer.takeRequest().sequenceNumber).isEqualTo(1) + } + + /** https://github.com/square/okhttp/issues/5705 */ + @Test + fun closeWithoutSuccessfulConnect() { + val request = Request.Builder() + .url(webServer.url("/")) + .build() + val webSocket = client.newWebSocket(request, clientListener) + webSocket.send("hello") + webSocket.close(1000, null) + } + + /** https://github.com/square/okhttp/issues/7768 */ + @Test + @Throws(InterruptedException::class) + fun reconnectingToNonWebSocket() { + for (i in 0..29) { + webServer.enqueue( + MockResponse.Builder() + .bodyDelay(100, TimeUnit.MILLISECONDS) + .body("Wrong endpoint") + .code(401) + .build() + ) + } + val request = Request.Builder() + .url(webServer.url("/")) + .build() + val attempts = CountDownLatch(20) + val webSockets = Collections.synchronizedList(ArrayList()) + val reconnectOnFailure: WebSocketListener = object : WebSocketListener() { + override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) { + if (attempts.count > 0) { + clientListener.setNextEventDelegate(this) + webSockets.add(client.newWebSocket(request, clientListener)) + attempts.countDown() + } + } + } + clientListener.setNextEventDelegate(reconnectOnFailure) + webSockets.add(client.newWebSocket(request, clientListener)) + attempts.await() + synchronized(webSockets) { + for (webSocket in webSockets) { + webSocket.cancel() + } + } + } + + @Test + @Throws(Exception::class) + fun compressedMessages() { + successfulExtensions("permessage-deflate") + } + + @Test + @Throws(Exception::class) + fun compressedMessagesNoClientContextTakeover() { + successfulExtensions("permessage-deflate; client_no_context_takeover") + } + + @Test + @Throws(Exception::class) + fun compressedMessagesNoServerContextTakeover() { + successfulExtensions("permessage-deflate; server_no_context_takeover") + } + + @Test + @Throws(Exception::class) + fun unexpectedExtensionParameter() { + extensionNegotiationFailure("permessage-deflate; unknown_parameter=15") + } + + @Test + @Throws(Exception::class) + fun clientMaxWindowBitsIncluded() { + extensionNegotiationFailure("permessage-deflate; client_max_window_bits=15") + } + + @Test + @Throws(Exception::class) + fun serverMaxWindowBitsTooLow() { + extensionNegotiationFailure("permessage-deflate; server_max_window_bits=7") + } + + @Test + @Throws(Exception::class) + fun serverMaxWindowBitsTooHigh() { + extensionNegotiationFailure("permessage-deflate; server_max_window_bits=16") + } + + @Test + @Throws(Exception::class) + fun serverMaxWindowBitsJustRight() { + successfulExtensions("permessage-deflate; server_max_window_bits=15") + } + + @Throws(Exception::class) + private fun successfulExtensions(extensionsHeader: String) { + webServer.enqueue( + MockResponse.Builder() + .addHeader("Sec-WebSocket-Extensions", extensionsHeader) + .webSocketUpgrade(serverListener) + .build() + ) + val client: WebSocket = newWebSocket() + clientListener.assertOpen() + val server = serverListener.assertOpen() + + // Server to client message big enough to be compressed. + val message1 = repeat('a', RealWebSocket.DEFAULT_MINIMUM_DEFLATE_SIZE.toInt()) + server.send(message1) + clientListener.assertTextMessage(message1) + + // Client to server message big enough to be compressed. + val message2 = repeat('b', RealWebSocket.DEFAULT_MINIMUM_DEFLATE_SIZE.toInt()) + client.send(message2) + serverListener.assertTextMessage(message2) + + // Empty server to client message. + val message3 = "" + server.send(message3) + clientListener.assertTextMessage(message3) + + // Empty client to server message. + val message4 = "" + client.send(message4) + serverListener.assertTextMessage(message4) + + // Server to client message that shares context with message1. + val message5 = message1 + message1 + server.send(message5) + clientListener.assertTextMessage(message5) + + // Client to server message that shares context with message2. + val message6 = message2 + message2 + client.send(message6) + serverListener.assertTextMessage(message6) + closeWebSockets(client, server) + val upgradeRequest = webServer.takeRequest() + assertThat(upgradeRequest.headers["Sec-WebSocket-Extensions"]) + .isEqualTo("permessage-deflate") + } + + @Throws(Exception::class) + private fun extensionNegotiationFailure(extensionsHeader: String) { + webServer.enqueue( + MockResponse.Builder() + .addHeader("Sec-WebSocket-Extensions", extensionsHeader) + .webSocketUpgrade(serverListener) + .build() + ) + newWebSocket() + clientListener.assertOpen() + val server = serverListener.assertOpen() + val clientReason = "unexpected Sec-WebSocket-Extensions in response header" + serverListener.assertClosing(1010, clientReason) + server.close(1010, "") + clientListener.assertClosing(1010, "") + clientListener.assertClosed(1010, "") + serverListener.assertClosed(1010, clientReason) + clientListener.assertExhausted() + serverListener.assertExhausted() + } + + private fun upgradeResponse(request: RecordedRequest): MockResponse.Builder { + val key = request.headers["Sec-WebSocket-Key"] + return MockResponse.Builder() + .status("HTTP/1.1 101 Switching Protocols") + .setHeader("Connection", "Upgrade") + .setHeader("Upgrade", "websocket") + .setHeader("Sec-WebSocket-Accept", acceptHeader(key!!)) + } + + private fun websocketScheme(scheme: String) { + webServer.enqueue( + MockResponse.Builder() + .webSocketUpgrade(serverListener) + .build() + ) + val request = Request.Builder() + .url(scheme + "://" + webServer.hostName + ":" + webServer.port + "/") + .build() + val webSocket = newWebSocket(request) + clientListener.assertOpen() + val server = serverListener.assertOpen() + webSocket.send("abc") + serverListener.assertTextMessage("abc") + closeWebSockets(webSocket, server) + } + + private fun newWebSocket( + request: Request = Request.Builder().get().url( + webServer.url("/") + ).build() + ): RealWebSocket { + val webSocket = RealWebSocket( + TaskRunner.INSTANCE, request, clientListener, + random, client.pingIntervalMillis.toLong(), null, 0L + ) + webSocket.connect(client) + return webSocket + } + + private fun closeWebSockets(client: WebSocket, server: WebSocket) { + server.close(1001, "") + clientListener.assertClosing(1001, "") + client.close(1000, "") + serverListener.assertClosing(1000, "") + clientListener.assertClosed(1001, "") + serverListener.assertClosed(1000, "") + clientListener.assertExhausted() + serverListener.assertExhausted() + } +}