Skip to content

Commit 2b1535f

Browse files
committed
TestcontainersBeanRegistrationAotProcessor that replaces InstanceSupplier of Container by a reflection equivalent
1 parent 4718485 commit 2b1535f

File tree

5 files changed

+218
-3
lines changed

5 files changed

+218
-3
lines changed

spring-boot-project/spring-boot-testcontainers/src/dockerTest/java/org/springframework/boot/testcontainers/ImportTestcontainersTests.java

+90
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,26 @@
1818

1919
import java.lang.annotation.Retention;
2020
import java.lang.annotation.RetentionPolicy;
21+
import java.util.function.BiConsumer;
2122

2223
import org.junit.jupiter.api.AfterEach;
2324
import org.junit.jupiter.api.Test;
2425
import org.testcontainers.containers.Container;
2526
import org.testcontainers.containers.PostgreSQLContainer;
2627

28+
import org.springframework.aot.test.generate.TestGenerationContext;
2729
import org.springframework.boot.testcontainers.beans.TestcontainerBeanDefinition;
2830
import org.springframework.boot.testcontainers.context.ImportTestcontainers;
2931
import org.springframework.boot.testsupport.container.DisabledIfDockerUnavailable;
3032
import org.springframework.boot.testsupport.container.TestImage;
33+
import org.springframework.context.ApplicationContextInitializer;
3134
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
35+
import org.springframework.context.aot.ApplicationContextAotGenerator;
36+
import org.springframework.context.support.GenericApplicationContext;
37+
import org.springframework.core.test.tools.CompileWithForkedClassLoader;
38+
import org.springframework.core.test.tools.Compiled;
39+
import org.springframework.core.test.tools.TestCompiler;
40+
import org.springframework.javapoet.ClassName;
3241
import org.springframework.test.context.DynamicPropertyRegistry;
3342
import org.springframework.test.context.DynamicPropertySource;
3443

@@ -43,6 +52,8 @@
4352
@DisabledIfDockerUnavailable
4453
class ImportTestcontainersTests {
4554

55+
private final TestGenerationContext generationContext = new TestGenerationContext();
56+
4657
private AnnotationConfigApplicationContext applicationContext;
4758

4859
@AfterEach
@@ -122,6 +133,70 @@ void importWhenHasBadArgsDynamicPropertySourceMethod() {
122133
.withMessage("@DynamicPropertySource method 'containerProperties' must be static");
123134
}
124135

136+
@Test
137+
@CompileWithForkedClassLoader
138+
void importTestcontainersImportWithoutValueAotContributionRegistersTestcontainers() {
139+
this.applicationContext = new AnnotationConfigApplicationContext();
140+
this.applicationContext.register(ImportWithoutValue.class);
141+
compile((freshContext, compiled) -> {
142+
PostgreSQLContainer<?> container = freshContext.getBean(PostgreSQLContainer.class);
143+
assertThat(container).isSameAs(ImportWithoutValue.container);
144+
});
145+
}
146+
147+
@Test
148+
@CompileWithForkedClassLoader
149+
void importTestcontainersImportWithValueAotContributionRegistersTestcontainers() {
150+
this.applicationContext = new AnnotationConfigApplicationContext();
151+
this.applicationContext.register(ImportWithValue.class);
152+
compile((freshContext, compiled) -> {
153+
PostgreSQLContainer<?> container = freshContext.getBean(PostgreSQLContainer.class);
154+
assertThat(container).isSameAs(ContainerDefinitions.container);
155+
});
156+
}
157+
158+
@Test
159+
@CompileWithForkedClassLoader
160+
void importTestcontainersWithDynamicPropertySourceAotContributionRegistersTestcontainers() {
161+
this.applicationContext = new AnnotationConfigApplicationContext();
162+
this.applicationContext.register(ContainerDefinitionsWithDynamicPropertySource.class);
163+
compile((freshContext, compiled) -> {
164+
PostgreSQLContainer<?> container = freshContext.getBean(PostgreSQLContainer.class);
165+
assertThat(container).isSameAs(ContainerDefinitionsWithDynamicPropertySource.container);
166+
});
167+
}
168+
169+
@Test
170+
@CompileWithForkedClassLoader
171+
void importTestcontainersWithCustomPostgreSQLContainerAotContributionRegistersTestcontainers() {
172+
this.applicationContext = new AnnotationConfigApplicationContext();
173+
this.applicationContext.register(CustomPostgreSQLContainerDefinitions.class);
174+
compile((freshContext, compiled) -> {
175+
PostgreSQLContainer<?> container = freshContext.getBean(PostgreSQLContainer.class);
176+
assertThat(container).isSameAs(CustomPostgreSQLContainerDefinitions.container);
177+
});
178+
}
179+
180+
@SuppressWarnings("unchecked")
181+
private void compile(BiConsumer<GenericApplicationContext, Compiled> result) {
182+
ClassName className = processAheadOfTime();
183+
TestCompiler.forSystem().with(this.generationContext).compile((compiled) -> {
184+
GenericApplicationContext freshApplicationContext = new GenericApplicationContext();
185+
ApplicationContextInitializer<GenericApplicationContext> initializer = compiled
186+
.getInstance(ApplicationContextInitializer.class, className.toString());
187+
initializer.initialize(freshApplicationContext);
188+
freshApplicationContext.refresh();
189+
result.accept(freshApplicationContext, compiled);
190+
});
191+
}
192+
193+
private ClassName processAheadOfTime() {
194+
ClassName className = new ApplicationContextAotGenerator().processAheadOfTime(this.applicationContext,
195+
this.generationContext);
196+
this.generationContext.writeGeneratedContent();
197+
return className;
198+
}
199+
125200
@ImportTestcontainers
126201
static class ImportWithoutValue {
127202

@@ -196,4 +271,19 @@ void containerProperties() {
196271

197272
}
198273

274+
@ImportTestcontainers
275+
static class CustomPostgreSQLContainerDefinitions {
276+
277+
static CustomPostgreSQLContainer container = new CustomPostgreSQLContainer();
278+
279+
}
280+
281+
static class CustomPostgreSQLContainer extends PostgreSQLContainer<CustomPostgreSQLContainer> {
282+
283+
CustomPostgreSQLContainer() {
284+
super("postgres:14");
285+
}
286+
287+
}
288+
199289
}

spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/context/TestcontainerFieldBeanDefinition.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2012-2023 the original author or authors.
2+
* Copyright 2012-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -38,9 +38,10 @@ class TestcontainerFieldBeanDefinition extends RootBeanDefinition implements Tes
3838
TestcontainerFieldBeanDefinition(Field field, Container<?> container) {
3939
this.container = container;
4040
this.annotations = MergedAnnotations.from(field);
41-
this.setBeanClass(container.getClass());
41+
setBeanClass(container.getClass());
4242
setInstanceSupplier(() -> container);
4343
setRole(ROLE_INFRASTRUCTURE);
44+
setAttribute(TestcontainerFieldBeanDefinition.class.getName(), field);
4445
}
4546

4647
@Override
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
/*
2+
* Copyright 2012-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.boot.testcontainers.context;
18+
19+
import java.lang.reflect.Field;
20+
21+
import javax.lang.model.element.Modifier;
22+
23+
import org.testcontainers.containers.Container;
24+
25+
import org.springframework.aot.generate.GeneratedMethod;
26+
import org.springframework.aot.generate.GenerationContext;
27+
import org.springframework.beans.factory.aot.BeanRegistrationAotContribution;
28+
import org.springframework.beans.factory.aot.BeanRegistrationAotProcessor;
29+
import org.springframework.beans.factory.aot.BeanRegistrationCode;
30+
import org.springframework.beans.factory.aot.BeanRegistrationCodeFragments;
31+
import org.springframework.beans.factory.aot.BeanRegistrationCodeFragmentsDecorator;
32+
import org.springframework.beans.factory.support.InstanceSupplier;
33+
import org.springframework.beans.factory.support.RegisteredBean;
34+
import org.springframework.beans.factory.support.RootBeanDefinition;
35+
import org.springframework.javapoet.ClassName;
36+
import org.springframework.javapoet.CodeBlock;
37+
import org.springframework.util.Assert;
38+
import org.springframework.util.ClassUtils;
39+
import org.springframework.util.ReflectionUtils;
40+
41+
/**
42+
* {@link BeanRegistrationAotProcessor} that replaces InstanceSupplier of
43+
* {@link Container} by a reflection equivalent.
44+
*
45+
* @author Dmytro Nosan
46+
*/
47+
class TestcontainersBeanRegistrationAotProcessor implements BeanRegistrationAotProcessor {
48+
49+
@Override
50+
public BeanRegistrationAotContribution processAheadOfTime(RegisteredBean registeredBean) {
51+
RootBeanDefinition bd = registeredBean.getMergedBeanDefinition();
52+
String attributeName = TestcontainerFieldBeanDefinition.class.getName();
53+
Object field = bd.getAttribute(attributeName);
54+
if (field != null) {
55+
Assert.isInstanceOf(Field.class, field,
56+
"BeanDefinition attribute '" + attributeName + "' value must be a type of '" + Field.class + "'");
57+
return BeanRegistrationAotContribution.withCustomCodeFragments(
58+
(codeFragments) -> new AotContribution(codeFragments, registeredBean, ((Field) field)));
59+
}
60+
return null;
61+
}
62+
63+
static class AotContribution extends BeanRegistrationCodeFragmentsDecorator {
64+
65+
private final RegisteredBean registeredBean;
66+
67+
private final Field field;
68+
69+
AotContribution(BeanRegistrationCodeFragments delegate, RegisteredBean registeredBean, Field field) {
70+
super(delegate);
71+
this.registeredBean = registeredBean;
72+
this.field = field;
73+
}
74+
75+
@Override
76+
public ClassName getTarget(RegisteredBean registeredBean) {
77+
return ClassName.get(this.registeredBean.getBeanClass());
78+
}
79+
80+
@Override
81+
public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext,
82+
BeanRegistrationCode beanRegistrationCode, boolean allowDirectSupplierShortcut) {
83+
Class<?> beanClass = this.registeredBean.getBeanClass();
84+
Class<?> testClass = this.field.getDeclaringClass();
85+
String fieldName = this.field.getName();
86+
GeneratedMethod generatedMethod = beanRegistrationCode.getMethods()
87+
.add("getInstance", (method) -> method
88+
.addJavadoc("Get the bean instance for '$L'.", this.registeredBean.getBeanName())
89+
.addModifiers(Modifier.PRIVATE, Modifier.STATIC)
90+
.returns(beanClass)
91+
.addStatement("$T<?> testClass = $T.forName($S, null)", Class.class, ClassUtils.class,
92+
testClass.getName())
93+
.addStatement("$T field = $T.findField(testClass, $S)", Field.class, ReflectionUtils.class,
94+
fieldName)
95+
.addStatement("$T.notNull(field, $S)", Assert.class, "Field '" + fieldName + "' is not found")
96+
.addStatement("$T.makeAccessible(field)", ReflectionUtils.class)
97+
.addStatement("$T container = ($T) $T.getField(field, null)", beanClass, beanClass,
98+
ReflectionUtils.class)
99+
.addStatement("$T.notNull(container, $S)", Assert.class,
100+
"Container field '" + fieldName + "' must not have a null value")
101+
.addStatement("return container")
102+
.addException(ClassNotFoundException.class));
103+
return CodeBlock.of("$T.using($T::$L)", InstanceSupplier.class, beanRegistrationCode.getClassName(),
104+
generatedMethod.getName());
105+
}
106+
107+
}
108+
109+
}

spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/properties/TestcontainersPropertySource.java

+11
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@
2626
import org.testcontainers.containers.Container;
2727

2828
import org.springframework.beans.BeansException;
29+
import org.springframework.beans.factory.aot.BeanRegistrationExcludeFilter;
2930
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
3031
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
3132
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
33+
import org.springframework.beans.factory.support.RegisteredBean;
3234
import org.springframework.beans.factory.support.RootBeanDefinition;
3335
import org.springframework.context.ApplicationEventPublisher;
3436
import org.springframework.context.ApplicationEventPublisherAware;
@@ -166,4 +168,13 @@ public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory)
166168

167169
}
168170

171+
static class TestcontainersEventPublisherBeanRegistrationExcludeFilter implements BeanRegistrationExcludeFilter {
172+
173+
@Override
174+
public boolean isExcludedFromAotProcessing(RegisteredBean registeredBean) {
175+
return EventPublisherRegistrar.NAME.equals(registeredBean.getBeanName());
176+
}
177+
178+
}
179+
169180
}
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
org.springframework.beans.factory.aot.BeanRegistrationExcludeFilter=\
2-
org.springframework.boot.testcontainers.service.connection.ConnectionDetailsRegistrar.ServiceConnectionBeanRegistrationExcludeFilter
2+
org.springframework.boot.testcontainers.service.connection.ConnectionDetailsRegistrar.ServiceConnectionBeanRegistrationExcludeFilter,\
3+
org.springframework.boot.testcontainers.properties.TestcontainersPropertySource.TestcontainersEventPublisherBeanRegistrationExcludeFilter
34

45
org.springframework.aot.hint.RuntimeHintsRegistrar=\
56
org.springframework.boot.testcontainers.service.connection.ContainerConnectionDetailsFactory.ContainerConnectionDetailsFactoriesRuntimeHints
7+
8+
org.springframework.beans.factory.aot.BeanRegistrationAotProcessor=\
9+
org.springframework.boot.testcontainers.context.TestcontainersBeanRegistrationAotProcessor

0 commit comments

Comments
 (0)