diff --git a/include/PTO/Transforms/InsertSync/SyncCommon.h b/include/PTO/Transforms/InsertSync/SyncCommon.h index 09aa4dc9d..a253a3e8a 100644 --- a/include/PTO/Transforms/InsertSync/SyncCommon.h +++ b/include/PTO/Transforms/InsertSync/SyncCommon.h @@ -83,6 +83,21 @@ enum class TCoreType { /// Meminfo of the target buffer /// 用于追踪 Buffer 的别名和根节点 +struct StaticMemRegion { + int64_t elemSizeBytes{1}; + int64_t baseOffsetBytes{0}; + // Per-dimension coordinates in the root layout. sizes are conservative + // bounding spans, so strided regions stay alias-safe. + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + + bool isPrecise() const { + return !offsets.empty() && offsets.size() == sizes.size() && + offsets.size() == strides.size(); + } +}; + struct BaseMemInfo { BaseMemInfo( Value baseBuffer, Value rootBuffer, pto::AddressSpace scope, @@ -98,6 +113,7 @@ struct BaseMemInfo { pto::AddressSpace scope; SmallVector baseAddresses; // 用于 Offset 分析 uint64_t allocateSize; + std::optional preciseRegion; bool areVectorEqual(const SmallVector& vec1, const SmallVector& vec2) const { @@ -120,13 +136,17 @@ struct BaseMemInfo { } std::unique_ptr clone() const { - return std::make_unique( + auto cloned = std::make_unique( baseBuffer, rootBuffer, scope, baseAddresses, allocateSize); + cloned->preciseRegion = preciseRegion; + return cloned; } std::unique_ptr clone(Value cloneBaseBuffer) const { - return std::make_unique( + auto cloned = std::make_unique( cloneBaseBuffer, rootBuffer, scope, baseAddresses, allocateSize); + cloned->preciseRegion = preciseRegion; + return cloned; } }; diff --git a/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp b/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp index 6cec030a4..7a9ad7977 100644 --- a/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp +++ b/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp @@ -12,6 +12,7 @@ // See LICENSE in the root of the software repository for the full text of the License. #include "PTO/Transforms/InsertSync/InsertSyncAnalysis.h" +#include "PTO/Transforms/InsertSync/InsertSyncDebug.h" #include "PTO/Transforms/InsertSync/SyncCommon.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -36,6 +37,144 @@ static bool isValidPipeIndex(PipelineType pipe) { return static_cast(pipe) < kPipeStateSize; } +static bool isTraceEnabled() { + return isInsertSyncDebugEnabled(InsertSyncDebugLevel::Trace); +} + +static llvm::StringRef getPipelineName(PipelineType pipe) { + switch (pipe) { + case PipelineType::PIPE_S: + return "PIPE_S"; + case PipelineType::PIPE_V: + return "PIPE_V"; + case PipelineType::PIPE_M: + return "PIPE_M"; + case PipelineType::PIPE_MTE1: + return "PIPE_MTE1"; + case PipelineType::PIPE_MTE2: + return "PIPE_MTE2"; + case PipelineType::PIPE_MTE3: + return "PIPE_MTE3"; + case PipelineType::PIPE_ALL: + return "PIPE_ALL"; + case PipelineType::PIPE_MTE4: + return "PIPE_MTE4"; + case PipelineType::PIPE_MTE5: + return "PIPE_MTE5"; + case PipelineType::PIPE_V2: + return "PIPE_V2"; + case PipelineType::PIPE_FIX: + return "PIPE_FIX"; + case PipelineType::VIRTUAL_PIPE_MTE2_L1A: + return "VIRTUAL_PIPE_MTE2_L1A"; + case PipelineType::VIRTUAL_PIPE_MTE2_L1B: + return "VIRTUAL_PIPE_MTE2_L1B"; + case PipelineType::PIPE_NUM: + return "PIPE_NUM"; + case PipelineType::PIPE_UNASSIGNED: + return "PIPE_UNASSIGNED"; + } + return "PIPE_UNKNOWN"; +} + +static const char *getScopeName(pto::AddressSpace scope) { + switch (scope) { + case pto::AddressSpace::Zero: + return "Zero"; + case pto::AddressSpace::GM: + return "GM"; + case pto::AddressSpace::VEC: + return "VEC"; + case pto::AddressSpace::MAT: + return "MAT"; + case pto::AddressSpace::ACC: + return "ACC"; + case pto::AddressSpace::LEFT: + return "LEFT"; + case pto::AddressSpace::RIGHT: + return "RIGHT"; + case pto::AddressSpace::BIAS: + return "BIAS"; + case pto::AddressSpace::SCALING: + return "SCALING"; + } + return "UNKNOWN"; +} + +static void dumpBases(llvm::raw_ostream &os, ArrayRef bases) { + os << "["; + for (size_t i = 0; i < bases.size(); ++i) { + os << bases[i]; + if (i + 1 != bases.size()) + os << ", "; + } + os << "]"; +} + +static void dumpInt64List(llvm::raw_ostream &os, ArrayRef values) { + os << "["; + for (size_t i = 0; i < values.size(); ++i) { + os << values[i]; + if (i + 1 != values.size()) + os << ", "; + } + os << "]"; +} + +static void dumpCompoundTrace(llvm::StringRef tag, + const CompoundInstanceElement *compound) { + llvm::errs() << " " << tag << ": "; + if (!compound) { + llvm::errs() << "\n"; + return; + } + llvm::errs() << "idx=" << compound->GetIndex() + << " op=" << compound->opName.getStringRef() + << " pipe=" << getPipelineName(compound->kPipeValue) << "\n"; +} + +static void dumpMemInfoTrace(llvm::StringRef tag, const BaseMemInfo *info) { + llvm::errs() << " " << tag << ": "; + if (!info) { + llvm::errs() << "\n"; + return; + } + llvm::errs() << "scope=" << getScopeName(info->scope) << " bases="; + dumpBases(llvm::errs(), info->baseAddresses); + llvm::errs() << " sizeBytes=" << info->allocateSize; + if (info->rootBuffer) { + llvm::errs() << " rootType=" << info->rootBuffer.getType(); + } + if (info->baseBuffer) { + llvm::errs() << " baseType=" << info->baseBuffer.getType(); + } + llvm::errs() << "\n"; + if (info->preciseRegion && info->preciseRegion->isPrecise()) { + llvm::errs() << " region offsets="; + dumpInt64List(llvm::errs(), info->preciseRegion->offsets); + llvm::errs() << " sizes="; + dumpInt64List(llvm::errs(), info->preciseRegion->sizes); + llvm::errs() << " strides="; + dumpInt64List(llvm::errs(), info->preciseRegion->strides); + llvm::errs() << " elemBytes=" << info->preciseRegion->elemSizeBytes + << "\n"; + } +} + +static void dumpDepPairs(llvm::StringRef kind, + const DepBaseMemInfoPairVec &depPairs, + size_t beginIndex) { + if (!isTraceEnabled()) + return; + llvm::errs() << " [InsertSync][Dependency] kind=" << kind + << " pairs=" << (depPairs.size() - beginIndex) << "\n"; + for (size_t i = beginIndex; i < depPairs.size(); ++i) { + llvm::errs() << " pair#" << (i - beginIndex) << "\n"; + dumpMemInfoTrace("now-side", depPairs[i].first); + dumpMemInfoTrace("front-side", depPairs[i].second); + } +} + // ============================================================================== // 1. Entry Point // ============================================================================== @@ -313,12 +452,27 @@ bool InsertSyncAnalysis::IsMemInfoHasDependency( CompoundInstanceElement *frontCompound, DepBaseMemInfoPairVec &depBaseMemInfosVec) { bool hasDependency = false; - hasDependency |= memAnalyzer_.DepBetween(nowCompound->useVec, frontCompound->defVec, - depBaseMemInfosVec); - hasDependency |= memAnalyzer_.DepBetween(nowCompound->defVec, frontCompound->useVec, - depBaseMemInfosVec); - hasDependency |= memAnalyzer_.DepBetween(nowCompound->defVec, frontCompound->defVec, - depBaseMemInfosVec); + auto checkDep = [&](llvm::StringRef kind, + const SmallVector &nowSide, + const SmallVector &frontSide) { + size_t before = depBaseMemInfosVec.size(); + bool dep = memAnalyzer_.DepBetween(nowSide, frontSide, + depBaseMemInfosVec); + if (dep && isTraceEnabled()) { + llvm::errs() << "\n[InsertSync][MemAnalyze] dependency found\n"; + dumpCompoundTrace("front", frontCompound); + dumpCompoundTrace("now", nowCompound); + dumpDepPairs(kind, depBaseMemInfosVec, before); + } + return dep; + }; + + hasDependency |= checkDep("RAW: now.use overlaps front.def", + nowCompound->useVec, frontCompound->defVec); + hasDependency |= checkDep("WAR: now.def overlaps front.use", + nowCompound->defVec, frontCompound->useVec); + hasDependency |= checkDep("WAW: now.def overlaps front.def", + nowCompound->defVec, frontCompound->defVec); // Special hazard: ACC (L0C) read/read cross-pipe ordering. // @@ -328,12 +482,20 @@ bool InsertSyncAnalysis::IsMemInfoHasDependency( DepBaseMemInfoPairVec rrDepVec; if (memAnalyzer_.DepBetween(nowCompound->useVec, frontCompound->useVec, rrDepVec)) { + size_t before = depBaseMemInfosVec.size(); for (auto &pair : rrDepVec) { if (!pair.first) continue; if (pair.first->scope != pto::AddressSpace::ACC) continue; depBaseMemInfosVec.push_back(pair); hasDependency = true; } + if (depBaseMemInfosVec.size() != before && isTraceEnabled()) { + llvm::errs() << "\n[InsertSync][MemAnalyze] dependency found\n"; + dumpCompoundTrace("front", frontCompound); + dumpCompoundTrace("now", nowCompound); + dumpDepPairs("ACC RR: now.use overlaps front.use", depBaseMemInfosVec, + before); + } } } @@ -349,6 +511,18 @@ void InsertSyncAnalysis::InsertSyncOperation( if (nowPipe == frontPipe) { unsigned insertBarrierId = nowCompound->GetIndex(); + if (isTraceEnabled()) { + llvm::errs() << "\n[InsertSync][InsertOperation] same-pipe dependency " + "creates PIPE_BARRIER\n"; + dumpCompoundTrace("front", frontCompound); + dumpCompoundTrace("now", nowCompound); + llvm::errs() << " syncIndex=" << syncIndex_ + << " insertBeforeIdx=" << insertBarrierId + << " pipe=" << getPipelineName(nowPipe) << "\n"; + dumpDepPairs("barrier dependency", depBaseMemInfosVec, 0); + if (nowPipe == PipelineType::PIPE_MTE2) + llvm::errs() << " result=pipe_barrier(PIPE_MTE2)\n"; + } auto barrierOp = std::make_unique( SyncOperation::TYPE::PIPE_BARRIER, frontPipe, nowPipe, syncIndex_, insertBarrierId, forEndIndex); @@ -362,6 +536,18 @@ void InsertSyncAnalysis::InsertSyncOperation( } else { unsigned insertWaitId = nowCompound->GetIndex(); unsigned insertSetId = frontCompound->GetIndex(); + if (isTraceEnabled()) { + llvm::errs() << "\n[InsertSync][InsertOperation] cross-pipe dependency " + "creates set/wait\n"; + dumpCompoundTrace("front", frontCompound); + dumpCompoundTrace("now", nowCompound); + llvm::errs() << " syncIndex=" << syncIndex_ + << " setAfterIdx=" << insertSetId + << " waitBeforeIdx=" << insertWaitId << " pipe=" + << getPipelineName(frontPipe) << "->" + << getPipelineName(nowPipe) << "\n"; + dumpDepPairs("event dependency", depBaseMemInfosVec, 0); + } auto setOp = std::make_unique( SyncOperation::TYPE::SET_EVENT, frontPipe, nowPipe, syncIndex_, insertSetId, forEndIndex); diff --git a/lib/PTO/Transforms/InsertSync/MemoryDependentAnalyzer.cpp b/lib/PTO/Transforms/InsertSync/MemoryDependentAnalyzer.cpp index 08c529246..8a35adf7f 100644 --- a/lib/PTO/Transforms/InsertSync/MemoryDependentAnalyzer.cpp +++ b/lib/PTO/Transforms/InsertSync/MemoryDependentAnalyzer.cpp @@ -43,6 +43,131 @@ static void printValueDebug(const char* tag, Value v) { } llvm::errs() << " | Type: " << v.getType() << "\n"; } + +static const char *getScopeName(pto::AddressSpace scope) { + switch (scope) { + case pto::AddressSpace::Zero: + return "Zero"; + case pto::AddressSpace::GM: + return "GM"; + case pto::AddressSpace::VEC: + return "VEC"; + case pto::AddressSpace::MAT: + return "MAT"; + case pto::AddressSpace::ACC: + return "ACC"; + case pto::AddressSpace::LEFT: + return "LEFT"; + case pto::AddressSpace::RIGHT: + return "RIGHT"; + case pto::AddressSpace::BIAS: + return "BIAS"; + case pto::AddressSpace::SCALING: + return "SCALING"; + } + return "UNKNOWN"; +} + +static void printBaseList(ArrayRef bases) { + llvm::errs() << "["; + for (size_t i = 0; i < bases.size(); ++i) { + llvm::errs() << bases[i]; + if (i + 1 != bases.size()) + llvm::errs() << ", "; + } + llvm::errs() << "]"; +} + +static void printInt64List(ArrayRef values) { + llvm::errs() << "["; + for (size_t i = 0; i < values.size(); ++i) { + llvm::errs() << values[i]; + if (i + 1 != values.size()) + llvm::errs() << ", "; + } + llvm::errs() << "]"; +} + +static void printRegionDebug(const char *tag, + const std::optional ®ion) { + if (!isTraceEnabled()) + return; + llvm::errs() << tag << ": "; + if (!region || !region->isPrecise()) { + llvm::errs() << "\n"; + return; + } + llvm::errs() << "offsets="; + printInt64List(region->offsets); + llvm::errs() << " sizes="; + printInt64List(region->sizes); + llvm::errs() << " strides="; + printInt64List(region->strides); + llvm::errs() << " elemBytes=" << region->elemSizeBytes + << " baseOffsetBytes=" << region->baseOffsetBytes << "\n"; +} + +static bool hasComparablePreciseRegions(const BaseMemInfo *a, + const BaseMemInfo *b) { + if (!a || !b || !a->preciseRegion || !b->preciseRegion) + return false; + const StaticMemRegion &ar = *a->preciseRegion; + const StaticMemRegion &br = *b->preciseRegion; + if (!ar.isPrecise() || !br.isPrecise()) + return false; + if (ar.offsets.size() != br.offsets.size()) + return false; + if (ar.elemSizeBytes != br.elemSizeBytes) + return false; + return true; +} + +static bool arePreciseRegionsProvenDisjoint(const BaseMemInfo *a, + const BaseMemInfo *b) { + if (!hasComparablePreciseRegions(a, b)) + return false; + + const StaticMemRegion &ar = *a->preciseRegion; + const StaticMemRegion &br = *b->preciseRegion; + for (size_t dim = 0; dim < ar.offsets.size(); ++dim) { + if (ar.sizes[dim] < 0 || br.sizes[dim] < 0) + return false; + int64_t aStart = ar.offsets[dim]; + int64_t bStart = br.offsets[dim]; + int64_t aEnd = aStart + ar.sizes[dim]; + int64_t bEnd = bStart + br.sizes[dim]; + if (aEnd <= bStart || bEnd <= aStart) { + if (isTraceEnabled()) { + llvm::errs() << " [RegionOverlap] precise regions are disjoint " + << "by dim=" << dim << " A=[" << aStart << ", " << aEnd + << ") B=[" << bStart << ", " << bEnd << ")\n"; + } + return true; + } + } + + if (isTraceEnabled()) + llvm::errs() << " [RegionOverlap] precise regions not proven " + "disjoint; falling back to flat range\n"; + return false; +} + +static void printMemInfoDebug(const char *tag, const BaseMemInfo *info) { + if (!isTraceEnabled()) + return; + llvm::errs() << tag << ": "; + if (!info) { + llvm::errs() << "\n"; + return; + } + llvm::errs() << "scope=" << getScopeName(info->scope) + << " bases="; + printBaseList(info->baseAddresses); + llvm::errs() << " sizeBytes=" << info->allocateSize << "\n"; + printValueDebug(" base", info->baseBuffer); + printValueDebug(" root", info->rootBuffer); + printRegionDebug(" region", info->preciseRegion); +} // [Fix & Debug] 增强版 GetRealRoot static Value GetRealRoot(Value v) { @@ -141,6 +266,8 @@ bool MemoryDependentAnalyzer::MemAlias(const BaseMemInfo *a, printValueDebug(" Root B", b->rootBuffer); llvm::errs() << " Scope A: " << (int)as << ", Scope B: " << (int)bs << "\n"; + printMemInfoDebug(" MemInfo A", a); + printMemInfoDebug(" MemInfo B", b); } if (as != bs) { @@ -157,8 +284,22 @@ bool MemoryDependentAnalyzer::MemAlias(const BaseMemInfo *a, // 2. Local Memory (UB/L1) if (a->rootBuffer == b->rootBuffer) { - if (a->baseAddresses.empty() || b->baseAddresses.empty()) return true; - return isBufferAddressRangeOverlap(a, b); + if (arePreciseRegionsProvenDisjoint(a, b)) { + if (isTraceEnabled()) + llvm::errs() << " -> Same root precise regions disjoint. " + "Alias=false.\n"; + return false; + } + if (a->baseAddresses.empty() || b->baseAddresses.empty()) { + if (isTraceEnabled()) + llvm::errs() << " -> Same root but unknown base list. Alias=true.\n"; + return true; + } + bool overlap = isBufferAddressRangeOverlap(a, b); + if (isTraceEnabled()) + llvm::errs() << " -> Same root address range overlap=" + << (overlap ? "true" : "false") << "\n"; + return overlap; } // 2.2 深层比较:穿透 View @@ -172,8 +313,15 @@ bool MemoryDependentAnalyzer::MemAlias(const BaseMemInfo *a, } if (realRootA == realRootB && realRootA != nullptr) { + if (arePreciseRegionsProvenDisjoint(a, b)) { + if (isTraceEnabled()) + llvm::errs() << " -> MATCH, but precise regions are disjoint. " + "Alias=false.\n"; + return false; + } if (isTraceEnabled()) - llvm::errs() << " -> MATCH! Real roots are the same.\n"; + llvm::errs() << " -> MATCH! Real roots are the same. Alias=true " + "without refined range check.\n"; return true; } else { if (isTraceEnabled()) @@ -190,13 +338,38 @@ bool MemoryDependentAnalyzer::isGMBufferOverlap(const BaseMemInfo *a, Value realRootB = GetRealRoot(b->rootBuffer); if (realRootA != realRootB) { + if (isTraceEnabled()) + llvm::errs() << " -> GM real roots differ. Alias=false.\n"; return false; } + if (arePreciseRegionsProvenDisjoint(a, b)) { + if (isTraceEnabled()) + llvm::errs() << " -> GM real roots match, but precise regions are " + "disjoint. Alias=false.\n"; + return false; + } + if (isTraceEnabled()) + llvm::errs() << " -> GM real roots match. Alias=true without " + "refined range check.\n"; return true; } - if (a->baseAddresses.empty() || b->baseAddresses.empty()) return true; - if (a->allocateSize == 0 || b->allocateSize == 0) return true; + if (arePreciseRegionsProvenDisjoint(a, b)) { + if (isTraceEnabled()) + llvm::errs() << " -> GM precise regions disjoint. Alias=false.\n"; + return false; + } + + if (a->baseAddresses.empty() || b->baseAddresses.empty()) { + if (isTraceEnabled()) + llvm::errs() << " -> GM unknown base list. Alias=true.\n"; + return true; + } + if (a->allocateSize == 0 || b->allocateSize == 0) { + if (isTraceEnabled()) + llvm::errs() << " -> GM unknown allocation size. Alias=true.\n"; + return true; + } return isBufferAddressRangeOverlap(a, b); } @@ -227,5 +400,17 @@ bool MemoryDependentAnalyzer::isBufferOverlap(const BaseMemInfo *a, uint64_t maxStart = std::max(aStart, bStart); uint64_t minEnd = std::min(aEnd, bEnd); - return maxStart < minEnd; + bool overlap = maxStart < minEnd; + if (isTraceEnabled()) { + llvm::errs() << " [RangeOverlap] A[" << aIndex << "]=[" << aStart + << ", " << aEnd << ") size=" << a->allocateSize + << " B[" << bIndex << "]=[" << bStart << ", " << bEnd + << ") size=" << b->allocateSize + << " maxStart=" << maxStart << " minEnd=" << minEnd + << " => " << (overlap ? "overlap" : "disjoint"); + if (overlap) + llvm::errs() << " overlapBytes=" << (minEnd - maxStart); + llvm::errs() << "\n"; + } + return overlap; } diff --git a/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp b/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp index 8ba4f265b..9937b8720 100644 --- a/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp +++ b/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp @@ -12,12 +12,14 @@ // See LICENSE in the root of the software repository for the full text of the License. #include "PTO/Transforms/InsertSync/PTOIRTranslator.h" +#include "PTO/Transforms/InsertSync/InsertSyncDebug.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "llvm/Support/Debug.h" #include "mlir/IR/AsmState.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Matchers.h" // [P0 新增] 引入副作用接口和 PTO 接口 @@ -27,65 +29,270 @@ using namespace mlir; using namespace mlir::pto; + +static bool isInsertSyncTraceEnabled() { + return isInsertSyncDebugEnabled(InsertSyncDebugLevel::Trace); +} + +static void dumpInt64List(llvm::raw_ostream &os, ArrayRef values) { + os << "["; + for (size_t i = 0; i < values.size(); ++i) { + os << values[i]; + if (i + 1 != values.size()) + os << ", "; + } + os << "]"; +} + +static void dumpValueSummary(llvm::raw_ostream &os, Value value) { + if (!value) { + os << ""; + return; + } + if (Operation *op = value.getDefiningOp()) + os << op->getName(); + else + os << "block_arg"; + os << ":" << value.getType(); +} + +struct StaticAliasInfo { + bool hasStaticOffset{true}; + bool preserveParentRegion{true}; + int64_t offsetBytes{0}; + int64_t sizeBytes{-1}; + std::optional relativeRegion; +}; + +static int64_t getElementSizeBytes(Type type) { + if (auto shapedType = dyn_cast(type)) + type = shapedType.getElementType(); + if (auto intType = dyn_cast(type)) { + int64_t bitWidth = intType.getWidth(); + return bitWidth > 0 ? std::max(1, bitWidth / 8) : 1; + } + if (auto floatType = dyn_cast(type)) { + int64_t bitWidth = floatType.getWidth(); + return bitWidth > 0 ? std::max(1, bitWidth / 8) : 1; + } + return 1; +} + +static bool hasOnlyStaticValues(ArrayRef values) { + for (int64_t value : values) + if (value == ShapedType::kDynamic) + return false; + return true; +} + +static std::optional getStaticBoundingSpan(int64_t size, + int64_t stride) { + if (size == ShapedType::kDynamic || stride == ShapedType::kDynamic || + size < 0 || stride < 0) + return std::nullopt; + if (size == 0) + return 0; + return (size - 1) * stride + 1; +} + +static bool shouldPreserveParentRegion(Operation *op) { + if (!op) + return true; + return isa(op); +} + +static std::optional +makeRootRegionFromMemRefType(MemRefType type) { + if (!type || !type.hasStaticShape()) + return std::nullopt; + + int64_t baseOffset; + SmallVector strides; + if (failed(mlir::getStridesAndOffset(type, strides, baseOffset))) + return std::nullopt; + if (strides.size() != static_cast(type.getRank()) || + !hasOnlyStaticValues(strides)) + return std::nullopt; + + StaticMemRegion region; + region.elemSizeBytes = getElementSizeBytes(type); + region.baseOffsetBytes = 0; + region.offsets.assign(type.getRank(), 0); + region.strides.assign(strides.begin(), strides.end()); + region.sizes.reserve(type.getRank()); + ArrayRef shape = type.getShape(); + for (size_t i = 0; i < shape.size(); ++i) { + std::optional span = getStaticBoundingSpan(shape[i], strides[i]); + if (!span) + return std::nullopt; + region.sizes.push_back(*span); + } + return region; +} + +static std::optional +composeSubviewRegion(const StaticMemRegion &relativeRegion, + const std::optional &parentRegion) { + if (!relativeRegion.isPrecise() || !parentRegion || + !parentRegion->isPrecise()) + return std::nullopt; + if (parentRegion->offsets.size() != relativeRegion.offsets.size()) + return std::nullopt; + if (parentRegion->elemSizeBytes != relativeRegion.elemSizeBytes) + return std::nullopt; + + StaticMemRegion region; + region.elemSizeBytes = relativeRegion.elemSizeBytes; + region.baseOffsetBytes = + parentRegion->baseOffsetBytes + relativeRegion.baseOffsetBytes; + region.offsets.reserve(relativeRegion.offsets.size()); + region.sizes.reserve(relativeRegion.sizes.size()); + region.strides.reserve(relativeRegion.strides.size()); + + for (size_t i = 0; i < relativeRegion.offsets.size(); ++i) { + int64_t parentStride = parentRegion->strides[i]; + int64_t childStride = relativeRegion.strides[i]; + if (parentStride == ShapedType::kDynamic || + childStride == ShapedType::kDynamic || parentStride < 0 || + childStride < 0) + return std::nullopt; + + int64_t rootStride = parentStride * childStride; + std::optional span = + getStaticBoundingSpan(relativeRegion.sizes[i], rootStride); + if (!span) + return std::nullopt; + + region.offsets.push_back(parentRegion->offsets[i] + + relativeRegion.offsets[i] * parentStride); + region.sizes.push_back(*span); + region.strides.push_back(rootStride); + } + return region; +} // [辅助函数] 尝试从 Operation 中计算相对于 Source 的字节偏移量和新大小 // 返回值: pair // 如果无法计算静态值,返回 {-1, -1} 表示这是动态的 -static std::pair getStaticOffsetAndSize(Operation *op, Value src) { +static StaticAliasInfo getStaticAliasInfo(Operation *op, Value src) { + StaticAliasInfo aliasInfo; + aliasInfo.preserveParentRegion = shouldPreserveParentRegion(op); auto srcType = dyn_cast(src.getType()); - if (!srcType) return {0, 0}; + if (!srcType) + return aliasInfo; - int64_t elemSize = srcType.getElementType().getIntOrFloatBitWidth() / 8; - if (elemSize == 0) elemSize = 1; + int64_t elemSize = getElementSizeBytes(srcType); // === Case 1: memref.subview === if (auto subView = dyn_cast(op)) { + aliasInfo.preserveParentRegion = false; int64_t baseOffset; SmallVector strides; if (failed(mlir::getStridesAndOffset(srcType, strides, baseOffset))) { - return {-1, -1}; + aliasInfo.hasStaticOffset = false; + return aliasInfo; } - int64_t newSize = 1; - for (int64_t s : subView.getStaticSizes()) { - if (s == ShapedType::kDynamic) return {-1, -1}; - newSize *= s; + auto staticSizes = subView.getStaticSizes(); + auto staticOffsets = subView.getStaticOffsets(); + auto staticSubStrides = subView.getStaticStrides(); + + if (staticOffsets.empty() || staticOffsets.size() > strides.size() || + staticSizes.size() != staticOffsets.size() || + staticSubStrides.size() != staticOffsets.size()) { + aliasInfo.hasStaticOffset = false; + return aliasInfo; + } + + int64_t flatSpan = 0; + bool hasZeroSize = false; + for (size_t i = 0; i < staticSizes.size(); ++i) { + int64_t s = staticSizes[i]; + if (s == ShapedType::kDynamic) { + aliasInfo.hasStaticOffset = false; + return aliasInfo; + } + if (s == 0) + hasZeroSize = true; } - newSize *= elemSize; int64_t totalOffset = 0; - auto staticOffsets = subView.getStaticOffsets(); - - if (staticOffsets.empty()) return {-1, -1}; - if (staticOffsets.size() > strides.size()) return {-1, -1}; - for (size_t i = 0; i < staticOffsets.size(); ++i) { int64_t off = staticOffsets[i]; - if (off == ShapedType::kDynamic) return {-1, -1}; + int64_t subStride = staticSubStrides[i]; + if (off == ShapedType::kDynamic || + subStride == ShapedType::kDynamic || subStride < 0) { + aliasInfo.hasStaticOffset = false; + return aliasInfo; + } int64_t stride = 1; if (i < strides.size() && strides[i] != ShapedType::kDynamic) { stride = strides[i]; } else { - return {-1, -1}; + aliasInfo.hasStaticOffset = false; + return aliasInfo; } totalOffset += off * stride; + if (!hasZeroSize) + flatSpan += (staticSizes[i] - 1) * stride * subStride; } - return {totalOffset * elemSize, newSize}; + int64_t byteOffset = totalOffset * elemSize; + aliasInfo.offsetBytes = byteOffset; + aliasInfo.sizeBytes = hasZeroSize ? 0 : (flatSpan + 1) * elemSize; + StaticMemRegion region; + region.elemSizeBytes = elemSize; + region.baseOffsetBytes = byteOffset; + region.offsets.assign(staticOffsets.begin(), staticOffsets.end()); + region.sizes.assign(staticSizes.begin(), staticSizes.end()); + region.strides.assign(staticSubStrides.begin(), staticSubStrides.end()); + aliasInfo.relativeRegion = region; + if (isInsertSyncTraceEnabled()) { + llvm::errs() << " [AliasRange] memref.subview flat range\n"; + llvm::errs() << " src="; + dumpValueSummary(llvm::errs(), src); + llvm::errs() << "\n"; + llvm::errs() << " srcType=" << srcType << " elemBytes=" << elemSize + << " baseOffset=" << baseOffset << "\n"; + llvm::errs() << " staticOffsets="; + dumpInt64List(llvm::errs(), staticOffsets); + llvm::errs() << " staticSizes="; + dumpInt64List(llvm::errs(), staticSizes); + llvm::errs() << " staticSubStrides="; + dumpInt64List(llvm::errs(), staticSubStrides); + llvm::errs() << " sourceStrides="; + dumpInt64List(llvm::errs(), strides); + llvm::errs() << "\n"; + llvm::errs() << " computedFlatOffsetBytes=" << byteOffset + << " computedFlatSizeBytes=" << aliasInfo.sizeBytes + << " (size = strided bounding span * elemBytes)\n"; + } + return aliasInfo; } // === Case 2: memref.reinterpret_cast === if (auto castOp = dyn_cast(op)) { + aliasInfo.preserveParentRegion = false; auto staticOffsets = castOp.getStaticOffsets(); if (staticOffsets.empty() || staticOffsets[0] == ShapedType::kDynamic) { - return {0, 0}; + return aliasInfo; } - return {staticOffsets[0] * elemSize, 0}; + int64_t byteOffset = staticOffsets[0] * elemSize; + aliasInfo.offsetBytes = byteOffset; + if (isInsertSyncTraceEnabled()) { + llvm::errs() << " [AliasRange] memref.reinterpret_cast offset\n"; + llvm::errs() << " src="; + dumpValueSummary(llvm::errs(), src); + llvm::errs() << " staticOffsets="; + dumpInt64List(llvm::errs(), staticOffsets); + llvm::errs() << " computedOffsetBytes=" << byteOffset << "\n"; + } + return aliasInfo; } - return {0, 0}; + return aliasInfo; } // ============================================================================ @@ -300,6 +507,7 @@ LogicalResult PTOIRTranslator::UpdatePointerCastOpMemInfo(pto::PointerCastOp op) SmallVector{0}, sizeInBytes ); + newMemInfo->preciseRegion = makeRootRegionFromMemRefType(memRefType); buffer2MemInfoMap_[res].emplace_back(newMemInfo->clone()); return success(); @@ -339,6 +547,7 @@ PTOIRTranslator::UpdateDeclareTileMemRefOpMemInfo(pto::DeclareTileMemRefOp op) { space, SmallVector{0}, sizeInBytes); + newMemInfo->preciseRegion = makeRootRegionFromMemRefType(memRefType); buffer2MemInfoMap_[res].emplace_back(newMemInfo->clone()); return success(); @@ -552,21 +761,20 @@ void PTOIRTranslator::UpdateAliasBufferInfo(Value result, Value source) { if (!result || !source) return; if (!buffer2MemInfoMap_.contains(source)) return; - int64_t deltaOffset = 0; - int64_t newSize = -1; + StaticAliasInfo aliasInfo; if (auto op = result.getDefiningOp()) { - auto info = getStaticOffsetAndSize(op, source); - if (info.first != -1) { - deltaOffset = info.first; - if (info.second > 0) newSize = info.second; - } + aliasInfo = getStaticAliasInfo(op, source); } + + int64_t deltaOffset = aliasInfo.hasStaticOffset ? aliasInfo.offsetBytes : 0; + int64_t newSize = aliasInfo.sizeBytes; auto &resultMemInfoVec = buffer2MemInfoMap_[result]; for (auto &parentInfo : buffer2MemInfoMap_[source]) { auto newInfo = parentInfo->clone(result); + SmallVector parentBases = parentInfo->baseAddresses; if (!newInfo->baseAddresses.empty()) { newInfo->baseAddresses[0] += deltaOffset; @@ -577,6 +785,56 @@ void PTOIRTranslator::UpdateAliasBufferInfo(Value result, Value source) { if (newSize > 0) { newInfo->allocateSize = newSize; } + + if (aliasInfo.relativeRegion) { + newInfo->preciseRegion = + composeSubviewRegion(*aliasInfo.relativeRegion, + parentInfo->preciseRegion); + } else if (!aliasInfo.preserveParentRegion) { + newInfo->preciseRegion.reset(); + } + + if (isInsertSyncTraceEnabled()) { + llvm::errs() << " [AliasInfo] "; + if (Operation *op = result.getDefiningOp()) + llvm::errs() << "op=" << op->getName() << " "; + llvm::errs() << "source="; + dumpValueSummary(llvm::errs(), source); + llvm::errs() << " result="; + dumpValueSummary(llvm::errs(), result); + llvm::errs() << "\n"; + llvm::errs() << " root="; + dumpValueSummary(llvm::errs(), newInfo->rootBuffer); + llvm::errs() << " scope=" << static_cast(newInfo->scope) + << " deltaOffsetBytes=" << deltaOffset + << " newSizeBytes=" << newSize << "\n"; + llvm::errs() << " parentBases=["; + for (size_t i = 0; i < parentBases.size(); ++i) { + llvm::errs() << parentBases[i]; + if (i + 1 != parentBases.size()) + llvm::errs() << ", "; + } + llvm::errs() << "] parentSizeBytes=" << parentInfo->allocateSize + << " resultBases=["; + for (size_t i = 0; i < newInfo->baseAddresses.size(); ++i) { + llvm::errs() << newInfo->baseAddresses[i]; + if (i + 1 != newInfo->baseAddresses.size()) + llvm::errs() << ", "; + } + llvm::errs() << "] resultSizeBytes=" << newInfo->allocateSize << "\n"; + if (newInfo->preciseRegion && newInfo->preciseRegion->isPrecise()) { + llvm::errs() << " preciseRegion offsets="; + dumpInt64List(llvm::errs(), newInfo->preciseRegion->offsets); + llvm::errs() << " sizes="; + dumpInt64List(llvm::errs(), newInfo->preciseRegion->sizes); + llvm::errs() << " strides="; + dumpInt64List(llvm::errs(), newInfo->preciseRegion->strides); + llvm::errs() << " elemBytes=" + << newInfo->preciseRegion->elemSizeBytes << "\n"; + } else { + llvm::errs() << " preciseRegion=\n"; + } + } resultMemInfoVec.emplace_back(std::move(newInfo)); } @@ -621,6 +879,7 @@ LogicalResult PTOIRTranslator::UpdateMemrefAllocOpMemInfo(memref::AllocOp op) { SmallVector{0}, // Base Addresses (Offset 0) sizeInBytes ); + newMemInfo->preciseRegion = makeRootRegionFromMemRefType(memRefType); buffer2MemInfoMap_[res].emplace_back(newMemInfo->clone()); return success(); diff --git a/test/lit/pto/issue667_strided_subview_mte2_barrier_alias.pto b/test/lit/pto/issue667_strided_subview_mte2_barrier_alias.pto new file mode 100644 index 000000000..fb114fb15 --- /dev/null +++ b/test/lit/pto/issue667_strided_subview_mte2_barrier_alias.pto @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --enable-insert-sync --emit-pto-ir %s -o - | FileCheck %s + +module { + func.func @issue667_strided_subview_tloads( + %src0: memref<8xf32, #pto.address_space>, + %src1: memref<8xf32, #pto.address_space>) { + %buf = memref.alloc() : memref<32xf32, #pto.address_space> + %a = memref.subview %buf[0] [8] [2] : + memref<32xf32, #pto.address_space> + to memref<8xf32, strided<[2]>, #pto.address_space> + %b = memref.subview %buf[8] [8] [2] : + memref<32xf32, #pto.address_space> + to memref<8xf32, strided<[2], offset: 8>, #pto.address_space> + + pto.tload ins(%src0 : memref<8xf32, #pto.address_space>) + outs(%a : memref<8xf32, strided<[2]>, #pto.address_space>) + pto.tload ins(%src1 : memref<8xf32, #pto.address_space>) + outs(%b : memref<8xf32, strided<[2], offset: 8>, #pto.address_space>) + return + } +} + +// CHECK-LABEL: func.func @issue667_strided_subview_tloads +// CHECK: pto.tload +// CHECK: pto.barrier +// CHECK: pto.tload diff --git a/test/lit/pto/issue667_subview_mte2_barrier_alias.pto b/test/lit/pto/issue667_subview_mte2_barrier_alias.pto new file mode 100644 index 000000000..90f9191dc --- /dev/null +++ b/test/lit/pto/issue667_subview_mte2_barrier_alias.pto @@ -0,0 +1,108 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --enable-insert-sync %s -o - | FileCheck %s --check-prefix=DISJOINT +// RUN: ptoas --pto-arch=a5 --enable-insert-sync %s -o - | FileCheck %s --check-prefix=OVERLAP + +module { + func.func @issue667_disjoint_subview_tloads( + %src0: memref<16x128xf32, #pto.address_space>, + %src1: memref<16x128xf32, #pto.address_space>, + %src2: memref<16x128xf32, #pto.address_space>, + %src3: memref<16x128xf32, #pto.address_space>) { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c384 = arith.constant 384 : index + + %tile = pto.alloc_tile : + !pto.tile_buf + %s0 = pto.subview %tile[%c0, %c0] sizes [16, 128] : + !pto.tile_buf + -> !pto.tile_buf + %s1 = pto.subview %tile[%c0, %c128] sizes [16, 128] : + !pto.tile_buf + -> !pto.tile_buf + %s2 = pto.subview %tile[%c0, %c256] sizes [16, 128] : + !pto.tile_buf + -> !pto.tile_buf + %s3 = pto.subview %tile[%c0, %c384] sizes [16, 128] : + !pto.tile_buf + -> !pto.tile_buf + + pto.tload ins(%src0 : memref<16x128xf32, #pto.address_space>) + outs(%s0 : !pto.tile_buf) + pto.tload ins(%src1 : memref<16x128xf32, #pto.address_space>) + outs(%s1 : !pto.tile_buf) + pto.tload ins(%src2 : memref<16x128xf32, #pto.address_space>) + outs(%s2 : !pto.tile_buf) + pto.tload ins(%src3 : memref<16x128xf32, #pto.address_space>) + outs(%s3 : !pto.tile_buf) + return + } + + func.func @issue667_overlapping_subview_tloads( + %src0: memref<16x128xf32, #pto.address_space>, + %src1: memref<16x128xf32, #pto.address_space>) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + + %tile = pto.alloc_tile : + !pto.tile_buf + %s0 = pto.subview %tile[%c0, %c0] sizes [16, 128] : + !pto.tile_buf + -> !pto.tile_buf + %s1 = pto.subview %tile[%c0, %c64] sizes [16, 128] : + !pto.tile_buf + -> !pto.tile_buf + + pto.tload ins(%src0 : memref<16x128xf32, #pto.address_space>) + outs(%s0 : !pto.tile_buf) + pto.tload ins(%src1 : memref<16x128xf32, #pto.address_space>) + outs(%s1 : !pto.tile_buf) + return + } + +} + +// DISJOINT-LABEL: AICORE void issue667_disjoint_subview_tloads +// DISJOINT-NOT: pipe_barrier(PIPE_MTE2) +// DISJOINT: TLOAD( +// DISJOINT-NOT: pipe_barrier(PIPE_MTE2) +// DISJOINT: TLOAD( +// DISJOINT-NOT: pipe_barrier(PIPE_MTE2) +// DISJOINT: TLOAD( +// DISJOINT-NOT: pipe_barrier(PIPE_MTE2) +// DISJOINT: TLOAD( +// DISJOINT-NOT: pipe_barrier(PIPE_MTE2) +// DISJOINT-LABEL: AICORE void issue667_overlapping_subview_tloads + +// OVERLAP-LABEL: AICORE void issue667_overlapping_subview_tloads +// OVERLAP: TLOAD( +// OVERLAP: pipe_barrier(PIPE_MTE2) +// OVERLAP: TLOAD(