Skip to content

Commit 8b67f36

Browse files
authored
[mlir] [arith] Fix ceildivsi lowering in arith-expand (#133774)
This fixes the current lowering of `arith.ceildivsi` in the arith-expand pass, which was previously incorrect. The new version is based on the lowering of `arith.floordivsi`, and will not introduce new undefined behavior or poison during the lowering. It also replaces one division with a multiplication. The previous lowering of `ceildivsi(n, m)` was the following: ``` x = (m > 0) ? -1 : 1 (n*m>0) ? ((n+x) / m) + 1 : - (-n / m) ``` This caused two problems: * In the case where `n` is INT_MIN and `m` is positive, the result would be poison instead of an actual value * In the case where `n` is INT_MAX and `m` is `-1`, this would trigger undefined behavior, while the original code wouldn't. This is because `n+x` would be equal to `INT_MIN` (`INT_MAX + 1`), so the `(n+x) / m` division would overflow and trigger UB.
1 parent e84b57d commit 8b67f36

File tree

4 files changed

+166
-93
lines changed

4 files changed

+166
-93
lines changed

mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

+25-35
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,13 @@ struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
5858
}
5959
};
6060

61-
/// Expands CeilDivSIOp (n, m) into
62-
/// 1) x = (m > 0) ? -1 : 1
63-
/// 2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
61+
/// Expands CeilDivSIOp (a, b) into
62+
/// z = a / b
63+
/// if (z * b != a && (a < 0) == (b < 0)) {
64+
/// return z + 1;
65+
/// } else {
66+
/// return z;
67+
/// }
6468
struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
6569
using OpRewritePattern::OpRewritePattern;
6670
LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
@@ -69,43 +73,29 @@ struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
6973
Type type = op.getType();
7074
Value a = op.getLhs();
7175
Value b = op.getRhs();
72-
Value plusOne = createConst(loc, type, 1, rewriter);
76+
7377
Value zero = createConst(loc, type, 0, rewriter);
74-
Value minusOne = createConst(loc, type, -1, rewriter);
75-
// Compute x = (b>0) ? -1 : 1.
76-
Value compare =
77-
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
78-
Value x = rewriter.create<arith::SelectOp>(loc, compare, minusOne, plusOne);
79-
// Compute positive res: 1 + ((x+a)/b).
80-
Value xPlusA = rewriter.create<arith::AddIOp>(loc, x, a);
81-
Value xPlusADivB = rewriter.create<arith::DivSIOp>(loc, xPlusA, b);
82-
Value posRes = rewriter.create<arith::AddIOp>(loc, plusOne, xPlusADivB);
83-
// Compute negative res: - ((-a)/b).
84-
Value minusA = rewriter.create<arith::SubIOp>(loc, zero, a);
85-
Value minusADivB = rewriter.create<arith::DivSIOp>(loc, minusA, b);
86-
Value negRes = rewriter.create<arith::SubIOp>(loc, zero, minusADivB);
87-
// Result is (a*b>0) ? pos result : neg result.
88-
// Note, we want to avoid using a*b because of possible overflow.
89-
// The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
90-
// not particuliarly care if a*b<0 is true or false when b is zero
91-
// as this will result in an illegal divide. So `a*b<0` can be reformulated
92-
// as `(a<0 && b<0) || (a>0 && b>0)' or `(a<0 && b<0) || (a>0 && b>=0)'.
93-
// We pick the first expression here.
78+
Value one = createConst(loc, type, 1, rewriter);
79+
80+
Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b);
81+
Value product = rewriter.create<arith::MulIOp>(loc, quotient, b);
82+
Value notEqualDivisor = rewriter.create<arith::CmpIOp>(
83+
loc, arith::CmpIPredicate::ne, a, product);
84+
9485
Value aNeg =
9586
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
96-
Value aPos =
97-
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
9887
Value bNeg =
9988
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
100-
Value bPos =
101-
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
102-
Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bNeg);
103-
Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos);
104-
Value compareRes =
105-
rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
106-
// Perform substitution and return success.
107-
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, posRes,
108-
negRes);
89+
90+
Value signEqual = rewriter.create<arith::CmpIOp>(
91+
loc, arith::CmpIPredicate::eq, aNeg, bNeg);
92+
Value cond =
93+
rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signEqual);
94+
95+
Value quotientPlusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
96+
97+
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientPlusOne,
98+
quotient);
10999
return success();
110100
}
111101
};

mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir

+14-22
Original file line numberDiff line numberDiff line change
@@ -600,28 +600,20 @@ func.func @select_complex(%arg0 : i1, %arg1 : complex<f32>, %arg2 : complex<f32>
600600
// -----
601601

602602
// CHECK-LABEL: @ceildivsi
603-
// CHECK-SAME: %[[ARG0:.*]]: i64) -> i64
604-
func.func @ceildivsi(%arg0 : i64) -> i64 {
605-
// CHECK: %[[CST0:.*]] = llvm.mlir.constant(1 : i64) : i64
606-
// CHECK: %[[CST1:.*]] = llvm.mlir.constant(0 : i64) : i64
607-
// CHECK: %[[CST2:.*]] = llvm.mlir.constant(-1 : i64) : i64
608-
// CHECK: %[[CMP0:.*]] = llvm.icmp "sgt" %[[ARG0]], %[[CST1]] : i64
609-
// CHECK: %[[SEL0:.*]] = llvm.select %[[CMP0]], %[[CST2]], %[[CST0]] : i1, i64
610-
// CHECK: %[[ADD0:.*]] = llvm.add %[[SEL0]], %[[ARG0]] : i64
611-
// CHECK: %[[DIV0:.*]] = llvm.sdiv %[[ADD0]], %[[ARG0]] : i64
612-
// CHECK: %[[ADD1:.*]] = llvm.add %[[DIV0]], %[[CST0]] : i64
613-
// CHECK: %[[SUB0:.*]] = llvm.sub %[[CST1]], %[[ARG0]] : i64
614-
// CHECK: %[[DIV1:.*]] = llvm.sdiv %[[SUB0]], %[[ARG0]] : i64
615-
// CHECK: %[[SUB1:.*]] = llvm.sub %[[CST1]], %[[DIV1]] : i64
616-
// CHECK: %[[CMP1:.*]] = llvm.icmp "slt" %[[ARG0]], %[[CST1]] : i64
617-
// CHECK: %[[CMP2:.*]] = llvm.icmp "sgt" %[[ARG0]], %[[CST1]] : i64
618-
// CHECK: %[[CMP3:.*]] = llvm.icmp "slt" %[[ARG0]], %[[CST1]] : i64
619-
// CHECK: %[[CMP4:.*]] = llvm.icmp "sgt" %[[ARG0]], %[[CST1]] : i64
620-
// CHECK: %[[AND0:.*]] = llvm.and %[[CMP1]], %[[CMP3]] : i1
621-
// CHECK: %[[AND1:.*]] = llvm.and %[[CMP2]], %[[CMP4]] : i1
622-
// CHECK: %[[OR:.*]] = llvm.or %[[AND0]], %[[AND1]] : i1
623-
// CHECK: %[[SEL1:.*]] = llvm.select %[[OR]], %[[ADD1]], %[[SUB1]] : i1, i64
624-
%0 = arith.ceildivsi %arg0, %arg0 : i64
603+
// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: i64) -> i64
604+
func.func @ceildivsi(%arg0 : i64, %arg1 : i64) -> i64 {
605+
// CHECK: %[[ZERO:.+]] = llvm.mlir.constant(0 : i64) : i64
606+
// CHECK: %[[ONE:.+]] = llvm.mlir.constant(1 : i64) : i64
607+
// CHECK: %[[DIV:.+]] = llvm.sdiv %[[ARG0]], %[[ARG1]] : i64
608+
// CHECK: %[[MUL:.+]] = llvm.mul %[[DIV]], %[[ARG1]] : i64
609+
// CHECK: %[[NEXACT:.+]] = llvm.icmp "ne" %[[ARG0]], %[[MUL]] : i64
610+
// CHECK: %[[NNEG:.+]] = llvm.icmp "slt" %[[ARG0]], %[[ZERO]] : i64
611+
// CHECK: %[[MNEG:.+]] = llvm.icmp "slt" %[[ARG1]], %[[ZERO]] : i64
612+
// CHECK: %[[SAMESIGN:.+]] = llvm.icmp "eq" %[[NNEG]], %[[MNEG]] : i1
613+
// CHECK: %[[SHOULDROUND:.+]] = llvm.and %[[NEXACT]], %[[SAMESIGN]] : i1
614+
// CHECK: %[[CEIL:.+]] = llvm.add %[[DIV]], %[[ONE]] : i64
615+
// CHECK: %[[RES:.+]] = llvm.select %[[SHOULDROUND]], %[[CEIL]], %[[DIV]] : i1, i64
616+
%0 = arith.ceildivsi %arg0, %arg1 : i64
625617
return %0: i64
626618
}
627619

mlir/test/Dialect/Arith/expand-ops.mlir

+21-36
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,17 @@ func.func @ceildivi(%arg0: i32, %arg1: i32) -> (i32) {
77
%res = arith.ceildivsi %arg0, %arg1 : i32
88
return %res : i32
99

10-
// CHECK: [[ONE:%.+]] = arith.constant 1 : i32
1110
// CHECK: [[ZERO:%.+]] = arith.constant 0 : i32
12-
// CHECK: [[MINONE:%.+]] = arith.constant -1 : i32
13-
// CHECK: [[CMP1:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : i32
14-
// CHECK: [[X:%.+]] = arith.select [[CMP1]], [[MINONE]], [[ONE]] : i32
15-
// CHECK: [[TRUE1:%.+]] = arith.addi [[X]], [[ARG0]] : i32
16-
// CHECK: [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : i32
17-
// CHECK: [[TRUE3:%.+]] = arith.addi [[ONE]], [[TRUE2]] : i32
18-
// CHECK: [[FALSE1:%.+]] = arith.subi [[ZERO]], [[ARG0]] : i32
19-
// CHECK: [[FALSE2:%.+]] = arith.divsi [[FALSE1]], [[ARG1]] : i32
20-
// CHECK: [[FALSE3:%.+]] = arith.subi [[ZERO]], [[FALSE2]] : i32
21-
// CHECK: [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : i32
22-
// CHECK: [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : i32
23-
// CHECK: [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : i32
24-
// CHECK: [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : i32
25-
// CHECK: [[TERM1:%.+]] = arith.andi [[NNEG]], [[MNEG]] : i1
26-
// CHECK: [[TERM2:%.+]] = arith.andi [[NPOS]], [[MPOS]] : i1
27-
// CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
28-
// CHECK: [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE3]] : i32
11+
// CHECK: [[ONE:%.+]] = arith.constant 1 : i32
12+
// CHECK: [[DIV:%.+]] = arith.divsi %arg0, %arg1 : i32
13+
// CHECK: [[MUL:%.+]] = arith.muli [[DIV]], %arg1 : i32
14+
// CHECK: [[NEXACT:%.+]] = arith.cmpi ne, %arg0, [[MUL]] : i32
15+
// CHECK: [[NNEG:%.+]] = arith.cmpi slt, %arg0, [[ZERO]] : i32
16+
// CHECK: [[MNEG:%.+]] = arith.cmpi slt, %arg1, [[ZERO]] : i32
17+
// CHECK: [[SAMESIGN:%.+]] = arith.cmpi eq, [[NNEG]], [[MNEG]] : i1
18+
// CHECK: [[SHOULDROUND:%.+]] = arith.andi [[NEXACT]], [[SAMESIGN]] : i1
19+
// CHECK: [[CEIL:%.+]] = arith.addi [[DIV]], [[ONE]] : i32
20+
// CHECK: [[RES:%.+]] = arith.select [[SHOULDROUND]], [[CEIL]], [[DIV]] : i32
2921
}
3022

3123
// -----
@@ -37,25 +29,18 @@ func.func @ceildivi_index(%arg0: index, %arg1: index) -> (index) {
3729
%res = arith.ceildivsi %arg0, %arg1 : index
3830
return %res : index
3931

40-
// CHECK: [[ONE:%.+]] = arith.constant 1 : index
4132
// CHECK: [[ZERO:%.+]] = arith.constant 0 : index
42-
// CHECK: [[MINONE:%.+]] = arith.constant -1 : index
43-
// CHECK: [[CMP1:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index
44-
// CHECK: [[X:%.+]] = arith.select [[CMP1]], [[MINONE]], [[ONE]] : index
45-
// CHECK: [[TRUE1:%.+]] = arith.addi [[X]], [[ARG0]] : index
46-
// CHECK: [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : index
47-
// CHECK: [[TRUE3:%.+]] = arith.addi [[ONE]], [[TRUE2]] : index
48-
// CHECK: [[FALSE1:%.+]] = arith.subi [[ZERO]], [[ARG0]] : index
49-
// CHECK: [[FALSE2:%.+]] = arith.divsi [[FALSE1]], [[ARG1]] : index
50-
// CHECK: [[FALSE3:%.+]] = arith.subi [[ZERO]], [[FALSE2]] : index
51-
// CHECK: [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : index
52-
// CHECK: [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : index
53-
// CHECK: [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index
54-
// CHECK: [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index
55-
// CHECK: [[TERM1:%.+]] = arith.andi [[NNEG]], [[MNEG]] : i1
56-
// CHECK: [[TERM2:%.+]] = arith.andi [[NPOS]], [[MPOS]] : i1
57-
// CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
58-
// CHECK: [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE3]] : index
33+
// CHECK: [[ONE:%.+]] = arith.constant 1 : index
34+
// CHECK: [[DIV:%.+]] = arith.divsi %arg0, %arg1 : index
35+
// CHECK: [[MUL:%.+]] = arith.muli [[DIV]], %arg1 : index
36+
// CHECK: [[NEXACT:%.+]] = arith.cmpi ne, %arg0, [[MUL]] : index
37+
// CHECK: [[NNEG:%.+]] = arith.cmpi slt, %arg0, [[ZERO]] : index
38+
// CHECK: [[MNEG:%.+]] = arith.cmpi slt, %arg1, [[ZERO]] : index
39+
// CHECK: [[SAMESIGN:%.+]] = arith.cmpi eq, [[NNEG]], [[MNEG]] : i1
40+
// CHECK: [[SHOULDROUND:%.+]] = arith.andi [[NEXACT]], [[SAMESIGN]] : i1
41+
// CHECK: [[CEIL:%.+]] = arith.addi [[DIV]], [[ONE]] : index
42+
// CHECK: [[RES:%.+]] = arith.select [[SHOULDROUND]], [[CEIL]], [[DIV]] : index
43+
5944
}
6045

6146
// -----
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
// Check that the ceildivsi lowering is correct.
2+
// We do not check any poison or UB values, as it is not possible to catch them.
3+
4+
// RUN: mlir-opt %s --convert-vector-to-llvm \
5+
// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \
6+
// RUN: mlir-runner -e entry --entry-point-result=void \
7+
// RUN: --shared-libs=%mlir_c_runner_utils | \
8+
// RUN: FileCheck %s --match-full-lines
9+
10+
func.func @check_ceildivsi(%lhs : i32, %rhs : i32) -> () {
11+
%res = arith.ceildivsi %lhs, %rhs : i32
12+
vector.print %res : i32
13+
return
14+
}
15+
16+
func.func @entry() {
17+
%int_min = arith.constant -2147483648 : i32
18+
%int_max = arith.constant 2147483647 : i32
19+
%minus_three = arith.constant -3 : i32
20+
%minus_two = arith.constant -2 : i32
21+
%minus_one = arith.constant -1 : i32
22+
%zero = arith.constant 0 : i32
23+
%one = arith.constant 1 : i32
24+
%two = arith.constant 2 : i32
25+
%three = arith.constant 3 : i32
26+
27+
// INT_MAX divided by values.
28+
// CHECK: 1
29+
func.call @check_ceildivsi(%int_max, %int_max) : (i32, i32) -> ()
30+
// CHECK-NEXT: 0
31+
func.call @check_ceildivsi(%int_max, %int_min) : (i32, i32) -> ()
32+
// CHECK-NEXT: -2147483647
33+
func.call @check_ceildivsi(%int_max, %minus_one) : (i32, i32) -> ()
34+
// CHECK-NEXT: -1073741823
35+
func.call @check_ceildivsi(%int_max, %minus_two) : (i32, i32) -> ()
36+
// CHECK-NEXT: 2147483647
37+
func.call @check_ceildivsi(%int_max, %one) : (i32, i32) -> ()
38+
// CHECK-NEXT: 1073741824
39+
func.call @check_ceildivsi(%int_max, %two) : (i32, i32) -> ()
40+
41+
// INT_MIN divided by values.
42+
// We do not check the result of INT_MIN divided by -1, as it is UB.
43+
// CHECK-NEXT: 1
44+
func.call @check_ceildivsi(%int_min, %int_min) : (i32, i32) -> ()
45+
// CHECK-NEXT: -1
46+
func.call @check_ceildivsi(%int_min, %int_max) : (i32, i32) -> ()
47+
// CHECK-NEXT: 1073741824
48+
func.call @check_ceildivsi(%int_min, %minus_two) : (i32, i32) -> ()
49+
// CHECK-NEXT: -2147483648
50+
func.call @check_ceildivsi(%int_min, %one) : (i32, i32) -> ()
51+
// CHECK-NEXT: -1073741824
52+
func.call @check_ceildivsi(%int_min, %two) : (i32, i32) -> ()
53+
54+
// Divide values by INT_MIN.
55+
// CHECK-NEXT: 0
56+
func.call @check_ceildivsi(%one, %int_min) : (i32, i32) -> ()
57+
// CHECK-NEXT: 0
58+
func.call @check_ceildivsi(%two, %int_min) : (i32, i32) -> ()
59+
// CHECK-NEXT: 1
60+
func.call @check_ceildivsi(%minus_one, %int_min) : (i32, i32) -> ()
61+
// CHECK-NEXT: 1
62+
func.call @check_ceildivsi(%minus_two, %int_min) : (i32, i32) -> ()
63+
64+
// Divide values by INT_MAX.
65+
// CHECK-NEXT: 1
66+
func.call @check_ceildivsi(%one, %int_max) : (i32, i32) -> ()
67+
// CHECK-NEXT: 1
68+
func.call @check_ceildivsi(%two, %int_max) : (i32, i32) -> ()
69+
// CHECK-NEXT: 0
70+
func.call @check_ceildivsi(%minus_one, %int_max) : (i32, i32) -> ()
71+
// CHECK-NEXT: 0
72+
func.call @check_ceildivsi(%minus_two, %int_max) : (i32, i32) -> ()
73+
74+
// Check divisions by 2.
75+
// CHECK-NEXT: -1
76+
func.call @check_ceildivsi(%minus_three, %two) : (i32, i32) -> ()
77+
// CHECK-NEXT: -1
78+
func.call @check_ceildivsi(%minus_two, %two) : (i32, i32) -> ()
79+
// CHECK-NEXT: 0
80+
func.call @check_ceildivsi(%minus_one, %two) : (i32, i32) -> ()
81+
// CHECK-NEXT: 0
82+
func.call @check_ceildivsi(%zero, %two) : (i32, i32) -> ()
83+
// CHECK-NEXT: 1
84+
func.call @check_ceildivsi(%one, %two) : (i32, i32) -> ()
85+
// CHECK-NEXT: 1
86+
func.call @check_ceildivsi(%two, %two) : (i32, i32) -> ()
87+
// CHECK-NEXT: 2
88+
func.call @check_ceildivsi(%three, %two) : (i32, i32) -> ()
89+
90+
// Check divisions by -2.
91+
// CHECK-NEXT: 2
92+
func.call @check_ceildivsi(%minus_three, %minus_two) : (i32, i32) -> ()
93+
// CHECK-NEXT: 1
94+
func.call @check_ceildivsi(%minus_two, %minus_two) : (i32, i32) -> ()
95+
// CHECK-NEXT: 1
96+
func.call @check_ceildivsi(%minus_one, %minus_two) : (i32, i32) -> ()
97+
// CHECK-NEXT: 0
98+
func.call @check_ceildivsi(%zero, %minus_two) : (i32, i32) -> ()
99+
// CHECK-NEXT: 0
100+
func.call @check_ceildivsi(%one, %minus_two) : (i32, i32) -> ()
101+
// CHECK-NEXT: -1
102+
func.call @check_ceildivsi(%two, %minus_two) : (i32, i32) -> ()
103+
// CHECK-NEXT: -1
104+
func.call @check_ceildivsi(%three, %minus_two) : (i32, i32) -> ()
105+
return
106+
}

0 commit comments

Comments
 (0)