Skip to content

Commit

Permalink
[HW][circt-synth] Implement AggregateToComb pass and add to circt-syinth
Browse files Browse the repository at this point in the history
pipeline
  • Loading branch information
uenoku committed Jan 15, 2025
1 parent 2d87d21 commit abb9d77
Show file tree
Hide file tree
Showing 12 changed files with 363 additions and 56 deletions.
9 changes: 9 additions & 0 deletions include/circt/Dialect/Comb/CombOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ Value createOrFoldNot(Location loc, Value value, OpBuilder &builder,
Value createOrFoldNot(Value value, ImplicitLocOpBuilder &builder,
bool twoState = false);

/// Extract bits from a value.
void extractBits(OpBuilder &builder, Value val, SmallVectorImpl<Value> &bits);

/// Construct a mux tree for given leaf nodes. `selectors` is the selector for
/// each level of the tree. Currently the selector is tested from MSB to LSB.
Value constructMuxTree(OpBuilder &builder, Location loc,
ArrayRef<Value> selectors, ArrayRef<Value> leafNodes,
Value outOfBoundsValue);

} // namespace comb
} // namespace circt

Expand Down
1 change: 1 addition & 0 deletions include/circt/Dialect/HW/HWPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ std::unique_ptr<mlir::Pass> createFlattenIOPass(bool recursiveFlag = true,
std::unique_ptr<mlir::Pass> createVerifyInnerRefNamespacePass();
std::unique_ptr<mlir::Pass> createFlattenModulesPass();
std::unique_ptr<mlir::Pass> createFooWiresPass();
std::unique_ptr<mlir::Pass> createHWAggregateToCombPass();

/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
Expand Down
12 changes: 12 additions & 0 deletions include/circt/Dialect/HW/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,16 @@ def FooWires : Pass<"hw-foo-wires", "hw::HWModuleOp"> {
let constructor = "circt::hw::createFooWiresPass()";
}

def HWAggregateToComb : Pass<"hw-aggregate-to-comb", "hw::HWModuleOp"> {
let summary = "Lower aggregate operations to comb operations";
let constructor = "circt::hw::createHWAggregateToCombPass()";

let description = [{
This pass lowers aggregate *operations* to comb operations within modules.
This pass does not lower ports, as ports are handled by FlattenIO. This pass
will also change the behavior of out-of-bounds access of arrays.
}];
let dependentDialects = ["comb::CombDialect"];
}

#endif // CIRCT_DIALECT_HW_PASSES_TD
18 changes: 17 additions & 1 deletion integration_test/circt-synth/comb-lowering-lec.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// REQUIRES: libz3
// REQUIRES: circt-lec-jit

// RUN: circt-opt %s --convert-comb-to-aig --convert-aig-to-comb -o %t.mlir
// RUN: circt-opt %s --hw-aggregate-to-comb --convert-comb-to-aig --convert-aig-to-comb -o %t.mlir
// RUN: circt-lec %t.mlir %s -c1=bit_logical -c2=bit_logical --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_BIT_LOGICAL
// COMB_BIT_LOGICAL: c1 == c2
hw.module @bit_logical(in %arg0: i32, in %arg1: i32, in %arg2: i32, in %arg3: i32,
Expand Down Expand Up @@ -78,3 +78,19 @@ hw.module @shift5(in %lhs: i5, in %rhs: i5, out out_shl: i5, out out_shr: i5, ou
%2 = comb.shrs %lhs, %rhs : i5
hw.output %0, %1, %2 : i5, i5, i5
}

// RUN: circt-lec %t.mlir %s -c1=array -c2=array --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_ARRAY
// COMB_ARRAY: c1 == c2
hw.module @array(in %arg0: i2, in %arg1: i2, in %arg2: i2, in %arg3: i2, in %sel1: i2, in %sel2: i2, out out1: i2, out out2: i2) {
%0 = hw.array_create %arg0, %arg1, %arg2, %arg3 : i2
%1 = hw.array_get %0[%sel1] : !hw.array<4xi2>, i2
%2 = hw.array_create %arg0, %arg1, %arg2 : i2
%c3_i2 = hw.constant 3 : i2
// NOTE: If the index is out of bounds, the result value is undefined.
// In LEC such value is lowered into unbounded SMT variable and cause
// the LEC to fail. So just asssume that the index is in bounds.
%inbound = comb.icmp ult %sel2, %c3_i2 : i2
verif.assume %inbound : i1
%3 = hw.array_get %2[%sel2] : !hw.array<3xi2>, i2
hw.output %1, %3 : i2, i2
}
71 changes: 16 additions & 55 deletions lib/Conversion/CombToAIG/CombToAIG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,64 +29,13 @@ using namespace comb;
// Utility Functions
//===----------------------------------------------------------------------===//

// Extract individual bits from a value
static SmallVector<Value> extractBits(ConversionPatternRewriter &rewriter,
Value val) {
assert(val.getType().isInteger() && "expected integer");
auto width = val.getType().getIntOrFloatBitWidth();
// A wrapper for comb::extractBits that returns a SmallVector<Value>.
static SmallVector<Value> extractBits(OpBuilder &builder, Value val) {
SmallVector<Value> bits;
bits.reserve(width);

// Check if we can reuse concat operands
if (auto concat = val.getDefiningOp<comb::ConcatOp>()) {
if (concat.getNumOperands() == width &&
llvm::all_of(concat.getOperandTypes(), [](Type type) {
return type.getIntOrFloatBitWidth() == 1;
})) {
// Reverse the operands to match the bit order
bits.append(std::make_reverse_iterator(concat.getOperands().end()),
std::make_reverse_iterator(concat.getOperands().begin()));
return bits;
}
}

// Extract individual bits
for (int64_t i = 0; i < width; ++i)
bits.push_back(
rewriter.createOrFold<comb::ExtractOp>(val.getLoc(), val, i, 1));

comb::extractBits(builder, val, bits);
return bits;
}

// Construct a mux tree for given leaf nodes. `selectors` is the selector for
// each level of the tree. Currently the selector is tested from MSB to LSB.
static Value constructMuxTree(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> selectors,
ArrayRef<Value> leafNodes,
Value outOfBoundsValue) {
// Recursive helper function to construct the mux tree
std::function<Value(size_t, size_t)> constructTreeHelper =
[&](size_t id, size_t level) -> Value {
// Base case: at the lowest level, return the result
if (level == 0) {
// Return the result for the given index. If the index is out of bounds,
// return the out-of-bound value.
return id < leafNodes.size() ? leafNodes[id] : outOfBoundsValue;
}

auto selector = selectors[level - 1];

// Recursive case: create muxes for true and false branches
auto trueVal = constructTreeHelper(2 * id + 1, level - 1);
auto falseVal = constructTreeHelper(2 * id, level - 1);

// Combine the results with a mux
return rewriter.createOrFold<comb::MuxOp>(loc, selector, trueVal, falseVal);
};

return constructTreeHelper(0, llvm::Log2_64_Ceil(leafNodes.size()));
}

// Construct a mux tree for shift operations. `isLeftShift` controls the
// direction of the shift operation and is used to determine order of the
// padding and extracted bits. Callbacks `getPadding` and `getExtract` are used
Expand Down Expand Up @@ -128,7 +77,8 @@ static Value createShiftLogic(ConversionPatternRewriter &rewriter, Location loc,
assert(outOfBoundsValue && "outOfBoundsValue must be valid");

// Construct mux tree for shift operation
auto result = constructMuxTree(rewriter, loc, bits, nodes, outOfBoundsValue);
auto result =
comb::constructMuxTree(rewriter, loc, bits, nodes, outOfBoundsValue);

// Add bounds checking
auto inBound = rewriter.createOrFold<comb::ICmpOp>(
Expand Down Expand Up @@ -667,10 +617,21 @@ static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns) {

void ConvertCombToAIGPass::runOnOperation() {
ConversionTarget target(getContext());

// Comb is source dialect.
target.addIllegalDialect<comb::CombDialect>();
// Keep data movement operations like Extract, Concat and Replicate.
target.addLegalOp<comb::ExtractOp, comb::ConcatOp, comb::ReplicateOp,
hw::BitcastOp, hw::ConstantOp>();

// Treat array operations as illegal. Strictly speaking, other than array get
// operation with non-const index are legal in AIG but array types prevent a
// bunch of optimizations so just lower them to integer operations. It's
// required to run HWAggregateToComb pass before this pass.
target.addIllegalOp<hw::ArrayGetOp, hw::ArrayCreateOp, hw::ArrayConcatOp,
hw::AggregateConstantOp>();

// AIG is target dialect.
target.addLegalDialect<aig::AIGDialect>();

// This is a test only option to add logical ops.
Expand Down
55 changes: 55 additions & 0 deletions lib/Dialect/Comb/CombOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,61 @@ Value comb::createOrFoldNot(Value value, ImplicitLocOpBuilder &builder,
return createOrFoldNot(builder.getLoc(), value, builder, twoState);
}

// Extract individual bits from a value
void comb::extractBits(OpBuilder &builder, Value val,
SmallVectorImpl<Value> &bits) {
assert(val.getType().isInteger() && "expected integer");
auto width = val.getType().getIntOrFloatBitWidth();
bits.reserve(width);

// Check if we can reuse concat operands
if (auto concat = val.getDefiningOp<comb::ConcatOp>()) {
if (concat.getNumOperands() == width &&
llvm::all_of(concat.getOperandTypes(), [](Type type) {
return type.getIntOrFloatBitWidth() == 1;
})) {
// Reverse the operands to match the bit order
bits.append(std::make_reverse_iterator(concat.getOperands().end()),
std::make_reverse_iterator(concat.getOperands().begin()));
return;
}
}

// Extract individual bits
for (int64_t i = 0; i < width; ++i)
bits.push_back(
builder.createOrFold<comb::ExtractOp>(val.getLoc(), val, i, 1));
}

// Construct a mux tree for given leaf nodes. `selectors` is the selector for
// each level of the tree. Currently the selector is tested from MSB to LSB.
Value comb::constructMuxTree(OpBuilder &builder, Location loc,
ArrayRef<Value> selectors,
ArrayRef<Value> leafNodes,
Value outOfBoundsValue) {
// Recursive helper function to construct the mux tree
std::function<Value(size_t, size_t)> constructTreeHelper =
[&](size_t id, size_t level) -> Value {
// Base case: at the lowest level, return the result
if (level == 0) {
// Return the result for the given index. If the index is out of bounds,
// return the out-of-bound value.
return id < leafNodes.size() ? leafNodes[id] : outOfBoundsValue;
}

auto selector = selectors[level - 1];

// Recursive case: create muxes for true and false branches
auto trueVal = constructTreeHelper(2 * id + 1, level - 1);
auto falseVal = constructTreeHelper(2 * id, level - 1);

// Combine the results with a mux
return builder.createOrFold<comb::MuxOp>(loc, selector, trueVal, falseVal);
};

return constructTreeHelper(0, llvm::Log2_64_Ceil(leafNodes.size()));
}

//===----------------------------------------------------------------------===//
// ICmpOp
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/HW/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_circt_dialect_library(CIRCTHWTransforms
HWAggregateToComb.cpp
HWPrintInstanceGraph.cpp
HWSpecialize.cpp
PrintHWModuleGraph.cpp
Expand Down
Loading

0 comments on commit abb9d77

Please sign in to comment.