Skip to content

Commit 8ee56aa

Browse files
committed
Fixes 17630 by not adding generated windown column twice
1 parent 9e36ec4 commit 8ee56aa

File tree

2 files changed

+65
-14
lines changed

2 files changed

+65
-14
lines changed

datafusion/core/src/dataframe/mod.rs

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ use crate::physical_plan::{
4343
use crate::prelude::SessionContext;
4444
use std::any::Any;
4545
use std::borrow::Cow;
46-
use std::collections::HashMap;
46+
use std::collections::{HashMap, HashSet};
4747
use std::sync::Arc;
4848

4949
use arrow::array::{Array, ArrayRef, Int64Array, StringArray};
@@ -2023,31 +2023,38 @@ impl DataFrame {
20232023
pub fn with_column(self, name: &str, expr: Expr) -> Result<DataFrame> {
20242024
let window_func_exprs = find_window_exprs([&expr]);
20252025

2026-
let (window_fn_str, plan) = if window_func_exprs.is_empty() {
2027-
(None, self.plan)
2026+
let original_names: HashSet<String> = self
2027+
.plan
2028+
.schema()
2029+
.iter()
2030+
.map(|(_, f)| f.name().clone())
2031+
.collect();
2032+
2033+
// Maybe build window plan
2034+
let plan = if window_func_exprs.is_empty() {
2035+
self.plan
20282036
} else {
2029-
(
2030-
Some(window_func_exprs[0].to_string()),
2031-
LogicalPlanBuilder::window_plan(self.plan, window_func_exprs)?,
2032-
)
2037+
LogicalPlanBuilder::window_plan(self.plan, window_func_exprs)?
20332038
};
20342039

2035-
let mut col_exists = false;
20362040
let new_column = expr.alias(name);
2041+
let mut col_exists = false;
2042+
20372043
let mut fields: Vec<(Expr, bool)> = plan
20382044
.schema()
20392045
.iter()
20402046
.filter_map(|(qualifier, field)| {
2047+
// Skip new fields introduced by window_plan
2048+
if !original_names.contains(field.name()) {
2049+
return None;
2050+
}
2051+
20412052
if field.name() == name {
20422053
col_exists = true;
20432054
Some((new_column.clone(), true))
20442055
} else {
20452056
let e = col(Column::from((qualifier, field)));
2046-
window_fn_str
2047-
.as_ref()
2048-
.filter(|s| *s == &e.to_string())
2049-
.is_none()
2050-
.then_some((e, self.projection_requires_validation))
2057+
Some((e, self.projection_requires_validation))
20512058
}
20522059
})
20532060
.collect();

datafusion/core/tests/dataframe/mod.rs

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ use datafusion_functions_aggregate::expr_fn::{
3838
array_agg, avg, count, count_distinct, max, median, min, sum,
3939
};
4040
use datafusion_functions_nested::make_array::make_array_udf;
41-
use datafusion_functions_window::expr_fn::{first_value, row_number};
41+
use datafusion_functions_window::expr_fn::{first_value, lead, row_number};
4242
use insta::assert_snapshot;
4343
use object_store::local::LocalFileSystem;
4444
use std::collections::HashMap;
@@ -85,6 +85,9 @@ use datafusion_physical_expr::Partitioning;
8585
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
8686
use datafusion_physical_plan::{displayable, ExecutionPlanProperties};
8787

88+
use datafusion::error::Result as DataFusionResult;
89+
use datafusion_functions_window::expr_fn::lag;
90+
8891
// Get string representation of the plan
8992
async fn physical_plan_to_string(df: &DataFrame) -> String {
9093
let physical_plan = df
@@ -152,6 +155,47 @@ async fn test_array_agg_ord_schema() -> Result<()> {
152155
Ok(())
153156
}
154157

158+
#[tokio::test]
159+
async fn with_column_lag() -> DataFusionResult<()> {
160+
let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
161+
162+
let batch = RecordBatch::try_new(
163+
Arc::new(schema.clone()),
164+
vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))],
165+
)?;
166+
167+
let ctx = SessionContext::new();
168+
169+
let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]])?;
170+
ctx.register_table("t", Arc::new(provider))?;
171+
172+
// Define test cases: (expr builder, alias name)
173+
let test_cases: Vec<(Box<dyn Fn() -> Expr>, &str)> = vec![
174+
(Box::new(|| lag(col("a"), Some(1), None)), "lag_val"),
175+
(Box::new(|| lead(col("a"), Some(1), None)), "lead_val"),
176+
(Box::new(|| row_number()), "row_num"),
177+
];
178+
179+
for (make_expr, alias) in test_cases {
180+
let df = ctx.table("t").await?;
181+
let expr = make_expr();
182+
let df_with = df.with_column(alias, expr)?;
183+
let df_schema = df_with.schema().clone();
184+
185+
// Assert schema contains the alias column
186+
assert!(
187+
df_schema.has_column_with_unqualified_name(alias),
188+
"Schema does not contain expected column {}",
189+
alias
190+
);
191+
192+
// Schema should have exactly 2 columns: original + alias
193+
assert_eq!(2, df_schema.columns().len());
194+
}
195+
196+
Ok(())
197+
}
198+
155199
#[tokio::test]
156200
async fn test_coalesce_schema() -> Result<()> {
157201
let ctx = SessionContext::new();

0 commit comments

Comments
 (0)