Skip to content

Commit b0e3ae2

Browse files
authored
feat: add convenience methods for union members (#639)
1 parent e8e4d0c commit b0e3ae2

File tree

4 files changed

+131
-4
lines changed

4 files changed

+131
-4
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"id": "02852edb-3c2d-4f29-a1ad-5f7e1b884268",
3+
"type": "feature",
4+
"description": "Add convenience getters for union members",
5+
"issues": [
6+
"awslabs/aws-sdk-kotlin#393"
7+
]
8+
}

smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/UnionGenerator.kt

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ class UnionGenerator(
4444
writer.renderMemberDocumentation(model, it)
4545
writer.renderAnnotations(it)
4646
val variantName = it.unionVariantName()
47-
val targetType = model.expectShape(it.target).type
48-
writer.writeInline("data class #L(val value: #Q) : #Q()", variantName, symbolProvider.toSymbol(it), symbol)
49-
when (targetType) {
47+
val variantSymbol = symbolProvider.toSymbol(it)
48+
writer.writeInline("data class #L(val value: #Q) : #Q()", variantName, variantSymbol, symbol)
49+
when (model.expectShape(it.target).type) {
5050
ShapeType.BLOB -> {
5151
writer.withBlock(" {", "}") {
5252
renderHashCode(model, listOf(it), symbolProvider, this)
@@ -56,9 +56,62 @@ class UnionGenerator(
5656
else -> writer.write("")
5757
}
5858
}
59+
5960
// generate the unknown which will always be last
6061
writer.write("object SdkUnknown : #Q()", symbol)
62+
63+
members.sortedBy { it.memberName }.forEach {
64+
val variantName = it.unionVariantName()
65+
val variantSymbol = symbolProvider.toSymbol(it)
66+
67+
writer.write("")
68+
writer.dokka {
69+
write(
70+
"""
71+
Casts this [#T] as a [#L] and retrieves its [#Q] value. Throws an exception if the [#T] is not a
72+
[#L].
73+
""".trimIndent(),
74+
symbol,
75+
variantName,
76+
variantSymbol,
77+
symbol,
78+
variantName,
79+
)
80+
}
81+
writer.write("fun as#L(): #Q = (this as #T.#L).value", variantName, variantSymbol, symbol, variantName)
82+
83+
writer.write("")
84+
writer.dokka {
85+
write(
86+
"Casts this [#T] as a [#L] and retrieves its [#Q] value. Returns null if the [#T] is not a [#L].",
87+
symbol,
88+
variantName,
89+
variantSymbol,
90+
symbol,
91+
variantName,
92+
)
93+
}
94+
writer.write(
95+
"fun as#LOrNull(): #Q? = (this as? #T.#L)?.value",
96+
variantName,
97+
variantSymbol,
98+
symbol,
99+
variantName,
100+
)
101+
}
102+
61103
writer.closeBlock("}").write("")
104+
105+
members.sortedBy { it.memberName }.forEach {
106+
val variantName = it.unionVariantName()
107+
val variantSymbol = symbolProvider.toSymbol(it)
108+
109+
writer.write("")
110+
writer.dokka {
111+
write("Casts this [#T] as a [#L] and retrieves its [#Q] value.", symbol, variantName, variantSymbol)
112+
}
113+
writer.write("val #T.#L get() = (this as #T.#L).value", symbol, variantName, symbol, variantName)
114+
}
62115
}
63116

64117
// generate a `hashCode()` implementation

smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/UnionGeneratorTest.kt

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,61 @@ class UnionGeneratorTest {
7272
data class Foo(val value: kotlin.String) : test.model.MyUnion()
7373
data class MyStruct(val value: test.model.MyStruct) : test.model.MyUnion()
7474
object SdkUnknown : test.model.MyUnion()
75+
76+
/**
77+
* Casts this [MyUnion] as a [Bar] and retrieves its [kotlin.Int] value. Throws an exception if the [MyUnion] is not a
78+
* [Bar].
79+
*/
80+
fun asBar(): kotlin.Int = (this as MyUnion.Bar).value
81+
82+
/**
83+
* Casts this [MyUnion] as a [Bar] and retrieves its [kotlin.Int] value. Returns null if the [MyUnion] is not a [Bar].
84+
*/
85+
fun asBarOrNull(): kotlin.Int? = (this as? MyUnion.Bar)?.value
86+
87+
/**
88+
* Casts this [MyUnion] as a [Baz] and retrieves its [kotlin.Int] value. Throws an exception if the [MyUnion] is not a
89+
* [Baz].
90+
*/
91+
fun asBaz(): kotlin.Int = (this as MyUnion.Baz).value
92+
93+
/**
94+
* Casts this [MyUnion] as a [Baz] and retrieves its [kotlin.Int] value. Returns null if the [MyUnion] is not a [Baz].
95+
*/
96+
fun asBazOrNull(): kotlin.Int? = (this as? MyUnion.Baz)?.value
97+
98+
/**
99+
* Casts this [MyUnion] as a [Blz] and retrieves its [kotlin.ByteArray] value. Throws an exception if the [MyUnion] is not a
100+
* [Blz].
101+
*/
102+
fun asBlz(): kotlin.ByteArray = (this as MyUnion.Blz).value
103+
104+
/**
105+
* Casts this [MyUnion] as a [Blz] and retrieves its [kotlin.ByteArray] value. Returns null if the [MyUnion] is not a [Blz].
106+
*/
107+
fun asBlzOrNull(): kotlin.ByteArray? = (this as? MyUnion.Blz)?.value
108+
109+
/**
110+
* Casts this [MyUnion] as a [Foo] and retrieves its [kotlin.String] value. Throws an exception if the [MyUnion] is not a
111+
* [Foo].
112+
*/
113+
fun asFoo(): kotlin.String = (this as MyUnion.Foo).value
114+
115+
/**
116+
* Casts this [MyUnion] as a [Foo] and retrieves its [kotlin.String] value. Returns null if the [MyUnion] is not a [Foo].
117+
*/
118+
fun asFooOrNull(): kotlin.String? = (this as? MyUnion.Foo)?.value
119+
120+
/**
121+
* Casts this [MyUnion] as a [MyStruct] and retrieves its [test.model.MyStruct] value. Throws an exception if the [MyUnion] is not a
122+
* [MyStruct].
123+
*/
124+
fun asMyStruct(): test.model.MyStruct = (this as MyUnion.MyStruct).value
125+
126+
/**
127+
* Casts this [MyUnion] as a [MyStruct] and retrieves its [test.model.MyStruct] value. Returns null if the [MyUnion] is not a [MyStruct].
128+
*/
129+
fun asMyStructOrNull(): test.model.MyStruct? = (this as? MyUnion.MyStruct)?.value
75130
}
76131
""".trimIndent()
77132

@@ -166,6 +221,17 @@ class UnionGeneratorTest {
166221
sealed class MyUnion {
167222
data class Foo(val value: test.model.MyStruct) : test.model.MyUnion()
168223
object SdkUnknown : test.model.MyUnion()
224+
225+
/**
226+
* Casts this [MyUnion] as a [Foo] and retrieves its [test.model.MyStruct] value. Throws an exception if the [MyUnion] is not a
227+
* [Foo].
228+
*/
229+
fun asFoo(): test.model.MyStruct = (this as MyUnion.Foo).value
230+
231+
/**
232+
* Casts this [MyUnion] as a [Foo] and retrieves its [test.model.MyStruct] value. Returns null if the [MyUnion] is not a [Foo].
233+
*/
234+
fun asFooOrNull(): test.model.MyStruct? = (this as? MyUnion.Foo)?.value
169235
}
170236
""".trimIndent()
171237

tests/compile/src/test/kotlin/software/amazon/smithy/kotlin/codegen/SmithySdkTest.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class SmithySdkTest {
3131
val compilationResult = compileSdkAndTest(model = model, outputSink = compileOutputStream, emitSourcesToTmp = Debug.emitSourcesToTemp)
3232
compileOutputStream.flush()
3333

34-
assertEquals(compilationResult.exitCode, KotlinCompilation.ExitCode.OK, compileOutputStream.toString())
34+
assertEquals(KotlinCompilation.ExitCode.OK, compilationResult.exitCode, compileOutputStream.toString())
3535
}
3636

3737
// FIXME - disabled until we invest time into improving the extraneous warnings we get for things like parameter never used, etc

0 commit comments

Comments
 (0)