-
Notifications
You must be signed in to change notification settings - Fork 13.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] [arith] Fix ceildivsi lowering in arith-expand #133774
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-arith Author: Fehr Mathieu (math-fehr) ChangesThis fixes the current lowering of The previous lowering of
This caused two problems:
Full diff: https://github.com/llvm/llvm-project/pull/133774.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 54be644a71011..2d627e523cde5 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -58,9 +58,13 @@ struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
}
};
-/// Expands CeilDivSIOp (n, m) into
-/// 1) x = (m > 0) ? -1 : 1
-/// 2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
+/// Expands CeilDivSIOp (a, b) into
+/// z = a / b
+/// if (z * b != a && (a < 0) == (b < 0)) {
+/// return z + 1;
+/// } else {
+/// return z;
+/// }
struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
@@ -69,43 +73,29 @@ struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
Type type = op.getType();
Value a = op.getLhs();
Value b = op.getRhs();
- Value plusOne = createConst(loc, type, 1, rewriter);
+
Value zero = createConst(loc, type, 0, rewriter);
- Value minusOne = createConst(loc, type, -1, rewriter);
- // Compute x = (b>0) ? -1 : 1.
- Value compare =
- rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
- Value x = rewriter.create<arith::SelectOp>(loc, compare, minusOne, plusOne);
- // Compute positive res: 1 + ((x+a)/b).
- Value xPlusA = rewriter.create<arith::AddIOp>(loc, x, a);
- Value xPlusADivB = rewriter.create<arith::DivSIOp>(loc, xPlusA, b);
- Value posRes = rewriter.create<arith::AddIOp>(loc, plusOne, xPlusADivB);
- // Compute negative res: - ((-a)/b).
- Value minusA = rewriter.create<arith::SubIOp>(loc, zero, a);
- Value minusADivB = rewriter.create<arith::DivSIOp>(loc, minusA, b);
- Value negRes = rewriter.create<arith::SubIOp>(loc, zero, minusADivB);
- // Result is (a*b>0) ? pos result : neg result.
- // Note, we want to avoid using a*b because of possible overflow.
- // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
- // not particuliarly care if a*b<0 is true or false when b is zero
- // as this will result in an illegal divide. So `a*b<0` can be reformulated
- // as `(a<0 && b<0) || (a>0 && b>0)' or `(a<0 && b<0) || (a>0 && b>=0)'.
- // We pick the first expression here.
+ Value one = createConst(loc, type, 1, rewriter);
+
+ Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b);
+ Value product = rewriter.create<arith::MulIOp>(loc, quotient, b);
+ Value notEqualDivisor = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ne, a, product);
+
Value aNeg =
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
- Value aPos =
- rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
Value bNeg =
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
- Value bPos =
- rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
- Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bNeg);
- Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos);
- Value compareRes =
- rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
- // Perform substitution and return success.
- rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, posRes,
- negRes);
+
+ Value signEqual = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::eq, aNeg, bNeg);
+ Value cond =
+ rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signEqual);
+
+ Value quotientPlusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
+
+ rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientPlusOne,
+ quotient);
return success();
}
};
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 7daf4ef8717bc..e0d974ea74041 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -600,28 +600,20 @@ func.func @select_complex(%arg0 : i1, %arg1 : complex<f32>, %arg2 : complex<f32>
// -----
// CHECK-LABEL: @ceildivsi
-// CHECK-SAME: %[[ARG0:.*]]: i64) -> i64
-func.func @ceildivsi(%arg0 : i64) -> i64 {
- // CHECK: %[[CST0:.*]] = llvm.mlir.constant(1 : i64) : i64
- // CHECK: %[[CST1:.*]] = llvm.mlir.constant(0 : i64) : i64
- // CHECK: %[[CST2:.*]] = llvm.mlir.constant(-1 : i64) : i64
- // CHECK: %[[CMP0:.*]] = llvm.icmp "sgt" %[[ARG0]], %[[CST1]] : i64
- // CHECK: %[[SEL0:.*]] = llvm.select %[[CMP0]], %[[CST2]], %[[CST0]] : i1, i64
- // CHECK: %[[ADD0:.*]] = llvm.add %[[SEL0]], %[[ARG0]] : i64
- // CHECK: %[[DIV0:.*]] = llvm.sdiv %[[ADD0]], %[[ARG0]] : i64
- // CHECK: %[[ADD1:.*]] = llvm.add %[[DIV0]], %[[CST0]] : i64
- // CHECK: %[[SUB0:.*]] = llvm.sub %[[CST1]], %[[ARG0]] : i64
- // CHECK: %[[DIV1:.*]] = llvm.sdiv %[[SUB0]], %[[ARG0]] : i64
- // CHECK: %[[SUB1:.*]] = llvm.sub %[[CST1]], %[[DIV1]] : i64
- // CHECK: %[[CMP1:.*]] = llvm.icmp "slt" %[[ARG0]], %[[CST1]] : i64
- // CHECK: %[[CMP2:.*]] = llvm.icmp "sgt" %[[ARG0]], %[[CST1]] : i64
- // CHECK: %[[CMP3:.*]] = llvm.icmp "slt" %[[ARG0]], %[[CST1]] : i64
- // CHECK: %[[CMP4:.*]] = llvm.icmp "sgt" %[[ARG0]], %[[CST1]] : i64
- // CHECK: %[[AND0:.*]] = llvm.and %[[CMP1]], %[[CMP3]] : i1
- // CHECK: %[[AND1:.*]] = llvm.and %[[CMP2]], %[[CMP4]] : i1
- // CHECK: %[[OR:.*]] = llvm.or %[[AND0]], %[[AND1]] : i1
- // CHECK: %[[SEL1:.*]] = llvm.select %[[OR]], %[[ADD1]], %[[SUB1]] : i1, i64
- %0 = arith.ceildivsi %arg0, %arg0 : i64
+// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: i64) -> i64
+func.func @ceildivsi(%arg0 : i64, %arg1 : i64) -> i64 {
+ // CHECK: %[[ZERO:.+]] = llvm.mlir.constant(0 : i64) : i64
+ // CHECK: %[[ONE:.+]] = llvm.mlir.constant(1 : i64) : i64
+ // CHECK: %[[DIV:.+]] = llvm.sdiv %[[ARG0]], %[[ARG1]] : i64
+ // CHECK: %[[MUL:.+]] = llvm.mul %[[DIV]], %[[ARG1]] : i64
+ // CHECK: %[[NEXACT:.+]] = llvm.icmp "ne" %[[ARG0]], %[[MUL]] : i64
+ // CHECK: %[[NNEG:.+]] = llvm.icmp "slt" %[[ARG0]], %[[ZERO]] : i64
+ // CHECK: %[[MNEG:.+]] = llvm.icmp "slt" %[[ARG1]], %[[ZERO]] : i64
+ // CHECK: %[[SAMESIGN:.+]] = llvm.icmp "eq" %[[NNEG]], %[[MNEG]] : i1
+ // CHECK: %[[SHOULDROUND:.+]] = llvm.and %[[NEXACT]], %[[SAMESIGN]] : i1
+ // CHECK: %[[CEIL:.+]] = llvm.add %[[DIV]], %[[ONE]] : i64
+ // CHECK: %[[RES:.+]] = llvm.select %[[SHOULDROUND]], %[[CEIL]], %[[DIV]] : i1, i64
+ %0 = arith.ceildivsi %arg0, %arg1 : i64
return %0: i64
}
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 174eb468cc004..bdf022642b717 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -7,25 +7,17 @@ func.func @ceildivi(%arg0: i32, %arg1: i32) -> (i32) {
%res = arith.ceildivsi %arg0, %arg1 : i32
return %res : i32
-// CHECK: [[ONE:%.+]] = arith.constant 1 : i32
// CHECK: [[ZERO:%.+]] = arith.constant 0 : i32
-// CHECK: [[MINONE:%.+]] = arith.constant -1 : i32
-// CHECK: [[CMP1:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : i32
-// CHECK: [[X:%.+]] = arith.select [[CMP1]], [[MINONE]], [[ONE]] : i32
-// CHECK: [[TRUE1:%.+]] = arith.addi [[X]], [[ARG0]] : i32
-// CHECK: [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : i32
-// CHECK: [[TRUE3:%.+]] = arith.addi [[ONE]], [[TRUE2]] : i32
-// CHECK: [[FALSE1:%.+]] = arith.subi [[ZERO]], [[ARG0]] : i32
-// CHECK: [[FALSE2:%.+]] = arith.divsi [[FALSE1]], [[ARG1]] : i32
-// CHECK: [[FALSE3:%.+]] = arith.subi [[ZERO]], [[FALSE2]] : i32
-// CHECK: [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : i32
-// CHECK: [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : i32
-// CHECK: [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : i32
-// CHECK: [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : i32
-// CHECK: [[TERM1:%.+]] = arith.andi [[NNEG]], [[MNEG]] : i1
-// CHECK: [[TERM2:%.+]] = arith.andi [[NPOS]], [[MPOS]] : i1
-// CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
-// CHECK: [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE3]] : i32
+// CHECK: [[ONE:%.+]] = arith.constant 1 : i32
+// CHECK: [[DIV:%.+]] = arith.divsi %arg0, %arg1 : i32
+// CHECK: [[MUL:%.+]] = arith.muli [[DIV]], %arg1 : i32
+// CHECK: [[NEXACT:%.+]] = arith.cmpi ne, %arg0, [[MUL]] : i32
+// CHECK: [[NNEG:%.+]] = arith.cmpi slt, %arg0, [[ZERO]] : i32
+// CHECK: [[MNEG:%.+]] = arith.cmpi slt, %arg1, [[ZERO]] : i32
+// CHECK: [[SAMESIGN:%.+]] = arith.cmpi eq, [[NNEG]], [[MNEG]] : i1
+// CHECK: [[SHOULDROUND:%.+]] = arith.andi [[NEXACT]], [[SAMESIGN]] : i1
+// CHECK: [[CEIL:%.+]] = arith.addi [[DIV]], [[ONE]] : i32
+// CHECK: [[RES:%.+]] = arith.select [[SHOULDROUND]], [[CEIL]], [[DIV]] : i32
}
// -----
@@ -37,25 +29,18 @@ func.func @ceildivi_index(%arg0: index, %arg1: index) -> (index) {
%res = arith.ceildivsi %arg0, %arg1 : index
return %res : index
-// CHECK: [[ONE:%.+]] = arith.constant 1 : index
// CHECK: [[ZERO:%.+]] = arith.constant 0 : index
-// CHECK: [[MINONE:%.+]] = arith.constant -1 : index
-// CHECK: [[CMP1:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index
-// CHECK: [[X:%.+]] = arith.select [[CMP1]], [[MINONE]], [[ONE]] : index
-// CHECK: [[TRUE1:%.+]] = arith.addi [[X]], [[ARG0]] : index
-// CHECK: [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : index
-// CHECK: [[TRUE3:%.+]] = arith.addi [[ONE]], [[TRUE2]] : index
-// CHECK: [[FALSE1:%.+]] = arith.subi [[ZERO]], [[ARG0]] : index
-// CHECK: [[FALSE2:%.+]] = arith.divsi [[FALSE1]], [[ARG1]] : index
-// CHECK: [[FALSE3:%.+]] = arith.subi [[ZERO]], [[FALSE2]] : index
-// CHECK: [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : index
-// CHECK: [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : index
-// CHECK: [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index
-// CHECK: [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index
-// CHECK: [[TERM1:%.+]] = arith.andi [[NNEG]], [[MNEG]] : i1
-// CHECK: [[TERM2:%.+]] = arith.andi [[NPOS]], [[MPOS]] : i1
-// CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
-// CHECK: [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE3]] : index
+// CHECK: [[ONE:%.+]] = arith.constant 1 : index
+// CHECK: [[DIV:%.+]] = arith.divsi %arg0, %arg1 : index
+// CHECK: [[MUL:%.+]] = arith.muli [[DIV]], %arg1 : index
+// CHECK: [[NEXACT:%.+]] = arith.cmpi ne, %arg0, [[MUL]] : index
+// CHECK: [[NNEG:%.+]] = arith.cmpi slt, %arg0, [[ZERO]] : index
+// CHECK: [[MNEG:%.+]] = arith.cmpi slt, %arg1, [[ZERO]] : index
+// CHECK: [[SAMESIGN:%.+]] = arith.cmpi eq, [[NNEG]], [[MNEG]] : i1
+// CHECK: [[SHOULDROUND:%.+]] = arith.andi [[NEXACT]], [[SAMESIGN]] : i1
+// CHECK: [[CEIL:%.+]] = arith.addi [[DIV]], [[ONE]] : index
+// CHECK: [[RES:%.+]] = arith.select [[SHOULDROUND]], [[CEIL]], [[DIV]] : index
+
}
// -----
|
This fixes the current lowering of arith.ceildivsi in the arith-expand pass, which was previously incorrect. The new version is based on the lowering of arith.floordivsi, and will not introduce new undefined behavior or poison during the lowering. It also do one less division.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems reasonable to me, approved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't checked the math but looks good overall.
I wonder if we should add some integration tests to exercise some known corner cases -- we do that for wide integer emulation which was also non-trivial to get right: https://github.com/llvm/llvm-project/tree/main/mlir/test/Integration/Dialect/Arith/CPU ?
Or alternatively, could we use something like sympy to show that this expansion is correct in isolation?
WDYT @math-fehr ?
I would be really happy to figure out a way to properly test it, as it's really easy to get it wrong. The way I checked the correctness of this (and found the bug) was by using my SMT tool, but it relies on my own implementation of ceildivsi in SMT-LIB, so it has the exact same problem. I thought about trying to use Alive2 in LLVM to prove that correctly, but LLVM doesn't define ceildivsi, so that's not really possible :/ What I can easily do is add integration test for some values that do not trigger poison/UB if you want, but I'm just not sure how to proceed with poison/UB. |
Yeah this sounds good. I know we can't test for poison/UB, but let's at least make sure that the non-ub results are correct. |
My worry would be that something that attempts to fix UB/poison propagation could just as well break the correctness for well-defined inputs and it's hard to tell if the math checks out or not. |
BTW, if you would like to test your tool on more passes @math-fehr, I'd surprised if wide integer emulation propagates poison properly: https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp. The code predates poison semantics in mlir. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved as that seems to be correct and does fix an existing issue, but I think that even the new version, albeit correct, is still unnecessarily complicated. How about implementing this instead? Can be a follow-up PR, or an open TODO.
int ceildiv(int a, int b) {
int c = b - 1;
int d = (a >= INT_MAX - c) ? INT_MAX : a + c; // d = saturating_add(a, c)
return d / b;
}
EDIT: oops, that is only working when b > 0. OK, maybe once that is fixed, it won't really be less complicated anymore. Your choice.
This would also not work when |
I'll try to do that this week or next week, will report the bugs and send some patches ;) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % minor nits
mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-ceildivsi.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-ceildivsi.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-ceildivsi.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-ceildivsi.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-ceildivsi.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-ceildivsi.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-ceildivsi.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-ceildivsi.mlir
Outdated
Show resolved
Hide resolved
Co-authored-by: Jakub Kuderski <[email protected]>
This fixes the current lowering of
arith.ceildivsi
in the arith-expand pass, which was previously incorrect. The new version is based on the lowering ofarith.floordivsi
, and will not introduce new undefined behavior or poison during the lowering. It also replaces one division with a multiplication.The previous lowering of
ceildivsi(n, m)
was the following:This caused two problems:
n
is INT_MIN andm
is positive, the result would be poison instead of an actual valuen
is INT_MAX andm
is-1
, this would trigger undefined behavior, while the original code wouldn't. This is becausen+x
would be equal toINT_MIN
(INT_MAX + 1
), so the(n+x) / m
division would overflow and trigger UB.