Skip to content

Commit 28866ca

Browse files
committed
added working kotlin spark sample
1 parent 211412a commit 28866ca

File tree

6 files changed

+455
-32
lines changed

6 files changed

+455
-32
lines changed

build.gradle.kts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ allprojects {
196196
logger.warn("Could not set ktlint config on :${this.name}")
197197
}
198198

199-
// set the java toolchain version to 11 for all subprojects for CI stability
199+
// set the java toolchain version to 21 for all subprojects for CI stability
200200
extensions.findByType<KotlinJvmProjectExtension>()?.jvmToolchain(21)
201201

202202
// Attempts to configure buildConfig for each sub-project that uses it

examples/idea-examples/unsupported-data-sources/build.gradle.kts

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,24 @@ dependencies {
3131
// (kotlin) spark support
3232
implementation(libs.kotlin.spark)
3333
compileOnly(libs.spark)
34+
implementation(libs.log4j.core)
35+
implementation(libs.log4j.api)
3436
}
3537

36-
tasks.withType<KotlinCompile> {
37-
compilerOptions.jvmTarget = JvmTarget.JVM_1_8
38+
/**
39+
* Runs the kotlinSpark/typedDataset example with java 11.
40+
*/
41+
val runKotlinSparkTypedDataset by tasks.registering(JavaExec::class) {
42+
classpath = sourceSets["main"].runtimeClasspath
43+
javaLauncher = javaToolchains.launcherFor { languageVersion = JavaLanguageVersion.of(11) }
44+
mainClass = "org.jetbrains.kotlinx.dataframe.examples.kotlinSpark.TypedDatasetKt"
45+
}
46+
47+
/**
48+
* Runs the kotlinSpark/untypedDataset example with java 11.
49+
*/
50+
val runKotlinSparkUntypedDataset by tasks.registering(JavaExec::class) {
51+
classpath = sourceSets["main"].runtimeClasspath
52+
javaLauncher = javaToolchains.launcherFor { languageVersion = JavaLanguageVersion.of(11) }
53+
mainClass = "org.jetbrains.kotlinx.dataframe.examples.kotlinSpark.UntypedDatasetKt"
3854
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
package org.jetbrains.kotlinx.dataframe.examples.kotlinSpark
2+
3+
import org.apache.spark.api.java.JavaSparkContext
4+
import org.apache.spark.sql.Dataset
5+
import org.apache.spark.sql.Row
6+
import org.apache.spark.sql.RowFactory
7+
import org.apache.spark.sql.SparkSession
8+
import org.apache.spark.sql.types.ArrayType
9+
import org.apache.spark.sql.types.DataType
10+
import org.apache.spark.sql.types.DataTypes
11+
import org.apache.spark.sql.types.Decimal
12+
import org.apache.spark.sql.types.DecimalType
13+
import org.apache.spark.sql.types.MapType
14+
import org.apache.spark.sql.types.StructType
15+
import org.apache.spark.unsafe.types.CalendarInterval
16+
import org.jetbrains.kotlinx.dataframe.AnyFrame
17+
import org.jetbrains.kotlinx.dataframe.DataColumn
18+
import org.jetbrains.kotlinx.dataframe.DataFrame
19+
import org.jetbrains.kotlinx.dataframe.DataRow
20+
import org.jetbrains.kotlinx.dataframe.api.rows
21+
import org.jetbrains.kotlinx.dataframe.api.schema
22+
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
23+
import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
24+
import org.jetbrains.kotlinx.dataframe.columns.TypeSuggestion
25+
import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema
26+
import org.jetbrains.kotlinx.dataframe.schema.DataFrameSchema
27+
import org.jetbrains.kotlinx.spark.api.toRDD
28+
import java.math.BigDecimal
29+
import java.math.BigInteger
30+
import java.sql.Date
31+
import java.sql.Timestamp
32+
import java.time.Instant
33+
import java.time.LocalDate
34+
import kotlin.reflect.KType
35+
import kotlin.reflect.KTypeProjection
36+
import kotlin.reflect.full.createType
37+
import kotlin.reflect.full.isSubtypeOf
38+
import kotlin.reflect.full.withNullability
39+
import kotlin.reflect.typeOf
40+
41+
// region Spark to DataFrame
42+
43+
/**
44+
* Converts an untyped Spark [Dataset] (Dataframe) to a Kotlin [DataFrame].
45+
* [StructTypes][StructType] are converted to [ColumnGroups][ColumnGroup].
46+
*
47+
* DataFrame supports type inference to do the conversion automatically.
48+
* This is usually fine for smaller data sets, but when working with larger datasets a type map might be a good idea.
49+
* See [convertToDataFrame] for more information.
50+
*/
51+
fun Dataset<Row>.convertToDataFrameByInference(
52+
schema: StructType = schema(),
53+
prefix: List<String> = emptyList(),
54+
): AnyFrame {
55+
val columns = schema.fields().map { field ->
56+
val name = field.name()
57+
when (val dataType = field.dataType()) {
58+
is StructType ->
59+
DataColumn.createColumnGroup(
60+
name = name,
61+
df = convertToDataFrameByInference(dataType, prefix + name),
62+
)
63+
64+
else ->
65+
DataColumn.createByInference(
66+
name = name,
67+
values = select((prefix + name).joinToString("."))
68+
.collectAsList()
69+
.map { it[0] },
70+
suggestedType = TypeSuggestion.Infer,
71+
nullable = field.nullable(),
72+
)
73+
}
74+
}
75+
return columns.toDataFrame()
76+
}
77+
78+
/**
79+
* Converts an untyped Spark [Dataset] (Dataframe) to a Kotlin [DataFrame].
80+
* [StructTypes][StructType] are converted to [ColumnGroups][ColumnGroup].
81+
*
82+
* This version uses a [type-map][DataType.convertToDataFrame] to convert the schemas with a fallback to inference.
83+
* For smaller data sets, inference is usually fine too.
84+
* See [convertToDataFrameByInference] for more information.
85+
*/
86+
fun Dataset<Row>.convertToDataFrame(schema: StructType = schema(), prefix: List<String> = emptyList()): AnyFrame {
87+
val columns = schema.fields().map { field ->
88+
val name = field.name()
89+
when (val dataType = field.dataType()) {
90+
is StructType ->
91+
DataColumn.createColumnGroup(
92+
name = name,
93+
df = convertToDataFrame(dataType, prefix + name),
94+
)
95+
96+
else ->
97+
DataColumn.createByInference(
98+
name = name,
99+
values = select((prefix + name).joinToString("."))
100+
.collectAsList()
101+
.map { it[0] },
102+
suggestedType =
103+
dataType.convertToDataFrame()
104+
?.let(TypeSuggestion::Use)
105+
?: TypeSuggestion.Infer, // fallback to inference if needed
106+
nullable = field.nullable(),
107+
)
108+
}
109+
}
110+
return columns.toDataFrame()
111+
}
112+
113+
/**
114+
* Returns the corresponding Kotlin type for a given Spark DataType.
115+
*
116+
* This list may be incomplete, but it can at least give you a good start.
117+
*
118+
* @return The KType that corresponds to the Spark DataType, or null if no matching KType is found.
119+
*/
120+
fun DataType.convertToDataFrame(): KType? =
121+
when {
122+
this == DataTypes.ByteType -> typeOf<Byte>()
123+
124+
this == DataTypes.ShortType -> typeOf<Short>()
125+
126+
this == DataTypes.IntegerType -> typeOf<Int>()
127+
128+
this == DataTypes.LongType -> typeOf<Long>()
129+
130+
this == DataTypes.BooleanType -> typeOf<Boolean>()
131+
132+
this == DataTypes.FloatType -> typeOf<Float>()
133+
134+
this == DataTypes.DoubleType -> typeOf<Double>()
135+
136+
this == DataTypes.StringType -> typeOf<String>()
137+
138+
this == DataTypes.DateType -> typeOf<Date>()
139+
140+
this == DataTypes.TimestampType -> typeOf<Timestamp>()
141+
142+
this is DecimalType -> typeOf<Decimal>()
143+
144+
this == DataTypes.CalendarIntervalType -> typeOf<CalendarInterval>()
145+
146+
this == DataTypes.NullType -> nullableNothingType
147+
148+
this == DataTypes.BinaryType -> typeOf<ByteArray>()
149+
150+
this is ArrayType -> {
151+
when (elementType()) {
152+
DataTypes.ShortType -> typeOf<ShortArray>()
153+
DataTypes.IntegerType -> typeOf<IntArray>()
154+
DataTypes.LongType -> typeOf<LongArray>()
155+
DataTypes.FloatType -> typeOf<FloatArray>()
156+
DataTypes.DoubleType -> typeOf<DoubleArray>()
157+
DataTypes.BooleanType -> typeOf<BooleanArray>()
158+
else -> null
159+
}
160+
}
161+
162+
this is MapType -> {
163+
val key = keyType().convertToDataFrame() ?: return null
164+
val value = valueType().convertToDataFrame() ?: return null
165+
Map::class.createType(
166+
listOf(
167+
KTypeProjection.invariant(key),
168+
KTypeProjection.invariant(value.withNullability(valueContainsNull())),
169+
),
170+
)
171+
}
172+
173+
else -> null
174+
}
175+
176+
// endregion
177+
178+
// region DataFrame to Spark
179+
180+
/**
181+
* Converts the DataFrame to a Spark Dataset of Rows using the provided SparkSession and JavaSparkContext.
182+
*
183+
* Spark needs both the data and the schema to be converted to create a correct [Dataset].
184+
*
185+
* @param spark The SparkSession object to use for creating the DataFrame.
186+
* @param sc The JavaSparkContext object to use for converting the DataFrame to RDD.
187+
* @return A Dataset of Rows representing the converted DataFrame.
188+
*/
189+
fun DataFrame<*>.convertToSpark(spark: SparkSession, sc: JavaSparkContext): Dataset<Row> {
190+
val rows = sc.toRDD(rows().map { it.convertToSpark() })
191+
return spark.createDataFrame(rows, schema().convertToSpark())
192+
}
193+
194+
/**
195+
* Converts a DataRow to a Spark Row object.
196+
*
197+
* @return The converted Spark Row.
198+
*/
199+
fun DataRow<*>.convertToSpark(): Row =
200+
RowFactory.create(
201+
*values().map {
202+
when (it) {
203+
is DataRow<*> -> it.convertToSpark()
204+
else -> it
205+
}
206+
}.toTypedArray(),
207+
)
208+
209+
/**
210+
* Converts a DataFrameSchema to a Spark StructType.
211+
*
212+
* @return The converted Spark StructType.
213+
*/
214+
fun DataFrameSchema.convertToSpark(): StructType =
215+
DataTypes.createStructType(
216+
columns.map { (name, schema) ->
217+
DataTypes.createStructField(name, schema.convertToSpark(), schema.nullable)
218+
},
219+
)
220+
221+
/**
222+
* Converts a ColumnSchema object to Spark DataType.
223+
*
224+
* @return The Spark DataType corresponding to the given ColumnSchema object.
225+
* @throws IllegalArgumentException if the column type or kind is unknown.
226+
*/
227+
fun ColumnSchema.convertToSpark(): DataType =
228+
when (this) {
229+
is ColumnSchema.Value -> type.convertToSpark() ?: error("unknown data type: $type")
230+
is ColumnSchema.Group -> schema.convertToSpark()
231+
is ColumnSchema.Frame -> error("nested dataframes are not supported")
232+
else -> error("unknown column kind: $this")
233+
}
234+
235+
/**
236+
* Returns the corresponding Spark DataType for a given Kotlin type.
237+
*
238+
* This list may be incomplete, but it can at least give you a good start.
239+
*
240+
* @return The Spark DataType that corresponds to the Kotlin type, or null if no matching DataType is found.
241+
*/
242+
fun KType.convertToSpark(): DataType? =
243+
when {
244+
isSubtypeOf(typeOf<Byte?>()) -> DataTypes.ByteType
245+
246+
isSubtypeOf(typeOf<Short?>()) -> DataTypes.ShortType
247+
248+
isSubtypeOf(typeOf<Int?>()) -> DataTypes.IntegerType
249+
250+
isSubtypeOf(typeOf<Long?>()) -> DataTypes.LongType
251+
252+
isSubtypeOf(typeOf<Boolean?>()) -> DataTypes.BooleanType
253+
254+
isSubtypeOf(typeOf<Float?>()) -> DataTypes.FloatType
255+
256+
isSubtypeOf(typeOf<Double?>()) -> DataTypes.DoubleType
257+
258+
isSubtypeOf(typeOf<String?>()) -> DataTypes.StringType
259+
260+
isSubtypeOf(typeOf<LocalDate?>()) -> DataTypes.DateType
261+
262+
isSubtypeOf(typeOf<Date?>()) -> DataTypes.DateType
263+
264+
isSubtypeOf(typeOf<Timestamp?>()) -> DataTypes.TimestampType
265+
266+
isSubtypeOf(typeOf<Instant?>()) -> DataTypes.TimestampType
267+
268+
isSubtypeOf(typeOf<Decimal?>()) -> DecimalType.SYSTEM_DEFAULT()
269+
270+
isSubtypeOf(typeOf<BigDecimal?>()) -> DecimalType.SYSTEM_DEFAULT()
271+
272+
isSubtypeOf(typeOf<BigInteger?>()) -> DecimalType.SYSTEM_DEFAULT()
273+
274+
isSubtypeOf(typeOf<CalendarInterval?>()) -> DataTypes.CalendarIntervalType
275+
276+
isSubtypeOf(nullableNothingType) -> DataTypes.NullType
277+
278+
isSubtypeOf(typeOf<ByteArray?>()) -> DataTypes.BinaryType
279+
280+
isSubtypeOf(typeOf<ShortArray?>()) -> DataTypes.createArrayType(DataTypes.ShortType, false)
281+
282+
isSubtypeOf(typeOf<IntArray?>()) -> DataTypes.createArrayType(DataTypes.IntegerType, false)
283+
284+
isSubtypeOf(typeOf<LongArray?>()) -> DataTypes.createArrayType(DataTypes.LongType, false)
285+
286+
isSubtypeOf(typeOf<FloatArray?>()) -> DataTypes.createArrayType(DataTypes.FloatType, false)
287+
288+
isSubtypeOf(typeOf<DoubleArray?>()) -> DataTypes.createArrayType(DataTypes.DoubleType, false)
289+
290+
isSubtypeOf(typeOf<BooleanArray?>()) -> DataTypes.createArrayType(DataTypes.BooleanType, false)
291+
292+
isSubtypeOf(typeOf<Array<*>>()) ->
293+
error("non-primitive arrays are not supported for now, you can add it yourself")
294+
295+
isSubtypeOf(typeOf<List<*>>()) -> error("lists are not supported for now, you can add it yourself")
296+
297+
isSubtypeOf(typeOf<Set<*>>()) -> error("sets are not supported for now, you can add it yourself")
298+
299+
classifier == Map::class -> {
300+
val (key, value) = arguments
301+
DataTypes.createMapType(
302+
key.type?.convertToSpark(),
303+
value.type?.convertToSpark(),
304+
value.type?.isMarkedNullable ?: true,
305+
)
306+
}
307+
308+
else -> null
309+
}
310+
311+
private val nullableNothingType: KType = typeOf<List<Nothing?>>().arguments.first().type!!
312+
313+
// endregion

0 commit comments

Comments
 (0)