|
| 1 | +@file:OptIn(ExperimentalTypeInference::class) |
| 2 | + |
| 3 | +package org.jetbrains.kotlinx.dataframe.examples.multik |
| 4 | + |
| 5 | +import org.jetbrains.kotlinx.dataframe.AnyFrame |
| 6 | +import org.jetbrains.kotlinx.dataframe.ColumnSelector |
| 7 | +import org.jetbrains.kotlinx.dataframe.ColumnsSelector |
| 8 | +import org.jetbrains.kotlinx.dataframe.DataColumn |
| 9 | +import org.jetbrains.kotlinx.dataframe.DataFrame |
| 10 | +import org.jetbrains.kotlinx.dataframe.api.ValueProperty |
| 11 | +import org.jetbrains.kotlinx.dataframe.api.cast |
| 12 | +import org.jetbrains.kotlinx.dataframe.api.colsOf |
| 13 | +import org.jetbrains.kotlinx.dataframe.api.column |
| 14 | +import org.jetbrains.kotlinx.dataframe.api.dataFrameOf |
| 15 | +import org.jetbrains.kotlinx.dataframe.api.getColumn |
| 16 | +import org.jetbrains.kotlinx.dataframe.api.getColumns |
| 17 | +import org.jetbrains.kotlinx.dataframe.api.map |
| 18 | +import org.jetbrains.kotlinx.dataframe.api.named |
| 19 | +import org.jetbrains.kotlinx.dataframe.api.toColumn |
| 20 | +import org.jetbrains.kotlinx.dataframe.api.toDataFrame |
| 21 | +import org.jetbrains.kotlinx.multik.api.mk |
| 22 | +import org.jetbrains.kotlinx.multik.api.ndarray |
| 23 | +import org.jetbrains.kotlinx.multik.ndarray.complex.Complex |
| 24 | +import org.jetbrains.kotlinx.multik.ndarray.data.D1Array |
| 25 | +import org.jetbrains.kotlinx.multik.ndarray.data.D2Array |
| 26 | +import org.jetbrains.kotlinx.multik.ndarray.data.get |
| 27 | +import org.jetbrains.kotlinx.multik.ndarray.operations.toList |
| 28 | +import kotlin.experimental.ExperimentalTypeInference |
| 29 | +import kotlin.reflect.KClass |
| 30 | +import kotlin.reflect.full.isSubtypeOf |
| 31 | +import kotlin.reflect.typeOf |
| 32 | + |
| 33 | +// region 1D |
| 34 | + |
| 35 | +/** Converts a one-dimensional array ([D1Array]) to a [DataColumn] with optional [name]. */ |
| 36 | +inline fun <reified N> D1Array<N>.convertToColumn(name: String = ""): DataColumn<N> = column(toList()) named name |
| 37 | + |
| 38 | +/** Converts a [DataColumn] to a one-dimensional array ([D1Array]). */ |
| 39 | +@JvmName("convertNumberColumnToMultik") |
| 40 | +inline fun <reified N> DataColumn<N>.convertToMultik(): D1Array<N> where N : Number, N : Comparable<N> = |
| 41 | + mk.ndarray(toList()) |
| 42 | + |
| 43 | +/** Converts a [DataColumn] to a one-dimensional array ([D1Array]). */ |
| 44 | +@JvmName("convertComplexColumnToMultik") |
| 45 | +inline fun <reified N : Complex> DataColumn<N>.convertToMultik(): D1Array<N> = mk.ndarray(toList()) |
| 46 | + |
| 47 | +@JvmName("convertNumberColumnFromDfToMultik") |
| 48 | +@OverloadResolutionByLambdaReturnType |
| 49 | +inline fun <T, reified N> DataFrame<T>.convertToMultik( |
| 50 | + crossinline column: ColumnSelector<T, N>, |
| 51 | +): D1Array<N> |
| 52 | + where N : Number, N : Comparable<N> = getColumn { column(it) }.convertToMultik() |
| 53 | + |
| 54 | +@JvmName("convertComplexColumnFromDfToMultik") |
| 55 | +@OverloadResolutionByLambdaReturnType |
| 56 | +inline fun <T, reified N : Complex> DataFrame<T>.convertToMultik(crossinline column: ColumnSelector<T, N>): D1Array<N> = |
| 57 | + getColumn { column(it) }.convertToMultik() |
| 58 | + |
| 59 | +/** |
| 60 | + * Converts a one-dimensional array ([D1Array]) of type [N] into a DataFrame. |
| 61 | + * The resulting DataFrame contains a single column named "value", where each element of the array becomes a row in the DataFrame. |
| 62 | + * |
| 63 | + * @return a DataFrame where each element of the source array is represented as a row in a column named "value" under the schema [ValueProperty]. |
| 64 | + */ |
| 65 | +inline fun <reified N> D1Array<N>.convertToDataFrame(): DataFrame<ValueProperty<N>> = |
| 66 | + dataFrameOf(ValueProperty<*>::value.name to column(toList())) |
| 67 | + .cast() |
| 68 | + |
| 69 | +// endregion |
| 70 | + |
| 71 | +// region 2D |
| 72 | + |
| 73 | +/** |
| 74 | + * Converts a two-dimensional array ([D2Array]) to a DataFrame. |
| 75 | + * It will contain `shape[0]` rows and `shape[1]` columns. |
| 76 | + * |
| 77 | + * Column names can be specified using the [columnNameGenerator] lambda. |
| 78 | + * |
| 79 | + * The conversion enforces that `multikArray[x][y] == dataframe[x][y]` |
| 80 | + */ |
| 81 | +inline fun <reified N> D2Array<N>.convertToDataFrame(columnNameGenerator: (Int) -> String = { "col$it" }): AnyFrame = |
| 82 | + (0..<shape[1]).map { col -> |
| 83 | + this[0..<shape[0], col] |
| 84 | + .toList() |
| 85 | + .toColumn(columnNameGenerator(col)) |
| 86 | + }.toDataFrame() |
| 87 | + |
| 88 | +@JvmName("convertToMultikOfComplex") |
| 89 | +inline fun <reified N : Complex> AnyFrame.convertToMultikOf(_klass: KClass<Complex> = Complex::class): D2Array<N> = |
| 90 | + convertToMultik { colsOf<N>() } |
| 91 | + |
| 92 | +@JvmName("convertToMultikOfNumber") |
| 93 | +inline fun <reified N> AnyFrame.convertToMultikOf( |
| 94 | + _klass: KClass<Number> = Number::class, |
| 95 | +): D2Array<N> where N : Number, N : Comparable<N> = convertToMultik { colsOf<N>() } |
| 96 | + |
| 97 | +@JvmName("convertToMultikGuess") |
| 98 | +fun AnyFrame.convertToMultik(): D2Array<*> { |
| 99 | + val columnTypes = columnTypes().distinct() |
| 100 | + return when { |
| 101 | + columnTypes.size != 1 -> error("found column types: $columnTypes") |
| 102 | + columnTypes.single() == typeOf<Complex>() -> convertToMultik { colsOf<Complex>() } |
| 103 | + columnTypes.single().isSubtypeOf(typeOf<Byte>()) -> convertToMultik { colsOf<Byte>() } |
| 104 | + columnTypes.single().isSubtypeOf(typeOf<Short>()) -> convertToMultik { colsOf<Short>() } |
| 105 | + columnTypes.single().isSubtypeOf(typeOf<Int>()) -> convertToMultik { colsOf<Int>() } |
| 106 | + columnTypes.single().isSubtypeOf(typeOf<Long>()) -> convertToMultik { colsOf<Long>() } |
| 107 | + columnTypes.single().isSubtypeOf(typeOf<Float>()) -> convertToMultik { colsOf<Float>() } |
| 108 | + columnTypes.single().isSubtypeOf(typeOf<Double>()) -> convertToMultik { colsOf<Double>() } |
| 109 | + else -> error("found column types: $columnTypes") |
| 110 | + } |
| 111 | +} |
| 112 | + |
| 113 | +@JvmName("convertNumberColumnsFromDfToMultik") |
| 114 | +@OverloadResolutionByLambdaReturnType |
| 115 | +inline fun <T, reified N> DataFrame<T>.convertToMultik( |
| 116 | + crossinline columns: ColumnsSelector<T, N>, |
| 117 | +): D2Array<N> |
| 118 | + where N : Number, N : Comparable<N> = getColumns { columns(it) }.convertToMultik() |
| 119 | + |
| 120 | +@JvmName("convertComplexColumnsFromDfToMultik") |
| 121 | +@OverloadResolutionByLambdaReturnType |
| 122 | +inline fun <T, reified N : Complex> DataFrame<T>.convertToMultik( |
| 123 | + crossinline columns: ColumnsSelector<T, N>, |
| 124 | +): D2Array<N> = getColumns { columns(it) }.convertToMultik() |
| 125 | + |
| 126 | +@JvmName("convertNumberColumnsToMultik") |
| 127 | +inline fun <reified N> List<DataColumn<N>>.convertToMultik(): D2Array<N> where N : Number, N : Comparable<N> = |
| 128 | + mk.ndarray( |
| 129 | + toDataFrame().map { it.values() as List<N> }, |
| 130 | + ) |
| 131 | + |
| 132 | +@JvmName("convertComplexColumnsToMultik") |
| 133 | +inline fun <reified N : Complex> List<DataColumn<N>>.convertToMultik(): D2Array<N> = |
| 134 | + mk.ndarray( |
| 135 | + toDataFrame().map { it.values() as List<N> }, |
| 136 | + ) |
| 137 | + |
| 138 | +// endregion |
0 commit comments