diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 393f57662e..2b41a4bdcb 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -22,7 +22,7 @@ use crate::{ spark_array_repeat, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkBitwiseNot, - SparkDateTrunc, SparkStringSpace, + SparkStringSpace, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -191,7 +191,6 @@ fn all_scalar_functions() -> Vec> { vec![ Arc::new(ScalarUDF::new_from_impl(SparkBitwiseNot::default())), Arc::new(ScalarUDF::new_from_impl(SparkBitwiseCount::default())), - Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())), Arc::new(ScalarUDF::new_from_impl(SparkStringSpace::default())), ] } diff --git a/native/spark-expr/src/datetime_funcs/date_trunc.rs b/native/spark-expr/src/datetime_funcs/date_trunc.rs deleted file mode 100644 index 6d36b0975c..0000000000 --- a/native/spark-expr/src/datetime_funcs/date_trunc.rs +++ /dev/null @@ -1,89 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::datatypes::DataType; -use datafusion::common::{utils::take_function_args, DataFusionError, Result, ScalarValue::Utf8}; -use datafusion::logical_expr::{ - ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, -}; -use std::any::Any; - -use crate::kernels::temporal::{date_trunc_array_fmt_dyn, date_trunc_dyn}; - -#[derive(Debug, PartialEq, Eq, Hash)] -pub struct SparkDateTrunc { - signature: Signature, - aliases: Vec, -} - -impl SparkDateTrunc { - pub fn new() -> Self { - Self { - signature: Signature::exact( - vec![DataType::Date32, DataType::Utf8], - Volatility::Immutable, - ), - aliases: vec![], - } - } -} - -impl Default for SparkDateTrunc { - fn default() -> Self { - Self::new() - } -} - -impl ScalarUDFImpl for SparkDateTrunc { - fn as_any(&self) -> &dyn Any { - self - } - - fn name(&self) -> &str { - "date_trunc" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, _: &[DataType]) -> Result { - Ok(DataType::Date32) - } - - fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let [date, format] = take_function_args(self.name(), args.args)?; - match (date, format) { - (ColumnarValue::Array(date), ColumnarValue::Scalar(Utf8(Some(format)))) => { - let result = date_trunc_dyn(&date, format)?; - Ok(ColumnarValue::Array(result)) - } - (ColumnarValue::Array(date), ColumnarValue::Array(formats)) => { - let result = date_trunc_array_fmt_dyn(&date, &formats)?; - Ok(ColumnarValue::Array(result)) - } - _ => Err(DataFusionError::Execution( - "Invalid input to function DateTrunc. Expected (PrimitiveArray, Scalar) or \ - (PrimitiveArray, StringArray)".to_string(), - )), - } - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} diff --git a/native/spark-expr/src/datetime_funcs/mod.rs b/native/spark-expr/src/datetime_funcs/mod.rs index ef8041e5fe..d53201a174 100644 --- a/native/spark-expr/src/datetime_funcs/mod.rs +++ b/native/spark-expr/src/datetime_funcs/mod.rs @@ -15,11 +15,9 @@ // specific language governing permissions and limitations // under the License. -mod date_trunc; mod extract_date_part; mod timestamp_trunc; -pub use date_trunc::SparkDateTrunc; pub use extract_date_part::SparkHour; pub use extract_date_part::SparkMinute; pub use extract_date_part::SparkSecond; diff --git a/native/spark-expr/src/kernels/temporal.rs b/native/spark-expr/src/kernels/temporal.rs index 09e2c905c7..10f6abef45 100644 --- a/native/spark-expr/src/kernels/temporal.rs +++ b/native/spark-expr/src/kernels/temporal.rs @@ -17,7 +17,7 @@ //! temporal kernels -use chrono::{DateTime, Datelike, Duration, NaiveDateTime, Timelike, Utc}; +use chrono::{DateTime, Datelike, Duration, Timelike, Utc}; use std::sync::Arc; @@ -25,7 +25,7 @@ use arrow::array::{ downcast_dictionary_array, downcast_temporal_array, temporal_conversions::*, timezone::Tz, - types::{ArrowDictionaryKeyType, ArrowTemporalType, Date32Type, TimestampMicrosecondType}, + types::{ArrowDictionaryKeyType, ArrowTemporalType, TimestampMicrosecondType}, ArrowNumericType, }; use arrow::{ @@ -42,53 +42,6 @@ macro_rules! return_compute_error_with { }; } -// The number of days between the beginning of the proleptic gregorian calendar (0001-01-01) -// and the beginning of the Unix Epoch (1970-01-01) -const DAYS_TO_UNIX_EPOCH: i32 = 719_163; - -// Copied from arrow_arith/temporal.rs with modification to the output datatype -// Transforms a array of NaiveDate to an array of Date32 after applying an operation -fn as_datetime_with_op, T: ArrowTemporalType, F>( - iter: ArrayIter, - mut builder: PrimitiveBuilder, - op: F, -) -> Date32Array -where - F: Fn(NaiveDateTime) -> i32, - i64: From, -{ - iter.into_iter().for_each(|value| { - if let Some(value) = value { - match as_datetime::(i64::from(value)) { - Some(dt) => builder.append_value(op(dt)), - None => builder.append_null(), - } - } else { - builder.append_null(); - } - }); - - builder.finish() -} - -#[inline] -fn as_datetime_with_op_single( - value: Option, - builder: &mut PrimitiveBuilder, - op: F, -) where - F: Fn(NaiveDateTime) -> i32, -{ - if let Some(value) = value { - match as_datetime::(i64::from(value)) { - Some(dt) => builder.append_value(op(dt)), - None => builder.append_null(), - } - } else { - builder.append_null(); - } -} - // Based on arrow_arith/temporal.rs:extract_component_from_datetime_array // Transforms an array of DateTime to an arrayOf TimeStampMicrosecond after applying an // operation @@ -143,11 +96,6 @@ where Ok(()) } -#[inline] -fn as_days_from_unix_epoch(dt: Option) -> i32 { - dt.unwrap().num_days_from_ce() - DAYS_TO_UNIX_EPOCH -} - // Apply the Tz to the Naive Date Time,,convert to UTC, and return as microseconds in Unix epoch #[inline] fn as_micros_from_unix_epoch_utc(dt: Option>) -> i64 { @@ -244,250 +192,6 @@ fn trunc_date_to_microsec(dt: T) -> Option { Some(dt).and_then(|d| d.with_nanosecond(1_000 * (d.nanosecond() / 1_000))) } -/// -/// Implements the spark [TRUNC](https://spark.apache.org/docs/latest/api/sql/index.html#trunc) -/// function where the specified format is a scalar value -/// -/// array is an array of Date32 values. The array may be a dictionary array. -/// -/// format is a scalar string specifying the format to apply to the timestamp value. -pub(crate) fn date_trunc_dyn(array: &dyn Array, format: String) -> Result { - match array.data_type().clone() { - DataType::Dictionary(_, _) => { - downcast_dictionary_array!( - array => { - let truncated_values = date_trunc_dyn(array.values(), format)?; - Ok(Arc::new(array.with_values(truncated_values))) - } - dt => return_compute_error_with!("date_trunc does not support", dt), - ) - } - _ => { - downcast_temporal_array!( - array => { - date_trunc(array, format) - .map(|a| Arc::new(a) as ArrayRef) - } - dt => return_compute_error_with!("date_trunc does not support", dt), - ) - } - } -} - -pub(crate) fn date_trunc( - array: &PrimitiveArray, - format: String, -) -> Result -where - T: ArrowTemporalType + ArrowNumericType, - i64: From, -{ - let builder = Date32Builder::with_capacity(array.len()); - let iter = ArrayIter::new(array); - match array.data_type() { - DataType::Date32 => match format.to_uppercase().as_str() { - "YEAR" | "YYYY" | "YY" => Ok(as_datetime_with_op::<&PrimitiveArray, T, _>( - iter, - builder, - |dt| as_days_from_unix_epoch(trunc_date_to_year(dt)), - )), - "QUARTER" => Ok(as_datetime_with_op::<&PrimitiveArray, T, _>( - iter, - builder, - |dt| as_days_from_unix_epoch(trunc_date_to_quarter(dt)), - )), - "MONTH" | "MON" | "MM" => Ok(as_datetime_with_op::<&PrimitiveArray, T, _>( - iter, - builder, - |dt| as_days_from_unix_epoch(trunc_date_to_month(dt)), - )), - "WEEK" => Ok(as_datetime_with_op::<&PrimitiveArray, T, _>( - iter, - builder, - |dt| as_days_from_unix_epoch(trunc_date_to_week(dt)), - )), - _ => Err(SparkError::Internal(format!( - "Unsupported format: {format:?} for function 'date_trunc'" - ))), - }, - dt => return_compute_error_with!( - "Unsupported input type '{:?}' for function 'date_trunc'", - dt - ), - } -} - -/// -/// Implements the spark [TRUNC](https://spark.apache.org/docs/latest/api/sql/index.html#trunc) -/// function where the specified format may be an array -/// -/// array is an array of Date32 values. The array may be a dictionary array. -/// -/// format is an array of strings specifying the format to apply to the corresponding date value. -/// The array may be a dictionary array. -pub(crate) fn date_trunc_array_fmt_dyn( - array: &dyn Array, - formats: &dyn Array, -) -> Result { - match (array.data_type().clone(), formats.data_type().clone()) { - (DataType::Dictionary(_, v), DataType::Dictionary(_, f)) => { - if !matches!(*v, DataType::Date32) { - return_compute_error_with!("date_trunc does not support", v) - } - if !matches!(*f, DataType::Utf8) { - return_compute_error_with!("date_trunc does not support format type ", f) - } - downcast_dictionary_array!( - formats => { - downcast_dictionary_array!( - array => { - date_trunc_array_fmt_dict_dict( - &array.downcast_dict::().unwrap(), - &formats.downcast_dict::().unwrap()) - .map(|a| Arc::new(a) as ArrayRef) - } - dt => return_compute_error_with!("date_trunc does not support", dt) - ) - } - fmt => return_compute_error_with!("date_trunc does not support format type", fmt), - ) - } - (DataType::Dictionary(_, v), DataType::Utf8) => { - if !matches!(*v, DataType::Date32) { - return_compute_error_with!("date_trunc does not support", v) - } - downcast_dictionary_array!( - array => { - date_trunc_array_fmt_dict_plain( - &array.downcast_dict::().unwrap(), - formats.as_any().downcast_ref::() - .expect("Unexpected value type in formats")) - .map(|a| Arc::new(a) as ArrayRef) - } - dt => return_compute_error_with!("date_trunc does not support", dt), - ) - } - (DataType::Date32, DataType::Dictionary(_, f)) => { - if !matches!(*f, DataType::Utf8) { - return_compute_error_with!("date_trunc does not support format type ", f) - } - downcast_dictionary_array!( - formats => { - downcast_temporal_array!(array => { - date_trunc_array_fmt_plain_dict( - array.as_any().downcast_ref::() - .expect("Unexpected error in casting date array"), - &formats.downcast_dict::().unwrap()) - .map(|a| Arc::new(a) as ArrayRef) - } - dt => return_compute_error_with!("date_trunc does not support", dt), - ) - } - fmt => return_compute_error_with!("date_trunc does not support format type", fmt), - ) - } - (DataType::Date32, DataType::Utf8) => date_trunc_array_fmt_plain_plain( - array - .as_any() - .downcast_ref::() - .expect("Unexpected error in casting date array"), - formats - .as_any() - .downcast_ref::() - .expect("Unexpected value type in formats"), - ) - .map(|a| Arc::new(a) as ArrayRef), - (dt, fmt) => Err(SparkError::Internal(format!( - "Unsupported datatype: {dt:}, format: {fmt:?} for function 'date_trunc'" - ))), - } -} - -macro_rules! date_trunc_array_fmt_helper { - ($array: ident, $formats: ident, $datatype: ident) => {{ - let mut builder = Date32Builder::with_capacity($array.len()); - let iter = $array.into_iter(); - match $datatype { - DataType::Date32 => { - for (index, val) in iter.enumerate() { - let op_result = match $formats.value(index).to_uppercase().as_str() { - "YEAR" | "YYYY" | "YY" => { - Ok(as_datetime_with_op_single(val, &mut builder, |dt| { - as_days_from_unix_epoch(trunc_date_to_year(dt)) - })) - } - "QUARTER" => Ok(as_datetime_with_op_single(val, &mut builder, |dt| { - as_days_from_unix_epoch(trunc_date_to_quarter(dt)) - })), - "MONTH" | "MON" | "MM" => { - Ok(as_datetime_with_op_single(val, &mut builder, |dt| { - as_days_from_unix_epoch(trunc_date_to_month(dt)) - })) - } - "WEEK" => Ok(as_datetime_with_op_single(val, &mut builder, |dt| { - as_days_from_unix_epoch(trunc_date_to_week(dt)) - })), - _ => Err(SparkError::Internal(format!( - "Unsupported format: {:?} for function 'date_trunc'", - $formats.value(index) - ))), - }; - op_result? - } - Ok(builder.finish()) - } - dt => return_compute_error_with!( - "Unsupported input type '{:?}' for function 'date_trunc'", - dt - ), - } - }}; -} - -fn date_trunc_array_fmt_plain_plain( - array: &Date32Array, - formats: &StringArray, -) -> Result -where -{ - let data_type = array.data_type(); - date_trunc_array_fmt_helper!(array, formats, data_type) -} - -fn date_trunc_array_fmt_plain_dict( - array: &Date32Array, - formats: &TypedDictionaryArray, -) -> Result -where - K: ArrowDictionaryKeyType, -{ - let data_type = array.data_type(); - date_trunc_array_fmt_helper!(array, formats, data_type) -} - -fn date_trunc_array_fmt_dict_plain( - array: &TypedDictionaryArray, - formats: &StringArray, -) -> Result -where - K: ArrowDictionaryKeyType, -{ - let data_type = array.values().data_type(); - date_trunc_array_fmt_helper!(array, formats, data_type) -} - -fn date_trunc_array_fmt_dict_dict( - array: &TypedDictionaryArray, - formats: &TypedDictionaryArray, -) -> Result -where - K: ArrowDictionaryKeyType, - F: ArrowDictionaryKeyType, -{ - let data_type = array.values().data_type(); - date_trunc_array_fmt_helper!(array, formats, data_type) -} - /// /// Implements the spark [DATE_TRUNC](https://spark.apache.org/docs/latest/api/sql/index.html#date_trunc) /// function where the specified format is a scalar value @@ -806,156 +510,15 @@ where #[cfg(test)] mod tests { - use crate::kernels::temporal::{ - date_trunc, date_trunc_array_fmt_dyn, timestamp_trunc, timestamp_trunc_array_fmt_dyn, - }; + use crate::kernels::temporal::{timestamp_trunc, timestamp_trunc_array_fmt_dyn}; use arrow::array::{ builder::{PrimitiveDictionaryBuilder, StringDictionaryBuilder}, iterator::ArrayIter, - types::{Date32Type, Int32Type, TimestampMicrosecondType}, - Array, Date32Array, PrimitiveArray, StringArray, TimestampMicrosecondArray, + types::{Int32Type, TimestampMicrosecondType}, + Array, PrimitiveArray, StringArray, TimestampMicrosecondArray, }; use std::sync::Arc; - #[test] - #[cfg_attr(miri, ignore)] // test takes too long with miri - fn test_date_trunc() { - let size = 1000; - let mut vec: Vec = Vec::with_capacity(size); - for i in 0..size { - vec.push(i as i32); - } - let array = Date32Array::from(vec); - for fmt in [ - "YEAR", "YYYY", "YY", "QUARTER", "MONTH", "MON", "MM", "WEEK", - ] { - match date_trunc(&array, fmt.to_string()) { - Ok(a) => { - for i in 0..size { - assert!(array.values().get(i) >= a.values().get(i)) - } - } - _ => unreachable!(), - } - } - } - - #[test] - // This test only verifies that the various input array types work. Actually correctness to - // ensure this produces the same results as spark is verified in the JVM tests - fn test_date_trunc_array_fmt_dyn() { - let size = 10; - let formats = [ - "YEAR", "YYYY", "YY", "QUARTER", "MONTH", "MON", "MM", "WEEK", - ]; - let mut vec: Vec = Vec::with_capacity(size * formats.len()); - let mut fmt_vec: Vec<&str> = Vec::with_capacity(size * formats.len()); - for i in 0..size { - for fmt_value in &formats { - vec.push(i as i32 * 1_000_001); - fmt_vec.push(fmt_value); - } - } - - // timestamp array - let array = Date32Array::from(vec); - - // formats array - let fmt_array = StringArray::from(fmt_vec); - - // timestamp dictionary array - let mut date_dict_builder = PrimitiveDictionaryBuilder::::new(); - for v in array.iter() { - date_dict_builder - .append(v.unwrap()) - .expect("Error in building timestamp array"); - } - let mut array_dict = date_dict_builder.finish(); - // apply timezone - array_dict = array_dict.with_values(Arc::new( - array_dict - .values() - .as_any() - .downcast_ref::() - .unwrap() - .clone(), - )); - - // formats dictionary array - let mut formats_dict_builder = StringDictionaryBuilder::::new(); - for v in fmt_array.iter() { - formats_dict_builder - .append(v.unwrap()) - .expect("Error in building formats array"); - } - let fmt_dict = formats_dict_builder.finish(); - - // verify input arrays - let iter = ArrayIter::new(&array); - let mut dict_iter = array_dict - .downcast_dict::>() - .unwrap() - .into_iter(); - for val in iter { - assert_eq!( - dict_iter - .next() - .expect("array and dictionary array do not match"), - val - ) - } - - // verify input format arrays - let fmt_iter = ArrayIter::new(&fmt_array); - let mut fmt_dict_iter = fmt_dict.downcast_dict::().unwrap().into_iter(); - for val in fmt_iter { - assert_eq!( - fmt_dict_iter - .next() - .expect("formats and dictionary formats do not match"), - val - ) - } - - // test cases - if let Ok(a) = date_trunc_array_fmt_dyn(&array, &fmt_array) { - for i in 0..array.len() { - assert!( - array.value(i) >= a.as_any().downcast_ref::().unwrap().value(i) - ) - } - } else { - unreachable!() - } - if let Ok(a) = date_trunc_array_fmt_dyn(&array_dict, &fmt_array) { - for i in 0..array.len() { - assert!( - array.value(i) >= a.as_any().downcast_ref::().unwrap().value(i) - ) - } - } else { - unreachable!() - } - if let Ok(a) = date_trunc_array_fmt_dyn(&array, &fmt_dict) { - for i in 0..array.len() { - assert!( - array.value(i) >= a.as_any().downcast_ref::().unwrap().value(i) - ) - } - } else { - unreachable!() - } - if let Ok(a) = date_trunc_array_fmt_dyn(&array_dict, &fmt_dict) { - for i in 0..array.len() { - assert!( - array.value(i) >= a.as_any().downcast_ref::().unwrap().value(i) - ) - } - } else { - unreachable!() - } - } - #[test] #[cfg_attr(miri, ignore)] // test takes too long with miri fn test_timestamp_trunc() { diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 932fcbe53d..c576057717 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -68,7 +68,7 @@ pub use comet_scalar_funcs::{ create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, register_all_comet_functions, }; -pub use datetime_funcs::{SparkDateTrunc, SparkHour, SparkMinute, SparkSecond, TimestampTruncExpr}; +pub use datetime_funcs::{SparkHour, SparkMinute, SparkSecond, TimestampTruncExpr}; pub use error::{SparkError, SparkResult}; pub use hash_funcs::*; pub use json_funcs::ToJson; diff --git a/spark/src/main/scala/org/apache/comet/serde/datetime.scala b/spark/src/main/scala/org/apache/comet/serde/datetime.scala index 9473ee30e4..e7b1ff54dd 100644 --- a/spark/src/main/scala/org/apache/comet/serde/datetime.scala +++ b/spark/src/main/scala/org/apache/comet/serde/datetime.scala @@ -20,7 +20,7 @@ package org.apache.comet.serde import org.apache.spark.sql.catalyst.expressions.{Attribute, DateAdd, DateSub, DayOfMonth, DayOfWeek, DayOfYear, GetDateField, Hour, Literal, Minute, Month, Quarter, Second, TruncDate, TruncTimestamp, WeekDay, WeekOfYear, Year} -import org.apache.spark.sql.types.{DateType, IntegerType} +import org.apache.spark.sql.types.IntegerType import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.serde.CometGetDateField.CometGetDateField @@ -255,18 +255,7 @@ object CometDateAdd extends CometScalarFunction[DateAdd]("date_add") object CometDateSub extends CometScalarFunction[DateSub]("date_sub") -object CometTruncDate extends CometExpressionSerde[TruncDate] { - override def convert( - expr: TruncDate, - inputs: Seq[Attribute], - binding: Boolean): Option[ExprOuterClass.Expr] = { - val childExpr = exprToProtoInternal(expr.date, inputs, binding) - val formatExpr = exprToProtoInternal(expr.format, inputs, binding) - val optExpr = - scalarFunctionExprToProtoWithReturnType("date_trunc", DateType, childExpr, formatExpr) - optExprWithInfo(optExpr, expr, expr.date, expr.format) - } -} +object CometTruncDate extends CometScalarFunction[TruncDate]("date_trunc") object CometTruncTimestamp extends CometExpressionSerde[TruncTimestamp] { override def convert(