-
Notifications
You must be signed in to change notification settings - Fork 153
Use recommended algorithm in assertions #990
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,10 @@ | |
package com.microsoft.aad.msal4j; | ||
|
||
import java.nio.charset.StandardCharsets; | ||
import java.security.InvalidKeyException; | ||
import java.security.Signature; | ||
import java.security.spec.MGF1ParameterSpec; | ||
import java.security.spec.PSSParameterSpec; | ||
import java.util.ArrayList; | ||
import java.util.Base64; | ||
import java.util.HashMap; | ||
|
@@ -22,62 +25,183 @@ static ClientAssertion buildJwt(String clientId, final ClientCertificate credent | |
ParameterValidationUtils.validateNotNull("credential", clientId); | ||
|
||
try { | ||
final long time = System.currentTimeMillis(); | ||
|
||
// Build header | ||
Map<String, Object> header = new HashMap<>(); | ||
header.put("alg", "RS256"); | ||
header.put("typ", "JWT"); | ||
|
||
if (sendX5c) { | ||
List<String> certs = new ArrayList<>(credential.getEncodedPublicKeyCertificateChain()); | ||
header.put("x5c", certs); | ||
} | ||
|
||
//SHA-256 is preferred, however certain flows still require SHA-1 due to what is supported server-side. If SHA-256 | ||
// is not supported or the IClientCredential.publicCertificateHash256() method is not implemented, the library will default to SHA-1. | ||
String hash256 = credential.publicCertificateHash256(); | ||
if (useSha1 || hash256 == null) { | ||
header.put("x5t", credential.publicCertificateHash()); | ||
} else { | ||
header.put("x5t#S256", hash256); | ||
// First try with PS256 (preferred) | ||
return generatePS256Jwt(clientId, credential, jwtAudience, sendX5c, useSha1); | ||
} catch (InvalidKeyException e) { | ||
// If the key isn't compatible with PSS, fall back to RS256. | ||
// This is for backwards compatibility, as the Signature instance created with SHA256withRSA | ||
// accepted key types that weren't RSAPrivateKey but the RSASSA-PSS signature does not. | ||
try { | ||
return generateRs256Jwt(clientId, credential, jwtAudience, sendX5c, useSha1); | ||
} catch (Exception fallbackException) { | ||
throw new MsalClientException(fallbackException); | ||
} | ||
} catch (Exception e) { | ||
throw new MsalClientException(e); | ||
} | ||
} | ||
|
||
// Build payload | ||
Map<String, Object> payload = new HashMap<>(); | ||
payload.put("aud", jwtAudience); | ||
payload.put("iss", clientId); | ||
payload.put("jti", UUID.randomUUID().toString()); | ||
payload.put("nbf", time / 1000); | ||
payload.put("exp", time / 1000 + Constants.AAD_JWT_TOKEN_LIFETIME_SECONDS); | ||
payload.put("sub", clientId); | ||
/** | ||
* Generates a JWT signed using the PS256 algorithm (RSASSA-PSS with SHA-256). | ||
* | ||
* @param clientId The client ID to use as the issuer and subject | ||
* @param credential The certificate credential used for signing | ||
* @param jwtAudience The audience claim for the JWT | ||
* @param sendX5c Whether to include the x5c header with certificate chain | ||
* @param useSha1 Whether to use SHA-1 hash for thumbprint instead of SHA-256 | ||
* @return A ClientAssertion containing the signed JWT | ||
* @throws Exception If JWT creation or signing fails | ||
*/ | ||
private static ClientAssertion generatePS256Jwt(String clientId, ClientCertificate credential, | ||
String jwtAudience, boolean sendX5c, | ||
boolean useSha1) throws Exception { | ||
// Build header with PS256 algorithm | ||
Map<String, Object> header = createHeader(credential, sendX5c, useSha1, "PS256"); | ||
|
||
// Build payload | ||
Map<String, Object> payload = createPayload(clientId, jwtAudience, System.currentTimeMillis()); | ||
|
||
// Encode header and payload | ||
String jsonHeader = JsonHelper.writeJsonMap(header); | ||
String jsonPayload = JsonHelper.writeJsonMap(payload); | ||
String encodedHeader = base64UrlEncode(jsonHeader.getBytes(StandardCharsets.UTF_8)); | ||
String encodedPayload = base64UrlEncode(jsonPayload.getBytes(StandardCharsets.UTF_8)); | ||
String dataToSign = encodedHeader + "." + encodedPayload; | ||
|
||
// Sign with PS256 | ||
byte[] signatureBytes = signWithPS256(credential, dataToSign); | ||
String encodedSignature = base64UrlEncode(signatureBytes); | ||
|
||
// Build the JWT | ||
String jwt = dataToSign + "." + encodedSignature; | ||
return new ClientAssertion(jwt); | ||
} | ||
|
||
// Concatenate header and payload | ||
String jsonHeader = JsonHelper.writeJsonMap(header); | ||
String jsonPayload = JsonHelper.writeJsonMap(payload); | ||
/** | ||
* Generates a JWT signed using the RS256 algorithm (RSASSA-PKCS1-v1_5 with SHA-256). | ||
* This is used as a fallback when PS256 is not supported by the private key. | ||
* | ||
* @param clientId The client ID to use as the issuer and subject | ||
* @param credential The certificate credential used for signing | ||
* @param jwtAudience The audience claim for the JWT | ||
* @param sendX5c Whether to include the x5c header with certificate chain | ||
* @param useSha1 Whether to use SHA-1 hash for thumbprint instead of SHA-256 | ||
* @return A ClientAssertion containing the signed JWT | ||
* @throws Exception If JWT creation or signing fails | ||
*/ | ||
private static ClientAssertion generateRs256Jwt(String clientId, ClientCertificate credential, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There seem to be a lot of code dup between thjis method and the ps256 version |
||
String jwtAudience, boolean sendX5c, | ||
boolean useSha1) throws Exception { | ||
// Build header with RS256 algorithm | ||
Map<String, Object> header = createHeader(credential, sendX5c, useSha1, "RS256"); | ||
|
||
// Build payload | ||
Map<String, Object> payload = createPayload(clientId, jwtAudience, System.currentTimeMillis()); | ||
|
||
// Encode header and payload | ||
String jsonHeader = JsonHelper.writeJsonMap(header); | ||
String jsonPayload = JsonHelper.writeJsonMap(payload); | ||
String encodedHeader = base64UrlEncode(jsonHeader.getBytes(StandardCharsets.UTF_8)); | ||
String encodedPayload = base64UrlEncode(jsonPayload.getBytes(StandardCharsets.UTF_8)); | ||
String dataToSign = encodedHeader + "." + encodedPayload; | ||
|
||
// Sign with RS256 | ||
byte[] signatureBytes = signWithRS256(credential, dataToSign); | ||
String encodedSignature = base64UrlEncode(signatureBytes); | ||
|
||
// Build the JWT | ||
String jwt = dataToSign + "." + encodedSignature; | ||
return new ClientAssertion(jwt); | ||
} | ||
|
||
String encodedHeader = base64UrlEncode(jsonHeader.getBytes(StandardCharsets.UTF_8)); | ||
String encodedPayload = base64UrlEncode(jsonPayload.getBytes(StandardCharsets.UTF_8)); | ||
/** | ||
* Creates the JWT header with the specified algorithm and certificate information. | ||
* | ||
* @param credential The certificate credential containing thumbprint and chain | ||
* @param sendX5c Whether to include the x5c header with certificate chain | ||
* @param useSha1 Whether to use SHA-1 hash for thumbprint instead of SHA-256 | ||
* @param algorithm The signing algorithm to specify in the header (PS256 or RS256) | ||
* @return A map containing the JWT header claims | ||
* @throws Exception If certificate operations fail | ||
*/ | ||
private static Map<String, Object> createHeader(ClientCertificate credential, boolean sendX5c, | ||
boolean useSha1, String algorithm) throws Exception { | ||
Map<String, Object> header = new HashMap<>(); | ||
header.put("alg", algorithm); | ||
header.put("typ", "JWT"); | ||
|
||
if (sendX5c) { | ||
List<String> certs = new ArrayList<>(credential.getEncodedPublicKeyCertificateChain()); | ||
header.put("x5c", certs); | ||
} | ||
|
||
// Create signature | ||
String dataToSign = encodedHeader + "." + encodedPayload; | ||
// SHA-256 is preferred, however certain flows still require SHA-1 | ||
String hash256 = credential.publicCertificateHash256(); | ||
if (useSha1 || hash256 == null) { | ||
header.put("x5t", credential.publicCertificateHash()); | ||
} else { | ||
header.put("x5t#S256", hash256); | ||
} | ||
|
||
Signature sig = Signature.getInstance("SHA256withRSA"); | ||
sig.initSign(credential.privateKey()); | ||
sig.update(dataToSign.getBytes(StandardCharsets.UTF_8)); | ||
byte[] signatureBytes = sig.sign(); | ||
return header; | ||
} | ||
|
||
String encodedSignature = base64UrlEncode(signatureBytes); | ||
/** | ||
* Creates the JWT payload with standard claims. | ||
* | ||
* @param clientId The client ID to use as the issuer and subject | ||
* @param audience The audience claim for the JWT | ||
* @param time The current time in milliseconds | ||
* @return A map containing the JWT payload claims | ||
*/ | ||
private static Map<String, Object> createPayload(String clientId, String audience, long time) { | ||
Map<String, Object> payload = new HashMap<>(); | ||
payload.put("aud", audience); | ||
payload.put("iss", clientId); | ||
payload.put("jti", UUID.randomUUID().toString()); | ||
payload.put("nbf", time / 1000); | ||
payload.put("exp", time / 1000 + Constants.AAD_JWT_TOKEN_LIFETIME_SECONDS); | ||
payload.put("sub", clientId); | ||
return payload; | ||
} | ||
|
||
// Build the JWT | ||
String jwt = dataToSign + "." + encodedSignature; | ||
/** | ||
* Signs data using the PS256 algorithm (RSASSA-PSS with SHA-256). | ||
* | ||
* @param credential The certificate credential containing the private key | ||
* @param dataToSign The data to sign | ||
* @return The signature bytes | ||
* @throws Exception If signing fails | ||
*/ | ||
private static byte[] signWithPS256(ClientCertificate credential, String dataToSign) throws Exception { | ||
Signature sig = Signature.getInstance("RSASSA-PSS"); | ||
sig.setParameter(new PSSParameterSpec("SHA-256", "MGF1", MGF1ParameterSpec.SHA256, 32, 1)); | ||
sig.initSign(credential.privateKey()); | ||
sig.update(dataToSign.getBytes(StandardCharsets.UTF_8)); | ||
return sig.sign(); | ||
} | ||
|
||
return new ClientAssertion(jwt); | ||
} catch (final Exception e) { | ||
throw new MsalClientException(e); | ||
} | ||
/** | ||
* Signs data using the RS256 algorithm (RSASSA-PKCS1-v1_5 with SHA-256). | ||
* | ||
* @param credential The certificate credential containing the private key | ||
* @param dataToSign The data to sign | ||
* @return The signature bytes | ||
* @throws Exception If signing fails | ||
*/ | ||
private static byte[] signWithRS256(ClientCertificate credential, String dataToSign) throws Exception { | ||
Signature sig = Signature.getInstance("SHA256withRSA"); | ||
sig.initSign(credential.privateKey()); | ||
sig.update(dataToSign.getBytes(StandardCharsets.UTF_8)); | ||
return sig.sign(); | ||
} | ||
|
||
/** | ||
* Encodes bytes using Base64URL encoding without padding. | ||
* | ||
* @param data The data to encode | ||
* @return The Base64URL encoded string | ||
*/ | ||
private static String base64UrlEncode(byte[] data) { | ||
return Base64.getUrlEncoder().withoutPadding().encodeToString(data); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,11 +3,14 @@ | |
|
||
package com.microsoft.aad.msal4j; | ||
|
||
import com.nimbusds.jwt.SignedJWT; | ||
import org.junit.jupiter.api.Test; | ||
|
||
import java.nio.charset.StandardCharsets; | ||
import java.security.NoSuchAlgorithmException; | ||
import java.security.PrivateKey; | ||
import java.security.cert.CertificateEncodingException; | ||
import java.security.interfaces.RSAPrivateKey; | ||
import java.util.*; | ||
|
||
import static org.junit.jupiter.api.Assertions.*; | ||
|
@@ -115,7 +118,7 @@ void JwtHelper_buildJwt_ValidSha1AndSha256Assertions() throws MsalClientExceptio | |
|
||
// Decode and verify headers | ||
String headerJson = new String(Base64.getUrlDecoder().decode(jwtParts[0])); | ||
assertTrue(headerJson.contains("\"alg\":\"RS256\""), "Header should specify RS256 algorithm"); | ||
assertTrue(headerJson.contains("\"alg\":\"PS256\""), "Header should specify RS256 algorithm"); | ||
assertTrue(headerJson.contains("\"typ\":\"JWT\""), "Header should specify JWT type"); | ||
assertTrue(headerJson.contains("\"x5t#S256\":\"certificateHash256\""), "Header should contain x5t#S256"); | ||
assertTrue(headerJson.contains("\"x5c\":[\"cert1\",\"cert2\"]"), "Header should contain x5c"); | ||
|
@@ -187,4 +190,99 @@ void JsonHelper_createIdTokenFromEncodedTokenString_InvalidJsonInToken() { | |
|
||
assertEquals(AuthenticationErrorCode.INVALID_JSON, exception.errorCode()); | ||
} | ||
|
||
@Test | ||
void JwtHelper_buildJwt_UsesPSS256WhenSupported() throws Exception { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The way it's implemented in MSAL .NET is:
I see that the way to specify a certificate in MSAL Java is similar. So why not have some higher level tests, i.e. start from the public API and assert what gets put on the wire? |
||
// Create a certificate mock with an RSAPrivateKey that supports PSS | ||
RSAPrivateKey rsaPrivateKey = (RSAPrivateKey) TestHelper.getPrivateKey(); | ||
|
||
ClientCertificate clientCertificateMock = mock(ClientCertificate.class); | ||
when(clientCertificateMock.privateKey()).thenReturn(rsaPrivateKey); | ||
when(clientCertificateMock.publicCertificateHash()).thenReturn("certificateHash"); | ||
when(clientCertificateMock.publicCertificateHash256()).thenReturn("certificateHash256"); | ||
when(clientCertificateMock.getEncodedPublicKeyCertificateChain()).thenReturn(Arrays.asList("cert1", "cert2")); | ||
|
||
String clientId = "clientId"; | ||
String audience = "https://login.microsoftonline.com/common/oauth2/v2.0/token"; | ||
|
||
// Create the JWT | ||
ClientAssertion clientAssertion = JwtHelper.buildJwt(clientId, clientCertificateMock, audience, true, false); | ||
|
||
assertNotNull(clientAssertion); | ||
String jwt = clientAssertion.assertion(); | ||
String[] jwtParts = jwt.split("\\."); | ||
assertEquals(3, jwtParts.length, "JWT should have three parts"); | ||
|
||
// Decode and verify header uses PS256 | ||
String headerJson = new String(Base64.getUrlDecoder().decode(jwtParts[0])); | ||
assertTrue(headerJson.contains("\"alg\":\"PS256\""), "Header should specify PS256 algorithm"); | ||
|
||
// Parse the JWT to verify the algorithm is PS256 | ||
SignedJWT signedJWT = SignedJWT.parse(jwt); | ||
assertEquals("PS256", signedJWT.getHeader().getAlgorithm().getName(), "JWT should use PS256 algorithm"); | ||
} | ||
|
||
@Test | ||
void JwtHelper_buildJwt_FallsBackToRS256WhenPSSNotSupported() throws Exception { | ||
// When loaded from the Windows-MY keystore the PrivateKey will be a sun.security.mscapi.CPrivateKey, | ||
// which for some reason works with the library's older RS256 signature but not the newer PSS signature. | ||
PrivateKey nonRsaCompatibleKey = TestHelper.getPrivateKeyFromKeystore(); | ||
|
||
// This key should cause the PSS code to fail with an InvalidKeyException | ||
ClientCertificate clientCertificateMock = mock(ClientCertificate.class); | ||
when(clientCertificateMock.privateKey()).thenReturn(nonRsaCompatibleKey); | ||
when(clientCertificateMock.publicCertificateHash()).thenReturn("certificateHash"); | ||
when(clientCertificateMock.publicCertificateHash256()).thenReturn("certificateHash256"); | ||
when(clientCertificateMock.getEncodedPublicKeyCertificateChain()).thenReturn(Arrays.asList("cert1", "cert2")); | ||
|
||
String clientId = "clientId"; | ||
String audience = "https://login.microsoftonline.com/common/oauth2/v2.0/token"; | ||
|
||
// Create the JWT - this should fallback to RS256 | ||
ClientAssertion clientAssertion = JwtHelper.buildJwt(clientId, clientCertificateMock, audience, true, false); | ||
|
||
assertNotNull(clientAssertion); | ||
String jwt = clientAssertion.assertion(); | ||
String[] jwtParts = jwt.split("\\."); | ||
assertEquals(3, jwtParts.length, "JWT should have three parts"); | ||
|
||
// Decode and verify header uses RS256 as fallback | ||
String headerJson = new String(Base64.getUrlDecoder().decode(jwtParts[0])); | ||
assertTrue(headerJson.contains("\"alg\":\"RS256\""), "Header should specify RS256 algorithm as fallback"); | ||
} | ||
|
||
@Test | ||
void JwtHelper_buildJwt_UsesCorrectSignatureAlgorithmsBasedOnKeyType() throws Exception { | ||
// Use real keys for both RSA and non-RSA tests | ||
RSAPrivateKey rsaPrivateKey = (RSAPrivateKey) TestHelper.getPrivateKey(); | ||
PrivateKey nonRsaPrivateKey = TestHelper.privateKeyFromKeystore; | ||
|
||
ClientCertificate rsaCertMock = mock(ClientCertificate.class); | ||
when(rsaCertMock.privateKey()).thenReturn(rsaPrivateKey); | ||
when(rsaCertMock.publicCertificateHash256()).thenReturn("certHash256"); | ||
when(rsaCertMock.getEncodedPublicKeyCertificateChain()).thenReturn(Arrays.asList("cert1", "cert2")); | ||
|
||
ClientCertificate nonRsaCertMock = mock(ClientCertificate.class); | ||
when(nonRsaCertMock.privateKey()).thenReturn(nonRsaPrivateKey); | ||
when(nonRsaCertMock.publicCertificateHash256()).thenReturn("certHash256"); | ||
when(nonRsaCertMock.getEncodedPublicKeyCertificateChain()).thenReturn(Arrays.asList("cert1", "cert2")); | ||
|
||
String clientId = "clientId"; | ||
String audience = "https://login.microsoftonline.com/common/oauth2/v2.0/token"; | ||
|
||
// Test RSA key -> should use PS256 | ||
ClientAssertion rsaAssertion = JwtHelper.buildJwt(clientId, rsaCertMock, audience, true, false); | ||
String rsaJwt = rsaAssertion.assertion(); | ||
String rsaHeader = new String(Base64.getUrlDecoder().decode(rsaJwt.split("\\.")[0])); | ||
assertTrue(rsaHeader.contains("\"alg\":\"PS256\""), "RSA key should produce PS256 algorithm"); | ||
|
||
// Test non-RSA key -> should fallback to RS256 | ||
ClientAssertion nonRsaAssertion = JwtHelper.buildJwt(clientId, nonRsaCertMock, audience, true, false); | ||
String nonRsaJwt = nonRsaAssertion.assertion(); | ||
String nonRsaHeader = new String(Base64.getUrlDecoder().decode(nonRsaJwt.split("\\.")[0])); | ||
assertTrue(nonRsaHeader.contains("\"alg\":\"RS256\""), "Non-RSA key should fallback to RS256 algorithm"); | ||
|
||
// Verify we're actually using different keys for the different tests | ||
assertNotEquals(rsaJwt, nonRsaJwt, "The two assertions should be different"); | ||
} | ||
} |
Uh oh!
There was an error while loading. Please reload this page.