Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions include/PTO/Transforms/InsertSync/InsertSyncAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,13 @@ class InsertSyncAnalysis {
const CompoundInstanceElement *frontCompound,
bool isBackwardDep) const;

/// 获取依赖对涉及的 Event ID 数量 (用于 Multi-Buffer 分析)
int GetEventIdNum(const DepBaseMemInfoPairVec &depBaseMemInfosVec);
/// Multi-buffer event-id deduction (HIVM-style). `backEdgeForLoop`, when
/// non-null, is the scf.for whose back-edge this dependency crosses; the
/// deduction additionally requires every involved buffer to live directly
/// under that loop. If null (forward dep), the deduction is a no-op and
/// returns 1.
int GetEventIdNum(const DepBaseMemInfoPairVec &depBaseMemInfosVec,
Operation *backEdgeForLoop = nullptr);

/// 辅助函数:获取所有涉及的 Buffer (用于 LCA 计算,虽然现在简化了,保留接口)
SmallVector<Value> GetMemInfoBuffers(const DepBaseMemInfoPairVec &depBaseMemInfosVec);
Expand Down
20 changes: 14 additions & 6 deletions include/PTO/Transforms/InsertSync/MemoryDependentAnalyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,29 @@ class MemoryDependentAnalyzer {
public:
MemoryDependentAnalyzer() = default;
~MemoryDependentAnalyzer() = default;

// 检查两组内存信息之间是否存在依赖
bool DepBetween(const SmallVector<const BaseMemInfo *> &a,
const SmallVector<const BaseMemInfo *> &b,
DepBaseMemInfoPairVec &depBaseMemInfosVec);

// 检查两个具体的 MemInfo 是否别名
bool MemAlias(const BaseMemInfo *a, const BaseMemInfo *b);


/// Multi-buffer eligibility for a dependent pair: HIVM requires both sides
/// to expose N>=2 byte-offset slots, sizes equal, **every same-index slot
/// overlaps** (the real cross-iteration dep) and **no different-index slot
/// overlaps** (so consecutive iterations land in disjoint physical buffers).
/// Returns N when eligible, otherwise 0.
unsigned getMultiBufferSlotCount(const BaseMemInfo *a,
const BaseMemInfo *b);

private:
bool isGMBufferOverlap(const BaseMemInfo *a, const BaseMemInfo *b);

bool isBufferAddressRangeOverlap(const BaseMemInfo *a, const BaseMemInfo *b);
bool isBufferOverlap(const BaseMemInfo *a, const BaseMemInfo *b,

bool isBufferOverlap(const BaseMemInfo *a, const BaseMemInfo *b,
int aIndex, int bIndex);
};

Expand Down
6 changes: 3 additions & 3 deletions include/PTO/Transforms/InsertSync/SyncCodegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ class SyncCodegen {
// 记录 Op -> Sync 的映射
DenseMap<const Operation *, SyncPipeBuild> op2InsertSync;

// 记录 Loop -> Counter 的映射 (缓存)
DenseMap<Operation *, Value> loop2BufferCounter;
// 记录 Loop -> ( Counter value , modulo N ) 的映射 (缓存)
DenseMap<Operation *, std::pair<Value, unsigned>> loop2BufferCounter;

// 记录 SyncIndex -> EventID Value 的映射 (缓存)
DenseMap<unsigned, Value> SyncIndex2SelectBuffer;
Expand All @@ -97,4 +97,4 @@ class SyncCodegen {
} // namespace pto
} // namespace mlir

#endif // MLIR_DIALECT_PTO_TRANSFORMS_INJECTSYNC_SYNCCODEGEN_HN_H
#endif // MLIR_DIALECT_PTO_TRANSFORMS_INJECTSYNC_SYNCCODEGEN_H
15 changes: 11 additions & 4 deletions include/PTO/Transforms/InsertSync/SyncCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,11 @@ enum class TCoreType {
struct BaseMemInfo {
BaseMemInfo(
Value baseBuffer, Value rootBuffer, pto::AddressSpace scope,
SmallVector<uint64_t> baseAddresses, uint64_t allocateSize)
SmallVector<uint64_t> baseAddresses, uint64_t allocateSize,
bool hasVariableAddress = false)
: baseBuffer(baseBuffer), rootBuffer(rootBuffer), scope(scope),
baseAddresses(std::move(baseAddresses)), allocateSize(allocateSize) {}
baseAddresses(std::move(baseAddresses)), allocateSize(allocateSize),
hasVariableAddress(hasVariableAddress) {}

/// baseBuffer: 当前操作直接使用的 Buffer (可能是 View 或 Alias)
Value baseBuffer;
Expand All @@ -98,6 +100,8 @@ struct BaseMemInfo {
pto::AddressSpace scope;
SmallVector<uint64_t> baseAddresses; // 用于 Offset 分析
uint64_t allocateSize;
/// True when pointer/workspace addresses are not compile-time constants.
bool hasVariableAddress{false};

bool areVectorEqual(const SmallVector<uint64_t>& vec1,
const SmallVector<uint64_t>& vec2) const {
Expand All @@ -116,17 +120,20 @@ struct BaseMemInfo {
// 但为了保持原有逻辑,先保留。重点是 rootBuffer 必须一致。
if (allocateSize != other.allocateSize) return false;
if (baseBuffer != other.baseBuffer) return false;
if (hasVariableAddress != other.hasVariableAddress) return false;
return true;
}

std::unique_ptr<BaseMemInfo> clone() const {
return std::make_unique<BaseMemInfo>(
baseBuffer, rootBuffer, scope, baseAddresses, allocateSize);
baseBuffer, rootBuffer, scope, baseAddresses, allocateSize,
hasVariableAddress);
}

std::unique_ptr<BaseMemInfo> clone(Value cloneBaseBuffer) const {
return std::make_unique<BaseMemInfo>(
cloneBaseBuffer, rootBuffer, scope, baseAddresses, allocateSize);
cloneBaseBuffer, rootBuffer, scope, baseAddresses, allocateSize,
hasVariableAddress);
}
};

Expand Down
26 changes: 26 additions & 0 deletions include/PTO/Transforms/MultiBuffer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) 2026 Huawei Technologies Co., Ltd.
// This program is free software, you can redistribute it and/or modify it under the terms and conditions of
// CANN Open Software License Agreement Version 2.0 (the "License").
// Please refer to the License for details. You may not use this file except in compliance with the License.
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
// See LICENSE in the root of the software repository for the full text of the License.

#ifndef PTO_TRANSFORMS_MULTIBUFFER_H
#define PTO_TRANSFORMS_MULTIBUFFER_H

#include "llvm/ADT/StringRef.h"

namespace mlir {
namespace pto {

/// Attribute name for multi-buffer depth on `memref.alloc` (integer slot count N>=2).
inline constexpr llvm::StringLiteral kPtoMultiBufferAttrName = "pto.multi_buffer";

/// Upper bound for N; must stay consistent with `MAX_MULTI_BUFFER_NUM` in insert-sync.
inline constexpr unsigned kPtoMultiBufferMaxNum = 16;

} // namespace pto
} // namespace mlir

#endif // PTO_TRANSFORMS_MULTIBUFFER_H
1 change: 1 addition & 0 deletions include/PTO/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ std::unique_ptr<Pass>
createPlanMemoryPass(const PlanMemoryOptions &planMemoryOption = {});

std::unique_ptr<Pass> createPTORemoveRedundantBarrierPass();
std::unique_ptr<Pass> createPTOEnableMultiBufferPass();
std::unique_ptr<Pass> createPTOViewToMemrefPass();
std::unique_ptr<Pass> createInferPTOLayoutPass();
std::unique_ptr<Pass> createPTOA5NormalizeTMovPass();
Expand Down
19 changes: 19 additions & 0 deletions include/PTO/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,25 @@ def PTOVerifyTFree : Pass<"pto-verify-tfree", "func::FuncOp"> {
];
}

def PTOEnableMultiBuffer : Pass<"pto-enable-multi-buffer", "func::FuncOp"> {
let summary = "Lower variadic pto.pointer_cast with multi-buffer addrs into "
"single-address casts plus a per-iteration arith.select";
let description = [{
Mirrors HIVM's `EnableMultiBuffer` lowering: takes a `pto.pointer_cast` with
N>1 address operands, hoists each address into its own single-address
`pto.pointer_cast` outside the parent `scf.for`, then replaces the original
multi-address cast with an N-way `arith.select` chain driven by `iv mod N`.
Runs after `pto-insert-sync` so the multi-address `pto.pointer_cast` stays
visible to dependency analysis.
}];
let constructor = "mlir::pto::createPTOEnableMultiBufferPass()";
let dependentDialects = [
"mlir::pto::PTODialect",
"mlir::scf::SCFDialect",
"mlir::arith::ArithDialect"
];
}

def PTOViewToMemref : Pass<"pto-view-to-memref", "ModuleOp"> {
let summary = "Lower PTO views to memref with Metadata Binding";
let description = [{
Expand Down
1 change: 1 addition & 0 deletions lib/PTO/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ add_mlir_dialect_library(PTOTransforms
InsertSync/RemoveRedundantSync.cpp
InsertSync/SyncEventIdAllocation.cpp
InsertSync/SyncCodegen.cpp
PTOEnableMultiBuffer.cpp
LoweringSyncToPipe.cpp
PTOVerifyTFreePass.cpp

Expand Down
89 changes: 78 additions & 11 deletions lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,19 @@ void InsertSyncAnalysis::InsertSyncOperation(
setOp->SetDepSyncIRIndex(frontCompound->GetIndex());
waitOp->SetDepSyncIRIndex(frontCompound->GetIndex());

// Back-edge dependencies may require multi-buffer event IDs.
// Back-edge dependencies may require multi-buffer event IDs. Resolve the
// owning scf.for so GetEventIdNum can verify that the dep buffer rotates
// on the right loop's induction variable (B1).
if (forEndIndex.has_value()) {
int eventIdNum = GetEventIdNum(depBaseMemInfosVec);
Operation *backEdgeForOp = nullptr;
if (forEndIndex.value() < syncIR_.size()) {
InstanceElement *loopElem = syncIR_[forEndIndex.value()].get();
if (loopElem) {
// For LOOP_END elements, elementOp points at the originating scf.for.
backEdgeForOp = loopElem->elementOp;
}
}
int eventIdNum = GetEventIdNum(depBaseMemInfosVec, backEdgeForOp);
setOp->eventIdNum = eventIdNum;
waitOp->eventIdNum = eventIdNum;
}
Expand Down Expand Up @@ -535,18 +545,75 @@ SmallVector<Value> InsertSyncAnalysis::GetMemInfoBuffers(
return result;
}

// Walk up `value`'s parent op chain to the nearest enclosing scf.for, if any.
// Used to satisfy HIVM's constraint that every multi-buffer dependency pair
// share a common scf.for ancestor (so a single `iv mod N` selector is valid).
static scf::ForOp getEnclosingScfFor(Value value) {
if (!value)
return nullptr;
Operation *op = value.getDefiningOp();
if (!op) {
// Block argument (e.g., loop iter_arg). Walk up from the parent block.
if (Block *block = value.getParentBlock())
op = block->getParentOp();
}
while (op) {
if (auto forOp = dyn_cast<scf::ForOp>(op))
return forOp;
op = op->getParentOp();
}
return nullptr;
}

int InsertSyncAnalysis::GetEventIdNum(
const DepBaseMemInfoPairVec &depBaseMemInfosVec) {
const DepBaseMemInfoPairVec &depBaseMemInfosVec,
Operation *backEdgeForLoop) {
// HIVM `GetEventIdNum` semantics: only deduce N>1 when EVERY dependent pair
// is multi-buffer-eligible (same slot count, same-index slots overlap,
// different-index slots are disjoint), all pairs agree on N, and every
// involved root buffer hangs off the same scf.for. Any failure collapses to
// single-buffer (eventIdNum = 1).
//
// Forward dependencies (no enclosing back-edge) are unconditionally
// single-buffer: multi-event-id only buys parallelism by breaking
// loop-carried sync, so it's meaningless without a back-edge.
if (depBaseMemInfosVec.empty())
return 1;
auto backEdgeFor = dyn_cast_or_null<scf::ForOp>(backEdgeForLoop);
if (!backEdgeFor)
return 1;

unsigned commonN = 0;
for (const auto &pair : depBaseMemInfosVec) {
bool isLocalA =
pair.first && (pair.first->scope == pto::AddressSpace::MAT ||
pair.first->scope == pto::AddressSpace::VEC);
bool isLocalB =
pair.second && (pair.second->scope == pto::AddressSpace::MAT ||
pair.second->scope == pto::AddressSpace::VEC);
if (isLocalA || isLocalB) return 1;
unsigned n = memAnalyzer_.getMultiBufferSlotCount(pair.first, pair.second);
if (n < 2)
return 1;
if (commonN == 0)
commonN = n;
else if (commonN != n)
return 1;

// B1: every involved buffer's enclosing scf.for must match the back-edge
// loop. A buffer that lives in an *inner* loop nested inside the back-edge
// loop would rotate slots on the wrong iv (inner.iv mod N), giving the
// wrong physical slot for a backward dep that crosses the outer
// back-edge. A buffer in an *outer* loop never rotates with the back-edge
// we care about. Either case must collapse to single-buffer.
auto checkLoop = [&](Value buffer) -> bool {
auto forOp = getEnclosingScfFor(buffer);
return forOp && forOp.getOperation() == backEdgeFor.getOperation();
};
// Use `baseBuffer` (the alloc-like SSA result inside the loop body)
// rather than `rootBuffer`, which for pto.pointer_cast is the i64 base
// address at function top and has no enclosing scf.for.
if (!checkLoop(pair.first->baseBuffer) ||
!checkLoop(pair.second->baseBuffer))
return 1;
}
return 1;

if (commonN == 0 || commonN > MAX_MULTI_BUFFER_NUM)
return 1;
return static_cast<int>(commonN);
}

bool InsertSyncAnalysis::IsGMHazard(
Expand Down
32 changes: 31 additions & 1 deletion lib/PTO/Transforms/InsertSync/MemoryDependentAnalyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,36 @@ bool MemoryDependentAnalyzer::isBufferOverlap(const BaseMemInfo *a,

uint64_t maxStart = std::max(aStart, bStart);
uint64_t minEnd = std::min(aEnd, bEnd);

return maxStart < minEnd;
}

unsigned MemoryDependentAnalyzer::getMultiBufferSlotCount(
const BaseMemInfo *a, const BaseMemInfo *b) {
if (a == nullptr || b == nullptr)
return 0;
// Variable addresses cannot prove the disjoint-slot invariant.
if (a->hasVariableAddress || b->hasVariableAddress)
return 0;
if (a->baseAddresses.size() != b->baseAddresses.size())
return 0;
unsigned n = static_cast<unsigned>(a->baseAddresses.size());
if (n < 2)
return 0;
if (a->allocateSize == 0 || b->allocateSize == 0)
return 0;

// Same-index slots must overlap (real backward dep across iterations on the
// same physical buffer); different-index slots must NOT overlap (otherwise
// consecutive iterations would alias and multi-buffer is unsafe).
for (unsigned i = 0; i < n; ++i) {
if (!isBufferOverlap(a, b, static_cast<int>(i), static_cast<int>(i)))
return 0;
for (unsigned j = 0; j < n; ++j) {
if (i == j) continue;
if (isBufferOverlap(a, b, static_cast<int>(i), static_cast<int>(j)))
return 0;
}
}
return n;
}
Loading
Loading