From 1c70fcec53c628a184b38816868c6443efba0523 Mon Sep 17 00:00:00 2001 From: Jiahan Xie <88367305+jiahanxie353@users.noreply.github.com> Date: Mon, 27 Jan 2025 16:48:24 -0500 Subject: [PATCH] [MemoryBanking] Operation-granularity banking configurations by attaching attributes (#8133) --- lib/Transforms/MemoryBanking.cpp | 194 +++++++++++++++++++--- test/Transforms/memory_banking_attrs.mlir | 123 ++++++++++++++ 2 files changed, 295 insertions(+), 22 deletions(-) create mode 100644 test/Transforms/memory_banking_attrs.mlir diff --git a/lib/Transforms/MemoryBanking.cpp b/lib/Transforms/MemoryBanking.cpp index 39a3a41dcbd6..eea1cf0c1e70 100644 --- a/lib/Transforms/MemoryBanking.cpp +++ b/lib/Transforms/MemoryBanking.cpp @@ -71,16 +71,22 @@ DenseSet collectMemRefs(mlir::affine::AffineParallelOp parOp) { return memrefVals; } -MemRefType computeBankedMemRefType(MemRefType originalType, - uint64_t bankingFactor, - unsigned bankingDimension) { +// Verify the banking configuration with different conditions. +void verifyBankingConfigurations(unsigned bankingDimension, + unsigned bankingFactor, + MemRefType originalType) { ArrayRef originalShape = originalType.getShape(); assert(!originalShape.empty() && "memref shape should not be empty"); - assert(bankingDimension < originalType.getRank() && "dimension must be within the memref rank"); assert(originalShape[bankingDimension] % bankingFactor == 0 && "memref shape must be evenly divided by the banking factor"); +} + +MemRefType computeBankedMemRefType(MemRefType originalType, + uint64_t bankingFactor, + unsigned bankingDimension) { + ArrayRef originalShape = originalType.getShape(); SmallVector newShape(originalShape.begin(), originalShape.end()); newShape[bankingDimension] /= bankingFactor; MemRefType newMemRefType = @@ -200,8 +206,9 @@ SmallVector handleGetGlobalOp(memref::GetGlobalOp getGlobalOp, return banks; } -unsigned getBankingDimension(std::optional bankingDimensionOpt, - int64_t rank, ArrayRef shape) { +unsigned getSpecifiedOrDefaultBankingDim(std::optional bankingDimensionOpt, + int64_t rank, + ArrayRef shape) { // If the banking dimension is already specified, return it. // Note, the banking dimension will always be nonempty because TableGen will // assign it with a default value -1 if it's not specified by the user. Thus, @@ -226,13 +233,114 @@ unsigned getBankingDimension(std::optional bankingDimensionOpt, return static_cast(bankingDimension); } -SmallVector createBanks(Value originalMem, uint64_t bankingFactor, +// Retrieve potentially specified banking factor/dimension attributes and +// overwrite the command line or the default ones. +void resolveBankingAttributes(Value originalMem, unsigned &bankingFactor, + unsigned &bankingDimension) { + if (auto *originalDef = originalMem.getDefiningOp()) { + if (auto attrFactor = dyn_cast_if_present( + originalDef->getAttr("banking.factor"))) + bankingFactor = attrFactor.getInt(); + if (auto attrDimension = dyn_cast_if_present( + originalDef->getAttr("banking.dimension"))) + bankingDimension = attrDimension.getInt(); + + return; + } + + if (isa(originalMem)) { + auto blockArg = cast(originalMem); + auto *parentOp = blockArg.getOwner()->getParentOp(); + + auto funcOp = dyn_cast(parentOp); + assert(funcOp && + "Expected the original memory to be a FuncOp block argument!"); + + unsigned argIndex = blockArg.getArgNumber(); + if (auto argAttrs = funcOp.getArgAttrDict(argIndex)) { + if (auto attrFactor = + dyn_cast_if_present(argAttrs.get("banking.factor"))) + bankingFactor = attrFactor.getInt(); + if (auto attrDimension = dyn_cast_if_present( + argAttrs.get("banking.dimension"))) + bankingDimension = attrDimension.getInt(); + } + + return; + } +} + +// Update the argument types of `funcOp` by inserting `numInsertedArgs` number +// of `newMemRefType` after `argIndex`. +void updateFuncOpArgumentTypes(func::FuncOp funcOp, unsigned argIndex, + MemRefType newMemRefType, + unsigned numInsertedArgs) { + auto originalArgTypes = funcOp.getFunctionType().getInputs(); + SmallVector updatedArgTypes; + + // Rebuild the argument types, inserting new types for the newly added + // arguments + for (unsigned i = 0; i < originalArgTypes.size(); ++i) { + updatedArgTypes.push_back(originalArgTypes[i]); + + // Insert new argument types after the specified argument index + if (i == argIndex) { + for (unsigned j = 0; j < numInsertedArgs; ++j) { + updatedArgTypes.push_back(newMemRefType); + } + } + } + + // Update the function type with the new argument types + auto resultTypes = funcOp.getFunctionType().getResults(); + auto newFuncType = + FunctionType::get(funcOp.getContext(), updatedArgTypes, resultTypes); + funcOp.setType(newFuncType); +} + +// Update `funcOp`'s "arg_attrs" by inserting `numInsertedArgs` number of empty +// DictionaryAttr after `argIndex`. +void updateFuncOpArgAttrs(func::FuncOp funcOp, unsigned argIndex, + unsigned numInsertedArgs) { + ArrayAttr existingArgAttrs = funcOp->getAttrOfType("arg_attrs"); + SmallVector updatedArgAttrs; + unsigned numArguments = funcOp.getNumArguments(); + unsigned newNumArguments = numArguments + numInsertedArgs; + updatedArgAttrs.resize(newNumArguments); + + // Copy existing attributes, adjusting for the new arguments + for (unsigned i = 0; i < numArguments; ++i) { + // Shift attributes for arguments after the inserted ones. + unsigned newIndex = (i > argIndex) ? i + numInsertedArgs : i; + updatedArgAttrs[newIndex] = existingArgAttrs + ? existingArgAttrs[i] + : DictionaryAttr::get(funcOp.getContext()); + } + + // Initialize new attributes for the inserted arguments as empty dictionaries + for (unsigned i = 0; i < numInsertedArgs; ++i) { + updatedArgAttrs[argIndex + 1 + i] = + DictionaryAttr::get(funcOp.getContext()); + } + + // Set the updated attributes. + funcOp->setAttr("arg_attrs", + ArrayAttr::get(funcOp.getContext(), updatedArgAttrs)); +} + +SmallVector createBanks(Value originalMem, unsigned bankingFactor, std::optional bankingDimensionOpt) { MemRefType originalMemRefType = cast(originalMem.getType()); unsigned rank = originalMemRefType.getRank(); ArrayRef shape = originalMemRefType.getShape(); - auto bankingDimension = getBankingDimension(bankingDimensionOpt, rank, shape); + unsigned bankingDimension = + getSpecifiedOrDefaultBankingDim(bankingDimensionOpt, rank, shape); + + resolveBankingAttributes(originalMem, bankingFactor, bankingDimension); + + verifyBankingConfigurations(bankingDimension, bankingFactor, + originalMemRefType); MemRefType newMemRefType = computeBankedMemRefType( originalMemRefType, bankingFactor, bankingDimension); @@ -248,6 +356,16 @@ SmallVector createBanks(Value originalMem, uint64_t bankingFactor, auto blockArgs = block->getArguments().slice(blockArgNum + 1, bankingFactor); banks.append(blockArgs.begin(), blockArgs.end()); + + auto *parentOp = block->getParentOp(); + auto funcOp = dyn_cast(parentOp); + assert(funcOp && "BlockArgument is not part of a FuncOp"); + // Update the ArgumentTypes of `funcOp` so that we can correctly get + // `getArgAttrDict` when resolving banking attributes across the iterations + // of creating new banks. + updateFuncOpArgumentTypes(funcOp, blockArgNum, newMemRefType, + bankingFactor); + updateFuncOpArgAttrs(funcOp, blockArgNum, bankingFactor); } else { Operation *originalDef = originalMem.getDefiningOp(); Location loc = originalDef->getLoc(); @@ -295,13 +413,20 @@ struct BankAffineLoadPattern LogicalResult matchAndRewrite(mlir::affine::AffineLoadOp loadOp, PatternRewriter &rewriter) const override { Location loc = loadOp.getLoc(); - auto banks = memoryToBanks[loadOp.getMemref()]; + auto originalMem = loadOp.getMemref(); + auto banks = memoryToBanks[originalMem]; auto loadIndices = loadOp.getIndices(); - int64_t memrefRank = loadOp.getMemRefType().getRank(); - ArrayRef shape = loadOp.getMemRefType().getShape(); + MemRefType originalMemRefType = loadOp.getMemRefType(); + int64_t memrefRank = originalMemRefType.getRank(); + ArrayRef shape = originalMemRefType.getShape(); auto bankingDimension = - getBankingDimension(bankingDimensionOpt, memrefRank, shape); + getSpecifiedOrDefaultBankingDim(bankingDimensionOpt, memrefRank, shape); + + resolveBankingAttributes(originalMem, bankingFactor, bankingDimension); + + verifyBankingConfigurations(bankingDimension, bankingFactor, + originalMemRefType); auto modMap = AffineMap::get( /*dimCount=*/memrefRank, /*symbolCount=*/0, @@ -355,8 +480,8 @@ struct BankAffineLoadPattern } private: - uint64_t bankingFactor; - std::optional bankingDimensionOpt; + mutable unsigned bankingFactor; + mutable std::optional bankingDimensionOpt; DenseMap> &memoryToBanks; DenseSet &oldMemRefVals; }; @@ -381,13 +506,20 @@ struct BankAffineStorePattern return failure(); } Location loc = storeOp.getLoc(); - auto banks = memoryToBanks[storeOp.getMemref()]; + auto originalMem = storeOp.getMemref(); + auto banks = memoryToBanks[originalMem]; auto storeIndices = storeOp.getIndices(); - int64_t memrefRank = storeOp.getMemRefType().getRank(); - ArrayRef shape = storeOp.getMemRefType().getShape(); + auto originalMemRefType = storeOp.getMemRefType(); + int64_t memrefRank = originalMemRefType.getRank(); + ArrayRef shape = originalMemRefType.getShape(); auto bankingDimension = - getBankingDimension(bankingDimensionOpt, memrefRank, shape); + getSpecifiedOrDefaultBankingDim(bankingDimensionOpt, memrefRank, shape); + + resolveBankingAttributes(originalMem, bankingFactor, bankingDimension); + + verifyBankingConfigurations(bankingDimension, bankingFactor, + originalMemRefType); auto modMap = AffineMap::get( /*dimCount=*/memrefRank, /*symbolCount=*/0, @@ -436,8 +568,8 @@ struct BankAffineStorePattern } private: - uint64_t bankingFactor; - std::optional bankingDimensionOpt; + mutable unsigned bankingFactor; + mutable std::optional bankingDimensionOpt; DenseMap> &memoryToBanks; DenseSet &opsToErase; DenseSet &processedOps; @@ -495,12 +627,15 @@ LogicalResult cleanUpOldMemRefs(DenseSet &oldMemRefVals, DenseSet &opsToErase) { DenseSet funcsToModify; SmallVector valuesToErase; + DenseMap> erasedArgIndices; for (auto &memrefVal : oldMemRefVals) { valuesToErase.push_back(memrefVal); if (auto blockArg = dyn_cast(memrefVal)) { if (auto funcOp = - dyn_cast(blockArg.getOwner()->getParentOp())) + dyn_cast(blockArg.getOwner()->getParentOp())) { funcsToModify.insert(funcOp); + erasedArgIndices[funcOp].push_back(blockArg.getArgNumber()); + } } } @@ -517,8 +652,23 @@ LogicalResult cleanUpOldMemRefs(DenseSet &oldMemRefVals, } } - // Modify the function type accordingly + // Modify the function argument attributes and function type accordingly for (auto funcOp : funcsToModify) { + ArrayAttr existingArgAttrs = funcOp->getAttrOfType("arg_attrs"); + if (existingArgAttrs) { + SmallVector updatedArgAttrs; + auto erasedIndices = erasedArgIndices[funcOp]; + DenseSet indicesToErase(erasedIndices.begin(), + erasedIndices.end()); + for (unsigned i = 0; i < existingArgAttrs.size(); ++i) { + if (!indicesToErase.contains(i)) + updatedArgAttrs.push_back(existingArgAttrs[i]); + } + + funcOp->setAttr("arg_attrs", + ArrayAttr::get(funcOp.getContext(), updatedArgAttrs)); + } + SmallVector newArgTypes; for (BlockArgument arg : funcOp.getArguments()) { newArgTypes.push_back(arg.getType()); diff --git a/test/Transforms/memory_banking_attrs.mlir b/test/Transforms/memory_banking_attrs.mlir new file mode 100644 index 000000000000..fd7fa3f06a44 --- /dev/null +++ b/test/Transforms/memory_banking_attrs.mlir @@ -0,0 +1,123 @@ +// RUN: circt-opt %s -memory-banking="factor=3 dimension=1" + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1) -> (d1 mod 5)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> (d1 floordiv 5)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1) -> (d0 mod 5)> +// CHECK: #[[$ATTR_3:.+]] = affine_map<(d0, d1) -> (d0 floordiv 5)> + +// CHECK-LABEL: func.func @main( +// CHECK-SAME: %[[VAL_0:arg0]]: memref<3x1xf32>, +// CHECK-SAME: %[[VAL_1:arg1]]: memref<3x1xf32>, +// CHECK-SAME: %[[VAL_2:arg2]]: memref<3x1xf32>, +// CHECK-SAME: %[[VAL_3:arg3]]: memref<3x1xf32>, +// CHECK-SAME: %[[VAL_4:arg4]]: memref<3x1xf32>, +// CHECK-SAME: %[[VAL_5:arg5]]: memref<1x3xf32>, +// CHECK-SAME: %[[VAL_6:arg6]]: memref<1x3xf32>, +// CHECK-SAME: %[[VAL_7:arg7]]: memref<1x3xf32>, +// CHECK-SAME: %[[VAL_8:arg8]]: memref<1x3xf32>, +// CHECK-SAME: %[[VAL_9:arg9]]: memref<1x3xf32>) -> (memref<1x3xf32>, memref<1x3xf32>, memref<1x3xf32>, memref<1x3xf32>, memref<1x3xf32>) { +// CHECK: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_11:.*]] = memref.alloc() : memref<1x3xf32> +// CHECK: %[[VAL_12:.*]] = memref.alloc() : memref<1x3xf32> +// CHECK: %[[VAL_13:.*]] = memref.alloc() : memref<1x3xf32> +// CHECK: %[[VAL_14:.*]] = memref.alloc() : memref<1x3xf32> +// CHECK: %[[VAL_15:.*]] = memref.alloc() : memref<1x3xf32> +// CHECK: affine.parallel (%[[VAL_16:.*]]) = (0) to (5) { +// CHECK: affine.parallel (%[[VAL_17:.*]]) = (0) to (3) { +// CHECK: %[[VAL_18:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_17]], %[[VAL_16]]) +// CHECK: %[[VAL_19:.*]] = affine.apply #[[$ATTR_1]](%[[VAL_17]], %[[VAL_16]]) +// CHECK: %[[VAL_20:.*]] = scf.index_switch %[[VAL_18]] -> f32 +// CHECK: case 0 { +// CHECK: %[[VAL_21:.*]] = affine.load %[[VAL_0]]{{\[}}%[[VAL_17]], %[[VAL_19]]] : memref<3x1xf32> +// CHECK: scf.yield %[[VAL_21]] : f32 +// CHECK: } +// CHECK: case 1 { +// CHECK: %[[VAL_22:.*]] = affine.load %[[VAL_1]]{{\[}}%[[VAL_17]], %[[VAL_19]]] : memref<3x1xf32> +// CHECK: scf.yield %[[VAL_22]] : f32 +// CHECK: } +// CHECK: case 2 { +// CHECK: %[[VAL_23:.*]] = affine.load %[[VAL_2]]{{\[}}%[[VAL_17]], %[[VAL_19]]] : memref<3x1xf32> +// CHECK: scf.yield %[[VAL_23]] : f32 +// CHECK: } +// CHECK: case 3 { +// CHECK: %[[VAL_24:.*]] = affine.load %[[VAL_3]]{{\[}}%[[VAL_17]], %[[VAL_19]]] : memref<3x1xf32> +// CHECK: scf.yield %[[VAL_24]] : f32 +// CHECK: } +// CHECK: case 4 { +// CHECK: %[[VAL_25:.*]] = affine.load %[[VAL_4]]{{\[}}%[[VAL_17]], %[[VAL_19]]] : memref<3x1xf32> +// CHECK: scf.yield %[[VAL_25]] : f32 +// CHECK: } +// CHECK: default { +// CHECK: scf.yield %[[VAL_10]] : f32 +// CHECK: } +// CHECK: %[[VAL_26:.*]] = affine.apply #[[$ATTR_2]](%[[VAL_16]], %[[VAL_17]]) +// CHECK: %[[VAL_27:.*]] = affine.apply #[[$ATTR_3]](%[[VAL_16]], %[[VAL_17]]) +// CHECK: %[[VAL_28:.*]] = scf.index_switch %[[VAL_26]] -> f32 +// CHECK: case 0 { +// CHECK: %[[VAL_29:.*]] = affine.load %[[VAL_5]]{{\[}}%[[VAL_27]], %[[VAL_17]]] : memref<1x3xf32> +// CHECK: scf.yield %[[VAL_29]] : f32 +// CHECK: } +// CHECK: case 1 { +// CHECK: %[[VAL_30:.*]] = affine.load %[[VAL_6]]{{\[}}%[[VAL_27]], %[[VAL_17]]] : memref<1x3xf32> +// CHECK: scf.yield %[[VAL_30]] : f32 +// CHECK: } +// CHECK: case 2 { +// CHECK: %[[VAL_31:.*]] = affine.load %[[VAL_7]]{{\[}}%[[VAL_27]], %[[VAL_17]]] : memref<1x3xf32> +// CHECK: scf.yield %[[VAL_31]] : f32 +// CHECK: } +// CHECK: case 3 { +// CHECK: %[[VAL_32:.*]] = affine.load %[[VAL_8]]{{\[}}%[[VAL_27]], %[[VAL_17]]] : memref<1x3xf32> +// CHECK: scf.yield %[[VAL_32]] : f32 +// CHECK: } +// CHECK: case 4 { +// CHECK: %[[VAL_33:.*]] = affine.load %[[VAL_9]]{{\[}}%[[VAL_27]], %[[VAL_17]]] : memref<1x3xf32> +// CHECK: scf.yield %[[VAL_33]] : f32 +// CHECK: } +// CHECK: default { +// CHECK: scf.yield %[[VAL_10]] : f32 +// CHECK: } +// CHECK: %[[VAL_34:.*]] = arith.mulf %[[VAL_20]], %[[VAL_28]] : f32 +// CHECK: %[[VAL_35:.*]] = affine.apply #[[$ATTR_2]](%[[VAL_16]], %[[VAL_17]]) +// CHECK: %[[VAL_36:.*]] = affine.apply #[[$ATTR_3]](%[[VAL_16]], %[[VAL_17]]) +// CHECK: scf.index_switch %[[VAL_35]] +// CHECK: case 0 { +// CHECK: affine.store %[[VAL_34]], %[[VAL_11]]{{\[}}%[[VAL_36]], %[[VAL_17]]] : memref<1x3xf32> +// CHECK: scf.yield +// CHECK: } +// CHECK: case 1 { +// CHECK: affine.store %[[VAL_34]], %[[VAL_12]]{{\[}}%[[VAL_36]], %[[VAL_17]]] : memref<1x3xf32> +// CHECK: scf.yield +// CHECK: } +// CHECK: case 2 { +// CHECK: affine.store %[[VAL_34]], %[[VAL_13]]{{\[}}%[[VAL_36]], %[[VAL_17]]] : memref<1x3xf32> +// CHECK: scf.yield +// CHECK: } +// CHECK: case 3 { +// CHECK: affine.store %[[VAL_34]], %[[VAL_14]]{{\[}}%[[VAL_36]], %[[VAL_17]]] : memref<1x3xf32> +// CHECK: scf.yield +// CHECK: } +// CHECK: case 4 { +// CHECK: affine.store %[[VAL_34]], %[[VAL_15]]{{\[}}%[[VAL_36]], %[[VAL_17]]] : memref<1x3xf32> +// CHECK: scf.yield +// CHECK: } +// CHECK: default { +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: return %[[VAL_11]], %[[VAL_12]], %[[VAL_13]], %[[VAL_14]], %[[VAL_15]] : memref<1x3xf32>, memref<1x3xf32>, memref<1x3xf32>, memref<1x3xf32>, memref<1x3xf32> +// CHECK: } + +module { + func.func @main(%arg0: memref<3x5xf32> {banking.factor=5, banking.dimension=1}, %arg1: memref<5x3xf32>{banking.factor=5, banking.dimension=0}) -> (memref<5x3xf32>) { + %mem = memref.alloc() {banking.factor=5, banking.dimension=0} : memref<5x3xf32> + affine.parallel (%i) = (0) to (5) { + affine.parallel (%j) = (0) to (3) { + %1 = affine.load %arg0[%j, %i] : memref<3x5xf32> + %2 = affine.load %arg1[%i, %j] : memref<5x3xf32> + %3 = arith.mulf %1, %2 : f32 + affine.store %3, %mem[%i, %j] : memref<5x3xf32> + } + } + return %mem : memref<5x3xf32> + } +}