From 47d2cf827f03bf02413f88b1be480f6e1df0745f Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Thu, 20 Mar 2025 20:28:17 +0000 Subject: [PATCH 1/3] type coercion fix for uint/int's. --- .../expr-common/src/type_coercion/binary.rs | 173 +++++++++++++----- .../src/single_distinct_to_groupby.rs | 8 +- .../optimizer/tests/optimizer_integration.rs | 4 +- .../physical-expr/src/expressions/binary.rs | 6 +- datafusion/sqllogictest/test_files/math.slt | 10 +- .../sqllogictest/test_files/operator.slt | 8 +- datafusion/sqllogictest/test_files/window.slt | 8 +- 7 files changed, 148 insertions(+), 69 deletions(-) diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index fb559e163bb1..40ed04ff2dd6 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -855,21 +855,14 @@ pub fn binary_numeric_coercion( (UInt64, _) | (_, UInt64) => Some(UInt64), (Int64, _) | (_, Int64) - | (UInt32, Int8) - | (Int8, UInt32) - | (UInt32, Int16) - | (Int16, UInt32) - | (UInt32, Int32) - | (Int32, UInt32) => Some(Int64), - (Int32, _) - | (_, Int32) - | (UInt16, Int16) - | (Int16, UInt16) - | (UInt16, Int8) - | (Int8, UInt16) => Some(Int32), + | (UInt32, Int32 | Int16 | Int8) + | (Int32 | Int16 | Int8, UInt32) => Some(Int64), (UInt32, _) | (_, UInt32) => Some(UInt32), - (Int16, _) | (_, Int16) | (Int8, UInt8) | (UInt8, Int8) => Some(Int16), + (Int32, _) | (_, Int32) | (UInt16, Int16 | Int8) | (Int16 | Int8, UInt16) => { + Some(Int32) + } (UInt16, _) | (_, UInt16) => Some(UInt16), + (Int16, _) | (_, Int16) | (Int8, UInt8) | (UInt8, Int8) => Some(Int16), (Int8, _) | (_, Int8) => Some(Int8), (UInt8, _) | (_, UInt8) => Some(UInt8), _ => None, @@ -1036,13 +1029,24 @@ fn mathematics_numerical_coercion( } (Float64, _) | (_, Float64) => Some(Float64), (_, Float32) | (Float32, _) => Some(Float32), - (Int64, _) | (_, Int64) => Some(Int64), - (Int32, _) | (_, Int32) => Some(Int32), - (Int16, _) | (_, Int16) => Some(Int16), - (Int8, _) | (_, Int8) => Some(Int8), + // The following match arms encode the following logic: Given the two + // integral types, we choose the narrowest possible integral type that + // accommodates all values of both types. Note that to avoid information + // loss when combining UInt64 with signed integers we use Decimal128(20, 0). + (UInt64, Int64 | Int32 | Int16 | Int8) + | (Int64 | Int32 | Int16 | Int8, UInt64) => Some(Decimal128(20, 0)), (UInt64, _) | (_, UInt64) => Some(UInt64), + (Int64, _) + | (_, Int64) + | (UInt32, Int32 | Int16 | Int8) + | (Int32 | Int16 | Int8, UInt32) => Some(Int64), (UInt32, _) | (_, UInt32) => Some(UInt32), + (Int32, _) | (_, Int32) | (UInt16, Int16 | Int8) | (Int16 | Int8, UInt16) => { + Some(Int32) + } (UInt16, _) | (_, UInt16) => Some(UInt16), + (Int16, _) | (_, Int16) | (Int8, UInt8) | (UInt8, Int8) => Some(Int16), + (Int8, _) | (_, Int8) => Some(Int8), (UInt8, _) | (_, UInt8) => Some(UInt8), _ => None, } @@ -1621,7 +1625,7 @@ mod tests { /// Test coercion rules for binary operators /// - /// Applies coercion rules for `$LHS_TYPE $OP $RHS_TYPE` and asserts that the + /// Applies coercion rules for `$LHS_TYPE $OP $RHS_TYPE` and asserts that /// the result type is `$RESULT_TYPE` macro_rules! test_coercion_binary_rule { ($LHS_TYPE:expr, $RHS_TYPE:expr, $OP:expr, $RESULT_TYPE:expr) => {{ @@ -1632,6 +1636,26 @@ mod tests { }}; } + /// Test coercion rules for binary operators + /// + /// Applies coercion rules for each RHS_TYPE in $RHS_TYPES such that + /// `$LHS_TYPE $OP RHS_TYPE` and asserts that the result type is `$RESULT_TYPE`. + /// Also tests that the inverse `RHS_TYPE $OP $LHS_TYPE` is true + macro_rules! test_coercion_binary_rule_multiple { + ($LHS_TYPE:expr, $RHS_TYPES:expr, $OP:expr, $RESULT_TYPE:expr) => {{ + for rh_type in $RHS_TYPES { + let (lhs, rhs) = BinaryTypeCoercer::new(&$LHS_TYPE, &$OP, &rh_type) + .get_input_types()?; + assert_eq!(lhs, $RESULT_TYPE); + assert_eq!(rhs, $RESULT_TYPE); + + BinaryTypeCoercer::new(&rh_type, &$OP, &$LHS_TYPE).get_input_types()?; + assert_eq!(lhs, $RESULT_TYPE); + assert_eq!(rhs, $RESULT_TYPE); + } + }}; + } + /// Test coercion rules for like /// /// Applies coercion rules for both @@ -1991,39 +2015,94 @@ mod tests { #[test] fn test_type_coercion_arithmetic() -> Result<()> { - // integer - test_coercion_binary_rule!( - DataType::Int32, - DataType::UInt32, + use DataType::*; + + // (Float64, _) | (_, Float64) => Some(Float64), + test_coercion_binary_rule_multiple!( + Float64, + [ + Float64, Float32, Float16, Int64, UInt64, Int32, UInt32, Int16, UInt16, + Int8, UInt8 + ], Operator::Plus, - DataType::Int32 - ); - test_coercion_binary_rule!( - DataType::Int32, - DataType::UInt16, + Float64 + ); + // (_, Float32) | (Float32, _) => Some(Float32), + test_coercion_binary_rule_multiple!( + Float32, + [ + Float32, Float16, Int64, UInt64, Int32, UInt32, Int16, UInt16, Int8, + UInt8 + ], + Operator::Plus, + Float32 + ); + // (UInt64, Int64 | Int32 | Int16 | Int8) | (Int64 | Int32 | Int16 | Int8, UInt64) => Some(Decimal128(20, 0)), + test_coercion_binary_rule_multiple!( + UInt64, + [Int64, Int32, Int16, Int8], + Operator::Divide, + Decimal128(20, 0) + ); + // (UInt64, _) | (_, UInt64) => Some(UInt64), + test_coercion_binary_rule_multiple!( + UInt64, + [UInt64, UInt32, UInt16, UInt8], + Operator::Modulo, + UInt64 + ); + // (Int64, _) | (_, Int64) => Some(Int64), + test_coercion_binary_rule_multiple!( + Int64, + [Int64, Int32, UInt32, Int16, UInt16, Int8, UInt8], + Operator::Modulo, + Int64 + ); + // (UInt32, Int32 | Int16 | Int8) | (Int32 | Int16 | Int8, UInt32) => Some(Int64) + test_coercion_binary_rule_multiple!( + UInt32, + [Int32, Int16, Int8], + Operator::Modulo, + Int64 + ); + // (UInt32, _) | (_, UInt32) => Some(UInt32), + test_coercion_binary_rule_multiple!( + UInt32, + [UInt32, UInt16, UInt8], + Operator::Modulo, + UInt32 + ); + // (Int32, _) | (_, Int32) => Some(Int32), + test_coercion_binary_rule_multiple!( + Int32, + [Int32, Int16, Int8], + Operator::Modulo, + Int32 + ); + // (UInt16, Int16 | Int8) | (Int16 | Int8, UInt16) => Some(Int32) + test_coercion_binary_rule_multiple!( + UInt16, + [Int16, Int8], Operator::Minus, - DataType::Int32 + Int32 ); - test_coercion_binary_rule!( - DataType::Int8, - DataType::Int64, - Operator::Multiply, - DataType::Int64 - ); - // float - test_coercion_binary_rule!( - DataType::Float32, - DataType::Int32, + // (UInt16, _) | (_, UInt16) => Some(UInt16), + test_coercion_binary_rule_multiple!( + UInt16, + [UInt16, UInt8, UInt8], Operator::Plus, - DataType::Float32 - ); - test_coercion_binary_rule!( - DataType::Float32, - DataType::Float64, - Operator::Multiply, - DataType::Float64 - ); - // TODO add other data type + UInt16 + ); + // (Int16, _) | (_, Int16) => Some(Int16), + test_coercion_binary_rule_multiple!(Int16, [Int16, Int8], Operator::Plus, Int16); + // (UInt8, Int8) | (Int8, UInt8) => Some(Int16) + test_coercion_binary_rule!(Int8, UInt8, Operator::Minus, Int16); + test_coercion_binary_rule!(UInt8, Int8, Operator::Multiply, Int16); + // (UInt8, _) | (_, UInt8) => Some(UInt8), + test_coercion_binary_rule!(UInt8, UInt8, Operator::Minus, UInt8); + // (Int8, _) | (_, Int8) => Some(Int8), + test_coercion_binary_rule!(Int8, Int8, Operator::Plus, Int8); + Ok(()) } diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 191377fc2759..7337d2ffce5c 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -410,7 +410,7 @@ mod tests { let expected = "Projection: count(alias1) AS count(DISTINCT Int32(2) * test.b) [count(DISTINCT Int32(2) * test.b):Int64]\ \n Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64]\ - \n Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int32]\ + \n Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -497,9 +497,9 @@ mod tests { .build()?; // Should work - let expected = "Projection: group_alias_0 AS test.a + Int32(1), count(alias1) AS count(DISTINCT test.c) [test.a + Int32(1):Int32, count(DISTINCT test.c):Int64]\ - \n Aggregate: groupBy=[[group_alias_0]], aggr=[[count(alias1)]] [group_alias_0:Int32, count(alias1):Int64]\ - \n Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\ + let expected = "Projection: group_alias_0 AS test.a + Int32(1), count(alias1) AS count(DISTINCT test.c) [test.a + Int32(1):Int64, count(DISTINCT test.c):Int64]\ + \n Aggregate: groupBy=[[group_alias_0]], aggr=[[count(alias1)]] [group_alias_0:Int64, count(alias1):Int64]\ + \n Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int64, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index 5e66c7ec0313..13d6b8de79cf 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -267,8 +267,8 @@ fn push_down_filter_groupby_expr_contains_alias() { let sql = "SELECT * FROM (SELECT (col_int32 + col_uint32) AS c, count(*) FROM test GROUP BY 1) where c > 3"; let plan = test_sql(sql).unwrap(); let expected = "Projection: test.col_int32 + test.col_uint32 AS c, count(Int64(1)) AS count(*)\ - \n Aggregate: groupBy=[[test.col_int32 + CAST(test.col_uint32 AS Int32)]], aggr=[[count(Int64(1))]]\ - \n Filter: test.col_int32 + CAST(test.col_uint32 AS Int32) > Int32(3)\ + \n Aggregate: groupBy=[[CAST(test.col_int32 AS Int64) + CAST(test.col_uint32 AS Int64)]], aggr=[[count(Int64(1))]]\ + \n Filter: CAST(test.col_int32 AS Int64) + CAST(test.col_uint32 AS Int64) > Int64(3)\ \n TableScan: test projection=[col_int32, col_uint32]"; assert_eq!(expected, format!("{plan}")); } diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index a00d135ef3c1..f33ddf685cbf 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -1020,9 +1020,9 @@ mod tests { DataType::UInt32, vec![1u32, 2u32], Operator::Plus, - Int32Array, - DataType::Int32, - [2i32, 4i32], + Int64Array, + DataType::Int64, + [2i64, 4i64], ); test_coercion!( Int32Array, diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index a49e0a642106..e206aa16b8a9 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -164,22 +164,22 @@ INSERT into test_nullable_integer values(127, 32767, 2147483647, 922337203685477 ---- 1 -query IIIIIIII +query IIIIIIIR SELECT c1*0, c2*0, c3*0, c4*0, c5*0, c6*0, c7*0, c8*0 FROM test_nullable_integer where dataset = 'nulls' ---- NULL NULL NULL NULL NULL NULL NULL NULL -query IIIIIIII +query IIIIIIIR SELECT c1/0, c2/0, c3/0, c4/0, c5/0, c6/0, c7/0, c8/0 FROM test_nullable_integer where dataset = 'nulls' ---- NULL NULL NULL NULL NULL NULL NULL NULL -query IIIIIIII +query IIIIIIIR SELECT c1%0, c2%0, c3%0, c4%0, c5%0, c6%0, c7%0, c8%0 FROM test_nullable_integer where dataset = 'nulls' ---- NULL NULL NULL NULL NULL NULL NULL NULL -query IIIIIIII rowsort +query IIIIIIIR rowsort select c1*0, c2*0, c3*0, c4*0, c5*0, c6*0, c7*0, c8*0 from test_nullable_integer where dataset != 'maxs' ---- 0 0 0 0 0 0 0 0 @@ -300,7 +300,7 @@ INSERT INTO test_non_nullable_integer VALUES(1, 1, 1, 1, 1, 1, 1, 1) ---- 1 -query IIIIIIII rowsort +query IIIIIIIR rowsort select c1*0, c2*0, c3*0, c4*0, c5*0, c6*0, c7*0, c8*0 from test_non_nullable_integer ---- 0 0 0 0 0 0 0 0 diff --git a/datafusion/sqllogictest/test_files/operator.slt b/datafusion/sqllogictest/test_files/operator.slt index 8fd0a7a61033..a651eda99684 100644 --- a/datafusion/sqllogictest/test_files/operator.slt +++ b/datafusion/sqllogictest/test_files/operator.slt @@ -70,7 +70,7 @@ select arrow_typeof(decimal + 2) from numeric_types; ---- -Int64 Int64 Int64 Int64 Int64 Int64 Int64 Int64 Float32 Float64 Decimal128(23, 2) +Int64 Int64 Int64 Int64 Int64 Int64 Int64 Decimal128(21, 0) Float32 Float64 Decimal128(23, 2) # Plus with literal decimal query TTTTTTTTTTT @@ -127,7 +127,7 @@ select arrow_typeof(decimal - 2) from numeric_types; ---- -Int64 Int64 Int64 Int64 Int64 Int64 Int64 Int64 Float32 Float64 Decimal128(23, 2) +Int64 Int64 Int64 Int64 Int64 Int64 Int64 Decimal128(21, 0) Float32 Float64 Decimal128(23, 2) # Minus with literal decimal query TTTTTTTTTTT @@ -184,7 +184,7 @@ select arrow_typeof(decimal * 2) from numeric_types; ---- -Int64 Int64 Int64 Int64 Int64 Int64 Int64 Int64 Float32 Float64 Decimal128(26, 2) +Int64 Int64 Int64 Int64 Int64 Int64 Int64 Decimal128(38, 0) Float32 Float64 Decimal128(26, 2) # Multiply with literal decimal query TTTTTTTTTTT @@ -242,7 +242,7 @@ select arrow_typeof(decimal / 2) from numeric_types; ---- -Int64 Int64 Int64 Int64 Int64 Int64 Int64 Int64 Float32 Float64 Decimal128(9, 6) +Int64 Int64 Int64 Int64 Int64 Int64 Int64 Decimal128(24, 4) Float32 Float64 Decimal128(9, 6) # Divide with literal decimal query TTTTTTTTTTT diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index fd623b67fe9f..76e3751e4b8e 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -2448,15 +2448,15 @@ EXPLAIN SELECT c5, c9, rn1 FROM (SELECT c5, c9, LIMIT 5 ---- logical_plan -01)Sort: rn1 ASC NULLS LAST, CAST(aggregate_test_100.c9 AS Int32) + aggregate_test_100.c5 DESC NULLS FIRST, fetch=5 +01)Sort: rn1 ASC NULLS LAST, CAST(aggregate_test_100.c9 AS Decimal128(20, 0)) + CAST(aggregate_test_100.c5 AS Decimal128(20, 0)) DESC NULLS FIRST, fetch=5 02)--Projection: aggregate_test_100.c5, aggregate_test_100.c9, row_number() ORDER BY [aggregate_test_100.c9 + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rn1 -03)----WindowAggr: windowExpr=[[row_number() ORDER BY [CAST(aggregate_test_100.c9 AS Int32) + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS row_number() ORDER BY [aggregate_test_100.c9 + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +03)----WindowAggr: windowExpr=[[row_number() ORDER BY [CAST(aggregate_test_100.c9 AS Decimal128(20, 0)) + CAST(aggregate_test_100.c5 AS Decimal128(20, 0)) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS row_number() ORDER BY [aggregate_test_100.c9 + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] 04)------TableScan: aggregate_test_100 projection=[c5, c9] physical_plan 01)ProjectionExec: expr=[c5@0 as c5, c9@1 as c9, row_number() ORDER BY [aggregate_test_100.c9 + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as rn1] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "row_number() ORDER BY [aggregate_test_100.c9 + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -04)------SortExec: expr=[CAST(c9@1 AS Int32) + c5@0 DESC], preserve_partitioning=[false] +03)----BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "row_number() ORDER BY [aggregate_test_100.c9 + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Decimal128(None,21,0)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +04)------SortExec: expr=[CAST(c9@1 AS Decimal128(20, 0)) + CAST(c5@0 AS Decimal128(20, 0)) DESC], preserve_partitioning=[false] 05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c5, c9], file_type=csv, has_header=true # Ordering equivalence should be preserved during cast expression From ced045e96dde863512d4ebeef911f407b7264cc8 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Fri, 21 Mar 2025 16:16:13 +0000 Subject: [PATCH 2/3] Refactored common numerical coercion logic into a single function. --- .../expr-common/src/type_coercion/binary.rs | 38 ++++++------------- 1 file changed, 11 insertions(+), 27 deletions(-) diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index 40ed04ff2dd6..70ad1232428a 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -827,7 +827,6 @@ pub fn binary_numeric_coercion( lhs_type: &DataType, rhs_type: &DataType, ) -> Option { - use arrow::datatypes::DataType::*; if !lhs_type.is_numeric() || !rhs_type.is_numeric() { return None; }; @@ -841,32 +840,7 @@ pub fn binary_numeric_coercion( return Some(t); } - // These are ordered from most informative to least informative so - // that the coercion does not lose information via truncation - match (lhs_type, rhs_type) { - (Float64, _) | (_, Float64) => Some(Float64), - (_, Float32) | (Float32, _) => Some(Float32), - // The following match arms encode the following logic: Given the two - // integral types, we choose the narrowest possible integral type that - // accommodates all values of both types. Note that to avoid information - // loss when combining UInt64 with signed integers we use Decimal128(20, 0). - (UInt64, Int64 | Int32 | Int16 | Int8) - | (Int64 | Int32 | Int16 | Int8, UInt64) => Some(Decimal128(20, 0)), - (UInt64, _) | (_, UInt64) => Some(UInt64), - (Int64, _) - | (_, Int64) - | (UInt32, Int32 | Int16 | Int8) - | (Int32 | Int16 | Int8, UInt32) => Some(Int64), - (UInt32, _) | (_, UInt32) => Some(UInt32), - (Int32, _) | (_, Int32) | (UInt16, Int16 | Int8) | (Int16 | Int8, UInt16) => { - Some(Int32) - } - (UInt16, _) | (_, UInt16) => Some(UInt16), - (Int16, _) | (_, Int16) | (Int8, UInt8) | (UInt8, Int8) => Some(Int16), - (Int8, _) | (_, Int8) => Some(Int8), - (UInt8, _) | (_, UInt8) => Some(UInt8), - _ => None, - } + numerical_coercion(lhs_type, rhs_type) } /// Decimal coercion rules. @@ -1027,6 +1001,16 @@ fn mathematics_numerical_coercion( (_, Dictionary(_, value_type)) => { mathematics_numerical_coercion(lhs_type, value_type) } + _ => numerical_coercion(lhs_type, rhs_type), + } +} + +/// A common set of numerical coercions that are applied for mathematical and binary ops +/// to lhs_type` and `rhs_type`. +fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { + use arrow::datatypes::DataType::*; + + match (lhs_type, rhs_type) { (Float64, _) | (_, Float64) => Some(Float64), (_, Float32) | (Float32, _) => Some(Float32), // The following match arms encode the following logic: Given the two From 08effa8f8a2d095b2c1760e16a3401585e2b9ff4 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Fri, 21 Mar 2025 16:48:16 +0000 Subject: [PATCH 3/3] Cargo fmt. --- datafusion/expr-common/src/type_coercion/binary.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index 70ad1232428a..77807538f9d3 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -1006,7 +1006,7 @@ fn mathematics_numerical_coercion( } /// A common set of numerical coercions that are applied for mathematical and binary ops -/// to lhs_type` and `rhs_type`. +/// to `lhs_type` and `rhs_type`. fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*;