From 029dbd43994166818d88631285f7096ab7b39d89 Mon Sep 17 00:00:00 2001 From: Martin Erhart Date: Wed, 1 Jan 2025 22:01:42 +0100 Subject: [PATCH] [Comb] Don't try to canonicalize muxes indefinitely (#8023) --- lib/Dialect/Comb/CombFolds.cpp | 22 ++++++++++++++-------- test/Dialect/Comb/canonicalization.mlir | 11 +++++++---- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/lib/Dialect/Comb/CombFolds.cpp b/lib/Dialect/Comb/CombFolds.cpp index 691ae579c9cb..316234320b0b 100644 --- a/lib/Dialect/Comb/CombFolds.cpp +++ b/lib/Dialect/Comb/CombFolds.cpp @@ -1921,7 +1921,7 @@ OpFoldResult MuxOp::fold(FoldAdaptor adaptor) { return {}; // mux (c, b, b) -> b - if (getTrueValue() == getFalseValue()) + if (getTrueValue() == getFalseValue() && getTrueValue() != getResult()) return getTrueValue(); if (auto tv = adaptor.getTrueValue()) if (tv == adaptor.getFalseValue()) @@ -2183,6 +2183,9 @@ static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand, // `mux(cond, mux(cond2, a, b), a)` -> `mux(~cond|cond2, a, b)` // `mux(cond, mux(cond2, b, a), a)` -> `mux(~cond|~cond2, a, b)` if (auto subMux = dyn_cast(subExpr)) { + if (subMux == op) + return false; + Value otherValue; Value subCond = subMux.getCond(); @@ -2514,8 +2517,8 @@ LogicalResult MuxRewriter::matchAndRewrite(MuxOp op, } } - if (auto falseMux = - dyn_cast_or_null(op.getFalseValue().getDefiningOp())) { + if (auto falseMux = op.getFalseValue().getDefiningOp(); + falseMux && falseMux != op) { // mux(selector, x, mux(selector, y, z) = mux(selector, x, z) if (op.getCond() == falseMux.getCond()) { replaceOpWithNewOpAndCopyName( @@ -2529,8 +2532,8 @@ LogicalResult MuxRewriter::matchAndRewrite(MuxOp op, return success(); } - if (auto trueMux = - dyn_cast_or_null(op.getTrueValue().getDefiningOp())) { + if (auto trueMux = op.getTrueValue().getDefiningOp(); + trueMux && trueMux != op) { // mux(selector, mux(selector, a, b), c) = mux(selector, a, c) if (op.getCond() == trueMux.getCond()) { replaceOpWithNewOpAndCopyName( @@ -2548,7 +2551,8 @@ LogicalResult MuxRewriter::matchAndRewrite(MuxOp op, if (auto trueMux = dyn_cast_or_null(op.getTrueValue().getDefiningOp()), falseMux = dyn_cast_or_null(op.getFalseValue().getDefiningOp()); trueMux && falseMux && trueMux.getCond() == falseMux.getCond() && - trueMux.getTrueValue() == falseMux.getTrueValue()) { + trueMux.getTrueValue() == falseMux.getTrueValue() && trueMux != op && + falseMux != op) { auto subMux = rewriter.create( rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}), op.getCond(), trueMux.getFalseValue(), falseMux.getFalseValue()); @@ -2562,7 +2566,8 @@ LogicalResult MuxRewriter::matchAndRewrite(MuxOp op, if (auto trueMux = dyn_cast_or_null(op.getTrueValue().getDefiningOp()), falseMux = dyn_cast_or_null(op.getFalseValue().getDefiningOp()); trueMux && falseMux && trueMux.getCond() == falseMux.getCond() && - trueMux.getFalseValue() == falseMux.getFalseValue()) { + trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op && + falseMux != op) { auto subMux = rewriter.create( rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}), op.getCond(), trueMux.getTrueValue(), falseMux.getTrueValue()); @@ -2577,7 +2582,8 @@ LogicalResult MuxRewriter::matchAndRewrite(MuxOp op, falseMux = dyn_cast_or_null(op.getFalseValue().getDefiningOp()); trueMux && falseMux && trueMux.getTrueValue() == falseMux.getTrueValue() && - trueMux.getFalseValue() == falseMux.getFalseValue()) { + trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op && + falseMux != op) { auto subMux = rewriter.create( rewriter.getFusedLoc( {op.getLoc(), trueMux.getLoc(), falseMux.getLoc()}), diff --git a/test/Dialect/Comb/canonicalization.mlir b/test/Dialect/Comb/canonicalization.mlir index e7cda4bd24c4..92cb591b7032 100644 --- a/test/Dialect/Comb/canonicalization.mlir +++ b/test/Dialect/Comb/canonicalization.mlir @@ -1230,7 +1230,7 @@ hw.module @muxConstantsFold(in %cond: i1, out o: i25) { hw.module @muxCommon(in %cond: i1, in %cond2: i1, in %arg0 : i32, in %arg1 : i32, in %arg2: i32, in %arg3: i32, out o1: i32, out o2: i32, out o3: i32, out o4: i32, - out o5: i32, out orResult: i32, out o6: i32, out o7: i32) { + out o5: i32, out orResult: i32, out o6: i32, out o7: i32, out o8 : i1) { %allones = hw.constant -1 : i32 %notArg0 = comb.xor %arg0, %allones : i32 @@ -1275,10 +1275,13 @@ hw.module @muxCommon(in %cond: i1, in %cond2: i1, %1 = comb.mux %cond, %arg1, %arg0 : i32 %o7 = comb.mux %cond2, %1, %arg0 : i32 + /// CHECK: [[O8:%.+]] = comb.mux [[O8]], [[O8]], [[O8]] : i1 + %o8 = comb.mux %o8, %o8, %o8 : i1 + // CHECK: hw.output [[O1]], [[O2]], [[O3]], [[O4]], [[O5]], [[ORRESULT]], - // CHECK: [[O6]], [[O7]] - hw.output %o1, %o2, %o3, %o4, %o5, %orResult, %o6, %o7 - : i32, i32, i32, i32, i32, i32, i32, i32 + // CHECK: [[O6]], [[O7]], [[O8]] + hw.output %o1, %o2, %o3, %o4, %o5, %orResult, %o6, %o7, %o8 + : i32, i32, i32, i32, i32, i32, i32, i32, i1 } // CHECK-LABEL: @flatten_multi_use_and