Skip to content

Commit a27a789

Browse files
committed
disable name hack by default again, added JCP case for auto-applying the expression encoder without spark-connect
1 parent 0c8f4b1 commit a27a789

File tree

3 files changed

+173
-23
lines changed

3 files changed

+173
-23
lines changed

buildSrc/src/main/kotlin/Versions.kt

+6-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ object Versions {
99
inline val scala get() = System.getProperty("scala") as String
1010
inline val sparkMinor get() = spark.substringBeforeLast('.')
1111
inline val scalaCompat get() = scala.substringBeforeLast('.')
12+
13+
// TODO
14+
const val sparkConnect = false
15+
1216
const val jupyter = "0.12.0-32-1"
1317

1418
const val kotest = "5.5.4"
@@ -25,14 +29,15 @@ object Versions {
2529
const val jacksonDatabind = "2.13.4.2"
2630
const val kotlinxDateTime = "0.6.0-RC.2"
2731

28-
inline val versionMap
32+
inline val versionMap: Map<String, String>
2933
get() = mapOf(
3034
"kotlin" to kotlin,
3135
"scala" to scala,
3236
"scalaCompat" to scalaCompat,
3337
"spark" to spark,
3438
"sparkMinor" to sparkMinor,
3539
"version" to project,
40+
"sparkConnect" to sparkConnect.toString(),
3641
)
3742

3843
}

kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt

+28-20
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,15 @@ fun <T : Any> kotlinEncoderFor(
6969
arguments: List<KTypeProjection> = emptyList(),
7070
nullable: Boolean = false,
7171
annotations: List<Annotation> = emptyList()
72-
): Encoder<T> = ExpressionEncoder.apply(
73-
KotlinTypeInference.encoderFor(
74-
kClass = kClass,
75-
arguments = arguments,
76-
nullable = nullable,
77-
annotations = annotations,
72+
): Encoder<T> =
73+
applyEncoder(
74+
KotlinTypeInference.encoderFor(
75+
kClass = kClass,
76+
arguments = arguments,
77+
nullable = nullable,
78+
annotations = annotations,
79+
)
7880
)
79-
)
8081

8182
/**
8283
* Main method of API, which gives you seamless integration with Spark:
@@ -88,15 +89,26 @@ fun <T : Any> kotlinEncoderFor(
8889
* @return generated encoder
8990
*/
9091
inline fun <reified T> kotlinEncoderFor(): Encoder<T> =
91-
ExpressionEncoder.apply(
92-
KotlinTypeInference.encoderFor<T>()
92+
kotlinEncoderFor(
93+
typeOf<T>()
9394
)
9495

9596
fun <T> kotlinEncoderFor(kType: KType): Encoder<T> =
96-
ExpressionEncoder.apply(
97+
applyEncoder(
9798
KotlinTypeInference.encoderFor(kType)
9899
)
99100

101+
/**
102+
* For spark-connect, no ExpressionEncoder is needed, so we can just return the AgnosticEncoder.
103+
*/
104+
private fun <T> applyEncoder(agnosticEncoder: AgnosticEncoder<T>): Encoder<T> {
105+
//#if sparkConnect == false
106+
return ExpressionEncoder.apply(agnosticEncoder)
107+
//#else
108+
//$return agnosticEncoder
109+
//#endif
110+
}
111+
100112

101113
@Deprecated("Use kotlinEncoderFor instead", ReplaceWith("kotlinEncoderFor<T>()"))
102114
inline fun <reified T> encoder(): Encoder<T> = kotlinEncoderFor(typeOf<T>())
@@ -112,7 +124,7 @@ object KotlinTypeInference {
112124
// TODO this hack is a WIP and can give errors
113125
// TODO it's to make data classes get column names like "age" with functions like "getAge"
114126
// TODO instead of column names like "getAge"
115-
var DO_NAME_HACK = true
127+
var DO_NAME_HACK = false
116128

117129
/**
118130
* @param kClass the class for which to infer the encoder.
@@ -151,7 +163,6 @@ object KotlinTypeInference {
151163
currentType = kType,
152164
seenTypeSet = emptySet(),
153165
typeVariables = emptyMap(),
154-
isTopLevel = true,
155166
) as AgnosticEncoder<T>
156167

157168

@@ -218,7 +229,6 @@ object KotlinTypeInference {
218229

219230
// how the generic types of the data class (like T, S) are filled in for this instance of the class
220231
typeVariables: Map<String, KType>,
221-
isTopLevel: Boolean = false,
222232
): AgnosticEncoder<*> {
223233
val kClass =
224234
currentType.classifier as? KClass<*> ?: throw IllegalArgumentException("Unsupported type $currentType")
@@ -328,7 +338,7 @@ object KotlinTypeInference {
328338
AgnosticEncoders.UDTEncoder(udt, udt.javaClass)
329339
}
330340

331-
currentType.isSubtypeOf<scala.Option<*>>() -> {
341+
currentType.isSubtypeOf<scala.Option<*>?>() -> {
332342
val elementEncoder = encoderFor(
333343
currentType = tArguments.first().type!!,
334344
seenTypeSet = seenTypeSet,
@@ -506,7 +516,6 @@ object KotlinTypeInference {
506516

507517
DirtyProductEncoderField(
508518
doNameHack = DO_NAME_HACK,
509-
isTopLevel = isTopLevel,
510519
columnName = paramName,
511520
readMethodName = readMethodName,
512521
writeMethodName = writeMethodName,
@@ -525,7 +534,7 @@ object KotlinTypeInference {
525534
if (currentType in seenTypeSet) throw IllegalStateException("Circular reference detected for type $currentType")
526535
val constructorParams = currentType.getScalaConstructorParameters(typeVariables, kClass)
527536

528-
val params: List<AgnosticEncoders.EncoderField> = constructorParams.map { (paramName, paramType) ->
537+
val params = constructorParams.map { (paramName, paramType) ->
529538
val encoder = encoderFor(
530539
currentType = paramType,
531540
seenTypeSet = seenTypeSet + currentType,
@@ -564,7 +573,6 @@ internal open class DirtyProductEncoderField(
564573
private val readMethodName: String, // the name of the method used to read the value
565574
private val writeMethodName: String?,
566575
private val doNameHack: Boolean,
567-
private val isTopLevel: Boolean,
568576
encoder: AgnosticEncoder<*>,
569577
nullable: Boolean,
570578
metadata: Metadata = Metadata.empty(),
@@ -577,18 +585,18 @@ internal open class DirtyProductEncoderField(
577585
/* writeMethod = */ writeMethodName.toOption(),
578586
), Serializable {
579587

580-
private var isFirstNameCall = true
588+
private var noNameCalls = 0
581589

582590
/**
583591
* This dirty trick only works because in [SerializerBuildHelper], [ProductEncoder]
584592
* creates an [Invoke] using [name] first and then calls [name] again to retrieve
585593
* the name of the column. This way, we can alternate between the two names.
586594
*/
587595
override fun name(): String =
588-
if (doNameHack && !isFirstNameCall) {
596+
if (doNameHack && noNameCalls > 0) {
589597
columnName
590598
} else {
591-
isFirstNameCall = false
599+
noNameCalls++
592600
readMethodName
593601
}
594602

kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/EncodingTest.kt

+139-2
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ package org.jetbrains.kotlinx.spark.api
2222
import ch.tutteli.atrium.api.fluent.en_GB.*
2323
import ch.tutteli.atrium.api.verbs.expect
2424
import io.kotest.core.spec.style.ShouldSpec
25-
import io.kotest.matchers.collections.shouldContain
2625
import io.kotest.matchers.collections.shouldContainExactly
2726
import io.kotest.matchers.shouldBe
27+
import io.kotest.matchers.string.shouldContain
2828
import org.apache.spark.sql.Dataset
2929
import org.apache.spark.sql.types.Decimal
3030
import org.apache.spark.unsafe.types.CalendarInterval
@@ -210,7 +210,7 @@ class EncodingTest : ShouldSpec({
210210
context("schema") {
211211
withSpark(props = mapOf("spark.sql.codegen.comments" to true)) {
212212

213-
context("Give proper names to columns of data classe") {
213+
context("Give proper names to columns of data classes") {
214214
val old = KotlinTypeInference.DO_NAME_HACK
215215
KotlinTypeInference.DO_NAME_HACK = true
216216

@@ -240,6 +240,142 @@ class EncodingTest : ShouldSpec({
240240
dataset.collectAsList() shouldBe pairs
241241
}
242242

243+
should("Be able to serialize pairs of pairs of pairs") {
244+
val pairs = listOf(
245+
1 to (1 to (1 to "1")),
246+
2 to (2 to (2 to "2")),
247+
3 to (3 to (3 to "3")),
248+
)
249+
val dataset = pairs.toDS()
250+
dataset.show()
251+
dataset.printSchema()
252+
dataset.columns().shouldContainExactly("first", "second")
253+
dataset.select("second.*").columns().shouldContainExactly("first", "second")
254+
dataset.select("second.second.*").columns().shouldContainExactly("first", "second")
255+
dataset.collectAsList() shouldBe pairs
256+
}
257+
258+
should("Be able to serialize lists of pairs") {
259+
val pairs = listOf(
260+
listOf(1 to "1", 2 to "2"),
261+
listOf(3 to "3", 4 to "4"),
262+
)
263+
val dataset = pairs.toDS()
264+
dataset.show()
265+
dataset.printSchema()
266+
dataset.schema().toString().let {
267+
it shouldContain "first"
268+
it shouldContain "second"
269+
}
270+
dataset.collectAsList() shouldBe pairs
271+
}
272+
273+
should("Be able to serialize lists of lists of pairs") {
274+
val pairs = listOf(
275+
listOf(
276+
listOf(1 to "1", 2 to "2"),
277+
listOf(3 to "3", 4 to "4")
278+
)
279+
)
280+
val dataset = pairs.toDS()
281+
dataset.show()
282+
dataset.printSchema()
283+
dataset.schema().toString().let {
284+
it shouldContain "first"
285+
it shouldContain "second"
286+
}
287+
dataset.collectAsList() shouldBe pairs
288+
}
289+
290+
should("Be able to serialize lists of lists of lists of pairs") {
291+
val pairs = listOf(
292+
listOf(
293+
listOf(
294+
listOf(1 to "1", 2 to "2"),
295+
listOf(3 to "3", 4 to "4"),
296+
)
297+
)
298+
)
299+
val dataset = pairs.toDS()
300+
dataset.show()
301+
dataset.printSchema()
302+
dataset.schema().toString().let {
303+
it shouldContain "first"
304+
it shouldContain "second"
305+
}
306+
dataset.collectAsList() shouldBe pairs
307+
}
308+
309+
should("Be able to serialize lists of lists of lists of pairs of pairs") {
310+
val pairs = listOf(
311+
listOf(
312+
listOf(
313+
listOf(1 to ("1" to 3.0), 2 to ("2" to 3.0)),
314+
listOf(3 to ("3" to 3.0), 4 to ("4" to 3.0)),
315+
)
316+
)
317+
)
318+
val dataset = pairs.toDS()
319+
dataset.show()
320+
dataset.printSchema()
321+
dataset.schema().toString().let {
322+
it shouldContain "first"
323+
it shouldContain "second"
324+
}
325+
dataset.collectAsList() shouldBe pairs
326+
}
327+
328+
should("Be able to serialize arrays of pairs") {
329+
val pairs = arrayOf(
330+
arrayOf(1 to "1", 2 to "2"),
331+
arrayOf(3 to "3", 4 to "4"),
332+
)
333+
val dataset = pairs.toDS()
334+
dataset.show()
335+
dataset.printSchema()
336+
dataset.schema().toString().let {
337+
it shouldContain "first"
338+
it shouldContain "second"
339+
}
340+
dataset.collectAsList() shouldBe pairs
341+
}
342+
343+
should("Be able to serialize arrays of arrays of pairs") {
344+
val pairs = arrayOf(
345+
arrayOf(
346+
arrayOf(1 to "1", 2 to "2"),
347+
arrayOf(3 to "3", 4 to "4")
348+
)
349+
)
350+
val dataset = pairs.toDS()
351+
dataset.show()
352+
dataset.printSchema()
353+
dataset.schema().toString().let {
354+
it shouldContain "first"
355+
it shouldContain "second"
356+
}
357+
dataset.collectAsList() shouldBe pairs
358+
}
359+
360+
should("Be able to serialize arrays of arrays of arrays of pairs") {
361+
val pairs = arrayOf(
362+
arrayOf(
363+
arrayOf(
364+
arrayOf(1 to "1", 2 to "2"),
365+
arrayOf(3 to "3", 4 to "4"),
366+
)
367+
)
368+
)
369+
val dataset = pairs.toDS()
370+
dataset.show()
371+
dataset.printSchema()
372+
dataset.schema().toString().let {
373+
it shouldContain "first"
374+
it shouldContain "second"
375+
}
376+
dataset.collectAsList() shouldBe pairs
377+
}
378+
243379
KotlinTypeInference.DO_NAME_HACK = old
244380
}
245381

@@ -351,6 +487,7 @@ class EncodingTest : ShouldSpec({
351487
listOf(SomeClass(intArrayOf(1, 2, 3), 4)),
352488
listOf(SomeClass(intArrayOf(3, 2, 1), 0)),
353489
)
490+
dataset.printSchema()
354491

355492
val (first, second) = dataset.collectAsList()
356493

0 commit comments

Comments
 (0)