Skip to content

Commit b30ec02

Browse files
committed
Use worklist instead of recursion
1 parent 03a41aa commit b30ec02

File tree

2 files changed

+95
-66
lines changed

2 files changed

+95
-66
lines changed

Diff for: lib/Dialect/Verif/Transforms/LowerContracts.cpp

+68-39
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13+
#include <ranges>
14+
1315
#include "circt/Dialect/Verif/VerifOps.h"
1416
#include "circt/Dialect/Verif/VerifPasses.h"
1517
#include "mlir/IR/IRMapping.h"
@@ -53,10 +55,6 @@ Operation *replaceContractOp(OpBuilder &builder, RequireLike op,
5355
return nullptr;
5456
}
5557

56-
LogicalResult cloneFanIn(OpBuilder &builder, Operation *opToClone,
57-
IRMapping &mapping, DenseSet<Operation *> &seen,
58-
bool assumeContract);
59-
6058
LogicalResult cloneContractOp(OpBuilder &builder, Operation *opToClone,
6159
IRMapping &mapping, bool assumeContract) {
6260
Operation *clonedOp;
@@ -78,6 +76,10 @@ LogicalResult cloneContractOp(OpBuilder &builder, Operation *opToClone,
7876
return llvm::success();
7977
}
8078

79+
LogicalResult cloneFanIn(OpBuilder &builder, Operation *opToClone,
80+
IRMapping &mapping, DenseSet<Operation *> &seen,
81+
bool assumeContract);
82+
8183
LogicalResult cloneContractBody(ContractOp &contract, OpBuilder &builder,
8284
IRMapping &mapping, DenseSet<Operation *> &seen,
8385
bool assumeContract, bool shouldCloneFanIn) {
@@ -105,57 +107,84 @@ LogicalResult inlineContract(ContractOp &contract, OpBuilder &builder,
105107
shouldCloneFanIn);
106108
}
107109

108-
LogicalResult cloneOperands(OpBuilder &builder, Operation *opToClone,
109-
IRMapping &mapping, DenseSet<Operation *> &seen,
110-
bool assumeContract, Operation *parent = nullptr) {
111-
for (auto operand : opToClone->getOperands()) {
110+
void buildOpsToClone(OpBuilder &builder, IRMapping &mapping, Operation *op,
111+
SmallVector<Operation *> &opsToClone,
112+
std::queue<Operation *> &workList,
113+
DenseSet<Operation *> &seen, Operation *parent = nullptr) {
114+
if (auto contract = dyn_cast<ContractOp>(*op)) {
115+
for (auto result : contract.getResults()) {
116+
auto sym =
117+
builder.create<SymbolicValueOp>(result.getLoc(), result.getType());
118+
mapping.map(result, sym);
119+
}
120+
// Assume it holds
121+
// TODO: merge w/ parent/walk logic?
122+
auto &contractOps = contract.getBody().front().getOperations();
123+
for (auto it = contractOps.rbegin(); it != contractOps.rend(); ++it) {
124+
if (!seen.contains(&*it)) {
125+
workList.push(&*it);
126+
}
127+
}
128+
return;
129+
}
130+
for (auto operand : op->getOperands()) {
112131
if (mapping.contains(operand))
113132
continue;
114-
133+
// Don't need to clone if defining op is within parent op
134+
if (parent && parent->isAncestor(operand.getParentBlock()->getParentOp()))
135+
continue;
115136
if (auto *definingOp = operand.getDefiningOp()) {
116-
// Don't need to clone if defining op is within parent op
117-
if (parent && parent->isAncestor(operand.getParentBlock()->getParentOp()))
118-
continue;
119-
// Recurse and clone defining op
120-
if (failed(
121-
cloneFanIn(builder, definingOp, mapping, seen, assumeContract)))
122-
return failure();
137+
if (!seen.contains(definingOp)) {
138+
workList.push(definingOp);
139+
}
123140
} else {
124141
// Create symbolic values for arguments
125142
auto sym = builder.create<verif::SymbolicValueOp>(operand.getLoc(),
126143
operand.getType());
127144
mapping.map(operand, sym);
128145
}
129146
}
130-
return success();
131147
}
132148

133149
LogicalResult cloneFanIn(OpBuilder &builder, Operation *opToClone,
134150
IRMapping &mapping, DenseSet<Operation *> &seen,
135151
bool assumeContract) {
136-
if (seen.contains(opToClone))
137-
return llvm::success();
138-
seen.insert(opToClone);
139-
140-
if (failed(cloneOperands(builder, opToClone, mapping, seen, assumeContract)))
141-
return failure();
142-
// Ensure all operands have been mapped
143-
if (opToClone
144-
->walk([&](Operation *nestedOp) {
145-
if (failed(cloneOperands(builder, nestedOp, mapping, seen,
146-
assumeContract, opToClone)))
147-
return WalkResult::interrupt();
148-
return WalkResult::advance();
149-
})
150-
.wasInterrupted())
151-
return failure();
152-
153-
if (auto contract = dyn_cast<ContractOp>(opToClone)) {
154-
// Assume it holds
155-
return inlineContract(contract, builder, mapping, seen, true);
152+
SmallVector<Operation *> opsToClone;
153+
std::queue<Operation *> workList;
154+
workList.push(opToClone);
155+
while (!workList.empty()) {
156+
auto *currentOp = workList.front();
157+
workList.pop();
158+
if (seen.contains(currentOp))
159+
continue;
160+
seen.insert(currentOp);
161+
buildOpsToClone(builder, mapping, currentOp, opsToClone, workList, seen);
162+
if (auto contract = dyn_cast<ContractOp>(*currentOp))
163+
continue;
164+
currentOp->walk([&](Operation *nestedOp) {
165+
buildOpsToClone(builder, mapping, nestedOp, opsToClone, workList, seen,
166+
currentOp);
167+
});
168+
opsToClone.push_back(currentOp);
156169
}
157170

158-
return cloneContractOp(builder, opToClone, mapping, assumeContract);
171+
for (auto it = opsToClone.rbegin(); it != opsToClone.rend(); ++it) {
172+
Operation *op = *it;
173+
Operation *clonedOp;
174+
if (auto requireLike = dyn_cast<RequireLike>(*op)) {
175+
clonedOp =
176+
replaceContractOp(builder, requireLike, mapping, assumeContract);
177+
if (!clonedOp) {
178+
return failure();
179+
}
180+
} else {
181+
clonedOp = builder.clone(*op, mapping);
182+
}
183+
for (auto [x, y] : llvm::zip(op->getResults(), clonedOp->getResults())) {
184+
mapping.map(x, y);
185+
}
186+
}
187+
return success();
159188
}
160189

161190
LogicalResult runOnHWModule(HWModuleOp hwModule, ModuleOp mlirModule) {
@@ -186,7 +215,7 @@ LogicalResult runOnHWModule(HWModuleOp hwModule, ModuleOp mlirModule) {
186215
// Clone fan in cone for contract operands
187216
for (auto operand : contract.getOperands()) {
188217
auto *definingOp = operand.getDefiningOp();
189-
if (failed(cloneFanIn(formalBuilder, definingOp, mapping, seen, false)))
218+
if (failed(cloneFanIn(formalBuilder, definingOp, mapping, seen, true)))
190219
return failure();
191220
}
192221

Diff for: test/Dialect/Verif/lower-contracts.mlir

+27-27
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,11 @@ hw.module @Mul9(in %a: i42, out z: i42) {
4949
// CHECK-NEXT: %3 = comb.xor %0, %1, %2 : i42
5050
// CHECK-NEXT: %4 = comb.extract %0 from 0 : (i42) -> i41
5151
// CHECK-NEXT: %5 = comb.extract %1 from 0 : (i42) -> i41
52-
// CHECK-NEXT: %6 = comb.and %4, %5 : i41
53-
// CHECK-NEXT: %7 = comb.or %4, %5 : i41
54-
// CHECK-NEXT: %8 = comb.extract %2 from 0 : (i42) -> i41
55-
// CHECK-NEXT: %9 = comb.and %7, %8 : i41
56-
// CHECK-NEXT: %10 = comb.or %6, %9 : i41
52+
// CHECK-NEXT: %6 = comb.or %4, %5 : i41
53+
// CHECK-NEXT: %7 = comb.extract %2 from 0 : (i42) -> i41
54+
// CHECK-NEXT: %8 = comb.and %6, %7 : i41
55+
// CHECK-NEXT: %9 = comb.and %4, %5 : i41
56+
// CHECK-NEXT: %10 = comb.or %9, %8 : i41
5757
// CHECK-NEXT: %11 = comb.concat %10, %false : i41, i1
5858
// CHECK-NEXT: %12 = comb.add %0, %1, %2 : i42
5959
// CHECK-NEXT: %13 = comb.add %3, %11 : i42
@@ -95,22 +95,22 @@ hw.module @CarrySaveCompress3to2(
9595
// CHECK-NEXT: %c0_i4 = hw.constant 0 : i4
9696
// CHECK-NEXT: %c8_i8 = hw.constant 8 : i8
9797
// CHECK-NEXT: %0 = verif.symbolic_value : i8
98-
// CHECK-NEXT: %1 = comb.extract %0 from 0 : (i8) -> i1
99-
// CHECK-NEXT: %2 = comb.extract %0 from 1 : (i8) -> i1
100-
// CHECK-NEXT: %3 = comb.extract %0 from 2 : (i8) -> i1
101-
// CHECK-NEXT: %4 = verif.symbolic_value : i8
102-
// CHECK-NEXT: %5 = comb.extract %4 from 0 : (i8) -> i4
103-
// CHECK-NEXT: %6 = comb.concat %5, %c0_i4 : i4, i4
104-
// CHECK-NEXT: %7 = comb.mux %3, %6, %4 : i8
105-
// CHECK-NEXT: %8 = comb.extract %7 from 0 : (i8) -> i6
106-
// CHECK-NEXT: %9 = comb.concat %8, %c0_i2 : i6, i2
107-
// CHECK-NEXT: %10 = comb.mux %2, %9, %7 : i8
108-
// CHECK-NEXT: %11 = comb.extract %10 from 0 : (i8) -> i7
109-
// CHECK-NEXT: %12 = comb.concat %11, %false : i7, i1
110-
// CHECK-NEXT: %13 = comb.mux %1, %12, %10 : i8
98+
// CHECK-NEXT: %1 = verif.symbolic_value : i8
99+
// CHECK-NEXT: %2 = comb.extract %1 from 0 : (i8) -> i4
100+
// CHECK-NEXT: %3 = comb.concat %2, %c0_i4 : i4, i4
101+
// CHECK-NEXT: %4 = comb.extract %0 from 2 : (i8) -> i1
102+
// CHECK-NEXT: %5 = comb.mux %4, %3, %1 : i8
103+
// CHECK-NEXT: %6 = comb.extract %5 from 0 : (i8) -> i6
104+
// CHECK-NEXT: %7 = comb.concat %6, %c0_i2 : i6, i2
105+
// CHECK-NEXT: %8 = comb.extract %0 from 1 : (i8) -> i1
106+
// CHECK-NEXT: %9 = comb.mux %8, %7, %5 : i8
107+
// CHECK-NEXT: %10 = comb.extract %9 from 0 : (i8) -> i7
108+
// CHECK-NEXT: %11 = comb.concat %10, %false : i7, i1
109+
// CHECK-NEXT: %12 = comb.extract %0 from 0 : (i8) -> i1
110+
// CHECK-NEXT: %13 = comb.mux %12, %11, %9 : i8
111111
// CHECK-NEXT: %14 = comb.icmp ult %0, %c8_i8 : i8
112112
// CHECK-NEXT: verif.assume %14 : i1
113-
// CHECK-NEXT: %15 = comb.shl %4, %0 : i8
113+
// CHECK-NEXT: %15 = comb.shl %1, %0 : i8
114114
// CHECK-NEXT: %16 = comb.icmp eq %13, %15 : i8
115115
// CHECK-NEXT: verif.assert %16 : i1
116116
// CHECK-NEXT: }
@@ -187,11 +187,11 @@ hw.module @NoContract(in %a: i42, out z: i42) {
187187
// CHECK-NEXT: %c2_i42 = hw.constant 2 : i42
188188
// CHECK-NEXT: %0 = verif.symbolic_value : i42
189189
// CHECK-NEXT: %1 = verif.symbolic_value : i42
190-
// CHECK-NEXT: %2 = comb.icmp ult %0, %c2_i42 : i42
190+
// CHECK-NEXT: %2 = comb.icmp ult %1, %c2_i42 : i42
191191
// CHECK-NEXT: verif.assert %2 : i1
192-
// CHECK-NEXT: %3 = comb.extract %0 from 0 : (i42) -> i41
192+
// CHECK-NEXT: %3 = comb.extract %1 from 0 : (i42) -> i41
193193
// CHECK-NEXT: %4 = comb.concat %3, %false : i41, i1
194-
// CHECK-NEXT: %5 = comb.icmp eq %1, %4 : i42
194+
// CHECK-NEXT: %5 = comb.icmp eq %0, %4 : i42
195195
// CHECK-NEXT: verif.assume %5 : i1
196196
// CHECK-NEXT: verif.assert %5 : i1
197197
// CHECK-NEXT: }
@@ -237,14 +237,14 @@ hw.module @TwoContracts(in %a: i42, out z: i42) {
237237
// CHECK-NEXT: %false = hw.constant false
238238
// CHECK-NEXT: %0 = verif.symbolic_value : i1
239239
// CHECK-NEXT: %1 = verif.symbolic_value : i42
240-
// CHECK-NEXT: %2 = comb.extract %1 from 0 : (i42) -> i41
241-
// CHECK-NEXT: %3 = comb.concat %2, %false : i41, i1
242-
// CHECK-NEXT: %4 = comb.mul %1, %1 : i42
240+
// CHECK-NEXT: %2 = comb.mul %1, %1 : i42
241+
// CHECK-NEXT: %3 = comb.extract %1 from 0 : (i42) -> i41
242+
// CHECK-NEXT: %4 = comb.concat %3, %false : i41, i1
243243
// CHECK-NEXT: %5 = scf.if %0 -> (i42) {
244-
// CHECK-NEXT: %8 = comb.add %3, %4 : i42
244+
// CHECK-NEXT: %8 = comb.add %4, %2 : i42
245245
// CHECK-NEXT: scf.yield %8 : i42
246246
// CHECK-NEXT: } else {
247-
// CHECK-NEXT: %8 = comb.mul %3, %4 : i42
247+
// CHECK-NEXT: %8 = comb.mul %4, %2 : i42
248248
// CHECK-NEXT: scf.yield %8 : i42
249249
// CHECK-NEXT: }
250250
// CHECK-NEXT: %6 = verif.symbolic_value : i42

0 commit comments

Comments
 (0)