diff --git a/datafusion/expr-common/src/statistics.rs b/datafusion/expr-common/src/statistics.rs index 7e0bc88087ef..98ddea850f3b 100644 --- a/datafusion/expr-common/src/statistics.rs +++ b/datafusion/expr-common/src/statistics.rs @@ -203,6 +203,138 @@ impl Distribution { }; Ok(dt) } + + /// Merges two distributions into a single distribution that represents their combined statistics. + /// This creates a more general distribution that approximates the mixture of the input distributions. + /// + /// # Important Notes + /// + /// - The resulting mean, median, and variance are approximations of the mixture + /// distribution parameters. They are calculated using weighted averages based on + /// the input distributions. Users should not make definitive assumptions based on these values. + /// + /// - The range of the merged distribution is computed as the union of the input ranges + /// and its accuracy directly depends on the accuracy of the input ranges. + /// + /// - The result is always a [`Generic`] distribution, which may lose some specific + /// properties of the original distribution types. + /// + /// # Returns + /// + /// Returns a new [`Distribution`] that approximates the combined statistics of the + /// input distributions. + pub fn merge(&self, other: &Self) -> Result { + let range_a = self.range()?; + let range_b = other.range()?; + + // Determine data type and create combined range + let combined_range = range_a.union(&range_b)?; + + // Calculate weights for the mixture distribution + let (weight_a, weight_b) = match (range_a.cardinality(), range_b.cardinality()) { + (Some(ca), Some(cb)) => { + let total = (ca + cb) as f64; + ((ca as f64) / total, (cb as f64) / total) + } + _ => (0.5, 0.5), // Equal weights if cardinalities not available + }; + + // Get the original statistics + let mean_a = self.mean()?; + let mean_b = other.mean()?; + let median_a = self.median()?; + let median_b = other.median()?; + let var_a = self.variance()?; + let var_b = other.variance()?; + + // Always use Float64 for intermediate calculations to avoid truncation + // I assume that the target type is always numeric + // Todo: maybe we can keep all `ScalarValue` as `Float64` in `Distribution`? + let calc_type = DataType::Float64; + + // Create weight scalars using Float64 to avoid truncation + let weight_a_scalar = ScalarValue::from(weight_a); + let weight_b_scalar = ScalarValue::from(weight_b); + + // Calculate combined mean + let combined_mean = if mean_a.is_null() || mean_b.is_null() { + if mean_a.is_null() { + mean_b.clone() + } else { + mean_a.clone() + } + } else { + // Cast to Float64 for calculation + let mean_a_f64 = mean_a.cast_to(&calc_type)?; + let mean_b_f64 = mean_b.cast_to(&calc_type)?; + + // Calculate weighted mean + mean_a_f64 + .mul_checked(&weight_a_scalar)? + .add_checked(&mean_b_f64.mul_checked(&weight_b_scalar)?)? + }; + + // Calculate combined median + let combined_median = if median_a.is_null() || median_b.is_null() { + if median_a.is_null() { + median_b + } else { + median_a + } + } else { + // Cast to Float64 for calculation + let median_a_f64 = median_a.cast_to(&calc_type)?; + let median_b_f64 = median_b.cast_to(&calc_type)?; + + // Calculate weighted median + median_a_f64 + .mul_checked(&weight_a_scalar)? + .add_checked(&median_b_f64.mul_checked(&weight_b_scalar)?)? + }; + + // Calculate combined variance + let combined_variance = if var_a.is_null() || var_b.is_null() { + if var_a.is_null() { + var_b + } else { + var_a + } + } else { + // Cast to Float64 for calculation + let var_a_f64 = var_a.cast_to(&calc_type)?; + let var_b_f64 = var_b.cast_to(&calc_type)?; + + let weighted_var_a = var_a_f64.mul_checked(&weight_a_scalar)?; + let weighted_var_b = var_b_f64.mul_checked(&weight_b_scalar)?; + + // Add cross-term if both means are available + if !mean_a.is_null() && !mean_b.is_null() { + // Cast means to Float64 for calculation + let mean_a_f64 = mean_a.cast_to(&calc_type)?; + let mean_b_f64 = mean_b.cast_to(&calc_type)?; + + let mean_diff = mean_a_f64.sub_checked(&mean_b_f64)?; + let mean_diff_squared = mean_diff.mul_checked(&mean_diff)?; + let cross_term = weight_a_scalar + .mul_checked(&weight_b_scalar)? + .mul_checked(&mean_diff_squared)?; + + weighted_var_a + .add_checked(&weighted_var_b)? + .add_checked(&cross_term)? + } else { + weighted_var_a.add_checked(&weighted_var_b)? + } + }; + + // Create a Generic distribution with the combined statistics + Distribution::new_generic( + combined_mean, + combined_median, + combined_variance, + combined_range, + ) + } } /// Uniform distribution, represented by its range. If the given range extends