Skip to content

Commit c44ccdc

Browse files
david-perezHarry Barber
authored and
Harry Barber
committed
Fix recursive constraint violations with paths over list and map shapes (#2371)
* Fix recursive constraint violations with paths over list and map shapes There is a widespread assumption throughout the generation of constraint violations that does not hold true all the time, namely, that a recursive constraint violation graph has the same requirements with regards to boxing as the regular shape graph. Some types corresponding to recursive shapes are boxed to introduce indirection and thus not generate an infinitely recursive type. The algorithm however does not superfluously introduce boxes when the cycle goes through a list shape or a map shape. Why list shapes and map shapes? List shapes and map shapes get rendered in Rust as `Vec<T>` and `HashMap<K, V>`, respectively, they're the only Smithy shapes that "organically" introduce indirection (via a pointer to the heap) in the recursive path. For other recursive paths, we thus have to introduce the indirection artificially ourselves using `Box`. This is done in the `RecursiveShapeBoxer` model transform. However, the constraint violation graph needs to box types in recursive paths more often. Since we don't collect constraint violations (yet, see #2040), the constraint violation graph never holds `Vec<T>`s or `HashMap<K, V>`s, only simple types. Indeed, the following simple recursive model: ```smithy union Recursive { list: List } @Length(min: 69) list List { member: Recursive } ``` has a cycle that goes through a list shape, so no shapes in it need boxing in the regular shape graph. However, the constraint violation graph is infinitely recursive if we don't introduce boxing somewhere: ```rust pub mod model { pub mod list { pub enum ConstraintViolation { Length(usize), Member( usize, crate::model::recursive::ConstraintViolation, ), } } pub mod recursive { pub enum ConstraintViolation { List(crate::model::list::ConstraintViolation), } } } ``` This commit fixes things by making the `RecursiveShapeBoxer` model transform configurable so that the "cycles through lists and maps introduce indirection" assumption can be lifted. This allows a server model transform, `RecursiveConstraintViolationBoxer`, to tag member shapes along recursive paths with a new trait, `ConstraintViolationRustBoxTrait`, that the constraint violation type generation then utilizes to ensure that no infinitely recursive constraint violation types get generated. For example, for the above model, the generated Rust code would now look like: ```rust pub mod model { pub mod list { pub enum ConstraintViolation { Length(usize), Member( usize, std::boxed::Box(crate::model::recursive::ConstraintViolation), ), } } pub mod recursive { pub enum ConstraintViolation { List(crate::model::list::ConstraintViolation), } } } ``` Likewise, places where constraint violations are handled (like where unconstrained types are converted to constrained types) have been updated to account for the scenario where they now are or need to be boxed. Parametrized tests have been added to exhaustively test combinations of models exercising recursive paths going through (sparse and non-sparse) list and map shapes, as well as union and structure shapes (`RecursiveConstraintViolationsTest`). These tests even assert that the specific member shapes along the cycles are tagged as expected (`RecursiveConstraintViolationBoxerTest`). * Address comments
1 parent 7a69677 commit c44ccdc

File tree

29 files changed

+479
-78
lines changed

29 files changed

+479
-78
lines changed

codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitor.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -102,14 +102,14 @@ class ClientCodegenVisitor(
102102
// Add errors attached at the service level to the models
103103
.let { ModelTransformer.create().copyServiceErrorsToOperations(it, settings.getService(it)) }
104104
// Add `Box<T>` to recursive shapes as necessary
105-
.let(RecursiveShapeBoxer::transform)
105+
.let(RecursiveShapeBoxer()::transform)
106106
// Normalize the `message` field on errors when enabled in settings (default: true)
107107
.letIf(settings.codegenConfig.addMessageToErrors, AddErrorMessage::transform)
108108
// NormalizeOperations by ensuring every operation has an input & output shape
109109
.let(OperationNormalizer::transform)
110110
// Drop unsupported event stream operations from the model
111111
.let { RemoveEventStreamOperations.transform(it, settings) }
112-
// - Normalize event stream operations
112+
// Normalize event stream operations
113113
.let(EventStreamNormalizer::transform)
114114

115115
/**

codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/customizations/ResiliencyConfigCustomizationTest.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ internal class ResiliencyConfigCustomizationTest {
3636

3737
@Test
3838
fun `generates a valid config`() {
39-
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
39+
val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel))
4040
val project = TestWorkspace.testProject()
4141
val codegenContext = testCodegenContext(model, settings = project.rustSettings())
4242

codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/RustBoxTrait.kt

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ import software.amazon.smithy.model.traits.Trait
1212
/**
1313
* Trait indicating that this shape should be represented with `Box<T>` when converted into Rust
1414
*
15-
* This is used to handle recursive shapes. See RecursiveShapeBoxer.
15+
* This is used to handle recursive shapes.
16+
* See [software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer].
1617
*
1718
* This trait is synthetic, applied during code generation, and never used in actual models.
1819
*/

codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapeBoxer.kt

+62-32
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,50 @@ package software.amazon.smithy.rust.codegen.core.smithy.transformers
77

88
import software.amazon.smithy.codegen.core.TopologicalIndex
99
import software.amazon.smithy.model.Model
10-
import software.amazon.smithy.model.shapes.ListShape
10+
import software.amazon.smithy.model.shapes.CollectionShape
1111
import software.amazon.smithy.model.shapes.MapShape
1212
import software.amazon.smithy.model.shapes.MemberShape
13-
import software.amazon.smithy.model.shapes.SetShape
1413
import software.amazon.smithy.model.shapes.Shape
1514
import software.amazon.smithy.model.transform.ModelTransformer
1615
import software.amazon.smithy.rust.codegen.core.smithy.traits.RustBoxTrait
1716
import software.amazon.smithy.rust.codegen.core.util.hasTrait
1817

19-
object RecursiveShapeBoxer {
18+
class RecursiveShapeBoxer(
2019
/**
21-
* Transform a model which may contain recursive shapes into a model annotated with [RustBoxTrait]
20+
* A predicate that determines when a cycle in the shape graph contains "indirection". If a cycle contains
21+
* indirection, no shape needs to be tagged. What constitutes indirection is up to the caller to decide.
22+
*/
23+
private val containsIndirectionPredicate: (Collection<Shape>) -> Boolean = ::containsIndirection,
24+
/**
25+
* A closure that gets called on one member shape of a cycle that does not contain indirection for "fixing". For
26+
* example, the [RustBoxTrait] trait can be used to tag the member shape.
27+
*/
28+
private val boxShapeFn: (MemberShape) -> MemberShape = ::addRustBoxTrait,
29+
) {
30+
/**
31+
* Transform a model which may contain recursive shapes.
2232
*
23-
* When recursive shapes do NOT go through a List, Map, or Set, they must be boxed in Rust. This function will
24-
* iteratively find loops & add the `RustBox` trait in a deterministic way until it reaches a fixed point.
33+
* For example, when recursive shapes do NOT go through a `CollectionShape` or a `MapShape` shape, they must be
34+
* boxed in Rust. This function will iteratively find cycles and call [boxShapeFn] on a member shape in the
35+
* cycle to act on it. This is done in a deterministic way until it reaches a fixed point.
2536
*
26-
* This function MUST be deterministic (always choose the same shapes to `Box`). If it is not, that is a bug. Even so
37+
* This function MUST be deterministic (always choose the same shapes to fix). If it is not, that is a bug. Even so
2738
* this function may cause backward compatibility issues in certain pathological cases where a changes to recursive
2839
* structures cause different members to be boxed. We may need to address these via customizations.
40+
*
41+
* For example, given the following model,
42+
*
43+
* ```smithy
44+
* namespace com.example
45+
*
46+
* structure Recursive {
47+
* recursiveStruct: Recursive
48+
* anotherField: Boolean
49+
* }
50+
* ```
51+
*
52+
* The `com.example#Recursive$recursiveStruct` member shape is part of a cycle, but the
53+
* `com.example#Recursive$anotherField` member shape is not.
2954
*/
3055
fun transform(model: Model): Model {
3156
val next = transformInner(model)
@@ -37,16 +62,17 @@ object RecursiveShapeBoxer {
3762
}
3863

3964
/**
40-
* If [model] contains a recursive loop that must be boxed, apply one instance of [RustBoxTrait] return the new model.
41-
* If [model] contains no loops, return null.
65+
* If [model] contains a recursive loop that must be boxed, return the transformed model resulting form a call to
66+
* [boxShapeFn].
67+
* If [model] contains no loops, return `null`.
4268
*/
4369
private fun transformInner(model: Model): Model? {
44-
// Execute 1-step of the boxing algorithm in the path to reaching a fixed point
45-
// 1. Find all the shapes that are part of a cycle
46-
// 2. Find all the loops that those shapes are part of
47-
// 3. Filter out the loops that go through a layer of indirection
48-
// 3. Pick _just one_ of the remaining loops to fix
49-
// 4. Select the member shape in that loop with the earliest shape id
70+
// Execute 1 step of the boxing algorithm in the path to reaching a fixed point:
71+
// 1. Find all the shapes that are part of a cycle.
72+
// 2. Find all the loops that those shapes are part of.
73+
// 3. Filter out the loops that go through a layer of indirection.
74+
// 3. Pick _just one_ of the remaining loops to fix.
75+
// 4. Select the member shape in that loop with the earliest shape id.
5076
// 5. Box it.
5177
// (External to this function) Go back to 1.
5278
val index = TopologicalIndex.of(model)
@@ -58,34 +84,38 @@ object RecursiveShapeBoxer {
5884
// Flatten the connections into shapes.
5985
loops.map { it.shapes }
6086
}
61-
val loopToFix = loops.firstOrNull { !containsIndirection(it) }
87+
val loopToFix = loops.firstOrNull { !containsIndirectionPredicate(it) }
6288

6389
return loopToFix?.let { loop: List<Shape> ->
6490
check(loop.isNotEmpty())
65-
// pick the shape to box in a deterministic way
91+
// Pick the shape to box in a deterministic way.
6692
val shapeToBox = loop.filterIsInstance<MemberShape>().minByOrNull { it.id }!!
6793
ModelTransformer.create().mapShapes(model) { shape ->
6894
if (shape == shapeToBox) {
69-
shape.asMemberShape().get().toBuilder().addTrait(RustBoxTrait()).build()
95+
boxShapeFn(shape.asMemberShape().get())
7096
} else {
7197
shape
7298
}
7399
}
74100
}
75101
}
102+
}
76103

77-
/**
78-
* Check if a List<Shape> contains a shape which will use a pointer when represented in Rust, avoiding the
79-
* need to add more Boxes
80-
*/
81-
private fun containsIndirection(loop: List<Shape>): Boolean {
82-
return loop.find {
83-
when (it) {
84-
is ListShape,
85-
is MapShape,
86-
is SetShape, -> true
87-
else -> it.hasTrait<RustBoxTrait>()
88-
}
89-
} != null
104+
/**
105+
* Check if a `List<Shape>` contains a shape which will use a pointer when represented in Rust, avoiding the
106+
* need to add more `Box`es.
107+
*
108+
* Why `CollectionShape`s and `MapShape`s? Note that `CollectionShape`s get rendered in Rust as `Vec<T>`, and
109+
* `MapShape`s as `HashMap<String, T>`; they're the only Smithy shapes that "organically" introduce indirection
110+
* (via a pointer to the heap) in the recursive path. For other recursive paths, we thus have to introduce the
111+
* indirection artificially ourselves using `Box`.
112+
*
113+
*/
114+
private fun containsIndirection(loop: Collection<Shape>): Boolean = loop.find {
115+
when (it) {
116+
is CollectionShape, is MapShape -> true
117+
else -> it.hasTrait<RustBoxTrait>()
90118
}
91-
}
119+
} != null
120+
121+
private fun addRustBoxTrait(shape: MemberShape): MemberShape = shape.toBuilder().addTrait(RustBoxTrait()).build()

codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/InstantiatorTest.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class InstantiatorTest {
8282
@required
8383
num: Integer
8484
}
85-
""".asSmithyModel().let { RecursiveShapeBoxer.transform(it) }
85+
""".asSmithyModel().let { RecursiveShapeBoxer().transform(it) }
8686

8787
private val codegenContext = testCodegenContext(model)
8888
private val symbolProvider = codegenContext.symbolProvider

codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGeneratorTest.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ class StructureGeneratorTest {
327327
@Test
328328
fun `it generates accessor methods`() {
329329
val testModel =
330-
RecursiveShapeBoxer.transform(
330+
RecursiveShapeBoxer().transform(
331331
"""
332332
namespace test
333333

codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/AwsQueryParserGeneratorTest.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class AwsQueryParserGeneratorTest {
4242

4343
@Test
4444
fun `it modifies operation parsing to include Response and Result tags`() {
45-
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
45+
val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel))
4646
val codegenContext = testCodegenContext(model)
4747
val symbolProvider = codegenContext.symbolProvider
4848
val parserGenerator = AwsQueryParserGenerator(

codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/Ec2QueryParserGeneratorTest.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class Ec2QueryParserGeneratorTest {
4242

4343
@Test
4444
fun `it modifies operation parsing to include Response and Result tags`() {
45-
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
45+
val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel))
4646
val codegenContext = testCodegenContext(model)
4747
val symbolProvider = codegenContext.symbolProvider
4848
val parserGenerator = Ec2QueryParserGenerator(

codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGeneratorTest.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ class JsonParserGeneratorTest {
114114

115115
@Test
116116
fun `generates valid deserializers`() {
117-
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
117+
val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel))
118118
val codegenContext = testCodegenContext(model)
119119
val symbolProvider = codegenContext.symbolProvider
120120
fun builderSymbol(shape: StructureShape): Symbol =

codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGeneratorTest.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ internal class XmlBindingTraitParserGeneratorTest {
9494

9595
@Test
9696
fun `generates valid parsers`() {
97-
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
97+
val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel))
9898
val codegenContext = testCodegenContext(model)
9999
val symbolProvider = codegenContext.symbolProvider
100100
val parserGenerator = XmlBindingTraitParserGenerator(

codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/AwsQuerySerializerGeneratorTest.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class AwsQuerySerializerGeneratorTest {
9292
true -> CodegenTarget.CLIENT
9393
false -> CodegenTarget.SERVER
9494
}
95-
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
95+
val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel))
9696
val codegenContext = testCodegenContext(model, codegenTarget = codegenTarget)
9797
val symbolProvider = codegenContext.symbolProvider
9898
val parserGenerator = AwsQuerySerializerGenerator(testCodegenContext(model, codegenTarget = codegenTarget))

codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/Ec2QuerySerializerGeneratorTest.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class Ec2QuerySerializerGeneratorTest {
8585

8686
@Test
8787
fun `generates valid serializers`() {
88-
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
88+
val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel))
8989
val codegenContext = testCodegenContext(model)
9090
val symbolProvider = codegenContext.symbolProvider
9191
val parserGenerator = Ec2QuerySerializerGenerator(codegenContext)

codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGeneratorTest.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ class JsonSerializerGeneratorTest {
100100

101101
@Test
102102
fun `generates valid serializers`() {
103-
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
103+
val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel))
104104
val codegenContext = testCodegenContext(model)
105105
val symbolProvider = codegenContext.symbolProvider
106106
val parserSerializer = JsonSerializerGenerator(

codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGeneratorTest.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ internal class XmlBindingTraitSerializerGeneratorTest {
105105

106106
@Test
107107
fun `generates valid serializers`() {
108-
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
108+
val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel))
109109
val codegenContext = testCodegenContext(model)
110110
val symbolProvider = codegenContext.symbolProvider
111111
val parserGenerator = XmlBindingTraitSerializerGenerator(

codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapeBoxerTest.kt

+3-3
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ internal class RecursiveShapeBoxerTest {
3131
hello: Hello
3232
}
3333
""".asSmithyModel()
34-
RecursiveShapeBoxer.transform(model) shouldBe model
34+
RecursiveShapeBoxer().transform(model) shouldBe model
3535
}
3636

3737
@Test
@@ -43,7 +43,7 @@ internal class RecursiveShapeBoxerTest {
4343
anotherField: Boolean
4444
}
4545
""".asSmithyModel()
46-
val transformed = RecursiveShapeBoxer.transform(model)
46+
val transformed = RecursiveShapeBoxer().transform(model)
4747
val member: MemberShape = transformed.lookup("com.example#Recursive\$RecursiveStruct")
4848
member.expectTrait<RustBoxTrait>()
4949
}
@@ -70,7 +70,7 @@ internal class RecursiveShapeBoxerTest {
7070
third: SecondTree
7171
}
7272
""".asSmithyModel()
73-
val transformed = RecursiveShapeBoxer.transform(model)
73+
val transformed = RecursiveShapeBoxer().transform(model)
7474
val boxed = transformed.shapes().filter { it.hasTrait<RustBoxTrait>() }.toList()
7575
boxed.map { it.id.toString().removePrefix("com.example#") }.toSet() shouldBe setOf(
7676
"Atom\$add",

codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapesIntegrationTest.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class RecursiveShapesIntegrationTest {
6666
}
6767
output.message shouldContain "has infinite size"
6868

69-
val fixedProject = check(RecursiveShapeBoxer.transform(model))
69+
val fixedProject = check(RecursiveShapeBoxer().transform(model))
7070
fixedProject.compileAndTest()
7171
}
7272
}

codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt

+4-1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.Ser
7575
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerProtocolLoader
7676
import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput
7777
import software.amazon.smithy.rust.codegen.server.smithy.transformers.AttachValidationExceptionToConstrainedOperationInputsInAllowList
78+
import software.amazon.smithy.rust.codegen.server.smithy.transformers.RecursiveConstraintViolationBoxer
7879
import software.amazon.smithy.rust.codegen.server.smithy.transformers.RemoveEbsModelValidationException
7980
import software.amazon.smithy.rust.codegen.server.smithy.transformers.ShapesReachableFromOperationInputTagger
8081
import java.util.logging.Logger
@@ -159,7 +160,9 @@ open class ServerCodegenVisitor(
159160
// Add errors attached at the service level to the models
160161
.let { ModelTransformer.create().copyServiceErrorsToOperations(it, settings.getService(it)) }
161162
// Add `Box<T>` to recursive shapes as necessary
162-
.let(RecursiveShapeBoxer::transform)
163+
.let(RecursiveShapeBoxer()::transform)
164+
// Add `Box<T>` to recursive constraint violations as necessary
165+
.let(RecursiveConstraintViolationBoxer::transform)
163166
// Normalize operations by adding synthetic input and output shapes to every operation
164167
.let(OperationNormalizer::transform)
165168
// Remove the EBS model's own `ValidationException`, which collides with `smithy.framework#ValidationException`

codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/CollectionConstraintViolationGenerator.kt

+13-4
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,15 @@ import software.amazon.smithy.model.shapes.CollectionShape
99
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
1010
import software.amazon.smithy.rust.codegen.core.rustlang.Visibility
1111
import software.amazon.smithy.rust.codegen.core.rustlang.join
12-
import software.amazon.smithy.rust.codegen.core.rustlang.rust
1312
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
13+
import software.amazon.smithy.rust.codegen.core.smithy.makeRustBoxed
1414
import software.amazon.smithy.rust.codegen.core.smithy.module
15+
import software.amazon.smithy.rust.codegen.core.util.hasTrait
16+
import software.amazon.smithy.rust.codegen.core.util.letIf
1517
import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider
1618
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
1719
import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape
20+
import software.amazon.smithy.rust.codegen.server.smithy.traits.ConstraintViolationRustBoxTrait
1821
import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput
1922

2023
class CollectionConstraintViolationGenerator(
@@ -38,16 +41,22 @@ class CollectionConstraintViolationGenerator(
3841
private val constraintsInfo: List<TraitInfo> = collectionConstraintsInfo.map { it.toTraitInfo() }
3942

4043
fun render() {
41-
val memberShape = model.expectShape(shape.member.target)
44+
val targetShape = model.expectShape(shape.member.target)
4245
val constraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(shape)
4346
val constraintViolationName = constraintViolationSymbol.name
44-
val isMemberConstrained = memberShape.canReachConstrainedShape(model, symbolProvider)
47+
val isMemberConstrained = targetShape.canReachConstrainedShape(model, symbolProvider)
4548
val constraintViolationVisibility = Visibility.publicIf(publicConstrainedTypes, Visibility.PUBCRATE)
4649

4750
modelsModuleWriter.withInlineModule(constraintViolationSymbol.module()) {
4851
val constraintViolationVariants = constraintsInfo.map { it.constraintViolationVariant }.toMutableList()
4952
if (isMemberConstrained) {
5053
constraintViolationVariants += {
54+
val memberConstraintViolationSymbol =
55+
constraintViolationSymbolProvider.toSymbol(targetShape).letIf(
56+
shape.member.hasTrait<ConstraintViolationRustBoxTrait>(),
57+
) {
58+
it.makeRustBoxed()
59+
}
5160
rustTemplate(
5261
"""
5362
/// Constraint violation error when an element doesn't satisfy its own constraints.
@@ -56,7 +65,7 @@ class CollectionConstraintViolationGenerator(
5665
##[doc(hidden)]
5766
Member(usize, #{MemberConstraintViolationSymbol})
5867
""",
59-
"MemberConstraintViolationSymbol" to constraintViolationSymbolProvider.toSymbol(memberShape),
68+
"MemberConstraintViolationSymbol" to memberConstraintViolationSymbol,
6069
)
6170
}
6271
}

0 commit comments

Comments
 (0)