Skip to content

Commit

Permalink
[MemoryBanking] Operation-granularity banking configurations by attac…
Browse files Browse the repository at this point in the history
…hing attributes (#8133)
  • Loading branch information
jiahanxie353 authored Jan 27, 2025
1 parent 335e50c commit 1c70fce
Show file tree
Hide file tree
Showing 2 changed files with 295 additions and 22 deletions.
194 changes: 172 additions & 22 deletions lib/Transforms/MemoryBanking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,22 @@ DenseSet<Value> collectMemRefs(mlir::affine::AffineParallelOp parOp) {
return memrefVals;
}

MemRefType computeBankedMemRefType(MemRefType originalType,
uint64_t bankingFactor,
unsigned bankingDimension) {
// Verify the banking configuration with different conditions.
void verifyBankingConfigurations(unsigned bankingDimension,
unsigned bankingFactor,
MemRefType originalType) {
ArrayRef<int64_t> originalShape = originalType.getShape();
assert(!originalShape.empty() && "memref shape should not be empty");

assert(bankingDimension < originalType.getRank() &&
"dimension must be within the memref rank");
assert(originalShape[bankingDimension] % bankingFactor == 0 &&
"memref shape must be evenly divided by the banking factor");
}

MemRefType computeBankedMemRefType(MemRefType originalType,
uint64_t bankingFactor,
unsigned bankingDimension) {
ArrayRef<int64_t> originalShape = originalType.getShape();
SmallVector<int64_t, 4> newShape(originalShape.begin(), originalShape.end());
newShape[bankingDimension] /= bankingFactor;
MemRefType newMemRefType =
Expand Down Expand Up @@ -200,8 +206,9 @@ SmallVector<Value, 4> handleGetGlobalOp(memref::GetGlobalOp getGlobalOp,
return banks;
}

unsigned getBankingDimension(std::optional<int> bankingDimensionOpt,
int64_t rank, ArrayRef<int64_t> shape) {
unsigned getSpecifiedOrDefaultBankingDim(std::optional<int> bankingDimensionOpt,
int64_t rank,
ArrayRef<int64_t> shape) {
// If the banking dimension is already specified, return it.
// Note, the banking dimension will always be nonempty because TableGen will
// assign it with a default value -1 if it's not specified by the user. Thus,
Expand All @@ -226,13 +233,114 @@ unsigned getBankingDimension(std::optional<int> bankingDimensionOpt,
return static_cast<unsigned>(bankingDimension);
}

SmallVector<Value, 4> createBanks(Value originalMem, uint64_t bankingFactor,
// Retrieve potentially specified banking factor/dimension attributes and
// overwrite the command line or the default ones.
void resolveBankingAttributes(Value originalMem, unsigned &bankingFactor,
unsigned &bankingDimension) {
if (auto *originalDef = originalMem.getDefiningOp()) {
if (auto attrFactor = dyn_cast_if_present<IntegerAttr>(
originalDef->getAttr("banking.factor")))
bankingFactor = attrFactor.getInt();
if (auto attrDimension = dyn_cast_if_present<IntegerAttr>(
originalDef->getAttr("banking.dimension")))
bankingDimension = attrDimension.getInt();

return;
}

if (isa<BlockArgument>(originalMem)) {
auto blockArg = cast<BlockArgument>(originalMem);
auto *parentOp = blockArg.getOwner()->getParentOp();

auto funcOp = dyn_cast<func::FuncOp>(parentOp);
assert(funcOp &&
"Expected the original memory to be a FuncOp block argument!");

unsigned argIndex = blockArg.getArgNumber();
if (auto argAttrs = funcOp.getArgAttrDict(argIndex)) {
if (auto attrFactor =
dyn_cast_if_present<IntegerAttr>(argAttrs.get("banking.factor")))
bankingFactor = attrFactor.getInt();
if (auto attrDimension = dyn_cast_if_present<IntegerAttr>(
argAttrs.get("banking.dimension")))
bankingDimension = attrDimension.getInt();
}

return;
}
}

// Update the argument types of `funcOp` by inserting `numInsertedArgs` number
// of `newMemRefType` after `argIndex`.
void updateFuncOpArgumentTypes(func::FuncOp funcOp, unsigned argIndex,
MemRefType newMemRefType,
unsigned numInsertedArgs) {
auto originalArgTypes = funcOp.getFunctionType().getInputs();
SmallVector<Type, 4> updatedArgTypes;

// Rebuild the argument types, inserting new types for the newly added
// arguments
for (unsigned i = 0; i < originalArgTypes.size(); ++i) {
updatedArgTypes.push_back(originalArgTypes[i]);

// Insert new argument types after the specified argument index
if (i == argIndex) {
for (unsigned j = 0; j < numInsertedArgs; ++j) {
updatedArgTypes.push_back(newMemRefType);
}
}
}

// Update the function type with the new argument types
auto resultTypes = funcOp.getFunctionType().getResults();
auto newFuncType =
FunctionType::get(funcOp.getContext(), updatedArgTypes, resultTypes);
funcOp.setType(newFuncType);
}

// Update `funcOp`'s "arg_attrs" by inserting `numInsertedArgs` number of empty
// DictionaryAttr after `argIndex`.
void updateFuncOpArgAttrs(func::FuncOp funcOp, unsigned argIndex,
unsigned numInsertedArgs) {
ArrayAttr existingArgAttrs = funcOp->getAttrOfType<ArrayAttr>("arg_attrs");
SmallVector<Attribute, 4> updatedArgAttrs;
unsigned numArguments = funcOp.getNumArguments();
unsigned newNumArguments = numArguments + numInsertedArgs;
updatedArgAttrs.resize(newNumArguments);

// Copy existing attributes, adjusting for the new arguments
for (unsigned i = 0; i < numArguments; ++i) {
// Shift attributes for arguments after the inserted ones.
unsigned newIndex = (i > argIndex) ? i + numInsertedArgs : i;
updatedArgAttrs[newIndex] = existingArgAttrs
? existingArgAttrs[i]
: DictionaryAttr::get(funcOp.getContext());
}

// Initialize new attributes for the inserted arguments as empty dictionaries
for (unsigned i = 0; i < numInsertedArgs; ++i) {
updatedArgAttrs[argIndex + 1 + i] =
DictionaryAttr::get(funcOp.getContext());
}

// Set the updated attributes.
funcOp->setAttr("arg_attrs",
ArrayAttr::get(funcOp.getContext(), updatedArgAttrs));
}

SmallVector<Value, 4> createBanks(Value originalMem, unsigned bankingFactor,
std::optional<int> bankingDimensionOpt) {
MemRefType originalMemRefType = cast<MemRefType>(originalMem.getType());
unsigned rank = originalMemRefType.getRank();
ArrayRef<int64_t> shape = originalMemRefType.getShape();

auto bankingDimension = getBankingDimension(bankingDimensionOpt, rank, shape);
unsigned bankingDimension =
getSpecifiedOrDefaultBankingDim(bankingDimensionOpt, rank, shape);

resolveBankingAttributes(originalMem, bankingFactor, bankingDimension);

verifyBankingConfigurations(bankingDimension, bankingFactor,
originalMemRefType);

MemRefType newMemRefType = computeBankedMemRefType(
originalMemRefType, bankingFactor, bankingDimension);
Expand All @@ -248,6 +356,16 @@ SmallVector<Value, 4> createBanks(Value originalMem, uint64_t bankingFactor,
auto blockArgs =
block->getArguments().slice(blockArgNum + 1, bankingFactor);
banks.append(blockArgs.begin(), blockArgs.end());

auto *parentOp = block->getParentOp();
auto funcOp = dyn_cast<func::FuncOp>(parentOp);
assert(funcOp && "BlockArgument is not part of a FuncOp");
// Update the ArgumentTypes of `funcOp` so that we can correctly get
// `getArgAttrDict` when resolving banking attributes across the iterations
// of creating new banks.
updateFuncOpArgumentTypes(funcOp, blockArgNum, newMemRefType,
bankingFactor);
updateFuncOpArgAttrs(funcOp, blockArgNum, bankingFactor);
} else {
Operation *originalDef = originalMem.getDefiningOp();
Location loc = originalDef->getLoc();
Expand Down Expand Up @@ -295,13 +413,20 @@ struct BankAffineLoadPattern
LogicalResult matchAndRewrite(mlir::affine::AffineLoadOp loadOp,
PatternRewriter &rewriter) const override {
Location loc = loadOp.getLoc();
auto banks = memoryToBanks[loadOp.getMemref()];
auto originalMem = loadOp.getMemref();
auto banks = memoryToBanks[originalMem];
auto loadIndices = loadOp.getIndices();
int64_t memrefRank = loadOp.getMemRefType().getRank();
ArrayRef<int64_t> shape = loadOp.getMemRefType().getShape();
MemRefType originalMemRefType = loadOp.getMemRefType();
int64_t memrefRank = originalMemRefType.getRank();
ArrayRef<int64_t> shape = originalMemRefType.getShape();

auto bankingDimension =
getBankingDimension(bankingDimensionOpt, memrefRank, shape);
getSpecifiedOrDefaultBankingDim(bankingDimensionOpt, memrefRank, shape);

resolveBankingAttributes(originalMem, bankingFactor, bankingDimension);

verifyBankingConfigurations(bankingDimension, bankingFactor,
originalMemRefType);

auto modMap = AffineMap::get(
/*dimCount=*/memrefRank, /*symbolCount=*/0,
Expand Down Expand Up @@ -355,8 +480,8 @@ struct BankAffineLoadPattern
}

private:
uint64_t bankingFactor;
std::optional<int> bankingDimensionOpt;
mutable unsigned bankingFactor;
mutable std::optional<int> bankingDimensionOpt;
DenseMap<Value, SmallVector<Value>> &memoryToBanks;
DenseSet<Value> &oldMemRefVals;
};
Expand All @@ -381,13 +506,20 @@ struct BankAffineStorePattern
return failure();
}
Location loc = storeOp.getLoc();
auto banks = memoryToBanks[storeOp.getMemref()];
auto originalMem = storeOp.getMemref();
auto banks = memoryToBanks[originalMem];
auto storeIndices = storeOp.getIndices();
int64_t memrefRank = storeOp.getMemRefType().getRank();
ArrayRef<int64_t> shape = storeOp.getMemRefType().getShape();
auto originalMemRefType = storeOp.getMemRefType();
int64_t memrefRank = originalMemRefType.getRank();
ArrayRef<int64_t> shape = originalMemRefType.getShape();

auto bankingDimension =
getBankingDimension(bankingDimensionOpt, memrefRank, shape);
getSpecifiedOrDefaultBankingDim(bankingDimensionOpt, memrefRank, shape);

resolveBankingAttributes(originalMem, bankingFactor, bankingDimension);

verifyBankingConfigurations(bankingDimension, bankingFactor,
originalMemRefType);

auto modMap = AffineMap::get(
/*dimCount=*/memrefRank, /*symbolCount=*/0,
Expand Down Expand Up @@ -436,8 +568,8 @@ struct BankAffineStorePattern
}

private:
uint64_t bankingFactor;
std::optional<int> bankingDimensionOpt;
mutable unsigned bankingFactor;
mutable std::optional<int> bankingDimensionOpt;
DenseMap<Value, SmallVector<Value>> &memoryToBanks;
DenseSet<Operation *> &opsToErase;
DenseSet<Operation *> &processedOps;
Expand Down Expand Up @@ -495,12 +627,15 @@ LogicalResult cleanUpOldMemRefs(DenseSet<Value> &oldMemRefVals,
DenseSet<Operation *> &opsToErase) {
DenseSet<func::FuncOp> funcsToModify;
SmallVector<Value, 4> valuesToErase;
DenseMap<func::FuncOp, SmallVector<unsigned, 4>> erasedArgIndices;
for (auto &memrefVal : oldMemRefVals) {
valuesToErase.push_back(memrefVal);
if (auto blockArg = dyn_cast<BlockArgument>(memrefVal)) {
if (auto funcOp =
dyn_cast<func::FuncOp>(blockArg.getOwner()->getParentOp()))
dyn_cast<func::FuncOp>(blockArg.getOwner()->getParentOp())) {
funcsToModify.insert(funcOp);
erasedArgIndices[funcOp].push_back(blockArg.getArgNumber());
}
}
}

Expand All @@ -517,8 +652,23 @@ LogicalResult cleanUpOldMemRefs(DenseSet<Value> &oldMemRefVals,
}
}

// Modify the function type accordingly
// Modify the function argument attributes and function type accordingly
for (auto funcOp : funcsToModify) {
ArrayAttr existingArgAttrs = funcOp->getAttrOfType<ArrayAttr>("arg_attrs");
if (existingArgAttrs) {
SmallVector<Attribute, 4> updatedArgAttrs;
auto erasedIndices = erasedArgIndices[funcOp];
DenseSet<unsigned> indicesToErase(erasedIndices.begin(),
erasedIndices.end());
for (unsigned i = 0; i < existingArgAttrs.size(); ++i) {
if (!indicesToErase.contains(i))
updatedArgAttrs.push_back(existingArgAttrs[i]);
}

funcOp->setAttr("arg_attrs",
ArrayAttr::get(funcOp.getContext(), updatedArgAttrs));
}

SmallVector<Type, 4> newArgTypes;
for (BlockArgument arg : funcOp.getArguments()) {
newArgTypes.push_back(arg.getType());
Expand Down
Loading

0 comments on commit 1c70fce

Please sign in to comment.