55package software.amazon.smithy.kotlin.codegen.rendering.serde
66
77import software.amazon.smithy.codegen.core.CodegenException
8+ import software.amazon.smithy.codegen.core.Symbol
89import software.amazon.smithy.kotlin.codegen.core.*
910import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes
1011import software.amazon.smithy.kotlin.codegen.model.*
@@ -145,9 +146,10 @@ open class DeserializeStructGenerator(
145146 .indent()
146147 .withBlock(" deserializer.#T($descriptorName ) {" , " }" , RuntimeTypes .Serde .deserializeMap) {
147148 write(
148- " val #L = #T<String , #T#L>()" ,
149+ " val #L = #T<#T , #T#L>()" ,
149150 mutableCollectionName,
150151 KotlinTypes .Collections .mutableMapOf,
152+ ctx.symbolProvider.toSymbol(targetShape.key),
151153 ctx.symbolProvider.toSymbol(targetShape.value),
152154 nullabilitySuffix(targetShape.isSparse),
153155 )
@@ -168,6 +170,8 @@ open class DeserializeStructGenerator(
168170 nestingLevel : Int ,
169171 parentMemberName : String ,
170172 ) {
173+ val keyShape = ctx.model.expectShape(mapShape.key.target)
174+ val keySymbol = ctx.symbolProvider.toSymbol(keyShape)
171175 val elementShape = ctx.model.expectShape(mapShape.value.target)
172176 val isSparse = mapShape.isSparse
173177
@@ -187,21 +191,47 @@ open class DeserializeStructGenerator(
187191 ShapeType .TIMESTAMP ,
188192 ShapeType .ENUM ,
189193 ShapeType .INT_ENUM ,
190- -> renderEntry(elementShape, nestingLevel, isSparse, parentMemberName)
194+ -> renderEntry(keyShape, keySymbol, elementShape, nestingLevel, isSparse, parentMemberName)
191195
192196 ShapeType .SET ,
193197 ShapeType .LIST ,
194- -> renderListEntry(rootMemberShape, elementShape as CollectionShape , nestingLevel, isSparse, parentMemberName)
198+ -> renderListEntry(
199+ rootMemberShape,
200+ keyShape,
201+ keySymbol,
202+ elementShape as CollectionShape ,
203+ nestingLevel,
204+ isSparse,
205+ parentMemberName,
206+ )
207+
208+ ShapeType .MAP -> renderMapEntry(
209+ rootMemberShape,
210+ keyShape,
211+ keySymbol,
212+ elementShape as MapShape ,
213+ nestingLevel,
214+ isSparse,
215+ parentMemberName,
216+ )
195217
196- ShapeType .MAP -> renderMapEntry(rootMemberShape, elementShape as MapShape , nestingLevel, isSparse, parentMemberName)
197218 ShapeType .UNION ,
198219 ShapeType .STRUCTURE ,
199- -> renderNestedStructureEntry(elementShape, nestingLevel, isSparse, parentMemberName)
220+ -> renderNestedStructureEntry(keyShape, keySymbol, elementShape, nestingLevel, isSparse, parentMemberName)
200221
201222 else -> error(" Unhandled type ${elementShape.type} " )
202223 }
203224 }
204225
226+ private fun writeKeyVal (keyShape : Shape , keySymbol : Symbol , keyName : String ) {
227+ writer.writeInline(" val $keyName = " )
228+ if (keyShape.isEnum) {
229+ writer.write(" #T.fromValue(key())" , keySymbol)
230+ } else {
231+ writer.write(" key()" )
232+ }
233+ }
234+
205235 /* *
206236 * Renders the deserialization of a nested structure contained in a map. Example:
207237 *
@@ -212,6 +242,8 @@ open class DeserializeStructGenerator(
212242 * ```
213243 */
214244 private fun renderNestedStructureEntry (
245+ keyShape : Shape ,
246+ keySymbol : Symbol ,
215247 elementShape : Shape ,
216248 nestingLevel : Int ,
217249 isSparse : Boolean ,
@@ -226,7 +258,7 @@ open class DeserializeStructGenerator(
226258 writer.addImport(symbol)
227259 }
228260
229- writer.write( " val $keyName = key() " )
261+ writeKeyVal(keyShape, keySymbol, keyName )
230262 writer.write(" val $valueName = if (nextHasValue()) { $deserializerFn } else { deserializeNull()$populateNullValuePostfix }" )
231263 writer.write(" $parentMemberName [$keyName ] = $valueName " )
232264 }
@@ -247,6 +279,8 @@ open class DeserializeStructGenerator(
247279 */
248280 private fun renderMapEntry (
249281 rootMemberShape : MemberShape ,
282+ keyShape : Shape ,
283+ keySymbol : Symbol ,
250284 mapShape : MapShape ,
251285 nestingLevel : Int ,
252286 isSparse : Boolean ,
@@ -260,14 +294,15 @@ open class DeserializeStructGenerator(
260294 val memberName = nextNestingLevel.variableNameFor(NestedIdentifierType .MAP )
261295 val collectionReturnExpression = collectionReturnExpression(rootMemberShape, memberName)
262296
263- writer.write( " val $keyName = key() " )
297+ writeKeyVal(keyShape, keySymbol, keyName )
264298 writer.withBlock(" val $valueName =" , " " ) {
265299 withBlock(" if (nextHasValue()) {" , " } else { deserializeNull()$populateNullValuePostfix }" ) {
266300 withBlock(" deserializer.#T($descriptorName ) {" , " }" , RuntimeTypes .Serde .deserializeMap) {
267301 write(
268- " val #L = #T<String , #T#L>()" ,
302+ " val #L = #T<#T , #T#L>()" ,
269303 memberName,
270304 KotlinTypes .Collections .mutableMapOf,
305+ keySymbol,
271306 ctx.symbolProvider.toSymbol(mapShape.value),
272307 nullabilitySuffix(mapShape.isSparse),
273308 )
@@ -298,6 +333,8 @@ open class DeserializeStructGenerator(
298333 */
299334 private fun renderListEntry (
300335 rootMemberShape : MemberShape ,
336+ keyShape : Shape ,
337+ keySymbol : Symbol ,
301338 collectionShape : CollectionShape ,
302339 nestingLevel : Int ,
303340 isSparse : Boolean ,
@@ -311,7 +348,7 @@ open class DeserializeStructGenerator(
311348 val memberName = nextNestingLevel.variableNameFor(NestedIdentifierType .COLLECTION )
312349 val collectionReturnExpression = collectionReturnExpression(rootMemberShape, memberName)
313350
314- writer.write( " val $keyName = key() " )
351+ writeKeyVal(keyShape, keySymbol, keyName )
315352 writer.withBlock(" val $valueName =" , " " ) {
316353 withBlock(" if (nextHasValue()) {" , " } else { deserializeNull()$populateNullValuePostfix }" ) {
317354 withBlock(" deserializer.#T($descriptorName ) {" , " }" , RuntimeTypes .Serde .deserializeList) {
@@ -340,13 +377,20 @@ open class DeserializeStructGenerator(
340377 * map0[k0] = el0
341378 * ```
342379 */
343- private fun renderEntry (elementShape : Shape , nestingLevel : Int , isSparse : Boolean , parentMemberName : String ) {
380+ private fun renderEntry (
381+ keyShape : Shape ,
382+ keySymbol : Symbol ,
383+ elementShape : Shape ,
384+ nestingLevel : Int ,
385+ isSparse : Boolean ,
386+ parentMemberName : String ,
387+ ) {
344388 val deserializerFn = deserializerForShape(elementShape)
345389 val keyName = nestingLevel.variableNameFor(NestedIdentifierType .KEY )
346390 val valueName = nestingLevel.variableNameFor(NestedIdentifierType .VALUE )
347391 val populateNullValuePostfix = if (isSparse) " " else " ; continue"
348392
349- writer.write( " val $keyName = key() " )
393+ writeKeyVal(keyShape, keySymbol, keyName )
350394 writer.write(" val $valueName = if (nextHasValue()) { $deserializerFn } else { deserializeNull()$populateNullValuePostfix }" )
351395 writer.write(" $parentMemberName [$keyName ] = $valueName " )
352396 }
@@ -476,9 +520,10 @@ open class DeserializeStructGenerator(
476520
477521 writer.withBlock(" val $elementName = deserializer.#T($descriptorName ) {" , " }" , RuntimeTypes .Serde .deserializeMap) {
478522 write(
479- " val #L = #T<String , #T#L>()" ,
523+ " val #L = #T<#T , #T#L>()" ,
480524 mapName,
481525 KotlinTypes .Collections .mutableMapOf,
526+ ctx.symbolProvider.toSymbol(mapShape.key),
482527 ctx.symbolProvider.toSymbol(mapShape.value),
483528 nullabilitySuffix(mapShape.isSparse),
484529 )
0 commit comments