diff --git a/native/spark-expr/benches/decimal_div.rs b/native/spark-expr/benches/decimal_div.rs index 4262e81238..3ca3e42eb5 100644 --- a/native/spark-expr/benches/decimal_div.rs +++ b/native/spark-expr/benches/decimal_div.rs @@ -20,7 +20,7 @@ use arrow::compute::cast; use arrow::datatypes::DataType; use criterion::{criterion_group, criterion_main, Criterion}; use datafusion::physical_plan::ColumnarValue; -use datafusion_comet_spark_expr::{spark_decimal_div, spark_decimal_integral_div}; +use datafusion_comet_spark_expr::{spark_decimal_div, spark_decimal_integral_div, EvalMode}; use std::hint::black_box; use std::sync::Arc; @@ -48,6 +48,7 @@ fn criterion_benchmark(c: &mut Criterion) { black_box(spark_decimal_div( black_box(&args), black_box(&DataType::Decimal128(10, 4)), + EvalMode::Legacy, )) }) }); @@ -57,6 +58,7 @@ fn criterion_benchmark(c: &mut Criterion) { black_box(spark_decimal_integral_div( black_box(&args), black_box(&DataType::Decimal128(10, 4)), + EvalMode::Legacy, )) }) }); diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index f96ddffce9..a440501818 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -133,13 +133,14 @@ pub fn create_comet_physical_fun_with_eval_mode( make_comet_scalar_udf!("unhex", func, without data_type) } "decimal_div" => { - make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type) + make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type, eval_mode) } "decimal_integral_div" => { make_comet_scalar_udf!( "decimal_integral_div", spark_decimal_integral_div, - data_type + data_type, + eval_mode ) } "checked_add" => { diff --git a/native/spark-expr/src/math_funcs/div.rs b/native/spark-expr/src/math_funcs/div.rs index 9fc6692c03..933b28c094 100644 --- a/native/spark-expr/src/math_funcs/div.rs +++ b/native/spark-expr/src/math_funcs/div.rs @@ -16,29 +16,33 @@ // under the License. use crate::math_funcs::utils::get_precision_scale; +use crate::{divide_by_zero_error, EvalMode}; use arrow::array::{Array, Decimal128Array}; use arrow::datatypes::{DataType, DECIMAL128_MAX_PRECISION}; +use arrow::error::ArrowError; use arrow::{ array::{ArrayRef, AsArray}, datatypes::Decimal128Type, }; use datafusion::common::DataFusionError; use datafusion::physical_plan::ColumnarValue; -use num::{BigInt, Signed, ToPrimitive}; +use num::{BigInt, Signed, ToPrimitive, Zero}; use std::sync::Arc; pub fn spark_decimal_div( args: &[ColumnarValue], data_type: &DataType, + eval_mode: EvalMode, ) -> Result { - spark_decimal_div_internal(args, data_type, false) + spark_decimal_div_internal(args, data_type, false, eval_mode) } pub fn spark_decimal_integral_div( args: &[ColumnarValue], data_type: &DataType, + eval_mode: EvalMode, ) -> Result { - spark_decimal_div_internal(args, data_type, true) + spark_decimal_div_internal(args, data_type, true, eval_mode) } // Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) = Decimal(p3, s3). @@ -50,6 +54,7 @@ fn spark_decimal_div_internal( args: &[ColumnarValue], data_type: &DataType, is_integral_div: bool, + eval_mode: EvalMode, ) -> Result { let left = &args[0]; let right = &args[1]; @@ -80,9 +85,12 @@ fn spark_decimal_div_internal( let r_mul = ten.pow(r_exp); let five = BigInt::from(5); let zero = BigInt::from(0); - arrow::compute::kernels::arity::binary(left, right, |l, r| { + arrow::compute::kernels::arity::try_binary(left, right, |l, r| { let l = BigInt::from(l) * &l_mul; let r = BigInt::from(r) * &r_mul; + if eval_mode == EvalMode::Ansi && is_integral_div && r.is_zero() { + return Err(ArrowError::ComputeError(divide_by_zero_error().to_string())); + } let div = if r.eq(&zero) { zero.clone() } else { &l / &r }; let res = if is_integral_div { div @@ -91,14 +99,17 @@ fn spark_decimal_div_internal( } else { div + &five } / &ten; - res.to_i128().unwrap_or(i128::MAX) + Ok(res.to_i128().unwrap_or(i128::MAX)) })? } else { let l_mul = 10_i128.pow(l_exp); let r_mul = 10_i128.pow(r_exp); - arrow::compute::kernels::arity::binary(left, right, |l, r| { + arrow::compute::kernels::arity::try_binary(left, right, |l, r| { let l = l * l_mul; let r = r * r_mul; + if eval_mode == EvalMode::Ansi && is_integral_div && r.is_zero() { + return Err(ArrowError::ComputeError(divide_by_zero_error().to_string())); + } let div = if r == 0 { 0 } else { l / r }; let res = if is_integral_div { div @@ -107,7 +118,7 @@ fn spark_decimal_div_internal( } else { div + 5 } / 10; - res.to_i128().unwrap_or(i128::MAX) + Ok(res.to_i128().unwrap_or(i128::MAX)) })? }; let result = result.with_data_type(DataType::Decimal128(p3, s3)); diff --git a/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala b/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala index 4507dc1073..e34d9d5bd9 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala @@ -180,14 +180,6 @@ object CometDivide extends CometExpressionSerde[Divide] with MathBase { object CometIntegralDivide extends CometExpressionSerde[IntegralDivide] with MathBase { - override def getSupportLevel(expr: IntegralDivide): SupportLevel = { - if (expr.evalMode == EvalMode.ANSI) { - Incompatible(Some("ANSI mode is not supported")) - } else { - Compatible(None) - } - } - override def convert( expr: IntegralDivide, inputs: Seq[Attribute], @@ -206,7 +198,7 @@ object CometIntegralDivide extends CometExpressionSerde[IntegralDivide] with Mat if (expr.right.dataType.isInstanceOf[DecimalType]) expr.right else Cast(expr.right, DecimalType(19, 0)) - val rightExpr = nullIfWhenPrimitive(right) + val rightExpr = if (expr.evalMode != EvalMode.ANSI) nullIfWhenPrimitive(right) else right val dataType = (left.dataType, right.dataType) match { case (l: DecimalType, r: DecimalType) => diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index daf0e45cc8..e711d06d4d 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -58,7 +58,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { val ARITHMETIC_OVERFLOW_EXCEPTION_MSG = """org.apache.comet.CometNativeException: [ARITHMETIC_OVERFLOW] integer overflow. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error""" val DIVIDE_BY_ZERO_EXCEPTION_MSG = - """org.apache.comet.CometNativeException: [DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead""" + """Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead""" test("compare true/false to negative zero") { Seq(false, true).foreach { dictionary => @@ -2929,7 +2929,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } test("ANSI support for divide (division by zero)") { - // TODO : Support ANSI mode in Integral divide - val data = Seq((Integer.MIN_VALUE, 0)) withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { withParquetTable(data, "tbl") { @@ -2950,7 +2949,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } test("ANSI support for divide (division by zero) float division") { - // TODO : Support ANSI mode in Integral divide - val data = Seq((Float.MinPositiveValue, 0.0)) withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { withParquetTable(data, "tbl") { @@ -2970,6 +2968,26 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("ANSI support for integral divide (division by zero)") { + val data = Seq((Integer.MIN_VALUE, 0)) + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + withParquetTable(data, "tbl") { + val res = spark.sql(""" + |SELECT + | _1 div _2 + | from tbl + | """.stripMargin) + + checkSparkMaybeThrows(res) match { + case (Some(sparkExc), Some(cometExc)) => + assert(cometExc.getMessage.contains(DIVIDE_BY_ZERO_EXCEPTION_MSG)) + assert(sparkExc.getMessage.contains("Division by zero")) + case _ => fail("Exception should be thrown") + } + } + } + } + test("test integral divide overflow for decimal") { if (isSpark40Plus) { Seq(true, false)