@@ -1416,6 +1416,7 @@ pub(crate) mod tests {
1416
1416
use crate :: datasource:: object_store:: ObjectStoreUrl ;
1417
1417
use crate :: datasource:: physical_plan:: { CsvExec , FileScanConfig , ParquetExec } ;
1418
1418
use crate :: physical_optimizer:: enforce_sorting:: EnforceSorting ;
1419
+ use crate :: physical_optimizer:: sanity_checker:: check_plan_sanity;
1419
1420
use crate :: physical_optimizer:: test_utils:: {
1420
1421
check_integrity, coalesce_partitions_exec, repartition_exec,
1421
1422
} ;
@@ -1426,11 +1427,14 @@ pub(crate) mod tests {
1426
1427
use crate :: physical_plan:: limit:: { GlobalLimitExec , LocalLimitExec } ;
1427
1428
use crate :: physical_plan:: sorts:: sort:: SortExec ;
1428
1429
use crate :: physical_plan:: { displayable, DisplayAs , DisplayFormatType , Statistics } ;
1430
+ use datafusion_execution:: { SendableRecordBatchStream , TaskContext } ;
1431
+ use datafusion_functions_aggregate:: sum:: sum_udaf;
1432
+ use datafusion_physical_expr:: aggregate:: AggregateExprBuilder ;
1429
1433
use datafusion_physical_optimizer:: output_requirements:: OutputRequirements ;
1430
1434
1431
1435
use arrow:: datatypes:: { DataType , Field , Schema , SchemaRef } ;
1432
1436
use datafusion_common:: ScalarValue ;
1433
- use datafusion_expr:: Operator ;
1437
+ use datafusion_expr:: { AggregateUDF , Operator } ;
1434
1438
use datafusion_physical_expr:: expressions:: { BinaryExpr , Literal } ;
1435
1439
use datafusion_physical_expr:: {
1436
1440
expressions:: binary, expressions:: lit, LexOrdering , PhysicalSortExpr ,
@@ -1526,8 +1530,8 @@ pub(crate) mod tests {
1526
1530
fn execute (
1527
1531
& self ,
1528
1532
_partition : usize ,
1529
- _context : Arc < crate :: execution :: context :: TaskContext > ,
1530
- ) -> Result < crate :: physical_plan :: SendableRecordBatchStream > {
1533
+ _context : Arc < TaskContext > ,
1534
+ ) -> Result < SendableRecordBatchStream > {
1531
1535
unreachable ! ( ) ;
1532
1536
}
1533
1537
@@ -1643,6 +1647,15 @@ pub(crate) mod tests {
1643
1647
fn aggregate_exec_with_alias (
1644
1648
input : Arc < dyn ExecutionPlan > ,
1645
1649
alias_pairs : Vec < ( String , String ) > ,
1650
+ ) -> Arc < dyn ExecutionPlan > {
1651
+ aggregate_exec_with_aggr_expr_and_alias ( input, vec ! [ ] , alias_pairs)
1652
+ }
1653
+
1654
+ #[ expect( clippy:: type_complexity) ]
1655
+ fn aggregate_exec_with_aggr_expr_and_alias (
1656
+ input : Arc < dyn ExecutionPlan > ,
1657
+ aggr_expr : Vec < ( Arc < AggregateUDF > , Vec < Arc < dyn PhysicalExpr > > ) > ,
1658
+ alias_pairs : Vec < ( String , String ) > ,
1646
1659
) -> Arc < dyn ExecutionPlan > {
1647
1660
let schema = schema ( ) ;
1648
1661
let mut group_by_expr: Vec < ( Arc < dyn PhysicalExpr > , String ) > = vec ! [ ] ;
@@ -1664,18 +1677,33 @@ pub(crate) mod tests {
1664
1677
. collect :: < Vec < _ > > ( ) ;
1665
1678
let final_grouping = PhysicalGroupBy :: new_single ( final_group_by_expr) ;
1666
1679
1680
+ let aggr_expr = aggr_expr
1681
+ . into_iter ( )
1682
+ . map ( |( udaf, exprs) | {
1683
+ AggregateExprBuilder :: new ( udaf. clone ( ) , exprs)
1684
+ . alias ( udaf. name ( ) )
1685
+ . schema ( Arc :: clone ( & schema) )
1686
+ . build ( )
1687
+ . map ( Arc :: new)
1688
+ . unwrap ( )
1689
+ } )
1690
+ . collect :: < Vec < _ > > ( ) ;
1691
+ let filter_exprs = std:: iter:: repeat ( None )
1692
+ . take ( aggr_expr. len ( ) )
1693
+ . collect :: < Vec < _ > > ( ) ;
1694
+
1667
1695
Arc :: new (
1668
1696
AggregateExec :: try_new (
1669
1697
AggregateMode :: FinalPartitioned ,
1670
1698
final_grouping,
1671
- vec ! [ ] ,
1672
- vec ! [ ] ,
1699
+ aggr_expr . clone ( ) ,
1700
+ filter_exprs . clone ( ) ,
1673
1701
Arc :: new (
1674
1702
AggregateExec :: try_new (
1675
1703
AggregateMode :: Partial ,
1676
1704
group_by,
1677
- vec ! [ ] ,
1678
- vec ! [ ] ,
1705
+ aggr_expr ,
1706
+ filter_exprs ,
1679
1707
input,
1680
1708
schema. clone ( ) ,
1681
1709
)
@@ -3436,6 +3464,296 @@ pub(crate) mod tests {
3436
3464
Ok ( ( ) )
3437
3465
}
3438
3466
3467
+ #[ test]
3468
+ fn repartitions_for_aggregate_after_union ( ) -> Result < ( ) > {
3469
+ let union = union_exec ( vec ! [ parquet_exec( ) ; 2 ] ) ;
3470
+ let plan =
3471
+ aggregate_exec_with_alias ( union, vec ! [ ( "a" . to_string( ) , "a1" . to_string( ) ) ] ) ;
3472
+
3473
+ // distribution error without repartitioning
3474
+ let err = check_plan_sanity ( plan. clone ( ) , & Default :: default ( ) ) . unwrap_err ( ) ;
3475
+ assert ! ( err. message( ) . contains( "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]\" ] does not satisfy distribution requirements: HashPartitioned[[a1@0]]). Child-0 output partitioning: UnknownPartitioning(2)" ) ) ;
3476
+
3477
+ let expected = & [
3478
+ "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[]" ,
3479
+ "RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10" ,
3480
+ "AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[]" ,
3481
+ "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2" ,
3482
+ "UnionExec" ,
3483
+ "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]" ,
3484
+ "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]" ,
3485
+ ] ;
3486
+
3487
+ assert_optimized ! ( expected, plan. clone( ) , true ) ;
3488
+ assert_optimized ! ( expected, plan. clone( ) , false ) ;
3489
+
3490
+ Ok ( ( ) )
3491
+ }
3492
+
3493
+ #[ derive( Debug , Clone ) ]
3494
+ struct MyExtensionNode {
3495
+ input : Arc < dyn ExecutionPlan > ,
3496
+ cache : PlanProperties ,
3497
+ }
3498
+ impl MyExtensionNode {
3499
+ fn new ( input : Arc < dyn ExecutionPlan > ) -> Self {
3500
+ let cache = PlanProperties :: new (
3501
+ EquivalenceProperties :: new ( schema ( ) ) ,
3502
+ Partitioning :: UnknownPartitioning ( 1 ) , // from our extension node
3503
+ input. pipeline_behavior ( ) ,
3504
+ input. boundedness ( ) ,
3505
+ ) ;
3506
+ Self { cache, input }
3507
+ }
3508
+ }
3509
+ impl ExecutionPlan for MyExtensionNode {
3510
+ fn required_input_distribution ( & self ) -> Vec < Distribution > {
3511
+ // from our extension node
3512
+ vec ! [ Distribution :: SinglePartition ]
3513
+ }
3514
+ fn name ( & self ) -> & str {
3515
+ "MyExtensionNode"
3516
+ }
3517
+ fn as_any ( & self ) -> & dyn std:: any:: Any {
3518
+ self
3519
+ }
3520
+ fn properties ( & self ) -> & PlanProperties {
3521
+ & self . cache
3522
+ }
3523
+ fn children ( & self ) -> Vec < & Arc < dyn ExecutionPlan > > {
3524
+ vec ! [ & self . input]
3525
+ }
3526
+ fn with_new_children (
3527
+ self : Arc < Self > ,
3528
+ children : Vec < Arc < dyn ExecutionPlan > > ,
3529
+ ) -> Result < Arc < dyn ExecutionPlan > > {
3530
+ assert_eq ! ( children. len( ) , 1 ) ;
3531
+ Ok ( Arc :: new ( Self :: new ( children[ 0 ] . clone ( ) ) ) )
3532
+ }
3533
+ fn execute (
3534
+ & self ,
3535
+ _partition : usize ,
3536
+ _context : Arc < TaskContext > ,
3537
+ ) -> Result < SendableRecordBatchStream > {
3538
+ unimplemented ! ( )
3539
+ }
3540
+ }
3541
+ impl DisplayAs for MyExtensionNode {
3542
+ fn fmt_as (
3543
+ & self ,
3544
+ _t : DisplayFormatType ,
3545
+ f : & mut std:: fmt:: Formatter ,
3546
+ ) -> std:: fmt:: Result {
3547
+ write ! ( f, "MyExtensionNode" )
3548
+ }
3549
+ }
3550
+
3551
+ #[ test]
3552
+ fn repartitions_for_extension_node_with_aggregate_after_union ( ) -> Result < ( ) > {
3553
+ let union = union_exec ( vec ! [ parquet_exec( ) ; 2 ] ) ;
3554
+ let plan =
3555
+ aggregate_exec_with_alias ( union, vec ! [ ( "a" . to_string( ) , "a1" . to_string( ) ) ] ) ;
3556
+ let plan: Arc < dyn ExecutionPlan > = Arc :: new ( MyExtensionNode :: new ( plan) ) ;
3557
+
3558
+ // same plan as before, but with the extension node on top
3559
+ let err = check_plan_sanity ( plan. clone ( ) , & Default :: default ( ) ) . unwrap_err ( ) ;
3560
+ assert ! ( err. message( ) . contains( "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]\" ] does not satisfy distribution requirements: SinglePartition. Child-0 output partitioning: UnknownPartitioning(2)" ) ) ;
3561
+
3562
+ let expected = & [
3563
+ "MyExtensionNode" ,
3564
+ "CoalescePartitionsExec" ,
3565
+ "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[]" ,
3566
+ "RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10" ,
3567
+ "AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[]" ,
3568
+ "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2" ,
3569
+ "UnionExec" ,
3570
+ "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]" ,
3571
+ "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]" ,
3572
+ ] ;
3573
+
3574
+ assert_optimized ! ( expected, plan. clone( ) , true ) ;
3575
+ assert_optimized ! ( expected, plan. clone( ) , false ) ;
3576
+
3577
+ Ok ( ( ) )
3578
+ }
3579
+
3580
+ #[ test]
3581
+ fn repartitions_for_aggregate_after_sorted_union ( ) -> Result < ( ) > {
3582
+ let union = union_exec ( vec ! [ parquet_exec( ) ; 2 ] ) ;
3583
+ let schema = schema ( ) ;
3584
+ let sort_key = LexOrdering :: new ( vec ! [ PhysicalSortExpr {
3585
+ expr: col( "a" , & schema) . unwrap( ) ,
3586
+ options: SortOptions :: default ( ) ,
3587
+ } ] ) ;
3588
+ let sort = sort_exec ( sort_key, union, false ) ;
3589
+ let plan =
3590
+ aggregate_exec_with_alias ( sort, vec ! [ ( "a" . to_string( ) , "a1" . to_string( ) ) ] ) ;
3591
+
3592
+ // with the sort, there is no error
3593
+ let checker = check_plan_sanity ( plan. clone ( ) , & Default :: default ( ) ) ;
3594
+ assert ! ( checker. is_ok( ) ) ;
3595
+
3596
+ // it still repartitions
3597
+ let expected_after_first_run = & [
3598
+ "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[], ordering_mode=Sorted" ,
3599
+ "SortExec: expr=[a1@0 ASC NULLS LAST], preserve_partitioning=[true]" ,
3600
+ "RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10" ,
3601
+ "AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[], ordering_mode=Sorted" ,
3602
+ "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1" ,
3603
+ "SortPreservingMergeExec: [a@0 ASC]" ,
3604
+ "UnionExec" ,
3605
+ "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]" ,
3606
+ "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]" ,
3607
+ "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]" ,
3608
+ "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]" ,
3609
+ ] ;
3610
+ assert_optimized ! ( expected_after_first_run, plan. clone( ) , true ) ;
3611
+
3612
+ let expected_after_second_run = & [
3613
+ "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[], ordering_mode=Sorted" ,
3614
+ "SortExec: expr=[a1@0 ASC NULLS LAST], preserve_partitioning=[true]" ,
3615
+ "RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10" ,
3616
+ "SortExec: expr=[a1@0 ASC NULLS LAST], preserve_partitioning=[true]" ,
3617
+ "AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[], ordering_mode=Sorted" ,
3618
+ "SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]" , // adds another sort
3619
+ "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2" ,
3620
+ "SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]" , // removes the SPM
3621
+ "UnionExec" ,
3622
+ "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]" ,
3623
+ "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]" ,
3624
+ "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]" ,
3625
+ "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]" ,
3626
+ ] ;
3627
+ assert_optimized ! ( expected_after_second_run, plan. clone( ) , false ) ;
3628
+
3629
+ Ok ( ( ) )
3630
+ }
3631
+
3632
+ #[ test]
3633
+ fn repartition_for_aggregate_after_sorted_union_projection ( ) -> Result < ( ) > {
3634
+ let union = union_exec ( vec ! [ parquet_exec( ) ; 2 ] ) ;
3635
+ let projection = projection_exec_with_alias (
3636
+ union,
3637
+ vec ! [
3638
+ ( "a" . to_string( ) , "a" . to_string( ) ) ,
3639
+ ( "b" . to_string( ) , "value" . to_string( ) ) ,
3640
+ ] ,
3641
+ ) ;
3642
+ let schema = schema ( ) ;
3643
+ let sort_key = LexOrdering :: new ( vec ! [ PhysicalSortExpr {
3644
+ expr: col( "a" , & schema) . unwrap( ) ,
3645
+ options: SortOptions :: default ( ) ,
3646
+ } ] ) ;
3647
+ let sort = sort_exec ( sort_key, projection, false ) ;
3648
+ let plan =
3649
+ aggregate_exec_with_alias ( sort, vec ! [ ( "a" . to_string( ) , "a1" . to_string( ) ) ] ) ;
3650
+
3651
+ // with the sort, there is no error
3652
+ let checker = check_plan_sanity ( plan. clone ( ) , & Default :: default ( ) ) ;
3653
+ assert ! ( checker. is_ok( ) ) ;
3654
+
3655
+ // it still repartitions
3656
+ let expected_after_first_run = & [
3657
+ "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[], ordering_mode=Sorted" ,
3658
+ "SortExec: expr=[a1@0 ASC NULLS LAST], preserve_partitioning=[true]" ,
3659
+ "RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10" ,
3660
+ "AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[], ordering_mode=Sorted" ,
3661
+ "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1" ,
3662
+ "SortPreservingMergeExec: [a@0 ASC]" ,
3663
+ "SortExec: expr=[a@0 ASC], preserve_partitioning=[true]" ,
3664
+ "ProjectionExec: expr=[a@0 as a, b@1 as value]" ,
3665
+ "UnionExec" ,
3666
+ "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]" ,
3667
+ "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]" ,
3668
+ ] ;
3669
+ assert_optimized ! ( expected_after_first_run, plan. clone( ) , true ) ;
3670
+
3671
+ let expected_after_second_run = & [
3672
+ "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[], ordering_mode=Sorted" ,
3673
+ "SortExec: expr=[a1@0 ASC NULLS LAST], preserve_partitioning=[true]" ,
3674
+ "RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10" ,
3675
+ "SortExec: expr=[a1@0 ASC NULLS LAST], preserve_partitioning=[true]" , // adds another sort
3676
+ "AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[], ordering_mode=Sorted" ,
3677
+ "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1" ,
3678
+ // removes the SPM
3679
+ "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]" ,
3680
+ "CoalescePartitionsExec" , // adds the coalesce
3681
+ "ProjectionExec: expr=[a@0 as a, b@1 as value]" ,
3682
+ "UnionExec" ,
3683
+ "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]" ,
3684
+ "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]" ,
3685
+ ] ;
3686
+ assert_optimized ! ( expected_after_second_run, plan. clone( ) , false ) ;
3687
+
3688
+ Ok ( ( ) )
3689
+ }
3690
+
3691
+ #[ test]
3692
+ fn repartition_for_aggregate_sum_after_sorted_union_projection ( ) -> Result < ( ) > {
3693
+ let union = union_exec ( vec ! [ parquet_exec( ) ; 2 ] ) ;
3694
+ let projection = projection_exec_with_alias (
3695
+ union,
3696
+ vec ! [
3697
+ ( "a" . to_string( ) , "a" . to_string( ) ) ,
3698
+ ( "b" . to_string( ) , "b" . to_string( ) ) ,
3699
+ ] ,
3700
+ ) ;
3701
+ let schema = schema ( ) ;
3702
+ let sort_key = LexOrdering :: new ( vec ! [ PhysicalSortExpr {
3703
+ expr: col( "a" , & schema) . unwrap( ) ,
3704
+ options: SortOptions :: default ( ) ,
3705
+ } ] ) ;
3706
+ let sort = sort_exec ( sort_key, projection, false ) ;
3707
+ let plan = aggregate_exec_with_aggr_expr_and_alias (
3708
+ sort,
3709
+ vec ! [ ( sum_udaf( ) , vec![ col( "b" , & schema) ?] ) ] ,
3710
+ vec ! [ ( "a" . to_string( ) , "a1" . to_string( ) ) ] ,
3711
+ ) ;
3712
+ let plan: Arc < dyn ExecutionPlan > = Arc :: new ( MyExtensionNode :: new ( plan) ) ;
3713
+
3714
+ // with the sort, there is no error
3715
+ let checker = check_plan_sanity ( plan. clone ( ) , & Default :: default ( ) ) ;
3716
+ assert ! ( checker. is_ok( ) ) ;
3717
+
3718
+ // it still repartitions
3719
+ let expected_after_first_run = & [
3720
+ "MyExtensionNode" ,
3721
+ "CoalescePartitionsExec" ,
3722
+ "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[sum], ordering_mode=Sorted" ,
3723
+ "SortExec: expr=[a1@0 ASC NULLS LAST], preserve_partitioning=[true]" ,
3724
+ "RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10" ,
3725
+ "AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[sum], ordering_mode=Sorted" ,
3726
+ "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1" ,
3727
+ "SortPreservingMergeExec: [a@0 ASC]" ,
3728
+ "SortExec: expr=[a@0 ASC], preserve_partitioning=[true]" ,
3729
+ "ProjectionExec: expr=[a@0 as a, b@1 as b]" ,
3730
+ "UnionExec" ,
3731
+ "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]" ,
3732
+ "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]" ,
3733
+ ] ;
3734
+ assert_optimized ! ( expected_after_first_run, plan. clone( ) , true ) ;
3735
+
3736
+ let expected_after_second_run = & [
3737
+ "MyExtensionNode" ,
3738
+ "CoalescePartitionsExec" ,
3739
+ "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[sum], ordering_mode=Sorted" ,
3740
+ "SortExec: expr=[a1@0 ASC NULLS LAST], preserve_partitioning=[true]" ,
3741
+ "RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10" ,
3742
+ "SortExec: expr=[a1@0 ASC NULLS LAST], preserve_partitioning=[true]" ,
3743
+ "AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[sum], ordering_mode=Sorted" ,
3744
+ "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1" ,
3745
+ "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]" ,
3746
+ "CoalescePartitionsExec" ,
3747
+ "ProjectionExec: expr=[a@0 as a, b@1 as b]" ,
3748
+ "UnionExec" ,
3749
+ "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]" ,
3750
+ "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]" ,
3751
+ ] ;
3752
+ assert_optimized ! ( expected_after_second_run, plan. clone( ) , false ) ;
3753
+
3754
+ Ok ( ( ) )
3755
+ }
3756
+
3439
3757
#[ test]
3440
3758
fn repartition_through_sort_preserving_merge ( ) -> Result < ( ) > {
3441
3759
// sort preserving merge with non-sorted input
0 commit comments