-
Notifications
You must be signed in to change notification settings - Fork 50
Mutli Buffer Support #615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Mutli Buffer Support #615
Changes from 2 commits
135014d
b4d023e
615a3b7
7e1a3e5
d7db3f8
a6b7b01
e24b5a1
868d1a4
1111ec1
63a08ec
6a1bc22
2c682fd
97ebf4f
9a7047c
bfacd43
87a809c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| // Copyright (c) 2026 Huawei Technologies Co., Ltd. | ||
| // This program is free software, you can redistribute it and/or modify it under the terms and conditions of | ||
| // CANN Open Software License Agreement Version 2.0 (the "License"). | ||
| // Please refer to the License for details. You may not use this file except in compliance with the License. | ||
| // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, | ||
| // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. | ||
| // See LICENSE in the root of the software repository for the full text of the License. | ||
|
|
||
| #ifndef PTO_TRANSFORMS_MULTIBUFFER_H | ||
| #define PTO_TRANSFORMS_MULTIBUFFER_H | ||
|
|
||
| #include "llvm/ADT/StringRef.h" | ||
|
|
||
| namespace mlir { | ||
| namespace pto { | ||
|
|
||
| /// Attribute name for multi-buffer depth on `memref.alloc` (integer slot count N>=2). | ||
| inline constexpr llvm::StringLiteral kPtoMultiBufferAttrName = "pto.multi_buffer"; | ||
|
|
||
| /// Upper bound for N; must stay consistent with `MAX_MULTI_BUFFER_NUM` in insert-sync. | ||
| inline constexpr unsigned kPtoMultiBufferMaxNum = 16; | ||
|
|
||
| } // namespace pto | ||
| } // namespace mlir | ||
|
|
||
| #endif // PTO_TRANSFORMS_MULTIBUFFER_H |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| // Copyright (c) 2026 Huawei Technologies Co., Ltd. | ||
| // This program is free software, you can redistribute it and/or modify it under the terms and conditions of | ||
| // CANN Open Software License Agreement Version 2.0 (the "License"). | ||
| // Please refer to the License for details. You may not use this file except in compliance with the License. | ||
| // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, | ||
| // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. | ||
| // See LICENSE in the root of the software repository for the full text of the License. | ||
|
|
||
| #include "PTO/IR/PTO.h" | ||
| #include "PTO/Transforms/Passes.h" | ||
| #include "mlir/Dialect/Arith/IR/Arith.h" | ||
| #include "mlir/Dialect/Func/IR/FuncOps.h" | ||
| #include "mlir/Dialect/SCF/IR/SCF.h" | ||
| #include "mlir/IR/PatternMatch.h" | ||
| #include "mlir/Pass/Pass.h" | ||
|
|
||
| namespace mlir { | ||
| namespace pto { | ||
| #define GEN_PASS_DEF_PTOENABLEMULTIBUFFER | ||
| #include "PTO/Transforms/Passes.h.inc" | ||
| } // namespace pto | ||
| } // namespace mlir | ||
|
|
||
| using namespace mlir; | ||
| using namespace mlir::pto; | ||
|
|
||
| namespace { | ||
|
|
||
| static LogicalResult lowerMultiBufferPointerCast(IRRewriter &rewriter, | ||
| PointerCastOp op, | ||
| scf::ForOp forOp) { | ||
| ValueRange addrs = op.getAddrs(); | ||
| unsigned n = static_cast<unsigned>(addrs.size()); | ||
| assert(n >= 2); | ||
|
|
||
| Location loc = op.getLoc(); | ||
| MemRefType resTy = op.getType(); | ||
| Value validRow = op.getValidRow(); | ||
| Value validCol = op.getValidCol(); | ||
| std::optional<TileBufConfigAttr> config = op.getConfig(); | ||
|
|
||
| rewriter.setInsertionPoint(forOp); | ||
| SmallVector<Value> slotBufs; | ||
| slotBufs.reserve(n); | ||
| for (unsigned i = 0; i < n; ++i) { | ||
| auto oneAddr = addrs.slice(i, 1); | ||
| PointerCastOp slot = rewriter.create<PointerCastOp>( | ||
| loc, resTy, oneAddr, validRow, validCol, | ||
| config.has_value() | ||
| ? static_cast<Attribute>(*config) | ||
| : Attribute()); | ||
| slotBufs.push_back(slot.getResult()); | ||
| } | ||
|
Comment on lines
+65
to
+76
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The pass hoists |
||
|
|
||
| rewriter.setInsertionPointToStart(forOp.getBody()); | ||
| Value iv = forOp.getInductionVar(); | ||
| Value cN = rewriter.create<arith::ConstantIndexOp>(loc, n); | ||
| Value rem = rewriter.create<arith::RemUIOp>(loc, iv, cN); | ||
|
|
||
| Value selected = slotBufs[0]; | ||
| for (unsigned i = 1; i < n; ++i) { | ||
| Value ci = rewriter.create<arith::ConstantIndexOp>(loc, i); | ||
| Value eq = rewriter.create<arith::CmpIOp>( | ||
| loc, arith::CmpIPredicate::eq, rem, ci); | ||
| selected = | ||
| rewriter.create<arith::SelectOp>(loc, eq, slotBufs[i], selected); | ||
| } | ||
|
|
||
| rewriter.replaceOp(op, selected); | ||
| return success(); | ||
| } | ||
|
|
||
| struct PTOEnableMultiBufferPass | ||
| : public mlir::pto::impl::PTOEnableMultiBufferBase< | ||
| PTOEnableMultiBufferPass> { | ||
| void runOnOperation() override { | ||
| func::FuncOp func = getOperation(); | ||
| SmallVector<PointerCastOp> work; | ||
| func.walk([&](PointerCastOp op) { | ||
| if (op.getAddrs().size() > 1) | ||
| work.push_back(op); | ||
| }); | ||
|
|
||
| IRRewriter rewriter(&getContext()); | ||
| for (PointerCastOp op : work) { | ||
| auto forOp = op->getParentOfType<scf::ForOp>(); | ||
| if (!forOp) { | ||
| op.emitWarning() | ||
| << "pto-enable-multi-buffer: expected enclosing scf.for; skipping"; | ||
| continue; | ||
| } | ||
| if (failed(lowerMultiBufferPointerCast(rewriter, op, forOp))) { | ||
| signalPassFailure(); | ||
| return; | ||
| } | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| } // namespace | ||
|
|
||
| std::unique_ptr<Pass> mlir::pto::createPTOEnableMultiBufferPass() { | ||
| return std::make_unique<PTOEnableMultiBufferPass>(); | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The multi-buffer lowering hoists new single-address
pto.pointer_castops before the loop, but those ops still consumevalid_row/valid_col. Onlyaddrsare checked for loop-invariance, so if either valid-dim operand is defined inside the loop (e.g., loop-dependent dynamic bounds), the transformed IR will reference out-of-scope values and break SSA correctness. The pass should either validate these operands are defined outside the loop or keep the cast inside the loop when they are not.Useful? React with 👍 / 👎.