Skip to content

Commit

Permalink
[MemoryBanking] Add a new field to keep track of memory reference val…
Browse files Browse the repository at this point in the history
…ues that need to be cleaned up after memory banking is complete (#8039)
  • Loading branch information
jiahanxie353 authored Jan 9, 2025
1 parent b650d5e commit 1c5826f
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 10 deletions.
29 changes: 19 additions & 10 deletions lib/Transforms/MemoryBanking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ struct MemoryBankingPass
// map from original memory definition to newly allocated banks
DenseMap<Value, SmallVector<Value>> memoryToBanks;
DenseSet<Operation *> opsToErase;
// Track memory references that need to be cleaned up after memory banking is
// complete.
DenseSet<Value> oldMemRefVals;
};
} // namespace

Expand Down Expand Up @@ -134,10 +137,11 @@ struct BankAffineLoadPattern
: public OpRewritePattern<mlir::affine::AffineLoadOp> {
BankAffineLoadPattern(MLIRContext *context, uint64_t bankingFactor,
unsigned bankingDimension,
DenseMap<Value, SmallVector<Value>> &memoryToBanks)
DenseMap<Value, SmallVector<Value>> &memoryToBanks,
DenseSet<Value> &oldMemRefVals)
: OpRewritePattern<mlir::affine::AffineLoadOp>(context),
bankingFactor(bankingFactor), bankingDimension(bankingDimension),
memoryToBanks(memoryToBanks) {}
memoryToBanks(memoryToBanks), oldMemRefVals(oldMemRefVals) {}

LogicalResult matchAndRewrite(mlir::affine::AffineLoadOp loadOp,
PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -187,6 +191,10 @@ struct BankAffineLoadPattern
auto defaultValue = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
rewriter.create<scf::YieldOp>(loc, defaultValue.getResult());

// We track Load's memory reference only if it is a block argument - this is
// the only case where the reference isn't replaced.
if (Value memRef = loadOp.getMemref(); isa<BlockArgument>(memRef))
oldMemRefVals.insert(memRef);
rewriter.replaceOp(loadOp, switchOp.getResult(0));

return success();
Expand All @@ -196,6 +204,7 @@ struct BankAffineLoadPattern
uint64_t bankingFactor;
unsigned bankingDimension;
DenseMap<Value, SmallVector<Value>> &memoryToBanks;
DenseSet<Value> &oldMemRefVals;
};

// Replace the original store operations with newly created memory banks
Expand All @@ -205,11 +214,12 @@ struct BankAffineStorePattern
unsigned bankingDimension,
DenseMap<Value, SmallVector<Value>> &memoryToBanks,
DenseSet<Operation *> &opsToErase,
DenseSet<Operation *> &processedOps)
DenseSet<Operation *> &processedOps,
DenseSet<Value> &oldMemRefVals)
: OpRewritePattern<mlir::affine::AffineStoreOp>(context),
bankingFactor(bankingFactor), bankingDimension(bankingDimension),
memoryToBanks(memoryToBanks), opsToErase(opsToErase),
processedOps(processedOps) {}
processedOps(processedOps), oldMemRefVals(oldMemRefVals) {}

LogicalResult matchAndRewrite(mlir::affine::AffineStoreOp storeOp,
PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -262,6 +272,7 @@ struct BankAffineStorePattern

processedOps.insert(storeOp);
opsToErase.insert(storeOp);
oldMemRefVals.insert(storeOp.getMemref());

return success();
}
Expand All @@ -272,6 +283,7 @@ struct BankAffineStorePattern
DenseMap<Value, SmallVector<Value>> &memoryToBanks;
DenseSet<Operation *> &opsToErase;
DenseSet<Operation *> &processedOps;
DenseSet<Value> &oldMemRefVals;
};

// Replace the original return operation with newly created memory banks
Expand Down Expand Up @@ -388,9 +400,10 @@ void MemoryBankingPass::runOnOperation() {

DenseSet<Operation *> processedOps;
patterns.add<BankAffineLoadPattern>(ctx, bankingFactor, bankingDimension,
memoryToBanks);
memoryToBanks, oldMemRefVals);
patterns.add<BankAffineStorePattern>(ctx, bankingFactor, bankingDimension,
memoryToBanks, opsToErase, processedOps);
memoryToBanks, opsToErase, processedOps,
oldMemRefVals);
patterns.add<BankReturnPattern>(ctx, memoryToBanks);

GreedyRewriteConfig config;
Expand All @@ -401,10 +414,6 @@ void MemoryBankingPass::runOnOperation() {
}

// Clean up the old memref values
DenseSet<Value> oldMemRefVals;
for (const auto &[memory, _] : memoryToBanks)
oldMemRefVals.insert(memory);

if (failed(cleanUpOldMemRefs(oldMemRefVals, opsToErase))) {
signalPassFailure();
}
Expand Down
76 changes: 76 additions & 0 deletions test/Transforms/memory_banking.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// RUN: circt-opt %s -split-input-file -memory-banking="banking-factor=2" | FileCheck %s --check-prefix UNROLL-BY-2
// RUN: circt-opt %s -split-input-file -memory-banking="banking-factor=1" | FileCheck %s --check-prefix UNROLL-BY-1
// RUN: circt-opt %s -split-input-file -memory-banking="banking-factor=8" | FileCheck %s --check-prefix UNROLL-BY-8
// RUN: circt-opt %s -split-input-file -memory-banking="banking-factor=2" | FileCheck %s --check-prefix ALLOC-UNROLL-2

// -----

Expand Down Expand Up @@ -259,3 +260,78 @@ func.func @bank_one_dim_unroll8(%arg0: memref<8xf32>, %arg1: memref<8xf32>) -> (
}
return %mem : memref<8xf32>
}

// -----

// ALLOC-UNROLL-2: #[[$ATTR_0:.+]] = affine_map<(d0) -> (d0 mod 2)>
// ALLOC-UNROLL-2: #[[$ATTR_1:.+]] = affine_map<(d0) -> (d0 floordiv 2)>


// ALLOC-UNROLL-2-LABEL: func.func @alloc_unroll2() -> (memref<4xf32>, memref<4xf32>) {
// ALLOC-UNROLL-2: %[[VAL_0:.*]] = arith.constant 0.000000e+00 : f32
// ALLOC-UNROLL-2: %[[VAL_1:.*]] = memref.alloc() : memref<4xf32>
// ALLOC-UNROLL-2: %[[VAL_2:.*]] = memref.alloc() : memref<4xf32>
// ALLOC-UNROLL-2: %[[VAL_3:.*]] = memref.alloc() : memref<4xf32>
// ALLOC-UNROLL-2: %[[VAL_4:.*]] = memref.alloc() : memref<4xf32>
// ALLOC-UNROLL-2: %[[VAL_5:.*]] = memref.alloc() : memref<4xf32>
// ALLOC-UNROLL-2: %[[VAL_6:.*]] = memref.alloc() : memref<4xf32>
// ALLOC-UNROLL-2: affine.parallel (%[[VAL_7:.*]]) = (0) to (8) {
// ALLOC-UNROLL-2: %[[VAL_8:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_7]])
// ALLOC-UNROLL-2: %[[VAL_9:.*]] = affine.apply #[[$ATTR_1]](%[[VAL_7]])
// ALLOC-UNROLL-2: %[[VAL_10:.*]] = scf.index_switch %[[VAL_8]] -> f32
// ALLOC-UNROLL-2: case 0 {
// ALLOC-UNROLL-2: %[[VAL_11:.*]] = affine.load %[[VAL_1]]{{\[}}%[[VAL_9]]] : memref<4xf32>
// ALLOC-UNROLL-2: scf.yield %[[VAL_11]] : f32
// ALLOC-UNROLL-2: }
// ALLOC-UNROLL-2: case 1 {
// ALLOC-UNROLL-2: %[[VAL_12:.*]] = affine.load %[[VAL_2]]{{\[}}%[[VAL_9]]] : memref<4xf32>
// ALLOC-UNROLL-2: scf.yield %[[VAL_12]] : f32
// ALLOC-UNROLL-2: }
// ALLOC-UNROLL-2: default {
// ALLOC-UNROLL-2: scf.yield %[[VAL_0]] : f32
// ALLOC-UNROLL-2: }
// ALLOC-UNROLL-2: %[[VAL_13:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_7]])
// ALLOC-UNROLL-2: %[[VAL_14:.*]] = affine.apply #[[$ATTR_1]](%[[VAL_7]])
// ALLOC-UNROLL-2: %[[VAL_15:.*]] = scf.index_switch %[[VAL_13]] -> f32
// ALLOC-UNROLL-2: case 0 {
// ALLOC-UNROLL-2: %[[VAL_16:.*]] = affine.load %[[VAL_3]]{{\[}}%[[VAL_14]]] : memref<4xf32>
// ALLOC-UNROLL-2: scf.yield %[[VAL_16]] : f32
// ALLOC-UNROLL-2: }
// ALLOC-UNROLL-2: case 1 {
// ALLOC-UNROLL-2: %[[VAL_17:.*]] = affine.load %[[VAL_4]]{{\[}}%[[VAL_14]]] : memref<4xf32>
// ALLOC-UNROLL-2: scf.yield %[[VAL_17]] : f32
// ALLOC-UNROLL-2: }
// ALLOC-UNROLL-2: default {
// ALLOC-UNROLL-2: scf.yield %[[VAL_0]] : f32
// ALLOC-UNROLL-2: }
// ALLOC-UNROLL-2: %[[VAL_18:.*]] = arith.mulf %[[VAL_10]], %[[VAL_15]] : f32
// ALLOC-UNROLL-2: %[[VAL_19:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_7]])
// ALLOC-UNROLL-2: %[[VAL_20:.*]] = affine.apply #[[$ATTR_1]](%[[VAL_7]])
// ALLOC-UNROLL-2: scf.index_switch %[[VAL_19]]
// ALLOC-UNROLL-2: case 0 {
// ALLOC-UNROLL-2: affine.store %[[VAL_18]], %[[VAL_5]]{{\[}}%[[VAL_20]]] : memref<4xf32>
// ALLOC-UNROLL-2: scf.yield
// ALLOC-UNROLL-2: }
// ALLOC-UNROLL-2: case 1 {
// ALLOC-UNROLL-2: affine.store %[[VAL_18]], %[[VAL_6]]{{\[}}%[[VAL_20]]] : memref<4xf32>
// ALLOC-UNROLL-2: scf.yield
// ALLOC-UNROLL-2: }
// ALLOC-UNROLL-2: default {
// ALLOC-UNROLL-2: }
// ALLOC-UNROLL-2: }
// ALLOC-UNROLL-2: return %[[VAL_5]], %[[VAL_6]] : memref<4xf32>, memref<4xf32>
// ALLOC-UNROLL-2: }

func.func @alloc_unroll2() -> (memref<8xf32>) {
%arg0 = memref.alloc() : memref<8xf32>
%arg1 = memref.alloc() : memref<8xf32>
%mem = memref.alloc() : memref<8xf32>
affine.parallel (%i) = (0) to (8) {
%1 = affine.load %arg0[%i] : memref<8xf32>
%2 = affine.load %arg1[%i] : memref<8xf32>
%3 = arith.mulf %1, %2 : f32
affine.store %3, %mem[%i] : memref<8xf32>
}
return %mem : memref<8xf32>
}

0 comments on commit 1c5826f

Please sign in to comment.