From 1c5826fcd3486e12f5d36d6a443cdf318f8f70b3 Mon Sep 17 00:00:00 2001 From: Jiahan Xie <88367305+jiahanxie353@users.noreply.github.com> Date: Wed, 8 Jan 2025 22:05:17 -0500 Subject: [PATCH] [MemoryBanking] Add a new field to keep track of memory reference values that need to be cleaned up after memory banking is complete (#8039) --- lib/Transforms/MemoryBanking.cpp | 29 +++++++---- test/Transforms/memory_banking.mlir | 76 +++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 10 deletions(-) diff --git a/lib/Transforms/MemoryBanking.cpp b/lib/Transforms/MemoryBanking.cpp index c4e974d0a64b..ccca52dcbde4 100644 --- a/lib/Transforms/MemoryBanking.cpp +++ b/lib/Transforms/MemoryBanking.cpp @@ -50,6 +50,9 @@ struct MemoryBankingPass // map from original memory definition to newly allocated banks DenseMap> memoryToBanks; DenseSet opsToErase; + // Track memory references that need to be cleaned up after memory banking is + // complete. + DenseSet oldMemRefVals; }; } // namespace @@ -134,10 +137,11 @@ struct BankAffineLoadPattern : public OpRewritePattern { BankAffineLoadPattern(MLIRContext *context, uint64_t bankingFactor, unsigned bankingDimension, - DenseMap> &memoryToBanks) + DenseMap> &memoryToBanks, + DenseSet &oldMemRefVals) : OpRewritePattern(context), bankingFactor(bankingFactor), bankingDimension(bankingDimension), - memoryToBanks(memoryToBanks) {} + memoryToBanks(memoryToBanks), oldMemRefVals(oldMemRefVals) {} LogicalResult matchAndRewrite(mlir::affine::AffineLoadOp loadOp, PatternRewriter &rewriter) const override { @@ -187,6 +191,10 @@ struct BankAffineLoadPattern auto defaultValue = rewriter.create(loc, zeroAttr); rewriter.create(loc, defaultValue.getResult()); + // We track Load's memory reference only if it is a block argument - this is + // the only case where the reference isn't replaced. + if (Value memRef = loadOp.getMemref(); isa(memRef)) + oldMemRefVals.insert(memRef); rewriter.replaceOp(loadOp, switchOp.getResult(0)); return success(); @@ -196,6 +204,7 @@ struct BankAffineLoadPattern uint64_t bankingFactor; unsigned bankingDimension; DenseMap> &memoryToBanks; + DenseSet &oldMemRefVals; }; // Replace the original store operations with newly created memory banks @@ -205,11 +214,12 @@ struct BankAffineStorePattern unsigned bankingDimension, DenseMap> &memoryToBanks, DenseSet &opsToErase, - DenseSet &processedOps) + DenseSet &processedOps, + DenseSet &oldMemRefVals) : OpRewritePattern(context), bankingFactor(bankingFactor), bankingDimension(bankingDimension), memoryToBanks(memoryToBanks), opsToErase(opsToErase), - processedOps(processedOps) {} + processedOps(processedOps), oldMemRefVals(oldMemRefVals) {} LogicalResult matchAndRewrite(mlir::affine::AffineStoreOp storeOp, PatternRewriter &rewriter) const override { @@ -262,6 +272,7 @@ struct BankAffineStorePattern processedOps.insert(storeOp); opsToErase.insert(storeOp); + oldMemRefVals.insert(storeOp.getMemref()); return success(); } @@ -272,6 +283,7 @@ struct BankAffineStorePattern DenseMap> &memoryToBanks; DenseSet &opsToErase; DenseSet &processedOps; + DenseSet &oldMemRefVals; }; // Replace the original return operation with newly created memory banks @@ -388,9 +400,10 @@ void MemoryBankingPass::runOnOperation() { DenseSet processedOps; patterns.add(ctx, bankingFactor, bankingDimension, - memoryToBanks); + memoryToBanks, oldMemRefVals); patterns.add(ctx, bankingFactor, bankingDimension, - memoryToBanks, opsToErase, processedOps); + memoryToBanks, opsToErase, processedOps, + oldMemRefVals); patterns.add(ctx, memoryToBanks); GreedyRewriteConfig config; @@ -401,10 +414,6 @@ void MemoryBankingPass::runOnOperation() { } // Clean up the old memref values - DenseSet oldMemRefVals; - for (const auto &[memory, _] : memoryToBanks) - oldMemRefVals.insert(memory); - if (failed(cleanUpOldMemRefs(oldMemRefVals, opsToErase))) { signalPassFailure(); } diff --git a/test/Transforms/memory_banking.mlir b/test/Transforms/memory_banking.mlir index b859ccbe1770..be015e2b4973 100644 --- a/test/Transforms/memory_banking.mlir +++ b/test/Transforms/memory_banking.mlir @@ -1,6 +1,7 @@ // RUN: circt-opt %s -split-input-file -memory-banking="banking-factor=2" | FileCheck %s --check-prefix UNROLL-BY-2 // RUN: circt-opt %s -split-input-file -memory-banking="banking-factor=1" | FileCheck %s --check-prefix UNROLL-BY-1 // RUN: circt-opt %s -split-input-file -memory-banking="banking-factor=8" | FileCheck %s --check-prefix UNROLL-BY-8 +// RUN: circt-opt %s -split-input-file -memory-banking="banking-factor=2" | FileCheck %s --check-prefix ALLOC-UNROLL-2 // ----- @@ -259,3 +260,78 @@ func.func @bank_one_dim_unroll8(%arg0: memref<8xf32>, %arg1: memref<8xf32>) -> ( } return %mem : memref<8xf32> } + +// ----- + +// ALLOC-UNROLL-2: #[[$ATTR_0:.+]] = affine_map<(d0) -> (d0 mod 2)> +// ALLOC-UNROLL-2: #[[$ATTR_1:.+]] = affine_map<(d0) -> (d0 floordiv 2)> + + +// ALLOC-UNROLL-2-LABEL: func.func @alloc_unroll2() -> (memref<4xf32>, memref<4xf32>) { +// ALLOC-UNROLL-2: %[[VAL_0:.*]] = arith.constant 0.000000e+00 : f32 +// ALLOC-UNROLL-2: %[[VAL_1:.*]] = memref.alloc() : memref<4xf32> +// ALLOC-UNROLL-2: %[[VAL_2:.*]] = memref.alloc() : memref<4xf32> +// ALLOC-UNROLL-2: %[[VAL_3:.*]] = memref.alloc() : memref<4xf32> +// ALLOC-UNROLL-2: %[[VAL_4:.*]] = memref.alloc() : memref<4xf32> +// ALLOC-UNROLL-2: %[[VAL_5:.*]] = memref.alloc() : memref<4xf32> +// ALLOC-UNROLL-2: %[[VAL_6:.*]] = memref.alloc() : memref<4xf32> +// ALLOC-UNROLL-2: affine.parallel (%[[VAL_7:.*]]) = (0) to (8) { +// ALLOC-UNROLL-2: %[[VAL_8:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_7]]) +// ALLOC-UNROLL-2: %[[VAL_9:.*]] = affine.apply #[[$ATTR_1]](%[[VAL_7]]) +// ALLOC-UNROLL-2: %[[VAL_10:.*]] = scf.index_switch %[[VAL_8]] -> f32 +// ALLOC-UNROLL-2: case 0 { +// ALLOC-UNROLL-2: %[[VAL_11:.*]] = affine.load %[[VAL_1]]{{\[}}%[[VAL_9]]] : memref<4xf32> +// ALLOC-UNROLL-2: scf.yield %[[VAL_11]] : f32 +// ALLOC-UNROLL-2: } +// ALLOC-UNROLL-2: case 1 { +// ALLOC-UNROLL-2: %[[VAL_12:.*]] = affine.load %[[VAL_2]]{{\[}}%[[VAL_9]]] : memref<4xf32> +// ALLOC-UNROLL-2: scf.yield %[[VAL_12]] : f32 +// ALLOC-UNROLL-2: } +// ALLOC-UNROLL-2: default { +// ALLOC-UNROLL-2: scf.yield %[[VAL_0]] : f32 +// ALLOC-UNROLL-2: } +// ALLOC-UNROLL-2: %[[VAL_13:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_7]]) +// ALLOC-UNROLL-2: %[[VAL_14:.*]] = affine.apply #[[$ATTR_1]](%[[VAL_7]]) +// ALLOC-UNROLL-2: %[[VAL_15:.*]] = scf.index_switch %[[VAL_13]] -> f32 +// ALLOC-UNROLL-2: case 0 { +// ALLOC-UNROLL-2: %[[VAL_16:.*]] = affine.load %[[VAL_3]]{{\[}}%[[VAL_14]]] : memref<4xf32> +// ALLOC-UNROLL-2: scf.yield %[[VAL_16]] : f32 +// ALLOC-UNROLL-2: } +// ALLOC-UNROLL-2: case 1 { +// ALLOC-UNROLL-2: %[[VAL_17:.*]] = affine.load %[[VAL_4]]{{\[}}%[[VAL_14]]] : memref<4xf32> +// ALLOC-UNROLL-2: scf.yield %[[VAL_17]] : f32 +// ALLOC-UNROLL-2: } +// ALLOC-UNROLL-2: default { +// ALLOC-UNROLL-2: scf.yield %[[VAL_0]] : f32 +// ALLOC-UNROLL-2: } +// ALLOC-UNROLL-2: %[[VAL_18:.*]] = arith.mulf %[[VAL_10]], %[[VAL_15]] : f32 +// ALLOC-UNROLL-2: %[[VAL_19:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_7]]) +// ALLOC-UNROLL-2: %[[VAL_20:.*]] = affine.apply #[[$ATTR_1]](%[[VAL_7]]) +// ALLOC-UNROLL-2: scf.index_switch %[[VAL_19]] +// ALLOC-UNROLL-2: case 0 { +// ALLOC-UNROLL-2: affine.store %[[VAL_18]], %[[VAL_5]]{{\[}}%[[VAL_20]]] : memref<4xf32> +// ALLOC-UNROLL-2: scf.yield +// ALLOC-UNROLL-2: } +// ALLOC-UNROLL-2: case 1 { +// ALLOC-UNROLL-2: affine.store %[[VAL_18]], %[[VAL_6]]{{\[}}%[[VAL_20]]] : memref<4xf32> +// ALLOC-UNROLL-2: scf.yield +// ALLOC-UNROLL-2: } +// ALLOC-UNROLL-2: default { +// ALLOC-UNROLL-2: } +// ALLOC-UNROLL-2: } +// ALLOC-UNROLL-2: return %[[VAL_5]], %[[VAL_6]] : memref<4xf32>, memref<4xf32> +// ALLOC-UNROLL-2: } + +func.func @alloc_unroll2() -> (memref<8xf32>) { + %arg0 = memref.alloc() : memref<8xf32> + %arg1 = memref.alloc() : memref<8xf32> + %mem = memref.alloc() : memref<8xf32> + affine.parallel (%i) = (0) to (8) { + %1 = affine.load %arg0[%i] : memref<8xf32> + %2 = affine.load %arg1[%i] : memref<8xf32> + %3 = arith.mulf %1, %2 : f32 + affine.store %3, %mem[%i] : memref<8xf32> + } + return %mem : memref<8xf32> +} +