diff --git a/src/main/java/org/apache/xml/security/signature/XMLSignature.java b/src/main/java/org/apache/xml/security/signature/XMLSignature.java
index b2ec541e5..658bcaf37 100644
--- a/src/main/java/org/apache/xml/security/signature/XMLSignature.java
+++ b/src/main/java/org/apache/xml/security/signature/XMLSignature.java
@@ -684,11 +684,7 @@ private void setSignatureValueElement(byte[] bytes) {
signatureValueElement.removeChild(signatureValueElement.getFirstChild());
}
- String base64codedValue = XMLUtils.encodeToString(bytes);
-
- if (base64codedValue.length() > 76 && !XMLUtils.ignoreLineBreaks()) {
- base64codedValue = "\n" + base64codedValue + "\n";
- }
+ String base64codedValue = XMLUtils.encodeElementValue(bytes);
Text t = createText(base64codedValue);
signatureValueElement.appendChild(t);
diff --git a/src/main/java/org/apache/xml/security/stax/impl/processor/output/AbstractEncryptOutputProcessor.java b/src/main/java/org/apache/xml/security/stax/impl/processor/output/AbstractEncryptOutputProcessor.java
index efa2fa5a8..611ada923 100644
--- a/src/main/java/org/apache/xml/security/stax/impl/processor/output/AbstractEncryptOutputProcessor.java
+++ b/src/main/java/org/apache/xml/security/stax/impl/processor/output/AbstractEncryptOutputProcessor.java
@@ -40,7 +40,6 @@
import javax.xml.stream.XMLStreamConstants;
import javax.xml.stream.XMLStreamException;
-import org.apache.commons.codec.binary.Base64OutputStream;
import org.apache.xml.security.algorithms.JCEMapper;
import org.apache.xml.security.encryption.XMLCipherUtil;
import org.apache.xml.security.exceptions.XMLSecurityException;
@@ -175,12 +174,7 @@ public void init(OutputProcessorChain outputProcessorChain) throws XMLSecurityEx
symmetricCipher.init(Cipher.ENCRYPT_MODE, encryptionPartDef.getSymmetricKey(), parameterSpec);
characterEventGeneratorOutputStream = new CharacterEventGeneratorOutputStream();
- Base64OutputStream base64EncoderStream = null; //NOPMD
- if (XMLUtils.isIgnoreLineBreaks()) {
- base64EncoderStream = new Base64OutputStream(characterEventGeneratorOutputStream, true, 0, null);
- } else {
- base64EncoderStream = new Base64OutputStream(characterEventGeneratorOutputStream, true);
- }
+ OutputStream base64EncoderStream = XMLUtils.encodeStream(characterEventGeneratorOutputStream); //NOPMD
base64EncoderStream.write(iv);
OutputStream outputStream = new CipherOutputStream(base64EncoderStream, symmetricCipher); //NOPMD
diff --git a/src/main/java/org/apache/xml/security/utils/ElementProxy.java b/src/main/java/org/apache/xml/security/utils/ElementProxy.java
index 7e7828f2f..298fbbe01 100644
--- a/src/main/java/org/apache/xml/security/utils/ElementProxy.java
+++ b/src/main/java/org/apache/xml/security/utils/ElementProxy.java
@@ -313,9 +313,7 @@ public void addTextElement(String text, String localname) {
*/
public void addBase64Text(byte[] bytes) {
if (bytes != null) {
- Text t = XMLUtils.ignoreLineBreaks()
- ? createText(XMLUtils.encodeToString(bytes))
- : createText("\n" + XMLUtils.encodeToString(bytes) + "\n");
+ Text t = createText(XMLUtils.encodeElementValue(bytes));
appendSelf(t);
}
}
diff --git a/src/main/java/org/apache/xml/security/utils/XMLUtils.java b/src/main/java/org/apache/xml/security/utils/XMLUtils.java
index 9027469cd..a375044e5 100644
--- a/src/main/java/org/apache/xml/security/utils/XMLUtils.java
+++ b/src/main/java/org/apache/xml/security/utils/XMLUtils.java
@@ -56,14 +56,38 @@
/**
* DOM and XML accessibility and comfort functions.
*
+ * @implNote
+ * The following system properties affect XML formatting:
+ *
+ * - {@systemProperty org.apache.xml.security.ignoreLineBreaks} - ignores all line breaks,
+ * making a single-line document. Overrides all other formatting options. Default: false
+ * - {@systemProperty org.apache.xml.security.base64.ignoreLineBreaks} - ignores line breaks in base64Binary values.
+ * Takes precedence over line length and separator options (see below). Default: false
+ * - {@systemProperty org.apache.xml.security.base64.lineSeparator} - Sets the line separator sequence in base64Binary values.
+ * Possible values: crlf, lf. Default: crlf
+ * - {@systemProperty org.apache.xml.security.base64.lineLength} - Sets maximum line length in base64Binary values.
+ * The value is rounded down to the nearest multiple of 4. Values less than 4 are ignored. Default: 76
+ *
*/
public final class XMLUtils {
+ private static final Logger LOG = System.getLogger(XMLUtils.class.getName());
+
+ private static final String IGNORE_LINE_BREAKS_PROP = "org.apache.xml.security.ignoreLineBreaks";
+
private static boolean ignoreLineBreaks =
AccessController.doPrivileged(
- (PrivilegedAction) () -> Boolean.getBoolean("org.apache.xml.security.ignoreLineBreaks"));
+ (PrivilegedAction) () -> Boolean.getBoolean(IGNORE_LINE_BREAKS_PROP));
- private static final Logger LOG = System.getLogger(XMLUtils.class.getName());
+ private static Base64FormattingOptions base64Formatting =
+ AccessController.doPrivileged(
+ (PrivilegedAction) () -> new Base64FormattingOptions());
+
+ private static Base64.Encoder base64Encoder = (ignoreLineBreaks || base64Formatting.isIgnoreLineBreaks()) ?
+ Base64.getEncoder() :
+ Base64.getMimeEncoder(base64Formatting.getLineLength(), base64Formatting.getLineSeparator().getBytes());
+
+ private static Base64.Decoder base64Decoder = Base64.getMimeDecoder();
private static XMLParser xmlParserImpl =
AccessController.doPrivileged(
@@ -515,18 +539,48 @@ public static void addReturnBeforeChild(Element e, Node child) {
}
public static String encodeToString(byte[] bytes) {
- if (ignoreLineBreaks) {
- return Base64.getEncoder().encodeToString(bytes);
+ return base64Encoder.encodeToString(bytes);
+ }
+
+ /**
+ * Encodes bytes using Base64, with or without line breaks, depending on configuration (see {@link XMLUtils}).
+ * @param bytes Bytes to encode
+ * @return Base64 string
+ */
+ public static String encodeElementValue(byte[] bytes) {
+ String encoded = encodeToString(bytes);
+ if (!ignoreLineBreaks && !base64Formatting.isIgnoreLineBreaks()
+ && encoded.length() > base64Formatting.getLineLength()) {
+ encoded = "\n" + encoded + "\n";
}
- return Base64.getMimeEncoder().encodeToString(bytes);
+ return encoded;
+ }
+
+ /**
+ * Wraps output stream for Base64 encoding.
+ * Output data may contain line breaks or not, depending on configuration (see {@link XMLUtils})
+ * @param stream The underlying output stream to write Base64-encoded data
+ * @return Stream which writes binary data using Base64 encoder
+ */
+ public static OutputStream encodeStream(OutputStream stream) {
+ return base64Encoder.wrap(stream);
}
public static byte[] decode(String encodedString) {
- return Base64.getMimeDecoder().decode(encodedString);
+ return base64Decoder.decode(encodedString);
}
public static byte[] decode(byte[] encodedBytes) {
- return Base64.getMimeDecoder().decode(encodedBytes);
+ return base64Decoder.decode(encodedBytes);
+ }
+
+ /**
+ * Wraps input stream for Base64 decoding.
+ * @param stream Input stream with Base64-encoded data
+ * @return Input stream with decoded binary data
+ */
+ public static InputStream decodeStream(InputStream stream) {
+ return base64Decoder.wrap(stream);
}
public static boolean isIgnoreLineBreaks() {
@@ -1068,4 +1122,90 @@ public static byte[] getBytes(BigInteger big, int bitlen) {
return resizedBytes;
}
+
+ /**
+ * Aggregates formatting options for base64Binary values.
+ */
+ static class Base64FormattingOptions {
+ private static final String BASE64_IGNORE_LINE_BREAKS_PROP = "org.apache.xml.security.base64.ignoreLineBreaks";
+ private static final String BASE64_LINE_SEPARATOR_PROP = "org.apache.xml.security.base64.lineSeparator";
+ private static final String BASE64_LINE_LENGTH_PROP = "org.apache.xml.security.base64.lineLength";
+
+ private boolean ignoreLineBreaks = false;
+ private Base64LineSeparator lineSeparator = Base64LineSeparator.CRLF;
+ private int lineLength = 76;
+
+ /**
+ * Creates new formatting options by reading system properties.
+ */
+ Base64FormattingOptions() {
+ String ignoreLineBreaksProp = System.getProperty(BASE64_IGNORE_LINE_BREAKS_PROP);
+ ignoreLineBreaks = Boolean.parseBoolean(ignoreLineBreaksProp);
+ if (XMLUtils.ignoreLineBreaks && ignoreLineBreaksProp != null && !ignoreLineBreaks) {
+ LOG.log(Level.WARNING, "{0} property takes precedence over {1}, line breaks will be ignored",
+ IGNORE_LINE_BREAKS_PROP, BASE64_IGNORE_LINE_BREAKS_PROP);
+ }
+
+ String lineSeparatorProp = System.getProperty(BASE64_LINE_SEPARATOR_PROP);
+ if (lineSeparatorProp != null) {
+ try {
+ lineSeparator = Base64LineSeparator.valueOf(lineSeparatorProp.toUpperCase());
+ if (XMLUtils.ignoreLineBreaks || ignoreLineBreaks) {
+ LOG.log(Level.WARNING, "Property {0} has no effect since line breaks are ignored",
+ BASE64_LINE_SEPARATOR_PROP);
+ }
+ } catch (IllegalArgumentException e) {
+ LOG.log(Level.WARNING, "Illegal value of {0} property is ignored: {1}",
+ BASE64_LINE_SEPARATOR_PROP, lineSeparatorProp);
+ }
+ }
+
+ String lineLengthProp = System.getProperty(BASE64_LINE_LENGTH_PROP);
+ if (lineLengthProp != null) {
+ try {
+ int lineLength = Integer.parseInt(lineLengthProp);
+ if (lineLength >= 4) {
+ this.lineLength = lineLength;
+ if (XMLUtils.ignoreLineBreaks || ignoreLineBreaks) {
+ LOG.log(Level.WARNING, "Property {0} has no effect since line breaks are ignored",
+ BASE64_LINE_LENGTH_PROP);
+ }
+ } else {
+ LOG.log(Level.WARNING, "Illegal value of {0} property is ignored: {1}",
+ BASE64_LINE_LENGTH_PROP, lineLengthProp);
+ }
+ } catch (NumberFormatException e) {
+ LOG.log(Level.WARNING, "Illegal value of {0} property is ignored: {1}",
+ BASE64_LINE_LENGTH_PROP, lineLengthProp);
+ }
+ }
+ }
+
+ public boolean isIgnoreLineBreaks() {
+ return ignoreLineBreaks;
+ }
+
+ public Base64LineSeparator getLineSeparator() {
+ return lineSeparator;
+ }
+
+ public int getLineLength() {
+ return lineLength;
+ }
+ }
+
+ enum Base64LineSeparator {
+ CRLF(new byte[]{'\r', '\n'}),
+ LF(new byte[]{'\n'});
+
+ private byte[] bytes;
+
+ Base64LineSeparator(byte[] bytes) {
+ this.bytes = bytes;
+ }
+
+ byte[] getBytes() {
+ return bytes;
+ }
+ }
}
diff --git a/src/test/java/org/apache/xml/security/utils/XMLUtilsTest.java b/src/test/java/org/apache/xml/security/utils/XMLUtilsTest.java
new file mode 100644
index 000000000..c852ebd6c
--- /dev/null
+++ b/src/test/java/org/apache/xml/security/utils/XMLUtilsTest.java
@@ -0,0 +1,318 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.xml.security.utils;
+
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.nio.charset.StandardCharsets;
+import java.util.*;
+import java.util.stream.Collectors;
+
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.*;
+import static org.junit.jupiter.api.Assertions.*;
+
+/**
+ * This test checks {@link XMLUtils} class methods, responsible for Base64 values formatting in XML documents.
+ * Since it is a utility class with static methods, and it is configured with system properties,
+ * we need to reload the class after system properties are set in each test case.
+ * This test uses a special implementation of {@link ClassLoader} to achieve this and calls {@code XMLUtils} methods
+ * using reflection.
+ *
+ * There are three methods producing Base64-encoded data in {@code XMLUtils}:
+ *
+ * - {@link XMLUtils#encodeToString(byte[])}
+ * - {@link XMLUtils#encodeElementValue(byte[])}
+ * - {@link XMLUtils#encodeStream(OutputStream)}
(creates a wrapper stream, which applies the same encoding
+ * as {@code encodeToString(byte[])})
+ *
+ * In the tests, formatting of the outputs of these methods is checked.
+ */
+public class XMLUtilsTest {
+ private static final byte[] data = new byte[60]; // long enough for a line break in MIME encoding
+
+ private Properties backup;
+ private ClassLoader classLoader;
+
+ @BeforeEach
+ public void createClassLoader() {
+ /* create custom classloader to reload XMLUtils class and its nested classes in each test */
+ ClassLoader parent = getClass().getClassLoader();
+ Collection> classesToReload = List.of(
+ XMLUtils.class,
+ XMLUtils.Base64FormattingOptions.class,
+ XMLUtils.Base64LineSeparator.class
+ );
+ classLoader = new ReloadingClassLoader(parent, classesToReload);
+
+ /*
+ * XMLUtils instantiates XMLParserImpl, but its package is not exported,
+ * thus unavailable for the new classloader.
+ */
+ ModuleLayer.boot().findModule("org.apache.santuario.xmlsec").orElseThrow()
+ .addOpens("org.apache.xml.security.parser", classLoader.getUnnamedModule());
+ }
+
+ @BeforeEach
+ public void backupProperties() {
+ backup = new Properties();
+ backup.putAll(System.getProperties());
+ }
+
+ @AfterEach
+ public void restoreProperties() {
+ System.setProperties(backup);
+ }
+
+ @Test
+ public void testAllPropertiesUnset() throws ReflectiveOperationException, IOException {
+ System.clearProperty("org.apache.xml.security.ignoreLineBreaks");
+ System.clearProperty("org.apache.xml.security.base64.ignoreLineBreaks");
+ System.clearProperty("org.apache.xml.security.base64.lineSeparator");
+ System.clearProperty("org.apache.xml.security.base64.lineLength");
+
+ Class> xmlUtilsClass = classLoader.loadClass(XMLUtils.class.getName());
+ String encoded = encodeToString(xmlUtilsClass, data);
+ String elementValue = encodeElementValue(xmlUtilsClass, data);
+ String encodedWithStream = encodeUsingStream(xmlUtilsClass, data);
+
+ assertThat(encoded, containsString("\r\n"));
+ OptionalInt maxLineLength = Arrays.stream(encoded.split("\r\n")).mapToInt(String::length).max();
+ assertTrue(maxLineLength.isPresent());
+ assertEquals(76, maxLineLength.getAsInt());
+
+ assertThat(elementValue, containsString(encoded));
+ assertThat(elementValue, startsWith("\n"));
+ assertThat(elementValue, endsWith("\n"));
+
+ assertEquals(encoded, encodedWithStream);
+ }
+
+ @Test
+ public void testIgnoreLineBreaksSet() throws ReflectiveOperationException, IOException {
+ System.setProperty("org.apache.xml.security.ignoreLineBreaks", "true");
+ System.clearProperty("org.apache.xml.security.base64.ignoreLineBreaks");
+ System.clearProperty("org.apache.xml.security.base64.lineSeparator");
+ System.clearProperty("org.apache.xml.security.base64.lineLength");
+
+ Class> xmlUtilsClass = classLoader.loadClass(XMLUtils.class.getName());
+ String encoded = encodeToString(xmlUtilsClass, data);
+ String elementValue = encodeElementValue(xmlUtilsClass, data);
+ String encodedWithStream = encodeUsingStream(xmlUtilsClass, data);
+
+ assertThat(encoded, not(containsString("\r\n")));
+ assertThat(encoded, not(containsString("\n")));
+ assertThat(elementValue, not(containsString("\r\n")));
+ assertThat(elementValue, not(containsString("\n")));
+
+ assertEquals(encoded, encodedWithStream);
+ }
+
+ @Test
+ public void testIgnoreLineBreaksTakesPrecedence() throws ReflectiveOperationException, IOException {
+ System.setProperty("org.apache.xml.security.ignoreLineBreaks", "true");
+ System.setProperty("org.apache.xml.security.base64.ignoreLineBreaks", "false");
+ System.setProperty("org.apache.xml.security.base64.lineSeparator", "crlf");
+ System.setProperty("org.apache.xml.security.base64.lineLength", "40");
+
+ Class> xmlUtilsClass = classLoader.loadClass(XMLUtils.class.getName());
+ String encoded = encodeToString(xmlUtilsClass, data);
+ String elementValue = encodeElementValue(xmlUtilsClass, data);
+ String encodedWithStream = encodeUsingStream(xmlUtilsClass, data);
+
+ assertThat(encoded, not(containsString("\r\n")));
+ assertThat(encoded, not(containsString("\n")));
+ assertThat(elementValue, not(containsString("\r\n")));
+ assertThat(elementValue, not(containsString("\n")));
+
+ assertEquals(encoded, encodedWithStream);
+ }
+
+ @Test
+ public void testBase64IgnoreLineBreaksSet() throws ReflectiveOperationException, IOException {
+ System.clearProperty("org.apache.xml.security.ignoreLineBreaks");
+ System.setProperty("org.apache.xml.security.base64.ignoreLineBreaks", "true");
+ System.clearProperty("org.apache.xml.security.base64.lineSeparator");
+ System.clearProperty("org.apache.xml.security.base64.lineLength");
+
+ Class> xmlUtilsClass = classLoader.loadClass(XMLUtils.class.getName());
+ String encoded = encodeToString(xmlUtilsClass, data);
+ String elementValue = encodeElementValue(xmlUtilsClass, data);
+ String encodedWithStream = encodeUsingStream(xmlUtilsClass, data);
+
+ assertThat(encoded, not(containsString("\r\n")));
+ assertThat(encoded, not(containsString("\n")));
+ assertThat(elementValue, not(containsString("\r\n")));
+ assertThat(elementValue, not(containsString("\n")));
+
+ assertEquals(encoded, encodedWithStream);
+ }
+
+ @Test
+ public void testBase64IgnoreLineBreaksTakesPrecedence() throws ReflectiveOperationException, IOException {
+ System.clearProperty("org.apache.xml.security.ignoreLineBreaks");
+ System.setProperty("org.apache.xml.security.base64.ignoreLineBreaks", "true");
+ System.setProperty("org.apache.xml.security.base64.lineSeparator", "crlf");
+ System.setProperty("org.apache.xml.security.base64.lineLength", "40");
+
+ Class> xmlUtilsClass = classLoader.loadClass(XMLUtils.class.getName());
+ String encoded = encodeToString(xmlUtilsClass, data);
+ String elementValue = encodeElementValue(xmlUtilsClass, data);
+ String encodedWithStream = encodeUsingStream(xmlUtilsClass, data);
+
+ assertThat(encoded, not(containsString("\r\n")));
+ assertThat(encoded, not(containsString("\n")));
+ assertThat(elementValue, not(containsString("\r\n")));
+ assertThat(elementValue, not(containsString("\n")));
+
+ assertEquals(encoded, encodedWithStream);
+ }
+
+ @Test
+ public void testBase64CustomFormatting() throws ReflectiveOperationException, IOException {
+ System.clearProperty("org.apache.xml.security.ignoreLineBreaks");
+ System.clearProperty("org.apache.xml.security.base64.ignoreLineBreaks");
+ System.setProperty("org.apache.xml.security.base64.lineSeparator", "lf");
+ System.setProperty("org.apache.xml.security.base64.lineLength", "40");
+
+ Class> xmlUtilsClass = classLoader.loadClass(XMLUtils.class.getName());
+ String encoded = encodeToString(xmlUtilsClass, data);
+ String elementValue = encodeElementValue(xmlUtilsClass, data);
+ String encodedWithStream = encodeUsingStream(xmlUtilsClass, data);
+
+ assertThat(encoded, not(containsString("\r\n")));
+ assertThat(encoded, containsString("\n"));
+ OptionalInt maxLineLength = Arrays.stream(encoded.split("\n")).mapToInt(String::length).max();
+ assertTrue(maxLineLength.isPresent());
+ assertEquals(40, maxLineLength.getAsInt());
+
+ assertThat(elementValue, containsString(encoded));
+ assertThat(elementValue, startsWith("\n"));
+ assertThat(elementValue, endsWith("\n"));
+
+ assertEquals(encoded, encodedWithStream);
+ }
+
+ @Test
+ public void testIllegalPropertiesAreIgnored() throws ReflectiveOperationException, IOException {
+ System.setProperty("org.apache.xml.security.ignoreLineBreaks", "illegal");
+ System.setProperty("org.apache.xml.security.base64.ignoreLineBreaks", "illegal");
+ System.setProperty("org.apache.xml.security.base64.lineSeparator", "illegal");
+ System.setProperty("org.apache.xml.security.base64.lineLength", "illegal");
+
+ Class> xmlUtilsClass = classLoader.loadClass(XMLUtils.class.getName());
+ String encoded = encodeToString(xmlUtilsClass, data);
+ String elementValue = encodeElementValue(xmlUtilsClass, data);
+ String encodedWithStream = encodeUsingStream(xmlUtilsClass, data);
+
+ assertThat(encoded, containsString("\r\n"));
+ OptionalInt maxLineLength = Arrays.stream(encoded.split("\r\n")).mapToInt(String::length).max();
+ assertTrue(maxLineLength.isPresent());
+ assertEquals(76, maxLineLength.getAsInt());
+
+ assertThat(elementValue, containsString(encoded));
+ assertThat(elementValue, startsWith("\n"));
+ assertThat(elementValue, endsWith("\n"));
+
+ assertEquals(encoded, encodedWithStream);
+ }
+
+ private String encodeToString(Class> xmlUtilsClass, byte[] bytes) throws ReflectiveOperationException {
+ return (String) xmlUtilsClass.getMethod("encodeToString", byte[].class).invoke(null, (Object) bytes);
+ }
+
+ private String encodeElementValue(Class> xmlUtilsClass, byte[] bytes) throws ReflectiveOperationException {
+ return (String) xmlUtilsClass.getMethod("encodeElementValue", byte[].class).invoke(null, (Object) bytes);
+ }
+
+ private OutputStream encodeStream(Class> xmlUtilsClass, OutputStream stream) throws ReflectiveOperationException {
+ return (OutputStream) xmlUtilsClass.getMethod("encodeStream", OutputStream.class).invoke(null, stream);
+ }
+
+ private String encodeUsingStream(Class> xmlUtilsClass, byte[] bytes) throws ReflectiveOperationException, IOException {
+ try (ByteArrayOutputStream encoded = new ByteArrayOutputStream();
+ OutputStream raw = encodeStream(xmlUtilsClass, encoded)) {
+ raw.write(bytes);
+ raw.flush();
+ return encoded.toString(StandardCharsets.US_ASCII);
+ }
+ }
+
+ /**
+ * This implementation of {@code ClassLoader} reloads given classes from bytecode,
+ * even if they are already loaded by the parent class loader.
+ */
+ private static class ReloadingClassLoader extends ClassLoader {
+ private Collection classNames;
+
+ /**
+ * Creates new class loader.
+ * @param parent Parent class loader.
+ * @param classes Set of classes to be forcefully reloaded
+ */
+ private ReloadingClassLoader(ClassLoader parent, Collection> classes) {
+ super("TestClassLoader", parent);
+ this.classNames = classes.stream().map(Class::getName).collect(Collectors.toSet());
+ }
+
+ @Override
+ protected Class> loadClass(String name, boolean resolve) throws ClassNotFoundException {
+ if (classNames.contains(name)) {
+ Class> clazz = findClass(name);
+ if (resolve) {
+ resolveClass(clazz);
+ }
+ return clazz;
+ }
+ return super.loadClass(name, resolve);
+ }
+
+ @Override
+ protected Class> findClass(String name) throws ClassNotFoundException {
+ if (classNames.contains(name)) {
+ Class> parentLoadedClass = getParent().loadClass(name);
+ String resourceName = synthesizeClassName(parentLoadedClass) + ".class";
+ byte[] classData;
+ try (InputStream in = parentLoadedClass.getResourceAsStream(resourceName)) {
+ if (in == null) {
+ throw new ClassNotFoundException("Could not load class " + name);
+ }
+ classData = in.readAllBytes();
+ } catch (IOException e) {
+ throw new ClassNotFoundException("Could not load class " + name, e);
+ }
+
+ return defineClass(name, classData, 0, classData.length);
+ }
+ throw new ClassNotFoundException("Class not found: " + name);
+ }
+
+ private String synthesizeClassName(Class> clazz) {
+ String name = clazz.getSimpleName();
+ if (clazz.isMemberClass()) name = synthesizeClassName(clazz.getEnclosingClass()) + "$" + name;
+ return name;
+ }
+ }
+}