Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache URL Override #8234

Merged
merged 13 commits into from
Mar 22, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -187,15 +187,22 @@ class DnsOverHttps internal constructor(
}

private fun getCacheOnlyResponse(request: Request): Response? {
if (!post && client.cache != null) {
if (client.cache != null) {
try {
// Use the cache without hitting the network first
// 504 code indicates that the Cache is stale
val preferCache =
val onlyIfCached =
CacheControl.Builder()
.onlyIfCached()
.build()
val cacheRequest = request.newBuilder().cacheControl(preferCache).build()

var cacheUrl = request.url

val cacheRequest =
request.newBuilder()
.cacheControl(onlyIfCached)
.cacheUrlOverride(cacheUrl)
.build()

val cacheResponse = client.newCall(cacheRequest).execute()

Expand Down Expand Up @@ -247,7 +254,12 @@ class DnsOverHttps internal constructor(
val query = DnsRecordCodec.encodeQuery(hostname, type)

if (post) {
url(url).post(query.toRequestBody(DNS_MESSAGE))
url(url)
.cacheUrlOverride(
url.newBuilder()
.addQueryParameter("hostname", hostname).build(),
)
.post(query.toRequestBody(DNS_MESSAGE))
} else {
val encoded = query.base64Url().replace("=", "")
val requestUrl = url.newBuilder().addQueryParameter("dns", encoded).build()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import assertk.assertions.containsExactlyInAnyOrder
import assertk.assertions.hasMessage
import assertk.assertions.isEqualTo
import assertk.assertions.isInstanceOf
import assertk.assertions.isNull
import java.io.EOFException
import java.io.File
import java.io.IOException
Expand Down Expand Up @@ -167,30 +168,79 @@ class DnsOverHttpsTest {
// 3. successful network response
// 4. successful stale cached GET response
// 5. unsuccessful response
// TODO how closely to follow POST rules on caching?
@Test
fun usesCache() {
val cache = Cache("cache".toPath(), (100 * 1024).toLong(), cacheFs)
val cachedClient = bootstrapClient.newBuilder().cache(cache).build()
val cachedDns = buildLocalhost(cachedClient, false)
server.enqueue(
dnsResponse(
"0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c000500010" +
"0000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010001000" +
"0003b00049df00112",

repeat(2) {
server.enqueue(
dnsResponse(
"0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c000500010" +
"0000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010001000" +
"0003b00049df00112",
)
.newBuilder()
.setHeader("cache-control", "private, max-age=298")
.build(),
)
.newBuilder()
.setHeader("cache-control", "private, max-age=298")
.build(),
)
}

var result = cachedDns.lookup("google.com")
assertThat(result).containsExactly(address("157.240.1.18"))
val recordedRequest = server.takeRequest()
var recordedRequest = server.takeRequest()
assertThat(recordedRequest.method).isEqualTo("GET")
assertThat(recordedRequest.path)
.isEqualTo("/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ")

result = cachedDns.lookup("google.com")
assertThat(server.takeRequest(1, TimeUnit.MILLISECONDS)).isNull()
assertThat(result).isEqualTo(listOf(address("157.240.1.18")))

result = cachedDns.lookup("www.google.com")
assertThat(result).containsExactly(address("157.240.1.18"))
recordedRequest = server.takeRequest()
assertThat(recordedRequest.method).isEqualTo("GET")
assertThat(recordedRequest.path)
.isEqualTo("/lookup?ct&dns=AAABAAABAAAAAAAAA3d3dwZnb29nbGUDY29tAAABAAE")
}

@Test
fun usesCacheEvenForPost() {
val cache = Cache("cache".toPath(), (100 * 1024).toLong(), cacheFs)
val cachedClient = bootstrapClient.newBuilder().cache(cache).build()
val cachedDns = buildLocalhost(cachedClient, false, post = true)
repeat(2) {
server.enqueue(
dnsResponse(
"0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c000500010" +
"0000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010001000" +
"0003b00049df00112",
)
.newBuilder()
.setHeader("cache-control", "private, max-age=298")
.build(),
)
}

var result = cachedDns.lookup("google.com")
assertThat(result).containsExactly(address("157.240.1.18"))
var recordedRequest = server.takeRequest()
assertThat(recordedRequest.method).isEqualTo("POST")
assertThat(recordedRequest.path)
.isEqualTo("/lookup?ct")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Off-topic: we’re naughty for having a field called path that contains a path and a query. Thanks for confirming that we use the right URL for these!


result = cachedDns.lookup("google.com")
assertThat(server.takeRequest(0, TimeUnit.MILLISECONDS)).isNull()
assertThat(result).isEqualTo(listOf(address("157.240.1.18")))

result = cachedDns.lookup("www.google.com")
assertThat(result).containsExactly(address("157.240.1.18"))
recordedRequest = server.takeRequest(0, TimeUnit.MILLISECONDS)!!
assertThat(recordedRequest.method).isEqualTo("POST")
assertThat(recordedRequest.path)
.isEqualTo("/lookup?ct")
}

@Test
Expand Down Expand Up @@ -245,12 +295,14 @@ class DnsOverHttpsTest {
private fun buildLocalhost(
bootstrapClient: OkHttpClient,
includeIPv6: Boolean,
post: Boolean = false,
): DnsOverHttps {
val url = server.url("/lookup?ct")
return DnsOverHttps.Builder().client(bootstrapClient)
.includeIPv6(includeIPv6)
.resolvePrivateAddresses(true)
.url(url)
.post(post)
.build()
}

Expand Down
2 changes: 2 additions & 0 deletions okhttp/api/okhttp.api
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,7 @@ public final class okhttp3/Request {
public synthetic fun <init> (Lokhttp3/HttpUrl;Lokhttp3/Headers;Ljava/lang/String;Lokhttp3/RequestBody;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public final fun body ()Lokhttp3/RequestBody;
public final fun cacheControl ()Lokhttp3/CacheControl;
public final fun cacheUrlOverride ()Lokhttp3/HttpUrl;
public final fun header (Ljava/lang/String;)Ljava/lang/String;
public final fun headers ()Lokhttp3/Headers;
public final fun headers (Ljava/lang/String;)Ljava/util/List;
Expand All @@ -1048,6 +1049,7 @@ public class okhttp3/Request$Builder {
public fun addHeader (Ljava/lang/String;Ljava/lang/String;)Lokhttp3/Request$Builder;
public fun build ()Lokhttp3/Request;
public fun cacheControl (Lokhttp3/CacheControl;)Lokhttp3/Request$Builder;
public final fun cacheUrlOverride (Lokhttp3/HttpUrl;)Lokhttp3/Request$Builder;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that we’re not adding OkHttpExperimentalApi on this, and so we’re committing to this exact API and name. I’m okay with that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we reconsider as a policy for new API?

public final fun delete ()Lokhttp3/Request$Builder;
public fun delete (Lokhttp3/RequestBody;)Lokhttp3/Request$Builder;
public static synthetic fun delete$default (Lokhttp3/Request$Builder;Lokhttp3/RequestBody;ILjava/lang/Object;)Lokhttp3/Request$Builder;
Expand Down
17 changes: 17 additions & 0 deletions okhttp/src/main/kotlin/okhttp3/Request.kt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class Request internal constructor(builder: Builder) {
@get:JvmName("body")
val body: RequestBody? = builder.body

@get:JvmName("cacheUrlOverride")
val cacheUrlOverride: HttpUrl? = builder.cacheUrlOverride

internal val tags: Map<KClass<*>, Any> = builder.tags.toMap()

internal var lazyCacheControl: CacheControl? = null
Expand Down Expand Up @@ -183,6 +186,7 @@ class Request internal constructor(builder: Builder) {
internal var method: String
internal var headers: Headers.Builder
internal var body: RequestBody? = null
internal var cacheUrlOverride: HttpUrl? = null

/** A mutable map of tags, or an immutable empty map if we don't have any. */
internal var tags = mapOf<KClass<*>, Any>()
Expand All @@ -202,6 +206,7 @@ class Request internal constructor(builder: Builder) {
else -> request.tags.toMutableMap()
}
this.headers = request.headers.newBuilder()
this.cacheUrlOverride = request.cacheUrlOverride
}

open fun url(url: HttpUrl): Builder =
Expand Down Expand Up @@ -316,6 +321,18 @@ class Request internal constructor(builder: Builder) {
tag: T?,
) = commonTag(type.kotlin, tag)

/**
* Override the [Request.url] for caching, if it is either polluted with
* transient query params, or has a canonical URL possibly for a CDN.
*
* Note that POST requests will not be sent to the server if this URL is set
* and matches a cached response.
*/
fun cacheUrlOverride(cacheUrlOverride: HttpUrl?) =
apply {
this.cacheUrlOverride = cacheUrlOverride
}

open fun build(): Request = Request(this)
}
}
23 changes: 20 additions & 3 deletions okhttp/src/main/kotlin/okhttp3/internal/cache/CacheInterceptor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import okhttp3.EventListener
import okhttp3.Headers
import okhttp3.Interceptor
import okhttp3.Protocol
import okhttp3.Request
import okhttp3.Response
import okhttp3.internal.closeQuietly
import okhttp3.internal.connection.RealCall
Expand All @@ -44,7 +45,7 @@ class CacheInterceptor(internal val cache: Cache?) : Interceptor {
@Throws(IOException::class)
override fun intercept(chain: Interceptor.Chain): Response {
val call = chain.call()
val cacheCandidate = cache?.get(chain.request())
val cacheCandidate = cache?.get(chain.request().requestForCache())

val now = System.currentTimeMillis()

Expand Down Expand Up @@ -132,9 +133,11 @@ class CacheInterceptor(internal val cache: Cache?) : Interceptor {
.build()

if (cache != null) {
if (response.promisesBody() && CacheStrategy.isCacheable(response, networkRequest)) {
val cacheNetworkRequest = networkRequest.requestForCache()

if (response.promisesBody() && CacheStrategy.isCacheable(response, cacheNetworkRequest)) {
// Offer this request to the cache.
val cacheRequest = cache.put(response)
val cacheRequest = cache.put(response.newBuilder().request(cacheNetworkRequest).build())
return cacheWritingResponse(cacheRequest, response).also {
if (cacheResponse != null) {
// This will log a conditional cache miss only.
Expand Down Expand Up @@ -285,3 +288,17 @@ class CacheInterceptor(internal val cache: Cache?) : Interceptor {
}
}
}

private fun Request.requestForCache(): Request {
val cacheUrlOverride = cacheUrlOverride

return if (cacheUrlOverride != null && (method == "GET" || method == "POST")) {
newBuilder()
.get()
.url(cacheUrlOverride)
.cacheUrlOverride(null)
.build()
} else {
this
}
}
59 changes: 59 additions & 0 deletions okhttp/src/test/java/okhttp3/CacheTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -1002,11 +1002,22 @@ class CacheTest {
testRequestMethod("POST", false)
}

@Test
fun requestMethodPostIsNotCachedUnlessOverridden() {
// Supported via cacheUrlOverride
testRequestMethod("POST", true, withOverride = true)
}

@Test
fun requestMethodPutIsNotCached() {
testRequestMethod("PUT", false)
}

@Test
fun requestMethodPutIsNotCachedEvenWithOverride() {
testRequestMethod("PUT", false, withOverride = true)
}

@Test
fun requestMethodDeleteIsNotCached() {
testRequestMethod("DELETE", false)
Expand All @@ -1020,6 +1031,7 @@ class CacheTest {
private fun testRequestMethod(
requestMethod: String,
expectCached: Boolean,
withOverride: Boolean = false,
) {
// 1. Seed the cache (potentially).
// 2. Expect a cache hit or miss.
Expand All @@ -1038,6 +1050,11 @@ class CacheTest {
val request =
Request.Builder()
.url(url)
.apply {
if (withOverride) {
cacheUrlOverride(url)
}
}
.method(requestMethod, requestBodyOrNull(requestMethod))
.build()
val response1 = client.newCall(request).execute()
Expand Down Expand Up @@ -3250,6 +3267,48 @@ CLEAN $urlKey ${entryMetadata.length} ${entryBody.length}
)
}

@Test
fun getHasCorrectResponse() {
val request = Request(server.url("/abc"))

val response = testBasicCachingRules(request)

assertThat(response.request.url).isEqualTo(request.url)
assertThat(response.cacheResponse!!.request.url).isEqualTo(request.url)
}

@Test
fun postWithOverrideResponse() {
val url = server.url("/abc?token=123")
val cacheUrlOverride = url.newBuilder().removeAllQueryParameters("token").build()

val request =
Request.Builder()
.url(url)
.method("POST", "XYZ".toRequestBody())
.cacheUrlOverride(cacheUrlOverride)
.build()

val response = testBasicCachingRules(request)

assertThat(response.request.url).isEqualTo(request.url)
assertThat(response.cacheResponse!!.request.url).isEqualTo(cacheUrlOverride)
}

private fun testBasicCachingRules(request: Request): Response {
val mockResponse =
MockResponse.Builder()
.addHeader("Last-Modified: " + formatDate(-1, TimeUnit.HOURS))
.addHeader("Expires: " + formatDate(1, TimeUnit.HOURS))
.status("HTTP/1.1 200 Fantastic")
server.enqueue(mockResponse.build())

client.newCall(request).execute().use {
it.body.bytes()
}
return client.newCall(request).execute()
}

private operator fun get(url: HttpUrl): Response {
val request =
Request.Builder()
Expand Down
Loading