diff --git a/src/query/expression/src/types.rs b/src/query/expression/src/types.rs index 395b6e79c51f6..0f82cf065153f 100755 --- a/src/query/expression/src/types.rs +++ b/src/query/expression/src/types.rs @@ -66,7 +66,6 @@ pub use self::bitmap::BitmapType; pub use self::boolean::Bitmap; pub use self::boolean::BooleanType; pub use self::boolean::MutableBitmap; -pub use self::compute_view::StringConvert; pub use self::date::DateType; pub use self::decimal::*; pub use self::empty_array::EmptyArrayType; diff --git a/src/query/expression/src/types/compute_view.rs b/src/query/expression/src/types/compute_view.rs index 1b2a27af355d3..869aa564b91d9 100644 --- a/src/query/expression/src/types/compute_view.rs +++ b/src/query/expression/src/types/compute_view.rs @@ -20,20 +20,10 @@ use std::ops::Range; use databend_common_exception::Result; use num_traits::AsPrimitive; -use super::column_type_error; -use super::domain_type_error; -use super::scalar_type_error; -use super::string::StringDomain; -use super::string::StringIterator; use super::AccessType; -use super::AnyType; -use super::ArgType; use super::Number; use super::NumberType; use super::SimpleDomain; -use super::StringColumn; -use super::StringType; -use crate::display::scalar_ref_to_string; use crate::Column; use crate::Domain; use crate::ScalarRef; @@ -177,124 +167,3 @@ where SimpleDomain { min, max } } } - -/// For number convert -pub type StringConvertView = ComputeView; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct OwnedStringType; - -impl AccessType for OwnedStringType { - type Scalar = String; - type ScalarRef<'a> = String; - type Column = StringColumn; - type Domain = StringDomain; - type ColumnIterator<'a> = std::iter::Map, fn(&str) -> String>; - - fn to_owned_scalar(scalar: Self::ScalarRef<'_>) -> Self::Scalar { - scalar.to_string() - } - - fn to_scalar_ref(scalar: &Self::Scalar) -> Self::ScalarRef<'_> { - scalar.clone() - } - - fn try_downcast_scalar<'a>(scalar: &ScalarRef<'a>) -> Result> { - scalar - .as_string() - .map(|s| s.to_string()) - .ok_or_else(|| scalar_type_error::(scalar)) - } - - fn try_downcast_column(col: &Column) -> Result { - col.as_string() - .cloned() - .ok_or_else(|| column_type_error::(col)) - } - - fn try_downcast_domain(domain: &Domain) -> Result { - domain - .as_string() - .cloned() - .ok_or_else(|| domain_type_error::(domain)) - } - - fn column_len(col: &Self::Column) -> usize { - col.len() - } - - fn index_column(col: &Self::Column, index: usize) -> Option> { - col.index(index).map(str::to_string) - } - - #[inline] - unsafe fn index_column_unchecked(col: &Self::Column, index: usize) -> Self::ScalarRef<'_> { - col.value_unchecked(index).to_string() - } - - fn slice_column(col: &Self::Column, range: Range) -> Self::Column { - col.clone().sliced(range.start, range.end - range.start) - } - - fn iter_column(col: &Self::Column) -> Self::ColumnIterator<'_> { - col.iter().map(str::to_string) - } - - fn scalar_memory_size(scalar: &Self::ScalarRef<'_>) -> usize { - scalar.len() - } - - fn column_memory_size(col: &Self::Column) -> usize { - col.memory_size() - } - - #[inline(always)] - fn compare(left: Self::ScalarRef<'_>, right: Self::ScalarRef<'_>) -> Ordering { - left.cmp(&right) - } - - #[inline(always)] - fn equal(left: Self::ScalarRef<'_>, right: Self::ScalarRef<'_>) -> bool { - left == right - } - - #[inline(always)] - fn not_equal(left: Self::ScalarRef<'_>, right: Self::ScalarRef<'_>) -> bool { - left != right - } - - #[inline(always)] - fn greater_than(left: Self::ScalarRef<'_>, right: Self::ScalarRef<'_>) -> bool { - left > right - } - - #[inline(always)] - fn greater_than_equal(left: Self::ScalarRef<'_>, right: Self::ScalarRef<'_>) -> bool { - left >= right - } - - #[inline(always)] - fn less_than(left: Self::ScalarRef<'_>, right: Self::ScalarRef<'_>) -> bool { - left < right - } - - #[inline(always)] - fn less_than_equal(left: Self::ScalarRef<'_>, right: Self::ScalarRef<'_>) -> bool { - left <= right - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct StringConvert; - -impl Compute for StringConvert { - fn compute<'a>( - value: ::ScalarRef<'a>, - ) -> ::ScalarRef<'a> { - scalar_ref_to_string(&value) - } - - fn compute_domain(_: &::Domain) -> StringDomain { - StringType::full_domain() - } -} diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_sort_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_sort_adaptor.rs index cf8fa1ea19817..722b630ec8f6c 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_sort_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_sort_adaptor.rs @@ -44,6 +44,7 @@ use itertools::Itertools; use super::batch_merge1; use super::batch_serialize1; use super::AggregateFunctionSortDesc; +use super::SerializeInfo; use super::StateSerde; #[derive(Debug, Clone)] @@ -52,7 +53,7 @@ pub struct SortAggState { } impl StateSerde for SortAggState { - fn serialize_type(_: Option<&dyn super::FunctionData>) -> Vec { + fn serialize_type(_: Option<&dyn SerializeInfo>) -> Vec { vec![StateSerdeItem::Binary(None)] } diff --git a/src/query/functions/src/aggregates/aggregate_approx_count_distinct.rs b/src/query/functions/src/aggregates/aggregate_approx_count_distinct.rs index 63e06b963ddfb..cf70334655470 100644 --- a/src/query/functions/src/aggregates/aggregate_approx_count_distinct.rs +++ b/src/query/functions/src/aggregates/aggregate_approx_count_distinct.rs @@ -36,7 +36,7 @@ use super::AggregateFunction; use super::AggregateFunctionDescription; use super::AggregateFunctionSortDesc; use super::AggregateUnaryFunction; -use super::FunctionData; +use super::SerializeInfo; use super::StateSerde; use super::UnaryState; @@ -48,11 +48,7 @@ where T: ValueType, T::Scalar: Hash, { - fn add( - &mut self, - other: T::ScalarRef<'_>, - _function_data: Option<&dyn FunctionData>, - ) -> Result<()> { + fn add(&mut self, other: T::ScalarRef<'_>, _: &Self::FunctionInfo) -> Result<()> { self.add_object(&T::to_owned_scalar(other)); Ok(()) } @@ -65,7 +61,7 @@ where fn merge_result( &mut self, mut builder: BuilderMut<'_, UInt64Type>, - _function_data: Option<&dyn FunctionData>, + _: &Self::FunctionInfo, ) -> Result<()> { builder.push(self.count() as u64); Ok(()) @@ -73,7 +69,7 @@ where } impl StateSerde for AggregateApproxCountDistinctState { - fn serialize_type(_function_data: Option<&dyn FunctionData>) -> Vec { + fn serialize_type(_: Option<&dyn SerializeInfo>) -> Vec { vec![StateSerdeItem::Binary(None)] } @@ -144,7 +140,7 @@ fn create_templated( let return_type = DataType::Number(NumberDataType::UInt64); with_number_mapped_type!(|NUM_TYPE| match &arguments[0] { DataType::Number(NumberDataType::NUM_TYPE) => { - AggregateUnaryFunction::, NumberType, UInt64Type>::create( + AggregateUnaryFunction::, NumberType, UInt64Type>::new( display_name, return_type, ) @@ -152,7 +148,7 @@ fn create_templated( .finish() } DataType::String => { - AggregateUnaryFunction::, StringType, UInt64Type>::create( + AggregateUnaryFunction::, StringType, UInt64Type>::new( display_name, return_type, ) @@ -160,7 +156,7 @@ fn create_templated( .finish() } DataType::Date => { - AggregateUnaryFunction::, DateType, UInt64Type>::create( + AggregateUnaryFunction::, DateType, UInt64Type>::new( display_name, return_type, ) @@ -168,7 +164,7 @@ fn create_templated( .finish() } DataType::Timestamp => { - AggregateUnaryFunction::, TimestampType, UInt64Type>::create( + AggregateUnaryFunction::, TimestampType, UInt64Type>::new( display_name, return_type, ) @@ -176,7 +172,7 @@ fn create_templated( .finish() } _ => { - AggregateUnaryFunction::, AnyType, UInt64Type>::create( + AggregateUnaryFunction::, AnyType, UInt64Type>::new( display_name, return_type, ) diff --git a/src/query/functions/src/aggregates/aggregate_array_agg.rs b/src/query/functions/src/aggregates/aggregate_array_agg.rs index 8ab9f4c5867fb..ad37f011aea1c 100644 --- a/src/query/functions/src/aggregates/aggregate_array_agg.rs +++ b/src/query/functions/src/aggregates/aggregate_array_agg.rs @@ -63,7 +63,7 @@ use super::AggrStateLoc; use super::AggregateFunction; use super::AggregateFunctionDescription; use super::AggregateFunctionSortDesc; -use super::FunctionData; +use super::SerializeInfo; use super::StateAddr; use super::StateSerde; @@ -132,8 +132,8 @@ where Self: ScalarStateFunc, T: ValueType, { - fn serialize_type(function_data: Option<&dyn FunctionData>) -> Vec { - let return_type = function_data + fn serialize_type(info: Option<&dyn SerializeInfo>) -> Vec { + let return_type = info .and_then(|data| data.as_any().downcast_ref::()) .cloned() .unwrap(); @@ -234,8 +234,8 @@ where T: SimpleType + Debug impl StateSerde for ArrayAggStateSimple where T: SimpleType { - fn serialize_type(function_data: Option<&dyn FunctionData>) -> Vec { - let data_type = function_data + fn serialize_type(info: Option<&dyn SerializeInfo>) -> Vec { + let data_type = info .and_then(|data| data.as_any().downcast_ref::()) .and_then(|ty| ty.as_array()) .unwrap() @@ -336,7 +336,7 @@ where V: ZeroSizeType } impl StateSerde for ArrayAggStateZST { - fn serialize_type(_function_data: Option<&dyn super::FunctionData>) -> Vec { + fn serialize_type(_: Option<&dyn SerializeInfo>) -> Vec { vec![ArrayType::::data_type().into()] } @@ -464,8 +464,8 @@ where T: ArgType + Debug + std::marker::Send impl StateSerde for ArrayAggStateBinary where T: ArgType + std::marker::Send { - fn serialize_type(function_data: Option<&dyn FunctionData>) -> Vec { - let data_type = function_data + fn serialize_type(info: Option<&dyn SerializeInfo>) -> Vec { + let data_type = info .and_then(|data| data.as_any().downcast_ref::()) .and_then(|ty| ty.as_array()) .unwrap() diff --git a/src/query/functions/src/aggregates/aggregate_array_moving.rs b/src/query/functions/src/aggregates/aggregate_array_moving.rs index e17b7eee16b4c..2509aced68de6 100644 --- a/src/query/functions/src/aggregates/aggregate_array_moving.rs +++ b/src/query/functions/src/aggregates/aggregate_array_moving.rs @@ -46,6 +46,7 @@ use super::AggregateFunction; use super::AggregateFunctionDescription; use super::AggregateFunctionRef; use super::AggregateFunctionSortDesc; +use super::SerializeInfo; use super::StateAddr; use super::StateSerde; @@ -221,7 +222,7 @@ where Self: SumState, T: Number, { - fn serialize_type(_function_data: Option<&dyn super::FunctionData>) -> Vec { + fn serialize_type(_: Option<&dyn SerializeInfo>) -> Vec { vec![ArrayType::>::data_type().into()] } @@ -452,7 +453,7 @@ where Self: SumState, T: Decimal, { - fn serialize_type(_function_data: Option<&dyn super::FunctionData>) -> Vec { + fn serialize_type(_: Option<&dyn SerializeInfo>) -> Vec { vec![DataType::Array(Box::new(DataType::Decimal(T::default_decimal_size()))).into()] } diff --git a/src/query/functions/src/aggregates/aggregate_avg.rs b/src/query/functions/src/aggregates/aggregate_avg.rs index 051d2d5bd8905..e5fff7a29b0a4 100644 --- a/src/query/functions/src/aggregates/aggregate_avg.rs +++ b/src/query/functions/src/aggregates/aggregate_avg.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::any::Any; use std::marker::PhantomData; use databend_common_exception::ErrorCode; @@ -38,7 +37,7 @@ use super::AggregateFunctionDescription; use super::AggregateFunctionRef; use super::AggregateFunctionSortDesc; use super::AggregateUnaryFunction; -use super::FunctionData; +use super::SerializeInfo; use super::StateSerde; use super::UnaryState; @@ -73,11 +72,7 @@ where T::Scalar: Number + AsPrimitive, TSum::Scalar: Number + AsPrimitive + std::ops::AddAssign, { - fn add( - &mut self, - other: T::ScalarRef<'_>, - _function_data: Option<&dyn FunctionData>, - ) -> Result<()> { + fn add(&mut self, other: T::ScalarRef<'_>, _: &Self::FunctionInfo) -> Result<()> { self.count += 1; let other = T::to_owned_scalar(other).as_(); self.value += other; @@ -93,7 +88,7 @@ where fn merge_result( &mut self, mut builder: BuilderMut<'_, Float64Type>, - _function_data: Option<&dyn FunctionData>, + _: &Self::FunctionInfo, ) -> Result<()> { let value = self.value.as_() / (self.count as f64); builder.push(F64::from(value)); @@ -108,7 +103,7 @@ where T::Scalar: Number + AsPrimitive, TSum::Scalar: Number + AsPrimitive + std::ops::AddAssign, { - fn serialize_type(_: Option<&dyn FunctionData>) -> Vec { + fn serialize_type(_: Option<&dyn SerializeInfo>) -> Vec { std::vec![ StateSerdeItem::DataType(TSum::data_type()), StateSerdeItem::DataType(UInt64Type::data_type()) @@ -158,12 +153,6 @@ struct DecimalAvgData { pub scale_add: u8, } -impl FunctionData for DecimalAvgData { - fn as_any(&self) -> &dyn Any { - self - } -} - struct DecimalAvgState { pub value: T, pub count: u64, @@ -202,7 +191,9 @@ impl UnaryState, DecimalType> for DecimalAvgState where T: Decimal + std::ops::AddAssign { - fn add(&mut self, other: T, _function_data: Option<&dyn FunctionData>) -> Result<()> { + type FunctionInfo = DecimalAvgData; + + fn add(&mut self, other: T, _: &Self::FunctionInfo) -> Result<()> { self.add_internal(1, other) } @@ -213,16 +204,8 @@ where T: Decimal + std::ops::AddAssign fn merge_result( &mut self, mut builder: as ValueType>::ColumnBuilderMut<'_>, - function_data: Option<&dyn FunctionData>, + decimal_avg_data: &DecimalAvgData, ) -> Result<()> { - // # Safety - // `downcast_ref_unchecked` will check type in debug mode using dynamic dispatch, - let decimal_avg_data = unsafe { - function_data - .unwrap() - .as_any() - .downcast_ref_unchecked::() - }; match self .value .checked_mul(T::e(decimal_avg_data.scale_add)) @@ -244,7 +227,7 @@ where T: Decimal + std::ops::AddAssign impl StateSerde for DecimalAvgState where T: Decimal + std::ops::AddAssign { - fn serialize_type(_: Option<&dyn FunctionData>) -> Vec { + fn serialize_type(_: Option<&dyn SerializeInfo>) -> Vec { std::vec![ DataType::Decimal(T::default_decimal_size()).into(), UInt64Type::data_type().into() @@ -308,7 +291,6 @@ pub fn try_create_aggregate_avg_function( NumberType, Float64Type, >::create(display_name, return_type) - .finish() } DataType::Decimal(s) => { with_decimal_mapped_type!(|DECIMAL| match s.data_kind() { @@ -326,16 +308,22 @@ pub fn try_create_aggregate_avg_function( DecimalAvgState, DecimalType, DecimalType, - >::create(display_name, return_type) - .with_function_data(Box::new(DecimalAvgData { scale_add })) + >::with_function_info( + display_name, + return_type, + DecimalAvgData { scale_add }, + ) .finish() } else { AggregateUnaryFunction::< DecimalAvgState, DecimalType, DecimalType, - >::create(display_name, return_type) - .with_function_data(Box::new(DecimalAvgData { scale_add })) + >::with_function_info( + display_name, + return_type, + DecimalAvgData { scale_add }, + ) .finish() } } diff --git a/src/query/functions/src/aggregates/aggregate_boolean.rs b/src/query/functions/src/aggregates/aggregate_boolean.rs index 05f35cbe9a59c..4c0b3a42183ef 100644 --- a/src/query/functions/src/aggregates/aggregate_boolean.rs +++ b/src/query/functions/src/aggregates/aggregate_boolean.rs @@ -22,6 +22,7 @@ use databend_common_expression::AggrStateLoc; use databend_common_expression::AggregateFunctionRef; use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; +use databend_common_expression::ColumnView; use databend_common_expression::Scalar; use databend_common_expression::StateAddr; use databend_common_expression::StateSerdeItem; @@ -34,7 +35,7 @@ use super::AggrState; use super::AggregateFunctionDescription; use super::AggregateFunctionSortDesc; use super::AggregateUnaryFunction; -use super::FunctionData; +use super::SerializeInfo; use super::StateSerde; use super::UnaryState; @@ -85,7 +86,7 @@ pub fn boolean_batch(inner: Bitmap, validity: Option<&Bitmap } impl UnaryState for BooleanState { - fn add(&mut self, other: bool, _function_data: Option<&dyn FunctionData>) -> Result<()> { + fn add(&mut self, other: bool, _: &Self::FunctionInfo) -> Result<()> { if IS_AND { self.value &= other; } else { @@ -96,10 +97,14 @@ impl UnaryState for BooleanState, validity: Option<&Bitmap>, - _function_data: Option<&dyn FunctionData>, + _: &Self::FunctionInfo, ) -> Result<()> { + let other = match other { + ColumnView::Const(b, n) => Bitmap::new_constant(b, n), + ColumnView::Column(column) => column, + }; if IS_AND { self.value &= boolean_batch::(other, validity); } else { @@ -120,7 +125,7 @@ impl UnaryState for BooleanState, - _function_data: Option<&dyn FunctionData>, + _: &Self::FunctionInfo, ) -> Result<()> { builder.push(self.value); Ok(()) @@ -128,7 +133,7 @@ impl UnaryState for BooleanState StateSerde for BooleanState { - fn serialize_type(_function_data: Option<&dyn FunctionData>) -> Vec { + fn serialize_type(_: Option<&dyn SerializeInfo>) -> Vec { vec![DataType::Boolean.into()] } @@ -179,7 +184,6 @@ pub fn try_create_aggregate_boolean_function( display_name, return_type, ) - .finish() } _ => Err(ErrorCode::BadDataValueType(format!( diff --git a/src/query/functions/src/aggregates/aggregate_distinct_state.rs b/src/query/functions/src/aggregates/aggregate_distinct_state.rs index adadcf4fa455d..3500952ad19ac 100644 --- a/src/query/functions/src/aggregates/aggregate_distinct_state.rs +++ b/src/query/functions/src/aggregates/aggregate_distinct_state.rs @@ -43,7 +43,7 @@ use siphasher::sip128::SipHasher24; use super::batch_merge1; use super::batch_serialize1; use super::borsh_partial_deserialize; -use super::FunctionData; +use super::SerializeInfo; use super::StateAddr; use super::StateSerde; @@ -146,7 +146,7 @@ impl DistinctStateFunc for AggregateDistinctState { } impl StateSerde for AggregateDistinctState { - fn serialize_type(_function_data: Option<&dyn FunctionData>) -> Vec { + fn serialize_type(_: Option<&dyn SerializeInfo>) -> Vec { vec![DataType::Array(Box::new(DataType::Binary)).into()] } @@ -190,6 +190,7 @@ pub struct AggregateDistinctStringState { impl DistinctStateFunc for AggregateDistinctStringState { fn new() -> Self { + #![allow(clippy::arc_with_non_send_sync)] AggregateDistinctStringState { set: ShortStringHashSet::<[u8]>::with_capacity(4, Arc::new(Bump::new())), } @@ -251,7 +252,7 @@ impl DistinctStateFunc for AggregateDistinctStringState { } impl StateSerde for AggregateDistinctStringState { - fn serialize_type(_function_data: Option<&dyn FunctionData>) -> Vec { + fn serialize_type(_: Option<&dyn SerializeInfo>) -> Vec { vec![DataType::Array(Box::new(DataType::Binary)).into()] } @@ -356,7 +357,7 @@ where T: Number + HashtableKeyable impl StateSerde for AggregateDistinctNumberState where T: Number + HashtableKeyable { - fn serialize_type(_function_data: Option<&dyn FunctionData>) -> Vec { + fn serialize_type(_: Option<&dyn SerializeInfo>) -> Vec { vec![DataType::Array(Box::new(NumberType::::data_type())).into()] } @@ -493,7 +494,7 @@ impl AggregateUniqStringState { } impl StateSerde for AggregateUniqStringState { - fn serialize_type(_function_data: Option<&dyn FunctionData>) -> Vec { + fn serialize_type(_: Option<&dyn SerializeInfo>) -> Vec { vec![StateSerdeItem::Binary(None)] } diff --git a/src/query/functions/src/aggregates/aggregate_histogram.rs b/src/query/functions/src/aggregates/aggregate_histogram.rs index 46258afe3c30b..201cfd381f19e 100644 --- a/src/query/functions/src/aggregates/aggregate_histogram.rs +++ b/src/query/functions/src/aggregates/aggregate_histogram.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::any::Any; use std::collections::btree_map::BTreeMap; use std::collections::btree_map::Entry; use std::fmt::Display; @@ -42,20 +41,14 @@ use super::AggrState; use super::AggregateFunctionDescription; use super::AggregateFunctionSortDesc; use super::AggregateUnaryFunction; -use super::FunctionData; +use super::SerializeInfo; use super::StateSerde; use super::StateSerdeItem; use super::UnaryState; -struct HistogramData { - pub max_num_buckets: u64, - pub data_type: DataType, -} - -impl FunctionData for HistogramData { - fn as_any(&self) -> &dyn Any { - self - } +pub struct HistogramData { + max_num_buckets: u64, + data_type: DataType, } #[derive(BorshSerialize, BorshDeserialize)] @@ -84,11 +77,9 @@ where T: ValueType, T::Scalar: Ord + BorshSerialize + BorshDeserialize + Serialize + Display, { - fn add( - &mut self, - other: T::ScalarRef<'_>, - _function_data: Option<&dyn FunctionData>, - ) -> Result<()> { + type FunctionInfo = HistogramData; + + fn add(&mut self, other: T::ScalarRef<'_>, _: &Self::FunctionInfo) -> Result<()> { let other = T::to_owned_scalar(other); match self.value_map.entry(other) { Entry::Occupied(o) => *o.into_mut() += 1, @@ -115,15 +106,8 @@ where fn merge_result( &mut self, mut builder: BuilderMut<'_, StringType>, - function_data: Option<&dyn FunctionData>, + histogram_data: &HistogramData, ) -> Result<()> { - let histogram_data = unsafe { - function_data - .unwrap() - .as_any() - .downcast_ref_unchecked::() - }; - let mut buckets = build_histogram(&self.value_map, histogram_data.max_num_buckets); let format_scalar = |scalar| { @@ -154,7 +138,7 @@ where T: ValueType, T::Scalar: BorshSerialize + BorshDeserialize + Ord + Display + Serialize, { - fn serialize_type(_function_data: Option<&dyn FunctionData>) -> Vec { + fn serialize_type(_: Option<&dyn SerializeInfo>) -> Vec { vec![StateSerdeItem::Binary(None)] } @@ -202,11 +186,14 @@ pub fn try_create_aggregate_histogram_function( HistogramState>, NumberType, StringType, - >::create(display_name, DataType::String) - .with_function_data(Box::new(HistogramData { - max_num_buckets, - data_type, - })) + >::with_function_info( + display_name, + DataType::String, + HistogramData { + max_num_buckets, + data_type, + }, + ) .with_need_drop(true) .finish() } @@ -217,25 +204,28 @@ pub fn try_create_aggregate_histogram_function( HistogramState>, DecimalType, StringType, - >::create(display_name, DataType::String) - .with_function_data(Box::new(HistogramData { - max_num_buckets, - data_type, - })) + >::with_function_info( + display_name, + DataType::String, + HistogramData { + max_num_buckets, + data_type, + }, + ) .with_need_drop(true) .finish() } }) } DataType::String => { - AggregateUnaryFunction::, StringType, StringType>::create( + AggregateUnaryFunction::, StringType, StringType>::with_function_info( display_name, DataType::String, + HistogramData { + max_num_buckets, + data_type, + }, ) - .with_function_data(Box::new(HistogramData { - max_num_buckets, - data_type, - })) .with_need_drop(true) .finish() } @@ -244,25 +234,26 @@ pub fn try_create_aggregate_histogram_function( HistogramState, TimestampType, StringType, - >::create( - display_name, DataType::String + >::with_function_info( + display_name, + DataType::String, + HistogramData { + max_num_buckets, + data_type, + }, ) - .with_function_data(Box::new(HistogramData { - max_num_buckets, - data_type, - })) .with_need_drop(true) .finish() } DataType::Date => { - AggregateUnaryFunction::, DateType, StringType>::create( + AggregateUnaryFunction::, DateType, StringType>::with_function_info( display_name, DataType::String, + HistogramData { + max_num_buckets, + data_type, + }, ) - .with_function_data(Box::new(HistogramData { - max_num_buckets, - data_type, - })) .with_need_drop(true) .finish() } @@ -284,7 +275,7 @@ pub fn aggregate_histogram_function_desc() -> AggregateFunctionDescription { ) } -fn get_max_num_buckets(params: &Vec, display_name: &str) -> Result { +fn get_max_num_buckets(params: &[Scalar], display_name: &str) -> Result { if params.len() != 1 { return Ok(128); } diff --git a/src/query/functions/src/aggregates/aggregate_json_array_agg.rs b/src/query/functions/src/aggregates/aggregate_json_array_agg.rs index 5a712a2cdd178..575e8cedb65b3 100644 --- a/src/query/functions/src/aggregates/aggregate_json_array_agg.rs +++ b/src/query/functions/src/aggregates/aggregate_json_array_agg.rs @@ -141,7 +141,7 @@ where T: ValueType, T::Scalar: BorshSerialize + BorshDeserialize, { - fn serialize_type(_function_data: Option<&dyn super::FunctionData>) -> Vec { + fn serialize_type(_: Option<&dyn super::SerializeInfo>) -> Vec { vec![StateSerdeItem::Binary(None)] } diff --git a/src/query/functions/src/aggregates/aggregate_kurtosis.rs b/src/query/functions/src/aggregates/aggregate_kurtosis.rs index e91e21d4e1d0b..9fcfd0cf3b25e 100644 --- a/src/query/functions/src/aggregates/aggregate_kurtosis.rs +++ b/src/query/functions/src/aggregates/aggregate_kurtosis.rs @@ -37,7 +37,7 @@ use super::AggregateFunctionDescription; use super::AggregateFunctionRef; use super::AggregateFunctionSortDesc; use super::AggregateUnaryFunction; -use super::FunctionData; +use super::SerializeInfo; use super::StateSerde; use super::UnaryState; @@ -55,11 +55,7 @@ where T: AccessType, T::Scalar: AsPrimitive, { - fn add( - &mut self, - other: T::ScalarRef<'_>, - _function_data: Option<&dyn FunctionData>, - ) -> Result<()> { + fn add(&mut self, other: T::ScalarRef<'_>, _: &Self::FunctionInfo) -> Result<()> { let other = T::to_owned_scalar(other).as_(); self.n += 1; self.sum += other; @@ -84,7 +80,7 @@ where fn merge_result( &mut self, mut builder: BuilderMut<'_, Float64Type>, - _function_data: Option<&dyn FunctionData>, + _: &Self::FunctionInfo, ) -> Result<()> { if self.n <= 3 { builder.push(F64::from(0_f64)); @@ -124,7 +120,7 @@ where } impl StateSerde for KurtosisState { - fn serialize_type(_function_data: Option<&dyn FunctionData>) -> Vec { + fn serialize_type(_: Option<&dyn SerializeInfo>) -> Vec { vec![StateSerdeItem::Binary(Some(40))] } @@ -172,7 +168,6 @@ pub fn try_create_aggregate_kurtosis_function( NumberConvertView, Float64Type, >::create(display_name, return_type) - .finish() } DataType::Decimal(s) => { with_decimal_mapped_type!(|DECIMAL| match s.data_kind() { @@ -181,10 +176,7 @@ pub fn try_create_aggregate_kurtosis_function( KurtosisState, DecimalF64View, Float64Type, - >::create( - display_name, return_type - ) - .finish() + >::create(display_name, return_type) } }) } diff --git a/src/query/functions/src/aggregates/aggregate_min_max_any.rs b/src/query/functions/src/aggregates/aggregate_min_max_any.rs index 454ec43593677..60613815259ab 100644 --- a/src/query/functions/src/aggregates/aggregate_min_max_any.rs +++ b/src/query/functions/src/aggregates/aggregate_min_max_any.rs @@ -27,6 +27,7 @@ use databend_common_expression::with_number_mapped_type; use databend_common_expression::AggrStateLoc; use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; +use databend_common_expression::ColumnView; use databend_common_expression::Scalar; use databend_common_expression::StateAddr; use databend_common_expression::SELECTIVITY_THRESHOLD; @@ -50,7 +51,7 @@ use super::AggregateFunctionDescription; use super::AggregateFunctionFeatures; use super::AggregateFunctionSortDesc; use super::AggregateUnaryFunction; -use super::FunctionData; +use super::SerializeInfo; use super::StateSerde; use super::StateSerdeItem; use super::UnaryState; @@ -68,7 +69,7 @@ where C: ChangeIf impl UnaryState for MinMaxStringState where C: ChangeIf { - fn add(&mut self, other: &str, _function_data: Option<&dyn FunctionData>) -> Result<()> { + fn add(&mut self, other: &str, _: &Self::FunctionInfo) -> Result<()> { match &self.value { Some(v) => { if C::change_if(&StringType::to_scalar_ref(v), &other) { @@ -84,16 +85,16 @@ where C: ChangeIf fn add_batch( &mut self, - other: StringColumn, + other: ColumnView, validity: Option<&Bitmap>, - function_data: Option<&dyn FunctionData>, + _: &Self::FunctionInfo, ) -> Result<()> { - let column_len = StringType::column_len(&other); + let column_len = other.len(); if column_len == 0 { return Ok(()); } - let column_iter = 0..other.len(); + let column_iter = 0..column_len; if let Some(validity) = validity { if validity.null_count() == column_len { return Ok(()); @@ -103,25 +104,33 @@ where C: ChangeIf .filter(|(_, valid)| *valid) .map(|(idx, _)| idx) .reduce(|l, r| { - if !C::change_if_ordering(StringColumn::compare(&other, l, &other, r)) { + let ordering = match &other { + ColumnView::Const(_, _) => std::cmp::Ordering::Equal, + ColumnView::Column(other) => StringColumn::compare(other, l, other, r), + }; + if !C::change_if_ordering(ordering) { l } else { r } }); if let Some(v) = v { - let _ = self.add(other.index(v).unwrap(), function_data); + self.add(other.index(v).unwrap(), &())?; } } else { let v = column_iter.reduce(|l, r| { - if !C::change_if_ordering(StringColumn::compare(&other, l, &other, r)) { + let ordering = match &other { + ColumnView::Const(_, _) => std::cmp::Ordering::Equal, + ColumnView::Column(other) => StringColumn::compare(other, l, other, r), + }; + if !C::change_if_ordering(ordering) { l } else { r } }); if let Some(v) = v { - let _ = self.add(other.index(v).unwrap(), function_data); + self.add(other.index(v).unwrap(), &())?; } } Ok(()) @@ -129,7 +138,7 @@ where C: ChangeIf fn merge(&mut self, rhs: &Self) -> Result<()> { if let Some(v) = &rhs.value { - self.add(v.as_str(), None)?; + self.add(v.as_str(), &())?; } Ok(()) } @@ -137,7 +146,7 @@ where C: ChangeIf fn merge_result( &mut self, mut builder: BuilderMut<'_, StringType>, - _function_data: Option<&dyn FunctionData>, + _: &Self::FunctionInfo, ) -> Result<()> { if let Some(v) = &self.value { builder.push_item(v.as_str()); @@ -151,7 +160,7 @@ where C: ChangeIf impl StateSerde for MinMaxStringState where C: ChangeIf { - fn serialize_type(_function_data: Option<&dyn FunctionData>) -> Vec { + fn serialize_type(_: Option<&dyn SerializeInfo>) -> Vec { vec![DataType::String.wrap_nullable().into()] } @@ -224,11 +233,7 @@ where T::Scalar: BorshSerialize + BorshDeserialize, C: ChangeIf, { - fn add( - &mut self, - other: T::ScalarRef<'_>, - _function_data: Option<&dyn FunctionData>, - ) -> Result<()> { + fn add(&mut self, other: T::ScalarRef<'_>, _: &Self::FunctionInfo) -> Result<()> { match &self.value { Some(v) => { if C::change_if(&T::to_scalar_ref(v), &other) { @@ -244,16 +249,16 @@ where fn add_batch( &mut self, - other: T::Column, + other: ColumnView, validity: Option<&Bitmap>, - function_data: Option<&dyn FunctionData>, + _: &Self::FunctionInfo, ) -> Result<()> { - let column_len = T::column_len(&other); + let column_len = other.len(); if column_len == 0 { return Ok(()); } - let column_iter = T::iter_column(&other); + let column_iter = other.iter(); if let Some(v) = validity { if v.true_count() as f64 / v.len() as f64 >= SELECTIVITY_THRESHOLD { let value = column_iter @@ -263,18 +268,17 @@ where .reduce(|l, r| if !C::change_if(&l, &r) { l } else { r }); if let Some(value) = value { - self.add(value, function_data)?; + self.add(value, &())?; } } else { for idx in TrueIdxIter::new(v.len(), Some(v)) { - let v = unsafe { T::index_column_unchecked(&other, idx) }; - self.add(v, function_data)?; + self.add(unsafe { other.index_unchecked(idx) }, &())?; } }; } else { let v = column_iter.reduce(|l, r| if !C::change_if(&l, &r) { l } else { r }); if let Some(v) = v { - self.add(v, function_data)?; + self.add(v, &())?; } } Ok(()) @@ -282,7 +286,7 @@ where fn merge(&mut self, rhs: &Self) -> Result<()> { if let Some(v) = &rhs.value { - self.add(T::to_scalar_ref(v), None)?; + self.add(T::to_scalar_ref(v), &())?; } Ok(()) } @@ -290,7 +294,7 @@ where fn merge_result( &mut self, mut builder: T::ColumnBuilderMut<'_>, - _function_data: Option<&dyn FunctionData>, + _: &Self::FunctionInfo, ) -> Result<()> { if let Some(v) = &self.value { builder.push_item(T::to_scalar_ref(v)); @@ -308,8 +312,8 @@ where T::Scalar: BorshSerialize + BorshDeserialize, C: ChangeIf, { - fn serialize_type(function_data: Option<&dyn FunctionData>) -> Vec { - let data_type = function_data + fn serialize_type(info: Option<&dyn SerializeInfo>) -> Vec { + let data_type = info .and_then(|data| data.as_any().downcast_ref::()) .cloned() .unwrap(); @@ -393,22 +397,19 @@ pub fn try_create_aggregate_min_max_any_function( with_simple_no_number_no_string_mapped_type!(|T| match data_type { DataType::T => { let return_type = data_type.clone(); - AggregateUnaryFunction::, T, T>::create( + AggregateUnaryFunction::, T, T>::new( display_name, return_type, ) + .with_serialize_info(Box::new(data_type)) .with_need_drop(need_drop) - .with_function_data(Box::new(data_type)) .finish() } DataType::String => { let return_type = data_type.clone(); - AggregateUnaryFunction::< - MinMaxStringState, - StringType, - StringType, - >::create( - display_name, return_type + AggregateUnaryFunction::, StringType, StringType>::new( + display_name, + return_type, ) .with_need_drop(need_drop) .finish() @@ -421,8 +422,8 @@ pub fn try_create_aggregate_min_max_any_function( MinMaxAnyState, CMP>, NumberType, NumberType, - >::create(display_name, return_type) - .with_function_data(Box::new(data_type)) + >::new(display_name, return_type) + .with_serialize_info(Box::new(data_type)) .finish() } }) @@ -436,22 +437,17 @@ pub fn try_create_aggregate_min_max_any_function( DecimalType, DecimalType, >::create(display_name, return_type) - .with_function_data(Box::new(data_type)) - .finish() } }) } _ => { let return_type = data_type.clone(); - AggregateUnaryFunction::< - MinMaxAnyState, - AnyType, - AnyType, - >::create( - display_name, return_type + AggregateUnaryFunction::, AnyType, AnyType>::new( + display_name, + return_type, ) + .with_serialize_info(Box::new(data_type)) .with_need_drop(need_drop) - .with_function_data(Box::new(data_type)) .finish() } }) diff --git a/src/query/functions/src/aggregates/aggregate_min_max_any_decimal.rs b/src/query/functions/src/aggregates/aggregate_min_max_any_decimal.rs index 6b19daa8b055c..5df0c5952c414 100644 --- a/src/query/functions/src/aggregates/aggregate_min_max_any_decimal.rs +++ b/src/query/functions/src/aggregates/aggregate_min_max_any_decimal.rs @@ -21,12 +21,13 @@ use databend_common_expression::types::*; use databend_common_expression::AggrStateLoc; use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; +use databend_common_expression::ColumnView; use databend_common_expression::StateAddr; use super::aggregate_scalar_state::ChangeIf; use super::batch_merge1; use super::batch_serialize1; -use super::FunctionData; +use super::SerializeInfo; use super::StateSerde; use super::StateSerdeItem; use super::UnaryState; @@ -61,11 +62,7 @@ where T::Scalar: Decimal, C: ChangeIf, { - fn add( - &mut self, - other: T::ScalarRef<'_>, - _function_data: Option<&dyn FunctionData>, - ) -> Result<()> { + fn add(&mut self, other: T::ScalarRef<'_>, _: &Self::FunctionInfo) -> Result<()> { match self.value { Some(v) => { let v = T::Scalar::from_u64_array(v); @@ -82,16 +79,16 @@ where fn add_batch( &mut self, - other: T::Column, + other: ColumnView, validity: Option<&Bitmap>, - function_data: Option<&dyn FunctionData>, + _: &Self::FunctionInfo, ) -> Result<()> { - let column_len = T::column_len(&other); + let column_len = other.len(); if column_len == 0 { return Ok(()); } - let column_iter = T::iter_column(&other); + let column_iter = other.iter(); match validity { Some(validity) if validity.null_count() > 0 && validity.null_count() < validity.len() => @@ -102,14 +99,14 @@ where .map(|(v, _)| v) .reduce(|l, r| if !C::change_if(&l, &r) { l } else { r }); if let Some(v) = v { - let _ = self.add(v, function_data); + self.add(v, &())?; } } Some(validity) if validity.null_count() == validity.len() => {} _ => { let v = column_iter.reduce(|l, r| if !C::change_if(&l, &r) { l } else { r }); if let Some(v) = v { - let _ = self.add(v, function_data); + self.add(v, &())?; } } } @@ -120,7 +117,7 @@ where fn merge(&mut self, rhs: &Self) -> Result<()> { if let Some(v) = rhs.value { let v = T::Scalar::from_u64_array(v); - self.add(T::to_scalar_ref(&v), None)?; + self.add(T::to_scalar_ref(&v), &())?; } Ok(()) } @@ -128,7 +125,7 @@ where fn merge_result( &mut self, mut builder: T::ColumnBuilderMut<'_>, - _function_data: Option<&dyn FunctionData>, + _: &Self::FunctionInfo, ) -> Result<()> { if let Some(v) = self.value { let v = T::Scalar::from_u64_array(v); @@ -146,7 +143,7 @@ where T::Scalar: Decimal, C: ChangeIf, { - fn serialize_type(_function_data: Option<&dyn FunctionData>) -> Vec { + fn serialize_type(_: Option<&dyn SerializeInfo>) -> Vec { vec![DataType::Decimal(T::Scalar::default_decimal_size()) .wrap_nullable() .into()] diff --git a/src/query/functions/src/aggregates/aggregate_mode.rs b/src/query/functions/src/aggregates/aggregate_mode.rs index 12a647228e0f4..5b8d568b132e0 100644 --- a/src/query/functions/src/aggregates/aggregate_mode.rs +++ b/src/query/functions/src/aggregates/aggregate_mode.rs @@ -37,7 +37,7 @@ use super::batch_merge1; use super::AggregateFunctionDescription; use super::AggregateFunctionSortDesc; use super::AggregateUnaryFunction; -use super::FunctionData; +use super::SerializeInfo; use super::StateSerde; use super::StateSerdeItem; use super::UnaryState; @@ -68,11 +68,7 @@ where T: ValueType, T::Scalar: Ord + Hash + BorshSerialize + BorshDeserialize, { - fn add( - &mut self, - other: T::ScalarRef<'_>, - _function_data: Option<&dyn FunctionData>, - ) -> Result<()> { + fn add(&mut self, other: T::ScalarRef<'_>, _: &Self::FunctionInfo) -> Result<()> { let other = T::to_owned_scalar(other); match self.frequency_map.entry(other) { Entry::Occupied(o) => *o.into_mut() += 1, @@ -100,7 +96,7 @@ where fn merge_result( &mut self, mut builder: T::ColumnBuilderMut<'_>, - _function_data: Option<&dyn FunctionData>, + _: &Self::FunctionInfo, ) -> Result<()> { if self.frequency_map.is_empty() { builder.push_default(); @@ -122,7 +118,7 @@ where T: ValueType, T::Scalar: Ord + Hash + BorshSerialize + BorshDeserialize, { - fn serialize_type(_function_data: Option<&dyn FunctionData>) -> Vec { + fn serialize_type(_: Option<&dyn SerializeInfo>) -> Vec { vec![StateSerdeItem::Binary(None)] } @@ -169,9 +165,7 @@ pub fn try_create_aggregate_mode_function( ModeState>, NumberType, NumberType, - >::create( - display_name, data_type.clone() - ) + >::new(display_name, data_type.clone()) .with_need_drop(true) .finish() } @@ -182,14 +176,14 @@ pub fn try_create_aggregate_mode_function( ModeState>, DecimalType, DecimalType, - >::create(display_name, data_type.clone()) + >::new(display_name, data_type.clone()) .with_need_drop(true) .finish() } }) } _ => { - AggregateUnaryFunction::, AnyType, AnyType>::create( + AggregateUnaryFunction::, AnyType, AnyType>::new( display_name, data_type.clone(), ) diff --git a/src/query/functions/src/aggregates/aggregate_quantile_cont.rs b/src/query/functions/src/aggregates/aggregate_quantile_cont.rs index f3e08bee39e48..1eb41cb1692d8 100644 --- a/src/query/functions/src/aggregates/aggregate_quantile_cont.rs +++ b/src/query/functions/src/aggregates/aggregate_quantile_cont.rs @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::any::Any; - use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::types::array::ArrayColumnBuilderMut; @@ -39,7 +37,7 @@ use super::AggregateFunctionDescription; use super::AggregateFunctionRef; use super::AggregateFunctionSortDesc; use super::AggregateUnaryFunction; -use super::FunctionData; +use super::SerializeInfo; use super::StateSerde; use super::StateSerdeItem; use super::UnaryState; @@ -51,11 +49,6 @@ pub(crate) struct QuantileData { pub(crate) levels: Vec, } -impl FunctionData for QuantileData { - fn as_any(&self) -> &dyn Any { - self - } -} #[derive(Default)] struct QuantileContState { pub value: Vec, @@ -82,11 +75,9 @@ where T::Scalar: Number + AsPrimitive, R: ValueType, { - fn add( - &mut self, - other: T::ScalarRef<'_>, - _function_data: Option<&dyn FunctionData>, - ) -> Result<()> { + type FunctionInfo = QuantileData; + + fn add(&mut self, other: T::ScalarRef<'_>, _: &Self::FunctionInfo) -> Result<()> { let other = T::to_owned_scalar(other).as_(); self.value.push(other.into()); Ok(()) @@ -100,15 +91,9 @@ where fn merge_result( &mut self, mut builder: R::ColumnBuilderMut<'_>, - function_data: Option<&dyn FunctionData>, + quantile_cont_data: &QuantileData, ) -> Result<()> { let value_len = self.value.len(); - let quantile_cont_data = unsafe { - function_data - .unwrap() - .as_any() - .downcast_ref_unchecked::() - }; if quantile_cont_data.levels.len() > 1 { let indices = quantile_cont_data .levels @@ -150,7 +135,7 @@ where } impl StateSerde for QuantileContState { - fn serialize_type(_function_data: Option<&dyn FunctionData>) -> Vec { + fn serialize_type(_: Option<&dyn SerializeInfo>) -> Vec { vec![DataType::Array(Box::new(Float64Type::data_type())).into()] } @@ -242,11 +227,9 @@ where T: ValueType, T::Scalar: Decimal, { - fn add( - &mut self, - other: T::ScalarRef<'_>, - _function_data: Option<&dyn FunctionData>, - ) -> Result<()> { + type FunctionInfo = QuantileData; + + fn add(&mut self, other: T::ScalarRef<'_>, _: &Self::FunctionInfo) -> Result<()> { self.value.push(T::to_owned_scalar(other)); Ok(()) } @@ -263,15 +246,9 @@ where fn merge_result( &mut self, mut builder: ArrayColumnBuilderMut<'_, T>, - function_data: Option<&dyn FunctionData>, + quantile_cont_data: &QuantileData, ) -> Result<()> { let value_len = self.value.len(); - let quantile_cont_data = unsafe { - function_data - .unwrap() - .as_any() - .downcast_ref_unchecked::() - }; if quantile_cont_data.levels.len() > 1 { let indices = quantile_cont_data @@ -301,11 +278,9 @@ where T: ValueType, T::Scalar: Decimal, { - fn add( - &mut self, - other: T::ScalarRef<'_>, - _function_data: Option<&dyn FunctionData>, - ) -> Result<()> { + type FunctionInfo = QuantileData; + + fn add(&mut self, other: T::ScalarRef<'_>, _: &Self::FunctionInfo) -> Result<()> { self.value.push(T::to_owned_scalar(other)); Ok(()) } @@ -322,15 +297,9 @@ where fn merge_result( &mut self, mut builder: T::ColumnBuilderMut<'_>, - function_data: Option<&dyn FunctionData>, + quantile_cont_data: &QuantileData, ) -> Result<()> { let value_len = self.value.len(); - let quantile_cont_data = unsafe { - function_data - .unwrap() - .as_any() - .downcast_ref_unchecked::() - }; let (frac, whole) = libm::modf((value_len - 1) as f64 * quantile_cont_data.levels[0]); let whole = whole as usize; @@ -350,7 +319,7 @@ where T: ValueType, T::Scalar: Decimal, { - fn serialize_type(_function_data: Option<&dyn FunctionData>) -> Vec { + fn serialize_type(_: Option<&dyn SerializeInfo>) -> Vec { vec![DataType::Array(Box::new(DataType::Decimal( T::Scalar::default_decimal_size(), ))) @@ -418,8 +387,9 @@ pub fn try_create_aggregate_quantile_cont_function( QuantileContState, NumberType, ArrayType, - >::create(display_name, return_type) - .with_function_data(Box::new(QuantileData { levels })) + >::with_function_info(display_name, return_type, QuantileData { + levels, + }) .with_need_drop(true) .finish() } else { @@ -428,10 +398,9 @@ pub fn try_create_aggregate_quantile_cont_function( QuantileContState, NumberType, Float64Type, - >::create( - display_name, return_type + >::with_function_info( + display_name, return_type, QuantileData { levels } ) - .with_function_data(Box::new(QuantileData { levels })) .with_need_drop(true) .finish() } @@ -446,10 +415,11 @@ pub fn try_create_aggregate_quantile_cont_function( DecimalQuantileContState>, DecimalType, ArrayType>, - >::create( - display_name, DataType::Array(Box::new(data_type)) + >::with_function_info( + display_name, + DataType::Array(Box::new(data_type)), + QuantileData { levels }, ) - .with_function_data(Box::new(QuantileData { levels })) .with_need_drop(true) .finish() } else { @@ -457,15 +427,15 @@ pub fn try_create_aggregate_quantile_cont_function( DecimalQuantileContState>, DecimalType, DecimalType, - >::create(display_name, data_type) - .with_function_data(Box::new(QuantileData { levels })) + >::with_function_info( + display_name, data_type, QuantileData { levels } + ) .with_need_drop(true) .finish() } } }) } - _ => Err(ErrorCode::BadDataValueType(format!( "{} does not support type '{:?}'", display_name, arguments[0] diff --git a/src/query/functions/src/aggregates/aggregate_quantile_disc.rs b/src/query/functions/src/aggregates/aggregate_quantile_disc.rs index 9c56ff900b27a..5ed139697721a 100644 --- a/src/query/functions/src/aggregates/aggregate_quantile_disc.rs +++ b/src/query/functions/src/aggregates/aggregate_quantile_disc.rs @@ -34,8 +34,8 @@ use super::AggregateFunctionDescription; use super::AggregateFunctionRef; use super::AggregateFunctionSortDesc; use super::AggregateUnaryFunction; -use super::FunctionData; use super::QuantileData; +use super::SerializeInfo; use super::StateSerde; use super::StateSerdeItem; use super::UnaryState; @@ -80,11 +80,9 @@ where T: ValueType, T::Scalar: BorshSerialize + BorshDeserialize + Ord, { - fn add( - &mut self, - other: T::ScalarRef<'_>, - _function_data: Option<&dyn FunctionData>, - ) -> Result<()> { + type FunctionInfo = QuantileData; + + fn add(&mut self, other: T::ScalarRef<'_>, _: &Self::FunctionInfo) -> Result<()> { self.value.push(T::to_owned_scalar(other)); Ok(()) } @@ -96,15 +94,9 @@ where fn merge_result( &mut self, mut builder: ArrayColumnBuilderMut<'_, T>, - function_data: Option<&dyn FunctionData>, + quantile_disc_data: &QuantileData, ) -> Result<()> { let value_len = self.value.len(); - let quantile_disc_data = unsafe { - function_data - .unwrap() - .as_any() - .downcast_ref_unchecked::() - }; if quantile_disc_data.levels.len() > 1 { let indices = quantile_disc_data .levels @@ -131,7 +123,7 @@ where T: ValueType, T::Scalar: BorshSerialize + BorshDeserialize, { - fn serialize_type(_function_data: Option<&dyn FunctionData>) -> Vec { + fn serialize_type(_: Option<&dyn SerializeInfo>) -> Vec { vec![StateSerdeItem::Binary(None)] } @@ -167,11 +159,9 @@ where T: ValueType, T::Scalar: BorshSerialize + BorshDeserialize + Ord, { - fn add( - &mut self, - other: T::ScalarRef<'_>, - _function_data: Option<&dyn FunctionData>, - ) -> Result<()> { + type FunctionInfo = QuantileData; + + fn add(&mut self, other: T::ScalarRef<'_>, _: &Self::FunctionInfo) -> Result<()> { self.value.push(T::to_owned_scalar(other)); Ok(()) } @@ -188,15 +178,9 @@ where fn merge_result( &mut self, mut builder: T::ColumnBuilderMut<'_>, - function_data: Option<&dyn FunctionData>, + quantile_disc_data: &QuantileData, ) -> Result<()> { let value_len = self.value.len(); - let quantile_disc_data = unsafe { - function_data - .unwrap() - .as_any() - .downcast_ref_unchecked::() - }; let idx = ((value_len - 1) as f64 * quantile_disc_data.levels[0]).floor() as usize; if idx >= value_len { @@ -229,10 +213,11 @@ pub fn try_create_aggregate_quantile_disc_function( QuantileState>, NumberType, ArrayType>, - >::create( - display_name, DataType::Array(Box::new(data_type)) + >::with_function_info( + display_name, + DataType::Array(Box::new(data_type)), + QuantileData { levels }, ) - .with_function_data(Box::new(QuantileData { levels })) .with_need_drop(true) .finish() } else { @@ -240,8 +225,9 @@ pub fn try_create_aggregate_quantile_disc_function( QuantileState>, NumberType, NumberType, - >::create(display_name, data_type) - .with_function_data(Box::new(QuantileData { levels })) + >::with_function_info( + display_name, data_type, QuantileData { levels } + ) .with_need_drop(true) .finish() } @@ -257,10 +243,11 @@ pub fn try_create_aggregate_quantile_disc_function( QuantileState>, DecimalType, ArrayType>, - >::create( - display_name, DataType::Array(Box::new(data_type)) + >::with_function_info( + display_name, + DataType::Array(Box::new(data_type)), + QuantileData { levels }, ) - .with_function_data(Box::new(QuantileData { levels })) .with_need_drop(true) .finish() } else { @@ -268,8 +255,9 @@ pub fn try_create_aggregate_quantile_disc_function( QuantileState>, DecimalType, DecimalType, - >::create(display_name, data_type) - .with_function_data(Box::new(QuantileData { levels })) + >::with_function_info( + display_name, data_type, QuantileData { levels } + ) .with_need_drop(true) .finish() } diff --git a/src/query/functions/src/aggregates/aggregate_range_bound.rs b/src/query/functions/src/aggregates/aggregate_range_bound.rs index c91c93ac708d8..750f00fb81917 100644 --- a/src/query/functions/src/aggregates/aggregate_range_bound.rs +++ b/src/query/functions/src/aggregates/aggregate_range_bound.rs @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::any::Any; - use borsh::BorshDeserialize; use borsh::BorshSerialize; use databend_common_exception::ErrorCode; @@ -29,6 +27,7 @@ use databend_common_expression::AggrStateLoc; use databend_common_expression::AggregateFunctionRef; use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; +use databend_common_expression::ColumnView; use databend_common_expression::Scalar; use databend_common_expression::StateAddr; use rand::prelude::SliceRandom; @@ -43,24 +42,18 @@ use super::AggrState; use super::AggregateFunctionDescription; use super::AggregateFunctionSortDesc; use super::AggregateUnaryFunction; -use super::FunctionData; +use super::SerializeInfo; use super::StateSerde; use super::StateSerdeItem; use super::UnaryState; use crate::with_simple_no_number_mapped_type; -struct RangeBoundData { +pub struct RangeBoundData { partitions: usize, sample_size: usize, data_type: DataType, } -impl FunctionData for RangeBoundData { - fn as_any(&self) -> &dyn Any { - self - } -} - #[derive(BorshSerialize, BorshDeserialize)] pub struct RangeBoundState where @@ -91,18 +84,9 @@ where T: ReturnType, T::Scalar: Ord + BorshSerialize + BorshDeserialize, { - fn add( - &mut self, - other: T::ScalarRef<'_>, - function_data: Option<&dyn FunctionData>, - ) -> Result<()> { - let range_bound_data = unsafe { - function_data - .unwrap() - .as_any() - .downcast_ref_unchecked::() - }; + type FunctionInfo = RangeBoundData; + fn add(&mut self, other: T::ScalarRef<'_>, range_bound_data: &RangeBoundData) -> Result<()> { let total_sample_size = std::cmp::min( range_bound_data.sample_size * range_bound_data.partitions, 10_000, @@ -130,23 +114,17 @@ where fn add_batch( &mut self, - other: T::Column, + other: ColumnView, validity: Option<&Bitmap>, - function_data: Option<&dyn FunctionData>, + range_bound_data: &RangeBoundData, ) -> Result<()> { - let column_len = T::column_len(&other); + let column_len = other.len(); let unset_bits = validity.map_or(0, |v| v.null_count()); if unset_bits == column_len { return Ok(()); } let valid_size = column_len - unset_bits; - let range_bound_data = unsafe { - function_data - .unwrap() - .as_any() - .downcast_ref_unchecked::() - }; let sample_size = std::cmp::max(valid_size / 100, range_bound_data.sample_size); let mut indices = validity.map_or_else( @@ -169,7 +147,7 @@ where let sample_values = sampled_indices .iter() - .map(|i| T::to_owned_scalar(unsafe { T::index_column_unchecked(&other, *i) })) + .map(|i| T::to_owned_scalar(unsafe { other.index_unchecked(*i) })) .collect::>(); self.total_rows += valid_size; @@ -188,14 +166,8 @@ where fn merge_result( &mut self, mut builder: ArrayColumnBuilderMut<'_, T>, - function_data: Option<&dyn FunctionData>, + range_bound_data: &RangeBoundData, ) -> Result<()> { - let range_bound_data = unsafe { - function_data - .unwrap() - .as_any() - .downcast_ref_unchecked::() - }; let step = self.total_rows as f64 / range_bound_data.partitions as f64; let values = std::mem::take(&mut self.values); @@ -248,7 +220,7 @@ where T: ReturnType, T::Scalar: BorshSerialize + BorshDeserialize + Ord, { - fn serialize_type(_function_data: Option<&dyn FunctionData>) -> Vec { + fn serialize_type(_: Option<&dyn SerializeInfo>) -> Vec { vec![StateSerdeItem::Binary(None)] } @@ -287,16 +259,16 @@ pub fn try_create_aggregate_range_bound_function( ) -> Result { assert_unary_arguments(display_name, arguments.len())?; let data_type = arguments[0].clone().remove_nullable(); - let function_data = get_partitions(¶ms, display_name, data_type.clone())?; + let function_info = get_partitions(¶ms, display_name, data_type.clone())?; let return_type = DataType::Array(Box::new(data_type.clone())); with_simple_no_number_mapped_type!(|T| match data_type { DataType::T => { - AggregateUnaryFunction::, T, ArrayType>::create( + AggregateUnaryFunction::, T, ArrayType>::with_function_info( display_name, return_type, + function_info, ) - .with_function_data(Box::new(function_data)) .with_need_drop(true) .finish() } @@ -307,8 +279,9 @@ pub fn try_create_aggregate_range_bound_function( RangeBoundState>, NumberType, ArrayType>, - >::create(display_name, return_type) - .with_function_data(Box::new(function_data)) + >::with_function_info( + display_name, return_type, function_info + ) .with_need_drop(true) .finish() } @@ -321,20 +294,20 @@ pub fn try_create_aggregate_range_bound_function( RangeBoundState>, DecimalType, ArrayType>, - >::create(display_name, return_type) - .with_function_data(Box::new(function_data)) + >::with_function_info( + display_name, return_type, function_info + ) .with_need_drop(true) .finish() } }) } DataType::Binary => { - AggregateUnaryFunction::< - RangeBoundState, - BinaryType, - ArrayType, - >::create(display_name, return_type) - .with_function_data(Box::new(function_data)) + AggregateUnaryFunction::, BinaryType, ArrayType>::with_function_info( + display_name, + return_type, + function_info, + ) .with_need_drop(true) .finish() } diff --git a/src/query/functions/src/aggregates/aggregate_skewness.rs b/src/query/functions/src/aggregates/aggregate_skewness.rs index a3bffc9d7f2ea..33732d03b8f79 100644 --- a/src/query/functions/src/aggregates/aggregate_skewness.rs +++ b/src/query/functions/src/aggregates/aggregate_skewness.rs @@ -39,7 +39,7 @@ use super::batch_merge1; use super::AggrState; use super::AggregateFunctionDescription; use super::AggregateFunctionSortDesc; -use super::FunctionData; +use super::SerializeInfo; use super::StateSerde; #[derive(Default, BorshSerialize, BorshDeserialize)] @@ -55,11 +55,7 @@ where T: AccessType, T::Scalar: AsPrimitive, { - fn add( - &mut self, - other: T::ScalarRef<'_>, - _function_data: Option<&dyn FunctionData>, - ) -> Result<()> { + fn add(&mut self, other: T::ScalarRef<'_>, _: &Self::FunctionInfo) -> Result<()> { let other = T::to_owned_scalar(other).as_(); self.n += 1; self.sum += other; @@ -82,7 +78,7 @@ where fn merge_result( &mut self, mut builder: BuilderMut<'_, Float64Type>, - _function_data: Option<&dyn FunctionData>, + _: &Self::FunctionInfo, ) -> Result<()> { if self.n <= 2 { builder.push(F64::from(0_f64)); @@ -111,7 +107,7 @@ where } impl StateSerde for SkewnessStateV2 { - fn serialize_type(_function_data: Option<&dyn FunctionData>) -> Vec { + fn serialize_type(_: Option<&dyn SerializeInfo>) -> Vec { vec![StateSerdeItem::Binary(Some(32))] } @@ -159,7 +155,6 @@ pub fn try_create_aggregate_skewness_function( NumberConvertView, Float64Type, >::create(display_name, return_type) - .finish() } DataType::Decimal(s) => { @@ -169,10 +164,7 @@ pub fn try_create_aggregate_skewness_function( SkewnessStateV2, DecimalF64View, Float64Type, - >::create( - display_name, return_type - ) - .finish() + >::create(display_name, return_type) } }) } diff --git a/src/query/functions/src/aggregates/aggregate_st_collect.rs b/src/query/functions/src/aggregates/aggregate_st_collect.rs index 7ed5f0e95f5fb..0df5f643ca2c4 100644 --- a/src/query/functions/src/aggregates/aggregate_st_collect.rs +++ b/src/query/functions/src/aggregates/aggregate_st_collect.rs @@ -189,7 +189,7 @@ where T: ArgType impl StateSerde for StCollectState where T: ArgType { - fn serialize_type(_function_data: Option<&dyn super::FunctionData>) -> Vec { + fn serialize_type(_: Option<&dyn super::SerializeInfo>) -> Vec { vec![DataType::Array(Box::new(T::data_type())).into()] } diff --git a/src/query/functions/src/aggregates/aggregate_stddev.rs b/src/query/functions/src/aggregates/aggregate_stddev.rs index 8a37790374e22..96d133e84f9f5 100644 --- a/src/query/functions/src/aggregates/aggregate_stddev.rs +++ b/src/query/functions/src/aggregates/aggregate_stddev.rs @@ -48,7 +48,7 @@ use super::AggregateFunction; use super::AggregateFunctionDescription; use super::AggregateFunctionSortDesc; use super::AggregateUnaryFunction; -use super::FunctionData; +use super::SerializeInfo; use super::StateSerde; use super::UnaryState; @@ -130,11 +130,7 @@ where T: AccessType, T::Scalar: Into, { - fn add( - &mut self, - other: T::ScalarRef<'_>, - _function_data: Option<&dyn FunctionData>, - ) -> Result<()> { + fn add(&mut self, other: T::ScalarRef<'_>, _: &Self::FunctionInfo) -> Result<()> { let value = T::to_owned_scalar(other).into(); self.state_add(value) } @@ -146,14 +142,14 @@ where fn merge_result( &mut self, builder: NullableColumnBuilderMut<'_, Float64Type>, - _function_data: Option<&dyn FunctionData>, + _: &Self::FunctionInfo, ) -> Result<()> { self.state_merge_result(builder) } } impl StateSerde for StddevState { - fn serialize_type(_function_data: Option<&dyn FunctionData>) -> Vec { + fn serialize_type(_: Option<&dyn SerializeInfo>) -> Vec { vec![StateSerdeItem::Binary(Some(24))] } @@ -201,7 +197,6 @@ pub fn try_create_aggregate_stddev_pop_function( NumberConvertView, NullableType, >::create(display_name, return_type) - .finish() } DataType::Decimal(s) => { with_decimal_mapped_type!(|DECIMAL| match s.data_kind() { @@ -211,7 +206,6 @@ pub fn try_create_aggregate_stddev_pop_function( DecimalF64View, NullableType, >::create(display_name, return_type) - .finish() } }) } diff --git a/src/query/functions/src/aggregates/aggregate_string_agg.rs b/src/query/functions/src/aggregates/aggregate_string_agg.rs index 9c929f68f2c39..8b35d0c9a110f 100644 --- a/src/query/functions/src/aggregates/aggregate_string_agg.rs +++ b/src/query/functions/src/aggregates/aggregate_string_agg.rs @@ -12,233 +12,128 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::alloc::Layout; -use std::fmt; -use std::sync::Arc; +use std::fmt::Display; +use std::fmt::Write; use databend_common_exception::ErrorCode; use databend_common_exception::Result; -use databend_common_expression::types::compute_view::StringConvertView; +use databend_common_expression::display::scalar_ref_to_string; use databend_common_expression::types::AccessType; +use databend_common_expression::types::AnyType; use databend_common_expression::types::Bitmap; +use databend_common_expression::types::BooleanType; +use databend_common_expression::types::BuilderMut; use databend_common_expression::types::DataType; +use databend_common_expression::types::Number; +use databend_common_expression::types::NumberDataType; +use databend_common_expression::types::NumberType; use databend_common_expression::types::StringType; -use databend_common_expression::types::ValueType; -use databend_common_expression::AggrStateRegistry; -use databend_common_expression::AggrStateType; +use databend_common_expression::with_number_mapped_type; +use databend_common_expression::AggregateFunctionRef; use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; -use databend_common_expression::DataBlock; -use databend_common_expression::EvaluateOptions; -use databend_common_expression::Evaluator; -use databend_common_expression::FunctionContext; -use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; use databend_common_expression::StateSerdeItem; use super::assert_variadic_arguments; use super::batch_merge1; -use super::AggrState; -use super::AggrStateLoc; -use super::AggregateFunction; +use super::batch_serialize1; use super::AggregateFunctionDescription; use super::AggregateFunctionSortDesc; -use super::StateAddr; -use crate::BUILTIN_FUNCTIONS; +use super::AggregateUnaryFunction; +use super::SerializeInfo; +use super::StateSerde; +use super::UnaryState; -#[derive(Debug)] -pub struct StringAggState { +#[derive(Default)] +struct StringAggState { values: String, } -#[derive(Clone)] -pub struct AggregateStringAggFunction { - display_name: String, - delimiter: String, - value_type: DataType, -} - -impl AggregateFunction for AggregateStringAggFunction { - fn name(&self) -> &str { - "AggregateStringAggFunction" - } - - fn return_type(&self) -> Result { - Ok(DataType::String) - } +impl UnaryState for StringAggState +where T: ToStringType +{ + type FunctionInfo = String; - fn init_state(&self, place: AggrState) { - place.write(|| StringAggState { - values: String::new(), - }); + fn add(&mut self, other: T::ScalarRef<'_>, delimiter: &String) -> Result<()> { + write!(self.values, "{}{delimiter}", T::format(&other)).unwrap(); + Ok(()) } - fn register_state(&self, registry: &mut AggrStateRegistry) { - registry.register(AggrStateType::Custom(Layout::new::())); + fn merge(&mut self, rhs: &Self) -> Result<()> { + self.values.push_str(&rhs.values); + Ok(()) } - fn accumulate( - &self, - place: AggrState, - entries: ProjectedBlock, - validity: Option<&Bitmap>, - _input_rows: usize, + fn merge_result( + &mut self, + mut builder: BuilderMut<'_, StringType>, + delimiter: &String, ) -> Result<()> { - let column = if self.value_type != DataType::String { - let block = DataBlock::new(vec![entries[0].clone()], entries[0].len()); - let func_ctx = &FunctionContext::default(); - let evaluator = Evaluator::new(&block, func_ctx, &BUILTIN_FUNCTIONS); - let value = evaluator.run_cast( - None, - &self.value_type, - &DataType::String, - entries[0].value(), - None, - &|| "(string_aggr)".to_string(), - &mut EvaluateOptions::default(), - )?; - BlockEntry::new(value, || (DataType::String, block.num_rows())) - .downcast::() - .unwrap() + if self.values.is_empty() { + builder.put_and_commit(""); } else { - entries[0].downcast::().unwrap() - }; - let state = place.get::(); - match validity { - Some(validity) => { - column.iter().zip(validity.iter()).for_each(|(v, b)| { - if b { - state.values.push_str(v); - state.values.push_str(&self.delimiter); - } - }); - } - None => { - column.iter().for_each(|v| { - state.values.push_str(v); - state.values.push_str(&self.delimiter); - }); - } + let len = self.values.len() - delimiter.len(); + builder.put_and_commit(&self.values[..len]); } Ok(()) } +} - fn accumulate_keys( - &self, - places: &[StateAddr], - loc: &[AggrStateLoc], - columns: ProjectedBlock, - _input_rows: usize, - ) -> Result<()> { - StringConvertView::iter_column(&columns[0].to_column()) - .zip(places.iter()) - .for_each(|(v, place)| { - let state = AggrState::new(*place, loc).get::(); - state.values.push_str(v.as_str()); - state.values.push_str(&self.delimiter); - }); - Ok(()) +trait ToStringType: AccessType { + fn format(v: &Self::ScalarRef<'_>) -> impl Display; +} + +impl ToStringType for BooleanType { + fn format(v: &Self::ScalarRef<'_>) -> impl Display { + v } +} - fn accumulate_row(&self, place: AggrState, columns: ProjectedBlock, row: usize) -> Result<()> { - let view = columns[0].downcast::().unwrap(); - let v = view.index(row).unwrap(); - let state = place.get::(); - state.values.push_str(v); - state.values.push_str(&self.delimiter); - Ok(()) +impl ToStringType for StringType { + fn format(v: &Self::ScalarRef<'_>) -> impl Display { + v + } +} + +impl ToStringType for NumberType { + fn format(v: &Self::ScalarRef<'_>) -> impl Display { + v } +} + +impl ToStringType for AnyType { + fn format(v: &Self::ScalarRef<'_>) -> impl Display { + scalar_ref_to_string(v) + } +} - fn serialize_type(&self) -> Vec { +impl StateSerde for StringAggState { + fn serialize_type(_: Option<&dyn SerializeInfo>) -> Vec { vec![DataType::String.into()] } fn batch_serialize( - &self, - places: &[StateAddr], - loc: &[AggrStateLoc], + places: &[super::StateAddr], + loc: &[super::AggrStateLoc], builders: &mut [ColumnBuilder], ) -> Result<()> { - let builder = builders[0].as_string_mut().unwrap(); - for place in places { - let state = AggrState::new(*place, loc).get::(); - builder.put_str(&state.values); - builder.commit_row(); - } - Ok(()) + batch_serialize1::(places, loc, builders, |state, builder| { + builder.put_and_commit(&state.values); + Ok(()) + }) } fn batch_merge( - &self, - places: &[StateAddr], - loc: &[AggrStateLoc], + places: &[super::StateAddr], + loc: &[super::AggrStateLoc], state: &BlockEntry, filter: Option<&Bitmap>, ) -> Result<()> { - batch_merge1::( - places, - loc, - state, - filter, - |state, values| { - state.values.push_str(values); - Ok(()) - }, - ) - } - - fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { - let state = place.get::(); - let other = rhs.get::(); - state.values.push_str(&other.values); - Ok(()) - } - - fn merge_result( - &self, - place: AggrState, - _read_only: bool, - builder: &mut ColumnBuilder, - ) -> Result<()> { - let state = place.get::(); - let mut builder = StringType::downcast_builder(builder); - if !state.values.is_empty() { - let len = state.values.len() - self.delimiter.len(); - builder.put_and_commit(&state.values[..len]); - } else { - builder.put_and_commit(""); - } - Ok(()) - } - - fn need_manual_drop_state(&self) -> bool { - true - } - - unsafe fn drop_state(&self, place: AggrState) { - let state = place.get::(); - std::ptr::drop_in_place(state); - } -} - -impl fmt::Display for AggregateStringAggFunction { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.display_name) - } -} - -impl AggregateStringAggFunction { - fn try_create( - display_name: &str, - delimiter: String, - value_type: DataType, - ) -> Result> { - let func = AggregateStringAggFunction { - display_name: display_name.to_string(), - delimiter, - value_type, - }; - Ok(Arc::new(func)) + batch_merge1::(places, loc, state, filter, |state, values| { + state.values.push_str(values); + Ok(()) + }) } } @@ -247,31 +142,62 @@ pub fn try_create_aggregate_string_agg_function( params: Vec, argument_types: Vec, _sort_descs: Vec, -) -> Result> { +) -> Result { assert_variadic_arguments(display_name, argument_types.len(), (1, 2))?; let value_type = argument_types[0].remove_nullable(); - if !matches!( - value_type, - DataType::Boolean - | DataType::String - | DataType::Number(_) - | DataType::Decimal(_) - | DataType::Timestamp - | DataType::Date - | DataType::Variant - | DataType::Interval - ) { - return Err(ErrorCode::BadDataValueType(format!( - "{} does not support type '{:?}'", - display_name, value_type - ))); - } let delimiter = if params.len() == 1 { params[0].as_string().unwrap().clone() } else { String::new() }; - AggregateStringAggFunction::try_create(display_name, delimiter, value_type) + + match_template::match_template! { + T = [ + Boolean => BooleanType, + String => StringType, + ], + match value_type { + DataType::T => { + AggregateUnaryFunction::::with_function_info( + display_name, + DataType::String, + delimiter, + ) + .with_need_drop(true) + .finish() + }, + DataType::Number(num_type) => { + with_number_mapped_type!(|NUM| match num_type { + NumberDataType::NUM => { + AggregateUnaryFunction::, StringType>::with_function_info( + display_name, + DataType::String, + delimiter, + ) + .with_need_drop(true) + .finish() + } + }) + }, + DataType::Decimal(_) + | DataType::Timestamp + | DataType::Date + | DataType::Variant + | DataType::Interval => { + AggregateUnaryFunction::::with_function_info( + display_name, + DataType::String, + delimiter, + ) + .with_need_drop(true) + .finish() + }, + _ => Err(ErrorCode::BadDataValueType(format!( + "{} does not support type '{:?}'", + display_name, value_type + ))), + } + } } pub fn aggregate_string_agg_function_desc() -> AggregateFunctionDescription { diff --git a/src/query/functions/src/aggregates/aggregate_sum.rs b/src/query/functions/src/aggregates/aggregate_sum.rs index 4dcc33ac7a55d..953fcbfe318ee 100644 --- a/src/query/functions/src/aggregates/aggregate_sum.rs +++ b/src/query/functions/src/aggregates/aggregate_sum.rs @@ -29,6 +29,7 @@ use databend_common_expression::with_number_mapped_type; use databend_common_expression::AggregateFunctionRef; use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; +use databend_common_expression::ColumnView; use databend_common_expression::Scalar; use databend_common_expression::StateAddr; use databend_common_expression::StateSerdeItem; @@ -44,7 +45,7 @@ use super::AggrStateLoc; use super::AggregateFunctionDescription; use super::AggregateFunctionSortDesc; use super::AggregateUnaryFunction; -use super::FunctionData; +use super::SerializeInfo; use super::StateSerde; pub struct NumberSumState @@ -99,32 +100,35 @@ where } } -impl UnaryState for NumberSumState +impl UnaryState, NumberType> for NumberSumState> where - T: ArgType, - N: ArgType, - T::Scalar: Number + AsPrimitive, - N::Scalar: Number + AsPrimitive + std::ops::AddAssign, - for<'a> T::ScalarRef<'a>: Number + AsPrimitive, + T: Number + AsPrimitive, + N: Number + AsPrimitive + std::ops::AddAssign, { - fn add( - &mut self, - other: T::ScalarRef<'_>, - _function_data: Option<&dyn FunctionData>, - ) -> Result<()> { + fn add(&mut self, other: T, _: &Self::FunctionInfo) -> Result<()> { self.value += other.as_(); Ok(()) } fn add_batch( &mut self, - other: T::Column, + other: ColumnView>, validity: Option<&Bitmap>, - _function_data: Option<&dyn FunctionData>, + _: &Self::FunctionInfo, ) -> Result<()> { - let col = T::upcast_column(other); - let buffer = NumberType::::try_downcast_column(&col).unwrap(); - self.value += sum_batch::(buffer, validity); + match other { + ColumnView::Const(v, n) => { + let sum: N = v.as_(); + let count = validity.map(|v| v.true_count()).unwrap_or(n); + for _ in 0..count { + self.value += sum; + } + } + ColumnView::Column(buffer) => { + self.value += sum_batch::(buffer, validity); + } + } + Ok(()) } @@ -135,10 +139,10 @@ where fn merge_result( &mut self, - mut builder: N::ColumnBuilderMut<'_>, - _function_data: Option<&dyn FunctionData>, + mut builder: as ValueType>::ColumnBuilderMut<'_>, + _: &Self::FunctionInfo, ) -> Result<()> { - builder.push_item(N::to_scalar_ref(&self.value)); + builder.push(self.value); Ok(()) } } @@ -148,7 +152,7 @@ where N: ArgType, N::Scalar: Number + AsPrimitive + std::ops::AddAssign, { - fn serialize_type(_: Option<&dyn FunctionData>) -> Vec { + fn serialize_type(_: Option<&dyn SerializeInfo>) -> Vec { std::vec![N::data_type().into()] } @@ -196,7 +200,7 @@ impl UnaryState, DecimalTyp for DecimalSumState where T: Decimal + std::ops::AddAssign { - fn add(&mut self, other: T, _function_data: Option<&dyn FunctionData>) -> Result<()> { + fn add(&mut self, other: T, _: &Self::FunctionInfo) -> Result<()> { let mut value = T::from_u64_array(self.value); value += other; @@ -214,9 +218,9 @@ where T: Decimal + std::ops::AddAssign fn add_batch( &mut self, - other: Buffer, + other: ColumnView>, validity: Option<&Bitmap>, - function_data: Option<&dyn FunctionData>, + function_data: &Self::FunctionInfo, ) -> Result<()> { if !SHOULD_CHECK_OVERFLOW { let mut sum = T::from_u64_array(self.value); @@ -225,13 +229,13 @@ where T: Decimal + std::ops::AddAssign Some(validity) if validity.null_count() > 0 => { buffer.iter().zip(validity.iter()).for_each(|(t, b)| { if b { - sum += *t; + sum += t; } }); } _ => { buffer.iter().for_each(|t| { - sum += *t; + sum += t; }); } } @@ -241,13 +245,13 @@ where T: Decimal + std::ops::AddAssign Some(validity) => { for (data, valid) in other.iter().zip(validity.iter()) { if valid { - self.add(*data, function_data)?; + self.add(data, function_data)?; } } } None => { for value in other.iter() { - self.add(*value, function_data)?; + self.add(value, function_data)?; } } } @@ -257,13 +261,13 @@ where T: Decimal + std::ops::AddAssign fn merge(&mut self, rhs: &Self) -> Result<()> { let v = T::from_u64_array(rhs.value); - self.add(v, None) + self.add(v, &()) } fn merge_result( &mut self, mut builder: BuilderMut<'_, DecimalType>, - _function_data: Option<&dyn FunctionData>, + _: &Self::FunctionInfo, ) -> Result<()> { let v = T::from_u64_array(self.value); builder.push(v); @@ -274,7 +278,7 @@ where T: Decimal + std::ops::AddAssign impl StateSerde for DecimalSumState where T: Decimal + std::ops::AddAssign { - fn serialize_type(_function_data: Option<&dyn FunctionData>) -> Vec { + fn serialize_type(_: Option<&dyn SerializeInfo>) -> Vec { vec![DataType::Decimal(T::default_decimal_size()).into()] } @@ -310,41 +314,11 @@ pub struct IntervalSumState { } impl UnaryState for IntervalSumState { - fn add( - &mut self, - other: months_days_micros, - _function_data: Option<&dyn FunctionData>, - ) -> Result<()> { + fn add(&mut self, other: months_days_micros, _: &Self::FunctionInfo) -> Result<()> { self.value += other; Ok(()) } - fn add_batch( - &mut self, - other: Buffer, - validity: Option<&Bitmap>, - _function_data: Option<&dyn FunctionData>, - ) -> Result<()> { - let col = IntervalType::upcast_column_with_type(other, &DataType::Interval); - let buffer = IntervalType::try_downcast_column(&col).unwrap(); - match validity { - Some(validity) if validity.null_count() > 0 => { - buffer.iter().zip(validity.iter()).for_each(|(t, b)| { - if b { - self.value += *t; - } - }); - } - _ => { - buffer.iter().for_each(|t| { - self.value += *t; - }); - } - } - - Ok(()) - } - fn merge(&mut self, rhs: &Self) -> Result<()> { let res = self.value.total_micros() + rhs.value.total_micros(); self.value = months_days_micros(res as i128); @@ -354,7 +328,7 @@ impl UnaryState for IntervalSumState { fn merge_result( &mut self, mut builder: BuilderMut<'_, IntervalType>, - _function_data: Option<&dyn FunctionData>, + _: &Self::FunctionInfo, ) -> Result<()> { builder.push_item(IntervalType::to_scalar_ref(&self.value)); Ok(()) @@ -362,7 +336,7 @@ impl UnaryState for IntervalSumState { } impl StateSerde for IntervalSumState { - fn serialize_type(_function_data: Option<&dyn FunctionData>) -> Vec { + fn serialize_type(_: Option<&dyn SerializeInfo>) -> Vec { vec![DataType::Interval.into()] } @@ -414,7 +388,6 @@ pub fn try_create_aggregate_sum_function( NumberType, NumberType, >::create(display_name, return_type) - .finish() } DataType::Interval => { let return_type = DataType::Interval; @@ -422,7 +395,6 @@ pub fn try_create_aggregate_sum_function( display_name, return_type, ) - .finish() } DataType::Decimal(s) => { with_decimal_mapped_type!(|DECIMAL| match s.data_kind() { @@ -439,14 +411,12 @@ pub fn try_create_aggregate_sum_function( DecimalType, DecimalType, >::create(display_name, return_type) - .finish() } else { AggregateUnaryFunction::< DecimalSumState, DecimalType, DecimalType, >::create(display_name, return_type) - .finish() } } }) diff --git a/src/query/functions/src/aggregates/aggregate_unary.rs b/src/query/functions/src/aggregates/aggregate_unary.rs index 158a93eae992f..f7d5264f15ec9 100644 --- a/src/query/functions/src/aggregates/aggregate_unary.rs +++ b/src/query/functions/src/aggregates/aggregate_unary.rs @@ -13,7 +13,6 @@ // limitations under the License. use std::alloc::Layout; -use std::any::Any; use std::fmt::Display; use std::fmt::Formatter; use std::marker::PhantomData; @@ -30,12 +29,14 @@ use databend_common_expression::AggregateFunction; use databend_common_expression::AggregateFunctionRef; use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; +use databend_common_expression::ColumnView; use databend_common_expression::ProjectedBlock; use databend_common_expression::StateAddr; use databend_common_expression::StateSerdeItem; use super::AggrState; use super::AggrStateLoc; +use super::SerializeInfo; use super::StateSerde; pub(super) trait UnaryState: StateSerde + Default + Send + 'static @@ -43,28 +44,26 @@ where T: AccessType, R: ValueType, { - fn add( - &mut self, - other: T::ScalarRef<'_>, - function_data: Option<&dyn FunctionData>, - ) -> Result<()>; + type FunctionInfo: Send + Sync = (); + + fn add(&mut self, other: T::ScalarRef<'_>, function_data: &Self::FunctionInfo) -> Result<()>; fn add_batch( &mut self, - other: T::Column, + other: ColumnView, validity: Option<&Bitmap>, - function_data: Option<&dyn FunctionData>, + function_data: &Self::FunctionInfo, ) -> Result<()> { match validity { Some(validity) => { - for (data, valid) in T::iter_column(&other).zip(validity.iter()) { + for (data, valid) in other.iter().zip(validity.iter()) { if valid { self.add(data, function_data)?; } } } None => { - for value in T::iter_column(&other) { + for value in other.iter() { self.add(value, function_data)?; } } @@ -77,14 +76,10 @@ where fn merge_result( &mut self, builder: R::ColumnBuilderMut<'_>, - function_data: Option<&dyn FunctionData>, + function_info: &Self::FunctionInfo, ) -> Result<()>; } -pub trait FunctionData: Send + Sync { - fn as_any(&self) -> &dyn Any; -} - pub(super) struct AggregateUnaryFunction where S: UnaryState, @@ -93,7 +88,8 @@ where { display_name: String, return_type: DataType, - function_data: Option>, + function_info: S::FunctionInfo, + serialize_info: Option>, need_drop: bool, _p: PhantomData, } @@ -109,45 +105,59 @@ where } } +impl AggregateUnaryFunction +where + S: UnaryState + 'static, + T: AccessType, + R: ValueType, +{ + pub fn new(display_name: &str, return_type: DataType) -> Self { + Self::with_function_info(display_name, return_type, ()) + } + + pub fn create(display_name: &str, return_type: DataType) -> Result { + Self::with_function_info(display_name, return_type, ()).finish() + } +} + impl AggregateUnaryFunction where S: UnaryState + 'static, T: AccessType, R: ValueType, { - pub(crate) fn create( + pub fn with_function_info( display_name: &str, return_type: DataType, - ) -> AggregateUnaryFunction { + function_info: S::FunctionInfo, + ) -> Self { AggregateUnaryFunction { display_name: display_name.to_string(), return_type, - function_data: None, + function_info, + serialize_info: None, need_drop: false, _p: PhantomData, } } - pub(crate) fn with_function_data( - mut self, - function_data: Box, - ) -> AggregateUnaryFunction { - self.function_data = Some(function_data); + pub fn with_serialize_info(mut self, serialize_info: Box) -> Self { + self.serialize_info = Some(serialize_info); self } - pub(crate) fn with_need_drop(mut self, need_drop: bool) -> AggregateUnaryFunction { + pub fn with_need_drop(mut self, need_drop: bool) -> Self { self.need_drop = need_drop; self } - pub(crate) fn finish(self) -> Result { + pub fn finish(self) -> Result { Ok(Arc::new(self)) } fn do_merge_result(&self, state: &mut S, builder: &mut ColumnBuilder) -> Result<()> { let builder = R::downcast_builder(builder); - state.merge_result(builder, self.function_data.as_deref()) + state.merge_result(builder, &self.function_info) } } @@ -180,10 +190,10 @@ where validity: Option<&Bitmap>, _input_rows: usize, ) -> Result<()> { - let column = T::try_downcast_column(&columns[0].to_column()).unwrap(); + let column = columns[0].downcast().unwrap(); let state: &mut S = place.get::(); - state.add_batch(column, validity, self.function_data.as_deref()) + state.add_batch(column, validity, &self.function_info) } fn accumulate_row(&self, place: AggrState, columns: ProjectedBlock, row: usize) -> Result<()> { @@ -191,7 +201,7 @@ where let value = view.index(row).unwrap(); let state: &mut S = place.get::(); - state.add(value, self.function_data.as_deref())?; + state.add(value, &self.function_info)?; Ok(()) } @@ -206,14 +216,14 @@ where for (v, place) in view.iter().zip(places.iter()) { let state: &mut S = AggrState::new(*place, loc).get::(); - state.add(v, self.function_data.as_deref())?; + state.add(v, &self.function_info)?; } Ok(()) } fn serialize_type(&self) -> Vec { - S::serialize_type(self.function_data.as_deref()) + S::serialize_type(self.serialize_info.as_deref()) } fn batch_serialize( diff --git a/src/query/functions/src/aggregates/mod.rs b/src/query/functions/src/aggregates/mod.rs index 9f87884be4c6f..e565471040417 100644 --- a/src/query/functions/src/aggregates/mod.rs +++ b/src/query/functions/src/aggregates/mod.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::any::Any; + use adaptors::*; pub use aggregate_count::AggregateCountFunction; pub use aggregate_function::*; @@ -30,7 +32,7 @@ use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; trait StateSerde { - fn serialize_type(_function_data: Option<&dyn FunctionData>) -> Vec; + fn serialize_type(info: Option<&dyn SerializeInfo>) -> Vec; fn batch_serialize( places: &[StateAddr], @@ -46,7 +48,11 @@ trait StateSerde { ) -> Result<()>; } -impl FunctionData for DataType { +trait SerializeInfo: Send + Sync { + fn as_any(&self) -> &dyn Any; +} + +impl SerializeInfo for DataType { fn as_any(&self) -> &dyn std::any::Any { self } diff --git a/src/query/functions/src/lib.rs b/src/query/functions/src/lib.rs index 81a4b18d19f61..50b4eb9e6e586 100644 --- a/src/query/functions/src/lib.rs +++ b/src/query/functions/src/lib.rs @@ -12,14 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#![allow(clippy::arc_with_non_send_sync)] -#![allow(clippy::uninlined_format_args)] -#![allow(clippy::ptr_arg)] -#![allow(clippy::type_complexity)] #![feature(box_patterns)] #![feature(type_ascription)] #![feature(try_blocks)] #![feature(downcast_unchecked)] +#![feature(associated_type_defaults)] use aggregates::AggregateFunctionFactory; use ctor::ctor; diff --git a/src/query/functions/src/srfs/string.rs b/src/query/functions/src/srfs/string.rs index aeda5c7fb821f..bdf4231d6ea57 100644 --- a/src/query/functions/src/srfs/string.rs +++ b/src/query/functions/src/srfs/string.rs @@ -80,6 +80,7 @@ pub fn regexp_split_to_vec( pattern = &pattern[4..]; } let mut literal_mode_from_flags = false; + #[allow(clippy::type_complexity)] let mut builder_config_fns: Vec &mut RegexBuilder>> = Vec::new();