diff --git a/okhttp-dnsoverhttps/src/main/kotlin/okhttp3/dnsoverhttps/DnsOverHttps.kt b/okhttp-dnsoverhttps/src/main/kotlin/okhttp3/dnsoverhttps/DnsOverHttps.kt index 2944ddebb947..cc3872e5e194 100644 --- a/okhttp-dnsoverhttps/src/main/kotlin/okhttp3/dnsoverhttps/DnsOverHttps.kt +++ b/okhttp-dnsoverhttps/src/main/kotlin/okhttp3/dnsoverhttps/DnsOverHttps.kt @@ -187,15 +187,22 @@ class DnsOverHttps internal constructor( } private fun getCacheOnlyResponse(request: Request): Response? { - if (!post && client.cache != null) { + if (client.cache != null) { try { // Use the cache without hitting the network first // 504 code indicates that the Cache is stale - val preferCache = + val onlyIfCached = CacheControl.Builder() .onlyIfCached() .build() - val cacheRequest = request.newBuilder().cacheControl(preferCache).build() + + var cacheUrl = request.url + + val cacheRequest = + request.newBuilder() + .cacheControl(onlyIfCached) + .cacheUrlOverride(cacheUrl) + .build() val cacheResponse = client.newCall(cacheRequest).execute() @@ -247,7 +254,12 @@ class DnsOverHttps internal constructor( val query = DnsRecordCodec.encodeQuery(hostname, type) if (post) { - url(url).post(query.toRequestBody(DNS_MESSAGE)) + url(url) + .cacheUrlOverride( + url.newBuilder() + .addQueryParameter("hostname", hostname).build(), + ) + .post(query.toRequestBody(DNS_MESSAGE)) } else { val encoded = query.base64Url().replace("=", "") val requestUrl = url.newBuilder().addQueryParameter("dns", encoded).build() diff --git a/okhttp-dnsoverhttps/src/test/java/okhttp3/dnsoverhttps/DnsOverHttpsTest.kt b/okhttp-dnsoverhttps/src/test/java/okhttp3/dnsoverhttps/DnsOverHttpsTest.kt index 6cdac3bc1fe4..78bbaaac78fc 100644 --- a/okhttp-dnsoverhttps/src/test/java/okhttp3/dnsoverhttps/DnsOverHttpsTest.kt +++ b/okhttp-dnsoverhttps/src/test/java/okhttp3/dnsoverhttps/DnsOverHttpsTest.kt @@ -22,6 +22,7 @@ import assertk.assertions.containsExactlyInAnyOrder import assertk.assertions.hasMessage import assertk.assertions.isEqualTo import assertk.assertions.isInstanceOf +import assertk.assertions.isNull import java.io.EOFException import java.io.File import java.io.IOException @@ -167,30 +168,79 @@ class DnsOverHttpsTest { // 3. successful network response // 4. successful stale cached GET response // 5. unsuccessful response - // TODO how closely to follow POST rules on caching? @Test fun usesCache() { val cache = Cache("cache".toPath(), (100 * 1024).toLong(), cacheFs) val cachedClient = bootstrapClient.newBuilder().cache(cache).build() val cachedDns = buildLocalhost(cachedClient, false) - server.enqueue( - dnsResponse( - "0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c000500010" + - "0000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010001000" + - "0003b00049df00112", + + repeat(2) { + server.enqueue( + dnsResponse( + "0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c000500010" + + "0000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010001000" + + "0003b00049df00112", + ) + .newBuilder() + .setHeader("cache-control", "private, max-age=298") + .build(), ) - .newBuilder() - .setHeader("cache-control", "private, max-age=298") - .build(), - ) + } + var result = cachedDns.lookup("google.com") assertThat(result).containsExactly(address("157.240.1.18")) - val recordedRequest = server.takeRequest() + var recordedRequest = server.takeRequest() assertThat(recordedRequest.method).isEqualTo("GET") assertThat(recordedRequest.path) .isEqualTo("/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ") + + result = cachedDns.lookup("google.com") + assertThat(server.takeRequest(1, TimeUnit.MILLISECONDS)).isNull() + assertThat(result).isEqualTo(listOf(address("157.240.1.18"))) + + result = cachedDns.lookup("www.google.com") + assertThat(result).containsExactly(address("157.240.1.18")) + recordedRequest = server.takeRequest() + assertThat(recordedRequest.method).isEqualTo("GET") + assertThat(recordedRequest.path) + .isEqualTo("/lookup?ct&dns=AAABAAABAAAAAAAAA3d3dwZnb29nbGUDY29tAAABAAE") + } + + @Test + fun usesCacheEvenForPost() { + val cache = Cache("cache".toPath(), (100 * 1024).toLong(), cacheFs) + val cachedClient = bootstrapClient.newBuilder().cache(cache).build() + val cachedDns = buildLocalhost(cachedClient, false, post = true) + repeat(2) { + server.enqueue( + dnsResponse( + "0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c000500010" + + "0000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010001000" + + "0003b00049df00112", + ) + .newBuilder() + .setHeader("cache-control", "private, max-age=298") + .build(), + ) + } + + var result = cachedDns.lookup("google.com") + assertThat(result).containsExactly(address("157.240.1.18")) + var recordedRequest = server.takeRequest() + assertThat(recordedRequest.method).isEqualTo("POST") + assertThat(recordedRequest.path) + .isEqualTo("/lookup?ct") + result = cachedDns.lookup("google.com") + assertThat(server.takeRequest(0, TimeUnit.MILLISECONDS)).isNull() assertThat(result).isEqualTo(listOf(address("157.240.1.18"))) + + result = cachedDns.lookup("www.google.com") + assertThat(result).containsExactly(address("157.240.1.18")) + recordedRequest = server.takeRequest(0, TimeUnit.MILLISECONDS)!! + assertThat(recordedRequest.method).isEqualTo("POST") + assertThat(recordedRequest.path) + .isEqualTo("/lookup?ct") } @Test @@ -245,12 +295,14 @@ class DnsOverHttpsTest { private fun buildLocalhost( bootstrapClient: OkHttpClient, includeIPv6: Boolean, + post: Boolean = false, ): DnsOverHttps { val url = server.url("/lookup?ct") return DnsOverHttps.Builder().client(bootstrapClient) .includeIPv6(includeIPv6) .resolvePrivateAddresses(true) .url(url) + .post(post) .build() } diff --git a/okhttp/api/okhttp.api b/okhttp/api/okhttp.api index 8b7d85d030e6..f784a8736968 100644 --- a/okhttp/api/okhttp.api +++ b/okhttp/api/okhttp.api @@ -1030,6 +1030,7 @@ public final class okhttp3/Request { public synthetic fun (Lokhttp3/HttpUrl;Lokhttp3/Headers;Ljava/lang/String;Lokhttp3/RequestBody;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun body ()Lokhttp3/RequestBody; public final fun cacheControl ()Lokhttp3/CacheControl; + public final fun cacheUrlOverride ()Lokhttp3/HttpUrl; public final fun header (Ljava/lang/String;)Ljava/lang/String; public final fun headers ()Lokhttp3/Headers; public final fun headers (Ljava/lang/String;)Ljava/util/List; @@ -1048,6 +1049,7 @@ public class okhttp3/Request$Builder { public fun addHeader (Ljava/lang/String;Ljava/lang/String;)Lokhttp3/Request$Builder; public fun build ()Lokhttp3/Request; public fun cacheControl (Lokhttp3/CacheControl;)Lokhttp3/Request$Builder; + public final fun cacheUrlOverride (Lokhttp3/HttpUrl;)Lokhttp3/Request$Builder; public final fun delete ()Lokhttp3/Request$Builder; public fun delete (Lokhttp3/RequestBody;)Lokhttp3/Request$Builder; public static synthetic fun delete$default (Lokhttp3/Request$Builder;Lokhttp3/RequestBody;ILjava/lang/Object;)Lokhttp3/Request$Builder; diff --git a/okhttp/src/main/kotlin/okhttp3/Request.kt b/okhttp/src/main/kotlin/okhttp3/Request.kt index 4e399789c562..57cdc93f331e 100644 --- a/okhttp/src/main/kotlin/okhttp3/Request.kt +++ b/okhttp/src/main/kotlin/okhttp3/Request.kt @@ -54,6 +54,9 @@ class Request internal constructor(builder: Builder) { @get:JvmName("body") val body: RequestBody? = builder.body + @get:JvmName("cacheUrlOverride") + val cacheUrlOverride: HttpUrl? = builder.cacheUrlOverride + internal val tags: Map, Any> = builder.tags.toMap() internal var lazyCacheControl: CacheControl? = null @@ -183,6 +186,7 @@ class Request internal constructor(builder: Builder) { internal var method: String internal var headers: Headers.Builder internal var body: RequestBody? = null + internal var cacheUrlOverride: HttpUrl? = null /** A mutable map of tags, or an immutable empty map if we don't have any. */ internal var tags = mapOf, Any>() @@ -202,6 +206,7 @@ class Request internal constructor(builder: Builder) { else -> request.tags.toMutableMap() } this.headers = request.headers.newBuilder() + this.cacheUrlOverride = request.cacheUrlOverride } open fun url(url: HttpUrl): Builder = @@ -316,6 +321,18 @@ class Request internal constructor(builder: Builder) { tag: T?, ) = commonTag(type.kotlin, tag) + /** + * Override the [Request.url] for caching, if it is either polluted with + * transient query params, or has a canonical URL possibly for a CDN. + * + * Note that POST requests will not be sent to the server if this URL is set + * and matches a cached response. + */ + fun cacheUrlOverride(cacheUrlOverride: HttpUrl?) = + apply { + this.cacheUrlOverride = cacheUrlOverride + } + open fun build(): Request = Request(this) } } diff --git a/okhttp/src/main/kotlin/okhttp3/internal/cache/CacheInterceptor.kt b/okhttp/src/main/kotlin/okhttp3/internal/cache/CacheInterceptor.kt index e86c2470ba0c..47c4bf013123 100644 --- a/okhttp/src/main/kotlin/okhttp3/internal/cache/CacheInterceptor.kt +++ b/okhttp/src/main/kotlin/okhttp3/internal/cache/CacheInterceptor.kt @@ -25,6 +25,7 @@ import okhttp3.EventListener import okhttp3.Headers import okhttp3.Interceptor import okhttp3.Protocol +import okhttp3.Request import okhttp3.Response import okhttp3.internal.closeQuietly import okhttp3.internal.connection.RealCall @@ -44,7 +45,7 @@ class CacheInterceptor(internal val cache: Cache?) : Interceptor { @Throws(IOException::class) override fun intercept(chain: Interceptor.Chain): Response { val call = chain.call() - val cacheCandidate = cache?.get(chain.request()) + val cacheCandidate = cache?.get(chain.request().requestForCache()) val now = System.currentTimeMillis() @@ -132,9 +133,11 @@ class CacheInterceptor(internal val cache: Cache?) : Interceptor { .build() if (cache != null) { - if (response.promisesBody() && CacheStrategy.isCacheable(response, networkRequest)) { + val cacheNetworkRequest = networkRequest.requestForCache() + + if (response.promisesBody() && CacheStrategy.isCacheable(response, cacheNetworkRequest)) { // Offer this request to the cache. - val cacheRequest = cache.put(response) + val cacheRequest = cache.put(response.newBuilder().request(cacheNetworkRequest).build()) return cacheWritingResponse(cacheRequest, response).also { if (cacheResponse != null) { // This will log a conditional cache miss only. @@ -285,3 +288,17 @@ class CacheInterceptor(internal val cache: Cache?) : Interceptor { } } } + +private fun Request.requestForCache(): Request { + val cacheUrlOverride = cacheUrlOverride + + return if (cacheUrlOverride != null && (method == "GET" || method == "POST")) { + newBuilder() + .get() + .url(cacheUrlOverride) + .cacheUrlOverride(null) + .build() + } else { + this + } +} diff --git a/okhttp/src/test/java/okhttp3/CacheTest.kt b/okhttp/src/test/java/okhttp3/CacheTest.kt index e6921ddf495e..707f0fcebd95 100644 --- a/okhttp/src/test/java/okhttp3/CacheTest.kt +++ b/okhttp/src/test/java/okhttp3/CacheTest.kt @@ -1002,11 +1002,22 @@ class CacheTest { testRequestMethod("POST", false) } + @Test + fun requestMethodPostIsNotCachedUnlessOverridden() { + // Supported via cacheUrlOverride + testRequestMethod("POST", true, withOverride = true) + } + @Test fun requestMethodPutIsNotCached() { testRequestMethod("PUT", false) } + @Test + fun requestMethodPutIsNotCachedEvenWithOverride() { + testRequestMethod("PUT", false, withOverride = true) + } + @Test fun requestMethodDeleteIsNotCached() { testRequestMethod("DELETE", false) @@ -1020,6 +1031,7 @@ class CacheTest { private fun testRequestMethod( requestMethod: String, expectCached: Boolean, + withOverride: Boolean = false, ) { // 1. Seed the cache (potentially). // 2. Expect a cache hit or miss. @@ -1038,6 +1050,11 @@ class CacheTest { val request = Request.Builder() .url(url) + .apply { + if (withOverride) { + cacheUrlOverride(url) + } + } .method(requestMethod, requestBodyOrNull(requestMethod)) .build() val response1 = client.newCall(request).execute() @@ -3250,6 +3267,48 @@ CLEAN $urlKey ${entryMetadata.length} ${entryBody.length} ) } + @Test + fun getHasCorrectResponse() { + val request = Request(server.url("/abc")) + + val response = testBasicCachingRules(request) + + assertThat(response.request.url).isEqualTo(request.url) + assertThat(response.cacheResponse!!.request.url).isEqualTo(request.url) + } + + @Test + fun postWithOverrideResponse() { + val url = server.url("/abc?token=123") + val cacheUrlOverride = url.newBuilder().removeAllQueryParameters("token").build() + + val request = + Request.Builder() + .url(url) + .method("POST", "XYZ".toRequestBody()) + .cacheUrlOverride(cacheUrlOverride) + .build() + + val response = testBasicCachingRules(request) + + assertThat(response.request.url).isEqualTo(request.url) + assertThat(response.cacheResponse!!.request.url).isEqualTo(cacheUrlOverride) + } + + private fun testBasicCachingRules(request: Request): Response { + val mockResponse = + MockResponse.Builder() + .addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS)) + .addHeader("Expires: " + formatDate(1, TimeUnit.HOURS)) + .status("HTTP/1.1 200 Fantastic") + server.enqueue(mockResponse.build()) + + client.newCall(request).execute().use { + it.body.bytes() + } + return client.newCall(request).execute() + } + private operator fun get(url: HttpUrl): Response { val request = Request.Builder()