Skip to content

Commit

Permalink
Add support for SSL bundles in SAML2 signing auto-configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
scottfrederick committed Jul 30, 2024
1 parent 40bb05d commit e136fc1
Show file tree
Hide file tree
Showing 9 changed files with 438 additions and 111 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ public static class Credential {
* SSL bundle providing a private key used for signing and a Relying Party
* X509Certificate shared with the identity provider.
*/
private Bundle bundle;
private String bundle;

public Resource getPrivateKeyLocation() {
return this.privateKeyLocation;
Expand All @@ -196,11 +196,11 @@ public void setCertificateLocation(Resource certificate) {
this.certificateLocation = certificate;
}

public Bundle getBundle() {
public String getBundle() {
return this.bundle;
}

public void setBundle(Bundle bundle) {
public void setBundle(String bundle) {
this.bundle = bundle;
}

Expand Down Expand Up @@ -241,7 +241,7 @@ public static class Credential {
* SSL bundle providing a private key used for decrypting and a Relying Party
* X509Certificate shared with the identity provider.
*/
private Bundle bundle;
private String bundle;

public Resource getPrivateKeyLocation() {
return this.privateKeyLocation;
Expand All @@ -259,11 +259,11 @@ public void setCertificateLocation(Resource certificate) {
this.certificateLocation = certificate;
}

public Bundle getBundle() {
public String getBundle() {
return this.bundle;
}

public void setBundle(Bundle bundle) {
public void setBundle(String bundle) {
this.bundle = bundle;
}

Expand Down Expand Up @@ -400,7 +400,7 @@ public static class Credential {
* SSL bundle providing the X.509 certificate used for verification of
* incoming SAML messages.
*/
private Bundle bundle;
private String bundle;

public Resource getCertificateLocation() {
return this.certificate;
Expand All @@ -410,11 +410,11 @@ public void setCertificateLocation(Resource certificate) {
this.certificate = certificate;
}

public Bundle getBundle() {
public String getBundle() {
return this.bundle;
}

public void setBundle(Bundle bundle) {
public void setBundle(String bundle) {
this.bundle = bundle;
}

Expand Down Expand Up @@ -470,33 +470,4 @@ public void setBinding(Saml2MessageBinding binding) {

}

public static class Bundle {

/**
* Name of the SSL bundle.
*/
private String name;

/**
* Alias for the certificate to use from the SSL bundle.
*/
private String alias;

public String getName() {
return this.name;
}

public void setName(String name) {
this.name = name;
}

public String getAlias() {
return this.alias;
}

public void setAlias(String alias) {
this.alias = alias;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@
package org.springframework.boot.autoconfigure.security.saml2;

import java.io.InputStream;
import java.security.GeneralSecurityException;
import java.security.Key;
import java.security.KeyStore;
import java.security.PrivateKey;
import java.security.cert.Certificate;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.security.interfaces.RSAPrivateKey;
Expand All @@ -34,14 +30,14 @@
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.security.saml2.Saml2RelyingPartyProperties.AssertingParty;
import org.springframework.boot.autoconfigure.security.saml2.Saml2RelyingPartyProperties.AssertingParty.Verification;
import org.springframework.boot.autoconfigure.security.saml2.Saml2RelyingPartyProperties.Bundle;
import org.springframework.boot.autoconfigure.security.saml2.Saml2RelyingPartyProperties.Decryption;
import org.springframework.boot.autoconfigure.security.saml2.Saml2RelyingPartyProperties.Registration;
import org.springframework.boot.autoconfigure.security.saml2.Saml2RelyingPartyProperties.Registration.Signing.Credential;
import org.springframework.boot.context.properties.PropertyMapper;
import org.springframework.boot.ssl.SslBundle;
import org.springframework.boot.ssl.SslBundleKey;
import org.springframework.boot.ssl.SslBundleKeyStore;
import org.springframework.boot.ssl.SslBundles;
import org.springframework.boot.ssl.SslStoreBundle;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Conditional;
import org.springframework.context.annotation.Configuration;
Expand Down Expand Up @@ -163,9 +159,9 @@ private void validateSigningCredentials(Registration properties, boolean signReq
}

private Saml2X509Credential asSigningCredential(Credential properties, SslBundles sslBundles) {
Bundle sslBundle = properties.getBundle();
String sslBundle = properties.getBundle();
if (sslBundle != null) {
PrivateKey privateKey = getPrivateKey(sslBundle.getName(), sslBundles);
PrivateKey privateKey = getPrivateKey(sslBundle, sslBundles);
X509Certificate certificate = getCertificate(sslBundle, sslBundles);
return new Saml2X509Credential(privateKey, certificate, Saml2X509CredentialType.SIGNING);
}
Expand All @@ -175,9 +171,9 @@ private Saml2X509Credential asSigningCredential(Credential properties, SslBundle
}

private Saml2X509Credential asDecryptionCredential(Decryption.Credential properties, SslBundles sslBundles) {
Bundle sslBundle = properties.getBundle();
String sslBundle = properties.getBundle();
if (sslBundle != null) {
PrivateKey privateKey = getPrivateKey(sslBundle.getName(), sslBundles);
PrivateKey privateKey = getPrivateKey(sslBundle, sslBundles);
X509Certificate certificate = getCertificate(sslBundle, sslBundles);
return new Saml2X509Credential(privateKey, certificate, Saml2X509CredentialType.DECRYPTION);
}
Expand All @@ -187,7 +183,7 @@ private Saml2X509Credential asDecryptionCredential(Decryption.Credential propert
}

private Saml2X509Credential asVerificationCredential(Verification.Credential properties, SslBundles sslBundles) {
Bundle sslBundle = properties.getBundle();
String sslBundle = properties.getBundle();
if (sslBundle != null) {
X509Certificate certificate = getCertificate(sslBundle, sslBundles);
return new Saml2X509Credential(certificate, Saml2X509Credential.Saml2X509CredentialType.ENCRYPTION,
Expand Down Expand Up @@ -221,34 +217,19 @@ private X509Certificate readCertificate(Resource location) {
}

private PrivateKey getPrivateKey(String sslBundle, SslBundles sslBundles) {
try {
SslBundle bundle = sslBundles.getBundle(sslBundle);
SslBundleKey key = bundle.getKey();
KeyStore keyStore = bundle.getStores().getKeyStore();
Key privateKey = keyStore.getKey(key.getAlias(), key.getPassword().toCharArray());
Assert.notNull(privateKey,
"Private key with alias '" + key.getAlias() + "' was not found in SSL bundle '" + sslBundle + "'");
Assert.isInstanceOf(PrivateKey.class, privateKey);
return (PrivateKey) privateKey;
}
catch (GeneralSecurityException ex) {
throw new IllegalStateException("Error getting private key from SSL bundle '" + sslBundle + "'", ex);
}
SslBundle bundle = sslBundles.getBundle(sslBundle);
SslStoreBundle stores = bundle.getStores();
PrivateKey privateKey = SslBundleKeyStore.from(stores.getKeyStore(), bundle.getKey()).getPrivateKey();
Assert.notNull(privateKey, "KeyStore in SSL bundle '" + sslBundle + "' must have a private key");
return privateKey;
}

private X509Certificate getCertificate(Bundle sslBundle, SslBundles sslBundles) {
try {
SslBundle bundle = sslBundles.getBundle(sslBundle.getName());
KeyStore keyStore = bundle.getStores().getKeyStore();
Certificate certificate = keyStore.getCertificate(sslBundle.getAlias());
Assert.notNull(certificate, "Certificate with alias '" + sslBundle.getAlias()
+ "' was not found in SSL bundle '" + sslBundle + "'");
Assert.isInstanceOf(X509Certificate.class, certificate);
return (X509Certificate) certificate;
}
catch (GeneralSecurityException ex) {
throw new IllegalStateException("Error getting certificate from SSL bundle '" + sslBundle + "'", ex);
}
private X509Certificate getCertificate(String sslBundle, SslBundles sslBundles) {
SslBundle bundle = sslBundles.getBundle(sslBundle);
SslStoreBundle stores = bundle.getStores();
X509Certificate certificate = SslBundleKeyStore.from(stores.getKeyStore(), bundle.getKey()).getCertificate();
Assert.notNull(certificate, "KeyStore in SSL bundle '" + sslBundle + "' must have a certificate");
return certificate;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ class Saml2RelyingPartyAutoConfigurationTests {

private static final String PREFIX = "spring.security.saml2.relyingparty.registration";

private final WebApplicationContextRunner contextRunner = new WebApplicationContextRunner().withConfiguration(
AutoConfigurations.of(Saml2RelyingPartyAutoConfiguration.class, SecurityAutoConfiguration.class, SslAutoConfiguration.class));
private final WebApplicationContextRunner contextRunner = new WebApplicationContextRunner()
.withConfiguration(AutoConfigurations.of(Saml2RelyingPartyAutoConfiguration.class,
SecurityAutoConfiguration.class, SslAutoConfiguration.class));

@Test
void autoConfigurationShouldBeConditionalOnRelyingPartyRegistrationRepositoryClass() {
Expand Down Expand Up @@ -366,19 +367,16 @@ private String[] getPropertyValuesWithSslBundles() {
"spring.ssl.bundle.pem.saml.key.password=secret1",
"spring.ssl.bundle.pem.saml.keystore.certificate=classpath:saml/certificate-location",
"spring.ssl.bundle.pem.saml.keystore.private-key=classpath:saml/private-key-location",
PREFIX + ".foo.signing.credentials[0].bundle.name=saml",
PREFIX + ".foo.signing.credentials[0].bundle.alias=key-alias",
PREFIX + ".foo.decryption.credentials[0].bundle.name=saml",
PREFIX + ".foo.decryption.credentials[0].bundle.alias=key-alias",
PREFIX + ".foo.signing.credentials[0].bundle=saml",
PREFIX + ".foo.decryption.credentials[0].bundle=saml",
PREFIX + ".foo.singlelogout.url=https://simplesaml-for-spring-saml.cfapps.io/saml2/idp/SLOService.php",
PREFIX + ".foo.singlelogout.response-url=https://simplesaml-for-spring-saml.cfapps.io/",
PREFIX + ".foo.singlelogout.binding=post",
PREFIX + ".foo.assertingparty.singlesignon.url=https://simplesaml-for-spring-saml.cfapps.io/saml2/idp/SSOService.php",
PREFIX + ".foo.assertingparty.singlesignon.binding=post",
PREFIX + ".foo.assertingparty.singlesignon.sign-request=false",
PREFIX + ".foo.assertingparty.entity-id=https://simplesaml-for-spring-saml.cfapps.io/saml2/idp/metadata.php",
PREFIX + ".foo.assertingparty.verification.credentials[0].bundle.name=saml",
PREFIX + ".foo.assertingparty.verification.credentials[0].bundle.alias=key-alias",
PREFIX + ".foo.assertingparty.verification.credentials[0].bundle=saml",
PREFIX + ".foo.asserting-party.singlelogout.url=https://simplesaml-for-spring-saml.cfapps.io/saml2/idp/SLOService.php",
PREFIX + ".foo.asserting-party.singlelogout.response-url=https://simplesaml-for-spring-saml.cfapps.io/",
PREFIX + ".foo.asserting-party.singlelogout.binding=post",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
* Copyright 2012-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.boot.ssl;

import java.security.GeneralSecurityException;
import java.security.Key;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.PrivateKey;
import java.security.UnrecoverableKeyException;
import java.security.cert.Certificate;
import java.security.cert.X509Certificate;
import java.util.Enumeration;

import org.springframework.util.Assert;

/**
* Provides access to private keys and certificates from a {@link KeyStore} created from
* an {@link SslBundle}.
*
* @author Scott Frederick
* @since 3.4.0
*/
public final class SslBundleKeyStore {

private final KeyStore keyStore;

private final SslBundleKey sslBundleKey;

private SslBundleKeyStore(KeyStore keyStore, SslBundleKey sslBundleKey) {
this.sslBundleKey = sslBundleKey;
this.keyStore = keyStore;
}

public PrivateKey getPrivateKey() {
String keyAlias = this.sslBundleKey.getAlias();
String keyPassword = this.sslBundleKey.getPassword();
try {
if (keyAlias != null) {
Key key = this.keyStore.getKey(keyAlias, keyPassword.toCharArray());
Assert.notNull(key, "Private key with alias '" + keyAlias + "' was not found in SSL bundle");
Assert.isInstanceOf(PrivateKey.class, key,
"Key with alias '" + keyAlias + "' was expected to be a PrivateKey");
return (PrivateKey) key;
}
return getFirstPrivateKey(this.keyStore, keyPassword);
}
catch (UnrecoverableKeyException kex) {
throw new IllegalArgumentException("Key with alias '" + keyAlias
+ "' could not be retrieved from the key store with the provided password", kex);
}
catch (GeneralSecurityException ex) {
throw new IllegalStateException("Error getting private key from SSL bundle", ex);
}
}

public X509Certificate getCertificate() {
String keyAlias = (this.sslBundleKey != null) ? this.sslBundleKey.getAlias() : null;
try {
if (keyAlias != null) {
Certificate certificate = this.keyStore.getCertificate(keyAlias);
Assert.notNull(certificate, "Certificate with alias '" + keyAlias + "' was not found in SSL bundle");
Assert.isInstanceOf(X509Certificate.class, certificate,
"Certificate with alias '" + keyAlias + "' was expected to be an X509Certificate");
return (X509Certificate) certificate;
}
return getFirstCertificate(this.keyStore);
}
catch (GeneralSecurityException ex) {
throw new IllegalStateException("Error getting X509 certificate from SSL bundle", ex);
}
}

private PrivateKey getFirstPrivateKey(KeyStore keyStore, String keyPassword) throws KeyStoreException {
Enumeration<String> aliases = keyStore.aliases();
while (aliases.hasMoreElements()) {
String alias = aliases.nextElement();
if (keyStore.isKeyEntry(alias)) {
try {
Key key = keyStore.getKey(alias, (keyPassword != null) ? keyPassword.toCharArray() : null);
if (key instanceof PrivateKey privateKey) {
return privateKey;
}
}
catch (GeneralSecurityException ex) {
// did not find the key matching the password, keep looking
}
}
}
return null;
}

private X509Certificate getFirstCertificate(KeyStore keyStore) throws KeyStoreException {
Enumeration<String> aliases = keyStore.aliases();
while (aliases.hasMoreElements()) {
String alias = aliases.nextElement();
Certificate certificate = keyStore.getCertificate(alias);
if (certificate instanceof X509Certificate) {
return (X509Certificate) certificate;
}
}
return null;
}

public static SslBundleKeyStore from(KeyStore keyStore, SslBundleKey sslBundleKey) {
return new SslBundleKeyStore(keyStore, sslBundleKey);
}

}
Loading

0 comments on commit e136fc1

Please sign in to comment.