Skip to content

Commit

Permalink
[MooreToCore] Properly handle OOB accesses of moore.extract (#8182)
Browse files Browse the repository at this point in the history
  • Loading branch information
maerhart authored Feb 3, 2025
1 parent 8880e69 commit 2b0c30a
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 9 deletions.
104 changes: 95 additions & 9 deletions lib/Conversion/MooreToCore/MooreToCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -635,27 +635,113 @@ struct ExtractOpConversion : public OpConversionPattern<ExtractOp> {
LogicalResult
matchAndRewrite(ExtractOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO: properly handle out-of-bounds accesses
// TODO: return X if the domain is four-valued for out-of-bounds accesses
// once we support four-valued lowering
Type resultType = typeConverter->convertType(op.getResult().getType());
Type inputType = adaptor.getInput().getType();
int32_t low = adaptor.getLowBit();

if (isa<IntegerType>(inputType)) {
rewriter.replaceOpWithNewOp<comb::ExtractOp>(
op, resultType, adaptor.getInput(), adaptor.getLowBit());
int32_t inputWidth = inputType.getIntOrFloatBitWidth();
int32_t resultWidth = resultType.getIntOrFloatBitWidth();
int32_t high = low + resultWidth;

SmallVector<Value> toConcat;
if (low < 0)
toConcat.push_back(rewriter.create<hw::ConstantOp>(
op.getLoc(), APInt(std::min(-low, resultWidth), 0)));

if (low < inputWidth && high > 0) {
int32_t lowIdx = std::max(low, 0);
Value middle = rewriter.createOrFold<comb::ExtractOp>(
op.getLoc(),
rewriter.getIntegerType(
std::min(resultWidth, std::min(high, inputWidth) - lowIdx)),
adaptor.getInput(), lowIdx);
toConcat.push_back(middle);
}

int32_t diff = high - inputWidth;
if (diff > 0) {
Value val =
rewriter.create<hw::ConstantOp>(op.getLoc(), APInt(diff, 0));
toConcat.push_back(val);
}

Value concat =
rewriter.createOrFold<comb::ConcatOp>(op.getLoc(), toConcat);
rewriter.replaceOp(op, concat);
return success();
}

if (auto arrTy = dyn_cast<hw::ArrayType>(inputType)) {
int64_t width = llvm::Log2_64_Ceil(arrTy.getNumElements());
Value idx = rewriter.create<hw::ConstantOp>(
op.getLoc(), rewriter.getIntegerType(width), adaptor.getLowBit());
if (isa<hw::ArrayType>(resultType)) {
rewriter.replaceOpWithNewOp<hw::ArraySliceOp>(op, resultType,
adaptor.getInput(), idx);
int32_t width = llvm::Log2_64_Ceil(arrTy.getNumElements());
int32_t inputWidth = arrTy.getNumElements();

if (auto resArrTy = dyn_cast<hw::ArrayType>(resultType)) {
int32_t elementWidth = hw::getBitWidth(arrTy.getElementType());
if (elementWidth < 0)
return failure();

int32_t high = low + resArrTy.getNumElements();
int32_t resWidth = resArrTy.getNumElements();

SmallVector<Value> toConcat;
if (low < 0) {
Value val = rewriter.create<hw::ConstantOp>(
op.getLoc(),
APInt(std::min((-low) * elementWidth, resWidth * elementWidth),
0));
Value res = rewriter.createOrFold<hw::BitcastOp>(
op.getLoc(), hw::ArrayType::get(arrTy.getElementType(), -low),
val);
toConcat.push_back(res);
}

if (low < inputWidth && high > 0) {
int32_t lowIdx = std::max(0, low);
Value lowIdxVal = rewriter.create<hw::ConstantOp>(
op.getLoc(), rewriter.getIntegerType(width), lowIdx);
Value middle = rewriter.createOrFold<hw::ArraySliceOp>(
op.getLoc(),
hw::ArrayType::get(
arrTy.getElementType(),
std::min(resWidth, std::min(inputWidth, high) - lowIdx)),
adaptor.getInput(), lowIdxVal);
toConcat.push_back(middle);
}

int32_t diff = high - inputWidth;
if (diff > 0) {
Value constZero = rewriter.create<hw::ConstantOp>(
op.getLoc(), APInt(diff * elementWidth, 0));
Value val = rewriter.create<hw::BitcastOp>(
op.getLoc(), hw::ArrayType::get(arrTy.getElementType(), diff),
constZero);
toConcat.push_back(val);
}

Value concat =
rewriter.createOrFold<hw::ArrayConcatOp>(op.getLoc(), toConcat);
rewriter.replaceOp(op, concat);
return success();
}

// Otherwise, it has to be the array's element type
if (low < 0 || low >= inputWidth) {
int32_t bw = hw::getBitWidth(resultType);
if (bw < 0)
return failure();

Value val = rewriter.create<hw::ConstantOp>(op.getLoc(), APInt(bw, 0));
Value bitcast =
rewriter.createOrFold<hw::BitcastOp>(op.getLoc(), resultType, val);
rewriter.replaceOp(op, bitcast);
return success();
}

Value idx = rewriter.create<hw::ConstantOp>(
op.getLoc(), rewriter.getIntegerType(width), adaptor.getLowBit());
rewriter.replaceOpWithNewOp<hw::ArrayGetOp>(op, adaptor.getInput(), idx);
return success();
}
Expand Down
56 changes: 56 additions & 0 deletions test/Conversion/MooreToCore/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,62 @@ func.func @Expressions(%arg0: !moore.i1, %arg1: !moore.l1, %arg2: !moore.i6, %ar
// CHECK-NEXT: hw.array_get %arg5[[[V0]]] : !hw.array<5xi32>
moore.extract %arg5 from 2 : !moore.array<5 x i32> -> i32

// CHECK-NEXT: [[C0:%.+]] = hw.constant 0 : i2
// CHECK-NEXT: [[C1:%.+]] = hw.constant 0 : i2
// CHECK-NEXT: comb.concat [[C0]], %arg2, [[C1]] : i2, i6, i2
moore.extract %arg2 from -2 : !moore.i6 -> !moore.i10

// CHECK-NEXT: [[V0:%.+]] = comb.extract %arg2 from 4 : (i6) -> i2
// CHECK-NEXT: [[C0:%.+]] = hw.constant 0 : i2
// CHECK-NEXT: comb.concat [[V0]], [[C0]] : i2, i2
moore.extract %arg2 from 4 : !moore.i6 -> !moore.i4

// CHECK-NEXT: [[C0:%.+]] = hw.constant 0 : i2
// CHECK-NEXT: [[V0:%.+]] = comb.extract %arg2 from 0 : (i6) -> i2
// CHECK-NEXT: comb.concat [[C0]], [[V0]] : i2, i2
moore.extract %arg2 from -2 : !moore.i6 -> !moore.i4

// CHECK-NEXT: hw.constant 0 : i4
moore.extract %arg2 from -6 : !moore.i6 -> !moore.i4

// CHECK-NEXT: hw.constant 0 : i4
moore.extract %arg2 from 6 : !moore.i6 -> !moore.i4

// CHECK-NEXT: [[C0:%.+]] = hw.constant 0 : i64
// CHECK-NEXT: [[V0:%..+]] = hw.bitcast [[C0]] : (i64) -> !hw.array<2xi32>
// CHECK-NEXT: hw.constant 0 : i3
// CHECK-NEXT: [[C1:%.+]] = hw.constant 0 : i64
// CHECK-NEXT: [[V1:%.+]] = hw.bitcast [[C1]] : (i64) -> !hw.array<2xi32>
// CHECK-NEXT: hw.array_concat [[V0]], %arg5, [[V1]] : !hw.array<2xi32>, !hw.array<5xi32>, !hw.array<2xi32>
moore.extract %arg5 from -2 : !moore.array<5 x i32> -> !moore.array<9 x i32>

// CHECK-NEXT: [[IDX:%.+]] = hw.constant 2 : i3
// CHECK-NEXT: [[V0:%.+]] = hw.array_slice %arg5[[[IDX]]] : (!hw.array<5xi32>) -> !hw.array<3xi32>
// CHECK-NEXT: [[C0:%.+]] = hw.constant 0 : i32
// CHECK-NEXT: [[V1:%.+]] = hw.bitcast [[C0]] : (i32) -> !hw.array<1xi32>
// CHECK-NEXT: hw.array_concat [[V0]], [[V1]] : !hw.array<3xi32>, !hw.array<1xi32>
moore.extract %arg5 from 2 : !moore.array<5 x i32> -> !moore.array<4 x i32>

// CHECK-NEXT: [[C0:%.+]] = hw.constant 0 : i32
// CHECK-NEXT: [[V0:%.+]] = hw.bitcast [[C0]] : (i32) -> !hw.array<1xi32>
// CHECK-NEXT: [[IDX:%.+]] = hw.constant 0 : i3
// CHECK-NEXT: [[V1:%.+]] = hw.array_slice %arg5[[[IDX]]] : (!hw.array<5xi32>) -> !hw.array<1xi32>
// CHECK-NEXT: hw.array_concat [[V0]], [[V1]] : !hw.array<1xi32>, !hw.array<1xi32>
moore.extract %arg5 from -1 : !moore.array<5 x i32> -> !moore.array<2 x i32>

// CHECK-NEXT: [[C0:%.+]] = hw.constant 0 : i64
// CHECK-NEXT: hw.bitcast [[C0]] : (i64) -> !hw.array<2xi32>
moore.extract %arg5 from -2 : !moore.array<5 x i32> -> !moore.array<2 x i32>

// CHECK-NEXT: [[C0:%.+]] = hw.constant 0 : i64
// CHECK-NEXT: hw.bitcast [[C0]] : (i64) -> !hw.array<2xi32>
moore.extract %arg5 from 5 : !moore.array<5 x i32> -> !moore.array<2 x i32>

// CHECK-NEXT: hw.constant 0 : i32
moore.extract %arg5 from -2 : !moore.array<5 x i32> -> i32
// CHECK-NEXT: hw.constant 0 : i32
moore.extract %arg5 from 6 : !moore.array<5 x i32> -> i32

// CHECK-NEXT: [[V0:%.+]] = hw.constant 0 : i0
// CHECK-NEXT: llhd.sig.extract %arg6 from [[V0]] : (!hw.inout<i1>) -> !hw.inout<i1>
moore.extract_ref %arg6 from 0 : !moore.ref<!moore.i1> -> !moore.ref<!moore.i1>
Expand Down

0 comments on commit 2b0c30a

Please sign in to comment.