Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix type coercion for unsigned and signed integers (Int64 vs UInt64, etc) #15341

Merged
merged 3 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 131 additions & 68 deletions datafusion/expr-common/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,6 @@ pub fn binary_numeric_coercion(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Option<DataType> {
use arrow::datatypes::DataType::*;
if !lhs_type.is_numeric() || !rhs_type.is_numeric() {
return None;
};
Expand All @@ -841,39 +840,7 @@ pub fn binary_numeric_coercion(
return Some(t);
}

// These are ordered from most informative to least informative so
// that the coercion does not lose information via truncation
match (lhs_type, rhs_type) {
(Float64, _) | (_, Float64) => Some(Float64),
(_, Float32) | (Float32, _) => Some(Float32),
// The following match arms encode the following logic: Given the two
// integral types, we choose the narrowest possible integral type that
// accommodates all values of both types. Note that to avoid information
// loss when combining UInt64 with signed integers we use Decimal128(20, 0).
(UInt64, Int64 | Int32 | Int16 | Int8)
| (Int64 | Int32 | Int16 | Int8, UInt64) => Some(Decimal128(20, 0)),
(UInt64, _) | (_, UInt64) => Some(UInt64),
(Int64, _)
| (_, Int64)
| (UInt32, Int8)
| (Int8, UInt32)
| (UInt32, Int16)
| (Int16, UInt32)
| (UInt32, Int32)
| (Int32, UInt32) => Some(Int64),
(Int32, _)
| (_, Int32)
| (UInt16, Int16)
| (Int16, UInt16)
| (UInt16, Int8)
| (Int8, UInt16) => Some(Int32),
(UInt32, _) | (_, UInt32) => Some(UInt32),
(Int16, _) | (_, Int16) | (Int8, UInt8) | (UInt8, Int8) => Some(Int16),
(UInt16, _) | (_, UInt16) => Some(UInt16),
(Int8, _) | (_, Int8) => Some(Int8),
(UInt8, _) | (_, UInt8) => Some(UInt8),
_ => None,
}
numerical_coercion(lhs_type, rhs_type)
}

/// Decimal coercion rules.
Expand Down Expand Up @@ -1034,15 +1001,36 @@ fn mathematics_numerical_coercion(
(_, Dictionary(_, value_type)) => {
mathematics_numerical_coercion(lhs_type, value_type)
}
_ => numerical_coercion(lhs_type, rhs_type),
}
}

/// A common set of numerical coercions that are applied for mathematical and binary ops
/// to `lhs_type` and `rhs_type`.
fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;

match (lhs_type, rhs_type) {
(Float64, _) | (_, Float64) => Some(Float64),
(_, Float32) | (Float32, _) => Some(Float32),
(Int64, _) | (_, Int64) => Some(Int64),
(Int32, _) | (_, Int32) => Some(Int32),
(Int16, _) | (_, Int16) => Some(Int16),
(Int8, _) | (_, Int8) => Some(Int8),
// The following match arms encode the following logic: Given the two
// integral types, we choose the narrowest possible integral type that
// accommodates all values of both types. Note that to avoid information
// loss when combining UInt64 with signed integers we use Decimal128(20, 0).
(UInt64, Int64 | Int32 | Int16 | Int8)
| (Int64 | Int32 | Int16 | Int8, UInt64) => Some(Decimal128(20, 0)),
(UInt64, _) | (_, UInt64) => Some(UInt64),
(Int64, _)
| (_, Int64)
| (UInt32, Int32 | Int16 | Int8)
| (Int32 | Int16 | Int8, UInt32) => Some(Int64),
(UInt32, _) | (_, UInt32) => Some(UInt32),
(Int32, _) | (_, Int32) | (UInt16, Int16 | Int8) | (Int16 | Int8, UInt16) => {
Some(Int32)
}
(UInt16, _) | (_, UInt16) => Some(UInt16),
(Int16, _) | (_, Int16) | (Int8, UInt8) | (UInt8, Int8) => Some(Int16),
(Int8, _) | (_, Int8) => Some(Int8),
(UInt8, _) | (_, UInt8) => Some(UInt8),
_ => None,
}
Expand Down Expand Up @@ -1621,7 +1609,7 @@ mod tests {

/// Test coercion rules for binary operators
///
/// Applies coercion rules for `$LHS_TYPE $OP $RHS_TYPE` and asserts that the
/// Applies coercion rules for `$LHS_TYPE $OP $RHS_TYPE` and asserts that
/// the result type is `$RESULT_TYPE`
macro_rules! test_coercion_binary_rule {
($LHS_TYPE:expr, $RHS_TYPE:expr, $OP:expr, $RESULT_TYPE:expr) => {{
Expand All @@ -1632,6 +1620,26 @@ mod tests {
}};
}

/// Test coercion rules for binary operators
///
/// Applies coercion rules for each RHS_TYPE in $RHS_TYPES such that
/// `$LHS_TYPE $OP RHS_TYPE` and asserts that the result type is `$RESULT_TYPE`.
/// Also tests that the inverse `RHS_TYPE $OP $LHS_TYPE` is true
macro_rules! test_coercion_binary_rule_multiple {
($LHS_TYPE:expr, $RHS_TYPES:expr, $OP:expr, $RESULT_TYPE:expr) => {{
for rh_type in $RHS_TYPES {
let (lhs, rhs) = BinaryTypeCoercer::new(&$LHS_TYPE, &$OP, &rh_type)
.get_input_types()?;
assert_eq!(lhs, $RESULT_TYPE);
assert_eq!(rhs, $RESULT_TYPE);

BinaryTypeCoercer::new(&rh_type, &$OP, &$LHS_TYPE).get_input_types()?;
assert_eq!(lhs, $RESULT_TYPE);
assert_eq!(rhs, $RESULT_TYPE);
}
}};
}

/// Test coercion rules for like
///
/// Applies coercion rules for both
Expand Down Expand Up @@ -1991,39 +1999,94 @@ mod tests {

#[test]
fn test_type_coercion_arithmetic() -> Result<()> {
// integer
test_coercion_binary_rule!(
DataType::Int32,
DataType::UInt32,
use DataType::*;

// (Float64, _) | (_, Float64) => Some(Float64),
test_coercion_binary_rule_multiple!(
Float64,
[
Float64, Float32, Float16, Int64, UInt64, Int32, UInt32, Int16, UInt16,
Int8, UInt8
],
Operator::Plus,
DataType::Int32
);
test_coercion_binary_rule!(
DataType::Int32,
DataType::UInt16,
Float64
);
// (_, Float32) | (Float32, _) => Some(Float32),
test_coercion_binary_rule_multiple!(
Float32,
[
Float32, Float16, Int64, UInt64, Int32, UInt32, Int16, UInt16, Int8,
UInt8
],
Operator::Plus,
Float32
);
// (UInt64, Int64 | Int32 | Int16 | Int8) | (Int64 | Int32 | Int16 | Int8, UInt64) => Some(Decimal128(20, 0)),
test_coercion_binary_rule_multiple!(
UInt64,
[Int64, Int32, Int16, Int8],
Operator::Divide,
Decimal128(20, 0)
);
// (UInt64, _) | (_, UInt64) => Some(UInt64),
test_coercion_binary_rule_multiple!(
UInt64,
[UInt64, UInt32, UInt16, UInt8],
Operator::Modulo,
UInt64
);
// (Int64, _) | (_, Int64) => Some(Int64),
test_coercion_binary_rule_multiple!(
Int64,
[Int64, Int32, UInt32, Int16, UInt16, Int8, UInt8],
Operator::Modulo,
Int64
);
// (UInt32, Int32 | Int16 | Int8) | (Int32 | Int16 | Int8, UInt32) => Some(Int64)
test_coercion_binary_rule_multiple!(
UInt32,
[Int32, Int16, Int8],
Operator::Modulo,
Int64
);
// (UInt32, _) | (_, UInt32) => Some(UInt32),
test_coercion_binary_rule_multiple!(
UInt32,
[UInt32, UInt16, UInt8],
Operator::Modulo,
UInt32
);
// (Int32, _) | (_, Int32) => Some(Int32),
test_coercion_binary_rule_multiple!(
Int32,
[Int32, Int16, Int8],
Operator::Modulo,
Int32
);
// (UInt16, Int16 | Int8) | (Int16 | Int8, UInt16) => Some(Int32)
test_coercion_binary_rule_multiple!(
UInt16,
[Int16, Int8],
Operator::Minus,
DataType::Int32
Int32
);
test_coercion_binary_rule!(
DataType::Int8,
DataType::Int64,
Operator::Multiply,
DataType::Int64
);
// float
test_coercion_binary_rule!(
DataType::Float32,
DataType::Int32,
// (UInt16, _) | (_, UInt16) => Some(UInt16),
test_coercion_binary_rule_multiple!(
UInt16,
[UInt16, UInt8, UInt8],
Operator::Plus,
DataType::Float32
);
test_coercion_binary_rule!(
DataType::Float32,
DataType::Float64,
Operator::Multiply,
DataType::Float64
);
// TODO add other data type
UInt16
);
// (Int16, _) | (_, Int16) => Some(Int16),
test_coercion_binary_rule_multiple!(Int16, [Int16, Int8], Operator::Plus, Int16);
// (UInt8, Int8) | (Int8, UInt8) => Some(Int16)
test_coercion_binary_rule!(Int8, UInt8, Operator::Minus, Int16);
test_coercion_binary_rule!(UInt8, Int8, Operator::Multiply, Int16);
// (UInt8, _) | (_, UInt8) => Some(UInt8),
test_coercion_binary_rule!(UInt8, UInt8, Operator::Minus, UInt8);
// (Int8, _) | (_, Int8) => Some(Int8),
test_coercion_binary_rule!(Int8, Int8, Operator::Plus, Int8);

Ok(())
}

Expand Down
8 changes: 4 additions & 4 deletions datafusion/optimizer/src/single_distinct_to_groupby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ mod tests {

let expected = "Projection: count(alias1) AS count(DISTINCT Int32(2) * test.b) [count(DISTINCT Int32(2) * test.b):Int64]\
\n Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64]\
\n Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int32]\
\n Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
Expand Down Expand Up @@ -497,9 +497,9 @@ mod tests {
.build()?;

// Should work
let expected = "Projection: group_alias_0 AS test.a + Int32(1), count(alias1) AS count(DISTINCT test.c) [test.a + Int32(1):Int32, count(DISTINCT test.c):Int64]\
\n Aggregate: groupBy=[[group_alias_0]], aggr=[[count(alias1)]] [group_alias_0:Int32, count(alias1):Int64]\
\n Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\
let expected = "Projection: group_alias_0 AS test.a + Int32(1), count(alias1) AS count(DISTINCT test.c) [test.a + Int32(1):Int64, count(DISTINCT test.c):Int64]\
\n Aggregate: groupBy=[[group_alias_0]], aggr=[[count(alias1)]] [group_alias_0:Int64, count(alias1):Int64]\
\n Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int64, alias1:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
Expand Down
4 changes: 2 additions & 2 deletions datafusion/optimizer/tests/optimizer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,8 @@ fn push_down_filter_groupby_expr_contains_alias() {
let sql = "SELECT * FROM (SELECT (col_int32 + col_uint32) AS c, count(*) FROM test GROUP BY 1) where c > 3";
let plan = test_sql(sql).unwrap();
let expected = "Projection: test.col_int32 + test.col_uint32 AS c, count(Int64(1)) AS count(*)\
\n Aggregate: groupBy=[[test.col_int32 + CAST(test.col_uint32 AS Int32)]], aggr=[[count(Int64(1))]]\
\n Filter: test.col_int32 + CAST(test.col_uint32 AS Int32) > Int32(3)\
\n Aggregate: groupBy=[[CAST(test.col_int32 AS Int64) + CAST(test.col_uint32 AS Int64)]], aggr=[[count(Int64(1))]]\
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this might be slower-- as now the larger column type is used (so it needs to do a 64 bit comparison rather than 32 bit) 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but it probably also doesn't lose precision

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correctness > performance.

\n Filter: CAST(test.col_int32 AS Int64) + CAST(test.col_uint32 AS Int64) > Int64(3)\
\n TableScan: test projection=[col_int32, col_uint32]";
assert_eq!(expected, format!("{plan}"));
}
Expand Down
6 changes: 3 additions & 3 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1020,9 +1020,9 @@ mod tests {
DataType::UInt32,
vec![1u32, 2u32],
Operator::Plus,
Int32Array,
DataType::Int32,
[2i32, 4i32],
Int64Array,
DataType::Int64,
[2i64, 4i64],
);
test_coercion!(
Int32Array,
Expand Down
10 changes: 5 additions & 5 deletions datafusion/sqllogictest/test_files/math.slt
Original file line number Diff line number Diff line change
Expand Up @@ -164,22 +164,22 @@ INSERT into test_nullable_integer values(127, 32767, 2147483647, 922337203685477
----
1

query IIIIIIII
query IIIIIIIR
SELECT c1*0, c2*0, c3*0, c4*0, c5*0, c6*0, c7*0, c8*0 FROM test_nullable_integer where dataset = 'nulls'
----
NULL NULL NULL NULL NULL NULL NULL NULL

query IIIIIIII
query IIIIIIIR
SELECT c1/0, c2/0, c3/0, c4/0, c5/0, c6/0, c7/0, c8/0 FROM test_nullable_integer where dataset = 'nulls'
----
NULL NULL NULL NULL NULL NULL NULL NULL

query IIIIIIII
query IIIIIIIR
SELECT c1%0, c2%0, c3%0, c4%0, c5%0, c6%0, c7%0, c8%0 FROM test_nullable_integer where dataset = 'nulls'
----
NULL NULL NULL NULL NULL NULL NULL NULL

query IIIIIIII rowsort
query IIIIIIIR rowsort
select c1*0, c2*0, c3*0, c4*0, c5*0, c6*0, c7*0, c8*0 from test_nullable_integer where dataset != 'maxs'
----
0 0 0 0 0 0 0 0
Expand Down Expand Up @@ -300,7 +300,7 @@ INSERT INTO test_non_nullable_integer VALUES(1, 1, 1, 1, 1, 1, 1, 1)
----
1

query IIIIIIII rowsort
query IIIIIIIR rowsort
select c1*0, c2*0, c3*0, c4*0, c5*0, c6*0, c7*0, c8*0 from test_non_nullable_integer
----
0 0 0 0 0 0 0 0
Expand Down
8 changes: 4 additions & 4 deletions datafusion/sqllogictest/test_files/operator.slt
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ select
arrow_typeof(decimal + 2)
from numeric_types;
----
Int64 Int64 Int64 Int64 Int64 Int64 Int64 Int64 Float32 Float64 Decimal128(23, 2)
Int64 Int64 Int64 Int64 Int64 Int64 Int64 Decimal128(21, 0) Float32 Float64 Decimal128(23, 2)

# Plus with literal decimal
query TTTTTTTTTTT
Expand Down Expand Up @@ -127,7 +127,7 @@ select
arrow_typeof(decimal - 2)
from numeric_types;
----
Int64 Int64 Int64 Int64 Int64 Int64 Int64 Int64 Float32 Float64 Decimal128(23, 2)
Int64 Int64 Int64 Int64 Int64 Int64 Int64 Decimal128(21, 0) Float32 Float64 Decimal128(23, 2)

# Minus with literal decimal
query TTTTTTTTTTT
Expand Down Expand Up @@ -184,7 +184,7 @@ select
arrow_typeof(decimal * 2)
from numeric_types;
----
Int64 Int64 Int64 Int64 Int64 Int64 Int64 Int64 Float32 Float64 Decimal128(26, 2)
Int64 Int64 Int64 Int64 Int64 Int64 Int64 Decimal128(38, 0) Float32 Float64 Decimal128(26, 2)

# Multiply with literal decimal
query TTTTTTTTTTT
Expand Down Expand Up @@ -242,7 +242,7 @@ select
arrow_typeof(decimal / 2)
from numeric_types;
----
Int64 Int64 Int64 Int64 Int64 Int64 Int64 Int64 Float32 Float64 Decimal128(9, 6)
Int64 Int64 Int64 Int64 Int64 Int64 Int64 Decimal128(24, 4) Float32 Float64 Decimal128(9, 6)

# Divide with literal decimal
query TTTTTTTTTTT
Expand Down
8 changes: 4 additions & 4 deletions datafusion/sqllogictest/test_files/window.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2448,15 +2448,15 @@ EXPLAIN SELECT c5, c9, rn1 FROM (SELECT c5, c9,
LIMIT 5
----
logical_plan
01)Sort: rn1 ASC NULLS LAST, CAST(aggregate_test_100.c9 AS Int32) + aggregate_test_100.c5 DESC NULLS FIRST, fetch=5
01)Sort: rn1 ASC NULLS LAST, CAST(aggregate_test_100.c9 AS Decimal128(20, 0)) + CAST(aggregate_test_100.c5 AS Decimal128(20, 0)) DESC NULLS FIRST, fetch=5
02)--Projection: aggregate_test_100.c5, aggregate_test_100.c9, row_number() ORDER BY [aggregate_test_100.c9 + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rn1
03)----WindowAggr: windowExpr=[[row_number() ORDER BY [CAST(aggregate_test_100.c9 AS Int32) + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS row_number() ORDER BY [aggregate_test_100.c9 + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
03)----WindowAggr: windowExpr=[[row_number() ORDER BY [CAST(aggregate_test_100.c9 AS Decimal128(20, 0)) + CAST(aggregate_test_100.c5 AS Decimal128(20, 0)) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS row_number() ORDER BY [aggregate_test_100.c9 + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
04)------TableScan: aggregate_test_100 projection=[c5, c9]
physical_plan
01)ProjectionExec: expr=[c5@0 as c5, c9@1 as c9, row_number() ORDER BY [aggregate_test_100.c9 + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as rn1]
02)--GlobalLimitExec: skip=0, fetch=5
03)----BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "row_number() ORDER BY [aggregate_test_100.c9 + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]
04)------SortExec: expr=[CAST(c9@1 AS Int32) + c5@0 DESC], preserve_partitioning=[false]
03)----BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "row_number() ORDER BY [aggregate_test_100.c9 + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Decimal128(None,21,0)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]
04)------SortExec: expr=[CAST(c9@1 AS Decimal128(20, 0)) + CAST(c5@0 AS Decimal128(20, 0)) DESC], preserve_partitioning=[false]
05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c5, c9], file_type=csv, has_header=true

# Ordering equivalence should be preserved during cast expression
Expand Down