Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion native/spark-expr/benches/decimal_div.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use arrow::compute::cast;
use arrow::datatypes::DataType;
use criterion::{criterion_group, criterion_main, Criterion};
use datafusion::physical_plan::ColumnarValue;
use datafusion_comet_spark_expr::{spark_decimal_div, spark_decimal_integral_div};
use datafusion_comet_spark_expr::{spark_decimal_div, spark_decimal_integral_div, EvalMode};
use std::hint::black_box;
use std::sync::Arc;

Expand Down Expand Up @@ -48,6 +48,7 @@ fn criterion_benchmark(c: &mut Criterion) {
black_box(spark_decimal_div(
black_box(&args),
black_box(&DataType::Decimal128(10, 4)),
EvalMode::Legacy,
))
})
});
Expand All @@ -57,6 +58,7 @@ fn criterion_benchmark(c: &mut Criterion) {
black_box(spark_decimal_integral_div(
black_box(&args),
black_box(&DataType::Decimal128(10, 4)),
EvalMode::Legacy,
))
})
});
Expand Down
5 changes: 3 additions & 2 deletions native/spark-expr/src/comet_scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,14 @@ pub fn create_comet_physical_fun_with_eval_mode(
make_comet_scalar_udf!("unhex", func, without data_type)
}
"decimal_div" => {
make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type)
make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type, eval_mode)
}
"decimal_integral_div" => {
make_comet_scalar_udf!(
"decimal_integral_div",
spark_decimal_integral_div,
data_type
data_type,
eval_mode
)
}
"checked_add" => {
Expand Down
25 changes: 18 additions & 7 deletions native/spark-expr/src/math_funcs/div.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,33 @@
// under the License.

use crate::math_funcs::utils::get_precision_scale;
use crate::{divide_by_zero_error, EvalMode};
use arrow::array::{Array, Decimal128Array};
use arrow::datatypes::{DataType, DECIMAL128_MAX_PRECISION};
use arrow::error::ArrowError;
use arrow::{
array::{ArrayRef, AsArray},
datatypes::Decimal128Type,
};
use datafusion::common::DataFusionError;
use datafusion::physical_plan::ColumnarValue;
use num::{BigInt, Signed, ToPrimitive};
use num::{BigInt, Signed, ToPrimitive, Zero};
use std::sync::Arc;

pub fn spark_decimal_div(
args: &[ColumnarValue],
data_type: &DataType,
eval_mode: EvalMode,
) -> Result<ColumnarValue, DataFusionError> {
spark_decimal_div_internal(args, data_type, false)
spark_decimal_div_internal(args, data_type, false, eval_mode)
}

pub fn spark_decimal_integral_div(
args: &[ColumnarValue],
data_type: &DataType,
eval_mode: EvalMode,
) -> Result<ColumnarValue, DataFusionError> {
spark_decimal_div_internal(args, data_type, true)
spark_decimal_div_internal(args, data_type, true, eval_mode)
}

// Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) = Decimal(p3, s3).
Expand All @@ -50,6 +54,7 @@ fn spark_decimal_div_internal(
args: &[ColumnarValue],
data_type: &DataType,
is_integral_div: bool,
eval_mode: EvalMode,
) -> Result<ColumnarValue, DataFusionError> {
let left = &args[0];
let right = &args[1];
Expand Down Expand Up @@ -80,9 +85,12 @@ fn spark_decimal_div_internal(
let r_mul = ten.pow(r_exp);
let five = BigInt::from(5);
let zero = BigInt::from(0);
arrow::compute::kernels::arity::binary(left, right, |l, r| {
arrow::compute::kernels::arity::try_binary(left, right, |l, r| {
let l = BigInt::from(l) * &l_mul;
let r = BigInt::from(r) * &r_mul;
if eval_mode == EvalMode::Ansi && is_integral_div && r.is_zero() {
return Err(ArrowError::ComputeError(divide_by_zero_error().to_string()));
}
let div = if r.eq(&zero) { zero.clone() } else { &l / &r };
let res = if is_integral_div {
div
Expand All @@ -91,14 +99,17 @@ fn spark_decimal_div_internal(
} else {
div + &five
} / &ten;
res.to_i128().unwrap_or(i128::MAX)
Ok(res.to_i128().unwrap_or(i128::MAX))
})?
} else {
let l_mul = 10_i128.pow(l_exp);
let r_mul = 10_i128.pow(r_exp);
arrow::compute::kernels::arity::binary(left, right, |l, r| {
arrow::compute::kernels::arity::try_binary(left, right, |l, r| {
let l = l * l_mul;
let r = r * r_mul;
if eval_mode == EvalMode::Ansi && is_integral_div && r.is_zero() {
return Err(ArrowError::ComputeError(divide_by_zero_error().to_string()));
}
let div = if r == 0 { 0 } else { l / r };
let res = if is_integral_div {
div
Expand All @@ -107,7 +118,7 @@ fn spark_decimal_div_internal(
} else {
div + 5
} / 10;
res.to_i128().unwrap_or(i128::MAX)
Ok(res.to_i128().unwrap_or(i128::MAX))
})?
};
let result = result.with_data_type(DataType::Decimal128(p3, s3));
Expand Down
10 changes: 1 addition & 9 deletions spark/src/main/scala/org/apache/comet/serde/arithmetic.scala
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,6 @@ object CometDivide extends CometExpressionSerde[Divide] with MathBase {

object CometIntegralDivide extends CometExpressionSerde[IntegralDivide] with MathBase {

override def getSupportLevel(expr: IntegralDivide): SupportLevel = {
if (expr.evalMode == EvalMode.ANSI) {
Incompatible(Some("ANSI mode is not supported"))
} else {
Compatible(None)
}
}

override def convert(
expr: IntegralDivide,
inputs: Seq[Attribute],
Expand All @@ -206,7 +198,7 @@ object CometIntegralDivide extends CometExpressionSerde[IntegralDivide] with Mat
if (expr.right.dataType.isInstanceOf[DecimalType]) expr.right
else Cast(expr.right, DecimalType(19, 0))

val rightExpr = nullIfWhenPrimitive(right)
val rightExpr = if (expr.evalMode != EvalMode.ANSI) nullIfWhenPrimitive(right) else right

val dataType = (left.dataType, right.dataType) match {
case (l: DecimalType, r: DecimalType) =>
Expand Down
24 changes: 21 additions & 3 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
val ARITHMETIC_OVERFLOW_EXCEPTION_MSG =
"""org.apache.comet.CometNativeException: [ARITHMETIC_OVERFLOW] integer overflow. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error"""
val DIVIDE_BY_ZERO_EXCEPTION_MSG =
"""org.apache.comet.CometNativeException: [DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead"""
"""Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead"""

test("compare true/false to negative zero") {
Seq(false, true).foreach { dictionary =>
Expand Down Expand Up @@ -2929,7 +2929,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}

test("ANSI support for divide (division by zero)") {
// TODO : Support ANSI mode in Integral divide -
val data = Seq((Integer.MIN_VALUE, 0))
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
withParquetTable(data, "tbl") {
Expand All @@ -2950,7 +2949,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}

test("ANSI support for divide (division by zero) float division") {
// TODO : Support ANSI mode in Integral divide -
val data = Seq((Float.MinPositiveValue, 0.0))
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
withParquetTable(data, "tbl") {
Expand All @@ -2970,6 +2968,26 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("ANSI support for integral divide (division by zero)") {
val data = Seq((Integer.MIN_VALUE, 0))
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
withParquetTable(data, "tbl") {
val res = spark.sql("""
|SELECT
| _1 div _2
| from tbl
| """.stripMargin)

checkSparkMaybeThrows(res) match {
case (Some(sparkExc), Some(cometExc)) =>
assert(cometExc.getMessage.contains(DIVIDE_BY_ZERO_EXCEPTION_MSG))
assert(sparkExc.getMessage.contains("Division by zero"))
case _ => fail("Exception should be thrown")
}
}
}
}

test("test integral divide overflow for decimal") {
if (isSpark40Plus) {
Seq(true, false)
Expand Down
Loading