Skip to content
61 changes: 46 additions & 15 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,62 +238,61 @@ impl PhysicalPlanner {
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
match spark_expr.expr_struct.as_ref().unwrap() {
ExprStruct::Add(expr) => {
// TODO respect eval mode
// https://github.com/apache/datafusion-comet/issues/2021
// TODO respect ANSI eval mode
// https://github.com/apache/datafusion-comet/issues/536
let _eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
self.create_binary_expr(
expr.left.as_ref().unwrap(),
expr.right.as_ref().unwrap(),
expr.return_type.as_ref(),
DataFusionOperator::Plus,
input_schema,
eval_mode,
)
}
ExprStruct::Subtract(expr) => {
// TODO respect eval mode
// https://github.com/apache/datafusion-comet/issues/2021
// TODO respect ANSI eval mode
// https://github.com/apache/datafusion-comet/issues/535
let _eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
self.create_binary_expr(
expr.left.as_ref().unwrap(),
expr.right.as_ref().unwrap(),
expr.return_type.as_ref(),
DataFusionOperator::Minus,
input_schema,
eval_mode,
)
}
ExprStruct::Multiply(expr) => {
// TODO respect eval mode
// https://github.com/apache/datafusion-comet/issues/2021
// TODO respect ANSI eval mode
// https://github.com/apache/datafusion-comet/issues/534
let _eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
self.create_binary_expr(
expr.left.as_ref().unwrap(),
expr.right.as_ref().unwrap(),
expr.return_type.as_ref(),
DataFusionOperator::Multiply,
input_schema,
eval_mode,
)
}
ExprStruct::Divide(expr) => {
// TODO respect eval mode
// https://github.com/apache/datafusion-comet/issues/2021
// TODO respect ANSI eval mode
// https://github.com/apache/datafusion-comet/issues/533
let _eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
self.create_binary_expr(
expr.left.as_ref().unwrap(),
expr.right.as_ref().unwrap(),
expr.return_type.as_ref(),
DataFusionOperator::Divide,
input_schema,
eval_mode,
)
}
ExprStruct::IntegralDivide(expr) => {
// TODO respect eval mode
// https://github.com/apache/datafusion-comet/issues/2021
// https://github.com/apache/datafusion-comet/issues/533
let _eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
self.create_binary_expr_with_options(
expr.left.as_ref().unwrap(),
expr.right.as_ref().unwrap(),
Expand All @@ -303,6 +302,7 @@ impl PhysicalPlanner {
BinaryExprOptions {
is_integral_div: true,
},
eval_mode,
)
}
ExprStruct::Remainder(expr) => {
Expand Down Expand Up @@ -1004,6 +1004,7 @@ impl PhysicalPlanner {
return_type: Option<&spark_expression::DataType>,
op: DataFusionOperator,
input_schema: SchemaRef,
eval_mode: EvalMode,
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
self.create_binary_expr_with_options(
left,
Expand All @@ -1012,9 +1013,11 @@ impl PhysicalPlanner {
op,
input_schema,
BinaryExprOptions::default(),
eval_mode,
)
}

#[allow(clippy::too_many_arguments)]
fn create_binary_expr_with_options(
&self,
left: &Expr,
Expand All @@ -1023,6 +1026,7 @@ impl PhysicalPlanner {
op: DataFusionOperator,
input_schema: SchemaRef,
options: BinaryExprOptions,
eval_mode: EvalMode,
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
let left = self.create_expr(left, Arc::clone(&input_schema))?;
let right = self.create_expr(right, Arc::clone(&input_schema))?;
Expand Down Expand Up @@ -1087,7 +1091,34 @@ impl PhysicalPlanner {
Arc::new(Field::new(func_name, data_type, true)),
)))
}
_ => Ok(Arc::new(BinaryExpr::new(left, op, right))),
_ => {
let data_type = return_type.map(to_arrow_datatype).unwrap();
if eval_mode == EvalMode::Try && data_type.is_integer() {
let op_str = match op {
DataFusionOperator::Plus => "checked_add",
DataFusionOperator::Minus => "checked_sub",
DataFusionOperator::Multiply => "checked_mul",
DataFusionOperator::Divide => "checked_div",
_ => {
todo!("Operator yet to be implemented!");
}
};
let fun_expr = create_comet_physical_fun(
op_str,
data_type.clone(),
&self.session_ctx.state(),
None,
)?;
Ok(Arc::new(ScalarFunctionExpr::new(
op_str,
fun_expr,
vec![left, right],
Arc::new(Field::new(op_str, data_type, true)),
)))
} else {
Ok(Arc::new(BinaryExpr::new(left, op, right)))
}
}
}
}

Expand Down
13 changes: 13 additions & 0 deletions native/spark-expr/src/comet_scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

use crate::hash_funcs::*;
use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mul, checked_sub};
use crate::math_funcs::modulo_expr::spark_modulo;
use crate::{
spark_array_repeat, spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div,
Expand Down Expand Up @@ -115,6 +116,18 @@ pub fn create_comet_physical_fun(
data_type
)
}
"checked_add" => {
make_comet_scalar_udf!("checked_add", checked_add, data_type)
}
"checked_sub" => {
make_comet_scalar_udf!("checked_sub", checked_sub, data_type)
}
"checked_mul" => {
make_comet_scalar_udf!("checked_mul", checked_mul, data_type)
}
"checked_div" => {
make_comet_scalar_udf!("checked_div", checked_div, data_type)
}
"murmur3_hash" => {
let func = Arc::new(spark_murmur3_hash);
make_comet_scalar_udf!("murmur3_hash", func, without data_type)
Expand Down
150 changes: 150 additions & 0 deletions native/spark-expr/src/math_funcs/checked_arithmetic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::array::{Array, ArrowNativeTypeOp, PrimitiveArray, PrimitiveBuilder};
use arrow::array::{ArrayRef, AsArray};

use arrow::datatypes::{ArrowPrimitiveType, DataType, Int32Type, Int64Type};
use datafusion::common::DataFusionError;
use datafusion::physical_plan::ColumnarValue;
use std::sync::Arc;

pub fn try_arithmetic_kernel<T>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
op: &str,
) -> Result<ArrayRef, DataFusionError>
where
T: ArrowPrimitiveType,
{
let len = left.len();
let mut builder = PrimitiveBuilder::<T>::with_capacity(len);
match op {
"checked_add" => {
for i in 0..len {
if left.is_null(i) || right.is_null(i) {
builder.append_null();
} else {
builder.append_option(left.value(i).add_checked(right.value(i)).ok());
}
}
}
"checked_sub" => {
for i in 0..len {
if left.is_null(i) || right.is_null(i) {
builder.append_null();
} else {
builder.append_option(left.value(i).sub_checked(right.value(i)).ok());
}
}
}
"checked_mul" => {
for i in 0..len {
if left.is_null(i) || right.is_null(i) {
builder.append_null();
} else {
builder.append_option(left.value(i).mul_checked(right.value(i)).ok());
}
}
}
"checked_div" => {
for i in 0..len {
if left.is_null(i) || right.is_null(i) {
builder.append_null();
} else {
builder.append_option(left.value(i).div_checked(right.value(i)).ok());
}
}
}
_ => {
return Err(DataFusionError::Internal(format!(
"Unsupported operation: {:?}",
op
)))
}
}

Ok(Arc::new(builder.finish()) as ArrayRef)
}

pub fn checked_add(
args: &[ColumnarValue],
data_type: &DataType,
) -> Result<ColumnarValue, DataFusionError> {
checked_arithmetic_internal(args, data_type, "checked_add")
}

pub fn checked_sub(
args: &[ColumnarValue],
data_type: &DataType,
) -> Result<ColumnarValue, DataFusionError> {
checked_arithmetic_internal(args, data_type, "checked_sub")
}

pub fn checked_mul(
args: &[ColumnarValue],
data_type: &DataType,
) -> Result<ColumnarValue, DataFusionError> {
checked_arithmetic_internal(args, data_type, "checked_mul")
}

pub fn checked_div(
args: &[ColumnarValue],
data_type: &DataType,
) -> Result<ColumnarValue, DataFusionError> {
checked_arithmetic_internal(args, data_type, "checked_div")
}

fn checked_arithmetic_internal(
args: &[ColumnarValue],
data_type: &DataType,
op: &str,
) -> Result<ColumnarValue, DataFusionError> {
let left = &args[0];
let right = &args[1];

let (left_arr, right_arr): (ArrayRef, ArrayRef) = match (left, right) {
(ColumnarValue::Array(l), ColumnarValue::Array(r)) => (Arc::clone(l), Arc::clone(r)),
(ColumnarValue::Scalar(l), ColumnarValue::Array(r)) => {
(l.to_array_of_size(r.len())?, Arc::clone(r))
}
(ColumnarValue::Array(l), ColumnarValue::Scalar(r)) => {
(Arc::clone(l), r.to_array_of_size(l.len())?)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may eventually want to have a specialized version of the kernel for the scalar case to avoid the overhead of creating an array from the scalar. This does not need to happen as part of this PR, though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure ! I will create a follow up enhancement to track changes for a scalar impl. Thank you for the feed back @andygrove .

}
(ColumnarValue::Scalar(l), ColumnarValue::Scalar(r)) => (l.to_array()?, r.to_array()?),
};

// Rust only supports checked_arithmetic on Int32 and Int64
let result_array = match data_type {
DataType::Int32 => try_arithmetic_kernel::<Int32Type>(
left_arr.as_primitive::<Int32Type>(),
right_arr.as_primitive::<Int32Type>(),
op,
),
DataType::Int64 => try_arithmetic_kernel::<Int64Type>(
left_arr.as_primitive::<Int64Type>(),
right_arr.as_primitive::<Int64Type>(),
op,
),
_ => Err(DataFusionError::Internal(format!(
"Unsupported data type: {:?}",
data_type
))),
};

Ok(ColumnarValue::Array(result_array?))
}
1 change: 1 addition & 0 deletions native/spark-expr/src/math_funcs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

mod ceil;
pub(crate) mod checked_arithmetic;
mod div;
mod floor;
pub(crate) mod hex;
Expand Down
21 changes: 0 additions & 21 deletions spark/src/main/scala/org/apache/comet/serde/arithmetic.scala
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,6 @@ object CometAdd extends CometExpressionSerde[Add] with MathBase {
withInfo(expr, s"Unsupported datatype ${expr.left.dataType}")
return None
}
if (expr.evalMode == EvalMode.TRY) {
withInfo(expr, s"Eval mode ${expr.evalMode} is not supported")
return None
}
createMathExpression(
expr,
expr.left,
Expand All @@ -119,10 +115,6 @@ object CometSubtract extends CometExpressionSerde[Subtract] with MathBase {
withInfo(expr, s"Unsupported datatype ${expr.left.dataType}")
return None
}
if (expr.evalMode == EvalMode.TRY) {
withInfo(expr, s"Eval mode ${expr.evalMode} is not supported")
return None
}
createMathExpression(
expr,
expr.left,
Expand All @@ -144,10 +136,6 @@ object CometMultiply extends CometExpressionSerde[Multiply] with MathBase {
withInfo(expr, s"Unsupported datatype ${expr.left.dataType}")
return None
}
if (expr.evalMode == EvalMode.TRY) {
withInfo(expr, s"Eval mode ${expr.evalMode} is not supported")
return None
}
createMathExpression(
expr,
expr.left,
Expand All @@ -169,15 +157,10 @@ object CometDivide extends CometExpressionSerde[Divide] with MathBase {
// See https://github.com/apache/arrow-datafusion/pull/6792
// For now, use NullIf to swap zeros with nulls.
val rightExpr = nullIfWhenPrimitive(expr.right)

if (!supportedDataType(expr.left.dataType)) {
withInfo(expr, s"Unsupported datatype ${expr.left.dataType}")
return None
}
if (expr.evalMode == EvalMode.TRY) {
withInfo(expr, s"Eval mode ${expr.evalMode} is not supported")
return None
}
createMathExpression(
expr,
expr.left,
Expand All @@ -199,10 +182,6 @@ object CometIntegralDivide extends CometExpressionSerde[IntegralDivide] with Mat
withInfo(expr, s"Unsupported datatype ${expr.left.dataType}")
return None
}
if (expr.evalMode == EvalMode.TRY) {
withInfo(expr, s"Eval mode ${expr.evalMode} is not supported")
return None
}

// Precision is set to 19 (max precision for a numerical data type except DecimalType)

Expand Down
Loading
Loading