Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions include/PTO/Transforms/InsertSync/SyncCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,21 @@ enum class TCoreType {

/// Meminfo of the target buffer
/// 用于追踪 Buffer 的别名和根节点
struct StaticMemRegion {
int64_t elemSizeBytes{1};
int64_t baseOffsetBytes{0};
// Per-dimension coordinates in the root layout. sizes are conservative
// bounding spans, so strided regions stay alias-safe.
SmallVector<int64_t> offsets;
SmallVector<int64_t> sizes;
SmallVector<int64_t> strides;

bool isPrecise() const {
return !offsets.empty() && offsets.size() == sizes.size() &&
offsets.size() == strides.size();
}
};

struct BaseMemInfo {
BaseMemInfo(
Value baseBuffer, Value rootBuffer, pto::AddressSpace scope,
Expand All @@ -98,6 +113,7 @@ struct BaseMemInfo {
pto::AddressSpace scope;
SmallVector<uint64_t> baseAddresses; // 用于 Offset 分析
uint64_t allocateSize;
std::optional<StaticMemRegion> preciseRegion;

bool areVectorEqual(const SmallVector<uint64_t>& vec1,
const SmallVector<uint64_t>& vec2) const {
Expand All @@ -120,13 +136,17 @@ struct BaseMemInfo {
}

std::unique_ptr<BaseMemInfo> clone() const {
return std::make_unique<BaseMemInfo>(
auto cloned = std::make_unique<BaseMemInfo>(
baseBuffer, rootBuffer, scope, baseAddresses, allocateSize);
cloned->preciseRegion = preciseRegion;
return cloned;
}

std::unique_ptr<BaseMemInfo> clone(Value cloneBaseBuffer) const {
return std::make_unique<BaseMemInfo>(
auto cloned = std::make_unique<BaseMemInfo>(
cloneBaseBuffer, rootBuffer, scope, baseAddresses, allocateSize);
cloned->preciseRegion = preciseRegion;
return cloned;
}
};

Expand Down
198 changes: 192 additions & 6 deletions lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See LICENSE in the root of the software repository for the full text of the License.

#include "PTO/Transforms/InsertSync/InsertSyncAnalysis.h"
#include "PTO/Transforms/InsertSync/InsertSyncDebug.h"
#include "PTO/Transforms/InsertSync/SyncCommon.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
Expand All @@ -36,6 +37,144 @@ static bool isValidPipeIndex(PipelineType pipe) {
return static_cast<unsigned>(pipe) < kPipeStateSize;
}

static bool isTraceEnabled() {
return isInsertSyncDebugEnabled(InsertSyncDebugLevel::Trace);
}

static llvm::StringRef getPipelineName(PipelineType pipe) {
switch (pipe) {
case PipelineType::PIPE_S:
return "PIPE_S";
case PipelineType::PIPE_V:
return "PIPE_V";
case PipelineType::PIPE_M:
return "PIPE_M";
case PipelineType::PIPE_MTE1:
return "PIPE_MTE1";
case PipelineType::PIPE_MTE2:
return "PIPE_MTE2";
case PipelineType::PIPE_MTE3:
return "PIPE_MTE3";
case PipelineType::PIPE_ALL:
return "PIPE_ALL";
case PipelineType::PIPE_MTE4:
return "PIPE_MTE4";
case PipelineType::PIPE_MTE5:
return "PIPE_MTE5";
case PipelineType::PIPE_V2:
return "PIPE_V2";
case PipelineType::PIPE_FIX:
return "PIPE_FIX";
case PipelineType::VIRTUAL_PIPE_MTE2_L1A:
return "VIRTUAL_PIPE_MTE2_L1A";
case PipelineType::VIRTUAL_PIPE_MTE2_L1B:
return "VIRTUAL_PIPE_MTE2_L1B";
case PipelineType::PIPE_NUM:
return "PIPE_NUM";
case PipelineType::PIPE_UNASSIGNED:
return "PIPE_UNASSIGNED";
}
return "PIPE_UNKNOWN";
}

static const char *getScopeName(pto::AddressSpace scope) {
switch (scope) {
case pto::AddressSpace::Zero:
return "Zero";
case pto::AddressSpace::GM:
return "GM";
case pto::AddressSpace::VEC:
return "VEC";
case pto::AddressSpace::MAT:
return "MAT";
case pto::AddressSpace::ACC:
return "ACC";
case pto::AddressSpace::LEFT:
return "LEFT";
case pto::AddressSpace::RIGHT:
return "RIGHT";
case pto::AddressSpace::BIAS:
return "BIAS";
case pto::AddressSpace::SCALING:
return "SCALING";
}
return "UNKNOWN";
}

static void dumpBases(llvm::raw_ostream &os, ArrayRef<uint64_t> bases) {
os << "[";
for (size_t i = 0; i < bases.size(); ++i) {
os << bases[i];
if (i + 1 != bases.size())
os << ", ";
}
os << "]";
}

static void dumpInt64List(llvm::raw_ostream &os, ArrayRef<int64_t> values) {
os << "[";
for (size_t i = 0; i < values.size(); ++i) {
os << values[i];
if (i + 1 != values.size())
os << ", ";
}
os << "]";
}

static void dumpCompoundTrace(llvm::StringRef tag,
const CompoundInstanceElement *compound) {
llvm::errs() << " " << tag << ": ";
if (!compound) {
llvm::errs() << "<null>\n";
return;
}
llvm::errs() << "idx=" << compound->GetIndex()
<< " op=" << compound->opName.getStringRef()
<< " pipe=" << getPipelineName(compound->kPipeValue) << "\n";
}

static void dumpMemInfoTrace(llvm::StringRef tag, const BaseMemInfo *info) {
llvm::errs() << " " << tag << ": ";
if (!info) {
llvm::errs() << "<null>\n";
return;
}
llvm::errs() << "scope=" << getScopeName(info->scope) << " bases=";
dumpBases(llvm::errs(), info->baseAddresses);
llvm::errs() << " sizeBytes=" << info->allocateSize;
if (info->rootBuffer) {
llvm::errs() << " rootType=" << info->rootBuffer.getType();
}
if (info->baseBuffer) {
llvm::errs() << " baseType=" << info->baseBuffer.getType();
}
llvm::errs() << "\n";
if (info->preciseRegion && info->preciseRegion->isPrecise()) {
llvm::errs() << " region offsets=";
dumpInt64List(llvm::errs(), info->preciseRegion->offsets);
llvm::errs() << " sizes=";
dumpInt64List(llvm::errs(), info->preciseRegion->sizes);
llvm::errs() << " strides=";
dumpInt64List(llvm::errs(), info->preciseRegion->strides);
llvm::errs() << " elemBytes=" << info->preciseRegion->elemSizeBytes
<< "\n";
}
}

static void dumpDepPairs(llvm::StringRef kind,
const DepBaseMemInfoPairVec &depPairs,
size_t beginIndex) {
if (!isTraceEnabled())
return;
llvm::errs() << " [InsertSync][Dependency] kind=" << kind
<< " pairs=" << (depPairs.size() - beginIndex) << "\n";
for (size_t i = beginIndex; i < depPairs.size(); ++i) {
llvm::errs() << " pair#" << (i - beginIndex) << "\n";
dumpMemInfoTrace("now-side", depPairs[i].first);
dumpMemInfoTrace("front-side", depPairs[i].second);
}
}

// ==============================================================================
// 1. Entry Point
// ==============================================================================
Expand Down Expand Up @@ -313,12 +452,27 @@ bool InsertSyncAnalysis::IsMemInfoHasDependency(
CompoundInstanceElement *frontCompound,
DepBaseMemInfoPairVec &depBaseMemInfosVec) {
bool hasDependency = false;
hasDependency |= memAnalyzer_.DepBetween(nowCompound->useVec, frontCompound->defVec,
depBaseMemInfosVec);
hasDependency |= memAnalyzer_.DepBetween(nowCompound->defVec, frontCompound->useVec,
depBaseMemInfosVec);
hasDependency |= memAnalyzer_.DepBetween(nowCompound->defVec, frontCompound->defVec,
depBaseMemInfosVec);
auto checkDep = [&](llvm::StringRef kind,
const SmallVector<const BaseMemInfo *> &nowSide,
const SmallVector<const BaseMemInfo *> &frontSide) {
size_t before = depBaseMemInfosVec.size();
bool dep = memAnalyzer_.DepBetween(nowSide, frontSide,
depBaseMemInfosVec);
if (dep && isTraceEnabled()) {
llvm::errs() << "\n[InsertSync][MemAnalyze] dependency found\n";
dumpCompoundTrace("front", frontCompound);
dumpCompoundTrace("now", nowCompound);
dumpDepPairs(kind, depBaseMemInfosVec, before);
}
return dep;
};

hasDependency |= checkDep("RAW: now.use overlaps front.def",
nowCompound->useVec, frontCompound->defVec);
hasDependency |= checkDep("WAR: now.def overlaps front.use",
nowCompound->defVec, frontCompound->useVec);
hasDependency |= checkDep("WAW: now.def overlaps front.def",
nowCompound->defVec, frontCompound->defVec);

// Special hazard: ACC (L0C) read/read cross-pipe ordering.
//
Expand All @@ -328,12 +482,20 @@ bool InsertSyncAnalysis::IsMemInfoHasDependency(
DepBaseMemInfoPairVec rrDepVec;
if (memAnalyzer_.DepBetween(nowCompound->useVec, frontCompound->useVec,
rrDepVec)) {
size_t before = depBaseMemInfosVec.size();
for (auto &pair : rrDepVec) {
if (!pair.first) continue;
if (pair.first->scope != pto::AddressSpace::ACC) continue;
depBaseMemInfosVec.push_back(pair);
hasDependency = true;
}
if (depBaseMemInfosVec.size() != before && isTraceEnabled()) {
llvm::errs() << "\n[InsertSync][MemAnalyze] dependency found\n";
dumpCompoundTrace("front", frontCompound);
dumpCompoundTrace("now", nowCompound);
dumpDepPairs("ACC RR: now.use overlaps front.use", depBaseMemInfosVec,
before);
}
}
}

Expand All @@ -349,6 +511,18 @@ void InsertSyncAnalysis::InsertSyncOperation(

if (nowPipe == frontPipe) {
unsigned insertBarrierId = nowCompound->GetIndex();
if (isTraceEnabled()) {
llvm::errs() << "\n[InsertSync][InsertOperation] same-pipe dependency "
"creates PIPE_BARRIER\n";
dumpCompoundTrace("front", frontCompound);
dumpCompoundTrace("now", nowCompound);
llvm::errs() << " syncIndex=" << syncIndex_
<< " insertBeforeIdx=" << insertBarrierId
<< " pipe=" << getPipelineName(nowPipe) << "\n";
dumpDepPairs("barrier dependency", depBaseMemInfosVec, 0);
if (nowPipe == PipelineType::PIPE_MTE2)
llvm::errs() << " result=pipe_barrier(PIPE_MTE2)\n";
}
auto barrierOp = std::make_unique<SyncOperation>(
SyncOperation::TYPE::PIPE_BARRIER, frontPipe, nowPipe, syncIndex_,
insertBarrierId, forEndIndex);
Expand All @@ -362,6 +536,18 @@ void InsertSyncAnalysis::InsertSyncOperation(
} else {
unsigned insertWaitId = nowCompound->GetIndex();
unsigned insertSetId = frontCompound->GetIndex();
if (isTraceEnabled()) {
llvm::errs() << "\n[InsertSync][InsertOperation] cross-pipe dependency "
"creates set/wait\n";
dumpCompoundTrace("front", frontCompound);
dumpCompoundTrace("now", nowCompound);
llvm::errs() << " syncIndex=" << syncIndex_
<< " setAfterIdx=" << insertSetId
<< " waitBeforeIdx=" << insertWaitId << " pipe="
<< getPipelineName(frontPipe) << "->"
<< getPipelineName(nowPipe) << "\n";
dumpDepPairs("event dependency", depBaseMemInfosVec, 0);
}
auto setOp = std::make_unique<SyncOperation>(
SyncOperation::TYPE::SET_EVENT, frontPipe, nowPipe, syncIndex_,
insertSetId, forEndIndex);
Expand Down
Loading
Loading