Skip to content

Commit 065a822

Browse files
committed
rebase
1 parent 03c0626 commit 065a822

File tree

5 files changed

+46
-22
lines changed

5 files changed

+46
-22
lines changed

native/spark-expr/benches/decimal_div.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use arrow::compute::cast;
2020
use arrow::datatypes::DataType;
2121
use criterion::{criterion_group, criterion_main, Criterion};
2222
use datafusion::physical_plan::ColumnarValue;
23-
use datafusion_comet_spark_expr::{spark_decimal_div, spark_decimal_integral_div};
23+
use datafusion_comet_spark_expr::{spark_decimal_div, spark_decimal_integral_div, EvalMode};
2424
use std::hint::black_box;
2525
use std::sync::Arc;
2626

@@ -48,6 +48,7 @@ fn criterion_benchmark(c: &mut Criterion) {
4848
black_box(spark_decimal_div(
4949
black_box(&args),
5050
black_box(&DataType::Decimal128(10, 4)),
51+
EvalMode::Legacy,
5152
))
5253
})
5354
});
@@ -57,6 +58,7 @@ fn criterion_benchmark(c: &mut Criterion) {
5758
black_box(spark_decimal_integral_div(
5859
black_box(&args),
5960
black_box(&DataType::Decimal128(10, 4)),
61+
EvalMode::Legacy,
6062
))
6163
})
6264
});

native/spark-expr/src/comet_scalar_funcs.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,13 +133,14 @@ pub fn create_comet_physical_fun_with_eval_mode(
133133
make_comet_scalar_udf!("unhex", func, without data_type)
134134
}
135135
"decimal_div" => {
136-
make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type)
136+
make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type, eval_mode)
137137
}
138138
"decimal_integral_div" => {
139139
make_comet_scalar_udf!(
140140
"decimal_integral_div",
141141
spark_decimal_integral_div,
142-
data_type
142+
data_type,
143+
eval_mode
143144
)
144145
}
145146
"checked_add" => {

native/spark-expr/src/math_funcs/div.rs

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,29 +16,33 @@
1616
// under the License.
1717

1818
use crate::math_funcs::utils::get_precision_scale;
19+
use crate::{divide_by_zero_error, EvalMode};
1920
use arrow::array::{Array, Decimal128Array};
2021
use arrow::datatypes::{DataType, DECIMAL128_MAX_PRECISION};
22+
use arrow::error::ArrowError;
2123
use arrow::{
2224
array::{ArrayRef, AsArray},
2325
datatypes::Decimal128Type,
2426
};
2527
use datafusion::common::DataFusionError;
2628
use datafusion::physical_plan::ColumnarValue;
27-
use num::{BigInt, Signed, ToPrimitive};
29+
use num::{BigInt, Signed, ToPrimitive, Zero};
2830
use std::sync::Arc;
2931

3032
pub fn spark_decimal_div(
3133
args: &[ColumnarValue],
3234
data_type: &DataType,
35+
eval_mode: EvalMode,
3336
) -> Result<ColumnarValue, DataFusionError> {
34-
spark_decimal_div_internal(args, data_type, false)
37+
spark_decimal_div_internal(args, data_type, false, eval_mode)
3538
}
3639

3740
pub fn spark_decimal_integral_div(
3841
args: &[ColumnarValue],
3942
data_type: &DataType,
43+
eval_mode: EvalMode,
4044
) -> Result<ColumnarValue, DataFusionError> {
41-
spark_decimal_div_internal(args, data_type, true)
45+
spark_decimal_div_internal(args, data_type, true, eval_mode)
4246
}
4347

4448
// Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) = Decimal(p3, s3).
@@ -50,6 +54,7 @@ fn spark_decimal_div_internal(
5054
args: &[ColumnarValue],
5155
data_type: &DataType,
5256
is_integral_div: bool,
57+
eval_mode: EvalMode,
5358
) -> Result<ColumnarValue, DataFusionError> {
5459
let left = &args[0];
5560
let right = &args[1];
@@ -80,9 +85,12 @@ fn spark_decimal_div_internal(
8085
let r_mul = ten.pow(r_exp);
8186
let five = BigInt::from(5);
8287
let zero = BigInt::from(0);
83-
arrow::compute::kernels::arity::binary(left, right, |l, r| {
88+
arrow::compute::kernels::arity::try_binary(left, right, |l, r| {
8489
let l = BigInt::from(l) * &l_mul;
8590
let r = BigInt::from(r) * &r_mul;
91+
if eval_mode == EvalMode::Ansi && is_integral_div && r.is_zero() {
92+
return Err(ArrowError::ComputeError(divide_by_zero_error().to_string()));
93+
}
8694
let div = if r.eq(&zero) { zero.clone() } else { &l / &r };
8795
let res = if is_integral_div {
8896
div
@@ -91,14 +99,17 @@ fn spark_decimal_div_internal(
9199
} else {
92100
div + &five
93101
} / &ten;
94-
res.to_i128().unwrap_or(i128::MAX)
102+
Ok(res.to_i128().unwrap_or(i128::MAX))
95103
})?
96104
} else {
97105
let l_mul = 10_i128.pow(l_exp);
98106
let r_mul = 10_i128.pow(r_exp);
99-
arrow::compute::kernels::arity::binary(left, right, |l, r| {
107+
arrow::compute::kernels::arity::try_binary(left, right, |l, r| {
100108
let l = l * l_mul;
101109
let r = r * r_mul;
110+
if eval_mode == EvalMode::Ansi && is_integral_div && r.is_zero() {
111+
return Err(ArrowError::ComputeError(divide_by_zero_error().to_string()));
112+
}
102113
let div = if r == 0 { 0 } else { l / r };
103114
let res = if is_integral_div {
104115
div
@@ -107,7 +118,7 @@ fn spark_decimal_div_internal(
107118
} else {
108119
div + 5
109120
} / 10;
110-
res.to_i128().unwrap_or(i128::MAX)
121+
Ok(res.to_i128().unwrap_or(i128::MAX))
111122
})?
112123
};
113124
let result = result.with_data_type(DataType::Decimal128(p3, s3));

spark/src/main/scala/org/apache/comet/serde/arithmetic.scala

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -180,14 +180,6 @@ object CometDivide extends CometExpressionSerde[Divide] with MathBase {
180180

181181
object CometIntegralDivide extends CometExpressionSerde[IntegralDivide] with MathBase {
182182

183-
override def getSupportLevel(expr: IntegralDivide): SupportLevel = {
184-
if (expr.evalMode == EvalMode.ANSI) {
185-
Incompatible(Some("ANSI mode is not supported"))
186-
} else {
187-
Compatible(None)
188-
}
189-
}
190-
191183
override def convert(
192184
expr: IntegralDivide,
193185
inputs: Seq[Attribute],
@@ -206,7 +198,7 @@ object CometIntegralDivide extends CometExpressionSerde[IntegralDivide] with Mat
206198
if (expr.right.dataType.isInstanceOf[DecimalType]) expr.right
207199
else Cast(expr.right, DecimalType(19, 0))
208200

209-
val rightExpr = nullIfWhenPrimitive(right)
201+
val rightExpr = if (expr.evalMode != EvalMode.ANSI) right else nullIfWhenPrimitive(right)
210202

211203
val dataType = (left.dataType, right.dataType) match {
212204
case (l: DecimalType, r: DecimalType) =>

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
5858
val ARITHMETIC_OVERFLOW_EXCEPTION_MSG =
5959
"""org.apache.comet.CometNativeException: [ARITHMETIC_OVERFLOW] integer overflow. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error"""
6060
val DIVIDE_BY_ZERO_EXCEPTION_MSG =
61-
"""org.apache.comet.CometNativeException: [DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead"""
61+
"""Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead"""
6262

6363
test("compare true/false to negative zero") {
6464
Seq(false, true).foreach { dictionary =>
@@ -2929,7 +2929,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
29292929
}
29302930

29312931
test("ANSI support for divide (division by zero)") {
2932-
// TODO : Support ANSI mode in Integral divide -
29332932
val data = Seq((Integer.MIN_VALUE, 0))
29342933
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
29352934
withParquetTable(data, "tbl") {
@@ -2950,7 +2949,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
29502949
}
29512950

29522951
test("ANSI support for divide (division by zero) float division") {
2953-
// TODO : Support ANSI mode in Integral divide -
29542952
val data = Seq((Float.MinPositiveValue, 0.0))
29552953
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
29562954
withParquetTable(data, "tbl") {
@@ -2970,6 +2968,26 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
29702968
}
29712969
}
29722970

2971+
test("ANSI support for integral divide (division by zero)") {
2972+
val data = Seq((Integer.MIN_VALUE, 0))
2973+
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
2974+
withParquetTable(data, "tbl") {
2975+
val res = spark.sql("""
2976+
|SELECT
2977+
| _1 div _2
2978+
| from tbl
2979+
| """.stripMargin)
2980+
2981+
checkSparkMaybeThrows(res) match {
2982+
case (Some(sparkExc), Some(cometExc)) =>
2983+
assert(cometExc.getMessage.contains(DIVIDE_BY_ZERO_EXCEPTION_MSG))
2984+
assert(sparkExc.getMessage.contains("Division by zero"))
2985+
case _ => fail("Exception should be thrown")
2986+
}
2987+
}
2988+
}
2989+
}
2990+
29732991
test("test integral divide overflow for decimal") {
29742992
if (isSpark40Plus) {
29752993
Seq(true, false)

0 commit comments

Comments
 (0)