diff --git a/vertx-auth-common/src/main/java/io/vertx/ext/auth/impl/jose/JWT.java b/vertx-auth-common/src/main/java/io/vertx/ext/auth/impl/jose/JWT.java index 61c3a7c21..9266a8cf9 100644 --- a/vertx-auth-common/src/main/java/io/vertx/ext/auth/impl/jose/JWT.java +++ b/vertx-auth-common/src/main/java/io/vertx/ext/auth/impl/jose/JWT.java @@ -44,8 +44,46 @@ public final class JWT { private static final Logger LOG = LoggerFactory.getLogger(JWT.class); - // simple random as its value is just to create entropy - private static final Random RND = new Random(); + /** + * Internal holder of keys and usage tracking + */ + private static class KeyRing implements Iterable { + private final List signers = new ArrayList<>(); + private int cnt = 0; + + int size() { + return signers.size(); + } + + JWS get(int pos) { + return signers.get(pos); + } + + JWS get() { + int size = signers.size(); + switch (size) { + case 0: + return null; + case 1: + return signers.get(0); + default: + return signers.get(cnt++ % size); + } + } + + JWS set(int pos, JWS jws) { + return signers.set(pos, jws); + } + + boolean add(JWS jws) { + return signers.add(jws); + } + + @Override + public Iterator iterator() { + return signers.iterator(); + } + } private static final Charset UTF8 = StandardCharsets.UTF_8; @@ -54,8 +92,8 @@ public final class JWT { private MessageDigest nonceDigest; // keep 2 maps (1 for sing, 1 for verify) this simplifies the lookups - private final Map> SIGN = new ConcurrentHashMap<>(); - private final Map> VERIFY = new ConcurrentHashMap<>(); + private final Map SIGN = new ConcurrentHashMap<>(); + private final Map VERIFY = new ConcurrentHashMap<>(); /** * Adds a JSON Web Key (rfc7517) to the signature maps. @@ -66,14 +104,14 @@ public final class JWT { public JWT addJWK(JWK jwk) { if (jwk.use() == null || "sig".equals(jwk.use())) { - List current; + KeyRing current; synchronized (this) { if (jwk.mac() != null || jwk.publicKey() != null) { - current = VERIFY.computeIfAbsent(jwk.getAlgorithm(), k -> new ArrayList<>()); + current = VERIFY.computeIfAbsent(jwk.getAlgorithm(), k -> new KeyRing()); addJWK(current, jwk); } if (jwk.mac() != null || jwk.privateKey() != null) { - current = SIGN.computeIfAbsent(jwk.getAlgorithm(), k -> new ArrayList<>()); + current = SIGN.computeIfAbsent(jwk.getAlgorithm(), k -> new KeyRing()); addJWK(current, jwk); } } @@ -129,13 +167,13 @@ public JWT nonceAlgorithm(String alg) { return this; } - private void addJWK(List current, JWK jwk) { + private void addJWK(KeyRing keyring, JWK jwk) { boolean replaced = false; - for (int i = 0; i < current.size(); i++) { - if (current.get(i).jwk().label().equals(jwk.label())) { + for (int i = 0; i < keyring.size(); i++) { + if (keyring.get(i).jwk().label().equals(jwk.label())) { // replace LOG.info("replacing JWK with label " + jwk.label()); - current.set(i, new JWS(jwk)); + keyring.set(i, new JWS(jwk)); replaced = true; break; } @@ -143,7 +181,7 @@ private void addJWK(List current, JWK jwk) { if (!replaced) { // non existent, add it! - current.add(new JWS(jwk)); + keyring.add(new JWS(jwk)); } } @@ -270,7 +308,7 @@ public JsonObject decode(final String token, boolean full, List crls) t // verify signature. `sign` will return base64 string. if (!unsecure) { - List signatures = VERIFY.get(alg); + KeyRing signatures = VERIFY.get(alg); if (signatures == null || signatures.size() == 0) { throw new NoSuchKeyIdException(alg); @@ -331,14 +369,14 @@ public String sign(JsonObject payload, JWTOptions options) { final String kid; if (!unsecure) { - List signatures = SIGN.get(algorithm); + KeyRing signatures = SIGN.get(algorithm); if (signatures == null || signatures.size() == 0) { throw new RuntimeException("Algorithm not supported/allowed: " + algorithm); } // lock the crypto implementation - jws = signatures.get(signatures.size() == 1 ? 0 : RND.nextInt(signatures.size())); + jws = signatures.get(); kid = jws.jwk().getId(); } else { jws = null;