Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
135014d
PTOPlanMemory: align with HIVM design (SPEC_LEVEL_3 + retry loop)
zhangstevenunity May 2, 2026
b4d023e
Co-authored-by:
zhangstevenunity May 2, 2026
615a3b7
multi-buffer: wire up HIVM-style multi event-id deduction (P0)
zhangstevenunity May 3, 2026
7e1a3e5
multi-buffer: harden plan/sync paths and add fallback (P1)
zhangstevenunity May 3, 2026
d7db3f8
multi-buffer: refine sync semantics and SPEC_LEVEL_1 reuse (P2)
zhangstevenunity May 3, 2026
a6b7b01
multi-buffer: require actual pipe conflict in SPEC_LEVEL_2 filter
zhangstevenunity May 3, 2026
e24b5a1
Merge branch 'main' into enhance-pto-plan-memory-spec-level3-retry
zhangstevenunity May 3, 2026
868d1a4
multi-buffer/sync: drop redundant same-pipe back-edge PIPE_BARRIER
zhangstevenunity May 3, 2026
1111ec1
Merge remote-tracking branch 'upstream/main' into claude/distracted-l…
zhangstevenunity May 4, 2026
63a08ec
Merge remote-tracking branch 'upstream/main' into claude/distracted-l…
zhangstevenunity May 4, 2026
6a1bc22
gss/multi-buffer: HIVM-style event-id deduction + dyn flag codegen
zhangstevenunity May 5, 2026
2c682fd
docs: add PTOAS multi-buffer design
zhangstevenunity May 5, 2026
97ebf4f
Revise PTOAS PR615 multi-buffer design documentation
zhangstevenunity May 5, 2026
9a7047c
Fix PR615 GSS multibuffer merge with main
zhangstevenunity May 9, 2026
bfacd43
Merge branch 'main' into enhance-pto-plan-memory-spec-level3-retry
zhangstevenunity May 9, 2026
87a809c
feat(tilebuf): support multi-buffer expression
zhangstevenunity May 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions include/PTO/Transforms/InsertSync/SyncCodegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ class SyncCodegen {
// 记录 Op -> Sync 的映射
DenseMap<const Operation *, SyncPipeBuild> op2InsertSync;

// 记录 Loop -> Counter 的映射 (缓存)
DenseMap<Operation *, Value> loop2BufferCounter;
// 记录 Loop -> ( Counter value , modulo N ) 的映射 (缓存)
DenseMap<Operation *, std::pair<Value, unsigned>> loop2BufferCounter;

// 记录 SyncIndex -> EventID Value 的映射 (缓存)
DenseMap<unsigned, Value> SyncIndex2SelectBuffer;
Expand All @@ -97,4 +97,4 @@ class SyncCodegen {
} // namespace pto
} // namespace mlir

#endif // MLIR_DIALECT_PTO_TRANSFORMS_INJECTSYNC_SYNCCODEGEN_HN_H
#endif // MLIR_DIALECT_PTO_TRANSFORMS_INJECTSYNC_SYNCCODEGEN_H
15 changes: 11 additions & 4 deletions include/PTO/Transforms/InsertSync/SyncCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,11 @@ enum class TCoreType {
struct BaseMemInfo {
BaseMemInfo(
Value baseBuffer, Value rootBuffer, pto::AddressSpace scope,
SmallVector<uint64_t> baseAddresses, uint64_t allocateSize)
SmallVector<uint64_t> baseAddresses, uint64_t allocateSize,
bool hasVariableAddress = false)
: baseBuffer(baseBuffer), rootBuffer(rootBuffer), scope(scope),
baseAddresses(std::move(baseAddresses)), allocateSize(allocateSize) {}
baseAddresses(std::move(baseAddresses)), allocateSize(allocateSize),
hasVariableAddress(hasVariableAddress) {}

/// baseBuffer: 当前操作直接使用的 Buffer (可能是 View 或 Alias)
Value baseBuffer;
Expand All @@ -98,6 +100,8 @@ struct BaseMemInfo {
pto::AddressSpace scope;
SmallVector<uint64_t> baseAddresses; // 用于 Offset 分析
uint64_t allocateSize;
/// True when pointer/workspace addresses are not compile-time constants.
bool hasVariableAddress{false};

bool areVectorEqual(const SmallVector<uint64_t>& vec1,
const SmallVector<uint64_t>& vec2) const {
Expand All @@ -116,17 +120,20 @@ struct BaseMemInfo {
// 但为了保持原有逻辑,先保留。重点是 rootBuffer 必须一致。
if (allocateSize != other.allocateSize) return false;
if (baseBuffer != other.baseBuffer) return false;
if (hasVariableAddress != other.hasVariableAddress) return false;
return true;
}

std::unique_ptr<BaseMemInfo> clone() const {
return std::make_unique<BaseMemInfo>(
baseBuffer, rootBuffer, scope, baseAddresses, allocateSize);
baseBuffer, rootBuffer, scope, baseAddresses, allocateSize,
hasVariableAddress);
}

std::unique_ptr<BaseMemInfo> clone(Value cloneBaseBuffer) const {
return std::make_unique<BaseMemInfo>(
cloneBaseBuffer, rootBuffer, scope, baseAddresses, allocateSize);
cloneBaseBuffer, rootBuffer, scope, baseAddresses, allocateSize,
hasVariableAddress);
}
};

Expand Down
26 changes: 26 additions & 0 deletions include/PTO/Transforms/MultiBuffer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) 2026 Huawei Technologies Co., Ltd.
// This program is free software, you can redistribute it and/or modify it under the terms and conditions of
// CANN Open Software License Agreement Version 2.0 (the "License").
// Please refer to the License for details. You may not use this file except in compliance with the License.
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
// See LICENSE in the root of the software repository for the full text of the License.

#ifndef PTO_TRANSFORMS_MULTIBUFFER_H
#define PTO_TRANSFORMS_MULTIBUFFER_H

#include "llvm/ADT/StringRef.h"

namespace mlir {
namespace pto {

/// Attribute name for multi-buffer depth on `memref.alloc` (integer slot count N>=2).
inline constexpr llvm::StringLiteral kPtoMultiBufferAttrName = "pto.multi_buffer";

/// Upper bound for N; must stay consistent with `MAX_MULTI_BUFFER_NUM` in insert-sync.
inline constexpr unsigned kPtoMultiBufferMaxNum = 16;

} // namespace pto
} // namespace mlir

#endif // PTO_TRANSFORMS_MULTIBUFFER_H
1 change: 1 addition & 0 deletions include/PTO/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ std::unique_ptr<Pass>
createPlanMemoryPass(const PlanMemoryOptions &planMemoryOption = {});

std::unique_ptr<Pass> createPTORemoveRedundantBarrierPass();
std::unique_ptr<Pass> createPTOEnableMultiBufferPass();
std::unique_ptr<Pass> createPTOViewToMemrefPass();
std::unique_ptr<Pass> createInferPTOLayoutPass();
std::unique_ptr<Pass> createPTOA5NormalizeTMovPass();
Expand Down
19 changes: 19 additions & 0 deletions include/PTO/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,25 @@ def PTOVerifyTFree : Pass<"pto-verify-tfree", "func::FuncOp"> {
];
}

def PTOEnableMultiBuffer : Pass<"pto-enable-multi-buffer", "func::FuncOp"> {
let summary = "Lower variadic pto.pointer_cast with multi-buffer addrs into "
"single-address casts plus a per-iteration arith.select";
let description = [{
Mirrors HIVM's `EnableMultiBuffer` lowering: takes a `pto.pointer_cast` with
N>1 address operands, hoists each address into its own single-address
`pto.pointer_cast` outside the parent `scf.for`, then replaces the original
multi-address cast with an N-way `arith.select` chain driven by `iv mod N`.
Runs after `pto-insert-sync` so the multi-address `pto.pointer_cast` stays
visible to dependency analysis.
}];
let constructor = "mlir::pto::createPTOEnableMultiBufferPass()";
let dependentDialects = [
"mlir::pto::PTODialect",
"mlir::scf::SCFDialect",
"mlir::arith::ArithDialect"
];
}

def PTOViewToMemref : Pass<"pto-view-to-memref", "ModuleOp"> {
let summary = "Lower PTO views to memref with Metadata Binding";
let description = [{
Expand Down
1 change: 1 addition & 0 deletions lib/PTO/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ add_mlir_dialect_library(PTOTransforms
InsertSync/RemoveRedundantSync.cpp
InsertSync/SyncEventIdAllocation.cpp
InsertSync/SyncCodegen.cpp
PTOEnableMultiBuffer.cpp
LoweringSyncToPipe.cpp
PTOVerifyTFreePass.cpp

Expand Down
87 changes: 57 additions & 30 deletions lib/PTO/Transforms/InsertSync/SyncCodegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,18 @@ static void createSetOrWaitFlagOp(IRRewriter &rewriter, Operation *op,
}
rewriter.create<pto::SetFlagOp>(op->getLoc(), srcPipe, dstPipe, eventId);
}


static void createSetOrWaitFlagDynOp(IRRewriter &rewriter, Operation *op,
SyncOperation *sync, pto::PipeAttr srcPipe,
pto::PipeAttr dstPipe, Value eventIndex) {
if (sync->isSyncWaitType()) {
rewriter.create<pto::WaitFlagDynOp>(op->getLoc(), srcPipe, dstPipe,
eventIndex);
return;
}
rewriter.create<pto::SetFlagDynOp>(op->getLoc(), srcPipe, dstPipe, eventIndex);
}

// ==============================================================================
// 2. SyncCodegen Implementation
// ==============================================================================
Expand Down Expand Up @@ -267,12 +278,12 @@ void SyncCodegen::SyncInsert(IRRewriter &rewriter, Operation *op,
if (sync->GetType() == SyncOperation::TYPE::PIPE_BARRIER) {
CreateBarrierOp(rewriter, insertAnchorOp, sync, forceBefore);
} else if (sync->isSyncSetType() || sync->isSyncWaitType()) {
if (sync->eventIds.size() == 1) {
CreateSetWaitOpForSingleBuffer(rewriter, insertAnchorOp, sync, forceBefore);
} else {
if (sync->eventIdNum > 1 && sync->eventIds.size() > 1) {
CreateSetWaitOpForMultiBuffer(rewriter, insertAnchorOp, sync, forceBefore);
} else {
CreateSetWaitOpForSingleBuffer(rewriter, insertAnchorOp, sync, forceBefore);
}
}
}
}

// [核心修改] 加强版 CreateBarrierOp
Expand Down Expand Up @@ -346,46 +357,62 @@ void SyncCodegen::CreateSetWaitOpForMultiBuffer(IRRewriter &rewriter,
Operation *op,
SyncOperation *sync,
bool beforeInsert) {
Value bufferSelected = GetBufferSelected(rewriter, op, sync);
(void)bufferSelected;

Value eventIdxDyn;
{
mlir::OpBuilder::InsertionGuard guard(rewriter);
eventIdxDyn = GetBufferSelected(rewriter, op, sync);
}
setSyncInsertionPoint(
rewriter, op, beforeInsert || op->hasTrait<OpTrait::IsTerminator>());
auto srcPipe = getPipeAttr(rewriter, sync->GetActualSrcPipe());
auto dstPipe = getPipeAttr(rewriter, sync->GetActualDstPipe());
auto eventId = getEventAttr(rewriter, sync->eventIds[0]);
setSyncInsertionPoint(rewriter, op,
beforeInsert || op->hasTrait<OpTrait::IsTerminator>());
createSetOrWaitFlagOp(rewriter, op, sync, srcPipe, dstPipe, eventId);
if (!eventIdxDyn) {
int id0 = sync->eventIds.empty() ? 0 : sync->eventIds[0];
auto eventId = getEventAttr(rewriter, id0);
createSetOrWaitFlagOp(rewriter, op, sync, srcPipe, dstPipe, eventId);
return;
}
createSetOrWaitFlagDynOp(rewriter, op, sync, srcPipe, dstPipe, eventIdxDyn);
}

Value SyncCodegen::GetBufferSelected(IRRewriter &rewriter, Operation *op,
SyncOperation *sync) {
if (SyncIndex2SelectBuffer.count(sync->GetSyncIndex())) {
return SyncIndex2SelectBuffer[sync->GetSyncIndex()];
}


unsigned N = static_cast<unsigned>(sync->eventIdNum);
if (N <= 1 || sync->eventIds.size() < N)
return nullptr;

auto parentLoop = op->getParentOfType<scf::ForOp>();
if (!parentLoop) return nullptr;

if (!parentLoop)
return nullptr;

Value counter;
if (loop2BufferCounter.count(parentLoop)) {
counter = loop2BufferCounter[parentLoop];
auto loopIt = loop2BufferCounter.find(parentLoop);
if (loopIt != loop2BufferCounter.end() && loopIt->second.second == N) {
counter = loopIt->second.first;
} else {
rewriter.setInsertionPointToStart(parentLoop.getBody());
Value iv = parentLoop.getInductionVar();
Value c2 = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 2);
counter = rewriter.create<arith::RemUIOp>(op->getLoc(), iv, c2);
loop2BufferCounter[parentLoop] = counter;
Value cN = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), N);
counter = rewriter.create<arith::RemUIOp>(op->getLoc(), iv, cN);
loop2BufferCounter[parentLoop] = {counter, N};
}

rewriter.setInsertionPointAfter(counter.getDefiningOp());
Value id0 = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), sync->eventIds[0]);
Value id1 = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), sync->eventIds[1]);

Value isZero = rewriter.create<arith::CmpIOp>(op->getLoc(), arith::CmpIPredicate::eq, counter,
rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0));

Value selected = rewriter.create<arith::SelectOp>(op->getLoc(), isZero, id0, id1);

Value selected =
rewriter.create<arith::ConstantIndexOp>(op->getLoc(), sync->eventIds[0]);
for (unsigned i = 1; i < N; ++i) {
Value ci = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), i);
Value eq = rewriter.create<arith::CmpIOp>(op->getLoc(), arith::CmpIPredicate::eq,
counter, ci);
Value idv =
rewriter.create<arith::ConstantIndexOp>(op->getLoc(), sync->eventIds[i]);
selected = rewriter.create<arith::SelectOp>(op->getLoc(), eq, idv, selected);
}

SyncIndex2SelectBuffer[sync->GetSyncIndex()] = selected;
return selected;
}
104 changes: 104 additions & 0 deletions lib/PTO/Transforms/PTOEnableMultiBuffer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// Copyright (c) 2026 Huawei Technologies Co., Ltd.
// This program is free software, you can redistribute it and/or modify it under the terms and conditions of
// CANN Open Software License Agreement Version 2.0 (the "License").
// Please refer to the License for details. You may not use this file except in compliance with the License.
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
// See LICENSE in the root of the software repository for the full text of the License.

#include "PTO/IR/PTO.h"
#include "PTO/Transforms/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"

namespace mlir {
namespace pto {
#define GEN_PASS_DEF_PTOENABLEMULTIBUFFER
#include "PTO/Transforms/Passes.h.inc"
} // namespace pto
} // namespace mlir

using namespace mlir;
using namespace mlir::pto;

namespace {

static LogicalResult lowerMultiBufferPointerCast(IRRewriter &rewriter,
PointerCastOp op,
scf::ForOp forOp) {
ValueRange addrs = op.getAddrs();
unsigned n = static_cast<unsigned>(addrs.size());
assert(n >= 2);

Location loc = op.getLoc();
MemRefType resTy = op.getType();
Value validRow = op.getValidRow();
Value validCol = op.getValidCol();
std::optional<TileBufConfigAttr> config = op.getConfig();

rewriter.setInsertionPoint(forOp);
SmallVector<Value> slotBufs;
slotBufs.reserve(n);
for (unsigned i = 0; i < n; ++i) {
auto oneAddr = addrs.slice(i, 1);
PointerCastOp slot = rewriter.create<PointerCastOp>(
loc, resTy, oneAddr, validRow, validCol,
config.has_value()
Comment on lines +70 to +72
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Check hoisted pointer-cast dims are loop-invariant

The multi-buffer lowering hoists new single-address pto.pointer_cast ops before the loop, but those ops still consume valid_row/valid_col. Only addrs are checked for loop-invariance, so if either valid-dim operand is defined inside the loop (e.g., loop-dependent dynamic bounds), the transformed IR will reference out-of-scope values and break SSA correctness. The pass should either validate these operands are defined outside the loop or keep the cast inside the loop when they are not.

Useful? React with 👍 / 👎.

? static_cast<Attribute>(*config)
: Attribute());
slotBufs.push_back(slot.getResult());
}
Comment on lines +65 to +76
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The pass hoists pto.pointer_cast operations outside the enclosing scf.for loop without verifying that all operands (addresses, validRow, validCol) are loop-invariant. If any operand is defined inside the loop, this hoisting will result in invalid IR where a value is used before its definition. You should check if all operands are defined outside the loop before attempting to hoist.


rewriter.setInsertionPointToStart(forOp.getBody());
Value iv = forOp.getInductionVar();
Value cN = rewriter.create<arith::ConstantIndexOp>(loc, n);
Value rem = rewriter.create<arith::RemUIOp>(loc, iv, cN);

Value selected = slotBufs[0];
for (unsigned i = 1; i < n; ++i) {
Value ci = rewriter.create<arith::ConstantIndexOp>(loc, i);
Value eq = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, rem, ci);
selected =
rewriter.create<arith::SelectOp>(loc, eq, slotBufs[i], selected);
}

rewriter.replaceOp(op, selected);
return success();
}

struct PTOEnableMultiBufferPass
: public mlir::pto::impl::PTOEnableMultiBufferBase<
PTOEnableMultiBufferPass> {
void runOnOperation() override {
func::FuncOp func = getOperation();
SmallVector<PointerCastOp> work;
func.walk([&](PointerCastOp op) {
if (op.getAddrs().size() > 1)
work.push_back(op);
});

IRRewriter rewriter(&getContext());
for (PointerCastOp op : work) {
auto forOp = op->getParentOfType<scf::ForOp>();
if (!forOp) {
op.emitWarning()
<< "pto-enable-multi-buffer: expected enclosing scf.for; skipping";
continue;
}
if (failed(lowerMultiBufferPointerCast(rewriter, op, forOp))) {
signalPassFailure();
return;
}
}
}
};

} // namespace

std::unique_ptr<Pass> mlir::pto::createPTOEnableMultiBufferPass() {
return std::make_unique<PTOEnableMultiBufferPass>();
}
Loading
Loading