Skip to content

Commit 79eac6c

Browse files
committed
wip Multik sample
1 parent 464393f commit 79eac6c

File tree

7 files changed

+355
-0
lines changed

7 files changed

+355
-0
lines changed

examples/idea-examples/unsupported-data-sources/build.gradle.kts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ dependencies {
3030
compileOnly(libs.spark)
3131
implementation(libs.log4j.core)
3232
implementation(libs.log4j.api)
33+
34+
// multik support
35+
implementation(libs.multik.core)
36+
implementation(libs.multik.default)
3337
}
3438

3539
/**
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
@file:OptIn(ExperimentalTypeInference::class)
2+
3+
package org.jetbrains.kotlinx.dataframe.examples.multik
4+
5+
import org.jetbrains.kotlinx.dataframe.AnyFrame
6+
import org.jetbrains.kotlinx.dataframe.ColumnSelector
7+
import org.jetbrains.kotlinx.dataframe.ColumnsSelector
8+
import org.jetbrains.kotlinx.dataframe.DataColumn
9+
import org.jetbrains.kotlinx.dataframe.DataFrame
10+
import org.jetbrains.kotlinx.dataframe.api.ValueProperty
11+
import org.jetbrains.kotlinx.dataframe.api.cast
12+
import org.jetbrains.kotlinx.dataframe.api.colsOf
13+
import org.jetbrains.kotlinx.dataframe.api.column
14+
import org.jetbrains.kotlinx.dataframe.api.dataFrameOf
15+
import org.jetbrains.kotlinx.dataframe.api.getColumn
16+
import org.jetbrains.kotlinx.dataframe.api.getColumns
17+
import org.jetbrains.kotlinx.dataframe.api.map
18+
import org.jetbrains.kotlinx.dataframe.api.named
19+
import org.jetbrains.kotlinx.dataframe.api.toColumn
20+
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
21+
import org.jetbrains.kotlinx.multik.api.mk
22+
import org.jetbrains.kotlinx.multik.api.ndarray
23+
import org.jetbrains.kotlinx.multik.ndarray.complex.Complex
24+
import org.jetbrains.kotlinx.multik.ndarray.data.D1Array
25+
import org.jetbrains.kotlinx.multik.ndarray.data.D2Array
26+
import org.jetbrains.kotlinx.multik.ndarray.data.get
27+
import org.jetbrains.kotlinx.multik.ndarray.operations.toList
28+
import kotlin.experimental.ExperimentalTypeInference
29+
import kotlin.reflect.KClass
30+
import kotlin.reflect.full.isSubtypeOf
31+
import kotlin.reflect.typeOf
32+
33+
// region 1D
34+
35+
/** Converts a one-dimensional array ([D1Array]) to a [DataColumn] with optional [name]. */
36+
inline fun <reified N> D1Array<N>.convertToColumn(name: String = ""): DataColumn<N> = column(toList()) named name
37+
38+
/** Converts a [DataColumn] to a one-dimensional array ([D1Array]). */
39+
@JvmName("convertNumberColumnToMultik")
40+
inline fun <reified N> DataColumn<N>.convertToMultik(): D1Array<N> where N : Number, N : Comparable<N> =
41+
mk.ndarray(toList())
42+
43+
/** Converts a [DataColumn] to a one-dimensional array ([D1Array]). */
44+
@JvmName("convertComplexColumnToMultik")
45+
inline fun <reified N : Complex> DataColumn<N>.convertToMultik(): D1Array<N> = mk.ndarray(toList())
46+
47+
@JvmName("convertNumberColumnFromDfToMultik")
48+
@OverloadResolutionByLambdaReturnType
49+
inline fun <T, reified N> DataFrame<T>.convertToMultik(
50+
crossinline column: ColumnSelector<T, N>,
51+
): D1Array<N>
52+
where N : Number, N : Comparable<N> = getColumn { column(it) }.convertToMultik()
53+
54+
@JvmName("convertComplexColumnFromDfToMultik")
55+
@OverloadResolutionByLambdaReturnType
56+
inline fun <T, reified N : Complex> DataFrame<T>.convertToMultik(crossinline column: ColumnSelector<T, N>): D1Array<N> =
57+
getColumn { column(it) }.convertToMultik()
58+
59+
/**
60+
* Converts a one-dimensional array ([D1Array]) of type [N] into a DataFrame.
61+
* The resulting DataFrame contains a single column named "value", where each element of the array becomes a row in the DataFrame.
62+
*
63+
* @return a DataFrame where each element of the source array is represented as a row in a column named "value" under the schema [ValueProperty].
64+
*/
65+
inline fun <reified N> D1Array<N>.convertToDataFrame(): DataFrame<ValueProperty<N>> =
66+
dataFrameOf(ValueProperty<*>::value.name to column(toList()))
67+
.cast()
68+
69+
// endregion
70+
71+
// region 2D
72+
73+
/**
74+
* Converts a two-dimensional array ([D2Array]) to a DataFrame.
75+
* It will contain `shape[0]` rows and `shape[1]` columns.
76+
*
77+
* Column names can be specified using the [columnNameGenerator] lambda.
78+
*
79+
* The conversion enforces that `multikArray[x][y] == dataframe[x][y]`
80+
*/
81+
inline fun <reified N> D2Array<N>.convertToDataFrame(columnNameGenerator: (Int) -> String = { "col$it" }): AnyFrame =
82+
(0..<shape[1]).map { col ->
83+
this[0..<shape[0], col]
84+
.toList()
85+
.toColumn(columnNameGenerator(col))
86+
}.toDataFrame()
87+
88+
@JvmName("convertToMultikOfComplex")
89+
inline fun <reified N : Complex> AnyFrame.convertToMultikOf(_klass: KClass<Complex> = Complex::class): D2Array<N> =
90+
convertToMultik { colsOf<N>() }
91+
92+
@JvmName("convertToMultikOfNumber")
93+
inline fun <reified N> AnyFrame.convertToMultikOf(
94+
_klass: KClass<Number> = Number::class,
95+
): D2Array<N> where N : Number, N : Comparable<N> = convertToMultik { colsOf<N>() }
96+
97+
@JvmName("convertToMultikGuess")
98+
fun AnyFrame.convertToMultik(): D2Array<*> {
99+
val columnTypes = columnTypes().distinct()
100+
return when {
101+
columnTypes.size != 1 -> error("found column types: $columnTypes")
102+
columnTypes.single() == typeOf<Complex>() -> convertToMultik { colsOf<Complex>() }
103+
columnTypes.single().isSubtypeOf(typeOf<Byte>()) -> convertToMultik { colsOf<Byte>() }
104+
columnTypes.single().isSubtypeOf(typeOf<Short>()) -> convertToMultik { colsOf<Short>() }
105+
columnTypes.single().isSubtypeOf(typeOf<Int>()) -> convertToMultik { colsOf<Int>() }
106+
columnTypes.single().isSubtypeOf(typeOf<Long>()) -> convertToMultik { colsOf<Long>() }
107+
columnTypes.single().isSubtypeOf(typeOf<Float>()) -> convertToMultik { colsOf<Float>() }
108+
columnTypes.single().isSubtypeOf(typeOf<Double>()) -> convertToMultik { colsOf<Double>() }
109+
else -> error("found column types: $columnTypes")
110+
}
111+
}
112+
113+
@JvmName("convertNumberColumnsFromDfToMultik")
114+
@OverloadResolutionByLambdaReturnType
115+
inline fun <T, reified N> DataFrame<T>.convertToMultik(
116+
crossinline columns: ColumnsSelector<T, N>,
117+
): D2Array<N>
118+
where N : Number, N : Comparable<N> = getColumns { columns(it) }.convertToMultik()
119+
120+
@JvmName("convertComplexColumnsFromDfToMultik")
121+
@OverloadResolutionByLambdaReturnType
122+
inline fun <T, reified N : Complex> DataFrame<T>.convertToMultik(
123+
crossinline columns: ColumnsSelector<T, N>,
124+
): D2Array<N> = getColumns { columns(it) }.convertToMultik()
125+
126+
@JvmName("convertNumberColumnsToMultik")
127+
inline fun <reified N> List<DataColumn<N>>.convertToMultik(): D2Array<N> where N : Number, N : Comparable<N> =
128+
mk.ndarray(
129+
toDataFrame().map { it.values() as List<N> },
130+
)
131+
132+
@JvmName("convertComplexColumnsToMultik")
133+
inline fun <reified N : Complex> List<DataColumn<N>>.convertToMultik(): D2Array<N> =
134+
mk.ndarray(
135+
toDataFrame().map { it.values() as List<N> },
136+
)
137+
138+
// endregion
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package org.jetbrains.kotlinx.dataframe.examples.multik
2+
3+
import org.jetbrains.kotlinx.dataframe.api.print
4+
import org.jetbrains.kotlinx.multik.api.io.readNPY
5+
import org.jetbrains.kotlinx.multik.api.mk
6+
import org.jetbrains.kotlinx.multik.ndarray.data.D1
7+
import java.io.File
8+
9+
/**
10+
* Multik can read/write data from NPY/NPZ files.
11+
* We can use this from DataFrame too!
12+
*/
13+
fun main() {
14+
val npyFilename = "a1d.npy"
15+
val npyFile = File(object {}.javaClass.classLoader.getResource(npyFilename)!!.toURI())
16+
17+
val mk1 = mk.readNPY<Long, D1>(npyFile)
18+
val df1 = mk1.convertToDataFrame()
19+
20+
df1.print(borders = true, columnTypes = true)
21+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package org.jetbrains.kotlinx.dataframe.examples.multik
2+
3+
import org.jetbrains.kotlinx.dataframe.api.describe
4+
import org.jetbrains.kotlinx.dataframe.api.mean
5+
import org.jetbrains.kotlinx.dataframe.api.print
6+
import org.jetbrains.kotlinx.dataframe.api.value
7+
import org.jetbrains.kotlinx.multik.api.mk
8+
import org.jetbrains.kotlinx.multik.api.rand
9+
import org.jetbrains.kotlinx.multik.ndarray.data.get
10+
11+
/**
12+
* Let's explore some ways we can combine Multik with Kotlin DataFrame.
13+
*
14+
* We will use compatibilityLayer.kt for the conversions.
15+
*/
16+
fun main() {
17+
oneDimension()
18+
twoDimensions()
19+
}
20+
21+
fun oneDimension() {
22+
// we can convert a 1D ndarray to a column of a DataFrame:
23+
val mk1 = mk.rand<Double>(50)
24+
val col1 by mk1.convertToColumn()
25+
println(col1)
26+
27+
// or straight to a DataFrame. It will become the `value` column.
28+
val df1 = mk1.convertToDataFrame()
29+
println(df1)
30+
31+
// this allows us to perform any DF operation:
32+
println(df1.mean { value })
33+
df1.describe().print(borders = true)
34+
35+
// we can convert back to Multik:
36+
val mk2 = df1.convertToMultik { value }
37+
// or
38+
df1.value.convertToMultik()
39+
40+
println(mk2)
41+
}
42+
43+
fun twoDimensions() {
44+
// we can also convert a 2D ndarray to a DataFrame
45+
val mk1 = mk.rand<Int>(5, 10)
46+
println(mk1)
47+
48+
val df = mk1.convertToDataFrame()
49+
df.print()
50+
51+
val mk2 = df.convertToMultikOf<Int>()
52+
println(mk2)
53+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
package org.jetbrains.kotlinx.dataframe.examples.multik
2+
3+
import org.jetbrains.kotlinx.dataframe.annotations.DataSchema
4+
import org.jetbrains.kotlinx.dataframe.api.print
5+
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
6+
import org.jetbrains.kotlinx.dataframe.io.toStandaloneHtml
7+
import org.jetbrains.kotlinx.multik.api.identity
8+
import org.jetbrains.kotlinx.multik.api.mk
9+
import org.jetbrains.kotlinx.multik.ndarray.data.D2Array
10+
import org.jetbrains.kotlinx.multik.ndarray.data.set
11+
import kotlin.math.cos
12+
import kotlin.math.sin
13+
import kotlin.math.tan
14+
15+
@DataSchema
16+
data class Transformation(
17+
val type: TransformationType,
18+
val parameters: Map<String, Double>,
19+
val note: String,
20+
val matrix: D2Array<Double>,
21+
)
22+
23+
enum class TransformationType {
24+
IDENTITY,
25+
TRANSLATION,
26+
SCALING,
27+
ROTATION,
28+
SHEARING,
29+
REFLECTION_ABOUT_ORIGIN,
30+
REFLECTION_ABOUT_X_AXIS,
31+
REFLECTION_ABOUT_Y_AXIS,
32+
}
33+
34+
/**
35+
* IDK yet about this one... TODO
36+
*/
37+
fun main() {
38+
// DataFrames can store anything inside, including Multik nd arrays.
39+
// This can be useful for storing matrices for easier access later,
40+
// such as affine transformations when making 2D graphics!
41+
// (https://en.wikipedia.org/wiki/Affine_transformation)
42+
43+
// let's make a transformation sequence that rotates and scales an image in place.
44+
// It's currently 100x50, positioned with its left bottom corner at (x=10, y=0)
45+
val transformations = listOf(
46+
Transformation(
47+
type = TransformationType.TRANSLATION,
48+
parameters = mapOf("x" to -10.0, "y" to 0.0),
49+
note = "Translate so left-bottom touches origin",
50+
matrix = translationMatrixOf(x = -10.0, y = 0.0),
51+
),
52+
Transformation(
53+
type = TransformationType.SCALING,
54+
parameters = mapOf("w" to 2.0, "h" to 2.0),
55+
note = "Scale by x2",
56+
matrix = scaleMatrixOf(w = 2.0, h = 2.0),
57+
),
58+
Transformation(
59+
type = TransformationType.TRANSLATION,
60+
parameters = mapOf("x" to -100.0, "y" to -50.0),
61+
note = "Translate so the new image center is at the origin",
62+
matrix = translationMatrixOf(x = -100.0, y = -50.0),
63+
),
64+
Transformation(
65+
type = TransformationType.ROTATION,
66+
parameters = mapOf("angle" to 45.0),
67+
note = "Rotate by 45 degrees",
68+
matrix = rotationMatrixOf(angle = 45.0),
69+
),
70+
Transformation(
71+
type = TransformationType.TRANSLATION,
72+
parameters = mapOf("x" to 10.0 + 50.0, "y" to 0.0 + 25.0),
73+
note = "Translate back so the center is at the same original position",
74+
matrix = translationMatrixOf(x = 10.0 + 50.0, y = 0.0 + 25.0),
75+
),
76+
).toDataFrame()
77+
78+
transformations.print(borders = true)
79+
transformations.toStandaloneHtml().openInBrowser()
80+
}
81+
82+
fun identityMatrix(): D2Array<Double> = mk.identity(3)
83+
84+
/** Returns a 3x3 affine transformation matrix that translates by (x, y) */
85+
fun translationMatrixOf(x: Double = 0.0, y: Double = 0.0): D2Array<Double> =
86+
identityMatrix().apply {
87+
this[0, 2] = x
88+
this[1, 2] = y
89+
}
90+
91+
/** Returns a 3x3 affine transformation matrix that scales by (w, h) about the origin */
92+
fun scaleMatrixOf(w: Double = 1.0, h: Double = 1.0): D2Array<Double> =
93+
identityMatrix().apply {
94+
this[0, 0] = w
95+
this[1, 1] = h
96+
}
97+
98+
/** Returns a 3x3 affine transformation matrix that rotates by [angle] degrees about the origin */
99+
fun rotationMatrixOf(angle: Double): D2Array<Double> {
100+
val cos = cos(angle)
101+
val sin = sin(angle)
102+
return identityMatrix().apply {
103+
this[0, 0] = cos
104+
this[0, 1] = -sin
105+
this[1, 0] = sin
106+
this[1, 1] = cos
107+
}
108+
}
109+
110+
/** Returns a 3x3 affine transformation matrix that shears by [x] and [y] */
111+
fun shearingMatrixOf(x: Double = 0.0, y: Double = 0.0): D2Array<Double> =
112+
identityMatrix().apply {
113+
this[0, 1] = tan(x)
114+
this[1, 0] = tan(y)
115+
}
116+
117+
/** Returns a 3x3 affine transformation matrix that reflects about the origin */
118+
fun reflectionAboutOriginMatrix(): D2Array<Double> =
119+
identityMatrix().apply {
120+
this[0, 0] = -1.0
121+
this[1, 1] = -1.0
122+
}
123+
124+
/** Returns a 3x3 affine transformation matrix that reflects about the x-axis */
125+
fun reflectionAboutXAxisMatrix(): D2Array<Double> =
126+
identityMatrix().apply {
127+
this[1, 1] = -1.0
128+
}
129+
130+
/** Returns a 3x3 affine transformation matrix that reflects about the y-axis */
131+
fun reflectionAboutYAxisMatrix(): D2Array<Double> =
132+
identityMatrix().apply {
133+
this[0, 0] = -1.0
134+
}
Binary file not shown.

gradle/libs.versions.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ exposed = "1.0.0-beta-2"
6969
kotlin-spark = "1.2.4"
7070
spark = "3.3.2"
7171

72+
multik = "0.2.3"
73+
7274
[libraries]
7375
ksp-gradle = { group = "com.google.devtools.ksp", name = "symbol-processing-gradle-plugin", version.ref = "ksp" }
7476
ksp-api = { group = "com.google.devtools.ksp", name = "symbol-processing-api", version.ref = "ksp" }
@@ -175,6 +177,9 @@ spark = { group = "org.apache.spark", name = "spark-sql_2.13", version.ref = "sp
175177
log4j-core = { group = "org.apache.logging.log4j", name = "log4j-core", version.ref = "log4j" }
176178
log4j-api = { group = "org.apache.logging.log4j", name = "log4j-api", version.ref = "log4j" }
177179

180+
multik-core = { group = "org.jetbrains.kotlinx", name = "multik-core", version.ref = "multik" }
181+
multik-default = { group = "org.jetbrains.kotlinx", name = "multik-default", version.ref = "multik" }
182+
178183
[plugins]
179184
jupyter-api = { id = "org.jetbrains.kotlin.jupyter.api", version.ref = "kotlinJupyter" }
180185
ksp = { id = "com.google.devtools.ksp", version.ref = "ksp" }

0 commit comments

Comments
 (0)