diff --git a/datafusion/common/src/spans.rs b/datafusion/common/src/spans.rs index 40ebdeffb601..5111e264123c 100644 --- a/datafusion/common/src/spans.rs +++ b/datafusion/common/src/spans.rs @@ -140,7 +140,7 @@ impl Span { /// the column a that comes from SELECT 1 AS a UNION ALL SELECT 2 AS a you'll /// need two spans. #[derive(Debug, Clone)] -// Store teh first [`Span`] on the stack because that is by far the most common +// Store the first [`Span`] on the stack because that is by far the most common // case. More will spill onto the heap. pub struct Spans(pub Vec<Span>); diff --git a/datafusion/expr-common/src/interval_arithmetic.rs b/datafusion/expr-common/src/interval_arithmetic.rs index 7f20020c3457..9d00b45962bc 100644 --- a/datafusion/expr-common/src/interval_arithmetic.rs +++ b/datafusion/expr-common/src/interval_arithmetic.rs @@ -17,12 +17,13 @@ //! Interval arithmetic library -use crate::operator::Operator; -use crate::type_coercion::binary::BinaryTypeCoercer; use std::borrow::Borrow; use std::fmt::{self, Display, Formatter}; use std::ops::{AddAssign, SubAssign}; +use crate::operator::Operator; +use crate::type_coercion::binary::{comparison_coercion_numeric, BinaryTypeCoercer}; + use arrow::compute::{cast_with_options, CastOptions}; use arrow::datatypes::{ DataType, IntervalDayTime, IntervalMonthDayNano, IntervalUnit, TimeUnit, @@ -168,7 +169,7 @@ macro_rules! value_transition { /// limits after any operation, they either become unbounded or they are fixed /// to the maximum/minimum value of the datatype, depending on the direction /// of the overflowing endpoint, opting for the safer choice. -/// +/// /// 4. **Floating-point special cases**: /// - `INF` values are converted to `NULL`s while constructing an interval to /// ensure consistency, with other data types. @@ -405,13 +406,18 @@ impl Interval { // There must be no way to create an interval whose endpoints have // different types. - assert!( + debug_assert!( lower_type == upper_type, "Interval bounds have different types: {lower_type} != {upper_type}" ); lower_type } + /// Checks if the interval is unbounded (on either side). + pub fn is_unbounded(&self) -> bool { + self.lower.is_null() || self.upper.is_null() + } + /// Casts this interval to `data_type` using `cast_options`. pub fn cast_to( &self, @@ -645,7 +651,7 @@ impl Interval { let upper = min_of_bounds(&self.upper, &rhs.upper); // New lower and upper bounds must always construct a valid interval. - assert!( + debug_assert!( (lower.is_null() || upper.is_null() || (lower <= upper)), "The intersection of two intervals can not be an invalid interval" ); @@ -653,26 +659,70 @@ impl Interval { Ok(Some(Self { lower, upper })) } - /// Decide if this interval certainly contains, possibly contains, or can't - /// contain a [`ScalarValue`] (`other`) by returning `[true, true]`, - /// `[false, true]` or `[false, false]` respectively. + /// Compute the union of this interval with the given interval. /// /// NOTE: This function only works with intervals of the same data type. /// Attempting to compare intervals of different data types will lead /// to an error. - pub fn contains_value<T: Borrow<ScalarValue>>(&self, other: T) -> Result<bool> { + pub fn union<T: Borrow<Self>>(&self, other: T) -> Result<Self> { let rhs = other.borrow(); if self.data_type().ne(&rhs.data_type()) { + return internal_err!( + "Cannot calculate the union of intervals with different data types, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ); + }; + + let lower = if self.lower.is_null() + || (!rhs.lower.is_null() && self.lower <= rhs.lower) + { + self.lower.clone() + } else { + rhs.lower.clone() + }; + let upper = if self.upper.is_null() + || (!rhs.upper.is_null() && self.upper >= rhs.upper) + { + self.upper.clone() + } else { + rhs.upper.clone() + }; + + // New lower and upper bounds must always construct a valid interval. + debug_assert!( + (lower.is_null() || upper.is_null() || (lower <= upper)), + "The union of two intervals can not be an invalid interval" + ); + + Ok(Self { lower, upper }) + } + + /// Decide if this interval contains a [`ScalarValue`] (`other`) by returning `true` or `false`. + pub fn contains_value<T: Borrow<ScalarValue>>(&self, other: T) -> Result<bool> { + let rhs = other.borrow(); + + let (lhs_lower, lhs_upper, rhs) = if self.data_type().eq(&rhs.data_type()) { + (&self.lower, &self.upper, rhs) + } else if let Some(common_type) = + comparison_coercion_numeric(&self.data_type(), &rhs.data_type()) + { + ( + &self.lower.cast_to(&common_type)?, + &self.upper.cast_to(&common_type)?, + &rhs.cast_to(&common_type)?, + ) + } else { return internal_err!( "Data types must be compatible for containment checks, lhs:{}, rhs:{}", self.data_type(), rhs.data_type() ); - } + }; // We only check the upper bound for a `None` value because `None` // values are less than `Some` values according to Rust. - Ok(&self.lower <= rhs && (self.upper.is_null() || rhs <= &self.upper)) + Ok(lhs_lower <= rhs && (lhs_upper.is_null() || rhs <= lhs_upper)) } /// Decide if this interval is a superset of, overlaps with, or @@ -825,6 +875,17 @@ impl Interval { } } + /// Computes the width of this interval; i.e. the difference between its + /// bounds. For unbounded intervals, this function will return a `NULL` + /// `ScalarValue` If the underlying data type doesn't support subtraction, + /// this function will return an error. + pub fn width(&self) -> Result<ScalarValue> { + let dt = self.data_type(); + let width_dt = + BinaryTypeCoercer::new(&dt, &Operator::Minus, &dt).get_result_type()?; + Ok(sub_bounds::<true>(&width_dt, &self.upper, &self.lower)) + } + /// Returns the cardinality of this interval, which is the number of all /// distinct points inside it. This function returns `None` if: /// - The interval is unbounded from either side, or @@ -874,10 +935,10 @@ impl Interval { /// This method computes the arithmetic negation of the interval, reflecting /// it about the origin of the number line. This operation swaps and negates /// the lower and upper bounds of the interval. - pub fn arithmetic_negate(self) -> Result<Self> { + pub fn arithmetic_negate(&self) -> Result<Self> { Ok(Self { - lower: self.upper().clone().arithmetic_negate()?, - upper: self.lower().clone().arithmetic_negate()?, + lower: self.upper.arithmetic_negate()?, + upper: self.lower.arithmetic_negate()?, }) } } @@ -1119,11 +1180,11 @@ fn next_value_helper<const INC: bool>(value: ScalarValue) -> ScalarValue { match value { // f32/f64::NEG_INF/INF and f32/f64::NaN values should not emerge at this point. Float32(Some(val)) => { - assert!(val.is_finite(), "Non-standardized floating point usage"); + debug_assert!(val.is_finite(), "Non-standardized floating point usage"); Float32(Some(if INC { next_up(val) } else { next_down(val) })) } Float64(Some(val)) => { - assert!(val.is_finite(), "Non-standardized floating point usage"); + debug_assert!(val.is_finite(), "Non-standardized floating point usage"); Float64(Some(if INC { next_up(val) } else { next_down(val) })) } Int8(Some(val)) => Int8(Some(increment_decrement::<INC, i8>(val))), @@ -1275,7 +1336,7 @@ pub fn satisfy_greater( } else { right.upper.clone() }; - + // No possibility to create an invalid interval: Ok(Some(( Interval::new(new_left_lower, left.upper.clone()), Interval::new(right.lower.clone(), new_right_upper), @@ -1868,6 +1929,7 @@ mod tests { }; use arrow::datatypes::DataType; + use datafusion_common::rounding::{next_down, next_up}; use datafusion_common::{Result, ScalarValue}; #[test] @@ -2532,6 +2594,126 @@ mod tests { Ok(()) } + #[test] + fn union_test() -> Result<()> { + let possible_cases = vec![ + ( + Interval::make(Some(1000_i64), None)?, + Interval::make::<i64>(None, None)?, + Interval::make_unbounded(&DataType::Int64)?, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(1000_i64))?, + Interval::make_unbounded(&DataType::Int64)?, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(2000_i64))?, + Interval::make_unbounded(&DataType::Int64)?, + ), + ( + Interval::make(Some(1000_i64), Some(2000_i64))?, + Interval::make(Some(1000_i64), None)?, + Interval::make(Some(1000_i64), None)?, + ), + ( + Interval::make(Some(1000_i64), Some(2000_i64))?, + Interval::make(Some(1000_i64), Some(1500_i64))?, + Interval::make(Some(1000_i64), Some(2000_i64))?, + ), + ( + Interval::make(Some(1000_i64), Some(2000_i64))?, + Interval::make(Some(500_i64), Some(1500_i64))?, + Interval::make(Some(500_i64), Some(2000_i64))?, + ), + ( + Interval::make::<i64>(None, None)?, + Interval::make::<i64>(None, None)?, + Interval::make::<i64>(None, None)?, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(0_i64))?, + Interval::make_unbounded(&DataType::Int64)?, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(999_i64))?, + Interval::make_unbounded(&DataType::Int64)?, + ), + ( + Interval::make(Some(1500_i64), Some(2000_i64))?, + Interval::make(Some(1000_i64), Some(1499_i64))?, + Interval::make(Some(1000_i64), Some(2000_i64))?, + ), + ( + Interval::make(Some(0_i64), Some(1000_i64))?, + Interval::make(Some(2000_i64), Some(3000_i64))?, + Interval::make(Some(0_i64), Some(3000_i64))?, + ), + ( + Interval::make(None, Some(2000_u64))?, + Interval::make(Some(500_u64), None)?, + Interval::make(Some(0_u64), None)?, + ), + ( + Interval::make(Some(0_u64), Some(0_u64))?, + Interval::make(Some(0_u64), None)?, + Interval::make(Some(0_u64), None)?, + ), + ( + Interval::make(Some(1000.0_f32), None)?, + Interval::make(None, Some(1000.0_f32))?, + Interval::make_unbounded(&DataType::Float32)?, + ), + ( + Interval::make(Some(1000.0_f32), Some(1500.0_f32))?, + Interval::make(Some(0.0_f32), Some(1500.0_f32))?, + Interval::make(Some(0.0_f32), Some(1500.0_f32))?, + ), + ( + Interval::try_new( + prev_value(ScalarValue::Float32(Some(1.0))), + prev_value(ScalarValue::Float32(Some(1.0))), + )?, + Interval::make(Some(1.0_f32), Some(1.0_f32))?, + Interval::try_new( + prev_value(ScalarValue::Float32(Some(1.0))), + ScalarValue::Float32(Some(1.0)), + )?, + ), + ( + Interval::try_new( + next_value(ScalarValue::Float32(Some(1.0))), + next_value(ScalarValue::Float32(Some(1.0))), + )?, + Interval::make(Some(1.0_f32), Some(1.0_f32))?, + Interval::try_new( + ScalarValue::Float32(Some(1.0)), + next_value(ScalarValue::Float32(Some(1.0))), + )?, + ), + ( + Interval::make(Some(-1000.0_f64), Some(1500.0_f64))?, + Interval::make(Some(-1500.0_f64), Some(2000.0_f64))?, + Interval::make(Some(-1500.0_f64), Some(2000.0_f64))?, + ), + ( + Interval::make(Some(16.0_f64), Some(32.0_f64))?, + Interval::make(Some(32.0_f64), Some(64.0_f64))?, + Interval::make(Some(16.0_f64), Some(64.0_f64))?, + ), + ]; + for (first, second, expected) in possible_cases { + println!("{}", first); + println!("{}", second); + assert_eq!(first.union(second)?, expected) + } + + Ok(()) + } + #[test] fn test_contains() -> Result<()> { let possible_cases = vec![ @@ -2594,6 +2776,43 @@ mod tests { Ok(()) } + #[test] + fn test_contains_value() -> Result<()> { + let possible_cases = vec![ + ( + Interval::make(Some(0), Some(100))?, + ScalarValue::Int32(Some(50)), + true, + ), + ( + Interval::make(Some(0), Some(100))?, + ScalarValue::Int32(Some(150)), + false, + ), + ( + Interval::make(Some(0), Some(100))?, + ScalarValue::Float64(Some(50.)), + true, + ), + ( + Interval::make(Some(0), Some(100))?, + ScalarValue::Float64(Some(next_down(100.))), + true, + ), + ( + Interval::make(Some(0), Some(100))?, + ScalarValue::Float64(Some(next_up(100.))), + false, + ), + ]; + + for (interval, value, expected) in possible_cases { + assert_eq!(interval.contains_value(value)?, expected) + } + + Ok(()) + } + #[test] fn test_add() -> Result<()> { let cases = vec![ @@ -3208,6 +3427,53 @@ mod tests { Ok(()) } + #[test] + fn test_width_of_intervals() -> Result<()> { + let intervals = [ + ( + Interval::make(Some(0.25_f64), Some(0.50_f64))?, + ScalarValue::from(0.25_f64), + ), + ( + Interval::make(Some(0.5_f64), Some(1.0_f64))?, + ScalarValue::from(0.5_f64), + ), + ( + Interval::make(Some(1.0_f64), Some(2.0_f64))?, + ScalarValue::from(1.0_f64), + ), + ( + Interval::make(Some(32.0_f64), Some(64.0_f64))?, + ScalarValue::from(32.0_f64), + ), + ( + Interval::make(Some(-0.50_f64), Some(-0.25_f64))?, + ScalarValue::from(0.25_f64), + ), + ( + Interval::make(Some(-32.0_f64), Some(-16.0_f64))?, + ScalarValue::from(16.0_f64), + ), + ( + Interval::make(Some(-0.50_f64), Some(0.25_f64))?, + ScalarValue::from(0.75_f64), + ), + ( + Interval::make(Some(-32.0_f64), Some(16.0_f64))?, + ScalarValue::from(48.0_f64), + ), + ( + Interval::make(Some(-32_i64), Some(16_i64))?, + ScalarValue::from(48_i64), + ), + ]; + for (interval, expected) in intervals { + assert_eq!(interval.width()?, expected); + } + + Ok(()) + } + #[test] fn test_cardinality_of_intervals() -> Result<()> { // In IEEE 754 standard for floating-point arithmetic, if we keep the sign and exponent fields same, diff --git a/datafusion/expr-common/src/lib.rs b/datafusion/expr-common/src/lib.rs index fede0bb8e57e..ee40038beb21 100644 --- a/datafusion/expr-common/src/lib.rs +++ b/datafusion/expr-common/src/lib.rs @@ -38,4 +38,5 @@ pub mod interval_arithmetic; pub mod operator; pub mod signature; pub mod sort_properties; +pub mod statistics; pub mod type_coercion; diff --git a/datafusion/expr-common/src/statistics.rs b/datafusion/expr-common/src/statistics.rs new file mode 100644 index 000000000000..7e0bc88087ef --- /dev/null +++ b/datafusion/expr-common/src/statistics.rs @@ -0,0 +1,1620 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::f64::consts::LN_2; + +use crate::interval_arithmetic::{apply_operator, Interval}; +use crate::operator::Operator; +use crate::type_coercion::binary::binary_numeric_coercion; + +use arrow::array::ArrowNativeTypeOp; +use arrow::datatypes::DataType; +use datafusion_common::rounding::alter_fp_rounding_mode; +use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue}; + +/// This object defines probabilistic distributions that encode uncertain +/// information about a single, scalar value. Currently, we support five core +/// statistical distributions. New variants will be added over time. +/// +/// This object is the lowest-level object in the statistics hierarchy, and it +/// is the main unit of calculus when evaluating expressions in a statistical +/// context. Notions like column and table statistics are built on top of this +/// object and the operations it supports. +#[derive(Clone, Debug, PartialEq)] +pub enum Distribution { + Uniform(UniformDistribution), + Exponential(ExponentialDistribution), + Gaussian(GaussianDistribution), + Bernoulli(BernoulliDistribution), + Generic(GenericDistribution), +} + +use Distribution::{Bernoulli, Exponential, Gaussian, Generic, Uniform}; + +impl Distribution { + /// Constructs a new [`Uniform`] distribution from the given [`Interval`]. + pub fn new_uniform(interval: Interval) -> Result<Self> { + UniformDistribution::try_new(interval).map(Uniform) + } + + /// Constructs a new [`Exponential`] distribution from the given rate/offset + /// pair, and validates the given parameters. + pub fn new_exponential( + rate: ScalarValue, + offset: ScalarValue, + positive_tail: bool, + ) -> Result<Self> { + ExponentialDistribution::try_new(rate, offset, positive_tail).map(Exponential) + } + + /// Constructs a new [`Gaussian`] distribution from the given mean/variance + /// pair, and validates the given parameters. + pub fn new_gaussian(mean: ScalarValue, variance: ScalarValue) -> Result<Self> { + GaussianDistribution::try_new(mean, variance).map(Gaussian) + } + + /// Constructs a new [`Bernoulli`] distribution from the given success + /// probability, and validates the given parameters. + pub fn new_bernoulli(p: ScalarValue) -> Result<Self> { + BernoulliDistribution::try_new(p).map(Bernoulli) + } + + /// Constructs a new [`Generic`] distribution from the given mean, median, + /// variance, and range values after validating the given parameters. + pub fn new_generic( + mean: ScalarValue, + median: ScalarValue, + variance: ScalarValue, + range: Interval, + ) -> Result<Self> { + GenericDistribution::try_new(mean, median, variance, range).map(Generic) + } + + /// Constructs a new [`Generic`] distribution from the given range. Other + /// parameters (mean, median and variance) are initialized with null values. + pub fn new_from_interval(range: Interval) -> Result<Self> { + let null = ScalarValue::try_from(range.data_type())?; + Distribution::new_generic(null.clone(), null.clone(), null, range) + } + + /// Extracts the mean value of this uncertain quantity, depending on its + /// distribution: + /// - A [`Uniform`] distribution's interval determines its mean value, which + /// is the arithmetic average of the interval endpoints. + /// - An [`Exponential`] distribution's mean is calculable by the formula + /// `offset + 1 / λ`, where `λ` is the (non-negative) rate. + /// - A [`Gaussian`] distribution contains the mean explicitly. + /// - A [`Bernoulli`] distribution's mean is equal to its success probability `p`. + /// - A [`Generic`] distribution _may_ have it explicitly, or this information + /// may be absent. + pub fn mean(&self) -> Result<ScalarValue> { + match &self { + Uniform(u) => u.mean(), + Exponential(e) => e.mean(), + Gaussian(g) => Ok(g.mean().clone()), + Bernoulli(b) => Ok(b.mean().clone()), + Generic(u) => Ok(u.mean().clone()), + } + } + + /// Extracts the median value of this uncertain quantity, depending on its + /// distribution: + /// - A [`Uniform`] distribution's interval determines its median value, which + /// is the arithmetic average of the interval endpoints. + /// - An [`Exponential`] distribution's median is calculable by the formula + /// `offset + ln(2) / λ`, where `λ` is the (non-negative) rate. + /// - A [`Gaussian`] distribution's median is equal to its mean, which is + /// specified explicitly. + /// - A [`Bernoulli`] distribution's median is `1` if `p > 0.5` and `0` + /// otherwise, where `p` is the success probability. + /// - A [`Generic`] distribution _may_ have it explicitly, or this information + /// may be absent. + pub fn median(&self) -> Result<ScalarValue> { + match &self { + Uniform(u) => u.median(), + Exponential(e) => e.median(), + Gaussian(g) => Ok(g.median().clone()), + Bernoulli(b) => b.median(), + Generic(u) => Ok(u.median().clone()), + } + } + + /// Extracts the variance value of this uncertain quantity, depending on + /// its distribution: + /// - A [`Uniform`] distribution's interval determines its variance value, which + /// is calculable by the formula `(upper - lower) ^ 2 / 12`. + /// - An [`Exponential`] distribution's variance is calculable by the formula + /// `1 / (λ ^ 2)`, where `λ` is the (non-negative) rate. + /// - A [`Gaussian`] distribution's variance is specified explicitly. + /// - A [`Bernoulli`] distribution's median is given by the formula `p * (1 - p)` + /// where `p` is the success probability. + /// - A [`Generic`] distribution _may_ have it explicitly, or this information + /// may be absent. + pub fn variance(&self) -> Result<ScalarValue> { + match &self { + Uniform(u) => u.variance(), + Exponential(e) => e.variance(), + Gaussian(g) => Ok(g.variance.clone()), + Bernoulli(b) => b.variance(), + Generic(u) => Ok(u.variance.clone()), + } + } + + /// Extracts the range of this uncertain quantity, depending on its + /// distribution: + /// - A [`Uniform`] distribution's range is simply its interval. + /// - An [`Exponential`] distribution's range is `[offset, +∞)`. + /// - A [`Gaussian`] distribution's range is unbounded. + /// - A [`Bernoulli`] distribution's range is [`Interval::UNCERTAIN`], if + /// `p` is neither `0` nor `1`. Otherwise, it is [`Interval::CERTAINLY_FALSE`] + /// and [`Interval::CERTAINLY_TRUE`], respectively. + /// - A [`Generic`] distribution is unbounded by default, but more information + /// may be present. + pub fn range(&self) -> Result<Interval> { + match &self { + Uniform(u) => Ok(u.range().clone()), + Exponential(e) => e.range(), + Gaussian(g) => g.range(), + Bernoulli(b) => Ok(b.range()), + Generic(u) => Ok(u.range().clone()), + } + } + + /// Returns the data type of the statistical parameters comprising this + /// distribution. + pub fn data_type(&self) -> DataType { + match &self { + Uniform(u) => u.data_type(), + Exponential(e) => e.data_type(), + Gaussian(g) => g.data_type(), + Bernoulli(b) => b.data_type(), + Generic(u) => u.data_type(), + } + } + + pub fn target_type(args: &[&ScalarValue]) -> Result<DataType> { + let mut arg_types = args + .iter() + .filter(|&&arg| (arg != &ScalarValue::Null)) + .map(|&arg| arg.data_type()); + + let Some(dt) = arg_types.next().map_or_else( + || Some(DataType::Null), + |first| { + arg_types + .try_fold(first, |target, arg| binary_numeric_coercion(&target, &arg)) + }, + ) else { + return internal_err!("Can only evaluate statistics for numeric types"); + }; + Ok(dt) + } +} + +/// Uniform distribution, represented by its range. If the given range extends +/// towards infinity, the distribution will be improper -- which is OK. For a +/// more in-depth discussion, see: +/// +/// <https://en.wikipedia.org/wiki/Continuous_uniform_distribution> +/// <https://en.wikipedia.org/wiki/Prior_probability#Improper_priors> +#[derive(Clone, Debug, PartialEq)] +pub struct UniformDistribution { + interval: Interval, +} + +/// Exponential distribution with an optional shift. The probability density +/// function (PDF) is defined as follows: +/// +/// For a positive tail (when `positive_tail` is `true`): +/// +/// `f(x; λ, offset) = λ exp(-λ (x - offset)) for x ≥ offset` +/// +/// For a negative tail (when `positive_tail` is `false`): +/// +/// `f(x; λ, offset) = λ exp(-λ (offset - x)) for x ≤ offset` +/// +/// +/// In both cases, the PDF is `0` outside the specified domain. +/// +/// For more information, see: +/// +/// <https://en.wikipedia.org/wiki/Exponential_distribution> +#[derive(Clone, Debug, PartialEq)] +pub struct ExponentialDistribution { + rate: ScalarValue, + offset: ScalarValue, + /// Indicates whether the exponential distribution has a positive tail; + /// i.e. it extends towards positive infinity. + positive_tail: bool, +} + +/// Gaussian (normal) distribution, represented by its mean and variance. +/// For a more in-depth discussion, see: +/// +/// <https://en.wikipedia.org/wiki/Normal_distribution> +#[derive(Clone, Debug, PartialEq)] +pub struct GaussianDistribution { + mean: ScalarValue, + variance: ScalarValue, +} + +/// Bernoulli distribution with success probability `p`. If `p` has a null value, +/// the success probability is unknown. For a more in-depth discussion, see: +/// +/// <https://en.wikipedia.org/wiki/Bernoulli_distribution> +#[derive(Clone, Debug, PartialEq)] +pub struct BernoulliDistribution { + p: ScalarValue, +} + +/// A generic distribution whose functional form is not available, which is +/// approximated via some summary statistics. For a more in-depth discussion, see: +/// +/// <https://en.wikipedia.org/wiki/Summary_statistics> +#[derive(Clone, Debug, PartialEq)] +pub struct GenericDistribution { + mean: ScalarValue, + median: ScalarValue, + variance: ScalarValue, + range: Interval, +} + +impl UniformDistribution { + fn try_new(interval: Interval) -> Result<Self> { + if interval.data_type().eq(&DataType::Boolean) { + return internal_err!( + "Construction of a boolean `Uniform` distribution is prohibited, create a `Bernoulli` distribution instead." + ); + } + + Ok(Self { interval }) + } + + pub fn data_type(&self) -> DataType { + self.interval.data_type() + } + + /// Computes the mean value of this distribution. In case of improper + /// distributions (i.e. when the range is unbounded), the function returns + /// a `NULL` `ScalarValue`. + pub fn mean(&self) -> Result<ScalarValue> { + // TODO: Should we ensure that this always returns a real number data type? + let dt = self.data_type(); + let two = ScalarValue::from(2).cast_to(&dt)?; + let result = self + .interval + .lower() + .add_checked(self.interval.upper())? + .div(two); + debug_assert!( + !self.interval.is_unbounded() || result.as_ref().is_ok_and(|r| r.is_null()) + ); + result + } + + pub fn median(&self) -> Result<ScalarValue> { + self.mean() + } + + /// Computes the variance value of this distribution. In case of improper + /// distributions (i.e. when the range is unbounded), the function returns + /// a `NULL` `ScalarValue`. + pub fn variance(&self) -> Result<ScalarValue> { + // TODO: Should we ensure that this always returns a real number data type? + let width = self.interval.width()?; + let dt = width.data_type(); + let twelve = ScalarValue::from(12).cast_to(&dt)?; + let result = width.mul_checked(&width)?.div(twelve); + debug_assert!( + !self.interval.is_unbounded() || result.as_ref().is_ok_and(|r| r.is_null()) + ); + result + } + + pub fn range(&self) -> &Interval { + &self.interval + } +} + +impl ExponentialDistribution { + fn try_new( + rate: ScalarValue, + offset: ScalarValue, + positive_tail: bool, + ) -> Result<Self> { + let dt = rate.data_type(); + if offset.data_type() != dt { + internal_err!("Rate and offset must have the same data type") + } else if offset.is_null() { + internal_err!("Offset of an `ExponentialDistribution` cannot be null") + } else if rate.is_null() { + internal_err!("Rate of an `ExponentialDistribution` cannot be null") + } else if rate.le(&ScalarValue::new_zero(&dt)?) { + internal_err!("Rate of an `ExponentialDistribution` must be positive") + } else { + Ok(Self { + rate, + offset, + positive_tail, + }) + } + } + + pub fn data_type(&self) -> DataType { + self.rate.data_type() + } + + pub fn rate(&self) -> &ScalarValue { + &self.rate + } + + pub fn offset(&self) -> &ScalarValue { + &self.offset + } + + pub fn positive_tail(&self) -> bool { + self.positive_tail + } + + pub fn mean(&self) -> Result<ScalarValue> { + // TODO: Should we ensure that this always returns a real number data type? + let one = ScalarValue::new_one(&self.data_type())?; + let tail_mean = one.div(&self.rate)?; + if self.positive_tail { + self.offset.add_checked(tail_mean) + } else { + self.offset.sub_checked(tail_mean) + } + } + + pub fn median(&self) -> Result<ScalarValue> { + // TODO: Should we ensure that this always returns a real number data type? + let ln_two = ScalarValue::from(LN_2).cast_to(&self.data_type())?; + let tail_median = ln_two.div(&self.rate)?; + if self.positive_tail { + self.offset.add_checked(tail_median) + } else { + self.offset.sub_checked(tail_median) + } + } + + pub fn variance(&self) -> Result<ScalarValue> { + // TODO: Should we ensure that this always returns a real number data type? + let one = ScalarValue::new_one(&self.data_type())?; + let rate_squared = self.rate.mul_checked(&self.rate)?; + one.div(rate_squared) + } + + pub fn range(&self) -> Result<Interval> { + let end = ScalarValue::try_from(&self.data_type())?; + if self.positive_tail { + Interval::try_new(self.offset.clone(), end) + } else { + Interval::try_new(end, self.offset.clone()) + } + } +} + +impl GaussianDistribution { + fn try_new(mean: ScalarValue, variance: ScalarValue) -> Result<Self> { + let dt = mean.data_type(); + if variance.data_type() != dt { + internal_err!("Mean and variance must have the same data type") + } else if variance.is_null() { + internal_err!("Variance of a `GaussianDistribution` cannot be null") + } else if variance.lt(&ScalarValue::new_zero(&dt)?) { + internal_err!("Variance of a `GaussianDistribution` must be positive") + } else { + Ok(Self { mean, variance }) + } + } + + pub fn data_type(&self) -> DataType { + self.mean.data_type() + } + + pub fn mean(&self) -> &ScalarValue { + &self.mean + } + + pub fn variance(&self) -> &ScalarValue { + &self.variance + } + + pub fn median(&self) -> &ScalarValue { + self.mean() + } + + pub fn range(&self) -> Result<Interval> { + Interval::make_unbounded(&self.data_type()) + } +} + +impl BernoulliDistribution { + fn try_new(p: ScalarValue) -> Result<Self> { + if p.is_null() { + Ok(Self { p }) + } else { + let dt = p.data_type(); + let zero = ScalarValue::new_zero(&dt)?; + let one = ScalarValue::new_one(&dt)?; + if p.ge(&zero) && p.le(&one) { + Ok(Self { p }) + } else { + internal_err!( + "Success probability of a `BernoulliDistribution` must be in [0, 1]" + ) + } + } + } + + pub fn data_type(&self) -> DataType { + self.p.data_type() + } + + pub fn p_value(&self) -> &ScalarValue { + &self.p + } + + pub fn mean(&self) -> &ScalarValue { + &self.p + } + + /// Computes the median value of this distribution. In case of an unknown + /// success probability, the function returns a `NULL` `ScalarValue`. + pub fn median(&self) -> Result<ScalarValue> { + let dt = self.data_type(); + if self.p.is_null() { + ScalarValue::try_from(&dt) + } else { + let one = ScalarValue::new_one(&dt)?; + if one.sub_checked(&self.p)?.lt(&self.p) { + ScalarValue::new_one(&dt) + } else { + ScalarValue::new_zero(&dt) + } + } + } + + /// Computes the variance value of this distribution. In case of an unknown + /// success probability, the function returns a `NULL` `ScalarValue`. + pub fn variance(&self) -> Result<ScalarValue> { + let dt = self.data_type(); + let one = ScalarValue::new_one(&dt)?; + let result = one.sub_checked(&self.p)?.mul_checked(&self.p); + debug_assert!(!self.p.is_null() || result.as_ref().is_ok_and(|r| r.is_null())); + result + } + + pub fn range(&self) -> Interval { + let dt = self.data_type(); + // Unwraps are safe as the constructor guarantees that the data type + // supports zero and one values. + if ScalarValue::new_zero(&dt).unwrap().eq(&self.p) { + Interval::CERTAINLY_FALSE + } else if ScalarValue::new_one(&dt).unwrap().eq(&self.p) { + Interval::CERTAINLY_TRUE + } else { + Interval::UNCERTAIN + } + } +} + +impl GenericDistribution { + fn try_new( + mean: ScalarValue, + median: ScalarValue, + variance: ScalarValue, + range: Interval, + ) -> Result<Self> { + if range.data_type().eq(&DataType::Boolean) { + return internal_err!( + "Construction of a boolean `Generic` distribution is prohibited, create a `Bernoulli` distribution instead." + ); + } + + let validate_location = |m: &ScalarValue| -> Result<bool> { + // Checks whether the given location estimate is within the range. + if m.is_null() { + Ok(true) + } else { + range.contains_value(m) + } + }; + + if !validate_location(&mean)? + || !validate_location(&median)? + || (!variance.is_null() + && variance.lt(&ScalarValue::new_zero(&variance.data_type())?)) + { + internal_err!("Tried to construct an invalid `GenericDistribution` instance") + } else { + Ok(Self { + mean, + median, + variance, + range, + }) + } + } + + pub fn data_type(&self) -> DataType { + self.mean.data_type() + } + + pub fn mean(&self) -> &ScalarValue { + &self.mean + } + + pub fn median(&self) -> &ScalarValue { + &self.median + } + + pub fn variance(&self) -> &ScalarValue { + &self.variance + } + + pub fn range(&self) -> &Interval { + &self.range + } +} + +/// This function takes a logical operator and two Bernoulli distributions, +/// and it returns a new Bernoulli distribution that represents the result of +/// the operation. Currently, only `AND` and `OR` operations are supported. +pub fn combine_bernoullis( + op: &Operator, + left: &BernoulliDistribution, + right: &BernoulliDistribution, +) -> Result<BernoulliDistribution> { + let left_p = left.p_value(); + let right_p = right.p_value(); + match op { + Operator::And => match (left_p.is_null(), right_p.is_null()) { + (false, false) => { + BernoulliDistribution::try_new(left_p.mul_checked(right_p)?) + } + (false, true) if left_p.eq(&ScalarValue::new_zero(&left_p.data_type())?) => { + Ok(left.clone()) + } + (true, false) + if right_p.eq(&ScalarValue::new_zero(&right_p.data_type())?) => + { + Ok(right.clone()) + } + _ => { + let dt = Distribution::target_type(&[left_p, right_p])?; + BernoulliDistribution::try_new(ScalarValue::try_from(&dt)?) + } + }, + Operator::Or => match (left_p.is_null(), right_p.is_null()) { + (false, false) => { + let sum = left_p.add_checked(right_p)?; + let product = left_p.mul_checked(right_p)?; + let or_success = sum.sub_checked(product)?; + BernoulliDistribution::try_new(or_success) + } + (false, true) if left_p.eq(&ScalarValue::new_one(&left_p.data_type())?) => { + Ok(left.clone()) + } + (true, false) if right_p.eq(&ScalarValue::new_one(&right_p.data_type())?) => { + Ok(right.clone()) + } + _ => { + let dt = Distribution::target_type(&[left_p, right_p])?; + BernoulliDistribution::try_new(ScalarValue::try_from(&dt)?) + } + }, + _ => { + not_impl_err!("Statistical evaluation only supports AND and OR operators") + } + } +} + +/// Applies the given operation to the given Gaussian distributions. Currently, +/// this function handles only addition and subtraction operations. If the +/// result is not a Gaussian random variable, it returns `None`. For details, +/// see: +/// +/// <https://en.wikipedia.org/wiki/Sum_of_normally_distributed_random_variables> +pub fn combine_gaussians( + op: &Operator, + left: &GaussianDistribution, + right: &GaussianDistribution, +) -> Result<Option<GaussianDistribution>> { + match op { + Operator::Plus => GaussianDistribution::try_new( + left.mean().add_checked(right.mean())?, + left.variance().add_checked(right.variance())?, + ) + .map(Some), + Operator::Minus => GaussianDistribution::try_new( + left.mean().sub_checked(right.mean())?, + left.variance().add_checked(right.variance())?, + ) + .map(Some), + _ => Ok(None), + } +} + +/// Creates a new `Bernoulli` distribution by computing the resulting probability. +/// Expects `op` to be a comparison operator, with `left` and `right` having +/// numeric distributions. The resulting distribution has the `Float64` data +/// type. +pub fn create_bernoulli_from_comparison( + op: &Operator, + left: &Distribution, + right: &Distribution, +) -> Result<Distribution> { + match (left, right) { + (Uniform(left), Uniform(right)) => { + match op { + Operator::Eq | Operator::NotEq => { + let (li, ri) = (left.range(), right.range()); + if let Some(intersection) = li.intersect(ri)? { + // If the ranges are not disjoint, calculate the probability + // of equality using cardinalities: + if let (Some(lc), Some(rc), Some(ic)) = ( + li.cardinality(), + ri.cardinality(), + intersection.cardinality(), + ) { + // Avoid overflow by widening the type temporarily: + let pairs = ((lc as u128) * (rc as u128)) as f64; + let p = (ic as f64).div_checked(pairs)?; + // Alternative approach that may be more stable: + // let p = (ic as f64) + // .div_checked(lc as f64)? + // .div_checked(rc as f64)?; + + let mut p_value = ScalarValue::from(p); + if op == &Operator::NotEq { + let one = ScalarValue::from(1.0); + p_value = alter_fp_rounding_mode::<false, _>( + &one, + &p_value, + |lhs, rhs| lhs.sub_checked(rhs), + )?; + }; + return Distribution::new_bernoulli(p_value); + } + } else if op == &Operator::Eq { + // If the ranges are disjoint, probability of equality is 0. + return Distribution::new_bernoulli(ScalarValue::from(0.0)); + } else { + // If the ranges are disjoint, probability of not-equality is 1. + return Distribution::new_bernoulli(ScalarValue::from(1.0)); + } + } + Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq => { + // TODO: We can handle inequality operators and calculate a + // `p` value instead of falling back to an unknown Bernoulli + // distribution. Note that the strict and non-strict inequalities + // may require slightly different logic in case of real vs. + // integral data types. + } + _ => {} + } + } + (Gaussian(_), Gaussian(_)) => { + // TODO: We can handle Gaussian comparisons and calculate a `p` value + // instead of falling back to an unknown Bernoulli distribution. + } + _ => {} + } + let (li, ri) = (left.range()?, right.range()?); + let range_evaluation = apply_operator(op, &li, &ri)?; + if range_evaluation.eq(&Interval::CERTAINLY_FALSE) { + Distribution::new_bernoulli(ScalarValue::from(0.0)) + } else if range_evaluation.eq(&Interval::CERTAINLY_TRUE) { + Distribution::new_bernoulli(ScalarValue::from(1.0)) + } else if range_evaluation.eq(&Interval::UNCERTAIN) { + Distribution::new_bernoulli(ScalarValue::try_from(&DataType::Float64)?) + } else { + internal_err!("This function must be called with a comparison operator") + } +} + +/// Creates a new [`Generic`] distribution that represents the result of the +/// given binary operation on two unknown quantities represented by their +/// [`Distribution`] objects. The function computes the mean, median and +/// variance if possible. +pub fn new_generic_from_binary_op( + op: &Operator, + left: &Distribution, + right: &Distribution, +) -> Result<Distribution> { + Distribution::new_generic( + compute_mean(op, left, right)?, + compute_median(op, left, right)?, + compute_variance(op, left, right)?, + apply_operator(op, &left.range()?, &right.range()?)?, + ) +} + +/// Computes the mean value for the result of the given binary operation on +/// two unknown quantities represented by their [`Distribution`] objects. +pub fn compute_mean( + op: &Operator, + left: &Distribution, + right: &Distribution, +) -> Result<ScalarValue> { + let (left_mean, right_mean) = (left.mean()?, right.mean()?); + + match op { + Operator::Plus => return left_mean.add_checked(right_mean), + Operator::Minus => return left_mean.sub_checked(right_mean), + // Note the independence assumption below: + Operator::Multiply => return left_mean.mul_checked(right_mean), + // TODO: We can calculate the mean for division when we support reciprocals, + // or know the distributions of the operands. For details, see: + // + // <https://en.wikipedia.org/wiki/Algebra_of_random_variables> + // <https://stats.stackexchange.com/questions/185683/distribution-of-ratio-between-two-independent-uniform-random-variables> + // + // Fall back to an unknown mean value for division: + Operator::Divide => {} + // Fall back to an unknown mean value for other cases: + _ => {} + } + let target_type = Distribution::target_type(&[&left_mean, &right_mean])?; + ScalarValue::try_from(target_type) +} + +/// Computes the median value for the result of the given binary operation on +/// two unknown quantities represented by its [`Distribution`] objects. Currently, +/// the median is calculable only for addition and subtraction operations on: +/// - [`Uniform`] and [`Uniform`] distributions, and +/// - [`Gaussian`] and [`Gaussian`] distributions. +pub fn compute_median( + op: &Operator, + left: &Distribution, + right: &Distribution, +) -> Result<ScalarValue> { + match (left, right) { + (Uniform(lu), Uniform(ru)) => { + let (left_median, right_median) = (lu.median()?, ru.median()?); + // Under the independence assumption, the result is a symmetric + // triangular distribution, so we can simply add/subtract the + // median values: + match op { + Operator::Plus => return left_median.add_checked(right_median), + Operator::Minus => return left_median.sub_checked(right_median), + // Fall back to an unknown median value for other cases: + _ => {} + } + } + // Under the independence assumption, the result is another Gaussian + // distribution, so we can simply add/subtract the median values: + (Gaussian(lg), Gaussian(rg)) => match op { + Operator::Plus => return lg.mean().add_checked(rg.mean()), + Operator::Minus => return lg.mean().sub_checked(rg.mean()), + // Fall back to an unknown median value for other cases: + _ => {} + }, + // Fall back to an unknown median value for other cases: + _ => {} + } + + let (left_median, right_median) = (left.median()?, right.median()?); + let target_type = Distribution::target_type(&[&left_median, &right_median])?; + ScalarValue::try_from(target_type) +} + +/// Computes the variance value for the result of the given binary operation on +/// two unknown quantities represented by their [`Distribution`] objects. +pub fn compute_variance( + op: &Operator, + left: &Distribution, + right: &Distribution, +) -> Result<ScalarValue> { + let (left_variance, right_variance) = (left.variance()?, right.variance()?); + + match op { + // Note the independence assumption below: + Operator::Plus => return left_variance.add_checked(right_variance), + // Note the independence assumption below: + Operator::Minus => return left_variance.add_checked(right_variance), + // Note the independence assumption below: + Operator::Multiply => { + // For more details, along with an explanation of the formula below, see: + // + // <https://en.wikipedia.org/wiki/Distribution_of_the_product_of_two_random_variables> + let (left_mean, right_mean) = (left.mean()?, right.mean()?); + let left_mean_sq = left_mean.mul_checked(&left_mean)?; + let right_mean_sq = right_mean.mul_checked(&right_mean)?; + let left_sos = left_variance.add_checked(&left_mean_sq)?; + let right_sos = right_variance.add_checked(&right_mean_sq)?; + let pos = left_mean_sq.mul_checked(right_mean_sq)?; + return left_sos.mul_checked(right_sos)?.sub_checked(pos); + } + // TODO: We can calculate the variance for division when we support reciprocals, + // or know the distributions of the operands. For details, see: + // + // <https://en.wikipedia.org/wiki/Algebra_of_random_variables> + // <https://stats.stackexchange.com/questions/185683/distribution-of-ratio-between-two-independent-uniform-random-variables> + // + // Fall back to an unknown variance value for division: + Operator::Divide => {} + // Fall back to an unknown variance value for other cases: + _ => {} + } + let target_type = Distribution::target_type(&[&left_variance, &right_variance])?; + ScalarValue::try_from(target_type) +} + +#[cfg(test)] +mod tests { + use super::{ + combine_bernoullis, combine_gaussians, compute_mean, compute_median, + compute_variance, create_bernoulli_from_comparison, new_generic_from_binary_op, + BernoulliDistribution, Distribution, GaussianDistribution, UniformDistribution, + }; + use crate::interval_arithmetic::{apply_operator, Interval}; + use crate::operator::Operator; + + use arrow::datatypes::DataType; + use datafusion_common::{HashSet, Result, ScalarValue}; + + #[test] + fn uniform_dist_is_valid_test() -> Result<()> { + assert_eq!( + Distribution::new_uniform(Interval::make_zero(&DataType::Int8)?)?, + Distribution::Uniform(UniformDistribution { + interval: Interval::make_zero(&DataType::Int8)?, + }) + ); + + assert!(Distribution::new_uniform(Interval::UNCERTAIN).is_err()); + Ok(()) + } + + #[test] + fn exponential_dist_is_valid_test() { + // This array collects test cases of the form (distribution, validity). + let exponentials = vec![ + ( + Distribution::new_exponential(ScalarValue::Null, ScalarValue::Null, true), + false, + ), + ( + Distribution::new_exponential( + ScalarValue::from(0_f32), + ScalarValue::from(1_f32), + true, + ), + false, + ), + ( + Distribution::new_exponential( + ScalarValue::from(100_f32), + ScalarValue::from(1_f32), + true, + ), + true, + ), + ( + Distribution::new_exponential( + ScalarValue::from(-100_f32), + ScalarValue::from(1_f32), + true, + ), + false, + ), + ]; + for case in exponentials { + assert_eq!(case.0.is_ok(), case.1); + } + } + + #[test] + fn gaussian_dist_is_valid_test() { + // This array collects test cases of the form (distribution, validity). + let gaussians = vec![ + ( + Distribution::new_gaussian(ScalarValue::Null, ScalarValue::Null), + false, + ), + ( + Distribution::new_gaussian( + ScalarValue::from(0_f32), + ScalarValue::from(0_f32), + ), + true, + ), + ( + Distribution::new_gaussian( + ScalarValue::from(0_f32), + ScalarValue::from(0.5_f32), + ), + true, + ), + ( + Distribution::new_gaussian( + ScalarValue::from(0_f32), + ScalarValue::from(-0.5_f32), + ), + false, + ), + ]; + for case in gaussians { + assert_eq!(case.0.is_ok(), case.1); + } + } + + #[test] + fn bernoulli_dist_is_valid_test() { + // This array collects test cases of the form (distribution, validity). + let bernoullis = vec![ + (Distribution::new_bernoulli(ScalarValue::Null), true), + (Distribution::new_bernoulli(ScalarValue::from(0.)), true), + (Distribution::new_bernoulli(ScalarValue::from(0.25)), true), + (Distribution::new_bernoulli(ScalarValue::from(1.)), true), + (Distribution::new_bernoulli(ScalarValue::from(11.)), false), + (Distribution::new_bernoulli(ScalarValue::from(-11.)), false), + (Distribution::new_bernoulli(ScalarValue::from(0_i64)), true), + (Distribution::new_bernoulli(ScalarValue::from(1_i64)), true), + ( + Distribution::new_bernoulli(ScalarValue::from(11_i64)), + false, + ), + ( + Distribution::new_bernoulli(ScalarValue::from(-11_i64)), + false, + ), + ]; + for case in bernoullis { + assert_eq!(case.0.is_ok(), case.1); + } + } + + #[test] + fn generic_dist_is_valid_test() -> Result<()> { + // This array collects test cases of the form (distribution, validity). + let generic_dists = vec![ + // Using a boolean range to construct a Generic distribution is prohibited. + ( + Distribution::new_generic( + ScalarValue::Null, + ScalarValue::Null, + ScalarValue::Null, + Interval::UNCERTAIN, + ), + false, + ), + ( + Distribution::new_generic( + ScalarValue::Null, + ScalarValue::Null, + ScalarValue::Null, + Interval::make_zero(&DataType::Float32)?, + ), + true, + ), + ( + Distribution::new_generic( + ScalarValue::from(0_f32), + ScalarValue::Float32(None), + ScalarValue::Float32(None), + Interval::make_zero(&DataType::Float32)?, + ), + true, + ), + ( + Distribution::new_generic( + ScalarValue::Float64(None), + ScalarValue::from(0.), + ScalarValue::Float64(None), + Interval::make_zero(&DataType::Float32)?, + ), + true, + ), + ( + Distribution::new_generic( + ScalarValue::from(-10_f32), + ScalarValue::Float32(None), + ScalarValue::Float32(None), + Interval::make_zero(&DataType::Float32)?, + ), + false, + ), + ( + Distribution::new_generic( + ScalarValue::Float32(None), + ScalarValue::from(10_f32), + ScalarValue::Float32(None), + Interval::make_zero(&DataType::Float32)?, + ), + false, + ), + ( + Distribution::new_generic( + ScalarValue::Null, + ScalarValue::Null, + ScalarValue::Null, + Interval::make_zero(&DataType::Float32)?, + ), + true, + ), + ( + Distribution::new_generic( + ScalarValue::from(0), + ScalarValue::from(0), + ScalarValue::Int32(None), + Interval::make_zero(&DataType::Int32)?, + ), + true, + ), + ( + Distribution::new_generic( + ScalarValue::from(0_f32), + ScalarValue::from(0_f32), + ScalarValue::Float32(None), + Interval::make_zero(&DataType::Float32)?, + ), + true, + ), + ( + Distribution::new_generic( + ScalarValue::from(50.), + ScalarValue::from(50.), + ScalarValue::Float64(None), + Interval::make(Some(0.), Some(100.))?, + ), + true, + ), + ( + Distribution::new_generic( + ScalarValue::from(50.), + ScalarValue::from(50.), + ScalarValue::Float64(None), + Interval::make(Some(-100.), Some(0.))?, + ), + false, + ), + ( + Distribution::new_generic( + ScalarValue::Float64(None), + ScalarValue::Float64(None), + ScalarValue::from(1.), + Interval::make_zero(&DataType::Float64)?, + ), + true, + ), + ( + Distribution::new_generic( + ScalarValue::Float64(None), + ScalarValue::Float64(None), + ScalarValue::from(-1.), + Interval::make_zero(&DataType::Float64)?, + ), + false, + ), + ]; + for case in generic_dists { + assert_eq!(case.0.is_ok(), case.1, "{:?}", case.0); + } + + Ok(()) + } + + #[test] + fn mean_extraction_test() -> Result<()> { + // This array collects test cases of the form (distribution, mean value). + let dists = vec![ + ( + Distribution::new_uniform(Interval::make_zero(&DataType::Int64)?), + ScalarValue::from(0_i64), + ), + ( + Distribution::new_uniform(Interval::make_zero(&DataType::Float64)?), + ScalarValue::from(0.), + ), + ( + Distribution::new_uniform(Interval::make(Some(1), Some(100))?), + ScalarValue::from(50), + ), + ( + Distribution::new_uniform(Interval::make(Some(-100), Some(-1))?), + ScalarValue::from(-50), + ), + ( + Distribution::new_uniform(Interval::make(Some(-100), Some(100))?), + ScalarValue::from(0), + ), + ( + Distribution::new_exponential( + ScalarValue::from(2.), + ScalarValue::from(0.), + true, + ), + ScalarValue::from(0.5), + ), + ( + Distribution::new_exponential( + ScalarValue::from(2.), + ScalarValue::from(1.), + true, + ), + ScalarValue::from(1.5), + ), + ( + Distribution::new_gaussian(ScalarValue::from(0.), ScalarValue::from(1.)), + ScalarValue::from(0.), + ), + ( + Distribution::new_gaussian( + ScalarValue::from(-2.), + ScalarValue::from(0.5), + ), + ScalarValue::from(-2.), + ), + ( + Distribution::new_bernoulli(ScalarValue::from(0.5)), + ScalarValue::from(0.5), + ), + ( + Distribution::new_generic( + ScalarValue::from(42.), + ScalarValue::from(42.), + ScalarValue::Float64(None), + Interval::make(Some(25.), Some(50.))?, + ), + ScalarValue::from(42.), + ), + ]; + + for case in dists { + assert_eq!(case.0?.mean()?, case.1); + } + + Ok(()) + } + + #[test] + fn median_extraction_test() -> Result<()> { + // This array collects test cases of the form (distribution, median value). + let dists = vec![ + ( + Distribution::new_uniform(Interval::make_zero(&DataType::Int64)?), + ScalarValue::from(0_i64), + ), + ( + Distribution::new_uniform(Interval::make(Some(25.), Some(75.))?), + ScalarValue::from(50.), + ), + ( + Distribution::new_exponential( + ScalarValue::from(2_f64.ln()), + ScalarValue::from(0.), + true, + ), + ScalarValue::from(1.), + ), + ( + Distribution::new_gaussian(ScalarValue::from(2.), ScalarValue::from(1.)), + ScalarValue::from(2.), + ), + ( + Distribution::new_bernoulli(ScalarValue::from(0.25)), + ScalarValue::from(0.), + ), + ( + Distribution::new_bernoulli(ScalarValue::from(0.75)), + ScalarValue::from(1.), + ), + ( + Distribution::new_gaussian(ScalarValue::from(2.), ScalarValue::from(1.)), + ScalarValue::from(2.), + ), + ( + Distribution::new_generic( + ScalarValue::from(12.), + ScalarValue::from(12.), + ScalarValue::Float64(None), + Interval::make(Some(0.), Some(25.))?, + ), + ScalarValue::from(12.), + ), + ]; + + for case in dists { + assert_eq!(case.0?.median()?, case.1); + } + + Ok(()) + } + + #[test] + fn variance_extraction_test() -> Result<()> { + // This array collects test cases of the form (distribution, variance value). + let dists = vec![ + ( + Distribution::new_uniform(Interval::make(Some(0.), Some(12.))?), + ScalarValue::from(12.), + ), + ( + Distribution::new_exponential( + ScalarValue::from(10.), + ScalarValue::from(0.), + true, + ), + ScalarValue::from(0.01), + ), + ( + Distribution::new_gaussian(ScalarValue::from(0.), ScalarValue::from(1.)), + ScalarValue::from(1.), + ), + ( + Distribution::new_bernoulli(ScalarValue::from(0.5)), + ScalarValue::from(0.25), + ), + ( + Distribution::new_generic( + ScalarValue::Float64(None), + ScalarValue::Float64(None), + ScalarValue::from(0.02), + Interval::make_zero(&DataType::Float64)?, + ), + ScalarValue::from(0.02), + ), + ]; + + for case in dists { + assert_eq!(case.0?.variance()?, case.1); + } + + Ok(()) + } + + #[test] + fn test_calculate_generic_properties_gauss_gauss() -> Result<()> { + let dist_a = + Distribution::new_gaussian(ScalarValue::from(10.), ScalarValue::from(0.0))?; + let dist_b = + Distribution::new_gaussian(ScalarValue::from(20.), ScalarValue::from(0.0))?; + + let test_data = vec![ + // Mean: + ( + compute_mean(&Operator::Plus, &dist_a, &dist_b)?, + ScalarValue::from(30.), + ), + ( + compute_mean(&Operator::Minus, &dist_a, &dist_b)?, + ScalarValue::from(-10.), + ), + // Median: + ( + compute_median(&Operator::Plus, &dist_a, &dist_b)?, + ScalarValue::from(30.), + ), + ( + compute_median(&Operator::Minus, &dist_a, &dist_b)?, + ScalarValue::from(-10.), + ), + ]; + for (actual, expected) in test_data { + assert_eq!(actual, expected); + } + + Ok(()) + } + + #[test] + fn test_combine_bernoullis_and_op() -> Result<()> { + let op = Operator::And; + let left = BernoulliDistribution::try_new(ScalarValue::from(0.5))?; + let right = BernoulliDistribution::try_new(ScalarValue::from(0.4))?; + let left_null = BernoulliDistribution::try_new(ScalarValue::Null)?; + let right_null = BernoulliDistribution::try_new(ScalarValue::Null)?; + + assert_eq!( + combine_bernoullis(&op, &left, &right)?.p_value(), + &ScalarValue::from(0.5 * 0.4) + ); + assert_eq!( + combine_bernoullis(&op, &left_null, &right)?.p_value(), + &ScalarValue::Float64(None) + ); + assert_eq!( + combine_bernoullis(&op, &left, &right_null)?.p_value(), + &ScalarValue::Float64(None) + ); + assert_eq!( + combine_bernoullis(&op, &left_null, &left_null)?.p_value(), + &ScalarValue::Null + ); + + Ok(()) + } + + #[test] + fn test_combine_bernoullis_or_op() -> Result<()> { + let op = Operator::Or; + let left = BernoulliDistribution::try_new(ScalarValue::from(0.6))?; + let right = BernoulliDistribution::try_new(ScalarValue::from(0.4))?; + let left_null = BernoulliDistribution::try_new(ScalarValue::Null)?; + let right_null = BernoulliDistribution::try_new(ScalarValue::Null)?; + + assert_eq!( + combine_bernoullis(&op, &left, &right)?.p_value(), + &ScalarValue::from(0.6 + 0.4 - (0.6 * 0.4)) + ); + assert_eq!( + combine_bernoullis(&op, &left_null, &right)?.p_value(), + &ScalarValue::Float64(None) + ); + assert_eq!( + combine_bernoullis(&op, &left, &right_null)?.p_value(), + &ScalarValue::Float64(None) + ); + assert_eq!( + combine_bernoullis(&op, &left_null, &left_null)?.p_value(), + &ScalarValue::Null + ); + + Ok(()) + } + + #[test] + fn test_combine_bernoullis_unsupported_ops() -> Result<()> { + let mut operator_set = operator_set(); + operator_set.remove(&Operator::And); + operator_set.remove(&Operator::Or); + + let left = BernoulliDistribution::try_new(ScalarValue::from(0.6))?; + let right = BernoulliDistribution::try_new(ScalarValue::from(0.4))?; + for op in operator_set { + assert!( + combine_bernoullis(&op, &left, &right).is_err(), + "Operator {op} should not be supported for Bernoulli distributions" + ); + } + + Ok(()) + } + + #[test] + fn test_combine_gaussians_addition() -> Result<()> { + let op = Operator::Plus; + let left = GaussianDistribution::try_new( + ScalarValue::from(3.0), + ScalarValue::from(2.0), + )?; + let right = GaussianDistribution::try_new( + ScalarValue::from(4.0), + ScalarValue::from(1.0), + )?; + + let result = combine_gaussians(&op, &left, &right)?.unwrap(); + + assert_eq!(result.mean(), &ScalarValue::from(7.0)); // 3.0 + 4.0 + assert_eq!(result.variance(), &ScalarValue::from(3.0)); // 2.0 + 1.0 + Ok(()) + } + + #[test] + fn test_combine_gaussians_subtraction() -> Result<()> { + let op = Operator::Minus; + let left = GaussianDistribution::try_new( + ScalarValue::from(7.0), + ScalarValue::from(2.0), + )?; + let right = GaussianDistribution::try_new( + ScalarValue::from(4.0), + ScalarValue::from(1.0), + )?; + + let result = combine_gaussians(&op, &left, &right)?.unwrap(); + + assert_eq!(result.mean(), &ScalarValue::from(3.0)); // 7.0 - 4.0 + assert_eq!(result.variance(), &ScalarValue::from(3.0)); // 2.0 + 1.0 + + Ok(()) + } + + #[test] + fn test_combine_gaussians_unsupported_ops() -> Result<()> { + let mut operator_set = operator_set(); + operator_set.remove(&Operator::Plus); + operator_set.remove(&Operator::Minus); + + let left = GaussianDistribution::try_new( + ScalarValue::from(7.0), + ScalarValue::from(2.0), + )?; + let right = GaussianDistribution::try_new( + ScalarValue::from(4.0), + ScalarValue::from(1.0), + )?; + for op in operator_set { + assert!( + combine_gaussians(&op, &left, &right)?.is_none(), + "Operator {op} should not be supported for Gaussian distributions" + ); + } + + Ok(()) + } + + // Expected test results were calculated in Wolfram Mathematica, by using: + // + // *METHOD_NAME*[TransformedDistribution[ + // x *op* y, + // {x ~ *DISTRIBUTION_X*[..], y ~ *DISTRIBUTION_Y*[..]} + // ]] + #[test] + fn test_calculate_generic_properties_uniform_uniform() -> Result<()> { + let dist_a = Distribution::new_uniform(Interval::make(Some(0.), Some(12.))?)?; + let dist_b = Distribution::new_uniform(Interval::make(Some(12.), Some(36.))?)?; + + let test_data = vec![ + // Mean: + ( + compute_mean(&Operator::Plus, &dist_a, &dist_b)?, + ScalarValue::from(30.), + ), + ( + compute_mean(&Operator::Minus, &dist_a, &dist_b)?, + ScalarValue::from(-18.), + ), + ( + compute_mean(&Operator::Multiply, &dist_a, &dist_b)?, + ScalarValue::from(144.), + ), + // Median: + ( + compute_median(&Operator::Plus, &dist_a, &dist_b)?, + ScalarValue::from(30.), + ), + ( + compute_median(&Operator::Minus, &dist_a, &dist_b)?, + ScalarValue::from(-18.), + ), + // Variance: + ( + compute_variance(&Operator::Plus, &dist_a, &dist_b)?, + ScalarValue::from(60.), + ), + ( + compute_variance(&Operator::Minus, &dist_a, &dist_b)?, + ScalarValue::from(60.), + ), + ( + compute_variance(&Operator::Multiply, &dist_a, &dist_b)?, + ScalarValue::from(9216.), + ), + ]; + for (actual, expected) in test_data { + assert_eq!(actual, expected); + } + + Ok(()) + } + + /// Test for `Uniform`-`Uniform`, `Uniform`-`Generic`, `Generic`-`Uniform`, + /// `Generic`-`Generic` pairs, where range is always present. + #[test] + fn test_compute_range_where_present() -> Result<()> { + let a = &Interval::make(Some(0.), Some(12.0))?; + let b = &Interval::make(Some(0.), Some(12.0))?; + let mean = ScalarValue::from(6.0); + for (dist_a, dist_b) in [ + ( + Distribution::new_uniform(a.clone())?, + Distribution::new_uniform(b.clone())?, + ), + ( + Distribution::new_generic( + mean.clone(), + mean.clone(), + ScalarValue::Float64(None), + a.clone(), + )?, + Distribution::new_uniform(b.clone())?, + ), + ( + Distribution::new_uniform(a.clone())?, + Distribution::new_generic( + mean.clone(), + mean.clone(), + ScalarValue::Float64(None), + b.clone(), + )?, + ), + ( + Distribution::new_generic( + mean.clone(), + mean.clone(), + ScalarValue::Float64(None), + a.clone(), + )?, + Distribution::new_generic( + mean.clone(), + mean.clone(), + ScalarValue::Float64(None), + b.clone(), + )?, + ), + ] { + use super::Operator::{ + Divide, Eq, Gt, GtEq, Lt, LtEq, Minus, Multiply, NotEq, Plus, + }; + for op in [Plus, Minus, Multiply, Divide] { + assert_eq!( + new_generic_from_binary_op(&op, &dist_a, &dist_b)?.range()?, + apply_operator(&op, a, b)?, + "Failed for {:?} {op} {:?}", + dist_a, + dist_b + ); + } + for op in [Gt, GtEq, Lt, LtEq, Eq, NotEq] { + assert_eq!( + create_bernoulli_from_comparison(&op, &dist_a, &dist_b)?.range()?, + apply_operator(&op, a, b)?, + "Failed for {:?} {op} {:?}", + dist_a, + dist_b + ); + } + } + + Ok(()) + } + + fn operator_set() -> HashSet<Operator> { + use super::Operator::*; + + let all_ops = vec![ + And, + Or, + Eq, + NotEq, + Gt, + GtEq, + Lt, + LtEq, + Plus, + Minus, + Multiply, + Divide, + Modulo, + IsDistinctFrom, + IsNotDistinctFrom, + RegexMatch, + RegexIMatch, + RegexNotMatch, + RegexNotIMatch, + LikeMatch, + ILikeMatch, + NotLikeMatch, + NotILikeMatch, + BitwiseAnd, + BitwiseOr, + BitwiseXor, + BitwiseShiftRight, + BitwiseShiftLeft, + StringConcat, + AtArrow, + ArrowAt, + ]; + + all_ops.into_iter().collect() + } +} diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index d2ea6e809150..02684928bac7 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -51,7 +51,6 @@ pub mod function; pub mod groups_accumulator { pub use datafusion_expr_common::groups_accumulator::*; } - pub mod interval_arithmetic { pub use datafusion_expr_common::interval_arithmetic::*; } @@ -62,6 +61,9 @@ pub mod simplify; pub mod sort_properties { pub use datafusion_expr_common::sort_properties::*; } +pub mod statistics { + pub use datafusion_expr_common::statistics::*; +} pub mod test; pub mod tree_node; pub mod type_coercion; diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index b1b889136b35..cc2ff2f24790 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -26,10 +26,13 @@ use arrow::array::BooleanArray; use arrow::compute::filter_record_batch; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; -use datafusion_common::{internal_err, not_impl_err, Result}; +use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue}; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; use datafusion_expr_common::sort_properties::ExprProperties; +use datafusion_expr_common::statistics::Distribution; + +use itertools::izip; /// Shared [`PhysicalExpr`]. pub type PhysicalExprRef = Arc<dyn PhysicalExpr>; @@ -98,11 +101,16 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + DynEq + DynHash { /// Computes the output interval for the expression, given the input /// intervals. /// - /// # Arguments + /// # Parameters /// /// * `children` are the intervals for the children (inputs) of this /// expression. /// + /// # Returns + /// + /// A `Result` containing the output interval for the expression in + /// case of success, or an error object in case of failure. + /// /// # Example /// /// If the expression is `a + b`, and the input intervals are `a: [1, 2]` @@ -116,19 +124,20 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + DynEq + DynHash { /// /// This is used to propagate constraints down through an expression tree. /// - /// # Arguments + /// # Parameters /// /// * `interval` is the currently known interval for this expression. /// * `children` are the current intervals for the children of this expression. /// /// # Returns /// - /// A `Vec` of new intervals for the children, in order. + /// A `Result` containing a `Vec` of new intervals for the children (in order) + /// in case of success, or an error object in case of failure. /// /// If constraint propagation reveals an infeasibility for any child, returns - /// [`None`]. If none of the children intervals change as a result of propagation, - /// may return an empty vector instead of cloning `children`. This is the default - /// (and conservative) return value. + /// [`None`]. If none of the children intervals change as a result of + /// propagation, may return an empty vector instead of cloning `children`. + /// This is the default (and conservative) return value. /// /// # Example /// @@ -144,6 +153,111 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + DynEq + DynHash { Ok(Some(vec![])) } + /// Computes the output statistics for the expression, given the input + /// statistics. + /// + /// # Parameters + /// + /// * `children` are the statistics for the children (inputs) of this + /// expression. + /// + /// # Returns + /// + /// A `Result` containing the output statistics for the expression in + /// case of success, or an error object in case of failure. + /// + /// Expressions (should) implement this function and utilize the independence + /// assumption, match on children distribution types and compute the output + /// statistics accordingly. The default implementation simply creates an + /// unknown output distribution by combining input ranges. This logic loses + /// distribution information, but is a safe default. + fn evaluate_statistics(&self, children: &[&Distribution]) -> Result<Distribution> { + let children_ranges = children + .iter() + .map(|c| c.range()) + .collect::<Result<Vec<_>>>()?; + let children_ranges_refs = children_ranges.iter().collect::<Vec<_>>(); + let output_interval = self.evaluate_bounds(children_ranges_refs.as_slice())?; + let dt = output_interval.data_type(); + if dt.eq(&DataType::Boolean) { + let p = if output_interval.eq(&Interval::CERTAINLY_TRUE) { + ScalarValue::new_one(&dt) + } else if output_interval.eq(&Interval::CERTAINLY_FALSE) { + ScalarValue::new_zero(&dt) + } else { + ScalarValue::try_from(&dt) + }?; + Distribution::new_bernoulli(p) + } else { + Distribution::new_from_interval(output_interval) + } + } + + /// Updates children statistics using the given parent statistic for this + /// expression. + /// + /// This is used to propagate statistics down through an expression tree. + /// + /// # Parameters + /// + /// * `parent` is the currently known statistics for this expression. + /// * `children` are the current statistics for the children of this expression. + /// + /// # Returns + /// + /// A `Result` containing a `Vec` of new statistics for the children (in order) + /// in case of success, or an error object in case of failure. + /// + /// If statistics propagation reveals an infeasibility for any child, returns + /// [`None`]. If none of the children statistics change as a result of + /// propagation, may return an empty vector instead of cloning `children`. + /// This is the default (and conservative) return value. + /// + /// Expressions (should) implement this function and apply Bayes rule to + /// reconcile and update parent/children statistics. This involves utilizing + /// the independence assumption, and matching on distribution types. The + /// default implementation simply creates an unknown distribution if it can + /// narrow the range by propagating ranges. This logic loses distribution + /// information, but is a safe default. + fn propagate_statistics( + &self, + parent: &Distribution, + children: &[&Distribution], + ) -> Result<Option<Vec<Distribution>>> { + let children_ranges = children + .iter() + .map(|c| c.range()) + .collect::<Result<Vec<_>>>()?; + let children_ranges_refs = children_ranges.iter().collect::<Vec<_>>(); + let parent_range = parent.range()?; + let Some(propagated_children) = + self.propagate_constraints(&parent_range, children_ranges_refs.as_slice())? + else { + return Ok(None); + }; + izip!(propagated_children.into_iter(), children_ranges, children) + .map(|(new_interval, old_interval, child)| { + if new_interval == old_interval { + // We weren't able to narrow the range, preserve the old statistics. + Ok((*child).clone()) + } else if new_interval.data_type().eq(&DataType::Boolean) { + let dt = old_interval.data_type(); + let p = if new_interval.eq(&Interval::CERTAINLY_TRUE) { + ScalarValue::new_one(&dt) + } else if new_interval.eq(&Interval::CERTAINLY_FALSE) { + ScalarValue::new_zero(&dt) + } else { + unreachable!("Given that we have a range reduction for a boolean interval, we should have certainty") + }?; + Distribution::new_bernoulli(p) + } else { + Distribution::new_from_interval(new_interval) + } + }) + .collect::<Result<_>>() + .map(Some) + } + /// Calculates the properties of this [`PhysicalExpr`] based on its /// children's properties (i.e. order and range), recursively aggregating /// the information from its children. In cases where the [`PhysicalExpr`] @@ -155,7 +269,7 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + DynEq + DynHash { } /// [`PhysicalExpr`] can't be constrained by [`Eq`] directly because it must remain object -/// safe. To ease implementation blanket implementation is provided for [`Eq`] types. +/// safe. To ease implementation, blanket implementation is provided for [`Eq`] types. pub trait DynEq { fn dyn_eq(&self, other: &dyn Any) -> bool; } diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 052054bad6c1..1f16c5471ed7 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -20,6 +20,7 @@ mod kernels; use std::hash::Hash; use std::{any::Any, sync::Arc}; +use crate::expressions::binary::kernels::concat_elements_utf8view; use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison}; use crate::PhysicalExpr; @@ -36,10 +37,14 @@ use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::binary::BinaryTypeCoercer; use datafusion_expr::interval_arithmetic::{apply_operator, Interval}; use datafusion_expr::sort_properties::ExprProperties; +use datafusion_expr::statistics::Distribution::{Bernoulli, Gaussian}; +use datafusion_expr::statistics::{ + combine_bernoullis, combine_gaussians, create_bernoulli_from_comparison, + new_generic_from_binary_op, Distribution, +}; use datafusion_expr::{ColumnarValue, Operator}; use datafusion_physical_expr_common::datum::{apply, apply_cmp, apply_cmp_for_nested}; -use crate::expressions::binary::kernels::concat_elements_utf8view; use kernels::{ bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn, bitwise_or_dyn_scalar, bitwise_shift_left_dyn, bitwise_shift_left_dyn_scalar, bitwise_shift_right_dyn, @@ -486,6 +491,37 @@ impl PhysicalExpr for BinaryExpr { } } + fn evaluate_statistics(&self, children: &[&Distribution]) -> Result<Distribution> { + let (left, right) = (children[0], children[1]); + + if self.op.is_numerical_operators() { + // We might be able to construct the output statistics more accurately, + // without falling back to an unknown distribution, if we are dealing + // with Gaussian distributions and numerical operations. + if let (Gaussian(left), Gaussian(right)) = (left, right) { + if let Some(result) = combine_gaussians(&self.op, left, right)? { + return Ok(Gaussian(result)); + } + } + } else if self.op.is_logic_operator() { + // If we are dealing with logical operators, we expect (and can only + // operate on) Bernoulli distributions. + return if let (Bernoulli(left), Bernoulli(right)) = (left, right) { + combine_bernoullis(&self.op, left, right).map(Bernoulli) + } else { + internal_err!( + "Logical operators are only compatible with `Bernoulli` distributions" + ) + }; + } else if self.op.supports_propagation() { + // If we are handling comparison operators, we expect (and can only + // operate on) numeric distributions. + return create_bernoulli_from_comparison(&self.op, left, right); + } + // Fall back to an unknown distribution with only summary statistics: + new_generic_from_binary_op(&self.op, left, right) + } + /// For each operator, [`BinaryExpr`] has distinct rules. /// TODO: There may be rules specific to some data types and expression ranges. fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> { @@ -732,6 +768,7 @@ pub fn similar_to( mod tests { use super::*; use crate::expressions::{col, lit, try_cast, Column, Literal}; + use datafusion_common::plan_datafusion_err; /// Performs a binary operation, applying any type coercion necessary @@ -4379,4 +4416,260 @@ mod tests { ) .unwrap(); } + + pub fn binary_expr( + left: Arc<dyn PhysicalExpr>, + op: Operator, + right: Arc<dyn PhysicalExpr>, + schema: &Schema, + ) -> Result<BinaryExpr> { + Ok(binary_op(left, op, right, schema)? + .as_any() + .downcast_ref::<BinaryExpr>() + .unwrap() + .clone()) + } + + /// Test for Uniform-Uniform, Unknown-Uniform, Uniform-Unknown and Unknown-Unknown evaluation. + #[test] + fn test_evaluate_statistics_combination_of_range_holders() -> Result<()> { + let schema = &Schema::new(vec![Field::new("a", DataType::Float64, false)]); + let a = Arc::new(Column::new("a", 0)) as _; + let b = lit(ScalarValue::from(12.0)); + + let left_interval = Interval::make(Some(0.0), Some(12.0))?; + let right_interval = Interval::make(Some(12.0), Some(36.0))?; + let (left_mean, right_mean) = (ScalarValue::from(6.0), ScalarValue::from(24.0)); + let (left_med, right_med) = (ScalarValue::from(6.0), ScalarValue::from(24.0)); + + for children in [ + vec![ + &Distribution::new_uniform(left_interval.clone())?, + &Distribution::new_uniform(right_interval.clone())?, + ], + vec![ + &Distribution::new_generic( + left_mean.clone(), + left_med.clone(), + ScalarValue::Float64(None), + left_interval.clone(), + )?, + &Distribution::new_uniform(right_interval.clone())?, + ], + vec![ + &Distribution::new_uniform(right_interval.clone())?, + &Distribution::new_generic( + right_mean.clone(), + right_med.clone(), + ScalarValue::Float64(None), + right_interval.clone(), + )?, + ], + vec![ + &Distribution::new_generic( + left_mean.clone(), + left_med.clone(), + ScalarValue::Float64(None), + left_interval.clone(), + )?, + &Distribution::new_generic( + right_mean.clone(), + right_med.clone(), + ScalarValue::Float64(None), + right_interval.clone(), + )?, + ], + ] { + let ops = vec![ + Operator::Plus, + Operator::Minus, + Operator::Multiply, + Operator::Divide, + ]; + + for op in ops { + let expr = binary_expr(Arc::clone(&a), op, Arc::clone(&b), schema)?; + assert_eq!( + expr.evaluate_statistics(&children)?, + new_generic_from_binary_op(&op, children[0], children[1])? + ); + } + } + Ok(()) + } + + #[test] + fn test_evaluate_statistics_bernoulli() -> Result<()> { + let schema = &Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Int64, false), + ]); + let a = Arc::new(Column::new("a", 0)) as _; + let b = Arc::new(Column::new("b", 1)) as _; + let eq = Arc::new(binary_expr( + Arc::clone(&a), + Operator::Eq, + Arc::clone(&b), + schema, + )?); + let neq = Arc::new(binary_expr(a, Operator::NotEq, b, schema)?); + + let left_stat = &Distribution::new_uniform(Interval::make(Some(0), Some(7))?)?; + let right_stat = &Distribution::new_uniform(Interval::make(Some(4), Some(11))?)?; + + // Intervals: [0, 7], [4, 11]. + // The intersection is [4, 7], so the probability of equality is 4 / 64 = 1 / 16. + assert_eq!( + eq.evaluate_statistics(&[left_stat, right_stat])?, + Distribution::new_bernoulli(ScalarValue::from(1.0 / 16.0))? + ); + + // The probability of being distinct is 1 - 1 / 16 = 15 / 16. + assert_eq!( + neq.evaluate_statistics(&[left_stat, right_stat])?, + Distribution::new_bernoulli(ScalarValue::from(15.0 / 16.0))? + ); + + Ok(()) + } + + #[test] + fn test_propagate_statistics_combination_of_range_holders_arithmetic() -> Result<()> { + let schema = &Schema::new(vec![Field::new("a", DataType::Float64, false)]); + let a = Arc::new(Column::new("a", 0)) as _; + let b = lit(ScalarValue::from(12.0)); + + let left_interval = Interval::make(Some(0.0), Some(12.0))?; + let right_interval = Interval::make(Some(12.0), Some(36.0))?; + + let parent = Distribution::new_uniform(Interval::make(Some(-432.), Some(432.))?)?; + let children = vec![ + vec![ + Distribution::new_uniform(left_interval.clone())?, + Distribution::new_uniform(right_interval.clone())?, + ], + vec![ + Distribution::new_generic( + ScalarValue::from(6.), + ScalarValue::from(6.), + ScalarValue::Float64(None), + left_interval.clone(), + )?, + Distribution::new_uniform(right_interval.clone())?, + ], + vec![ + Distribution::new_uniform(left_interval.clone())?, + Distribution::new_generic( + ScalarValue::from(12.), + ScalarValue::from(12.), + ScalarValue::Float64(None), + right_interval.clone(), + )?, + ], + vec![ + Distribution::new_generic( + ScalarValue::from(6.), + ScalarValue::from(6.), + ScalarValue::Float64(None), + left_interval.clone(), + )?, + Distribution::new_generic( + ScalarValue::from(12.), + ScalarValue::from(12.), + ScalarValue::Float64(None), + right_interval.clone(), + )?, + ], + ]; + + let ops = vec![ + Operator::Plus, + Operator::Minus, + Operator::Multiply, + Operator::Divide, + ]; + + for child_view in children { + let child_refs = child_view.iter().collect::<Vec<_>>(); + for op in &ops { + let expr = binary_expr(Arc::clone(&a), *op, Arc::clone(&b), schema)?; + assert_eq!( + expr.propagate_statistics(&parent, child_refs.as_slice())?, + Some(child_view.clone()) + ); + } + } + Ok(()) + } + + #[test] + fn test_propagate_statistics_combination_of_range_holders_comparison() -> Result<()> { + let schema = &Schema::new(vec![Field::new("a", DataType::Float64, false)]); + let a = Arc::new(Column::new("a", 0)) as _; + let b = lit(ScalarValue::from(12.0)); + + let left_interval = Interval::make(Some(0.0), Some(12.0))?; + let right_interval = Interval::make(Some(6.0), Some(18.0))?; + + let one = ScalarValue::from(1.0); + let parent = Distribution::new_bernoulli(one)?; + let children = vec![ + vec![ + Distribution::new_uniform(left_interval.clone())?, + Distribution::new_uniform(right_interval.clone())?, + ], + vec![ + Distribution::new_generic( + ScalarValue::from(6.), + ScalarValue::from(6.), + ScalarValue::Float64(None), + left_interval.clone(), + )?, + Distribution::new_uniform(right_interval.clone())?, + ], + vec![ + Distribution::new_uniform(left_interval.clone())?, + Distribution::new_generic( + ScalarValue::from(12.), + ScalarValue::from(12.), + ScalarValue::Float64(None), + right_interval.clone(), + )?, + ], + vec![ + Distribution::new_generic( + ScalarValue::from(6.), + ScalarValue::from(6.), + ScalarValue::Float64(None), + left_interval.clone(), + )?, + Distribution::new_generic( + ScalarValue::from(12.), + ScalarValue::from(12.), + ScalarValue::Float64(None), + right_interval.clone(), + )?, + ], + ]; + + let ops = vec![ + Operator::Eq, + Operator::Gt, + Operator::GtEq, + Operator::Lt, + Operator::LtEq, + ]; + + for child_view in children { + let child_refs = child_view.iter().collect::<Vec<_>>(); + for op in &ops { + let expr = binary_expr(Arc::clone(&a), *op, Arc::clone(&b), schema)?; + assert!(expr + .propagate_statistics(&parent, child_refs.as_slice())? + .is_some()); + } + } + + Ok(()) + } } diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index dc863ccff511..8795545274a2 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -28,9 +28,12 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; -use datafusion_common::{plan_err, Result}; +use datafusion_common::{internal_err, plan_err, Result}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; +use datafusion_expr::statistics::Distribution::{ + self, Bernoulli, Exponential, Gaussian, Generic, Uniform, +}; use datafusion_expr::{ type_coercion::{is_interval, is_null, is_signed_numeric, is_timestamp}, ColumnarValue, @@ -89,14 +92,13 @@ impl PhysicalExpr for NegativeExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> { - let arg = self.arg.evaluate(batch)?; - match arg { + match self.arg.evaluate(batch)? { ColumnarValue::Array(array) => { let result = neg_wrapping(array.as_ref())?; Ok(ColumnarValue::Array(result)) } ColumnarValue::Scalar(scalar) => { - Ok(ColumnarValue::Scalar((scalar.arithmetic_negate())?)) + Ok(ColumnarValue::Scalar(scalar.arithmetic_negate()?)) } } } @@ -116,10 +118,7 @@ impl PhysicalExpr for NegativeExpr { /// It replaces the upper and lower bounds after multiplying them with -1. /// Ex: `(a, b]` => `[-b, -a)` fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> { - Interval::try_new( - children[0].upper().arithmetic_negate()?, - children[0].lower().arithmetic_negate()?, - ) + children[0].arithmetic_negate() } /// Returns a new [`Interval`] of a NegativeExpr that has the existing `interval` given that @@ -129,17 +128,37 @@ impl PhysicalExpr for NegativeExpr { interval: &Interval, children: &[&Interval], ) -> Result<Option<Vec<Interval>>> { - let child_interval = children[0]; - let negated_interval = Interval::try_new( - interval.upper().arithmetic_negate()?, - interval.lower().arithmetic_negate()?, - )?; + let negated_interval = interval.arithmetic_negate()?; - Ok(child_interval + Ok(children[0] .intersect(negated_interval)? .map(|result| vec![result])) } + fn evaluate_statistics(&self, children: &[&Distribution]) -> Result<Distribution> { + match children[0] { + Uniform(u) => Distribution::new_uniform(u.range().arithmetic_negate()?), + Exponential(e) => Distribution::new_exponential( + e.rate().clone(), + e.offset().arithmetic_negate()?, + !e.positive_tail(), + ), + Gaussian(g) => Distribution::new_gaussian( + g.mean().arithmetic_negate()?, + g.variance().clone(), + ), + Bernoulli(_) => { + internal_err!("NegativeExpr cannot operate on Boolean datatypes") + } + Generic(u) => Distribution::new_generic( + u.mean().arithmetic_negate()?, + u.median().arithmetic_negate()?, + u.variance().clone(), + u.range().arithmetic_negate()?, + ), + } + } + /// The ordering of a [`NegativeExpr`] is simply the reverse of its child. fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> { Ok(ExprProperties { @@ -181,7 +200,7 @@ mod tests { use arrow::datatypes::DataType::{Float32, Float64, Int16, Int32, Int64, Int8}; use arrow::datatypes::*; use datafusion_common::cast::as_primitive_array; - use datafusion_common::DataFusionError; + use datafusion_common::{DataFusionError, ScalarValue}; use paste::paste; @@ -233,6 +252,67 @@ mod tests { Ok(()) } + #[test] + fn test_evaluate_statistics() -> Result<()> { + let negative_expr = NegativeExpr::new(Arc::new(Column::new("a", 0))); + + // Uniform + assert_eq!( + negative_expr.evaluate_statistics(&[&Distribution::new_uniform( + Interval::make(Some(-2.), Some(3.))? + )?])?, + Distribution::new_uniform(Interval::make(Some(-3.), Some(2.))?)? + ); + + // Bernoulli + assert!(negative_expr + .evaluate_statistics(&[&Distribution::new_bernoulli(ScalarValue::from( + 0.75 + ))?]) + .is_err()); + + // Exponential + assert_eq!( + negative_expr.evaluate_statistics(&[&Distribution::new_exponential( + ScalarValue::from(1.), + ScalarValue::from(1.), + true + )?])?, + Distribution::new_exponential( + ScalarValue::from(1.), + ScalarValue::from(-1.), + false + )? + ); + + // Gaussian + assert_eq!( + negative_expr.evaluate_statistics(&[&Distribution::new_gaussian( + ScalarValue::from(15), + ScalarValue::from(225), + )?])?, + Distribution::new_gaussian(ScalarValue::from(-15), ScalarValue::from(225),)? + ); + + // Unknown + assert_eq!( + negative_expr.evaluate_statistics(&[&Distribution::new_generic( + ScalarValue::from(15), + ScalarValue::from(15), + ScalarValue::from(10), + Interval::make(Some(10), Some(20))? + )?])?, + Distribution::new_generic( + ScalarValue::from(-15), + ScalarValue::from(-15), + ScalarValue::from(10), + Interval::make(Some(-20), Some(-10))? + )? + ); + + Ok(()) + } + #[test] fn test_propagate_constraints() -> Result<()> { let negative_expr = NegativeExpr::new(Arc::new(Column::new("a", 0))); @@ -249,6 +329,35 @@ mod tests { Ok(()) } + #[test] + fn test_propagate_statistics_range_holders() -> Result<()> { + let negative_expr = NegativeExpr::new(Arc::new(Column::new("a", 0))); + let original_child_interval = Interval::make(Some(-2), Some(3))?; + let after_propagation = Interval::make(Some(-2), Some(0))?; + + let parent = Distribution::new_uniform(Interval::make(Some(0), Some(4))?)?; + let children: Vec<Vec<Distribution>> = vec![ + vec![Distribution::new_uniform(original_child_interval.clone())?], + vec![Distribution::new_generic( + ScalarValue::from(0), + ScalarValue::from(0), + ScalarValue::Int32(None), + original_child_interval.clone(), + )?], + ]; + + for child_view in children { + let child_refs: Vec<_> = child_view.iter().collect(); + let actual = negative_expr.propagate_statistics(&parent, &child_refs)?; + let expected = Some(vec![Distribution::new_from_interval( + after_propagation.clone(), + )?]); + assert_eq!(actual, expected); + } + + Ok(()) + } + #[test] fn test_negation_valid_types() -> Result<()> { let negatable_types = [ diff --git a/datafusion/physical-expr/src/expressions/not.rs b/datafusion/physical-expr/src/expressions/not.rs index 440c4e9557bd..ddf7c739b692 100644 --- a/datafusion/physical-expr/src/expressions/not.rs +++ b/datafusion/physical-expr/src/expressions/not.rs @@ -23,10 +23,12 @@ use std::hash::Hash; use std::sync::Arc; use crate::PhysicalExpr; + use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; -use datafusion_common::{cast::as_boolean_array, Result, ScalarValue}; +use datafusion_common::{cast::as_boolean_array, internal_err, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::statistics::Distribution::{self, Bernoulli}; use datafusion_expr::ColumnarValue; /// Not expression @@ -82,8 +84,7 @@ impl PhysicalExpr for NotExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> { - let evaluate_arg = self.arg.evaluate(batch)?; - match evaluate_arg { + match self.arg.evaluate(batch)? { ColumnarValue::Array(array) => { let array = as_boolean_array(&array)?; Ok(ColumnarValue::Array(Arc::new( @@ -95,9 +96,7 @@ impl PhysicalExpr for NotExpr { return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))); } let bool_value: bool = scalar.try_into()?; - Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some( - !bool_value, - )))) + Ok(ColumnarValue::Scalar(ScalarValue::from(!bool_value))) } } } @@ -112,9 +111,70 @@ impl PhysicalExpr for NotExpr { ) -> Result<Arc<dyn PhysicalExpr>> { Ok(Arc::new(NotExpr::new(Arc::clone(&children[0])))) } + fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> { children[0].not() } + + fn propagate_constraints( + &self, + interval: &Interval, + children: &[&Interval], + ) -> Result<Option<Vec<Interval>>> { + let complemented_interval = interval.not()?; + + Ok(children[0] + .intersect(complemented_interval)? + .map(|result| vec![result])) + } + + fn evaluate_statistics(&self, children: &[&Distribution]) -> Result<Distribution> { + match children[0] { + Bernoulli(b) => { + let p_value = b.p_value(); + if p_value.is_null() { + Ok(children[0].clone()) + } else { + let one = ScalarValue::new_one(&p_value.data_type())?; + Distribution::new_bernoulli(one.sub_checked(p_value)?) + } + } + _ => internal_err!("NotExpr can only operate on Boolean datatypes"), + } + } + + fn propagate_statistics( + &self, + parent: &Distribution, + children: &[&Distribution], + ) -> Result<Option<Vec<Distribution>>> { + match (parent, children[0]) { + (Bernoulli(parent), Bernoulli(child)) => { + let parent_range = parent.range(); + let result = if parent_range == Interval::CERTAINLY_TRUE { + if child.range() == Interval::CERTAINLY_TRUE { + None + } else { + Some(vec![Distribution::new_bernoulli(ScalarValue::new_zero( + &child.data_type(), + )?)?]) + } + } else if parent_range == Interval::CERTAINLY_FALSE { + if child.range() == Interval::CERTAINLY_FALSE { + None + } else { + Some(vec![Distribution::new_bernoulli(ScalarValue::new_one( + &child.data_type(), + )?)?]) + } + } else { + Some(vec![]) + }; + Ok(result) + } + _ => internal_err!("NotExpr can only operate on Boolean datatypes"), + } + } } /// Creates a unary expression NOT @@ -124,10 +184,12 @@ pub fn not(arg: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> { #[cfg(test)] mod tests { + use std::sync::LazyLock; + use super::*; - use crate::expressions::col; + use crate::expressions::{col, Column}; + use arrow::{array::BooleanArray, datatypes::*}; - use std::sync::LazyLock; #[test] fn neg_op() -> Result<()> { @@ -182,10 +244,81 @@ mod tests { expected_interval: Interval, ) -> Result<()> { let not_expr = not(col("a", &schema())?)?; + assert_eq!(not_expr.evaluate_bounds(&[&interval])?, expected_interval); + Ok(()) + } + + #[test] + fn test_evaluate_statistics() -> Result<()> { + let _schema = &Schema::new(vec![Field::new("a", DataType::Boolean, false)]); + let a = Arc::new(Column::new("a", 0)) as _; + let expr = not(a)?; + + // Uniform with non-boolean bounds + assert!(expr + .evaluate_statistics(&[&Distribution::new_uniform( + Interval::make_unbounded(&DataType::Float64)? + )?]) + .is_err()); + + // Exponential + assert!(expr + .evaluate_statistics(&[&Distribution::new_exponential( + ScalarValue::from(1.0), + ScalarValue::from(1.0), + true + )?]) + .is_err()); + + // Gaussian + assert!(expr + .evaluate_statistics(&[&Distribution::new_gaussian( + ScalarValue::from(1.0), + ScalarValue::from(1.0), + )?]) + .is_err()); + + // Bernoulli assert_eq!( - not_expr.evaluate_bounds(&[&interval]).unwrap(), - expected_interval + expr.evaluate_statistics(&[&Distribution::new_bernoulli( + ScalarValue::from(0.0), + )?])?, + Distribution::new_bernoulli(ScalarValue::from(1.))? ); + + assert_eq!( + expr.evaluate_statistics(&[&Distribution::new_bernoulli( + ScalarValue::from(1.0), + )?])?, + Distribution::new_bernoulli(ScalarValue::from(0.))? + ); + + assert_eq!( + expr.evaluate_statistics(&[&Distribution::new_bernoulli( + ScalarValue::from(0.25), + )?])?, + Distribution::new_bernoulli(ScalarValue::from(0.75))? + ); + + assert!(expr + .evaluate_statistics(&[&Distribution::new_generic( + ScalarValue::Null, + ScalarValue::Null, + ScalarValue::Null, + Interval::make_unbounded(&DataType::UInt8)? + )?]) + .is_err()); + + // Unknown with non-boolean interval as range + assert!(expr + .evaluate_statistics(&[&Distribution::new_generic( + ScalarValue::Null, + ScalarValue::Null, + ScalarValue::Null, + Interval::make_unbounded(&DataType::Float64)? + )?]) + .is_err()); + Ok(()) } diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index cb29109684fe..a53814c3ad2b 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -15,7 +15,130 @@ // specific language governing permissions and limitations // under the License. -//! Constraint propagator/solver for custom PhysicalExpr graphs. +//! Constraint propagator/solver for custom [`PhysicalExpr`] graphs. +//! +//! The constraint propagator/solver in DataFusion uses interval arithmetic to +//! perform mathematical operations on intervals, which represent a range of +//! possible values rather than a single point value. This allows for the +//! propagation of ranges through mathematical operations, and can be used to +//! compute bounds for a complicated expression. The key idea is that by +//! breaking down a complicated expression into simpler terms, and then +//! combining the bounds for those simpler terms, one can obtain bounds for the +//! overall expression. +//! +//! This way of using interval arithmetic to compute bounds for a complex +//! expression by combining the bounds for the constituent terms within the +//! original expression allows us to reason about the range of possible values +//! of the expression. This information later can be used in range pruning of +//! the provably unnecessary parts of `RecordBatch`es. +//! +//! # Example +//! +//! For example, consider a mathematical expression such as `x^2 + y = 4` \[1\]. +//! Since this expression would be a binary tree in [`PhysicalExpr`] notation, +//! this type of an hierarchical computation is well-suited for a graph based +//! implementation. In such an implementation, an equation system `f(x) = 0` is +//! represented by a directed acyclic expression graph (DAEG). +//! +//! In order to use interval arithmetic to compute bounds for this expression, +//! one would first determine intervals that represent the possible values of +//! `x` and `y`` Let's say that the interval for `x` is `[1, 2]` and the interval +//! for `y` is `[-3, 1]`. In the chart below, you can see how the computation +//! takes place. +//! +//! # References +//! +//! 1. Kabak, Mehmet Ozan. Analog Circuit Start-Up Behavior Analysis: An Interval +//! Arithmetic Based Approach, Chapter 4. Stanford University, 2015. +//! 2. Moore, Ramon E. Interval analysis. Vol. 4. Englewood Cliffs: Prentice-Hall, 1966. +//! 3. F. Messine, "Deterministic global optimization using interval constraint +//! propagation techniques," RAIRO-Operations Research, vol. 38, no. 04, +//! pp. 277-293, 2004. +//! +//! # Illustration +//! +//! ## Computing bounds for an expression using interval arithmetic +//! +//! ```text +//! +-----+ +-----+ +//! +----| + |----+ +----| + |----+ +//! | | | | | | | | +//! | +-----+ | | +-----+ | +//! | | | | +//! +-----+ +-----+ +-----+ +-----+ +//! | 2 | | y | | 2 | [1, 4] | y | +//! |[.] | | | |[.] | | | +//! +-----+ +-----+ +-----+ +-----+ +//! | | +//! | | +//! +---+ +---+ +//! | x | [1, 2] | x | [1, 2] +//! +---+ +---+ +//! +//! (a) Bottom-up evaluation: Step 1 (b) Bottom up evaluation: Step 2 +//! +//! [1 - 3, 4 + 1] = [-2, 5] +//! +-----+ +-----+ +//! +----| + |----+ +----| + |----+ +//! | | | | | | | | +//! | +-----+ | | +-----+ | +//! | | | | +//! +-----+ +-----+ +-----+ +-----+ +//! | 2 |[1, 4] | y | | 2 |[1, 4] | y | +//! |[.] | | | |[.] | | | +//! +-----+ +-----+ +-----+ +-----+ +//! | [-3, 1] | [-3, 1] +//! | | +//! +---+ +---+ +//! | x | [1, 2] | x | [1, 2] +//! +---+ +---+ +//! +//! (c) Bottom-up evaluation: Step 3 (d) Bottom-up evaluation: Step 4 +//! ``` +//! +//! ## Top-down constraint propagation using inverse semantics +//! +//! ```text +//! [-2, 5] ∩ [4, 4] = [4, 4] [4, 4] +//! +-----+ +-----+ +//! +----| + |----+ +----| + |----+ +//! | | | | | | | | +//! | +-----+ | | +-----+ | +//! | | | | +//! +-----+ +-----+ +-----+ +-----+ +//! | 2 | [1, 4] | y | | 2 | [1, 4] | y | [0, 1]* +//! |[.] | | | |[.] | | | +//! +-----+ +-----+ +-----+ +-----+ +//! | [-3, 1] | +//! | | +//! +---+ +---+ +//! | x | [1, 2] | x | [1, 2] +//! +---+ +---+ +//! +//! (a) Top-down propagation: Step 1 (b) Top-down propagation: Step 2 +//! +//! [1 - 3, 4 + 1] = [-2, 5] +//! +-----+ +-----+ +//! +----| + |----+ +----| + |----+ +//! | | | | | | | | +//! | +-----+ | | +-----+ | +//! | | | | +//! +-----+ +-----+ +-----+ +-----+ +//! | 2 |[3, 4]** | y | | 2 |[3, 4] | y | +//! |[.] | | | |[.] | | | +//! +-----+ +-----+ +-----+ +-----+ +//! | [0, 1] | [-3, 1] +//! | | +//! +---+ +---+ +//! | x | [1, 2] | x | [sqrt(3), 2]*** +//! +---+ +---+ +//! +//! (c) Top-down propagation: Step 3 (d) Top-down propagation: Step 4 +//! +//! * [-3, 1] ∩ ([4, 4] - [1, 4]) = [0, 1] +//! ** [1, 4] ∩ ([4, 4] - [0, 1]) = [3, 4] +//! *** [1, 2] ∩ [sqrt(3), sqrt(4)] = [sqrt(3), 2] +//! ``` use std::collections::HashSet; use std::fmt::{Display, Formatter}; @@ -39,84 +162,6 @@ use petgraph::stable_graph::{DefaultIx, StableGraph}; use petgraph::visit::{Bfs, Dfs, DfsPostOrder, EdgeRef}; use petgraph::Outgoing; -// Interval arithmetic provides a way to perform mathematical operations on -// intervals, which represent a range of possible values rather than a single -// point value. This allows for the propagation of ranges through mathematical -// operations, and can be used to compute bounds for a complicated expression. -// The key idea is that by breaking down a complicated expression into simpler -// terms, and then combining the bounds for those simpler terms, one can -// obtain bounds for the overall expression. -// -// For example, consider a mathematical expression such as x^2 + y = 4. Since -// it would be a binary tree in [PhysicalExpr] notation, this type of an -// hierarchical computation is well-suited for a graph based implementation. -// In such an implementation, an equation system f(x) = 0 is represented by a -// directed acyclic expression graph (DAEG). -// -// In order to use interval arithmetic to compute bounds for this expression, -// one would first determine intervals that represent the possible values of x -// and y. Let's say that the interval for x is [1, 2] and the interval for y -// is [-3, 1]. In the chart below, you can see how the computation takes place. -// -// This way of using interval arithmetic to compute bounds for a complex -// expression by combining the bounds for the constituent terms within the -// original expression allows us to reason about the range of possible values -// of the expression. This information later can be used in range pruning of -// the provably unnecessary parts of `RecordBatch`es. -// -// References -// 1 - Kabak, Mehmet Ozan. Analog Circuit Start-Up Behavior Analysis: An Interval -// Arithmetic Based Approach, Chapter 4. Stanford University, 2015. -// 2 - Moore, Ramon E. Interval analysis. Vol. 4. Englewood Cliffs: Prentice-Hall, 1966. -// 3 - F. Messine, "Deterministic global optimization using interval constraint -// propagation techniques," RAIRO-Operations Research, vol. 38, no. 04, -// pp. 277{293, 2004. -// -// ``` text -// Computing bounds for an expression using interval arithmetic. Constraint propagation through a top-down evaluation of the expression -// graph using inverse semantics. -// -// [-2, 5] ∩ [4, 4] = [4, 4] [4, 4] -// +-----+ +-----+ +-----+ +-----+ -// +----| + |----+ +----| + |----+ +----| + |----+ +----| + |----+ -// | | | | | | | | | | | | | | | | -// | +-----+ | | +-----+ | | +-----+ | | +-----+ | -// | | | | | | | | -// +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ -// | 2 | | y | | 2 | [1, 4] | y | | 2 | [1, 4] | y | | 2 | [1, 4] | y | [0, 1]* -// |[.] | | | |[.] | | | |[.] | | | |[.] | | | -// +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ -// | | | [-3, 1] | -// | | | | -// +---+ +---+ +---+ +---+ -// | x | [1, 2] | x | [1, 2] | x | [1, 2] | x | [1, 2] -// +---+ +---+ +---+ +---+ -// -// (a) Bottom-up evaluation: Step1 (b) Bottom up evaluation: Step2 (a) Top-down propagation: Step1 (b) Top-down propagation: Step2 -// -// [1 - 3, 4 + 1] = [-2, 5] [1 - 3, 4 + 1] = [-2, 5] -// +-----+ +-----+ +-----+ +-----+ -// +----| + |----+ +----| + |----+ +----| + |----+ +----| + |----+ -// | | | | | | | | | | | | | | | | -// | +-----+ | | +-----+ | | +-----+ | | +-----+ | -// | | | | | | | | -// +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ -// | 2 |[1, 4] | y | | 2 |[1, 4] | y | | 2 |[3, 4]** | y | | 2 |[1, 4] | y | -// |[.] | | | |[.] | | | |[.] | | | |[.] | | | -// +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ -// | [-3, 1] | [-3, 1] | [0, 1] | [-3, 1] -// | | | | -// +---+ +---+ +---+ +---+ -// | x | [1, 2] | x | [1, 2] | x | [1, 2] | x | [sqrt(3), 2]*** -// +---+ +---+ +---+ +---+ -// -// (c) Bottom-up evaluation: Step3 (d) Bottom-up evaluation: Step4 (c) Top-down propagation: Step3 (d) Top-down propagation: Step4 -// -// * [-3, 1] ∩ ([4, 4] - [1, 4]) = [0, 1] -// ** [1, 4] ∩ ([4, 4] - [0, 1]) = [3, 4] -// *** [1, 2] ∩ [sqrt(3), sqrt(4)] = [sqrt(3), 2] -// ``` - /// This object implements a directed acyclic expression graph (DAEG) that /// is used to compute ranges for expressions through interval arithmetic. #[derive(Clone, Debug)] @@ -125,18 +170,6 @@ pub struct ExprIntervalGraph { root: NodeIndex, } -impl ExprIntervalGraph { - /// Estimate size of bytes including `Self`. - pub fn size(&self) -> usize { - let node_memory_usage = self.graph.node_count() - * (size_of::<ExprIntervalGraphNode>() + size_of::<NodeIndex>()); - let edge_memory_usage = - self.graph.edge_count() * (size_of::<usize>() + size_of::<NodeIndex>() * 2); - - size_of_val(self) + node_memory_usage + edge_memory_usage - } -} - /// This object encapsulates all possible constraint propagation results. #[derive(PartialEq, Debug)] pub enum PropagationResult { @@ -153,6 +186,12 @@ pub struct ExprIntervalGraphNode { interval: Interval, } +impl PartialEq for ExprIntervalGraphNode { + fn eq(&self, other: &Self) -> bool { + self.expr.eq(&other.expr) + } +} + impl Display for ExprIntervalGraphNode { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.expr) @@ -160,7 +199,7 @@ impl Display for ExprIntervalGraphNode { } impl ExprIntervalGraphNode { - /// Constructs a new DAEG node with an [-∞, ∞] range. + /// Constructs a new DAEG node with an `[-∞, ∞]` range. pub fn new_unbounded(expr: Arc<dyn PhysicalExpr>, dt: &DataType) -> Result<Self> { Interval::make_unbounded(dt) .map(|interval| ExprIntervalGraphNode { expr, interval }) @@ -178,7 +217,7 @@ impl ExprIntervalGraphNode { /// This function creates a DAEG node from DataFusion's [`ExprTreeNode`] /// object. Literals are created with definite, singleton intervals while - /// any other expression starts with an indefinite interval ([-∞, ∞]). + /// any other expression starts with an indefinite interval (`[-∞, ∞]`). pub fn make_node(node: &ExprTreeNode<NodeIndex>, schema: &Schema) -> Result<Self> { let expr = Arc::clone(&node.expr); if let Some(literal) = expr.as_any().downcast_ref::<Literal>() { @@ -192,30 +231,24 @@ impl ExprIntervalGraphNode { } } -impl PartialEq for ExprIntervalGraphNode { - fn eq(&self, other: &Self) -> bool { - self.expr.eq(&other.expr) - } -} - /// This function refines intervals `left_child` and `right_child` by applying /// constraint propagation through `parent` via operation. The main idea is /// that we can shrink ranges of variables x and y using parent interval p. /// -/// Assuming that x,y and p has ranges [xL, xU], [yL, yU], and [pL, pU], we +/// Assuming that x,y and p has ranges `[xL, xU]`, `[yL, yU]`, and `[pL, pU]`, we /// apply the following operations: /// - For plus operation, specifically, we would first do -/// - [xL, xU] <- ([pL, pU] - [yL, yU]) ∩ [xL, xU], and then -/// - [yL, yU] <- ([pL, pU] - [xL, xU]) ∩ [yL, yU]. +/// - `[xL, xU]` <- (`[pL, pU]` - `[yL, yU]`) ∩ `[xL, xU]`, and then +/// - `[yL, yU]` <- (`[pL, pU]` - `[xL, xU]`) ∩ `[yL, yU]`. /// - For minus operation, specifically, we would first do -/// - [xL, xU] <- ([yL, yU] + [pL, pU]) ∩ [xL, xU], and then -/// - [yL, yU] <- ([xL, xU] - [pL, pU]) ∩ [yL, yU]. +/// - `[xL, xU]` <- (`[yL, yU]` + `[pL, pU]`) ∩ `[xL, xU]`, and then +/// - `[yL, yU]` <- (`[xL, xU]` - `[pL, pU]`) ∩ `[yL, yU]`. /// - For multiplication operation, specifically, we would first do -/// - [xL, xU] <- ([pL, pU] / [yL, yU]) ∩ [xL, xU], and then -/// - [yL, yU] <- ([pL, pU] / [xL, xU]) ∩ [yL, yU]. +/// - `[xL, xU]` <- (`[pL, pU]` / `[yL, yU]`) ∩ `[xL, xU]`, and then +/// - `[yL, yU]` <- (`[pL, pU]` / `[xL, xU]`) ∩ `[yL, yU]`. /// - For division operation, specifically, we would first do -/// - [xL, xU] <- ([yL, yU] * [pL, pU]) ∩ [xL, xU], and then -/// - [yL, yU] <- ([xL, xU] / [pL, pU]) ∩ [yL, yU]. +/// - `[xL, xU]` <- (`[yL, yU]` * `[pL, pU]`) ∩ `[xL, xU]`, and then +/// - `[yL, yU]` <- (`[xL, xU]` / `[pL, pU]`) ∩ `[yL, yU]`. pub fn propagate_arithmetic( op: &Operator, parent: &Interval, @@ -361,18 +394,30 @@ impl ExprIntervalGraph { self.graph.node_count() } + /// Estimate size of bytes including `Self`. + pub fn size(&self) -> usize { + let node_memory_usage = self.graph.node_count() + * (size_of::<ExprIntervalGraphNode>() + size_of::<NodeIndex>()); + let edge_memory_usage = + self.graph.edge_count() * (size_of::<usize>() + size_of::<NodeIndex>() * 2); + + size_of_val(self) + node_memory_usage + edge_memory_usage + } + // Sometimes, we do not want to calculate and/or propagate intervals all // way down to leaf expressions. For example, assume that we have a // `SymmetricHashJoin` which has a child with an output ordering like: // + // ```text // PhysicalSortExpr { // expr: BinaryExpr('a', +, 'b'), // sort_option: .. // } + // ``` // - // i.e. its output order comes from a clause like "ORDER BY a + b". In such - // a case, we must calculate the interval for the BinaryExpr('a', +, 'b') - // instead of the columns inside this BinaryExpr, because this interval + // i.e. its output order comes from a clause like `ORDER BY a + b`. In such + // a case, we must calculate the interval for the `BinaryExpr(a, +, b)` + // instead of the columns inside this `BinaryExpr`, because this interval // decides whether we prune or not. Therefore, children `PhysicalExpr`s of // this `BinaryExpr` may be pruned for performance. The figure below // explains this example visually. @@ -510,9 +555,6 @@ impl ExprIntervalGraph { /// Computes bounds for an expression using interval arithmetic via a /// bottom-up traversal. /// - /// # Arguments - /// * `leaf_bounds` - &[(usize, Interval)]. Provide NodeIndex, Interval tuples for leaf variables. - /// /// # Examples /// /// ``` @@ -570,7 +612,7 @@ impl ExprIntervalGraph { self.graph[node].expr.evaluate_bounds(&children_intervals)?; } } - Ok(&self.graph[self.root].interval) + Ok(self.graph[self.root].interval()) } /// Updates/shrinks bounds for leaf expressions using interval arithmetic @@ -579,8 +621,6 @@ impl ExprIntervalGraph { &mut self, given_range: Interval, ) -> Result<PropagationResult> { - let mut bfs = Bfs::new(&self.graph, self.root); - // Adjust the root node with the given range: if let Some(interval) = self.graph[self.root].interval.intersect(given_range)? { self.graph[self.root].interval = interval; @@ -588,6 +628,8 @@ impl ExprIntervalGraph { return Ok(PropagationResult::Infeasible); } + let mut bfs = Bfs::new(&self.graph, self.root); + while let Some(node) = bfs.next(&self.graph) { let neighbors = self.graph.neighbors_directed(node, Outgoing); let mut children = neighbors.collect::<Vec<_>>(); diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index b68d10905cab..0a448fa6a2e9 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -40,6 +40,7 @@ pub mod udf { #[allow(deprecated)] pub use crate::scalar_function::create_physical_expr; } +pub mod statistics; pub mod utils; pub mod window; diff --git a/datafusion/physical-expr/src/statistics/mod.rs b/datafusion/physical-expr/src/statistics/mod.rs new file mode 100644 index 000000000000..02897e059457 --- /dev/null +++ b/datafusion/physical-expr/src/statistics/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Statistics and constraint propagation library + +pub mod stats_solver; diff --git a/datafusion/physical-expr/src/statistics/stats_solver.rs b/datafusion/physical-expr/src/statistics/stats_solver.rs new file mode 100644 index 000000000000..ec58076caf3b --- /dev/null +++ b/datafusion/physical-expr/src/statistics/stats_solver.rs @@ -0,0 +1,287 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use crate::expressions::Literal; +use crate::intervals::cp_solver::PropagationResult; +use crate::physical_expr::PhysicalExpr; +use crate::utils::{build_dag, ExprTreeNode}; + +use arrow::datatypes::{DataType, Schema}; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::statistics::Distribution; +use datafusion_expr_common::interval_arithmetic::Interval; + +use petgraph::adj::DefaultIx; +use petgraph::prelude::Bfs; +use petgraph::stable_graph::{NodeIndex, StableGraph}; +use petgraph::visit::DfsPostOrder; +use petgraph::Outgoing; + +/// This object implements a directed acyclic expression graph (DAEG) that +/// is used to compute statistics/distributions for expressions hierarchically. +#[derive(Clone, Debug)] +pub struct ExprStatisticsGraph { + graph: StableGraph<ExprStatisticsGraphNode, usize>, + root: NodeIndex, +} + +/// This is a node in the DAEG; it encapsulates a reference to the actual +/// [`PhysicalExpr`] as well as its statistics/distribution. +#[derive(Clone, Debug)] +pub struct ExprStatisticsGraphNode { + expr: Arc<dyn PhysicalExpr>, + dist: Distribution, +} + +impl ExprStatisticsGraphNode { + /// Constructs a new DAEG node based on the given interval with a + /// `Uniform` distribution. + fn new_uniform(expr: Arc<dyn PhysicalExpr>, interval: Interval) -> Result<Self> { + Distribution::new_uniform(interval) + .map(|dist| ExprStatisticsGraphNode { expr, dist }) + } + + /// Constructs a new DAEG node with a `Bernoulli` distribution having an + /// unknown success probability. + fn new_bernoulli(expr: Arc<dyn PhysicalExpr>) -> Result<Self> { + Distribution::new_bernoulli(ScalarValue::Float64(None)) + .map(|dist| ExprStatisticsGraphNode { expr, dist }) + } + + /// Constructs a new DAEG node with a `Generic` distribution having no + /// definite summary statistics. + fn new_generic(expr: Arc<dyn PhysicalExpr>, dt: &DataType) -> Result<Self> { + let interval = Interval::make_unbounded(dt)?; + let dist = Distribution::new_from_interval(interval)?; + Ok(ExprStatisticsGraphNode { expr, dist }) + } + + /// Get the [`Distribution`] object representing the statistics of the + /// expression. + pub fn distribution(&self) -> &Distribution { + &self.dist + } + + /// This function creates a DAEG node from DataFusion's [`ExprTreeNode`] + /// object. Literals are created with `Uniform` distributions with a + /// definite, singleton interval. Expressions with a `Boolean` data type + /// result in a`Bernoulli` distribution with an unknown success probability. + /// Any other expression starts with an `Unknown` distribution with an + /// indefinite range (i.e. `[-∞, ∞]`). + pub fn make_node(node: &ExprTreeNode<NodeIndex>, schema: &Schema) -> Result<Self> { + let expr = Arc::clone(&node.expr); + if let Some(literal) = expr.as_any().downcast_ref::<Literal>() { + let value = literal.value(); + Interval::try_new(value.clone(), value.clone()) + .and_then(|interval| Self::new_uniform(expr, interval)) + } else { + expr.data_type(schema).and_then(|dt| { + if dt.eq(&DataType::Boolean) { + Self::new_bernoulli(expr) + } else { + Self::new_generic(expr, &dt) + } + }) + } + } +} + +impl ExprStatisticsGraph { + pub fn try_new(expr: Arc<dyn PhysicalExpr>, schema: &Schema) -> Result<Self> { + // Build the full graph: + let (root, graph) = build_dag(expr, &|node| { + ExprStatisticsGraphNode::make_node(node, schema) + })?; + Ok(Self { graph, root }) + } + + /// This function assigns given distributions to expressions in the DAEG. + /// The argument `assignments` associates indices of sought expressions + /// with their corresponding new distributions. + pub fn assign_statistics(&mut self, assignments: &[(usize, Distribution)]) { + for (index, stats) in assignments { + let node_index = NodeIndex::from(*index as DefaultIx); + self.graph[node_index].dist = stats.clone(); + } + } + + /// Computes statistics/distributions for an expression via a bottom-up + /// traversal. + pub fn evaluate_statistics(&mut self) -> Result<&Distribution> { + let mut dfs = DfsPostOrder::new(&self.graph, self.root); + while let Some(idx) = dfs.next(&self.graph) { + let neighbors = self.graph.neighbors_directed(idx, Outgoing); + let mut children_statistics = neighbors + .map(|child| self.graph[child].distribution()) + .collect::<Vec<_>>(); + // Note that all distributions are assumed to be independent. + if !children_statistics.is_empty() { + // Reverse to align with `PhysicalExpr`'s children: + children_statistics.reverse(); + self.graph[idx].dist = self.graph[idx] + .expr + .evaluate_statistics(&children_statistics)?; + } + } + Ok(self.graph[self.root].distribution()) + } + + /// Runs a propagation mechanism in a top-down manner to update statistics + /// of leaf nodes. + pub fn propagate_statistics( + &mut self, + given_stats: Distribution, + ) -> Result<PropagationResult> { + // Adjust the root node with the given statistics: + let root_range = self.graph[self.root].dist.range()?; + let given_range = given_stats.range()?; + if let Some(interval) = root_range.intersect(&given_range)? { + if interval != root_range { + // If the given statistics enable us to obtain a more precise + // range for the root, update it: + let subset = root_range.contains(given_range)?; + self.graph[self.root].dist = if subset == Interval::CERTAINLY_TRUE { + // Given statistics is strictly more informative, use it as is: + given_stats + } else { + // Intersecting ranges gives us a more precise range: + Distribution::new_from_interval(interval)? + }; + } + } else { + return Ok(PropagationResult::Infeasible); + } + + let mut bfs = Bfs::new(&self.graph, self.root); + + while let Some(node) = bfs.next(&self.graph) { + let neighbors = self.graph.neighbors_directed(node, Outgoing); + let mut children = neighbors.collect::<Vec<_>>(); + // If the current expression is a leaf, its statistics is now final. + // So, just continue with the propagation procedure: + if children.is_empty() { + continue; + } + // Reverse to align with `PhysicalExpr`'s children: + children.reverse(); + let children_stats = children + .iter() + .map(|child| self.graph[*child].distribution()) + .collect::<Vec<_>>(); + let node_statistics = self.graph[node].distribution(); + let propagated_statistics = self.graph[node] + .expr + .propagate_statistics(node_statistics, &children_stats)?; + if let Some(propagated_stats) = propagated_statistics { + for (child_idx, stats) in children.into_iter().zip(propagated_stats) { + self.graph[child_idx].dist = stats; + } + } else { + // The constraint is infeasible, report: + return Ok(PropagationResult::Infeasible); + } + } + Ok(PropagationResult::Success) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::expressions::{binary, try_cast, Column}; + use crate::intervals::cp_solver::PropagationResult; + use crate::statistics::stats_solver::ExprStatisticsGraph; + + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr_common::interval_arithmetic::Interval; + use datafusion_expr_common::operator::Operator; + use datafusion_expr_common::statistics::Distribution; + use datafusion_expr_common::type_coercion::binary::BinaryTypeCoercer; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + + pub fn binary_expr( + left: Arc<dyn PhysicalExpr>, + op: Operator, + right: Arc<dyn PhysicalExpr>, + schema: &Schema, + ) -> Result<Arc<dyn PhysicalExpr>> { + let left_type = left.data_type(schema)?; + let right_type = right.data_type(schema)?; + let binary_type_coercer = BinaryTypeCoercer::new(&left_type, &op, &right_type); + let (lhs, rhs) = binary_type_coercer.get_input_types()?; + + let left_expr = try_cast(left, schema, lhs)?; + let right_expr = try_cast(right, schema, rhs)?; + binary(left_expr, op, right_expr, schema) + } + + #[test] + fn test_stats_integration() -> Result<()> { + let schema = &Schema::new(vec![ + Field::new("a", DataType::Float64, false), + Field::new("b", DataType::Float64, false), + Field::new("c", DataType::Float64, false), + Field::new("d", DataType::Float64, false), + ]); + + let a = Arc::new(Column::new("a", 0)) as _; + let b = Arc::new(Column::new("b", 1)) as _; + let c = Arc::new(Column::new("c", 2)) as _; + let d = Arc::new(Column::new("d", 3)) as _; + + let left = binary_expr(a, Operator::Plus, b, schema)?; + let right = binary_expr(c, Operator::Minus, d, schema)?; + let expr = binary_expr(left, Operator::Eq, right, schema)?; + + let mut graph = ExprStatisticsGraph::try_new(expr, schema)?; + // 2, 5 and 6 are BinaryExpr + graph.assign_statistics(&[ + ( + 0usize, + Distribution::new_uniform(Interval::make(Some(0.), Some(1.))?)?, + ), + ( + 1usize, + Distribution::new_uniform(Interval::make(Some(0.), Some(2.))?)?, + ), + ( + 3usize, + Distribution::new_uniform(Interval::make(Some(1.), Some(3.))?)?, + ), + ( + 4usize, + Distribution::new_uniform(Interval::make(Some(1.), Some(5.))?)?, + ), + ]); + let ev_stats = graph.evaluate_statistics()?; + assert_eq!( + ev_stats, + &Distribution::new_bernoulli(ScalarValue::Float64(None))? + ); + + let one = ScalarValue::new_one(&DataType::Float64)?; + assert_eq!( + graph.propagate_statistics(Distribution::new_bernoulli(one)?)?, + PropagationResult::Success + ); + Ok(()) + } +} diff --git a/test-utils/src/array_gen/string.rs b/test-utils/src/array_gen/string.rs index e2a983612b8b..ac659ae67bc0 100644 --- a/test-utils/src/array_gen/string.rs +++ b/test-utils/src/array_gen/string.rs @@ -97,7 +97,7 @@ fn random_string(rng: &mut StdRng, max_len: usize) -> String { let len = rng.gen_range(1..=max_len); rng.sample_iter::<char, _>(rand::distributions::Standard) .take(len) - .collect::<String>() + .collect() } } }