diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 52b8eb6a30..2f8ba01b32 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -43,6 +43,7 @@ use datafusion_comet_proto::spark_operator::Operator; use datafusion_spark::function::bitwise::bit_get::SparkBitGet; use datafusion_spark::function::datetime::date_add::SparkDateAdd; use datafusion_spark::function::datetime::date_sub::SparkDateSub; +use datafusion_spark::function::hash::sha1::SparkSha1; use datafusion_spark::function::hash::sha2::SparkSha2; use datafusion_spark::function::math::expm1::SparkExpm1; use datafusion_spark::function::string::char::CharFunc; @@ -307,6 +308,7 @@ fn prepare_datafusion_session_context( session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitGet::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkDateAdd::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkDateSub::default())); + session_ctx.register_udf(ScalarUDF::new_from_impl(SparkSha1::default())); // Must be the last one to override existing functions with the same name datafusion_comet_spark_expr::register_all_comet_functions(&mut session_ctx)?; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 4d1daacd61..8f4a77bad0 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -154,7 +154,8 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[Md5] -> CometScalarFunction("md5"), classOf[Murmur3Hash] -> CometMurmur3Hash, classOf[Sha2] -> CometSha2, - classOf[XxHash64] -> CometXxHash64) + classOf[XxHash64] -> CometXxHash64, + classOf[Sha1] -> CometSha1) private val stringExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( classOf[Ascii] -> CometScalarFunction("ascii"), diff --git a/spark/src/main/scala/org/apache/comet/serde/hash.scala b/spark/src/main/scala/org/apache/comet/serde/hash.scala index 5c45a25936..523095011f 100644 --- a/spark/src/main/scala/org/apache/comet/serde/hash.scala +++ b/spark/src/main/scala/org/apache/comet/serde/hash.scala @@ -19,7 +19,7 @@ package org.apache.comet.serde -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Murmur3Hash, Sha2, XxHash64} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Murmur3Hash, Sha1, Sha2, XxHash64} import org.apache.spark.sql.types.{DecimalType, IntegerType, LongType, StringType} import org.apache.comet.CometSparkSessionExtensions.withInfo @@ -85,6 +85,19 @@ object CometSha2 extends CometExpressionSerde[Sha2] { } } +object CometSha1 extends CometExpressionSerde[Sha1] { + override def convert( + expr: Sha1, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + if (!HashUtils.isSupportedType(expr)) { + return None + } + val childExpr = exprToProtoInternal(expr.child, inputs, binding) + scalarFunctionExprToProtoWithReturnType("sha1", StringType, childExpr) + } +} + private object HashUtils { def isSupportedType(expr: Expression): Boolean = { for (child <- expr.children) { diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index f391d52f78..905473553f 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2023,7 +2023,8 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { |md5(col), md5(cast(a as string)), md5(cast(b as string)), |hash(col), hash(col, 1), hash(col, 0), hash(col, a, b), hash(b, a, col), |xxhash64(col), xxhash64(col, 1), xxhash64(col, 0), xxhash64(col, a, b), xxhash64(b, a, col), - |sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128), sha2(col, -1) + |sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128), sha2(col, -1), + |sha1(col), sha1(cast(a as string)), sha1(cast(b as string)) |from test |""".stripMargin) } @@ -2135,7 +2136,8 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { |md5(col), md5(cast(a as string)), --md5(cast(b as string)), |hash(col), hash(col, 1), hash(col, 0), hash(col, a, b), hash(b, a, col), |xxhash64(col), xxhash64(col, 1), xxhash64(col, 0), xxhash64(col, a, b), xxhash64(b, a, col), - |sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128), sha2(col, -1) + |sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128), sha2(col, -1), + |sha1(col), sha1(cast(a as string)) |from test |""".stripMargin) }