Skip to content

Add EnumSection to allow decorators to modify enum member attributes #4039

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
10 changes: 10 additions & 0 deletions .changelog/1740703869.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
applies_to: ["client", "server"]
authors: [Dorenavant]
references: []
breaking: false
new_feature: true
bug_fix: false
---

Add EnumSection to allow decorators to modify enum member attributes
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,13 @@ class ClientCodegenVisitor(
model = codegenDecorator.transformModel(untransformedService, baseModel, settings)
// the model transformer _might_ change the service shape
val service = settings.getService(model)
symbolProvider = RustClientCodegenPlugin.baseSymbolProvider(settings, model, service, rustSymbolProviderConfig, codegenDecorator)
symbolProvider = RustClientCodegenPlugin.baseSymbolProvider(
settings,
model,
service,
rustSymbolProviderConfig,
codegenDecorator,
)

codegenContext =
ClientCodegenContext(
Expand Down Expand Up @@ -177,7 +183,10 @@ class ClientCodegenVisitor(
)
try {
// use an increased max_width to make rustfmt fail less frequently
"cargo fmt -- --config max_width=150".runCommand(fileManifest.baseDir, timeout = settings.codegenConfig.formatTimeoutSeconds.toLong())
"cargo fmt -- --config max_width=150".runCommand(
fileManifest.baseDir,
timeout = settings.codegenConfig.formatTimeoutSeconds.toLong(),
)
} catch (err: CommandError) {
logger.warning("Failed to run cargo fmt: [${service.id}]\n${err.output}")
}
Expand Down Expand Up @@ -236,7 +245,10 @@ class ClientCodegenVisitor(

implBlock(symbolProvider.toSymbol(shape)) {
BuilderGenerator.renderConvenienceMethod(this, symbolProvider, shape)
if (codegenContext.protocolImpl?.httpBindingResolver?.handlesEventStreamInitialResponse(shape) == true) {
if (codegenContext.protocolImpl?.httpBindingResolver?.handlesEventStreamInitialResponse(
shape,
) == true
) {
BuilderGenerator.renderIntoBuilderMethod(this, symbolProvider, shape)
}
}
Expand All @@ -251,6 +263,7 @@ class ClientCodegenVisitor(
}
struct to builder
}

else -> {
val errorGenerator =
ErrorGenerator(
Expand Down Expand Up @@ -283,7 +296,11 @@ class ClientCodegenVisitor(
if (shape.hasTrait<EnumTrait>()) {
val privateModule = privateModule(shape)
rustCrate.inPrivateModuleWithReexport(privateModule, symbolProvider.toSymbol(shape)) {
ClientEnumGenerator(codegenContext, shape).render(this)
ClientEnumGenerator(
codegenContext,
shape,
codegenDecorator.enumCustomizations(codegenContext, emptyList()),
).render(this)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGeneratorContext
import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumMemberModel
Expand Down Expand Up @@ -278,7 +279,11 @@ data class InfallibleEnumType(
}
}

class ClientEnumGenerator(codegenContext: ClientCodegenContext, shape: StringShape) :
class ClientEnumGenerator(
codegenContext: ClientCodegenContext,
shape: StringShape,
customizations: List<EnumCustomization>,
) :
EnumGenerator(
codegenContext.model,
codegenContext.symbolProvider,
Expand All @@ -290,6 +295,7 @@ class ClientEnumGenerator(codegenContext: ClientCodegenContext, shape: StringSha
parent = ClientRustModule.primitives,
),
),
customizations,
)

private fun unknownVariantError(): RuntimeType =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class ClientEnumGeneratorTest {
val context = testClientCodegenContext(model)
val project = TestWorkspace.testProject(context.symbolProvider)
project.moduleFor(shape) {
ClientEnumGenerator(context, shape).render(this)
ClientEnumGenerator(context, shape, emptyList()).render(this)
unitTest(
"matching_on_enum_should_be_forward_compatible",
"""
Expand Down Expand Up @@ -88,7 +88,7 @@ class ClientEnumGeneratorTest {
val context = testClientCodegenContext(model)
val project = TestWorkspace.testProject(context.symbolProvider)
project.moduleFor(shape) {
ClientEnumGenerator(context, shape).render(this)
ClientEnumGenerator(context, shape, emptyList()).render(this)
unitTest(
"impl_debug_for_non_sensitive_enum_should_implement_the_derived_debug_trait",
"""
Expand Down Expand Up @@ -123,7 +123,7 @@ class ClientEnumGeneratorTest {
val context = testClientCodegenContext(model)
val project = TestWorkspace.testProject(context.symbolProvider)
project.moduleFor(shape) {
ClientEnumGenerator(context, shape).render(this)
ClientEnumGenerator(context, shape, emptyList()).render(this)
unitTest(
"it_escapes_the_unknown_variant_if_the_enum_has_an_unknown_value_in_the_model",
"""
Expand Down Expand Up @@ -156,7 +156,7 @@ class ClientEnumGeneratorTest {
val project = TestWorkspace.testProject(context.symbolProvider)
project.moduleFor(shape) {
rust("##![allow(deprecated)]")
ClientEnumGenerator(context, shape).render(this)
ClientEnumGenerator(context, shape, emptyList()).render(this)
unitTest(
"generated_named_enums_roundtrip",
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ internal class ClientInstantiatorTest {

val project = TestWorkspace.testProject(symbolProvider)
project.moduleFor(shape) {
ClientEnumGenerator(codegenContext, shape).render(this)
ClientEnumGenerator(codegenContext, shape, emptyList()).render(this)
unitTest("generate_named_enums") {
withBlock("let result = ", ";") {
sut.render(this, shape, data)
Expand All @@ -74,7 +74,7 @@ internal class ClientInstantiatorTest {

val project = TestWorkspace.testProject(symbolProvider)
project.moduleFor(shape) {
ClientEnumGenerator(codegenContext, shape).render(this)
ClientEnumGenerator(codegenContext, shape, emptyList()).render(this)
unitTest("generate_unnamed_enums") {
withBlock("let result = ", ";") {
sut.render(this, shape, data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.ModuleDocProvider
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.ManifestCustomizations
import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureCustomization
Expand Down Expand Up @@ -59,7 +60,8 @@ interface CoreCodegenDecorator<CodegenContext, CodegenSettings> {
fun extras(
codegenContext: CodegenContext,
rustCrate: RustCrate,
) {}
) {
}

/**
* Customize the documentation provider for module documentation.
Expand Down Expand Up @@ -94,6 +96,14 @@ interface CoreCodegenDecorator<CodegenContext, CodegenSettings> {
baseCustomizations: List<StructureCustomization>,
): List<StructureCustomization> = baseCustomizations

/**
* Hook to customize enums generated by `EnumGenerator`.
*/
fun enumCustomizations(
codegenContext: CodegenContext,
baseCustomizations: List<EnumCustomization>,
): List<EnumCustomization> = baseCustomizations

// TODO(https://github.com/smithy-lang/smithy-rs/issues/1401): Move builder customizations into `ClientCodegenDecorator`

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ import software.amazon.smithy.rust.codegen.core.smithy.MaybeRenamed
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customize.Section
import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations
import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata
import software.amazon.smithy.rust.codegen.core.smithy.renamedFrom
import software.amazon.smithy.rust.codegen.core.util.REDACTION
Expand All @@ -39,6 +42,25 @@ import software.amazon.smithy.rust.codegen.core.util.orNull
import software.amazon.smithy.rust.codegen.core.util.shouldRedact
import software.amazon.smithy.rust.codegen.core.util.toPascalCase

/** EnumGenerator customization sections */
sealed class EnumSection(name: String) : Section(name) {
abstract val shape: Shape

/** Hook to add additional attributes to an enum member */
data class AdditionalMemberAttributes(override val shape: Shape, val definition: EnumDefinition) :
EnumSection("AdditionalMemberAttributes")

/** Hook to add additional trait implementations */
data class AdditionalTraitImpls(override val shape: Shape) : EnumSection("AdditionalTraitImpls")

/** Hook to add additional enum members */
data class AdditionalEnumMembers(override val shape: Shape) :
EnumSection("AdditionalEnumMembers")
}

/** Customizations for EnumGenerator */
abstract class EnumCustomization : NamedCustomization<EnumSection>()

data class EnumGeneratorContext(
val enumName: String,
val enumMeta: RustMetadata,
Expand Down Expand Up @@ -86,6 +108,7 @@ class EnumMemberModel(
private val parentShape: Shape,
private val definition: EnumDefinition,
private val symbolProvider: RustSymbolProvider,
private val customizations: List<EnumCustomization>,
) {
companion object {
/**
Expand Down Expand Up @@ -140,6 +163,10 @@ class EnumMemberModel(
fun render(writer: RustWriter) {
renderDocumentation(writer)
renderDeprecated(writer)
writer.writeCustomizations(
customizations,
EnumSection.AdditionalMemberAttributes(parentShape, definition),
)
writer.write("${derivedName()},")
}
}
Expand Down Expand Up @@ -167,6 +194,7 @@ open class EnumGenerator(
private val symbolProvider: RustSymbolProvider,
private val shape: StringShape,
private val enumType: EnumType,
private val customizations: List<EnumCustomization>,
) {
companion object {
/** Name of the function on the enum impl to get a vec of value names */
Expand All @@ -180,7 +208,8 @@ open class EnumGenerator(
enumName = symbol.name,
enumMeta = symbol.expectRustMetadata(),
enumTrait = enumTrait,
sortedMembers = enumTrait.values.sortedBy { it.value }.map { EnumMemberModel(shape, it, symbolProvider) },
sortedMembers = enumTrait.values.sortedBy { it.value }
.map { EnumMemberModel(shape, it, symbolProvider, customizations) },
)

fun render(writer: RustWriter) {
Expand All @@ -193,6 +222,7 @@ open class EnumGenerator(
writer.renderUnnamedEnum()
}
enumType.additionalEnumImpls(context)(writer)
writer.writeCustomizations(customizations, EnumSection.AdditionalTraitImpls(shape))

if (shape.shouldRedact(model)) {
writer.renderDebugImplForSensitiveEnum()
Expand Down Expand Up @@ -266,6 +296,10 @@ open class EnumGenerator(
context.enumMeta.render(this)
rustBlock("enum ${context.enumName}") {
context.sortedMembers.forEach { member -> member.render(this) }
writeCustomizations(
customizations,
EnumSection.AdditionalEnumMembers(shape),
)
enumType.additionalEnumMembers(context)(this)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class EnumGeneratorTest {
testModel.lookup("test#EnumWithUnknown"),
enumTrait.values.first { it.name.orNull() == name },
symbolProvider,
emptyList(),
)

@Test
Expand Down Expand Up @@ -112,7 +113,7 @@ class EnumGeneratorTest {
shape: StringShape,
enumType: EnumType = TestEnumType,
) {
EnumGenerator(model, provider, shape, enumType).render(this)
EnumGenerator(model, provider, shape, enumType, emptyList()).render(this)
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ class JsonParserGeneratorTest {
project.moduleFor(top) {
UnionGenerator(model, symbolProvider, this, model.lookup("test#Choice")).render()
val enum = model.lookup<StringShape>("test#FooEnum")
EnumGenerator(model, symbolProvider, enum, TestEnumType).render(this)
EnumGenerator(model, symbolProvider, enum, TestEnumType, emptyList()).render(this)
}
}
model.lookup<OperationShape>("test#Op").outputShape(model).also { output ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ internal class XmlBindingTraitParserGeneratorTest {
project.moduleFor(top) {
UnionGenerator(model, symbolProvider, this, choiceShape).render()
model.lookup<StringShape>("test#FooEnum").also { enum ->
EnumGenerator(model, symbolProvider, enum, TestEnumType).render(this)
EnumGenerator(model, symbolProvider, enum, TestEnumType, emptyList()).render(this)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class AwsQuerySerializerGeneratorTest {
renderUnknownVariant = generateUnknownVariant,
).render()
val enum = model.lookup<StringShape>("test#FooEnum")
EnumGenerator(model, symbolProvider, enum, TestEnumType).render(this)
EnumGenerator(model, symbolProvider, enum, TestEnumType, emptyList()).render(this)
}
}

Expand Down Expand Up @@ -316,7 +316,7 @@ class AwsQuerySerializerGeneratorTest {
renderUnknownVariant = generateUnknownVariant,
).render()
val enum = model.lookup<StringShape>("test#FooEnum")
EnumGenerator(model, symbolProvider, enum, TestEnumType).render(this)
EnumGenerator(model, symbolProvider, enum, TestEnumType, emptyList()).render(this)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class Ec2QuerySerializerGeneratorTest {
project.moduleFor(top) {
UnionGenerator(model, symbolProvider, this, model.lookup("test#Choice")).render()
val enum = model.lookup<StringShape>("test#FooEnum")
EnumGenerator(model, symbolProvider, enum, TestEnumType).render(this)
EnumGenerator(model, symbolProvider, enum, TestEnumType, emptyList()).render(this)
}
}

Expand Down Expand Up @@ -298,7 +298,7 @@ class Ec2QuerySerializerGeneratorTest {
project.moduleFor(top) {
UnionGenerator(model, symbolProvider, this, model.lookup("test#Choice")).render()
val enum = model.lookup<StringShape>("test#FooEnum")
EnumGenerator(model, symbolProvider, enum, TestEnumType).render(this)
EnumGenerator(model, symbolProvider, enum, TestEnumType, emptyList()).render(this)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ class JsonSerializerGeneratorTest {
project.moduleFor(top) {
UnionGenerator(model, symbolProvider, this, model.lookup("test#Choice")).render()
val enum = model.lookup<StringShape>("test#FooEnum")
EnumGenerator(model, symbolProvider, enum, TestEnumType).render(this)
EnumGenerator(model, symbolProvider, enum, TestEnumType, emptyList()).render(this)
}
}

Expand Down Expand Up @@ -333,7 +333,7 @@ class JsonSerializerGeneratorTest {
project.moduleFor(top) {
UnionGenerator(model, symbolProvider, this, model.lookup("test#Choice")).render()
val enum = model.lookup<StringShape>("test#FooEnum")
EnumGenerator(model, symbolProvider, enum, TestEnumType).render(this)
EnumGenerator(model, symbolProvider, enum, TestEnumType, emptyList()).render(this)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ internal class XmlBindingTraitSerializerGeneratorTest {
project.moduleFor(top) {
UnionGenerator(model, symbolProvider, this, model.lookup("test#Choice")).render()
val enum = model.lookup<StringShape>("test#FooEnum")
EnumGenerator(model, symbolProvider, enum, TestEnumType).render(this)
EnumGenerator(model, symbolProvider, enum, TestEnumType, emptyList()).render(this)
}
}
model.lookup<OperationShape>("test#Op").inputShape(model).also { input ->
Expand Down Expand Up @@ -334,7 +334,7 @@ internal class XmlBindingTraitSerializerGeneratorTest {
project.moduleFor(top) {
UnionGenerator(model, symbolProvider, this, model.lookup("test#Choice")).render()
val enum = model.lookup<StringShape>("test#FooEnum")
EnumGenerator(model, symbolProvider, enum, TestEnumType).render(this)
EnumGenerator(model, symbolProvider, enum, TestEnumType, emptyList()).render(this)
}
}
model.lookup<OperationShape>("test#Op").inputShape(model).also { input ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,12 @@ open class ServerCodegenVisitor(
fun serverEnumGeneratorFactory(
codegenContext: ServerCodegenContext,
shape: StringShape,
) = ServerEnumGenerator(codegenContext, shape, validationExceptionConversionGenerator)
) = ServerEnumGenerator(
codegenContext,
shape,
validationExceptionConversionGenerator,
codegenDecorator.enumCustomizations(codegenContext, emptyList()),
)
stringShape(shape, ::serverEnumGeneratorFactory)
}

Expand Down
Loading