diff --git a/docs/source/user-guide/latest/configs.md b/docs/source/user-guide/latest/configs.md index 5eea5c4e5d..529ea423d0 100644 --- a/docs/source/user-guide/latest/configs.md +++ b/docs/source/user-guide/latest/configs.md @@ -232,6 +232,7 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.expression.Cosh.enabled` | Enable Comet acceleration for `Cosh` | true | | `spark.comet.expression.Cot.enabled` | Enable Comet acceleration for `Cot` | true | | `spark.comet.expression.CreateArray.enabled` | Enable Comet acceleration for `CreateArray` | true | +| `spark.comet.expression.CreateMap.enabled` | Enable Comet acceleration for `CreateMap` | true | | `spark.comet.expression.CreateNamedStruct.enabled` | Enable Comet acceleration for `CreateNamedStruct` | true | | `spark.comet.expression.DateAdd.enabled` | Enable Comet acceleration for `DateAdd` | true | | `spark.comet.expression.DateFormatClass.enabled` | Enable Comet acceleration for `DateFormatClass` | true | diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 75c53198b8..8704d49f00 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -46,6 +46,7 @@ use datafusion_spark::function::datetime::date_add::SparkDateAdd; use datafusion_spark::function::datetime::date_sub::SparkDateSub; use datafusion_spark::function::hash::sha1::SparkSha1; use datafusion_spark::function::hash::sha2::SparkSha2; +use datafusion_spark::function::map::map_from_arrays::MapFromArrays; use datafusion_spark::function::math::expm1::SparkExpm1; use datafusion_spark::function::math::hex::SparkHex; use datafusion_spark::function::string::char::CharFunc; @@ -349,6 +350,7 @@ fn register_datafusion_spark_function(session_ctx: &SessionContext) { session_ctx.register_udf(ScalarUDF::new_from_impl(SparkConcat::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitwiseNot::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkHex::default())); + session_ctx.register_udf(ScalarUDF::new_from_impl(MapFromArrays::default())); } /// Prepares arrow arrays for output. diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 3569559df8..221a090dd9 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -126,7 +126,8 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[MapKeys] -> CometMapKeys, classOf[MapEntries] -> CometMapEntries, classOf[MapValues] -> CometMapValues, - classOf[MapFromArrays] -> CometMapFromArrays) + classOf[MapFromArrays] -> CometMapFromArrays, + classOf[CreateMap] -> CometCreateMap) private val structExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( classOf[CreateNamedStruct] -> CometCreateNamedStruct, diff --git a/spark/src/main/scala/org/apache/comet/serde/maps.scala b/spark/src/main/scala/org/apache/comet/serde/maps.scala index 2e217f6af0..e7ec1f09a8 100644 --- a/spark/src/main/scala/org/apache/comet/serde/maps.scala +++ b/spark/src/main/scala/org/apache/comet/serde/maps.scala @@ -20,7 +20,8 @@ package org.apache.comet.serde import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{ArrayType, MapType} +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} +import org.apache.spark.sql.types.DataTypes.BinaryType import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType} @@ -89,3 +90,49 @@ object CometMapFromArrays extends CometExpressionSerde[MapFromArrays] { optExprWithInfo(mapFromArraysExpr, expr, expr.children: _*) } } + +object CometCreateMap extends CometExpressionSerde[CreateMap] with MapBase { + val keyUnsupportedReason = "Using BinaryType as Map keys is not allowed in create map function" + val valueUnsupportedReason = + "Using BinaryType as Map values is not allowed in create map function" + + override def getSupportLevel(expr: CreateMap): SupportLevel = { + if (containsBinary(expr.dataType.keyType)) { + return Incompatible(Some(keyUnsupportedReason)) + } + if (containsBinary(expr.dataType.valueType)) { + return Incompatible(Some(valueUnsupportedReason)) + } + Compatible(None) + } + + override def convert( + expr: CreateMap, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val keysArray = CreateArray(expr.keys) + val valuesArray = CreateArray(expr.values) + val keysExprProto = exprToProtoInternal(keysArray, inputs, binding) + val valuesExprProto = exprToProtoInternal(valuesArray, inputs, binding) + val createMapExprProto = + scalarFunctionExprToProtoWithReturnType( + "map_from_arrays", + expr.dataType, + false, + keysExprProto, + valuesExprProto) + optExprWithInfo(createMapExprProto, expr, expr.children: _*) + } +} + +sealed trait MapBase { + + protected def containsBinary(dataType: DataType): Boolean = { + dataType match { + case BinaryType => true + case StructType(fields) => fields.exists(field => containsBinary(field.dataType)) + case ArrayType(elementType, _) => containsBinary(elementType) + case _ => false + } + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala index 9276a20348..675dde2f6e 100644 --- a/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala @@ -25,7 +25,9 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.BinaryType +import org.apache.comet.serde.CometCreateMap import org.apache.comet.testing.{DataGenOptions, ParquetGenerator, SchemaGenOptions} class CometMapExpressionSuite extends CometTestBase { @@ -157,4 +159,45 @@ class CometMapExpressionSuite extends CometTestBase { } } + test("create_map") { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + val filename = path.toString + val random = new Random(42) + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + val schemaGenOptions = + SchemaGenOptions( + generateArray = true, + generateStruct = true, + primitiveTypes = SchemaGenOptions.defaultPrimitiveTypes.filterNot(_ == BinaryType)) + val dataGenOptions = DataGenOptions(allowNull = false, generateNegativeZero = false) + ParquetGenerator.makeParquetFile( + random, + spark, + filename, + 100, + schemaGenOptions, + dataGenOptions) + } + val df = spark.read.parquet(filename) + df.createOrReplaceTempView("t1") + for (fieldName <- df.schema.fieldNames) { + checkSparkAnswerAndOperator(spark.sql(s"SELECT map($fieldName, $fieldName) FROM t1")) + } + } + } + + test("create_map - fallback for binary type") { + val table = "t2" + withTable(table) { + sql( + s"create table $table using parquet as select cast('abc' as binary) as c1 from range(10)") + checkSparkAnswerAndFallbackReason( + sql(s"select map(c1, 1) from $table"), + CometCreateMap.keyUnsupportedReason) + checkSparkAnswerAndFallbackReason( + sql(s"select map(1, c1) from $table"), + CometCreateMap.valueUnsupportedReason) + } + } } diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometMapExpressionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometMapExpressionBenchmark.scala new file mode 100644 index 0000000000..e0a16d3295 --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometMapExpressionBenchmark.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.benchmark + +import org.apache.comet.CometConf + +/** + * Configuration for a map expression benchmark. + * + * @param name + * Name for the benchmark + * @param query + * SQL query to benchmark + * @param extraCometConfigs + * Additional Comet configurations for the scan+exec case + */ +case class MapExprConfig( + name: String, + query: String, + extraCometConfigs: Map[String, String] = Map.empty) + +/** + * Benchmark to measure performance of Comet map expressions. To run this benchmark: + * {{{ + * SPARK_GENERATE_BENCHMARK_FILES=1 make benchmark-org.apache.spark.sql.benchmark.CometMapExpressionBenchmark + * }}} + * Results will be written to "spark/benchmarks/CometMapExpressionBenchmark-**results.txt". + */ +object CometMapExpressionBenchmark extends CometBenchmarkBase { + + private val mapExpressions = List( + MapExprConfig( + "create_map", + "select map(c1, c1, c2, c2, c3, c3, c4, c4, c5, c5) from parquetV1Table")) + + override def runCometBenchmark(args: Array[String]): Unit = { + runBenchmarkWithTable("Map expressions", 1024) { v => + withTempPath { dir => + withTempTable("parquetV1Table") { + prepareTable( + dir, + spark.sql( + "SELECT " + + "(value + 0) AS C1, " + + "(value + 10) AS C2, " + + "(value + 20) AS C3, " + + "(value + 30) AS C4, " + + s"(value + 40) AS C5 FROM $tbl")) + + val extraConfigs = Map(CometConf.COMET_CASE_CONVERSION_ENABLED.key -> "true") + + mapExpressions.foreach { config => + val allConfigs = extraConfigs ++ config.extraCometConfigs + runBenchmark(config.name) { + runExpressionBenchmark(config.name, v, config.query, allConfigs) + } + } + } + } + } + } +}