Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/java-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ jobs:
--file duo-client/pom.xml
- name: Test with Maven
run: >
mvn test
--batch-mode
mvn verify
--batch-mode
-file duo-client/pom.xml
- name: Lint with checkstyle
run: mvn checkstyle:check
25 changes: 25 additions & 0 deletions duo-client/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,18 @@
<version>3.12.4</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>mockwebserver</artifactId>
<version>4.12.0</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp-tls</artifactId>
<version>4.12.0</version>
<scope>test</scope>
</dependency>
</dependencies>

<build>
Expand Down Expand Up @@ -123,6 +135,19 @@
<parallel>methods</parallel>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-failsafe-plugin</artifactId>
<version>3.2.5</version>
<executions>
<execution>
<goals>
<goal>integration-test</goal>
<goal>verify</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.cyclonedx</groupId>
<artifactId>cyclonedx-maven-plugin</artifactId>
Expand Down
42 changes: 41 additions & 1 deletion duo-client/src/main/java/com/duosecurity/client/Http.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ public class Http {
private Headers.Builder headers;
private SortedMap<String, Object> params = new TreeMap<String, Object>();
protected int sigVersion = 5;
private long maxBackoffMs = MAX_BACKOFF_MS;
private Random random = new Random();
private OkHttpClient httpClient;
private SortedMap<String, String> additionalDuoHeaders = new TreeMap<String, String>();
Expand Down Expand Up @@ -314,10 +315,14 @@ private Response executeRequest(Request request) throws Exception {
long backoffMs = INITIAL_BACKOFF_MS;
while (true) {
Response response = httpClient.newCall(request).execute();
if (response.code() != RATE_LIMIT_ERROR_CODE || backoffMs > MAX_BACKOFF_MS) {
if (response.code() != RATE_LIMIT_ERROR_CODE || backoffMs > maxBackoffMs) {
return response;
}

// Close the 429 response to release the connection back to the pool before retrying
if (response.body() != null) {
response.close();
}
sleep(backoffMs + nextRandomInt(1000));
backoffMs *= BACKOFF_FACTOR;
}
Expand All @@ -327,6 +332,13 @@ protected void sleep(long ms) throws Exception {
Thread.sleep(ms);
}

protected void setMaxBackoffMs(long maxBackoffMs) {
if (maxBackoffMs < 0) {
throw new IllegalArgumentException("maxBackoffMs must be >= 0");
}
this.maxBackoffMs = maxBackoffMs;
}

public void signRequest(String ikey, String skey)
throws UnsupportedEncodingException {
signRequest(ikey, skey, sigVersion);
Expand Down Expand Up @@ -529,6 +541,7 @@ protected abstract static class ClientBuilder<T extends Http> {
private final String uri;

private int timeout = DEFAULT_TIMEOUT_SECS;
private long maxBackoffMs = MAX_BACKOFF_MS;
private String[] caCerts = null;
private SortedMap<String, String> additionalDuoHeaders = new TreeMap<String, String>();
private Map<String, String> headers = new HashMap<String, String>();
Expand Down Expand Up @@ -558,6 +571,32 @@ public ClientBuilder<T> useTimeout(int timeout) {
return this;
}

/**
* Set the maximum base backoff time in milliseconds for rate limit (429) retries.
* When a request receives a 429 response, the client retries with exponential
* backoff until the base backoff exceeds this threshold. Note that actual sleep
* time includes up to 1000ms of random jitter on top of the base backoff.
* Setting to 0 disables retries (as does any value below the initial
* backoff of 1000ms). Default is 32000ms (32 seconds).
*
* <p>Note: When using method chaining from outside this package (e.g. with
* {@code AuthBuilder} or {@code AdminBuilder}), assign the builder to a variable
* and call methods separately, then call {@code build()}. This is a known
* limitation of all {@code ClientBuilder} methods.
*
* @param maxBackoffMs the maximum base backoff in milliseconds (must be >= 0)
* @return the Builder
* @throws IllegalArgumentException if maxBackoffMs is negative
*/
public ClientBuilder<T> useMaxBackoffMs(long maxBackoffMs) {
if (maxBackoffMs < 0) {
throw new IllegalArgumentException("maxBackoffMs must be >= 0");
}
this.maxBackoffMs = maxBackoffMs;

return this;
}

/**
* Provide custom CA certificates for certificate pinning.
*
Expand Down Expand Up @@ -604,6 +643,7 @@ public ClientBuilder<T> addHeader(String name, String value) {
*/
public T build() {
T duoClient = createClient(method, host, uri, timeout);
duoClient.setMaxBackoffMs(maxBackoffMs);
if (caCerts != null) {
duoClient.useCustomCertificates(caCerts);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
package com.duosecurity.client;

import okhttp3.OkHttpClient;
import okhttp3.Response;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.tls.HandshakeCertificates;
import okhttp3.tls.HeldCertificate;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mockito;

import java.lang.reflect.Field;

import static org.junit.Assert.assertEquals;

public class HttpRateLimitRetryIntegrationIT {

private MockWebServer server;
private HandshakeCertificates clientCerts;

@Before
public void setUp() throws Exception {
HeldCertificate serverCert = new HeldCertificate.Builder()
.addSubjectAlternativeName("localhost")
.build();

HandshakeCertificates serverCerts = new HandshakeCertificates.Builder()
.heldCertificate(serverCert)
.build();

clientCerts = new HandshakeCertificates.Builder()
.addTrustedCertificate(serverCert.certificate())
.build();

server = new MockWebServer();
server.useHttps(serverCerts.sslSocketFactory(), false);
server.start();
}

@After
public void tearDown() throws Exception {
server.shutdown();
}

/**
* Builds an Http spy pointing at the MockWebServer, with sleep() stubbed out to avoid real
* delays and the OkHttpClient replaced with one that trusts the test certificate.
*
* <p>The builder must be constructed with host "localhost" (no port) so that CertificatePinner
* accepts the pattern. This method then sets the real host (with port) and replaces the
* OkHttpClient via reflection before the spy is used.
*/
private Http buildSpyHttp(Http.ClientBuilder<Http> builder) throws Exception {
Http spy = Mockito.spy(builder.build());
Mockito.doNothing().when(spy).sleep(Mockito.any(Long.class));

// Point the host at the MockWebServer port (CertificatePinner rejects host:port patterns,
// so the builder uses "localhost" and we fix it here after construction).
Field hostField = Http.class.getDeclaredField("host");
hostField.setAccessible(true);
hostField.set(spy, "localhost:" + server.getPort());

// Replace the OkHttpClient with one configured to trust the test certificate
OkHttpClient testClient = new OkHttpClient.Builder()
.sslSocketFactory(clientCerts.sslSocketFactory(), clientCerts.trustManager())
.build();

Field httpClientField = Http.class.getDeclaredField("httpClient");
httpClientField.setAccessible(true);
httpClientField.set(spy, testClient);

return spy;
}

private Http.HttpBuilder defaultBuilder() {
// Use "localhost" without a port — CertificatePinner rejects host:port patterns.
// buildSpyHttp sets the real host (with port) via reflection after construction.
return new Http.HttpBuilder("GET", "localhost", "/foo/bar");
}

@Test
public void testSingleRateLimitRetry() throws Exception {
server.enqueue(new MockResponse().setResponseCode(429));
server.enqueue(new MockResponse().setResponseCode(200));

Http http = buildSpyHttp(defaultBuilder());
Response response = http.executeHttpRequest();

assertEquals(200, response.code());
assertEquals(2, server.getRequestCount());
Mockito.verify(http, Mockito.times(1)).sleep(Mockito.any(Long.class));
}

@Test
public void testRateLimitExhaustsDefaultMaxBackoff() throws Exception {
// Enqueue more responses than will ever be consumed
for (int i = 0; i < 10; i++) {
server.enqueue(new MockResponse().setResponseCode(429));
}

Http http = buildSpyHttp(defaultBuilder());
Response response = http.executeHttpRequest();

assertEquals(429, response.code());
// Default max backoff (32s): sleeps at 1s, 2s, 4s, 8s, 16s, 32s = 6 sleeps, 7 total requests
assertEquals(7, server.getRequestCount());
Mockito.verify(http, Mockito.times(6)).sleep(Mockito.any(Long.class));
}

@Test
public void testCustomMaxBackoffLimitsRetries() throws Exception {
for (int i = 0; i < 10; i++) {
server.enqueue(new MockResponse().setResponseCode(429));
}

Http http = buildSpyHttp(defaultBuilder().useMaxBackoffMs(4000));
Response response = http.executeHttpRequest();

assertEquals(429, response.code());
// maxBackoff=4000: sleeps at 1s, 2s, 4s = 3 sleeps, 4 total requests (next would be 8s > 4s)
assertEquals(4, server.getRequestCount());
Mockito.verify(http, Mockito.times(3)).sleep(Mockito.any(Long.class));
}

@Test
public void testMaxBackoffZeroDisablesRetry() throws Exception {
server.enqueue(new MockResponse().setResponseCode(429));

Http http = buildSpyHttp(defaultBuilder().useMaxBackoffMs(0));
Response response = http.executeHttpRequest();

assertEquals(429, response.code());
assertEquals(1, server.getRequestCount());
Mockito.verify(http, Mockito.never()).sleep(Mockito.any(Long.class));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,8 @@ public class HttpRateLimitRetryTest {

private final int RANDOM_INT = 234;

@Before
public void before() throws Exception {
http = new Http.HttpBuilder("GET", "example.test", "/foo/bar").build();
http = Mockito.spy(http);
private void setupHttp(Http client) throws Exception {
http = Mockito.spy(client);

Field httpClientField = Http.class.getDeclaredField("httpClient");
httpClientField.setAccessible(true);
Expand All @@ -39,6 +37,12 @@ public void before() throws Exception {
Mockito.doNothing().when(http).sleep(Mockito.any(Long.class));
}

@Before
public void before() throws Exception {
Http client = new Http.HttpBuilder("GET", "example.test", "/foo/bar").build();
setupHttp(client);
}

@Test
public void testSingleRateLimitRetry() throws Exception {
final List<Response> responses = new ArrayList<Response>();
Expand Down Expand Up @@ -128,4 +132,87 @@ public Call answer(InvocationOnMock invocationOnMock) throws Throwable {
assertEquals(16000L + RANDOM_INT, (long) sleepTimes.get(4));
assertEquals(32000L + RANDOM_INT, (long) sleepTimes.get(5));
}

@Test
public void testMaxBackoffZeroDisablesRetry() throws Exception {
Http customHttp = new Http.HttpBuilder("GET", "example.test", "/foo/bar")
.useMaxBackoffMs(0)
.build();
setupHttp(customHttp);

final List<Response> responses = new ArrayList<Response>();

Mockito.when(httpClient.newCall(Mockito.any(Request.class))).thenAnswer(new Answer<Call>() {
@Override
public Call answer(InvocationOnMock invocationOnMock) throws Throwable {
Call call = Mockito.mock(Call.class);

Response resp = new Response.Builder()
.protocol(Protocol.HTTP_2)
.code(429)
.request((Request) invocationOnMock.getArguments()[0])
.message("HTTP 429")
.build();
responses.add(resp);
Mockito.when(call.execute()).thenReturn(resp);

return call;
}
});

Response actualRes = http.executeHttpRequest();
assertEquals(1, responses.size());
assertEquals(429, actualRes.code());

// Verify no sleep was called
Mockito.verify(http, Mockito.never()).sleep(Mockito.any(Long.class));
}

@Test
public void testMaxBackoffCustomLimit() throws Exception {
Http customHttp = new Http.HttpBuilder("GET", "example.test", "/foo/bar")
.useMaxBackoffMs(4000)
.build();
setupHttp(customHttp);

final List<Response> responses = new ArrayList<Response>();

Mockito.when(httpClient.newCall(Mockito.any(Request.class))).thenAnswer(new Answer<Call>() {
@Override
public Call answer(InvocationOnMock invocationOnMock) throws Throwable {
Call call = Mockito.mock(Call.class);

Response resp = new Response.Builder()
.protocol(Protocol.HTTP_2)
.code(429)
.request((Request) invocationOnMock.getArguments()[0])
.message("HTTP 429")
.build();
responses.add(resp);
Mockito.when(call.execute()).thenReturn(resp);

return call;
}
});

// With maxBackoff=4000, retries at 1000, 2000, 4000, then 8000 > 4000 exits
// That's 4 total requests (1 initial + 3 retries)
Response actualRes = http.executeHttpRequest();
assertEquals(4, responses.size());
assertEquals(429, actualRes.code());

ArgumentCaptor<Long> sleepCapture = ArgumentCaptor.forClass(Long.class);
Mockito.verify(http, Mockito.times(3)).sleep(sleepCapture.capture());
List<Long> sleepTimes = sleepCapture.getAllValues();
assertEquals(1000L + RANDOM_INT, (long) sleepTimes.get(0));
assertEquals(2000L + RANDOM_INT, (long) sleepTimes.get(1));
assertEquals(4000L + RANDOM_INT, (long) sleepTimes.get(2));
}

@Test(expected = IllegalArgumentException.class)
public void testMaxBackoffNegativeThrows() {
new Http.HttpBuilder("GET", "example.test", "/foo/bar")
.useMaxBackoffMs(-1)
.build();
}
}