Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support merge for Distribution #15296

Closed
wants to merge 3 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions datafusion/expr-common/src/statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,143 @@ pub fn compute_variance(
ScalarValue::try_from(target_type)
}

/// Merges two distributions into a single distribution that represents their combined statistics.
/// This creates a more general distribution that approximates the mixture of the input distributions.
pub fn merge_distributions(a: &Distribution, b: &Distribution) -> Result<Distribution> {
let range_a = a.range()?;
let range_b = b.range()?;

// Determine data type and create combined range
let combined_range = if range_a.is_unbounded() || range_b.is_unbounded() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

@xudong963 xudong963 Mar 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, one concern is that I found the Interval::union works with intervals of the same data type.

It seems that we can loose the requirement, such as, Int64 with Int32, int with float, etc also can be unioned.

Interval::make_unbounded(&range_a.data_type())?
} else {
// Take the widest possible range conservatively
let lower_a = range_a.lower();
let lower_b = range_b.lower();
let upper_a = range_a.upper();
let upper_b = range_b.upper();

let combined_lower = if lower_a.lt(lower_b) {
lower_a.clone()
} else {
lower_b.clone()
};

let combined_upper = if upper_a.gt(upper_b) {
upper_a.clone()
} else {
upper_b.clone()
};

Interval::try_new(combined_lower, combined_upper)?
};

// Calculate weights for the mixture distribution
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does "mixture distribution" mean in this context?

It seems like this code weighs the input distributions on number of distinct values (cardinality) which seems not right. For example if we have two inputs:

  1. 1M rows, 3 distinct values
  2. 10 rows, 10 distinct values

I think this code is going to assume the man is close to the second input even though there are only 10 values

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your point is correct.

IMO, the best way to compute the weight is based on the count of each interval, but the count of each interval is unknown.

After thinking, I have a new idea, maybe we can use the variance to approximate the weight. That means, lower variance generally indicates more samples:

let (weight_a, weight_b) = {
    // Lower variance generally indicates more samples
    let var_a = self.variance()?.cast_to(&DataType::Float64)?;
    let var_b = other.variance()?.cast_to(&DataType::Float64)?;
    
    match (var_a, var_b) {
        (ScalarValue::Float64(Some(va)), ScalarValue::Float64(Some(vb))) => {
            // Weighting inversely by variance (with safeguards against division by zero)
            let va_safe = va.max(f64::EPSILON);
            let vb_safe = vb.max(f64::EPSILON);
            let wa = 1.0 / va_safe;
            let wb = 1.0 / vb_safe;
            let total = wa + wb;
            (wa / total, wb / total)
        }
        _ => (0.5, 0.5)  // Fall back to equal weights
    }
};

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And also ping @kosiew , do you have any thoughts for the new way to compute the weight?

let (weight_a, weight_b) = match (range_a.cardinality(), range_b.cardinality()) {
(Some(ca), Some(cb)) => {
let total = (ca + cb) as f64;
((ca as f64) / total, (cb as f64) / total)
}
_ => (0.5, 0.5), // Equal weights if cardinalities not available
};

// Get the original statistics
let mean_a = a.mean()?;
let mean_b = b.mean()?;
let median_a = a.median()?;
let median_b = b.median()?;
let var_a = a.variance()?;
let var_b = b.variance()?;

// Always use Float64 for intermediate calculations to avoid truncation
// I assume that the target type is always numeric
// Todo: maybe we can keep all `ScalarValue` as `Float64` in `Distribution`?
let calc_type = DataType::Float64;

// Create weight scalars using Float64 to avoid truncation
let weight_a_scalar = ScalarValue::from(weight_a);
let weight_b_scalar = ScalarValue::from(weight_b);

// Calculate combined mean
let combined_mean = if mean_a.is_null() || mean_b.is_null() {
if mean_a.is_null() {
mean_b.clone()
} else {
mean_a.clone()
}
} else {
// Cast to Float64 for calculation
let mean_a_f64 = mean_a.cast_to(&calc_type)?;
let mean_b_f64 = mean_b.cast_to(&calc_type)?;

// Calculate weighted mean
mean_a_f64
.mul_checked(&weight_a_scalar)?
.add_checked(&mean_b_f64.mul_checked(&weight_b_scalar)?)?
};

// Calculate combined median
let combined_median = if median_a.is_null() || median_b.is_null() {
if median_a.is_null() {
median_b
} else {
median_a
}
} else {
// Cast to Float64 for calculation
let median_a_f64 = median_a.cast_to(&calc_type)?;
let median_b_f64 = median_b.cast_to(&calc_type)?;

// Calculate weighted median
median_a_f64
.mul_checked(&weight_a_scalar)?
.add_checked(&median_b_f64.mul_checked(&weight_b_scalar)?)?
};

// Calculate combined variance
let combined_variance = if var_a.is_null() || var_b.is_null() {
if var_a.is_null() {
var_b
} else {
var_a
}
} else {
// Cast to Float64 for calculation
let var_a_f64 = var_a.cast_to(&calc_type)?;
let var_b_f64 = var_b.cast_to(&calc_type)?;

let weighted_var_a = var_a_f64.mul_checked(&weight_a_scalar)?;
let weighted_var_b = var_b_f64.mul_checked(&weight_b_scalar)?;

// Add cross-term if both means are available
if !mean_a.is_null() && !mean_b.is_null() {
// Cast means to Float64 for calculation
let mean_a_f64 = mean_a.cast_to(&calc_type)?;
let mean_b_f64 = mean_b.cast_to(&calc_type)?;

let mean_diff = mean_a_f64.sub_checked(&mean_b_f64)?;
let mean_diff_squared = mean_diff.mul_checked(&mean_diff)?;
let cross_term = weight_a_scalar
.mul_checked(&weight_b_scalar)?
.mul_checked(&mean_diff_squared)?;

weighted_var_a
.add_checked(&weighted_var_b)?
.add_checked(&cross_term)?
} else {
weighted_var_a.add_checked(&weighted_var_b)?
}
};

// Create a Generic distribution with the combined statistics
Distribution::new_generic(
combined_mean,
combined_median,
combined_variance,
combined_range,
)
}

#[cfg(test)]
mod tests {
use super::{
Expand Down