Skip to content

Commit 78adf24

Browse files
authored
fix: correctly codegen maps with enum keys (#1052)
1 parent b866db4 commit 78adf24

File tree

14 files changed

+340
-55
lines changed

14 files changed

+340
-55
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"id": "a7f82a33-11f1-4184-97af-ff713e922dfc",
3+
"type": "bugfix",
4+
"description": "⚠️ **IMPORTANT**: Fix codegen for map shapes which use string enums as map keys. See the [**Map key changes** breaking change announcement](https://github.com/awslabs/aws-sdk-kotlin/discussions/1258) for more details",
5+
"issues": [
6+
"awslabs/smithy-kotlin#1045"
7+
],
8+
"requiresMinorVersionBump": true
9+
}

codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/KotlinSettings.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ data class KotlinSettings(
4343
val sdkId: String,
4444
val build: BuildSettings = BuildSettings.Default,
4545
val api: ApiSettings = ApiSettings.Default,
46+
val debug: Boolean = false,
4647
) {
4748

4849
/**
@@ -104,12 +105,14 @@ data class KotlinSettings(
104105
val sdkId = config.getStringMemberOrDefault(SDK_ID, serviceId.name)
105106
val build = config.getObjectMember(BUILD_SETTINGS)
106107
val api = config.getObjectMember(API_SETTINGS)
108+
val debug = config.getBooleanMemberOrDefault("debug", false)
107109
return KotlinSettings(
108110
serviceId,
109111
PackageSettings(packageName, version, desc),
110112
sdkId,
111113
BuildSettings.fromNode(build),
112114
ApiSettings.fromNode(api),
115+
debug,
113116
)
114117
}
115118
}

codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinDelegator.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ class KotlinDelegator(
150150
val needsNewline = writers.containsKey(formattedFilename)
151151
val writer = writers.getOrPut(formattedFilename) {
152152
val kotlinWriter = KotlinWriter(namespace)
153+
if (settings.debug) kotlinWriter.enableStackTraceComments(true)
153154

154155
// Register all integrations [SectionWriterBindings] on the writer.
155156
integrations.forEach { integration ->

codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinSymbolProvider.kt

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ package software.amazon.smithy.kotlin.codegen.core
66

77
import software.amazon.smithy.codegen.core.*
88
import software.amazon.smithy.kotlin.codegen.KotlinSettings
9-
import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes
109
import software.amazon.smithy.kotlin.codegen.lang.kotlinReservedWords
1110
import software.amazon.smithy.kotlin.codegen.model.*
1211
import software.amazon.smithy.kotlin.codegen.utils.dq
@@ -162,15 +161,18 @@ class KotlinSymbolProvider(private val model: Model, private val settings: Kotli
162161
}
163162

164163
override fun mapShape(shape: MapShape): Symbol {
165-
val reference = toSymbol(shape.value)
166-
val valueSuffix = if (reference.isNullable) "?" else ""
167-
val valueType = "${reference.name}$valueSuffix"
168-
val fullyQualifiedValueType = "${reference.fullName}$valueSuffix"
164+
val keyReference = toSymbol(shape.key)
165+
val keyType = keyReference.name
166+
val fullyQualifiedKeyType = keyReference.fullName
167+
168+
val valueReference = toSymbol(shape.value)
169+
val valueSuffix = if (valueReference.isNullable) "?" else ""
170+
val valueType = "${valueReference.name}$valueSuffix"
171+
val fullyQualifiedValueType = "${valueReference.fullName}$valueSuffix"
169172

170-
val keyType = KotlinTypes.String.name
171-
val fullyQualifiedKeyType = KotlinTypes.String.fullName
172173
return createSymbolBuilder(shape, "Map<$keyType, $valueType>")
173-
.addReferences(reference)
174+
.addReferences(keyReference)
175+
.addReferences(valueReference)
174176
.putProperty(SymbolProperty.FULLY_QUALIFIED_NAME_HINT, "Map<$fullyQualifiedKeyType, $fullyQualifiedValueType>")
175177
.putProperty(SymbolProperty.MUTABLE_COLLECTION_FUNCTION, "mutableMapOf<$keyType, $valueType>")
176178
.putProperty(SymbolProperty.IMMUTABLE_COLLECTION_FUNCTION, "mapOf<$keyType, $valueType>")

codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/ShapeValueGenerator.kt

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,16 @@ class ShapeValueGenerator(
188188
}
189189
is MapShape -> {
190190
memberShape = generator.model.expectShape(currShape.value.target)
191-
writer.writeInline("#S to ", keyNode.value)
191+
192+
val keyTarget = generator.model.expectShape(currShape.key.target)
193+
if (keyTarget.isEnum) {
194+
val keySymbol = generator.symbolProvider.toSymbol(currShape.key)
195+
writer.writeInline("#T.fromValue(#S)", keySymbol, keyNode.value)
196+
} else {
197+
writer.writeInline("#S", keyNode.value)
198+
}
199+
200+
writer.writeInline(" to ")
192201

193202
if (valueNode is NullNode) {
194203
writer.write("null")

codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/DeserializeStructGenerator.kt

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package software.amazon.smithy.kotlin.codegen.rendering.serde
66

77
import software.amazon.smithy.codegen.core.CodegenException
8+
import software.amazon.smithy.codegen.core.Symbol
89
import software.amazon.smithy.kotlin.codegen.core.*
910
import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes
1011
import software.amazon.smithy.kotlin.codegen.model.*
@@ -145,9 +146,10 @@ open class DeserializeStructGenerator(
145146
.indent()
146147
.withBlock("deserializer.#T($descriptorName) {", "}", RuntimeTypes.Serde.deserializeMap) {
147148
write(
148-
"val #L = #T<String, #T#L>()",
149+
"val #L = #T<#T, #T#L>()",
149150
mutableCollectionName,
150151
KotlinTypes.Collections.mutableMapOf,
152+
ctx.symbolProvider.toSymbol(targetShape.key),
151153
ctx.symbolProvider.toSymbol(targetShape.value),
152154
nullabilitySuffix(targetShape.isSparse),
153155
)
@@ -168,6 +170,8 @@ open class DeserializeStructGenerator(
168170
nestingLevel: Int,
169171
parentMemberName: String,
170172
) {
173+
val keyShape = ctx.model.expectShape(mapShape.key.target)
174+
val keySymbol = ctx.symbolProvider.toSymbol(keyShape)
171175
val elementShape = ctx.model.expectShape(mapShape.value.target)
172176
val isSparse = mapShape.isSparse
173177

@@ -187,21 +191,47 @@ open class DeserializeStructGenerator(
187191
ShapeType.TIMESTAMP,
188192
ShapeType.ENUM,
189193
ShapeType.INT_ENUM,
190-
-> renderEntry(elementShape, nestingLevel, isSparse, parentMemberName)
194+
-> renderEntry(keyShape, keySymbol, elementShape, nestingLevel, isSparse, parentMemberName)
191195

192196
ShapeType.SET,
193197
ShapeType.LIST,
194-
-> renderListEntry(rootMemberShape, elementShape as CollectionShape, nestingLevel, isSparse, parentMemberName)
198+
-> renderListEntry(
199+
rootMemberShape,
200+
keyShape,
201+
keySymbol,
202+
elementShape as CollectionShape,
203+
nestingLevel,
204+
isSparse,
205+
parentMemberName,
206+
)
207+
208+
ShapeType.MAP -> renderMapEntry(
209+
rootMemberShape,
210+
keyShape,
211+
keySymbol,
212+
elementShape as MapShape,
213+
nestingLevel,
214+
isSparse,
215+
parentMemberName,
216+
)
195217

196-
ShapeType.MAP -> renderMapEntry(rootMemberShape, elementShape as MapShape, nestingLevel, isSparse, parentMemberName)
197218
ShapeType.UNION,
198219
ShapeType.STRUCTURE,
199-
-> renderNestedStructureEntry(elementShape, nestingLevel, isSparse, parentMemberName)
220+
-> renderNestedStructureEntry(keyShape, keySymbol, elementShape, nestingLevel, isSparse, parentMemberName)
200221

201222
else -> error("Unhandled type ${elementShape.type}")
202223
}
203224
}
204225

226+
private fun writeKeyVal(keyShape: Shape, keySymbol: Symbol, keyName: String) {
227+
writer.writeInline("val $keyName = ")
228+
if (keyShape.isEnum) {
229+
writer.write("#T.fromValue(key())", keySymbol)
230+
} else {
231+
writer.write("key()")
232+
}
233+
}
234+
205235
/**
206236
* Renders the deserialization of a nested structure contained in a map. Example:
207237
*
@@ -212,6 +242,8 @@ open class DeserializeStructGenerator(
212242
* ```
213243
*/
214244
private fun renderNestedStructureEntry(
245+
keyShape: Shape,
246+
keySymbol: Symbol,
215247
elementShape: Shape,
216248
nestingLevel: Int,
217249
isSparse: Boolean,
@@ -226,7 +258,7 @@ open class DeserializeStructGenerator(
226258
writer.addImport(symbol)
227259
}
228260

229-
writer.write("val $keyName = key()")
261+
writeKeyVal(keyShape, keySymbol, keyName)
230262
writer.write("val $valueName = if (nextHasValue()) { $deserializerFn } else { deserializeNull()$populateNullValuePostfix }")
231263
writer.write("$parentMemberName[$keyName] = $valueName")
232264
}
@@ -247,6 +279,8 @@ open class DeserializeStructGenerator(
247279
*/
248280
private fun renderMapEntry(
249281
rootMemberShape: MemberShape,
282+
keyShape: Shape,
283+
keySymbol: Symbol,
250284
mapShape: MapShape,
251285
nestingLevel: Int,
252286
isSparse: Boolean,
@@ -260,14 +294,15 @@ open class DeserializeStructGenerator(
260294
val memberName = nextNestingLevel.variableNameFor(NestedIdentifierType.MAP)
261295
val collectionReturnExpression = collectionReturnExpression(rootMemberShape, memberName)
262296

263-
writer.write("val $keyName = key()")
297+
writeKeyVal(keyShape, keySymbol, keyName)
264298
writer.withBlock("val $valueName =", "") {
265299
withBlock("if (nextHasValue()) {", "} else { deserializeNull()$populateNullValuePostfix }") {
266300
withBlock("deserializer.#T($descriptorName) {", "}", RuntimeTypes.Serde.deserializeMap) {
267301
write(
268-
"val #L = #T<String, #T#L>()",
302+
"val #L = #T<#T, #T#L>()",
269303
memberName,
270304
KotlinTypes.Collections.mutableMapOf,
305+
keySymbol,
271306
ctx.symbolProvider.toSymbol(mapShape.value),
272307
nullabilitySuffix(mapShape.isSparse),
273308
)
@@ -298,6 +333,8 @@ open class DeserializeStructGenerator(
298333
*/
299334
private fun renderListEntry(
300335
rootMemberShape: MemberShape,
336+
keyShape: Shape,
337+
keySymbol: Symbol,
301338
collectionShape: CollectionShape,
302339
nestingLevel: Int,
303340
isSparse: Boolean,
@@ -311,7 +348,7 @@ open class DeserializeStructGenerator(
311348
val memberName = nextNestingLevel.variableNameFor(NestedIdentifierType.COLLECTION)
312349
val collectionReturnExpression = collectionReturnExpression(rootMemberShape, memberName)
313350

314-
writer.write("val $keyName = key()")
351+
writeKeyVal(keyShape, keySymbol, keyName)
315352
writer.withBlock("val $valueName =", "") {
316353
withBlock("if (nextHasValue()) {", "} else { deserializeNull()$populateNullValuePostfix }") {
317354
withBlock("deserializer.#T($descriptorName) {", "}", RuntimeTypes.Serde.deserializeList) {
@@ -340,13 +377,20 @@ open class DeserializeStructGenerator(
340377
* map0[k0] = el0
341378
* ```
342379
*/
343-
private fun renderEntry(elementShape: Shape, nestingLevel: Int, isSparse: Boolean, parentMemberName: String) {
380+
private fun renderEntry(
381+
keyShape: Shape,
382+
keySymbol: Symbol,
383+
elementShape: Shape,
384+
nestingLevel: Int,
385+
isSparse: Boolean,
386+
parentMemberName: String,
387+
) {
344388
val deserializerFn = deserializerForShape(elementShape)
345389
val keyName = nestingLevel.variableNameFor(NestedIdentifierType.KEY)
346390
val valueName = nestingLevel.variableNameFor(NestedIdentifierType.VALUE)
347391
val populateNullValuePostfix = if (isSparse) "" else "; continue"
348392

349-
writer.write("val $keyName = key()")
393+
writeKeyVal(keyShape, keySymbol, keyName)
350394
writer.write("val $valueName = if (nextHasValue()) { $deserializerFn } else { deserializeNull()$populateNullValuePostfix }")
351395
writer.write("$parentMemberName[$keyName] = $valueName")
352396
}
@@ -476,9 +520,10 @@ open class DeserializeStructGenerator(
476520

477521
writer.withBlock("val $elementName = deserializer.#T($descriptorName) {", "}", RuntimeTypes.Serde.deserializeMap) {
478522
write(
479-
"val #L = #T<String, #T#L>()",
523+
"val #L = #T<#T, #T#L>()",
480524
mapName,
481525
KotlinTypes.Collections.mutableMapOf,
526+
ctx.symbolProvider.toSymbol(mapShape.key),
482527
ctx.symbolProvider.toSymbol(mapShape.value),
483528
nullabilitySuffix(mapShape.isSparse),
484529
)

0 commit comments

Comments
 (0)