Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support merge for Distribution #15296

Closed
wants to merge 3 commits into from
Closed
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions datafusion/expr-common/src/statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,121 @@ impl Distribution {
};
Ok(dt)
}

/// Merges two distributions into a single distribution that represents their combined statistics.
/// This creates a more general distribution that approximates the mixture of the input distributions.
pub fn merge(&self, other: &Self) -> Result<Self> {
let range_a = self.range()?;
let range_b = other.range()?;

// Determine data type and create combined range
let combined_range = range_a.union(&range_b)?;

// Calculate weights for the mixture distribution
let (weight_a, weight_b) = match (range_a.cardinality(), range_b.cardinality()) {
(Some(ca), Some(cb)) => {
let total = (ca + cb) as f64;
((ca as f64) / total, (cb as f64) / total)
}
_ => (0.5, 0.5), // Equal weights if cardinalities not available
};

// Get the original statistics
let mean_a = self.mean()?;
let mean_b = other.mean()?;
let median_a = self.median()?;
let median_b = other.median()?;
let var_a = self.variance()?;
let var_b = other.variance()?;

// Always use Float64 for intermediate calculations to avoid truncation
// I assume that the target type is always numeric
// Todo: maybe we can keep all `ScalarValue` as `Float64` in `Distribution`?
let calc_type = DataType::Float64;
Copy link
Contributor

@berkaysynnada berkaysynnada Mar 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why float? decimals have higher precisions? We've thought on that a lot, and relaxing the datatype is not a good way during computations and representing intermediate or final results. Rather than presuming a target type, we need to rely on the data type of the original quantity and standard coercions of it.


// Create weight scalars using Float64 to avoid truncation
let weight_a_scalar = ScalarValue::from(weight_a);
let weight_b_scalar = ScalarValue::from(weight_b);

// Calculate combined mean
let combined_mean = if mean_a.is_null() || mean_b.is_null() {
if mean_a.is_null() {
mean_b.clone()
} else {
mean_a.clone()
}
} else {
// Cast to Float64 for calculation
let mean_a_f64 = mean_a.cast_to(&calc_type)?;
let mean_b_f64 = mean_b.cast_to(&calc_type)?;

// Calculate weighted mean
mean_a_f64
.mul_checked(&weight_a_scalar)?
.add_checked(&mean_b_f64.mul_checked(&weight_b_scalar)?)?
};

// Calculate combined median
let combined_median = if median_a.is_null() || median_b.is_null() {
if median_a.is_null() {
median_b
} else {
median_a
}
} else {
// Cast to Float64 for calculation
let median_a_f64 = median_a.cast_to(&calc_type)?;
let median_b_f64 = median_b.cast_to(&calc_type)?;

// Calculate weighted median
median_a_f64
.mul_checked(&weight_a_scalar)?
.add_checked(&median_b_f64.mul_checked(&weight_b_scalar)?)?
};

// Calculate combined variance
let combined_variance = if var_a.is_null() || var_b.is_null() {
if var_a.is_null() {
var_b
} else {
var_a
}
} else {
// Cast to Float64 for calculation
let var_a_f64 = var_a.cast_to(&calc_type)?;
let var_b_f64 = var_b.cast_to(&calc_type)?;

let weighted_var_a = var_a_f64.mul_checked(&weight_a_scalar)?;
let weighted_var_b = var_b_f64.mul_checked(&weight_b_scalar)?;

// Add cross-term if both means are available
if !mean_a.is_null() && !mean_b.is_null() {
// Cast means to Float64 for calculation
let mean_a_f64 = mean_a.cast_to(&calc_type)?;
let mean_b_f64 = mean_b.cast_to(&calc_type)?;

let mean_diff = mean_a_f64.sub_checked(&mean_b_f64)?;
let mean_diff_squared = mean_diff.mul_checked(&mean_diff)?;
let cross_term = weight_a_scalar
.mul_checked(&weight_b_scalar)?
.mul_checked(&mean_diff_squared)?;

weighted_var_a
.add_checked(&weighted_var_b)?
.add_checked(&cross_term)?
} else {
weighted_var_a.add_checked(&weighted_var_b)?
}
};

// Create a Generic distribution with the combined statistics
Distribution::new_generic(
combined_mean,
combined_median,
combined_variance,
combined_range,
)
}
}

/// Uniform distribution, represented by its range. If the given range extends
Expand Down