Skip to content

Commit 4b85273

Browse files
author
Bryan Donlan
committed
Restore KMS caching logic
We now verify that the requested region is reachable before caching the KMS client.
1 parent 0e15a35 commit 4b85273

File tree

10 files changed

+148
-22
lines changed

10 files changed

+148
-22
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
## 1.3.5
44

5-
(nothing yet)
5+
### Minor Changes
6+
7+
* Restored the KMS client cache with a fix for the memory leak.
68

79
## 1.3.4
810

src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKeyProvider.java

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@
2626
import java.util.Objects;
2727
import java.util.concurrent.ConcurrentHashMap;
2828

29+
import com.amazonaws.AmazonServiceException;
2930
import com.amazonaws.ClientConfiguration;
31+
import com.amazonaws.Request;
32+
import com.amazonaws.Response;
3033
import com.amazonaws.auth.AWSCredentials;
3134
import com.amazonaws.auth.AWSCredentialsProvider;
3235
import com.amazonaws.auth.AWSStaticCredentialsProvider;
@@ -71,12 +74,16 @@ public interface RegionalClientSupplier {
7174
AWSKMS getClient(String regionName);
7275
}
7376

74-
public static final class Builder implements Cloneable {
77+
public static class Builder implements Cloneable {
7578
private String defaultRegion_ = null;
7679
private RegionalClientSupplier regionalClientSupplier_ = null;
7780
private AWSKMSClientBuilder templateBuilder_ = null;
7881
private List<String> keyIds_ = new ArrayList<>();
7982

83+
Builder() {
84+
// Default access: Don't allow outside classes to extend this class
85+
}
86+
8087
public Builder clone() {
8188
try {
8289
Builder cloned = (Builder) super.clone();
@@ -259,11 +266,68 @@ private RegionalClientSupplier clientFactory() {
259266
AWSKMSClientBuilder builder = templateBuilder_ != null ? cloneClientBuilder(templateBuilder_)
260267
: AWSKMSClientBuilder.standard();
261268

269+
ConcurrentHashMap<String, AWSKMS> clientCache = new ConcurrentHashMap<>();
270+
snoopClientCache(clientCache);
271+
262272
return region -> {
263-
// Clone yet again as we're going to change the region field.
264-
return cloneClientBuilder(builder).withRegion(region).build();
273+
AWSKMS kms = clientCache.get(region);
274+
275+
if (kms != null) return kms;
276+
277+
// We can't just use computeIfAbsent as we need to avoid leaking KMS clients if we're asked to decrypt
278+
// an EDK with a bogus region in its ARN. So we'll install a request handler to identify the first
279+
// successful call, and cache it when we see that.
280+
SuccessfulRequestCacher cacher = new SuccessfulRequestCacher(clientCache, region);
281+
ArrayList<RequestHandler2> handlers = new ArrayList<>();
282+
if (builder.getRequestHandlers() != null) {
283+
handlers.addAll(builder.getRequestHandlers());
284+
}
285+
handlers.add(cacher);
286+
287+
kms = cloneClientBuilder(builder)
288+
.withRegion(region)
289+
.withRequestHandlers(handlers.toArray(new RequestHandler2[handlers.size()]))
290+
.build();
291+
cacher.client_ = kms;
292+
293+
return kms;
265294
};
266295
}
296+
297+
protected void snoopClientCache(ConcurrentHashMap<String, AWSKMS> map) {
298+
// no-op - this is a test hook
299+
}
300+
}
301+
302+
private static class SuccessfulRequestCacher extends RequestHandler2 {
303+
private final ConcurrentHashMap<String, AWSKMS> cache_;
304+
private final String region_;
305+
private AWSKMS client_;
306+
307+
volatile boolean ranBefore_ = false;
308+
309+
private SuccessfulRequestCacher(
310+
final ConcurrentHashMap<String, AWSKMS> cache,
311+
final String region
312+
) {
313+
this.region_ = region;
314+
this.cache_ = cache;
315+
}
316+
317+
@Override public void afterResponse(final Request<?> request, final Response<?> response) {
318+
if (ranBefore_) return;
319+
ranBefore_ = true;
320+
321+
cache_.putIfAbsent(region_, client_);
322+
}
323+
324+
@Override public void afterError(final Request<?> request, final Response<?> response, final Exception e) {
325+
if (ranBefore_) return;
326+
if (e instanceof AmazonServiceException) {
327+
ranBefore_ = true;
328+
cache_.putIfAbsent(region_, client_);
329+
}
330+
}
267331
}
268332

269333
public static Builder builder() {

src/test/java/com/amazonaws/encryptionsdk/AllTestsSuite.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import com.amazonaws.encryptionsdk.model.CipherFrameHeadersTest;
2424
import com.amazonaws.encryptionsdk.model.KeyBlobTest;
2525
import com.amazonaws.encryptionsdk.multi.MultipleMasterKeyTest;
26-
import com.amazonaws.services.kms.KMSProviderBuilderMockTests;
26+
import com.amazonaws.encryptionsdk.kms.KMSProviderBuilderMockTests;
2727

2828
@RunWith(Suite.class)
2929
@Suite.SuiteClasses({

src/test/java/com/amazonaws/encryptionsdk/IntegrationTestSuite.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import org.junit.runner.RunWith;
44
import org.junit.runners.Suite;
55

6-
import com.amazonaws.services.kms.KMSProviderBuilderIntegrationTests;
7-
import com.amazonaws.services.kms.XCompatKmsDecryptTest;
6+
import com.amazonaws.encryptionsdk.kms.KMSProviderBuilderIntegrationTests;
7+
import com.amazonaws.encryptionsdk.kms.XCompatKmsDecryptTest;
88

99
@RunWith(Suite.class)
1010
@Suite.SuiteClasses({

src/test/java/com/amazonaws/services/kms/KMSProviderBuilderIntegrationTests.java renamed to src/test/java/com/amazonaws/encryptionsdk/kms/KMSProviderBuilderIntegrationTests.java

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
package com.amazonaws.services.kms;
1+
package com.amazonaws.encryptionsdk.kms;
22

33
import static org.junit.Assert.assertEquals;
4+
import static org.junit.Assert.assertFalse;
5+
import static org.junit.Assert.assertNotNull;
46
import static org.junit.Assert.assertTrue;
57
import static org.junit.Assert.fail;
68
import static org.mockito.ArgumentMatchers.any;
@@ -10,29 +12,91 @@
1012
import static org.mockito.Mockito.spy;
1113
import static org.mockito.Mockito.verify;
1214

15+
import java.nio.charset.StandardCharsets;
1316
import java.util.Arrays;
17+
import java.util.Collections;
18+
import java.util.HashMap;
19+
import java.util.concurrent.ConcurrentHashMap;
20+
import java.util.concurrent.atomic.AtomicReference;
1421

1522
import org.junit.Test;
1623
import org.mockito.ArgumentCaptor;
1724

1825
import com.amazonaws.AbortedException;
19-
import com.amazonaws.AmazonWebServiceRequest;
2026
import com.amazonaws.ClientConfiguration;
2127
import com.amazonaws.Request;
2228
import com.amazonaws.auth.AWSCredentials;
2329
import com.amazonaws.auth.AWSCredentialsProvider;
2430
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
2531
import com.amazonaws.client.builder.AwsClientBuilder;
2632
import com.amazonaws.encryptionsdk.AwsCrypto;
33+
import com.amazonaws.encryptionsdk.CryptoAlgorithm;
2734
import com.amazonaws.encryptionsdk.CryptoResult;
2835
import com.amazonaws.encryptionsdk.MasterKeyProvider;
2936
import com.amazonaws.encryptionsdk.exception.CannotUnwrapDataKeyException;
3037
import com.amazonaws.encryptionsdk.internal.VersionInfo;
31-
import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProvider;
38+
import com.amazonaws.encryptionsdk.model.KeyBlob;
3239
import com.amazonaws.handlers.RequestHandler2;
3340
import com.amazonaws.http.exception.HttpRequestTimeoutException;
41+
import com.amazonaws.services.kms.AWSKMS;
42+
import com.amazonaws.services.kms.AWSKMSClientBuilder;
3443

3544
public class KMSProviderBuilderIntegrationTests {
45+
@Test
46+
public void whenBogusRegionsDecrypted_doesNotLeakClients() throws Exception {
47+
AtomicReference<ConcurrentHashMap<String, AWSKMS>> kmsCache = new AtomicReference<>();
48+
49+
KmsMasterKeyProvider mkp = (new KmsMasterKeyProvider.Builder() {
50+
@Override protected void snoopClientCache(
51+
final ConcurrentHashMap<String, AWSKMS> map
52+
) {
53+
kmsCache.set(map);
54+
}
55+
}).build();
56+
57+
try {
58+
mkp.decryptDataKey(
59+
CryptoAlgorithm.ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256,
60+
Collections.singleton(
61+
new KeyBlob("aws-kms",
62+
"arn:aws:kms:us-bogus-1:123456789010:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"
63+
.getBytes(StandardCharsets.UTF_8),
64+
new byte[40]
65+
)
66+
),
67+
new HashMap<>()
68+
);
69+
fail("Expected CannotUnwrapDataKeyException");
70+
} catch (CannotUnwrapDataKeyException e) {
71+
// ok
72+
}
73+
74+
assertTrue(kmsCache.get().isEmpty());
75+
}
76+
77+
@Test
78+
public void whenOperationSuccessful_clientIsCached() {
79+
AtomicReference<ConcurrentHashMap<String, AWSKMS>> kmsCache = new AtomicReference<>();
80+
81+
KmsMasterKeyProvider mkp = (new KmsMasterKeyProvider.Builder() {
82+
@Override protected void snoopClientCache(
83+
final ConcurrentHashMap<String, AWSKMS> map
84+
) {
85+
kmsCache.set(map);
86+
}
87+
}).withKeysForEncryption(KMSTestFixtures.TEST_KEY_IDS[0])
88+
.build();
89+
90+
new AwsCrypto().encryptData(mkp, new byte[1]);
91+
92+
AWSKMS kms = kmsCache.get().get("us-west-2");
93+
assertNotNull(kms);
94+
95+
new AwsCrypto().encryptData(mkp, new byte[1]);
96+
97+
// Cache entry should stay the same
98+
assertEquals(kms, kmsCache.get().get("us-west-2"));
99+
}
36100

37101
@Test
38102
public void whenConstructedWithoutArguments_canUseMultipleRegions() throws Exception {
@@ -75,7 +139,7 @@ public void whenHandlerConfigured_handlerIsInvoked() throws Exception {
75139
KmsMasterKeyProvider.builder()
76140
.withClientBuilder(
77141
AWSKMSClientBuilder.standard()
78-
.withRequestHandlers(handler)
142+
.withRequestHandlers(handler)
79143
)
80144
.withKeysForEncryption(KMSTestFixtures.TEST_KEY_IDS[0])
81145
.build();

src/test/java/com/amazonaws/services/kms/KMSProviderBuilderMockTests.java renamed to src/test/java/com/amazonaws/encryptionsdk/kms/KMSProviderBuilderMockTests.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
package com.amazonaws.services.kms;
1+
package com.amazonaws.encryptionsdk.kms;
22

33
import static com.amazonaws.encryptionsdk.multi.MultipleProviderFactory.buildMultiProvider;
44
import static com.amazonaws.regions.Region.getRegion;
5-
import static com.amazonaws.regions.Regions.DEFAULT_REGION;
65
import static com.amazonaws.regions.Regions.fromName;
76
import static java.util.Collections.singletonList;
87
import static org.junit.Assert.assertEquals;
@@ -30,8 +29,6 @@
3029
import com.amazonaws.encryptionsdk.AwsCrypto;
3130
import com.amazonaws.encryptionsdk.MasterKeyProvider;
3231
import com.amazonaws.encryptionsdk.internal.VersionInfo;
33-
import com.amazonaws.encryptionsdk.kms.KmsMasterKey;
34-
import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProvider;
3532
import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProvider.RegionalClientSupplier;
3633
import com.amazonaws.regions.Region;
3734
import com.amazonaws.regions.Regions;

src/test/java/com/amazonaws/services/kms/KMSTestFixtures.java renamed to src/test/java/com/amazonaws/encryptionsdk/kms/KMSTestFixtures.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package com.amazonaws.services.kms;
1+
package com.amazonaws.encryptionsdk.kms;
22

33
final class KMSTestFixtures {
44
private KMSTestFixtures() {

src/test/java/com/amazonaws/services/kms/LegacyKMSMasterKeyProviderTests.java renamed to src/test/java/com/amazonaws/encryptionsdk/kms/LegacyKMSMasterKeyProviderTests.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
* specific language governing permissions and limitations under the License.
1212
*/
1313

14-
package com.amazonaws.services.kms;
14+
package com.amazonaws.encryptionsdk.kms;
1515

1616
import static com.amazonaws.encryptionsdk.CryptoAlgorithm.ALG_AES_128_GCM_IV12_TAG16_NO_KDF;
1717
import static com.amazonaws.encryptionsdk.internal.RandomBytesGenerator.generate;
@@ -34,11 +34,10 @@
3434
import com.amazonaws.encryptionsdk.MasterKeyProvider;
3535
import com.amazonaws.encryptionsdk.MasterKeyRequest;
3636
import com.amazonaws.encryptionsdk.jce.JceMasterKey;
37-
import com.amazonaws.encryptionsdk.kms.KmsMasterKey;
38-
import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProvider;
3937
import com.amazonaws.encryptionsdk.multi.MultipleProviderFactory;
4038
import com.amazonaws.regions.Region;
4139
import com.amazonaws.regions.Regions;
40+
import com.amazonaws.services.kms.AWSKMS;
4241

4342
public class LegacyKMSMasterKeyProviderTests {
4443
private static final String WRAPPING_ALG = "AES/GCM/NoPadding";

src/test/java/com/amazonaws/services/kms/MockKMSClient.java renamed to src/test/java/com/amazonaws/encryptionsdk/kms/MockKMSClient.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
* specific language governing permissions and limitations under the License.
1212
*/
1313

14-
package com.amazonaws.services.kms;
14+
package com.amazonaws.encryptionsdk.kms;
1515

1616
import java.nio.ByteBuffer;
1717
import java.security.SecureRandom;
@@ -29,6 +29,7 @@
2929
import com.amazonaws.ResponseMetadata;
3030
import com.amazonaws.regions.Region;
3131
import com.amazonaws.regions.Regions;
32+
import com.amazonaws.services.kms.AWSKMSClient;
3233
import com.amazonaws.services.kms.model.CreateAliasRequest;
3334
import com.amazonaws.services.kms.model.CreateAliasResult;
3435
import com.amazonaws.services.kms.model.CreateGrantRequest;

src/test/java/com/amazonaws/services/kms/XCompatKmsDecryptTest.java renamed to src/test/java/com/amazonaws/encryptionsdk/kms/XCompatKmsDecryptTest.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
* specific language governing permissions and limitations under the License.
1212
*/
1313

14-
package com.amazonaws.services.kms;
14+
package com.amazonaws.encryptionsdk.kms;
1515

1616
import static org.junit.Assert.assertArrayEquals;
1717

@@ -32,7 +32,6 @@
3232

3333
import com.amazonaws.encryptionsdk.AwsCrypto;
3434
import com.amazonaws.encryptionsdk.CryptoResult;
35-
import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProvider;
3635
import com.fasterxml.jackson.core.type.TypeReference;
3736
import com.fasterxml.jackson.databind.ObjectMapper;
3837

0 commit comments

Comments
 (0)