Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 169 additions & 45 deletions msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/JwtHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
package com.microsoft.aad.msal4j;

import java.nio.charset.StandardCharsets;
import java.security.InvalidKeyException;
import java.security.Signature;
import java.security.spec.MGF1ParameterSpec;
import java.security.spec.PSSParameterSpec;
import java.util.ArrayList;
import java.util.Base64;
import java.util.HashMap;
Expand All @@ -22,62 +25,183 @@ static ClientAssertion buildJwt(String clientId, final ClientCertificate credent
ParameterValidationUtils.validateNotNull("credential", clientId);

try {
final long time = System.currentTimeMillis();

// Build header
Map<String, Object> header = new HashMap<>();
header.put("alg", "RS256");
header.put("typ", "JWT");

if (sendX5c) {
List<String> certs = new ArrayList<>(credential.getEncodedPublicKeyCertificateChain());
header.put("x5c", certs);
}

//SHA-256 is preferred, however certain flows still require SHA-1 due to what is supported server-side. If SHA-256
// is not supported or the IClientCredential.publicCertificateHash256() method is not implemented, the library will default to SHA-1.
String hash256 = credential.publicCertificateHash256();
if (useSha1 || hash256 == null) {
header.put("x5t", credential.publicCertificateHash());
} else {
header.put("x5t#S256", hash256);
// First try with PS256 (preferred)
return generatePS256Jwt(clientId, credential, jwtAudience, sendX5c, useSha1);
} catch (InvalidKeyException e) {
// If the key isn't compatible with PSS, fall back to RS256.
// This is for backwards compatibility, as the Signature instance created with SHA256withRSA
// accepted key types that weren't RSAPrivateKey but the RSASSA-PSS signature does not.
try {
return generateRs256Jwt(clientId, credential, jwtAudience, sendX5c, useSha1);
} catch (Exception fallbackException) {
throw new MsalClientException(fallbackException);
}
} catch (Exception e) {
throw new MsalClientException(e);
}
}

// Build payload
Map<String, Object> payload = new HashMap<>();
payload.put("aud", jwtAudience);
payload.put("iss", clientId);
payload.put("jti", UUID.randomUUID().toString());
payload.put("nbf", time / 1000);
payload.put("exp", time / 1000 + Constants.AAD_JWT_TOKEN_LIFETIME_SECONDS);
payload.put("sub", clientId);
/**
* Generates a JWT signed using the PS256 algorithm (RSASSA-PSS with SHA-256).
*
* @param clientId The client ID to use as the issuer and subject
* @param credential The certificate credential used for signing
* @param jwtAudience The audience claim for the JWT
* @param sendX5c Whether to include the x5c header with certificate chain
* @param useSha1 Whether to use SHA-1 hash for thumbprint instead of SHA-256
* @return A ClientAssertion containing the signed JWT
* @throws Exception If JWT creation or signing fails
*/
private static ClientAssertion generatePS256Jwt(String clientId, ClientCertificate credential,
String jwtAudience, boolean sendX5c,
boolean useSha1) throws Exception {
// Build header with PS256 algorithm
Map<String, Object> header = createHeader(credential, sendX5c, useSha1, "PS256");

// Build payload
Map<String, Object> payload = createPayload(clientId, jwtAudience, System.currentTimeMillis());

// Encode header and payload
String jsonHeader = JsonHelper.writeJsonMap(header);
String jsonPayload = JsonHelper.writeJsonMap(payload);
String encodedHeader = base64UrlEncode(jsonHeader.getBytes(StandardCharsets.UTF_8));
String encodedPayload = base64UrlEncode(jsonPayload.getBytes(StandardCharsets.UTF_8));
String dataToSign = encodedHeader + "." + encodedPayload;

// Sign with PS256
byte[] signatureBytes = signWithPS256(credential, dataToSign);
String encodedSignature = base64UrlEncode(signatureBytes);

// Build the JWT
String jwt = dataToSign + "." + encodedSignature;
return new ClientAssertion(jwt);
}

// Concatenate header and payload
String jsonHeader = JsonHelper.writeJsonMap(header);
String jsonPayload = JsonHelper.writeJsonMap(payload);
/**
* Generates a JWT signed using the RS256 algorithm (RSASSA-PKCS1-v1_5 with SHA-256).
* This is used as a fallback when PS256 is not supported by the private key.
*
* @param clientId The client ID to use as the issuer and subject
* @param credential The certificate credential used for signing
* @param jwtAudience The audience claim for the JWT
* @param sendX5c Whether to include the x5c header with certificate chain
* @param useSha1 Whether to use SHA-1 hash for thumbprint instead of SHA-256
* @return A ClientAssertion containing the signed JWT
* @throws Exception If JWT creation or signing fails
*/
private static ClientAssertion generateRs256Jwt(String clientId, ClientCertificate credential,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seem to be a lot of code dup between thjis method and the ps256 version

String jwtAudience, boolean sendX5c,
boolean useSha1) throws Exception {
// Build header with RS256 algorithm
Map<String, Object> header = createHeader(credential, sendX5c, useSha1, "RS256");

// Build payload
Map<String, Object> payload = createPayload(clientId, jwtAudience, System.currentTimeMillis());

// Encode header and payload
String jsonHeader = JsonHelper.writeJsonMap(header);
String jsonPayload = JsonHelper.writeJsonMap(payload);
String encodedHeader = base64UrlEncode(jsonHeader.getBytes(StandardCharsets.UTF_8));
String encodedPayload = base64UrlEncode(jsonPayload.getBytes(StandardCharsets.UTF_8));
String dataToSign = encodedHeader + "." + encodedPayload;

// Sign with RS256
byte[] signatureBytes = signWithRS256(credential, dataToSign);
String encodedSignature = base64UrlEncode(signatureBytes);

// Build the JWT
String jwt = dataToSign + "." + encodedSignature;
return new ClientAssertion(jwt);
}

String encodedHeader = base64UrlEncode(jsonHeader.getBytes(StandardCharsets.UTF_8));
String encodedPayload = base64UrlEncode(jsonPayload.getBytes(StandardCharsets.UTF_8));
/**
* Creates the JWT header with the specified algorithm and certificate information.
*
* @param credential The certificate credential containing thumbprint and chain
* @param sendX5c Whether to include the x5c header with certificate chain
* @param useSha1 Whether to use SHA-1 hash for thumbprint instead of SHA-256
* @param algorithm The signing algorithm to specify in the header (PS256 or RS256)
* @return A map containing the JWT header claims
* @throws Exception If certificate operations fail
*/
private static Map<String, Object> createHeader(ClientCertificate credential, boolean sendX5c,
boolean useSha1, String algorithm) throws Exception {
Map<String, Object> header = new HashMap<>();
header.put("alg", algorithm);
header.put("typ", "JWT");

if (sendX5c) {
List<String> certs = new ArrayList<>(credential.getEncodedPublicKeyCertificateChain());
header.put("x5c", certs);
}

// Create signature
String dataToSign = encodedHeader + "." + encodedPayload;
// SHA-256 is preferred, however certain flows still require SHA-1
String hash256 = credential.publicCertificateHash256();
if (useSha1 || hash256 == null) {
header.put("x5t", credential.publicCertificateHash());
} else {
header.put("x5t#S256", hash256);
}

Signature sig = Signature.getInstance("SHA256withRSA");
sig.initSign(credential.privateKey());
sig.update(dataToSign.getBytes(StandardCharsets.UTF_8));
byte[] signatureBytes = sig.sign();
return header;
}

String encodedSignature = base64UrlEncode(signatureBytes);
/**
* Creates the JWT payload with standard claims.
*
* @param clientId The client ID to use as the issuer and subject
* @param audience The audience claim for the JWT
* @param time The current time in milliseconds
* @return A map containing the JWT payload claims
*/
private static Map<String, Object> createPayload(String clientId, String audience, long time) {
Map<String, Object> payload = new HashMap<>();
payload.put("aud", audience);
payload.put("iss", clientId);
payload.put("jti", UUID.randomUUID().toString());
payload.put("nbf", time / 1000);
payload.put("exp", time / 1000 + Constants.AAD_JWT_TOKEN_LIFETIME_SECONDS);
payload.put("sub", clientId);
return payload;
}

// Build the JWT
String jwt = dataToSign + "." + encodedSignature;
/**
* Signs data using the PS256 algorithm (RSASSA-PSS with SHA-256).
*
* @param credential The certificate credential containing the private key
* @param dataToSign The data to sign
* @return The signature bytes
* @throws Exception If signing fails
*/
private static byte[] signWithPS256(ClientCertificate credential, String dataToSign) throws Exception {
Signature sig = Signature.getInstance("RSASSA-PSS");
sig.setParameter(new PSSParameterSpec("SHA-256", "MGF1", MGF1ParameterSpec.SHA256, 32, 1));
sig.initSign(credential.privateKey());
sig.update(dataToSign.getBytes(StandardCharsets.UTF_8));
return sig.sign();
}

return new ClientAssertion(jwt);
} catch (final Exception e) {
throw new MsalClientException(e);
}
/**
* Signs data using the RS256 algorithm (RSASSA-PKCS1-v1_5 with SHA-256).
*
* @param credential The certificate credential containing the private key
* @param dataToSign The data to sign
* @return The signature bytes
* @throws Exception If signing fails
*/
private static byte[] signWithRS256(ClientCertificate credential, String dataToSign) throws Exception {
Signature sig = Signature.getInstance("SHA256withRSA");
sig.initSign(credential.privateKey());
sig.update(dataToSign.getBytes(StandardCharsets.UTF_8));
return sig.sign();
}

/**
* Encodes bytes using Base64URL encoding without padding.
*
* @param data The data to encode
* @return The Base64URL encoded string
*/
private static String base64UrlEncode(byte[] data) {
return Base64.getUrlEncoder().withoutPadding().encodeToString(data);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@

package com.microsoft.aad.msal4j;

import com.nimbusds.jwt.SignedJWT;
import org.junit.jupiter.api.Test;

import java.nio.charset.StandardCharsets;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.cert.CertificateEncodingException;
import java.security.interfaces.RSAPrivateKey;
import java.util.*;

import static org.junit.jupiter.api.Assertions.*;
Expand Down Expand Up @@ -115,7 +118,7 @@ void JwtHelper_buildJwt_ValidSha1AndSha256Assertions() throws MsalClientExceptio

// Decode and verify headers
String headerJson = new String(Base64.getUrlDecoder().decode(jwtParts[0]));
assertTrue(headerJson.contains("\"alg\":\"RS256\""), "Header should specify RS256 algorithm");
assertTrue(headerJson.contains("\"alg\":\"PS256\""), "Header should specify RS256 algorithm");
assertTrue(headerJson.contains("\"typ\":\"JWT\""), "Header should specify JWT type");
assertTrue(headerJson.contains("\"x5t#S256\":\"certificateHash256\""), "Header should contain x5t#S256");
assertTrue(headerJson.contains("\"x5c\":[\"cert1\",\"cert2\"]"), "Header should contain x5c");
Expand Down Expand Up @@ -187,4 +190,99 @@ void JsonHelper_createIdTokenFromEncodedTokenString_InvalidJsonInToken() {

assertEquals(AuthenticationErrorCode.INVALID_JSON, exception.errorCode());
}

@Test
void JwtHelper_buildJwt_UsesPSS256WhenSupported() throws Exception {
Copy link
Member

@bgavrilMS bgavrilMS Sep 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way it's implemented in MSAL .NET is:

  • for AAD, use SHA2 and PSS
  • for ADFS, use SHA1 and PKCS1

I see that the way to specify a certificate in MSAL Java is similar. So why not have some higher level tests, i.e. start from the public API and assert what gets put on the wire?

// Create a certificate mock with an RSAPrivateKey that supports PSS
RSAPrivateKey rsaPrivateKey = (RSAPrivateKey) TestHelper.getPrivateKey();

ClientCertificate clientCertificateMock = mock(ClientCertificate.class);
when(clientCertificateMock.privateKey()).thenReturn(rsaPrivateKey);
when(clientCertificateMock.publicCertificateHash()).thenReturn("certificateHash");
when(clientCertificateMock.publicCertificateHash256()).thenReturn("certificateHash256");
when(clientCertificateMock.getEncodedPublicKeyCertificateChain()).thenReturn(Arrays.asList("cert1", "cert2"));

String clientId = "clientId";
String audience = "https://login.microsoftonline.com/common/oauth2/v2.0/token";

// Create the JWT
ClientAssertion clientAssertion = JwtHelper.buildJwt(clientId, clientCertificateMock, audience, true, false);

assertNotNull(clientAssertion);
String jwt = clientAssertion.assertion();
String[] jwtParts = jwt.split("\\.");
assertEquals(3, jwtParts.length, "JWT should have three parts");

// Decode and verify header uses PS256
String headerJson = new String(Base64.getUrlDecoder().decode(jwtParts[0]));
assertTrue(headerJson.contains("\"alg\":\"PS256\""), "Header should specify PS256 algorithm");

// Parse the JWT to verify the algorithm is PS256
SignedJWT signedJWT = SignedJWT.parse(jwt);
assertEquals("PS256", signedJWT.getHeader().getAlgorithm().getName(), "JWT should use PS256 algorithm");
}

@Test
void JwtHelper_buildJwt_FallsBackToRS256WhenPSSNotSupported() throws Exception {
// When loaded from the Windows-MY keystore the PrivateKey will be a sun.security.mscapi.CPrivateKey,
// which for some reason works with the library's older RS256 signature but not the newer PSS signature.
PrivateKey nonRsaCompatibleKey = TestHelper.getPrivateKeyFromKeystore();

// This key should cause the PSS code to fail with an InvalidKeyException
ClientCertificate clientCertificateMock = mock(ClientCertificate.class);
when(clientCertificateMock.privateKey()).thenReturn(nonRsaCompatibleKey);
when(clientCertificateMock.publicCertificateHash()).thenReturn("certificateHash");
when(clientCertificateMock.publicCertificateHash256()).thenReturn("certificateHash256");
when(clientCertificateMock.getEncodedPublicKeyCertificateChain()).thenReturn(Arrays.asList("cert1", "cert2"));

String clientId = "clientId";
String audience = "https://login.microsoftonline.com/common/oauth2/v2.0/token";

// Create the JWT - this should fallback to RS256
ClientAssertion clientAssertion = JwtHelper.buildJwt(clientId, clientCertificateMock, audience, true, false);

assertNotNull(clientAssertion);
String jwt = clientAssertion.assertion();
String[] jwtParts = jwt.split("\\.");
assertEquals(3, jwtParts.length, "JWT should have three parts");

// Decode and verify header uses RS256 as fallback
String headerJson = new String(Base64.getUrlDecoder().decode(jwtParts[0]));
assertTrue(headerJson.contains("\"alg\":\"RS256\""), "Header should specify RS256 algorithm as fallback");
}

@Test
void JwtHelper_buildJwt_UsesCorrectSignatureAlgorithmsBasedOnKeyType() throws Exception {
// Use real keys for both RSA and non-RSA tests
RSAPrivateKey rsaPrivateKey = (RSAPrivateKey) TestHelper.getPrivateKey();
PrivateKey nonRsaPrivateKey = TestHelper.privateKeyFromKeystore;

ClientCertificate rsaCertMock = mock(ClientCertificate.class);
when(rsaCertMock.privateKey()).thenReturn(rsaPrivateKey);
when(rsaCertMock.publicCertificateHash256()).thenReturn("certHash256");
when(rsaCertMock.getEncodedPublicKeyCertificateChain()).thenReturn(Arrays.asList("cert1", "cert2"));

ClientCertificate nonRsaCertMock = mock(ClientCertificate.class);
when(nonRsaCertMock.privateKey()).thenReturn(nonRsaPrivateKey);
when(nonRsaCertMock.publicCertificateHash256()).thenReturn("certHash256");
when(nonRsaCertMock.getEncodedPublicKeyCertificateChain()).thenReturn(Arrays.asList("cert1", "cert2"));

String clientId = "clientId";
String audience = "https://login.microsoftonline.com/common/oauth2/v2.0/token";

// Test RSA key -> should use PS256
ClientAssertion rsaAssertion = JwtHelper.buildJwt(clientId, rsaCertMock, audience, true, false);
String rsaJwt = rsaAssertion.assertion();
String rsaHeader = new String(Base64.getUrlDecoder().decode(rsaJwt.split("\\.")[0]));
assertTrue(rsaHeader.contains("\"alg\":\"PS256\""), "RSA key should produce PS256 algorithm");

// Test non-RSA key -> should fallback to RS256
ClientAssertion nonRsaAssertion = JwtHelper.buildJwt(clientId, nonRsaCertMock, audience, true, false);
String nonRsaJwt = nonRsaAssertion.assertion();
String nonRsaHeader = new String(Base64.getUrlDecoder().decode(nonRsaJwt.split("\\.")[0]));
assertTrue(nonRsaHeader.contains("\"alg\":\"RS256\""), "Non-RSA key should fallback to RS256 algorithm");

// Verify we're actually using different keys for the different tests
assertNotEquals(rsaJwt, nonRsaJwt, "The two assertions should be different");
}
}
Loading