diff --git a/docs/PTO_IR_manual.md b/docs/PTO_IR_manual.md index 9216cfc14..86c7a7884 100644 --- a/docs/PTO_IR_manual.md +++ b/docs/PTO_IR_manual.md @@ -168,7 +168,53 @@ dimension counts FP4 pairs stored per byte, not logical scalar FP4 elements. --- -### 2.6 `!pto.local_array` +### 2.6 `!pto.multi_tile_buf` + +A **multi-buffer tile** representing N physically-distinct slots that share +one `tile_buf` shape. Only the underlying physical address differs across +slots; rank, valid shape, dtype, memory space, and config are identical. + +| Parameter | Type | Description | +|-----------|------|-------------| +| `slotType` | `!pto.tile_buf<...>` | The per-slot tile_buf type | +| `count` | unsigned `[2, 16]` | Number of physical slots N | + +**Constraints (enforced by the type verifier):** +- `2 <= count <= 16` (`kPtoMultiBufferMaxNum`) +- `slotType` follows all the same constraints as a single-slot `tile_buf` +- `multi_tile_buf` does not appear on function arguments or returns in the + initial release; the design's multi-buffer ownership stays inside PTOAS + +**Two compatible spellings:** + +```mlir +// Compact (preferred): the slot tile_buf is described inline. +!pto.multi_tile_buf + +// Verbose: spell out the slot tile_buf explicitly. +!pto.multi_tile_buf, count=2> +``` + +**Associated ops** (see Section 4 -- multi-buffer expression and slot +selection): + +- `pto.alloc_multi_tile` -- allocate an N-slot multi-buffer tile +- `pto.multi_tile_get` -- pick one slot of a multi_tile_buf, yielding a + regular `tile_buf` that flows through every existing DMA / compute / view + op unchanged + +The N-way physical fan-out lives on the `pto.multi_buffer = N : i32` +attribute that PTOViewToMemref writes onto the lowered `memref.alloc`; +downstream passes (PlanMemory / InsertSync / GraphSyncSolver) consume that +attribute. The per-use slot index threaded through `pto.multi_tile_get` is +forwarded to the memref layer via the internal `pto.slot_marker` view op. + +See `docs/designs/ptoas-multi-buffer-explicit-design.md` for the full +design. + +--- + +### 2.7 `!pto.local_array` A **C++ stack-local statically-shaped array**. Lowers to a plain `T a[D1][D2]...;` declaration in the emitted C++ — the array's address is decided by the host C++ @@ -536,6 +582,91 @@ result = alloc_tile(base_addr, valid_row, valid_col) // operands are optional %tb3 = pto.alloc_tile addr = %ad : !pto.tile_buf ``` +##### `pto.alloc_multi_tile` - Allocate N-Slot Multi-Buffer Tile + +**Summary:** Declares the lifetime of an N-slot multi-buffer tile. Each slot has the same `tile_buf` shape; only the underlying physical address differs. The N physical slots are reserved by `PTOPlanMemory` from the `pto.multi_buffer = N` attribute written onto the lowered `memref.alloc`. An explicit `addr` operand is intentionally NOT supported -- multi-buffer addresses are always compiler-decided. + +**Semantics:** + +``` +result = alloc_multi_tile(valid_row, valid_col) // operands are optional +``` + +**Arguments:** + +| Name | Type | Description | +|------|------|-------------| +| `valid_row` | `Optional` | Dynamic valid row count (required when slot `v_row` is `?`) | +| `valid_col` | `Optional` | Dynamic valid column count (required when slot `v_col` is `?`) | + +**Results:** `!pto.multi_tile_buf<...>` + +**Constraints & Verification:** + +- The result type must have `count` in `[2, 16]`. +- The slot tile type (rank, valid shape, dtype, memory space, config) is verified the same way as `pto.alloc_tile` for a single slot. +- No `addr` operand: the user cannot pin physical addresses on a multi-buffer alloc. + +**Hardware Mapping:** + +- No hardware pipeline (allocation/metadata op). N-way physical fan-out is realized by PlanMemory. + +**Basic Example:** + +```mlir +%mb = pto.alloc_multi_tile : !pto.multi_tile_buf +%mb2 = pto.alloc_multi_tile : !pto.multi_tile_buf, count=3> +``` + +##### `pto.multi_tile_get` - Select One Slot Of A Multi-Buffer Tile + +**Summary:** Returns a single-slot view of a `multi_tile_buf`. The frontend is the source of truth for which slot a given use refers to; the slot index `%k` is an `index` value (constant or any SSA expression) in `[0, count)`. PTOAS does NOT synthesize `iv mod N` for users -- the user expression IS the slot selector. Downstream sync and event-id allocation analyze the slot expressions and emit static `set_flag` / `wait_flag` for constant slots or `set_flag_dyn` / `wait_flag_dyn` for runtime slots. + +**Semantics:** + +``` +result = multi_tile_get(source, slot) +``` + +**Arguments:** + +| Name | Type | Description | +|------|------|-------------| +| `source` | `MultiTileBufType` | The N-slot multi-buffer tile | +| `slot` | `Index` | Slot index in `[0, count)` | + +**Results:** `!pto.tile_buf<...>` (must equal `source.slotType`) + +**Constraints & Verification:** + +- Result `tile_buf` must equal `source.slotType` (rank, valid shape, dtype, memory space, config all identical). +- If `slot` is a constant, the verifier checks `0 <= slot < count`. +- Pure view op -- no data movement, no extra address arithmetic. +- `multi_tile_buf` is not allowed on function arguments or results in the initial release. + +**Hardware Mapping:** + +- No hardware pipeline (metadata-only view). + +**Basic Example:** + +```mlir +%mb = pto.alloc_multi_tile : !pto.multi_tile_buf + +// constant-slot selection +%c0 = arith.constant 0 : index +%s0 = pto.multi_tile_get %mb[%c0] + : !pto.multi_tile_buf + -> !pto.tile_buf + +// dynamic-slot selection (e.g. prefetch with %k from the loop body) +%s_k = pto.multi_tile_get %mb[%k] + : !pto.multi_tile_buf + -> !pto.tile_buf +``` + +See `docs/designs/ptoas-multi-buffer-explicit-design.md` for the full design (sync/event-id derivation, downstream pass interplay, and end-to-end usage examples). + ##### `pto.subview` - Tile SubView **Summary:** Create a logical subview from a parent tile. The subview window is expressed by `offsets + sizes`, and the result tile type shape equals `sizes`. diff --git a/docs/designs/ptoas-multi-buffer-explicit-design.md b/docs/designs/ptoas-multi-buffer-explicit-design.md new file mode 100644 index 000000000..fb991de45 --- /dev/null +++ b/docs/designs/ptoas-multi-buffer-explicit-design.md @@ -0,0 +1,478 @@ +# PTOAS Multi-Buffer 显式表达 + 自动同步 设计 + +## 1. 文档范围 + +本文设计 PTOAS 一套新的 multi-buffer 表达方案,覆盖: + +- tile_buf 级前端表达; +- multi-buffer 物理地址规划; +- 手动 slot 选择的 lowering; +- 同步与 event id 的自动推导。 + +本设计参考 PR615 的若干基础设施(多地址 `pto.pointer_cast`、`set_flag_dyn` / `wait_flag_dyn`、`MAX_MULTI_BUFFER_NUM = 16` 等),但 IR 表面、用户接口、同步推导路径都是独立的,不在 PR615 之上叠加。 + +## 2. 设计目标 + +1. **显式表达 multi-buffer**:前端在 alloc 处声明"这块逻辑 tile 有 N 个物理槽位"。 +2. **手动选择 buffer**:前端在每个使用点(tload / tstore / 计算)显式说"使用第 k 个槽位",slot 索引可以是常量或任意 SSA 表达式。 +3. **自动同步与自动 event id**:跨 slot 的 RAW/WAR/WAW、跨 pipe 同步关系,由 `PTOInsertSync` / `PTOGraphSyncSolver` 自动推导并分配 event id(必要时使用 `set_flag_dyn` / `wait_flag_dyn`)。 +4. **不引入"全自动 `iv mod N` 注入"**:要求 multi-buffer 用户显式给 slot 表达式,前端可以把 `arith.remui iv, N` 作为 slot 表达式传进来,但这是前端的责任而非编译器的默认。 + +## 3. 与 PR615 的关键差异 + +| 决策点 | PR615 | 本设计 | +|---|---|---| +| Multi-buffer 表达载体 | `tile_buf` 类型字段 `multi_buffer=N` | 独立类型 `!pto.multi_tile_buf<..., count=N>` | +| Slot 选择方式 | 编译器全自动 `iv mod N` | 前端必须显式 `pto.multi_tile_get` | +| 是否依赖 `scf.for` | 必须有 enclosing `scf.for` | 不依赖;unroll、while、跨 block 都支持 | +| 同步推导 | 多地址 `pto.pointer_cast` ⇒ dyn event id(slot key 自动派生) | 每 use 携带 slot 标签 ⇒ 同步按标签推导 | +| CLI 开关 | `--enable-multi-buffer-lowering` | 不需要;slot 选择是 IR 的一部分 | + +参考自 PR615: +- `MAX_MULTI_BUFFER_NUM = 16`; +- PlanMemory 给多 buffer alloc 分 N 份地址、emit 多地址 `pto.pointer_cast` 的整体形态; +- `set_flag_dyn` / `wait_flag_dyn` 作 dyn event id 的 codegen; +- inplace 合并取 max。 + +不沿用 PR615 的部分: +- multi-buffer 不再用 `tile_buf` 的类型字段; +- 没有 `PTOEnableMultiBuffer` 的"`iv mod N` 自动注入"路径; +- 没有 `--enable-multi-buffer-lowering` 开关; +- PlanMemory 的 SPEC_LEVEL_1 复用策略不假设 `iv mod N` 顺序(详见 §5.2)。 + +## 4. 类型层设计 + +### 4.1 新增类型 `!pto.multi_tile_buf` + +```tablegen +def MultiTileBufType : TypeDef { + let mnemonic = "multi_tile_buf"; + let summary = "An array of N physical slots sharing one tile_buf shape"; + let parameters = (ins + "mlir::pto::TileBufType":$slotType, // 每个槽位的单 tile_buf 类型 + "uint32_t":$count // N, 2 <= N <= 16 + ); + let assemblyFormat = "`<` $slotType `,` `count` `=` $count `>`"; +} +``` + +紧凑写法(语法糖,参考 PR615 tile_buf 的 compact 形态): + +```mlir +!pto.multi_tile_buf +``` + +完整写法: + +```mlir +!pto.multi_tile_buf, count=2> +``` + +设计取舍:把 multi-buffer 单独成型,让"单 slot"和"N 槽容器"在类型层一目了然。所有 `tload` / `tstore` / `subview` / 计算 op 继续只见到 `tile_buf`,无需修改 signature。multi-buffer 唯一入口是 `alloc_multi_tile`,唯一出口是 `multi_tile_get`。 + +### 4.2 新增 op `pto.alloc_multi_tile` + +```tablegen +def AllocMultiTileOp : PTO_Op<"alloc_multi_tile", [AttrSizedOperandSegments]> { + let arguments = (ins + Optional:$valid_row, + Optional:$valid_col + ); + let results = (outs MultiTileBufType:$result); + let hasVerifier = 1; + let assemblyFormat = [{ + (`valid_row` `=` $valid_row^)? (`valid_col` `=` $valid_col^)? + attr-dict `:` qualified(type($result)) + }]; +} +``` + +特点: +- **没有** `addr` 操作数 —— multi-buffer 强制由 ptoas 规划物理槽位。 +- valid_row / valid_col 沿用 `alloc_tile` 的语义,作用到 N 个 slot 上(所有 slot 共享同一 valid_shape)。 + +### 4.3 新增 op `pto.multi_tile_get`(手动 slot 选择) + +```tablegen +def MultiTileGetOp : PTO_Op<"multi_tile_get", [Pure, ViewLikeOpInterface]> { + let arguments = (ins + MultiTileBufType:$source, + Index:$slot // 0 <= slot < count + ); + let results = (outs TileBufType:$result); + let hasVerifier = 1; + let assemblyFormat = [{ + $source `[` $slot `]` attr-dict + `:` qualified(type($source)) `->` qualified(type($result)) + }]; + let extraClassDeclaration = [{ + ::mlir::Value getViewSource() { return getSource(); } + }]; +} +``` + +约束: +- 结果 `tile_buf` 必须等于 `source.slotType`(shape / valid_shape / dtype / memSpace / config 全部相同)。 +- slot 是常量时,verifier 校验 `0 ≤ slot < count`;动态时不静态校验,由前端保证。 +- 结果是普通单槽 `tile_buf`,所有现有 op(subview / set_validshape / tload / 计算)零改动直接使用。 + +### 4.4 view op 组合规则 + +| 组合 | 允许? | 说明 | +|---|---|---| +| `multi_tile_get → subview` | 是 | 推荐:先 pin slot 再切窗口 | +| `multi_tile_get → set_validshape` | 是 | slot 维度与 valid_shape 维度正交 | +| `subview → multi_tile_get` | **否** | subview 操作 `tile_buf`(单槽),multi-buffer 信息已丢失 | +| 嵌套 `multi_tile_get` | **否** | 输入必须是 `multi_tile_buf`,verifier 拦截 | + +## 5. Pipeline + +```mermaid +flowchart LR + A["alloc_multi_tile count=N
multi_tile_get [%k]"] --> B["PTOViewToMemref"] + B --> B2["memref.alloc {pto.multi_buffer=N}
+ pto.slot_marker[%k]"] + B2 --> C["PTOPlanMemory
(N-address pointer_cast)"] + C --> D["PTOInsertSync / GSS
(sees slot_marker; baseAddresses 按 slot 区分)"] + D --> E["PTOResolveBufferSelect (新)"] + E --> F["set_flag / set_flag_dyn / EmitC"] +``` + +**关键顺序变化**:`PTOResolveBufferSelect` 排在 Sync **之后**,让 sync 直接看到 `pto.slot_marker` 并通过 `BaseMemInfo.baseAddresses` 的 slot-narrowing 自动获得 const-slot disjoint 优化。dyn-slot 在 sync 期间走保守路径(保留所有 slot 的 addresses);Resolve 之后才物化成 `arith.select`。 + +### 5.1 PTOViewToMemref + +- `pto.alloc_multi_tile : !pto.multi_tile_buf` + → `memref.alloc {pto.multi_buffer = N : i32}`(类型是 single-slot 物理大小)。 +- `pto.multi_tile_get %mb[%k] : ... -> S` + → 在 memref 层包一层内部 view op: + ```mlir + %slot_mem = pto.slot_marker %alloc_mem [%k] + : memref<16x16xf16, #pto.address_space> -> + memref<16x16xf16, #pto.address_space> + ``` + +`pto.slot_marker` 是 ptoas 内部 op(前端不可见),只把 slot SSA 挂到 memref 上,供后续 pass 识别。 + +### 5.2 PTOPlanMemory + +复用 PR615 的多地址规划: +- `pto.multi_buffer = N` 的 alloc 规划 N 份槽位; +- `StorageEntry::multiBufferNum` + `relationOtherBuffers` 持有 sibling slots; +- inplace 合并取 max; +- emit `pto.pointer_cast(addr0, ..., addrN-1)`。 + +**不沿用** PR615 SPEC_LEVEL_1 "假设 `iv mod N` 顺序"复用策略。本设计允许任意 slot 表达式,复用必须更保守:同一 alloc 的 N 个 slot 必须分配在彼此 disjoint 的物理段,不允许 alias 合并。如未来证明 slot 表达式形如 `iv mod N` 可以放宽,再加单独开关。 + +### 5.3 新增 pass `PTOResolveBufferSelect` + +位置:PlanMemory + Sync 都跑完之后。Sync 期间 `pto.slot_marker` 仍是 IR 节点,被 `MemoryDependentAnalyzer` 识别。Resolve 后下游 EmitC 看到的全部是单 slot 单地址。 + +逐个 `slot_marker` 处理: + +1. **slot 是 `arith.constant c`**:把该 use 链上的 memref 引用改为单地址 `pto.pointer_cast(addr_c)`。 +2. **slot 是 SSA**:在 use 点前生成 N-way `arith.select`,索引就是用户的 SSA: + ```mlir + %p0 = pto.pointer_cast(%addr0) : memref<...> + %p1 = pto.pointer_cast(%addr1) : memref<...> + %is1 = arith.cmpi eq, %k, %c1 : index + %slot_mem = arith.select %is1, %p1, %p0 : memref<...> + ``` + **不把 `%k` 替换为 `iv mod N`** —— 完全沿用前端给的表达式。 + +不变量:本 pass 结束后,每个数据 op 见到的 memref 都是单 slot。多地址 `pto.pointer_cast` 仅作为 sync 分析的"alloc 锚点"保留(已被前置 sync 消费完)。 + +### 5.4 InsertSync / GraphSyncSolver + +#### 5.4.1 BaseMemInfo 的多 slot 表达(已实现) + +`PTOIRTranslator::UpdatePointerCastOpMemInfo` 已经支持多地址: + +| `pto.pointer_cast` 入参数 | rootBuffer | baseAddresses | +|---|---|---| +| 单地址 `(addr0)` | `addr0` (i64 SSA) | `{0}` | +| 多地址 `(addr0, …, addrN-1)`(全 constant) | `op.getResult()` (cast SSA) | `{addr0, addr1, …, addrN-1}` 解析为 uint64 | +| 多地址,其中有非常量 | `addr0` | `{0}`(保守回退到老路径) | + +`UpdateSlotMarkerAliasBufferInfo` 处理 `pto.slot_marker`: + +- 常量 slot `k`:把父 BaseMemInfo 的 baseAddresses narrow 到 `{addrK}`。 +- 动态 slot:保留父的全部 baseAddresses(保守)。 + +`MemAlias` 现有的 `isBufferAddressRangeOverlap` 自动获益 —— +- 不同常量 slot ⇒ baseAddresses disjoint ⇒ 无冲突 +- 同常量 slot ⇒ baseAddresses 相同 ⇒ 真冲突 +- 任一 dyn ⇒ baseAddresses 覆盖所有 slot ⇒ 与任意 slot 都有 overlap,保守同步 + +#### 5.4.2 完整 SlotInfo(follow-up) + +最终目标是显式 `SlotInfo`,让 sync 在动态 slot 路径派发 dyn event id: + +```cpp +struct SlotInfo { + enum Kind { kSingle, kConstSlot, kDynSlot }; + Kind kind; + uint32_t slotCount; // == N + uint32_t constSlot; // for kConstSlot + Value dynSlotExpr; // for kDynSlot +}; +``` + +冲突判定表: + +| Producer | Consumer | 当前实现 | 目标 | +|---|---|---|---| +| Single ↔ Single | – | ✅ 普通同步 | 同 | +| Const(a) ↔ Const(b), a==b | – | ✅ 真冲突,静态 event id | 同 | +| Const(a) ↔ Const(b), a≠b | – | ✅ baseAddresses disjoint,无同步 | 同 | +| Const(a) ↔ Dyn(%k) | – | ⚠️ 保守 alias,全同步 | dyn event id(仅 %k==a 时同步) | +| Dyn(%j) ↔ Dyn(%k), 表达式相同 | – | ⚠️ 保守 alias | 真冲突 + dyn event id | +| Dyn(%j) ↔ Dyn(%k), 可证 disjoint | – | ⚠️ 保守 alias | 同 iter 不冲突,跨 iter dyn event id | +| Dyn ↔ Dyn 不可证 | – | ⚠️ 保守 alias | N 个 dyn event id | + +dyn event id 分配 + `set_flag_dyn` / `wait_flag_dyn` 生成留作 follow-up(需要扩展 `SyncEventIdAllocation` 和 `SyncCodegen`)。 + +资源不足回退(沿 PR615 思路):N → 偶数 → 2 → 1 → `PIPE_ALL` barrier。 + +### 5.5 CLI + +不需新开关。`alloc_multi_tile` 存在即驱动 `PTOResolveBufferSelect`。 + +## 6. 验证规则 + +`AllocMultiTileOp::verify`: +- `2 ≤ count ≤ 16`; +- valid_row/valid_col 操作数与 slotType 的 valid_shape 一致(沿用 alloc_tile 逻辑)。 + +`MultiTileGetOp::verify`: +- 结果 = `source.slotType`; +- 常量 slot 范围 `[0, count)`; +- 输入必须是 `multi_tile_buf`(防止嵌套 multi_tile_get)。 + +`PTOViewToMemref` 阶段一致性: +- 来自同一 `alloc_multi_tile` 的所有 use 必须经过 `multi_tile_get`(类型层已保证)。 +- function arg / return 不允许出现 `multi_tile_buf`(多 buffer 所有权限定在 ptoas 内)。 + +## 7. 使用例子(tile_buf 级 IR) + +### 7.1 例 1:静态 slot 并行装载与计算 + +两路 MTE2 装载分别落到 slot 0 / slot 1,向量计算消费两 slot。slot 全部常量。 + +```mlir +func.func @static_parallel( + %gm0 : memref<16x16xf16, #pto.address_space>, + %gm1 : memref<16x16xf16, #pto.address_space>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + %mb = pto.alloc_multi_tile + : !pto.multi_tile_buf + + %s0 = pto.multi_tile_get %mb[%c0] + : !pto.multi_tile_buf + -> !pto.tile_buf + %s1 = pto.multi_tile_get %mb[%c1] + : !pto.multi_tile_buf + -> !pto.tile_buf + + pto.tload ins(%gm0 : memref<16x16xf16, #pto.address_space>) + outs(%s0 : !pto.tile_buf) + pto.tload ins(%gm1 : memref<16x16xf16, #pto.address_space>) + outs(%s1 : !pto.tile_buf) + + pto.tadd ins(%s0, %s1 : !pto.tile_buf, + !pto.tile_buf) + outs(%s0 : !pto.tile_buf) + return +} +``` + +ptoas 自动行为: +- 2 份物理地址 addr0/addr1; +- slot 0 的 MTE2→V 一个静态 event id,slot 1 一个;两 slot 间无 RAW; +- 全部静态 `set_flag` / `wait_flag`,**不产生 dyn flag**。 + +### 7.2 例 2:双 buffer prefetch(动态 slot) + +```mlir +func.func @double_prefetch( + %gm : memref>, %n : index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + + %mb = pto.alloc_multi_tile + : !pto.multi_tile_buf + + // 预装 iter 0 -> slot0 + %pre = pto.multi_tile_get %mb[%c0] + : !pto.multi_tile_buf + -> !pto.tile_buf + pto.tload ins(%gm : ...) outs(%pre : !pto.tile_buf) + + scf.for %i = %c0 to %n step %c1 { + %next = arith.addi %i, %c1 : index + %cur_idx = arith.remui %i, %c2 : index + %next_idx = arith.remui %next, %c2 : index + + // prefetch 到另一 slot —— slot 索引由前端控制 + %s_next = pto.multi_tile_get %mb[%next_idx] + : !pto.multi_tile_buf + -> !pto.tile_buf + pto.tload ins(%gm : ...) outs(%s_next : !pto.tile_buf) + + %s_cur = pto.multi_tile_get %mb[%cur_idx] + : !pto.multi_tile_buf + -> !pto.tile_buf + pto.tadd ins(%s_cur, %s_cur : ..., ...) outs(...) + } + return +} +``` + +ptoas 自动行为: +- `PTOResolveBufferSelect` 直接用 `%cur_idx` / `%next_idx` 作 `arith.select` 索引; +- 同步分析:producer slot 表达式 `(iv+1)%2`,consumer slot `iv%2` → 同 iter disjoint / 跨 iter 冲突 → 2 个 dyn event id; +- emit `set_flag_dyn` / `wait_flag_dyn`,event id value 与 slot 表达式同源。 + +### 7.3 例 3:N=4 同表达式轮转 + +```mlir +%mb = pto.alloc_multi_tile + : !pto.multi_tile_buf +%c4 = arith.constant 4 : index + +scf.for %i = %c0 to %n step %c1 { + %k = arith.remui %i, %c4 : index + %slot = pto.multi_tile_get %mb[%k] + : !pto.multi_tile_buf + -> !pto.tile_buf + + pto.tload ins(%gm : ...) outs(%slot : !pto.tile_buf) + pto.tadd ins(%slot, %other : ..., ...) outs(...) +} +``` + +ptoas 自动行为: +- 4 份物理地址; +- producer/consumer 同 slot 表达式 → 同 slot 内 RAW;跨 iter 通过 4 个 dyn event id 推进。 + +### 7.4 例 4:无 `scf.for` 的手动 unroll + +```mlir +%mb = pto.alloc_multi_tile + : !pto.multi_tile_buf +%c0 = arith.constant 0 : index +%c1 = arith.constant 1 : index + +%s0 = pto.multi_tile_get %mb[%c0] + : !pto.multi_tile_buf + -> !pto.tile_buf +pto.tload ins(%gm0 : ...) outs(%s0 : !pto.tile_buf) + +%s1 = pto.multi_tile_get %mb[%c1] + : !pto.multi_tile_buf + -> !pto.tile_buf +pto.tload ins(%gm1 : ...) outs(%s1 : !pto.tile_buf) + +pto.tadd ins(%s0, %s0 : ..., ...) outs(...) +pto.tadd ins(%s1, %s1 : ..., ...) outs(...) +``` + +无循环也能 multi-buffer:两个常量 slot 的 producer/consumer 各自独立同步。这是 PR615 自动路径原生不支持的形态。 + +### 7.5 例 5:与 subview 组合 + +```mlir +%mb = pto.alloc_multi_tile + : !pto.multi_tile_buf + +scf.for %i = ... { + %k = arith.remui %i, %c2 : index + %slot = pto.multi_tile_get %mb[%k] + : !pto.multi_tile_buf + -> !pto.tile_buf + %win = pto.subview %slot[%c0, %c0] sizes=[16, 16] + : !pto.tile_buf + -> !pto.tile_buf + pto.tload ins(%gm : ...) outs(%win : !pto.tile_buf) + pto.tadd ins(%win, %win : ..., ...) outs(...) +} +``` + +slot 标签在 `%slot` 上 pinned,`subview` 在 memref 层继承 `slot_marker`,sync 分析照常工作。 + +## 8. 测试覆盖 + +| 测试 | 覆盖点 | +|---|---| +| `multi_tile_buf_type_verify.pto` | count 越界、slot 越界、嵌套 multi_tile_get 报错 | +| `alloc_multi_tile_view_to_memref.pto` | 类型 → `memref.alloc {pto.multi_buffer=N}` | +| `multi_tile_get_const_slot_resolve.pto` | 例 1:常量 slot → 单地址 pointer_cast | +| `multi_tile_get_dyn_slot_resolve.pto` | 例 2:动态 slot → dyn select 链 | +| `multi_tile_prefetch_insert_sync.pto` | 例 2:InsertSync 推导 prefetch 双 event id dyn flag | +| `multi_tile_prefetch_gss.pto` | 同上,GraphSyncSolver 路径 | +| `multi_tile_n4_rotate.pto` | 例 3:N=4 同表达式轮转 | +| `multi_tile_no_loop_unroll.pto` | 例 4:无 scf.for 的手动 unroll | +| `multi_tile_subview_compose.pto` | 例 5:与 subview 组合 | +| `multi_tile_disjoint_const_slots.pto` | 例 1:两个常量 slot 无冲突,不发同步 | + +建议验证命令: + +```bash +lit test/lit/pto/multi_tile_buf_type_verify.pto +lit test/lit/pto/alloc_multi_tile_view_to_memref.pto +lit test/lit/pto/multi_tile_get_const_slot_resolve.pto +lit test/lit/pto/multi_tile_get_dyn_slot_resolve.pto +lit test/lit/pto/multi_tile_prefetch_insert_sync.pto +lit test/lit/pto/multi_tile_prefetch_gss.pto +``` + +## 9. 当前实现状态 & 后续 + +### 已实现(截至本批次) + +| 阶段 | 状态 | 文件 | +|---|---|---| +| `!pto.multi_tile_buf` 类型,N ∈ [2, 16] | ✅ | `include/PTO/IR/PTOTypeDefs.td`, `lib/PTO/IR/PTOTypeDefs.cpp` | +| `pto.alloc_multi_tile` / `pto.multi_tile_get` / `pto.slot_marker` op | ✅ | `include/PTO/IR/PTOOps.td`, `lib/PTO/IR/PTO.cpp` | +| 类型 / op 验证(count 范围、slot 范围、嵌套禁止) | ✅ | `lib/PTO/IR/PTOTypeDefs.cpp`, `lib/PTO/IR/PTO.cpp` | +| `PTOViewToMemref` 下沉 alloc_multi_tile/multi_tile_get → `memref.alloc {pto.multi_buffer=N}` + `pto.slot_marker` | ✅ | `lib/PTO/Transforms/PTOViewToMemref.cpp` | +| `PTOPlanMemory` N-way 多 slot 规划:`StorageEntry.relationOtherBuffers` 列表 + `ExpandMultiBufferStorageEntry` N-way 兄弟展开 + `UpdateBuffer2Offsets` 按 slot 顺序写回 | ✅ N ∈ [2, 16] | `lib/PTO/Transforms/PTOPlanMemory.cpp`, `lib/PTO/Transforms/PTOPlanMemory.h` | +| `AllocToPointerCast` emit N-address `pto.pointer_cast(addr0..addrN-1)` | ✅ | `lib/PTO/Transforms/AllocToPointerCast.cpp` (pre-existing 已支持) | +| `PTOResolveBufferSelect` pass(常量 slot → 单地址 cast、动态 slot → arith.select 链) | ✅ | `lib/PTO/Transforms/PTOResolveBufferSelect.cpp` | +| ptoas pipeline 接入(PlanMemory → ResolveBufferSelect → Sync) | ✅ | `tools/ptoas/ptoas.cpp` | +| **Pipeline 重排**:Sync 跑在 `PTOResolveBufferSelect` 之前 → sync 直接看 `pto.slot_marker` | ✅ | `tools/ptoas/ptoas.cpp` | +| **InsertSync 多 slot 感知**:`UpdatePointerCastOpMemInfo` 多地址 cast 把 N 个 slot 的物理 offset 灌进 `baseAddresses`;`UpdateSlotMarkerAliasBufferInfo` 按 slot 常/动态 narrowing | ✅ | `lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp` | +| **`MemAlias` 自动识别 const-slot disjoint**:不同常量 slot 的 access baseAddresses disjoint → 不发同步 | ✅ | `lib/PTO/Transforms/InsertSync/MemoryDependentAnalyzer.cpp` (复用既有 range-overlap 逻辑) | +| **GSS 别名链穿透 `pto.bind_tile` / `pto.slot_marker`** | ✅ | `lib/PTO/Transforms/Utils.cpp` (`getOperationAliasInfo`) | +| **InsertSync 处理 arith.select-on-memref**:保留为防御逻辑(Resolve 移到 sync 后实际不再产生此场景) | ✅ | `lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp` | +| **`GetEventIdNum` 按多 slot 推 N**:back-edge dep 双侧 `baseAddresses.size() == N` 时返回 N | ✅ | `lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp` | +| **`SyncOperation` 携带 slot SSA**:`slotSSAExpr` + `slotCount` 字段,set/wait 各自记录自己一侧的 slot SSA | ✅ | `include/PTO/Transforms/InsertSync/SyncCommon.h`, `lib/PTO/Transforms/InsertSync/{SyncCommon,InsertSyncAnalysis}.cpp` | +| **dyn event-id codegen**:`CreateSetWaitOpForMultiBuffer` 发 `pto.set_flag_dyn` / `pto.wait_flag_dyn`,event id 由 N-way `arith.select` 链根据 `slot % N` 选自分配的 N 个静态 event id | ✅ | `lib/PTO/Transforms/InsertSync/SyncCodegen.cpp` | +| **`SyncEventIdAllocation` N event ids 分配**:复用既有 `eventIdNum > 1` 路径(已支持 N),自动给 set/wait 对分配 N 个 hardware event id | ✅ pre-existing | `lib/PTO/Transforms/InsertSync/SyncEventIdAllocation.cpp` | +| **GSS slot-aware**:`SyncSolverIRTranslator::tracebackMemValsStep` 在 `pto.slot_marker` 停步;`MemInfo::getMemInfoForSlotMarker` 按常量 slot 收窄 `PointerLikeInfo::addresses`、按 slot_marker enclosing loop 设 `parentLoop`,让 `getMultiBufferEventIdInfo` 识别多 buffer 并分配 N event ids | ✅ | `lib/PTO/Transforms/GraphSyncSolver/{SyncSolverIRTranslator,MemInfo}.cpp` | +| **GSS slotSSAExpr 落到 SetWaitOp**:`findSlotSSAExprForRWOp` 沿 `bind_tile` 走回 `pto.slot_marker.slot`,set 端取 `op1` 的 slot SSA、wait 端取 `op2` 的 | ✅ | `include/PTO/Transforms/GraphSyncSolver/SyncSolverIR.h`, `lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp` | +| **GSS dyn flag codegen**:`SyncSolverCodeGen::emitSyncOp` 在 `eventIds.size() > 1 && slotSSAExpr` 时折成单条 `pto.set_flag_dyn` / `pto.wait_flag_dyn`,event_id 用 N-way `arith.select` 链按 `slot % N` 选;`allAtOnce` prime/drain 仍走 N 静态 fanout(语义需要每个 slot 各 prime / drain 一次) | ✅ | `lib/PTO/Transforms/GraphSyncSolver/SyncSolverCodeGen.cpp` | +| **共享 affine slot 分析 `SlotAffineAnalysis`**:`findSlotMarkerExpr` + `compareSlotSSA` 三态(kEqual / kDisjoint / kUnknown),覆盖 `iv % N`、`(iv ± c) % N`、纯常量、相同 SSA 等形态;InsertSync / GSS / EmitC 三处共用 | ✅ | `include/PTO/Transforms/SlotAffineAnalysis.h`, `lib/PTO/Transforms/SlotAffineAnalysis.cpp` | +| **InsertSync 同 iter forward 提前 drop**:`MemAnalyze` 在 forward dep 上 跑 `isForwardDepDroppableBySlotAffine`,affine 可证 disjoint 的 pair 整对剔除,loop 体内省一对 same-iter set/wait | ✅ | `lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp` | +| **GSS 同 iter forward 提前 drop**:`checkMemoryConflictsForOcc` 在非 back-edge 路径上同样基于 affine disjoint 把整组 (corePipeSrc, corePipeDst) 过滤掉 | ✅ | `lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp` | +| **GSS 同 SSA 行为对齐 InsertSync**:`getMultiBufferEventIdInfo` 不再因 all-equal slot SSA 早退,对同 SSA 的 producer/consumer 也走 N dyn event id 路径;GSS 同 SSA 现在 emit 与 InsertSync 完全一样的 prefetch pipeline | ✅ | `lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp` | +| **EmitC 穿透 `arith.select`-on-memref**:`PTOMaterializeTileHandles::computeExplicitAddress` 沿 select 两支递归求 i64 地址,配合 dyn slot 路径的 select 链;EmitC TASSIGN 拿到正确地址 | ✅ | `lib/PTO/Transforms/PTOMaterializeTileHandles.cpp` | +| **`multi_tile_get` lowering 防御**:`op->getOperand(0)` 取 source(绕开类型 cast),因为 alloc_multi_tile replace 之后 source 已经变成 memref,typed accessor `getSource()` 会断言 | ✅ | `lib/PTO/Transforms/PTOViewToMemref.cpp` | +| lit 测试(17 个):parse/print、verifier、const slot lowering、dyn slot lowering、N=3 / N=4 端到端、无 loop unroll、const-slot sync disjoint、dyn-slot sync 编译、GSS multi-buffer compile、prefetch dyn event-id (InsertSync)、prefetch GSS dyn flag、affine disjoint slots、const preload + dyn loop select、preload + loop set/wait、unknown slot GSS 保守降级 | ✅ | `test/lit/pto/multi_tile_*.pto` | + +### 当前限制 + +- **affine 分析仅覆盖核心几种形态**:`compareSlotSSA` 当前能证 `iv % N` / `(iv ± c) % N` / 同 SSA / 纯常量;不能证 `(iv * c) % N`、跨函数 / 跨循环的 SSA 等价、非 `arith.remui` 包装的 slot 表达式。命中不到时退回 kUnknown / 保守 N dyn event id。 +- **PlanMemory N>2 不复用 Stage1**:N>2 的兄弟 slot 不走 SPEC_LEVEL_1 "ping/pong 相邻摆放"优化,用更多内存。N=2 路径不变。 +- 初版仅支持 `loc=vec` / `loc=mat` local memory。 +- function argument / return 上的 `multi_tile_buf` 不支持(多 buffer 所有权限定在 ptoas 内)。 +- workspace / preload / CV 多 buffer 不在本设计范围。 + +### 后续 PR + +1. **affine 分析扩展**:加 `(iv * c) % N`、跨 loop iter_arg 的 SSA 等价、非 `arith.remui` 形态(`arith.divui` / affine.apply)。 +2. **PlanMemory N > 2 SPEC_LEVEL_1 复用**:扩展 ping/pong 相邻摆放为 N-way 相邻摆放,减少 N > 2 的物理内存压力。 +3. **Python 绑定 + samples**:暴露 `alloc_multi_tile` / `multi_tile_get`,配套 `test/samples` 的 prefetch 示例。 +4. **跨函数 ABI**:把 `multi_tile_buf` 支持成 func arg / return,配套 sync 跨函数传递(v2 议题)。 diff --git a/include/PTO/IR/PTOMultiBuffer.h b/include/PTO/IR/PTOMultiBuffer.h new file mode 100644 index 000000000..53c89f6c3 --- /dev/null +++ b/include/PTO/IR/PTOMultiBuffer.h @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTOMultiBuffer.h - Shared constants for multi-buffer ----*- C++ -*-===// +// +// Shared constants for the multi-buffer expression scheme: +// - `kPtoMultiBufferAttrName` is the memref-level attribute name written by +// PTOViewToMemref when lowering an `alloc_multi_tile` op. PlanMemory and +// downstream passes read it to reserve N physical slots. +// - `kPtoMultiBufferMaxNum` is the upper bound on the slot count N. It is +// kept in lock-step with the InsertSync `MAX_MULTI_BUFFER_NUM`. +// +//===----------------------------------------------------------------------===// + +#ifndef PTO_IR_PTOMULTIBUFFER_H +#define PTO_IR_PTOMULTIBUFFER_H + +#include "llvm/ADT/StringRef.h" + +namespace mlir { +namespace pto { + +/// Attribute name for multi-buffer depth (integer slot count N>=2). +inline constexpr llvm::StringLiteral kPtoMultiBufferAttrName = + "pto.multi_buffer"; + +/// Upper bound for N; must stay consistent with `MAX_MULTI_BUFFER_NUM` in +/// insert-sync. +inline constexpr unsigned kPtoMultiBufferMaxNum = 16; + +/// Lower bound for N (a multi_tile_buf must have at least 2 slots). +inline constexpr unsigned kPtoMultiBufferMinNum = 2; + +} // namespace pto +} // namespace mlir + +#endif // PTO_IR_PTOMULTIBUFFER_H diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index 689938a06..2d098c27a 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -232,7 +232,7 @@ def AllocTileOp : PTO_Op<"alloc_tile", [AttrSizedOperandSegments]> { let assemblyFormat = [{ (`addr` `=` $addr^)? - (`valid_row` `=` $valid_row^)? + (`valid_row` `=` $valid_row^)? (`valid_col` `=` $valid_col^)? attr-dict `:` qualified(type($result)) }]; @@ -242,6 +242,95 @@ def AllocTileOp : PTO_Op<"alloc_tile", [AttrSizedOperandSegments]> { }]; } +//===----------------------------------------------------------------------===// +// Multi-buffer tile allocation and slot selection +//===----------------------------------------------------------------------===// +// +// `pto.alloc_multi_tile` declares an N-slot logical tile buffer. The actual +// physical allocation of N slots is performed by PTOPlanMemory once the op +// has been lowered to `memref.alloc {pto.multi_buffer = N : i32}` by +// PTOViewToMemref. +// +// `pto.multi_tile_get` is the *only* way to consume a `multi_tile_buf`: it +// returns a regular `tile_buf` view onto the chosen slot. The slot index +// may be a constant or any SSA `index` value; PTOAS does not synthesize +// `iv mod N` automatically -- the user expression IS the slot selector. +// +// `pto.slot_marker` is an internal op materialized by PTOViewToMemref to +// thread the slot SSA from `multi_tile_get` through the memref layer down +// to PlanMemory / sync analysis. It is not intended for direct frontend +// use. + +def AllocMultiTileOp : PTO_Op<"alloc_multi_tile", [AttrSizedOperandSegments]> { + let summary = "Allocate an N-slot multi-buffer tile"; + let description = [{ + Produces a `!pto.multi_tile_buf` value owning N physical + slots of the inner tile shape `S`. The physical addresses of the N + slots are decided by PTOPlanMemory; an explicit address operand is + intentionally not supported. + + `valid_row` / `valid_col` operands follow the same semantics as + `pto.alloc_tile`: required when the inner tile's valid shape is `?`, + absent when both valid dims are static. + }]; + + let arguments = (ins + Optional:$valid_row, + Optional:$valid_col + ); + + let results = (outs MultiTileBufType:$result); + + let assemblyFormat = [{ + (`valid_row` `=` $valid_row^)? + (`valid_col` `=` $valid_col^)? + attr-dict `:` qualified(type($result)) + }]; + + let extraClassDeclaration = [{ + ::mlir::LogicalResult verify(); + }]; +} + +def MultiTileGetOp : PTO_Op<"multi_tile_get", [ + Pure, + ViewLikeOpInterface + ]> { + let summary = "Pick one physical slot of a multi_tile_buf"; + let description = [{ + Returns a single-slot view of a `multi_tile_buf`. The result `tile_buf` + must equal the source's per-slot type; the only thing this op selects + is which of the N physical slots is referred to. + + The slot index is an `index` value in `[0, count)`. If it is a constant, + the verifier checks the range. If it is a runtime SSA value, downstream + sync analysis treats this use as a slot-indexed dynamic access; ptoas + does NOT rewrite the user expression into `iv mod N`. + + This op is metadata-only (no data movement); the lowering merely + annotates the underlying memref view with the slot index for PlanMemory + / sync / EnableBufferSelect to consume. + }]; + + let arguments = (ins + MultiTileBufType:$source, + Index:$slot + ); + + let results = (outs TileBufType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `[` $slot `]` attr-dict + `:` qualified(type($source)) `->` qualified(type($result)) + }]; + + let extraClassDeclaration = [{ + ::mlir::Value getViewSource() { return getSource(); } + }]; +} + // ============================================================================ // BindTileOp: 将 Config 和 Valid Dims 绑定到 MemRef 上 @@ -1217,16 +1306,51 @@ def PointerCastOp : PTO_Op<"pointer_cast", [AttrSizedOperandSegments, Pure]> { if (vRow) $_state.addOperands(vRow); if (vCol) $_state.addOperands(vCol); if (config) $_state.addAttribute("config", config); - + int32_t addrsSize = addrs.size(); int32_t vRowSize = vRow ? 1 : 0; int32_t vColSize = vCol ? 1 : 0; - $_state.addAttribute("operandSegmentSizes", + $_state.addAttribute("operandSegmentSizes", $_builder.getDenseI32ArrayAttr({addrsSize, vRowSize, vColSize})); }]> ]; } +def SlotMarkerOp : PTO_Op<"slot_marker", [ + Pure, + ViewLikeOpInterface, + AllTypesMatch<["source", "result"]> + ]> { + let summary = "Tag a memref view as referring to one slot of a multi_tile_buf"; + let description = [{ + Internal op materialized by `PTOViewToMemref` while lowering + `pto.multi_tile_get`. It carries the slot SSA index through the memref + layer so that PlanMemory, sync analysis (InsertSync / GraphSyncSolver), + and the buffer-select lowering pass can identify which physical slot + this memref reference touches. + + The op is metadata-only (no data movement, no extra address arithmetic); + its result memref aliases the source memref byte-for-byte. Frontends do + not produce this op directly -- use `pto.multi_tile_get` instead. + }]; + + let arguments = (ins + AnyMemRef:$source, + Index:$slot + ); + + let results = (outs AnyMemRef:$result); + + let assemblyFormat = [{ + $source `[` $slot `]` attr-dict + `:` qualified(type($source)) `->` qualified(type($result)) + }]; + + let extraClassDeclaration = [{ + ::mlir::Value getViewSource() { return getSource(); } + }]; +} + // ============================================================================= // System/Runtime Query Ops // ============================================================================= diff --git a/include/PTO/IR/PTOTypeDefs.td b/include/PTO/IR/PTOTypeDefs.td index 6b8c0ee5c..134b1249a 100644 --- a/include/PTO/IR/PTOTypeDefs.td +++ b/include/PTO/IR/PTOTypeDefs.td @@ -221,6 +221,45 @@ def TileBufType : TypeDef { }]; } +// ============================================================================= +// MultiTileBufType +// ============================================================================= +// 对应 IR: !pto.multi_tile_buf<, count=N> +// +// 表达一个"由 N 个物理 slot 组成"的逻辑 tile buffer。每个 slot 与 +// `slotType` 完全同构(同 shape / dtype / loc / valid / config)。该类型只 +// 能通过 `pto.alloc_multi_tile` 产生,通过 `pto.multi_tile_get [%k]` 取出 +// 第 k 个 slot 后转回普通 `tile_buf`,下游所有 op 仍然只见到 `tile_buf`。 +// ============================================================================= +def MultiTileBufType : TypeDef { + let mnemonic = "multi_tile_buf"; + let summary = "An array of N physical slots sharing one tile_buf shape"; + let description = [{ + A `multi_tile_buf` value owns `count` physically-distinct slots of an + inner `tile_buf` shape. The slot type is identical across all N slots; + only the underlying physical address differs. + + Multi-buffer is opaque to all existing `tile_buf` ops: extracting a slot + via `pto.multi_tile_get` yields a regular `tile_buf` that can be passed + to any DMA / compute / view op unchanged. The slot index expression is + user-provided (constant or any SSA `index`); ptoas does not synthesize + `iv mod N` for users. + }]; + + let parameters = (ins + "mlir::pto::TileBufType":$slotType, + "uint32_t":$count + ); + + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; + + let extraClassDeclaration = [{ + /// Return the inner per-slot tile_buf type. + mlir::pto::TileBufType getTileBufType() const { return getSlotType(); } + }]; +} + def EventIdArrayType : TypeDef { let mnemonic = "eventid_array"; let summary = "Manual-only local array for dynamic event ids"; diff --git a/include/PTO/Transforms/GraphSyncSolver/SyncSolver.h b/include/PTO/Transforms/GraphSyncSolver/SyncSolver.h index 257bca956..12ea9798e 100644 --- a/include/PTO/Transforms/GraphSyncSolver/SyncSolver.h +++ b/include/PTO/Transforms/GraphSyncSolver/SyncSolver.h @@ -245,6 +245,10 @@ class Solver { llvm::SmallVector> checkMemoryConflicts(RWOperation *rwOp1, RWOperation *rwOp2); + llvm::SmallVector> + checkMemoryConflictsForOcc(Occurrence *occ1, Occurrence *occ2, + RWOperation *rwOp1, RWOperation *rwOp2); + bool checkMemoryConflictBetweenOccExclusive( Occurrence *occ1, Occurrence *occ2, std::function filter = [](RWOperation *) { diff --git a/include/PTO/Transforms/GraphSyncSolver/SyncSolverIR.h b/include/PTO/Transforms/GraphSyncSolver/SyncSolverIR.h index ee780691e..0a6fcf43c 100644 --- a/include/PTO/Transforms/GraphSyncSolver/SyncSolverIR.h +++ b/include/PTO/Transforms/GraphSyncSolver/SyncSolverIR.h @@ -438,6 +438,13 @@ class SetWaitOp : public SyncOp { bool allAtOnce{false}; bool checkFirstIter{false}; bool checkLastIter{false}; + // Slot SSA at this access site. Populated when the sync corresponds to a + // multi-buffer back-edge dep produced by `pto.multi_tile_get` / lowered + // `pto.slot_marker`. When non-null and `eventIds.size() > 1`, codegen + // emits `pto.set_flag_dyn` / `pto.wait_flag_dyn` with a runtime event id + // selected by `slotSSAExpr % eventIds.size()` instead of fanning out + // into N static `set_flag` / `wait_flag` pairs per iteration. + mlir::Value slotSSAExpr; SetWaitOp(const OpType &opType, Operation *op, OperationBase *parentOp, const llvm::SmallVector &eventIds, pto::PIPE pipeSrc, diff --git a/include/PTO/Transforms/InsertSync/PTOIRTranslator.h b/include/PTO/Transforms/InsertSync/PTOIRTranslator.h index 9f329e517..185c4a8cb 100644 --- a/include/PTO/Transforms/InsertSync/PTOIRTranslator.h +++ b/include/PTO/Transforms/InsertSync/PTOIRTranslator.h @@ -79,6 +79,7 @@ class PTOIRTranslator { void UpdateConservativeAliasBufferInfo(Value result, Value source); void UpdateMemrefSubViewAliasBufferInfo(memref::SubViewOp op); void UpdateTileSubViewAliasBufferInfo(pto::SubViewOp op); + void UpdateSlotMarkerAliasBufferInfo(pto::SlotMarkerOp op); // --- 控制流处理 (SCF) --- void UpdateForOpInfo(scf::ForOp forOp); diff --git a/include/PTO/Transforms/InsertSync/SyncCommon.h b/include/PTO/Transforms/InsertSync/SyncCommon.h index 09aa4dc9d..0016e9051 100644 --- a/include/PTO/Transforms/InsertSync/SyncCommon.h +++ b/include/PTO/Transforms/InsertSync/SyncCommon.h @@ -159,6 +159,11 @@ class SyncOperation { SmallVector depRootBuffers; bool uselessSync{false}; int eventIdNum{1}; + // For multi-buffer dyn-event sync: the slot SSA expression at this access + // site. set_flag_dyn / wait_flag_dyn use `slotSSAExpr % slotCount` as the + // hardware event-id index. Empty when this sync is single-buffer. + Value slotSSAExpr; + uint32_t slotCount{1}; Value lowestCommonAncestorBuffer{nullptr}; int reuseCntForWiden{0}; bool reallocatedLoopHeadTailSync{false}; diff --git a/include/PTO/Transforms/Passes.h b/include/PTO/Transforms/Passes.h index 06667fdeb..e0612b245 100644 --- a/include/PTO/Transforms/Passes.h +++ b/include/PTO/Transforms/Passes.h @@ -68,6 +68,7 @@ createPlanMemoryPass(const PlanMemoryOptions &planMemoryOption = {}); std::unique_ptr createPTORemoveRedundantBarrierPass(); std::unique_ptr createPTOViewToMemrefPass(); std::unique_ptr createPTOMaterializeTileHandlesPass(); +std::unique_ptr createPTOResolveBufferSelectPass(); std::unique_ptr createInferPTOLayoutPass(); std::unique_ptr createPTOA5NormalizeTMovPass(); diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index 444efe268..91f1770a6 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -316,4 +316,35 @@ def PTOMaterializeTileHandles : Pass<"pto-materialize-tile-handles", "ModuleOp"> ]; } +def PTOResolveBufferSelect : Pass<"pto-resolve-buffer-select", "ModuleOp"> { + let summary = "Lower pto.slot_marker to single-slot pointer_cast (multi-buffer)"; + let description = [{ + Consumes `pto.slot_marker %src[%k]` ops written by PTOViewToMemref to + thread multi-buffer slot selection through the memref layer. + + The op is replaced by a fresh single-address `pto.pointer_cast` that + refers to the chosen physical slot: + + - constant slot `k`: a single `pto.pointer_cast(addrK)` is emitted at + the use site, using slot k from the underlying multi-address + pointer_cast created by PTOPlanMemory / AllocToPointerCast. + - dynamic slot `%k`: per-slot single-address `pto.pointer_cast`s are + created and an N-way `arith.select` chain picks one according to the + user-supplied SSA. This pass does NOT rewrite `%k` into `iv mod N`; + the frontend expression is the slot selector. + + The original multi-address `pto.pointer_cast` is kept in IR as the + "alloc anchor" so future sync passes can recognize the multi-buffer + geometry. + }]; + + let constructor = "mlir::pto::createPTOResolveBufferSelectPass()"; + + let dependentDialects = [ + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::arith::ArithDialect" + ]; +} + #endif // MLIR_DIALECT_PTO_PASSES diff --git a/include/PTO/Transforms/SlotAffineAnalysis.h b/include/PTO/Transforms/SlotAffineAnalysis.h new file mode 100644 index 000000000..4e265e77a --- /dev/null +++ b/include/PTO/Transforms/SlotAffineAnalysis.h @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- SlotAffineAnalysis.h - Multi-buffer slot affine compare --*- C++ -*-===// +// +// Small affine helper used by the multi-buffer sync path. Both InsertSync +// and GraphSyncSolver consume it to decide, for two `pto.slot_marker` +// slot-index SSA expressions, whether they are provably equal modulo N, +// provably disjoint modulo N, or indeterminate. The result lets sync +// shrink event-id count or skip same-iter forward syncs entirely when +// producer and consumer touch different slots in every iteration. +// +//===----------------------------------------------------------------------===// + +#ifndef PTO_TRANSFORMS_SLOTAFFINEANALYSIS_H +#define PTO_TRANSFORMS_SLOTAFFINEANALYSIS_H + +#include "mlir/IR/Value.h" +#include + +namespace mlir { +namespace pto { + +/// Three-valued relation between two multi-buffer slot SSA expressions +/// taken modulo `N`. Anything we cannot prove statically degrades to +/// `kUnknown`, which the callers treat conservatively (i.e. fall back to +/// the existing all-slots-may-overlap path). +enum class SlotRelation { + kEqual, // a(iv) == b(iv) (mod N) for every iv + kDisjoint, // a(iv) != b(iv) (mod N) for every iv + kUnknown, // can neither prove equal nor disjoint +}; + +/// Walk back through metadata-only ops (`pto.bind_tile`) to the nearest +/// `pto.slot_marker` and return its slot SSA value. Returns a null Value +/// if the chain does not pass through a slot_marker. +mlir::Value findSlotMarkerExpr(mlir::Value v); + +/// Compare two slot SSA expressions modulo `N`. The analysis is +/// intentionally narrow: it accepts the forms commonly produced by +/// frontends and lowerings (`iv % N`, `(iv + c) % N`, `c`, and same-SSA +/// equality) and bails to `kUnknown` for anything else. +/// +/// Examples (all with N == 2): +/// compareSlotSSA(%iv % 2, %iv % 2) -> kEqual +/// compareSlotSSA((%iv + 1) % 2, %iv % 2) -> kDisjoint +/// compareSlotSSA((%iv + 3) % 2, %iv % 2) -> kDisjoint // 3 % 2 == 1 +/// compareSlotSSA((%iv + 2) % 2, %iv % 2) -> kEqual // 2 % 2 == 0 +/// compareSlotSSA(%iv % 2, %j % 2) -> kUnknown // diff symbols +/// compareSlotSSA(arith.constant 0, arith.constant 1) -> kDisjoint +SlotRelation compareSlotSSA(mlir::Value a, mlir::Value b, uint32_t N); + +} // namespace pto +} // namespace mlir + +#endif // PTO_TRANSFORMS_SLOTAFFINEANALYSIS_H diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index 721314590..193f2c8d0 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -10,6 +10,7 @@ //===----------------------------------------------------------------------===// #include "PTO/IR/PTO.h" +#include "PTO/IR/PTOMultiBuffer.h" #include "PTO/IR/PTOTypeUtils.h" #include "PTO/IR/PTOSyncUtils.h" @@ -2246,6 +2247,87 @@ LogicalResult AllocTileOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// AllocMultiTileOp / MultiTileGetOp +//===----------------------------------------------------------------------===// + +LogicalResult AllocMultiTileOp::verify() { + auto mtbTy = getResult().getType(); + if (!mtbTy) + return emitOpError("result must be `!pto.multi_tile_buf`"); + + TileBufType slotTy = mtbTy.getSlotType(); + if (!slotTy) + return emitOpError("multi_tile_buf slot type must be non-null"); + + // Reuse the AllocTileOp valid_row/valid_col contract on the slot type. + Type elemTy = slotTy.getElementType(); + if (isPTOLowPrecisionType(elemTy)) + return emitOpError() << "slot dtype " << elemTy + << " is not supported by pto.alloc_multi_tile yet"; + + if (failed(verifyTileBufLayoutConstraints(*this, slotTy, "slot"))) + return failure(); + + bool hasVR = getValidRow() != nullptr; + bool hasVC = getValidCol() != nullptr; + auto vs = slotTy.getValidShape(); + if (vs.size() != 2) + return emitOpError("slot tile_buf must have rank-2 validShape"); + + bool needVR = (vs[0] < 0); + bool needVC = (vs[1] < 0); + if (hasVR != needVR) + return emitOpError() << "valid_row operand " + << (needVR ? "is required" : "must be absent") + << " because slot v_row is " + << (needVR ? "?" : std::to_string(vs[0])); + if (hasVC != needVC) + return emitOpError() << "valid_col operand " + << (needVC ? "is required" : "must be absent") + << " because slot v_col is " + << (needVC ? "?" : std::to_string(vs[1])); + + // Count bounds are also enforced by MultiTileBufType::verify, but repeat + // here so the error points at the alloc op the user wrote. + uint32_t count = mtbTy.getCount(); + if (count < kPtoMultiBufferMinNum || count > kPtoMultiBufferMaxNum) { + return emitOpError() << "multi_tile_buf count must be in [" + << kPtoMultiBufferMinNum << ", " + << kPtoMultiBufferMaxNum << "] (got " << count << ")"; + } + + return success(); +} + +LogicalResult MultiTileGetOp::verify() { + auto srcTy = getSource().getType(); + auto resultTy = getResult().getType(); + if (!srcTy || !resultTy) + return emitOpError("source and result types must be non-null"); + + if (srcTy.getSlotType() != resultTy) { + return emitOpError() + << "result tile_buf must match the multi_tile_buf slot type: " + << "expected " << srcTy.getSlotType() << ", got " << resultTy; + } + + // If slot is an `arith.constant`, check it is in range. + if (auto slotDef = getSlot().getDefiningOp()) { + if (auto attr = llvm::dyn_cast(slotDef.getValue())) { + int64_t slotVal = attr.getValue().getSExtValue(); + int64_t count = static_cast(srcTy.getCount()); + if (slotVal < 0 || slotVal >= count) { + return emitOpError() + << "constant slot " << slotVal + << " is out of range for multi_tile_buf count=" << count; + } + } + } + + return success(); +} + LogicalResult MaterializeTileOp::verify() { auto sourceTy = cast(getSource().getType()); auto resultTy = cast(getResult().getType()); diff --git a/lib/PTO/IR/PTOTypeDefs.cpp b/lib/PTO/IR/PTOTypeDefs.cpp index a3d5ab596..7291d9e8f 100644 --- a/lib/PTO/IR/PTOTypeDefs.cpp +++ b/lib/PTO/IR/PTOTypeDefs.cpp @@ -8,6 +8,7 @@ //===- PTOTypeDefs.cpp --------------------------------------------*- C++ -*-===// #include "PTO/IR/PTO.h" +#include "PTO/IR/PTOMultiBuffer.h" #include "mlir/IR/DialectImplementation.h" #include #include @@ -294,9 +295,16 @@ static LogicalResult parseLegacyTileBufFields(AsmParser &parser, return success(); } +// When `outMultiCount` is non-null, the parser is willing to consume an +// optional trailing `, count=N` clause as belonging to a wrapping +// `multi_tile_buf` instead of treating it as an unknown tile_buf field. The +// extracted N is written into `*outMultiCount` and the loop exits without +// consuming additional fields. When `outMultiCount` is null, `count` is +// treated as an unknown key (preserving the original tile_buf semantics). static LogicalResult parseCompactTileBufFields(AsmParser &parser, StringRef firstToken, - ParsedTileBufFields &fields) { + ParsedTileBufFields &fields, + uint32_t *outMultiCount = nullptr) { fields.locStr = firstToken.str(); if (failed(parser.parseComma())) @@ -430,6 +438,14 @@ static LogicalResult parseCompactTileBufFields(AsmParser &parser, continue; } + if (outMultiCount && key == "count") { + // Tail field belonging to a wrapping multi_tile_buf<...>. Consume the + // integer and return success; the wrapper finishes the parse. + if (failed(parseTileBufUInt32Value(parser, key, *outMultiCount))) + return failure(); + return success(); + } + parser.emitError(parser.getCurrentLocation(), "unknown key in tile_buf compact syntax: ") << key; @@ -657,3 +673,112 @@ void mlir::pto::TileBufType::print(mlir::AsmPrinter &printer) const { printer << ">"; } + +// ---- MultiTileBufType custom asm ---- +// +// Syntax: +// +// Verbose form: +// !pto.multi_tile_buf, count=N> +// +// Compact (sugar) form: +// !pto.multi_tile_buf +// +// In the compact form the per-slot tile_buf is built from the same compact +// syntax as `!pto.tile_buf`, followed by a mandatory `count=N`. + +LogicalResult MultiTileBufType::verify( + function_ref emitError, + mlir::pto::TileBufType slotType, uint32_t count) { + if (!slotType) { + return emitError() << "multi_tile_buf slot type must be non-null"; + } + if (count < kPtoMultiBufferMinNum) { + return emitError() << "multi_tile_buf count must be >= " + << kPtoMultiBufferMinNum << " (got " << count << ")"; + } + if (count > kPtoMultiBufferMaxNum) { + return emitError() << "multi_tile_buf count must be <= " + << kPtoMultiBufferMaxNum << " (got " << count << ")"; + } + return success(); +} + +namespace { +// Parse a trailing `, count = N` clause. The caller has already parsed the +// per-slot tile_buf description; we must now consume the count and the +// closing `>`. +static LogicalResult parseMultiTileBufCount(AsmParser &parser, + uint32_t &count) { + if (failed(parser.parseComma())) + return failure(); + if (failed(parser.parseKeyword("count"))) + return failure(); + if (failed(parser.parseEqual())) + return failure(); + uint32_t parsed = 0; + if (failed(parseTileBufUInt32Value(parser, "count", parsed))) + return failure(); + count = parsed; + return success(); +} +} // namespace + +Type MultiTileBufType::parse(AsmParser &parser) { + if (failed(parser.parseLess())) + return Type(); + + MLIRContext *ctx = parser.getContext(); + TileBufType slotType; + uint32_t count = 0; + bool countConsumedByCompact = false; + + // Verbose form: an explicit `!pto.tile_buf<...>` type token comes next. + // Compact form: a bare keyword (loc such as `vec`/`mat`/...) comes next. + Type maybeType; + OptionalParseResult typeRes = parser.parseOptionalType(maybeType); + if (typeRes.has_value()) { + if (failed(*typeRes)) + return Type(); + slotType = llvm::dyn_cast(maybeType); + if (!slotType) { + parser.emitError(parser.getCurrentLocation(), + "multi_tile_buf slot type must be `!pto.tile_buf<...>`"); + return Type(); + } + } else { + // Compact form: parse via the same compact path used by tile_buf, but + // tell it to consume the trailing `, count=N` on our behalf. + std::string firstToken; + if (failed(parser.parseKeywordOrString(&firstToken))) + return Type(); + + ParsedTileBufFields fields; + if (failed(parseCompactTileBufFields(parser, firstToken, fields, &count))) + return Type(); + + Type built = buildTileBufType(parser, fields); + if (!built) + return Type(); + slotType = llvm::cast(built); + countConsumedByCompact = (count != 0); + } + + if (!countConsumedByCompact) { + if (failed(parseMultiTileBufCount(parser, count))) + return Type(); + } + + if (failed(parser.parseGreater())) + return Type(); + + return getChecked( + [&]() { return parser.emitError(parser.getNameLoc()); }, ctx, slotType, + count); +} + +void MultiTileBufType::print(AsmPrinter &printer) const { + printer << "<"; + printer.printType(getSlotType()); + printer << ", count=" << getCount() << ">"; +} diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index 9c1f7d22c..72837bf42 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -31,7 +31,9 @@ add_mlir_dialect_library(PTOTransforms PTOAssignDefaultFrontendPipeIdPass.cpp PTOLowerFrontendPipeOpsPass.cpp PTOInferValidatePipeInitPass.cpp + PTOResolveBufferSelect.cpp PTOResolveReservedBuffersPass.cpp + SlotAffineAnalysis.cpp PTOWrapFunctionsInSectionsPass.cpp InsertSync/PTOIRTranslator.cpp InsertSync/SyncCommon.cpp diff --git a/lib/PTO/Transforms/GraphSyncSolver/MemInfo.cpp b/lib/PTO/Transforms/GraphSyncSolver/MemInfo.cpp index 50d0024cb..daf6971e9 100644 --- a/lib/PTO/Transforms/GraphSyncSolver/MemInfo.cpp +++ b/lib/PTO/Transforms/GraphSyncSolver/MemInfo.cpp @@ -14,6 +14,7 @@ #include "../Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/Value.h" #include "llvm/Support/ErrorHandling.h" #include @@ -70,11 +71,77 @@ PointerLikeInfo getPointerLikeInfo(pto::PointerCastOp pointerCastOp) { return pointerLikeInfo; } +// Walk back through metadata-only view ops (`pto.bind_tile`) to the +// nearest `pto.pointer_cast`. Used to anchor slot_marker MemInfo on its +// underlying multi-address alloc cast. +static pto::PointerCastOp findUnderlyingPointerCast(Value v) { + int hops = 0; + while (v && hops++ < 32) { + Operation *op = v.getDefiningOp(); + if (!op) + return {}; + if (auto pc = llvm::dyn_cast(op)) + return pc; + if (auto bind = llvm::dyn_cast(op)) { + v = bind.getSource(); + continue; + } + return {}; + } + return {}; +} + +// Build a MemInfo for a `pto.slot_marker` use. For a constant slot K the +// MemInfo carries just slot K's physical address so two const-slot +// accesses on different slots come back as non-conflicting via the +// existing `PointerLikeInfo::checkConflict` byte-range overlap. For a +// dynamic slot the MemInfo carries all N physical addresses; downstream +// `checkMultiBufferEventIdInfo` then deduces N event ids using the +// `(i % N) == (j % N)` slot-skipping rule, which is exactly the +// multi-buffer prefetch pattern. +// +// Note on `parentLoop`: `getPointerLikeInfo` records the parent loop of +// the cast op, which is typically outside the multi-buffer scf.for (the +// alloc/cast lives at function scope). The multi-buffer geometry, though, +// is keyed by the loop that *uses* the slot. We override `parentLoop` +// with the slot_marker's enclosing LoopLikeOpInterface so +// `getMultiBufferLoop` finds the right anchor. +static MemInfo getMemInfoForSlotMarker(pto::SlotMarkerOp slotMarker) { + pto::PointerCastOp castOp = findUnderlyingPointerCast(slotMarker.getSource()); + if (!castOp) { + return MemInfo(slotMarker.getResult(), + isWorkSpaceFuncArgument(slotMarker.getResult())); + } + + PointerLikeInfo info = getPointerLikeInfo(castOp); + + IntegerAttr constSlotAttr; + if (matchPattern(slotMarker.getSlot(), m_Constant(&constSlotAttr)) && + info.addresses.size() > 1) { + int64_t slotIdx = constSlotAttr.getValue().getSExtValue(); + if (slotIdx >= 0 && slotIdx < static_cast(info.addresses.size())) { + int64_t picked = info.addresses[static_cast(slotIdx)]; + info.addresses.clear(); + info.addresses.push_back(picked); + } + } + + if (auto useLoop = + slotMarker->template getParentOfType()) { + info.parentLoop = useLoop; + } + + return MemInfo(slotMarker.getResult(), info); +} + MemInfo getMemInfo(Value val) { if (auto *defOp = val.getDefiningOp()) { if (auto pointerCastOp = llvm::dyn_cast(defOp)) { return MemInfo(val, getPointerLikeInfo(pointerCastOp)); } + if (auto slotMarker = llvm::dyn_cast(defOp)) { + return getMemInfoForSlotMarker(slotMarker); + } } return MemInfo(val, isWorkSpaceFuncArgument(val)); } diff --git a/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp b/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp index 23a4032a6..41935c1db 100644 --- a/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp +++ b/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp @@ -14,6 +14,7 @@ #include "PTO/Transforms/GraphSyncSolver/MemInfo.h" #include "PTO/Transforms/GraphSyncSolver/SyncSolverIR.h" #include "PTO/Transforms/GraphSyncSolver/Utility.h" +#include "PTO/Transforms/SlotAffineAnalysis.h" #include "PTO/IR/PTO.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -30,6 +31,7 @@ #include #include #include +#include #include #include #include @@ -40,6 +42,42 @@ using namespace mlir; using namespace pto::syncsolver; +namespace { +// Pick a slot SSA expression that this access touches, scanning the +// rwOp's read + write memrefs. Returns null if none of them reach a +// slot_marker -- in that case codegen falls back to the existing N-static +// `set_flag` / `wait_flag` fanout. +mlir::Value findSlotSSAExprForRWOp(RWOperation *rwOp) { + if (!rwOp) + return {}; + for (auto &v : rwOp->readMemVals) + if (auto slot = mlir::pto::findSlotMarkerExpr(v)) + return slot; + for (auto &v : rwOp->writeMemVals) + if (auto slot = mlir::pto::findSlotMarkerExpr(v)) + return slot; + return {}; +} + +mlir::pto::SlotRelation compareMemInfoSlotSSA(const MemInfo &memInfo1, + const MemInfo &memInfo2) { + size_t n = std::max(memInfo1.getSz(), memInfo2.getSz()); + if (n < 2 || n > std::numeric_limits::max()) + return mlir::pto::SlotRelation::kUnknown; + Value slot1 = mlir::pto::findSlotMarkerExpr(memInfo1.value); + Value slot2 = mlir::pto::findSlotMarkerExpr(memInfo2.value); + if (!slot1 || !slot2) + return mlir::pto::SlotRelation::kUnknown; + return mlir::pto::compareSlotSSA(slot1, slot2, static_cast(n)); +} + +bool isForwardDepDroppableBySlotAffine(const MemInfo &memInfo1, + const MemInfo &memInfo2) { + return compareMemInfoSlotSSA(memInfo1, memInfo2) == + mlir::pto::SlotRelation::kDisjoint; +} +} // namespace + // Reset per-pass bookkeeping to start fresh. void Solver::reset(bool resetEventIdRanOutOpts) { if (resetEventIdRanOutOpts) { @@ -294,6 +332,63 @@ Solver::checkMemoryConflicts(RWOperation *rwOp1, RWOperation *rwOp2) { return it->second = collectedConflicts; } +llvm::SmallVector> +Solver::checkMemoryConflictsForOcc(Occurrence *occ1, Occurrence *occ2, + RWOperation *rwOp1, RWOperation *rwOp2) { + assert(occ1 != nullptr && occ2 != nullptr); + assert(rwOp1 != nullptr && rwOp2 != nullptr); + if (isBackwardSync(occ1, occ2)) { + return checkMemoryConflicts(rwOp1, rwOp2); + } + + auto coreSrc = rwOp1->coreType; + auto coreDst = rwOp2->coreType; + if (options.isCrossCoreMode()) { + if (coreDst == pto::TCoreType::CUBE_AND_VECTOR) { + coreDst = (coreSrc == pto::TCoreType::VECTOR) ? pto::TCoreType::CUBE + : pto::TCoreType::VECTOR; + } + assert(coreSrc == pto::TCoreType::VECTOR || + coreSrc == pto::TCoreType::CUBE); + assert(coreDst == pto::TCoreType::VECTOR || + coreDst == pto::TCoreType::CUBE); + } + + auto hasForwardConflict = + [&](const llvm::SmallVector &memInfoList1, + const llvm::SmallVector &memInfoList2) -> bool { + for (auto &memInfo1 : memInfoList1) { + for (auto &memInfo2 : memInfoList2) { + if (!checkMemInfoConflict(rwOp1, rwOp2, memInfo1, memInfo2)) { + continue; + } + if (isForwardDepDroppableBySlotAffine(memInfo1, memInfo2)) { + continue; + } + return true; + } + } + return false; + }; + + llvm::SetVector> collectedConflictsSet; + if (hasForwardConflict(rwOp1->readMemInfo, rwOp2->writeMemInfo)) { + collectedConflictsSet.insert({CorePipeInfo(coreSrc, rwOp1->pipeRead), + CorePipeInfo(coreDst, rwOp2->pipeWrite)}); + } + if (hasForwardConflict(rwOp1->writeMemInfo, rwOp2->readMemInfo)) { + collectedConflictsSet.insert({CorePipeInfo(coreSrc, rwOp1->pipeWrite), + CorePipeInfo(coreDst, rwOp2->pipeRead)}); + } + if (hasForwardConflict(rwOp1->writeMemInfo, rwOp2->writeMemInfo)) { + collectedConflictsSet.insert({CorePipeInfo(coreSrc, rwOp1->pipeWrite), + CorePipeInfo(coreDst, rwOp2->pipeWrite)}); + } + llvm::SmallVector> collectedConflicts( + collectedConflictsSet.begin(), collectedConflictsSet.end()); + return collectedConflicts; +} + bool Solver::checkMemoryConflictBetweenOccExclusive( Occurrence *occ1, Occurrence *occ2, std::function filter) { @@ -452,6 +547,12 @@ Solver::getMultiBufferEventIdInfo(Occurrence *occ1, Occurrence *occ2, minWriteSize = 1; return {}; } + // (Same-SSA equal-slot accesses used to early-bail here, falling back to + // a single static event id. That diverged from InsertSync, which still + // allocates N dyn event ids for same-SSA prefetch so different + // iterations touching different physical slots can pipeline. The bail + // was removed for consistency; the standard N-way deduction below now + // also runs for same-SSA pairs.) int64_t eventIdNum = minWriteSize; for (; eventIdNum >= 1; eventIdNum--) { @@ -2258,6 +2359,14 @@ SyncBeforeAfterMap Solver::getBeforeAfterSyncMaps() { waitOp->eventIdInfo = conflictPair->eventIdInfo; setOp->checkLastIter = conflictPair->setOnLastIterOnly; waitOp->checkFirstIter = conflictPair->waitOnFirstIterOnly; + // For multi-buffer back-edge syncs, plumb each side's slot SSA. The + // set side sits at op1 (producer); the wait side sits at op2 + // (consumer). Codegen uses these to lower into `pto.set_flag_dyn` / + // `pto.wait_flag_dyn`. + if (conflictPair->eventIdInfo.eventIdNum > 1) { + setOp->slotSSAExpr = findSlotSSAExprForRWOp(conflictPair->op1); + waitOp->slotSSAExpr = findSlotSSAExprForRWOp(conflictPair->op2); + } LLVM_DEBUG({ setOp->debugId = conflictPair->id; waitOp->debugId = conflictPair->id; @@ -2311,7 +2420,8 @@ SyncBeforeAfterMap Solver::getBeforeAfterSyncMaps() { void Solver::processConflict(Occurrence *occ1, Occurrence *occ2, RWOperation *rwOp1, RWOperation *rwOp2, bool isUseless) { - for (auto [corePipeSrc, corePipeDst] : checkMemoryConflicts(rwOp1, rwOp2)) { + for (auto [corePipeSrc, corePipeDst] : + checkMemoryConflictsForOcc(occ1, occ2, rwOp1, rwOp2)) { if (options.alwaysUsePipeSAsWaitingPipe) { corePipeDst.pipe = pto::PIPE::PIPE_S; } diff --git a/lib/PTO/Transforms/GraphSyncSolver/SyncSolverCodeGen.cpp b/lib/PTO/Transforms/GraphSyncSolver/SyncSolverCodeGen.cpp index eaca68a5b..553a8adfd 100644 --- a/lib/PTO/Transforms/GraphSyncSolver/SyncSolverCodeGen.cpp +++ b/lib/PTO/Transforms/GraphSyncSolver/SyncSolverCodeGen.cpp @@ -12,6 +12,7 @@ #include "PTO/Transforms/GraphSyncSolver/SyncSolverCodeGen.h" #include "PTO/IR/PTO.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Builders.h" #include "llvm/Support/Casting.h" @@ -116,14 +117,53 @@ void CodeGenerator::emitSyncOp(IRRewriter &rewriter, SyncOp *syncOp) { assert(!setWait->checkLastIter && "checkLastIter wrapping not implemented in codegen"); - // One set/wait op per assigned event id. The current solver only assigns - // a single id per node, but the codegen handles multi-id assignments so a - // future multi-buffer pass can plug in without re-touching this layer. auto srcAttr = makePipe(rewriter.getContext(), setWait->pipeSrc); auto dstAttr = makePipe(rewriter.getContext(), setWait->pipeDst); Location loc = resolveSyncLoc(setWait); bool isSet = isa(setWait); bool isWait = isa(setWait); + + // Multi-buffer dyn-event path: when the sync was produced by a + // multi-buffer back-edge (eventIds.size() > 1 AND we have a slot SSA), + // collapse the per-slot fanout into a single `pto.set_flag_dyn` / + // `pto.wait_flag_dyn`. The hardware event id is selected at runtime by + // an N-way `arith.select` chain over the allocated event ids keyed off + // `slotSSAExpr % N`. The `allAtOnce` scopes (pre-loop prime / post-loop + // drain) keep their static fanout so each per-slot event gets primed / + // drained exactly once. + if (!setWait->allAtOnce && setWait->eventIds.size() > 1 && + setWait->slotSSAExpr) { + int64_t n = static_cast(setWait->eventIds.size()); + Value slot = setWait->slotSSAExpr; + if (slot.getType() != rewriter.getIndexType()) { + slot = rewriter.create( + loc, rewriter.getIndexType(), slot); + } + Value nConst = rewriter.create(loc, n); + Value slotMod = rewriter.create(loc, slot, nConst); + + Value selected = rewriter.create( + loc, setWait->eventIds[0]); + for (int64_t i = 1; i < n; ++i) { + Value iIdx = rewriter.create(loc, i); + Value isThis = rewriter.create( + loc, arith::CmpIPredicate::eq, slotMod, iIdx); + Value idI = rewriter.create( + loc, setWait->eventIds[static_cast(i)]); + selected = rewriter.create(loc, isThis, idI, selected); + } + + if (isSet) + rewriter.create(loc, srcAttr, dstAttr, selected); + else if (isWait) + rewriter.create(loc, srcAttr, dstAttr, selected); + return; + } + + // Fallback / scope-anchored path: one static set/wait per assigned event + // id. Always used for `allAtOnce` prime/drain pairs and for syncs that + // could not have their slot SSA recovered (in which case we degrade to + // the conservative N-static fanout rather than dropping the dep). for (int64_t eventId : setWait->eventIds) { auto eventAttr = makeEvent(rewriter.getContext(), eventId); if (isSet) diff --git a/lib/PTO/Transforms/GraphSyncSolver/SyncSolverIRTranslator.cpp b/lib/PTO/Transforms/GraphSyncSolver/SyncSolverIRTranslator.cpp index a4cc5dd09..3818e6ace 100644 --- a/lib/PTO/Transforms/GraphSyncSolver/SyncSolverIRTranslator.cpp +++ b/lib/PTO/Transforms/GraphSyncSolver/SyncSolverIRTranslator.cpp @@ -72,6 +72,15 @@ llvm::SmallVector IRTranslator::tracebackMemValsStep(Value val) { out.push_back(whileOp.getYieldedValues()[resultNo]); } + // Stop the walk at `pto.slot_marker` so the multi-buffer slot index is + // preserved for `getMemInfo`. Without this special case, the generic + // `getOperationAliasInfo` path below would treat slot_marker as a + // transparent view and let the trace fall through to the underlying + // multi-address `pto.pointer_cast`, dropping the slot. + if (isa(defOp)) { + return out; + } + if (auto alias = pto::getOperationAliasInfo(defOp)) { if (alias->first == result) out.push_back(alias->second); @@ -113,8 +122,13 @@ llvm::SmallVector IRTranslator::tracebackMemVals(Value val) { if (!result) continue; Operation *defOp = result.getDefiningOp(); + // `pto.slot_marker` is a multi-buffer slot tag and stops traversal so + // `getMemInfo` can extract slot-narrowed addresses below. Without this + // stop, `getOperationAliasInfo` would let the walk slip past slot_marker + // and reach the underlying multi-address `pto.pointer_cast`, dropping + // the slot index. if (isa(defOp)) { + memref::AllocOp, pto::SlotMarkerOp>(defOp)) { leaves.insert(result); continue; } diff --git a/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp b/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp index 0d93b1d61..0fbc35b66 100644 --- a/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp +++ b/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp @@ -11,9 +11,11 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. +#include "PTO/IR/PTO.h" #include "PTO/Transforms/InsertSync/InsertSyncAnalysis.h" #include "PTO/IR/PTOTypeUtils.h" #include "PTO/Transforms/InsertSync/SyncCommon.h" +#include "PTO/Transforms/SlotAffineAnalysis.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinTypes.h" @@ -436,6 +438,29 @@ void InsertSyncAnalysis::InsertSync( MemAnalyze(nowCompound, frontCompound, syncRecordList, forEndIndex); } +// Returns true if a *same-iter* multi-buffer dep pair can be dropped +// because the producer's and consumer's slot SSA expressions are provably +// disjoint modulo N. Only applied to forward (non-back-edge) deps -- the +// back-edge path still needs to sync per-slot via dyn event id (the +// prefetch idiom). When the analysis is inconclusive (kUnknown / kEqual) +// the dep is kept and the existing conservative path runs. +static bool isForwardDepDroppableBySlotAffine(const BaseMemInfo *a, + const BaseMemInfo *b) { + if (!a || !b) + return false; + size_t aN = a->baseAddresses.size(); + size_t bN = b->baseAddresses.size(); + size_t n = std::max(aN, bN); + if (n < 2) + return false; + Value slotA = findSlotMarkerExpr(a->baseBuffer); + Value slotB = findSlotMarkerExpr(b->baseBuffer); + if (!slotA || !slotB) + return false; + return compareSlotSSA(slotA, slotB, static_cast(n)) == + SlotRelation::kDisjoint; +} + void InsertSyncAnalysis::MemAnalyze( CompoundInstanceElement *nowCompound, CompoundInstanceElement *frontCompound, SyncRecordList &syncRecordList, @@ -449,6 +474,21 @@ void InsertSyncAnalysis::MemAnalyze( return; } + // Same-iter (forward) deps: drop pairs that the affine analysis proves + // touch disjoint slots in every iteration of the multi-buffer loop. + // Back-edge deps stay untouched -- they still need per-slot syncing + // through the dyn-event-id pipeline. + if (!forEndIndex.has_value()) { + auto isDroppable = [](const std::pair &pair) { + return isForwardDepDroppableBySlotAffine(pair.first, pair.second); + }; + depVec.erase(std::remove_if(depVec.begin(), depVec.end(), isDroppable), + depVec.end()); + if (depVec.empty()) + return; + } + if (CanPrunePipeVBarrier(nowCompound, frontCompound, depVec, forEndIndex)) { return; } @@ -575,9 +615,38 @@ void InsertSyncAnalysis::InsertSyncOperation( setOp->SetDepSyncIRIndex(frontCompound->GetIndex()); waitOp->SetDepSyncIRIndex(frontCompound->GetIndex()); - // Back-edge dependencies may require multi-buffer event IDs. + // Back-edge dependencies may require multi-buffer event IDs. When N + // dyn event IDs are warranted, also plumb the per-side slot SSA so + // codegen can lower into `pto.set_flag_dyn` / `pto.wait_flag_dyn`. if (forEndIndex.has_value()) { int eventIdNum = GetEventIdNum(depBaseMemInfosVec); + if (eventIdNum > 1) { + // Each dep pair has (now=consumer, front=producer). The producer's + // slot SSA gates the `set_flag_dyn`; the consumer's gates the + // `wait_flag_dyn`. Walk the first viable dep pair to extract them. + Value producerSlot; + Value consumerSlot; + for (auto &pair : depBaseMemInfosVec) { + if (pair.second && pair.second->baseBuffer) + producerSlot = findSlotMarkerExpr(pair.second->baseBuffer); + if (pair.first && pair.first->baseBuffer) + consumerSlot = findSlotMarkerExpr(pair.first->baseBuffer); + if (producerSlot && consumerSlot) + break; + } + if (!producerSlot || !consumerSlot) { + // No slot SSA threaded through -- fall back to single event id. + // This keeps non-multi-buffer codepaths untouched even if their + // baseAddresses happen to have multiple entries for some other + // reason (e.g. memref subview). + eventIdNum = 1; + } else { + setOp->slotSSAExpr = producerSlot; + setOp->slotCount = static_cast(eventIdNum); + waitOp->slotSSAExpr = consumerSlot; + waitOp->slotCount = static_cast(eventIdNum); + } + } setOp->eventIdNum = eventIdNum; waitOp->eventIdNum = eventIdNum; } @@ -740,6 +809,16 @@ SmallVector InsertSyncAnalysis::GetMemInfoBuffers( int InsertSyncAnalysis::GetEventIdNum( const DepBaseMemInfoPairVec &depBaseMemInfosVec) { + // A back-edge dependency benefits from N dynamic event IDs whenever at + // least one side is a multi-buffer access. We detect that from the + // BaseMemInfo's `baseAddresses` size, which `UpdateSlotMarkerAliasBufferInfo` + // populated: + // - kSingle / const-slot : size == 1 + // - dyn-slot (PTOIRTranslator default) : size == N (all slots, conservative) + // For the alias to even reach this point both sides share a root, so the + // slot count derived from either side's full address set should be the + // same N. We pick the max to be robust against accidental narrowing. + int eventIdNum = 1; for (const auto &pair : depBaseMemInfosVec) { bool isLocalA = pair.first && (pair.first->scope == pto::AddressSpace::MAT || @@ -747,9 +826,23 @@ int InsertSyncAnalysis::GetEventIdNum( bool isLocalB = pair.second && (pair.second->scope == pto::AddressSpace::MAT || pair.second->scope == pto::AddressSpace::VEC); - if (isLocalA || isLocalB) return 1; + if (!isLocalA && !isLocalB) + continue; + size_t aN = pair.first ? pair.first->baseAddresses.size() : 1; + size_t bN = pair.second ? pair.second->baseAddresses.size() : 1; + int pairN = static_cast(std::max(aN, bN)); + if (pairN <= 1) + continue; + if (eventIdNum == 1) { + eventIdNum = pairN; + } else if (eventIdNum != pairN) { + // Multiple dep pairs disagreeing on N: fall back to single event id + // for safety. With more work this could be relaxed by per-pair + // multi-buffer reasoning. + return 1; + } } - return 1; + return eventIdNum; } bool InsertSyncAnalysis::IsGMHazard( diff --git a/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp b/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp index 1aad79749..46dc8085b 100644 --- a/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp +++ b/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp @@ -335,6 +335,9 @@ void PTOIRTranslator::RecursionIR(Region *region) { else if (auto memrefSubView = dyn_cast(op)) { UpdateMemrefSubViewAliasBufferInfo(memrefSubView); } + else if (auto slotMarker = dyn_cast(op)) { + UpdateSlotMarkerAliasBufferInfo(slotMarker); + } else if (auto castOp = dyn_cast(op)) { UpdateAliasBufferInfo(castOp.getResult(), castOp.getSource()); } @@ -345,6 +348,21 @@ void PTOIRTranslator::RecursionIR(Region *region) { else if (auto expandOp = dyn_cast(op)) { UpdateAliasBufferInfo(expandOp.getResult(), expandOp.getSrc()); } + // arith.select on memref values: emitted by `PTOResolveBufferSelect` + // for the dynamic-slot path of `pto.multi_tile_get`. The result + // conservatively aliases both branches so sync analysis cannot miss a + // dependency that arises through either possible slot. This preserves + // correctness for runtime slot indices; finer-grained "slot + // expressions are disjoint" reasoning is a future affine-analysis + // upgrade and would let sync emit fewer flags in prefetch idioms. + else if (auto selectOp = dyn_cast(op)) { + if (isa(selectOp.getResult().getType())) { + UpdateConservativeAliasBufferInfo(selectOp.getResult(), + selectOp.getTrueValue()); + UpdateConservativeAliasBufferInfo(selectOp.getResult(), + selectOp.getFalseValue()); + } + } // --- Case C: 控制流 (SCF) --- else if (auto forOp = dyn_cast(op)) { @@ -443,7 +461,6 @@ LogicalResult PTOIRTranslator::UpdatePointerCastOpMemInfo(pto::PointerCastOp op) if (op.getAddrs().empty()) { return op.emitError("PointerCast must have at least one address operand"); } - Value rootSrc = op.getAddrs().front(); uint64_t sizeInBytes = 0; if (memRefType.hasStaticShape()) { @@ -460,14 +477,46 @@ LogicalResult PTOIRTranslator::UpdatePointerCastOpMemInfo(pto::PointerCastOp op) } } - auto newMemInfo = std::make_unique( - res, - rootSrc, - space, - SmallVector{0}, - sizeInBytes - ); + if (op.getAddrs().size() == 1) { + // Single-address path: keep the historical semantics where `rootBuffer` + // is the i64 address SSA and `baseAddresses` starts at {0}. Downstream + // view ops accumulate a delta into baseAddresses[0]; MemAlias gates on + // rootBuffer SSA identity for single-buffer allocations. + Value rootSrc = op.getAddrs().front(); + auto newMemInfo = std::make_unique( + res, rootSrc, space, SmallVector{0}, sizeInBytes); + buffer2MemInfoMap_[res].emplace_back(newMemInfo->clone()); + return success(); + } + // Multi-address (multi-buffer) cast. Use the cast result as `rootBuffer` + // so every downstream `pto.slot_marker` from the same alloc shares one + // root, and populate `baseAddresses` with each slot's physical offset + // (extracted from the constant i64 operands emitted by + // AllocToPointerCast). `pto.slot_marker` then narrows or keeps these + // offsets according to its slot SSA; MemAlias's existing + // `isBufferAddressRangeOverlap` does the per-slot disambiguation. + SmallVector slotOffsets; + slotOffsets.reserve(op.getAddrs().size()); + for (Value a : op.getAddrs()) { + auto cst = a.getDefiningOp(); + IntegerAttr attr; + if (!cst || !(attr = dyn_cast(cst.getValue()))) { + // Non-constant slot address: fall back to single-address semantics + // with the first operand as rootBuffer so existing non-multi-buffer + // codepaths that happen to feed non-constant i64s keep their + // historical behavior. + Value rootSrc = op.getAddrs().front(); + auto newMemInfo = std::make_unique( + res, rootSrc, space, SmallVector{0}, sizeInBytes); + buffer2MemInfoMap_[res].emplace_back(newMemInfo->clone()); + return success(); + } + slotOffsets.push_back(attr.getValue().getZExtValue()); + } + + auto newMemInfo = std::make_unique( + res, res, space, std::move(slotOffsets), sizeInBytes); buffer2MemInfoMap_[res].emplace_back(newMemInfo->clone()); return success(); } @@ -756,6 +805,41 @@ void PTOIRTranslator::UpdateConservativeAliasBufferInfo(Value result, resultMemInfoVec.emplace_back(parentInfo->clone(result)); } +void PTOIRTranslator::UpdateSlotMarkerAliasBufferInfo(pto::SlotMarkerOp op) { + Value result = op.getResult(); + Value source = op.getSource(); + if (!result || !source) + return; + if (!buffer2MemInfoMap_.contains(source)) + return; + + Value slot = op.getSlot(); + IntegerAttr constAttr; + bool isConstSlot = matchPattern(slot, m_Constant(&constAttr)); + int64_t constSlotIdx = isConstSlot ? constAttr.getValue().getSExtValue() : -1; + + auto &resultMemInfoVec = buffer2MemInfoMap_[result]; + for (auto &parentInfo : buffer2MemInfoMap_[source]) { + auto newInfo = parentInfo->clone(result); + // Multi-buffer parent: `baseAddresses` lists every physical slot's + // offset (populated by `UpdatePointerCastOpMemInfo` for multi-address + // casts). For a constant slot index, narrow `baseAddresses` to just + // that one slot so MemAlias's range-overlap check returns false when + // two const-slot uses pick different slots. For a dynamic slot index, + // keep all addresses -- the runtime SSA could resolve to any slot, so + // sync must conservatively treat the use as touching all of them. + if (isConstSlot && constSlotIdx >= 0 && + constSlotIdx < static_cast(newInfo->baseAddresses.size()) && + newInfo->baseAddresses.size() > 1) { + uint64_t pickAddr = + newInfo->baseAddresses[static_cast(constSlotIdx)]; + newInfo->baseAddresses.clear(); + newInfo->baseAddresses.push_back(pickAddr); + } + resultMemInfoVec.emplace_back(std::move(newInfo)); + } +} + void PTOIRTranslator::UpdateMemrefSubViewAliasBufferInfo(memref::SubViewOp op) { Value result = op.getResult(); Value source = op.getSource(); diff --git a/lib/PTO/Transforms/InsertSync/SyncCodegen.cpp b/lib/PTO/Transforms/InsertSync/SyncCodegen.cpp index a88c86b39..8ef06d6c8 100644 --- a/lib/PTO/Transforms/InsertSync/SyncCodegen.cpp +++ b/lib/PTO/Transforms/InsertSync/SyncCodegen.cpp @@ -350,15 +350,58 @@ void SyncCodegen::CreateSetWaitOpForMultiBuffer(IRRewriter &rewriter, Operation *op, SyncOperation *sync, bool beforeInsert) { - Value bufferSelected = GetBufferSelected(rewriter, op, sync); - (void)bufferSelected; - auto srcPipe = getPipeAttr(rewriter, sync->GetActualSrcPipe()); auto dstPipe = getPipeAttr(rewriter, sync->GetActualDstPipe()); - auto eventId = getEventAttr(rewriter, sync->eventIds[0]); setSyncInsertionPoint(rewriter, op, beforeInsert || op->hasTrait()); - createSetOrWaitFlagOp(rewriter, op, sync, srcPipe, dstPipe, eventId); + Location loc = op->getLoc(); + + // If the analysis did not plumb a slot SSA (e.g. multi-buffer alloc + // present but the access reads it through an unexpected view chain), + // fall back to a static set/wait on the first event id. This preserves + // correctness at the cost of forgoing per-slot pipelining. + if (!sync->slotSSAExpr || sync->eventIds.empty()) { + auto eventId = getEventAttr(rewriter, sync->eventIds[0]); + createSetOrWaitFlagOp(rewriter, op, sync, srcPipe, dstPipe, eventId); + return; + } + + // Emit pto.set_flag_dyn / pto.wait_flag_dyn with a runtime event id chosen + // from the slot SSA. Specifically, build an N-way select chain over the + // allocated event ids so the hardware sees event id `eventIds[slot % N]`. + // For the common N == 2 case this collapses to a single arith.select. + uint32_t n = sync->slotCount; + assert(n >= 2 && "multi-buffer codegen requires slotCount >= 2"); + assert(sync->eventIds.size() == n && + "multi-buffer codegen expects N event ids"); + + // Compute `slot % N` once and reuse across the chain. + Value nConst = rewriter.create(loc, n); + Value slot = sync->slotSSAExpr; + if (slot.getType() != rewriter.getIndexType()) { + slot = rewriter.create(loc, rewriter.getIndexType(), + slot); + } + Value slotMod = rewriter.create(loc, slot, nConst); + + // N-way select: start from eventIds[0] and chain `eq slotMod, i` picks + // through 1..N-1. + Value selected = + rewriter.create(loc, sync->eventIds[0]); + for (uint32_t i = 1; i < n; ++i) { + Value iIdx = rewriter.create(loc, i); + Value isThis = rewriter.create( + loc, arith::CmpIPredicate::eq, slotMod, iIdx); + Value idI = + rewriter.create(loc, sync->eventIds[i]); + selected = rewriter.create(loc, isThis, idI, selected); + } + + if (sync->isSyncWaitType()) { + rewriter.create(loc, srcPipe, dstPipe, selected); + } else { + rewriter.create(loc, srcPipe, dstPipe, selected); + } } Value SyncCodegen::GetBufferSelected(IRRewriter &rewriter, Operation *op, diff --git a/lib/PTO/Transforms/InsertSync/SyncCommon.cpp b/lib/PTO/Transforms/InsertSync/SyncCommon.cpp index b7a8f01a7..e1efcf7d6 100644 --- a/lib/PTO/Transforms/InsertSync/SyncCommon.cpp +++ b/lib/PTO/Transforms/InsertSync/SyncCommon.cpp @@ -128,6 +128,11 @@ SyncOperation::GetMatchSync(unsigned index) const { res->isCompensation = this->isCompensation; res->autoSyncTailBarrier = this->autoSyncTailBarrier; res->SetDepSyncIRIndex(this->GetDepSyncIRIndex()); + // Slot info: propagate as a default; callers that know the matched side's + // slot SSA (e.g. wait-side gets the consumer slot rather than the + // producer slot) overwrite after GetMatchSync. + res->slotSSAExpr = this->slotSSAExpr; + res->slotCount = this->slotCount; return res; } diff --git a/lib/PTO/Transforms/PTOMaterializeTileHandles.cpp b/lib/PTO/Transforms/PTOMaterializeTileHandles.cpp index 76aabd282..a0e86126a 100644 --- a/lib/PTO/Transforms/PTOMaterializeTileHandles.cpp +++ b/lib/PTO/Transforms/PTOMaterializeTileHandles.cpp @@ -511,6 +511,16 @@ static Value computeExplicitAddress(Value value, OpBuilder &builder, if (auto subview = value.getDefiningOp()) return computeSubviewAddress(subview, builder, loc); + if (auto select = value.getDefiningOp()) { + Value trueAddr = computeExplicitAddress(select.getTrueValue(), builder, loc); + Value falseAddr = + computeExplicitAddress(select.getFalseValue(), builder, loc); + if (!trueAddr || !falseAddr) + return Value(); + return builder.create(loc, select.getCondition(), + trueAddr, falseAddr); + } + if (auto cast = value.getDefiningOp()) return computeExplicitAddress(cast.getSource(), builder, loc); diff --git a/lib/PTO/Transforms/PTOPlanMemory.cpp b/lib/PTO/Transforms/PTOPlanMemory.cpp index 36ec6bb22..51ab8d34d 100644 --- a/lib/PTO/Transforms/PTOPlanMemory.cpp +++ b/lib/PTO/Transforms/PTOPlanMemory.cpp @@ -11,6 +11,7 @@ #include "PTOPlanMemory.h" +#include "PTO/IR/PTOMultiBuffer.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -403,10 +404,36 @@ void MemLivenessAnalysis::RecursionIR(Region *region, Liveness live) { // of the result as a use of the source in liveness analysis. UpdateBufferAlias(bindOp.getResult(), bindOp.getSource()); return WalkResult::advance(); + } else if (auto slotOp = dyn_cast(op)) { + // SlotMarker is metadata-only: it tags which physical slot of a + // multi-buffer alloc this view refers to. From the planner's point of + // view its result aliases the source; the slot index travels with the + // op and is consumed later by PTOResolveBufferSelect / sync. + UpdateBufferAlias(slotOp.getResult(), slotOp.getSource()); + return WalkResult::advance(); } else if (isLocalMemPlan() && dyn_cast(op)) { + auto allocOp = cast(op); if (failed(CheckLocalBufferAllocOp(op))) { return WalkResult::interrupt(); } + // Pick up the multi-buffer slot count when present. Range checking + // mirrors the type-level verifier so a malformed memref alloc (e.g. + // from a hand-written test) still gets a clear diagnostic. The + // sibling expansion in `ExpandMultiBufferStorageEntry` supports any + // N in the legal range. + if (auto attr = allocOp->getAttrOfType( + mlir::pto::kPtoMultiBufferAttrName)) { + uint64_t n = attr.getValue().getZExtValue(); + if (n < mlir::pto::kPtoMultiBufferMinNum || + n > mlir::pto::kPtoMultiBufferMaxNum) { + allocOp.emitError() + << "pto.multi_buffer must be in [" + << mlir::pto::kPtoMultiBufferMinNum << ", " + << mlir::pto::kPtoMultiBufferMaxNum << "] (got " << n << ")"; + return WalkResult::interrupt(); + } + buffer2MultiNum[allocOp.getResult()] = static_cast(n); + } UpdateOpBufferInfo(op, op->getResults()); return WalkResult::advance(); } else if (auto loadOp = dyn_cast(op)) { @@ -632,7 +659,16 @@ MemLivenessAnalysis::CheckLocalBufferAllocOp(Operation *op) const { bool MemLivenessAnalysis::isSkippableOp(Operation *op) const { // Call-like ops are still modeled explicitly. Only pure terminators and // dim queries are skipped here. - return isa(op); + // + // `pto.slot_marker` is a metadata-only view added by PTOViewToMemref to + // thread multi-buffer slot selection through the memref layer. Until + // PlanMemory acquires first-class multi-buffer support (the design's + // §5.2 work), treat it as a passthrough so the rest of the pipeline can + // still be exercised. The N-way physical fan-out lives on the + // `pto.multi_buffer` attr of the underlying `memref.alloc` and is a + // follow-up. + return isa(op); } LogicalResult @@ -1123,10 +1159,24 @@ void MemPlan::ValidateParameters(std::unique_ptr &e) const { void MemPlan::UpdateBuffer2Offsets() { for (auto &e : StorageEntryVec) { + // Skip sibling (slot >= 1) entries -- their offsets are written via the + // primary entry's `relationOtherBuffers` traversal below. Without this + // skip the sibling offsets would be appended in StorageEntryVec order + // rather than slot order, breaking the runtime contract that + // `buffer2Offsets[buffer][k]` is slot k's physical offset. + if (e->isMultiBufferSlot) + continue; for (Value &buffer : e->inplaceBuffers) { - // MultiBuffer can cause multiple addrs. buffer2Offsets[buffer].push_back( (e->bitsOffset + kBitsToByte - 1) / kBitsToByte); + // Multi-buffer primary: append sibling offsets in slot order so the + // final offsets list is [slot0, slot1, ..., slotN-1]. + for (auto *sibling : e->relationOtherBuffers) { + if (!sibling) + continue; + buffer2Offsets[buffer].push_back( + (sibling->bitsOffset + kBitsToByte - 1) / kBitsToByte); + } } } // In the MultiBuffer scenario, single reuse db will result in additional @@ -1255,20 +1305,31 @@ void MemPlan::GlobalWorkspaceNoReuse(StorageEntry *rootStorageEntry) { } void MemPlan::ExpandMultiBufferStorageEntry() { - // StorageEntry that needs to be expanded. + // For each multi-buffer primary entry, create (N - 1) sibling entries so + // the planner can lay out one physical slot per sibling. Siblings are + // pushed into `StorageEntryVec` and participate in normal Stage0/Stage2 + // address allocation. The primary keeps `relationOtherBuffers` pointing + // at the siblings in slot order (slot 1..N-1), and `relationPongEntry` + // aliases the first sibling so existing N == 2 codepaths keep working. size_t size = StorageEntryVec.size(); for (size_t i = 0; i < size; i++) { - if (StorageEntryVec[i]->multiBufferNum > 1) { - std::unique_ptr entry = std::make_unique(); - entry->bufInfo = StorageEntryVec[i]->bufInfo; - entry->bufferLifeVec = StorageEntryVec[i]->bufferLifeVec; - entry->alignedConstBits = StorageEntryVec[i]->alignedConstBits; - entry->inplaceBuffers = StorageEntryVec[i]->inplaceBuffers; - entry->multiBufferNum = StorageEntryVec[i]->multiBufferNum; - // Ping saves information related to Pong. - StorageEntryVec[i]->relationPongEntry = entry.get(); + auto *primary = StorageEntryVec[i].get(); + if (primary->multiBufferNum <= 1) + continue; + uint32_t n = primary->multiBufferNum; + for (uint32_t slot = 1; slot < n; ++slot) { + auto entry = std::make_unique(); + entry->bufInfo = primary->bufInfo; + entry->bufferLifeVec = primary->bufferLifeVec; + entry->alignedConstBits = primary->alignedConstBits; + entry->inplaceBuffers = primary->inplaceBuffers; + entry->multiBufferNum = n; + entry->isMultiBufferSlot = true; + primary->relationOtherBuffers.push_back(entry.get()); StorageEntryVec.push_back(std::move(entry)); } + if (!primary->relationOtherBuffers.empty()) + primary->relationPongEntry = primary->relationOtherBuffers.front(); } } @@ -1808,9 +1869,13 @@ void MemPlan::PlanRelationPongEntryAddress(uint64_t offset, StorageEntry *e) { pingEntry2RelationPongEntry[e] = std::move(entry); } else if (e->multiBufferNum == kDoubleBufferCount) { e->relationPongEntry->bitsOffset = offset; - } else { - llvm_unreachable("Does not support multi buffer num greater than 2 !"); } + // N > 2: the Stage1 "place ping next to a free pong slot" optimization is + // not modeled for the general N-way case in this release. Sibling entries + // get their own addresses via the normal Stage0/Stage2 paths in + // `PlanReusableLocalBuffer` / `PlanSingleLocalBuffer`. This branch is a + // no-op rather than an unreachable so the planner can keep making forward + // progress on N > 2 inputs. } bool MemPlan::VerifyConflictStage2(PlanRecHis &his, const StorageEntry *e, diff --git a/lib/PTO/Transforms/PTOPlanMemory.h b/lib/PTO/Transforms/PTOPlanMemory.h index 1230b389f..fd8d0ad2a 100644 --- a/lib/PTO/Transforms/PTOPlanMemory.h +++ b/lib/PTO/Transforms/PTOPlanMemory.h @@ -147,8 +147,22 @@ struct StorageEntry { SmallVector inplaceBuffers; /// multiBuffer relation StorageEntry. + /// For N >= 2 this aliases `relationOtherBuffers.front()` -- kept around so + /// existing N == 2 code paths can keep using the single-sibling field. StorageEntry *relationPongEntry{nullptr}; + /// Sibling slot entries for multi-buffer (N - 1 entries for slot 1..N-1). + /// The primary entry occupies slot 0; siblings own slot 1..N-1. Sibling + /// entries have `isMultiBufferSlot == true` and live in `StorageEntryVec` + /// independently of their primary -- the planner assigns each one its own + /// `bitsOffset` via the same Stage0/Stage2 logic used for normal allocs. + SmallVector relationOtherBuffers; + + /// True if this entry is a multi-buffer sibling (slot >= 1) that should + /// NOT independently write into `buffer2Offsets` -- the primary entry is + /// responsible for emitting all slot offsets in slot order. + bool isMultiBufferSlot{false}; + /// The number of multibuffer optimization. /// note: default 1 which means single buffer and does not do multibuffer /// optimization. diff --git a/lib/PTO/Transforms/PTOResolveBufferSelect.cpp b/lib/PTO/Transforms/PTOResolveBufferSelect.cpp new file mode 100644 index 000000000..858593e6d --- /dev/null +++ b/lib/PTO/Transforms/PTOResolveBufferSelect.cpp @@ -0,0 +1,196 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTOResolveBufferSelect.cpp -----------------------------------------===// +// +// Lowering for multi-buffer slot selection. +// +// Consumes `pto.slot_marker %src[%k] : memref<...>` ops written by +// PTOViewToMemref while lowering `pto.multi_tile_get`. By the time this pass +// runs, PTOPlanMemory has already converted the underlying `memref.alloc` to +// a multi-address `pto.pointer_cast(addr0, ..., addrN-1)`. This pass picks +// the right per-slot address(es) for each slot_marker use: +// +// * Constant slot k: emit a single-address `pto.pointer_cast(addrK)` at +// the use site and replace the slot_marker. +// * Dynamic slot %k: emit N single-address per-slot pointer_casts and +// pick one via an N-way `arith.select` chain. The user's SSA selects +// the slot -- this pass does NOT synthesize `iv mod N`. +// +// The original multi-address `pto.pointer_cast` is left in IR as the +// "alloc anchor" so future sync extensions can still see the multi-buffer +// geometry (e.g. for `set_flag_dyn` / `wait_flag_dyn` derivation). +// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOMultiBuffer.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_PTORESOLVEBUFFERSELECT +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +#define DEBUG_TYPE "pto-resolve-buffer-select" + +using namespace mlir; + +namespace { + +/// Walk back through pure metadata ops (`pto.bind_tile`, `pto.slot_marker`) +/// to find the root multi-address `pto.pointer_cast` that this view ties +/// to. Returns nullptr if the chain does not terminate on a +/// `pto.pointer_cast` -- in which case this slot_marker is not a multi- +/// buffer reference and should not be touched. +static pto::PointerCastOp lookupRootPointerCast(Value v) { + while (Operation *def = v.getDefiningOp()) { + if (auto pc = dyn_cast(def)) + return pc; + if (auto bind = dyn_cast(def)) { + v = bind.getSource(); + continue; + } + if (auto sm = dyn_cast(def)) { + // Nested slot_marker should not happen (verifier disallows nested + // multi_tile_get), but follow the chain defensively. + v = sm.getSource(); + continue; + } + return {}; + } + return {}; +} + +/// Lookup tile-buf config from the existing PointerCastOp's optional attr. +/// Returns nullptr if not set. +static Attribute getCastConfigAttr(pto::PointerCastOp root) { + auto cfg = root.getConfig(); + if (cfg.has_value()) + return *cfg; + return Attribute(); +} + +/// Create a fresh single-address pointer_cast that aliases slot `slotIdx` +/// of `root`. The result type matches `targetType`. `vRow` / `vCol` and +/// `config` are forwarded from the root. +static Value emitSlotPointerCast(IRRewriter &rewriter, Location loc, + pto::PointerCastOp root, uint32_t slotIdx, + Type targetType) { + auto rootAddrs = root.getAddrs(); + assert(slotIdx < rootAddrs.size() && "slot index out of range"); + Value vRow = root.getValidRow(); + Value vCol = root.getValidCol(); + Attribute cfg = getCastConfigAttr(root); + auto pc = rewriter.create( + loc, targetType, ValueRange{rootAddrs[slotIdx]}, + vRow ? vRow : Value(), vCol ? vCol : Value(), cfg); + return pc.getResult(); +} + +struct PTOResolveBufferSelectPass + : public mlir::pto::impl::PTOResolveBufferSelectBase< + PTOResolveBufferSelectPass> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PTOResolveBufferSelectPass) + + void runOnOperation() override { + ModuleOp mod = getOperation(); + MLIRContext *ctx = &getContext(); + + SmallVector markers; + mod.walk([&](pto::SlotMarkerOp op) { markers.push_back(op); }); + if (markers.empty()) + return; + + for (auto op : markers) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + // Find the root multi-address pto.pointer_cast that this slot_marker + // refers to. If the chain does not land on one, the marker is not a + // multi-buffer reference; downgrade silently by forwarding source. + pto::PointerCastOp root = lookupRootPointerCast(op.getSource()); + if (!root) { + rewriter.replaceOp(op, op.getSource()); + continue; + } + + auto rootAddrs = root.getAddrs(); + uint32_t n = static_cast(rootAddrs.size()); + if (n < 2) { + // Single-address root: treat slot_marker as identity. + rewriter.replaceOp(op, op.getSource()); + continue; + } + if (n > mlir::pto::kPtoMultiBufferMaxNum) { + op.emitError() << "underlying pointer_cast has " << n + << " addresses, exceeds max " + << mlir::pto::kPtoMultiBufferMaxNum; + signalPassFailure(); + return; + } + + Type targetType = op.getResult().getType(); + + // Constant slot: emit a single-address pointer_cast for that slot. + IntegerAttr constSlotAttr; + if (matchPattern(op.getSlot(), m_Constant(&constSlotAttr))) { + int64_t slotI = constSlotAttr.getValue().getSExtValue(); + if (slotI < 0 || slotI >= static_cast(n)) { + op.emitError() << "constant slot " << slotI + << " is out of range for " + << n << " physical buffers"; + signalPassFailure(); + return; + } + Value picked = emitSlotPointerCast(rewriter, loc, root, + static_cast(slotI), + targetType); + rewriter.replaceOp(op, picked); + continue; + } + + // Dynamic slot: emit per-slot single-addr casts + N-way arith.select. + // The select chain uses the user-supplied SSA verbatim -- ptoas does + // NOT replace it with `iv mod N`. + SmallVector slotMems; + slotMems.reserve(n); + for (uint32_t i = 0; i < n; ++i) + slotMems.push_back( + emitSlotPointerCast(rewriter, loc, root, i, targetType)); + + Value selected = slotMems[0]; + Value slot = op.getSlot(); + for (uint32_t i = 1; i < n; ++i) { + Value iIdx = rewriter.create(loc, i); + Value isThis = rewriter.create( + loc, arith::CmpIPredicate::eq, slot, iIdx); + selected = rewriter.create(loc, isThis, slotMems[i], + selected); + } + rewriter.replaceOp(op, selected); + } + } +}; +} // namespace + +std::unique_ptr mlir::pto::createPTOResolveBufferSelectPass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index 4db9c9ff3..25e61c644 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -13,6 +13,7 @@ // metadata through binding ops and SSA backtracking. #include "PTO/IR/PTO.h" +#include "PTO/IR/PTOMultiBuffer.h" #include "PTO/IR/PTOTypeUtils.h" #include "PTO/Transforms/Passes.h" @@ -137,9 +138,12 @@ static mlir::pto::TileBufConfigAttr lookupConfig(Value v) { if (auto cast = v.getDefiningOp()) { return lookupConfig(cast.getSource()); } - + if (auto slot = v.getDefiningOp()) { + return lookupConfig(slot.getSource()); + } + // 如果追溯到 BlockArgument (函数参数) 或其他无法穿透的 Op,则返回空 - return {}; + return {}; } // ============================================================================= @@ -168,6 +172,10 @@ static void lookupValidDims(Value v, Value &vRow, Value &vCol) { lookupValidDims(cast.getSource(), vRow, vCol); return; } + if (auto slot = v.getDefiningOp()) { + lookupValidDims(slot.getSource(), vRow, vCol); + return; + } vRow = Value(); vCol = Value(); } @@ -578,10 +586,16 @@ static Type convertPTOTypeToMemRef(Type t) { return MemRefType::get({ShapedType::kDynamic}, pty.getElementType(), MemRefLayoutAttrInterface(), Attribute()); } - + // 2. 处理 !pto.tile_buf<...> if (auto tbTy = dyn_cast(t)) return convertTileBufTypeToMemRef(tbTy); + // 3. !pto.multi_tile_buf: collapses to the slot memref shape; + // the N-slot fan-out lives on the `pto.multi_buffer` attr written by + // the alloc lowering, not in the type. This branch is defensive -- + // by design multi_tile_buf does not appear on function boundaries. + if (auto mtbTy = dyn_cast(t)) + return convertTileBufTypeToMemRef(mtbTy.getSlotType()); if (auto tvTy = dyn_cast(t)) return MemRefType::get(tvTy.getShape(), tvTy.getElementType(), MemRefLayoutAttrInterface(), Attribute()); @@ -1339,6 +1353,190 @@ struct PTOViewToMemrefPass rewriter.replaceOp(op, bindOp.getResult()); } + // ------------------------------------------------------------------ + // Stage 0.6: lower pto.alloc_multi_tile / pto.multi_tile_get + // + // alloc_multi_tile produces a `multi_tile_buf`. We lower + // it into: + // %a = memref.alloc() {pto.multi_buffer = N : i32} : memref + // %m = pto.bind_tile %a, ... : memref -> memref + // and replace all uses of the multi_tile_buf SSA with %m. The N-way + // physical fan-out lives on the `pto.multi_buffer` attr and is + // materialized later by PTOPlanMemory. + // + // multi_tile_get consumes that memref and wraps it in + // %slot = pto.slot_marker %m[%k] : memref -> memref + // %t = pto.bind_tile %slot, ... : memref -> memref + // The slot_marker carries the user-supplied slot SSA forward so that + // PlanMemory / sync / EnableBufferSelect can identify which physical + // slot this use refers to. + // + // Ordering: alloc_multi_tile must be lowered before multi_tile_get so + // that the get's source SSA has already been rewired to a memref. + // ------------------------------------------------------------------ + SmallVector allocMultiTiles; + func.walk( + [&](mlir::pto::AllocMultiTileOp op) { allocMultiTiles.push_back(op); }); + + for (auto op : allocMultiTiles) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + auto mtbTy = op.getResult().getType(); + if (!mtbTy) + continue; + auto tbTy = mtbTy.getSlotType(); + if (!tbTy) + continue; + + SmallVector shape(tbTy.getShape().begin(), + tbTy.getShape().end()); + Type elemTy = tbTy.getElementType(); + + // Stride / layout reuse the same logic as alloc_tile. + SmallVector strides; + TileLayoutInfo info; + if (computeTileLayoutInfo(tbTy.getConfigAttr(), elemTy, shape, info)) { + strides = {info.rowStride, info.colStride}; + } else { + strides.resize(shape.size()); + int64_t s = 1; + for (int i = (int)shape.size() - 1; i >= 0; --i) { + strides[i] = s; + if (shape[i] != ShapedType::kDynamic) + s *= shape[i]; + } + } + + auto targetLayout = + StridedLayoutAttr::get(ctx, ShapedType::kDynamic, strides); + auto targetType = + MemRefType::get(shape, elemTy, targetLayout, tbTy.getMemorySpace()); + + Value vRow = op.getValidRow(); + Value vCol = op.getValidCol(); + ArrayRef validShape = tbTy.getValidShape(); + if (!tbTy.hasDynamicValid()) { + if (validShape.size() >= 1 && validShape[0] >= 0) { + vRow = rewriter + .create( + loc, rewriter.getIndexType(), + rewriter.getIndexAttr(validShape[0])) + .getResult(); + } + if (validShape.size() >= 2 && validShape[1] >= 0) { + vCol = rewriter + .create( + loc, rewriter.getIndexType(), + rewriter.getIndexAttr(validShape[1])) + .getResult(); + } + } + + auto configAttr = tbTy.getConfigAttr(); + if (!configAttr) + configAttr = pto::TileBufConfigAttr::getDefault(ctx); + + // memref.alloc with N-slot annotation. The actual N-way address + // expansion happens in PlanMemory. + auto allocLayout = StridedLayoutAttr::get(ctx, 0, strides); + auto allocType = + MemRefType::get(shape, elemTy, allocLayout, tbTy.getMemorySpace()); + auto allocOp = rewriter.create(loc, allocType); + auto i32Ty = IntegerType::get(ctx, 32); + allocOp->setAttr( + mlir::pto::kPtoMultiBufferAttrName, + IntegerAttr::get(i32Ty, static_cast(mtbTy.getCount()))); + + auto bindOp = rewriter.create( + loc, targetType, allocOp.getResult(), vRow ? vRow : Value(), + vCol ? vCol : Value(), configAttr); + markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); + + rewriter.replaceOp(op, bindOp.getResult()); + } + + SmallVector multiTileGets; + func.walk( + [&](mlir::pto::MultiTileGetOp op) { multiTileGets.push_back(op); }); + + for (auto op : multiTileGets) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + // After alloc_multi_tile has been replaced, the source SSA value + // is a memref (no longer multi_tile_buf). We accept the raw Value. + Value srcMem = op->getOperand(0); + if (!llvm::isa(srcMem.getType())) { + op.emitError( + "multi_tile_get expects its source to have been lowered to " + "memref by alloc_multi_tile rewriting; got ") + << srcMem.getType(); + signalPassFailure(); + return; + } + auto srcMemTy = llvm::cast(srcMem.getType()); + + // Recover per-slot tile_buf info from the get op's result type + // (which still describes the per-slot tile shape, valid dims, + // config, etc.). + auto tbTy = llvm::dyn_cast(op.getResult().getType()); + if (!tbTy) { + op.emitError("multi_tile_get result must be `!pto.tile_buf<...>`"); + signalPassFailure(); + return; + } + + // The slot view aliases the source memref byte-for-byte. + auto slotMarker = rewriter.create( + loc, srcMemTy, srcMem, op.getSlot()); + + // Recover valid_row / valid_col from the per-slot tile type so + // downstream BindTile carries the same metadata as the alloc. + Value vRow; + Value vCol; + ArrayRef validShape = tbTy.getValidShape(); + if (!tbTy.hasDynamicValid()) { + if (validShape.size() >= 1 && validShape[0] >= 0) { + vRow = rewriter + .create( + loc, rewriter.getIndexType(), + rewriter.getIndexAttr(validShape[0])) + .getResult(); + } + if (validShape.size() >= 2 && validShape[1] >= 0) { + vCol = rewriter + .create( + loc, rewriter.getIndexType(), + rewriter.getIndexAttr(validShape[1])) + .getResult(); + } + } + + auto configAttr = tbTy.getConfigAttr(); + if (!configAttr) + configAttr = pto::TileBufConfigAttr::getDefault(ctx); + + // BindTile result type uses the same layout signature as in the + // alloc lowering, so subsequent subview / DMA ops see a consistent + // memref view. + SmallVector strides = buildTileMemRefStrides(tbTy); + auto targetLayout = + StridedLayoutAttr::get(ctx, ShapedType::kDynamic, strides); + auto targetType = + MemRefType::get(tbTy.getShape(), tbTy.getElementType(), targetLayout, + tbTy.getMemorySpace()); + + auto bindOp = rewriter.create( + loc, targetType, slotMarker.getResult(), vRow ? vRow : Value(), + vCol ? vCol : Value(), configAttr); + markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); + + rewriter.replaceOp(op, bindOp.getResult()); + } + // ------------------------------------------------------------------ // Stage 0.75: lower pto.declare_tile -> pto.declare_tile_memref + // pto.bind_tile diff --git a/lib/PTO/Transforms/SlotAffineAnalysis.cpp b/lib/PTO/Transforms/SlotAffineAnalysis.cpp new file mode 100644 index 000000000..ab37697c6 --- /dev/null +++ b/lib/PTO/Transforms/SlotAffineAnalysis.cpp @@ -0,0 +1,210 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- SlotAffineAnalysis.cpp ----------------------------------*- C++ -*-===// + +#include "PTO/Transforms/SlotAffineAnalysis.h" + +#include "PTO/IR/PTO.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Operation.h" + +using namespace mlir; + +namespace mlir { +namespace pto { + +Value findSlotMarkerExpr(Value v) { + int hops = 0; + while (v && hops++ < 32) { + Operation *op = v.getDefiningOp(); + if (!op) + return {}; + if (auto sm = dyn_cast(op)) + return sm.getSlot(); + if (auto bind = dyn_cast(op)) { + v = bind.getSource(); + continue; + } + return {}; + } + return {}; +} + +namespace { + +// Canonical form `(innerSym + innerOffset) mod N`. `innerSym` may be null +// when the input is a pure constant -- then the canonical form is just +// `innerOffset mod N` with `innerSym == nullptr`. +struct SlotForm { + Value innerSym; // null = constant-only form + int64_t innerOffset{0}; + uint32_t N{0}; +}; + +static bool tryGetConstantInt(Value v, int64_t &out) { + IntegerAttr attr; + if (!matchPattern(v, m_Constant(&attr))) + return false; + out = attr.getValue().getSExtValue(); + return true; +} + +// Peel `arith.addi`/`arith.subi` with one constant side off `v` into +// `(remaining, offsetDelta)`. Returns false if `v` is not such an op or +// neither side is a constant. +static bool peelAddSubConst(Value v, Value &remaining, int64_t &offset) { + Operation *op = v.getDefiningOp(); + if (!op) + return false; + Value lhs, rhs; + bool isSub = false; + if (auto add = dyn_cast(op)) { + lhs = add.getLhs(); + rhs = add.getRhs(); + } else if (auto sub = dyn_cast(op)) { + lhs = sub.getLhs(); + rhs = sub.getRhs(); + isSub = true; + } else { + return false; + } + + int64_t c; + if (tryGetConstantInt(rhs, c)) { + remaining = lhs; + offset += isSub ? -c : c; + return true; + } + if (!isSub && tryGetConstantInt(lhs, c)) { + // commutativity only for add + remaining = rhs; + offset += c; + return true; + } + return false; +} + +// Express `slot` as `(innerSym + innerOffset) mod N` where N is taken from a +// surrounding `arith.remui slot, %const_N`. If `slot` is a pure constant +// without a `remui`, treat N as the caller-supplied `expectN` and reduce. +// Returns false if the form is not representable. +static bool extractSlotForm(Value slot, uint32_t expectN, SlotForm &out) { + if (!slot) + return false; + + out.innerSym = Value(); + out.innerOffset = 0; + out.N = expectN; + + Operation *def = slot.getDefiningOp(); + + // Case 1: `arith.remui inner, %const_N`. + if (auto remOp = dyn_cast_if_present(def)) { + int64_t n; + if (!tryGetConstantInt(remOp.getRhs(), n) || n <= 0) + return false; + out.N = static_cast(n); + Value inner = remOp.getLhs(); + int64_t offset = 0; + // Peel at most one add/sub of a constant. + Value rem = inner; + int peeled = 0; + while (peeled++ < 4) { + Value next; + if (!peelAddSubConst(rem, next, offset)) + break; + rem = next; + } + int64_t cst; + if (tryGetConstantInt(rem, cst)) { + out.innerSym = Value(); + out.innerOffset = cst + offset; + } else { + out.innerSym = rem; + out.innerOffset = offset; + } + return true; + } + + // Case 2: pure constant (no remui wrapper). + int64_t cst; + if (tryGetConstantInt(slot, cst)) { + out.innerSym = Value(); + out.innerOffset = cst; + return true; + } + + // Case 3: bare symbol (`iv` with no remui). Compare equality of bare + // symbols still works (kEqual when same SSA), but we cannot guarantee + // the underlying value is in `[0, N)` so disjointness is unsafe. + out.innerSym = slot; + out.innerOffset = 0; + return true; +} + +static int64_t pyMod(int64_t a, int64_t n) { + int64_t r = a % n; + if (r < 0) + r += n; + return r; +} + +} // namespace + +SlotRelation compareSlotSSA(Value a, Value b, uint32_t N) { + if (!a || !b || N == 0) + return SlotRelation::kUnknown; + + // Shortcut: same SSA value -> always equal regardless of N. + if (a == b) + return SlotRelation::kEqual; + + SlotForm fa, fb; + if (!extractSlotForm(a, N, fa) || !extractSlotForm(b, N, fb)) + return SlotRelation::kUnknown; + + // Only compare slot forms that share the canonical `mod N` window the + // caller asked about. The shared utility leaves slots that were not + // wrapped in `arith.remui` (or were wrapped with a different modulus) + // to fall through as kUnknown for symbolic forms. + if (fa.N != N || fb.N != N) { + // Both pure-constant fallthrough still works: project both onto N. + if (!fa.innerSym && !fb.innerSym) { + int64_t da = pyMod(fa.innerOffset, N); + int64_t db = pyMod(fb.innerOffset, N); + return da == db ? SlotRelation::kEqual : SlotRelation::kDisjoint; + } + return SlotRelation::kUnknown; + } + + // Pure constants on both sides: project mod N. + if (!fa.innerSym && !fb.innerSym) { + int64_t da = pyMod(fa.innerOffset, N); + int64_t db = pyMod(fb.innerOffset, N); + return da == db ? SlotRelation::kEqual : SlotRelation::kDisjoint; + } + + // One side const, other symbolic: cannot prove disjoint without + // assuming a value range on the symbol. Equality also unprovable. + if (!fa.innerSym || !fb.innerSym) + return SlotRelation::kUnknown; + + // Both symbolic. Need same symbol to reason about (a - b) mod N. + if (fa.innerSym != fb.innerSym) + return SlotRelation::kUnknown; + + int64_t diff = pyMod(fa.innerOffset - fb.innerOffset, N); + return diff == 0 ? SlotRelation::kEqual : SlotRelation::kDisjoint; +} + +} // namespace pto +} // namespace mlir diff --git a/lib/PTO/Transforms/Utils.cpp b/lib/PTO/Transforms/Utils.cpp index 58e68c77e..e74b58933 100644 --- a/lib/PTO/Transforms/Utils.cpp +++ b/lib/PTO/Transforms/Utils.cpp @@ -85,6 +85,13 @@ std::optional GetBufferSpaceAttr(Value operand) { std::optional> getOperationAliasInfo(Operation *op) { if (auto subViewOp = dyn_cast(op)) { return std::make_pair(subViewOp.getResult(), subViewOp.getViewSource()); + } else if (auto bindTileOp = dyn_cast(op)) { + return std::make_pair(bindTileOp.getResult(), bindTileOp.getSource()); + } else if (auto slotMarkerOp = dyn_cast(op)) { + // `pto.slot_marker` is a metadata-only view that tags a memref with the + // physical slot of a multi-buffer alloc. From an alias-walking + // standpoint it behaves like any other view-like op. + return std::make_pair(slotMarkerOp.getResult(), slotMarkerOp.getSource()); } else if (auto extSliceOp = dyn_cast(op)) { return std::make_pair(extSliceOp.getResult(), extSliceOp.getSource()); } else if (auto collapseShapeOp = dyn_cast(op)) { diff --git a/test/lit/pto/multi_tile_affine_disjoint_slots.pto b/test/lit/pto/multi_tile_affine_disjoint_slots.pto new file mode 100644 index 000000000..03806f668 --- /dev/null +++ b/test/lit/pto/multi_tile_affine_disjoint_slots.pto @@ -0,0 +1,89 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --enable-insert-sync --mlir-print-ir-after=pto-insert-sync %s 2>&1 1>/dev/null | FileCheck %s +// RUN: ptoas --enable-graph-sync-solver --graph-sync-solver-event-id-max=8 --mlir-print-ir-after=pto-graph-sync-solver %s 2>&1 1>/dev/null | FileCheck %s --check-prefix=GSS + +// Affine slot disjoint optimization: producer touches slot `(iv+1) % 2`, +// consumer touches slot `iv % 2`. The slot SSA expressions are provably +// different in every iteration, so the same-iter forward MTE2 -> MTE3 +// dependency is dropped entirely. The back-edge MTE3 -> MTE2 dep across +// iterations still gets the per-slot dyn flag pipelining. +// +// Without affine analysis, InsertSync would conservatively emit +// `set_flag[MTE2 -> MTE3, EVENT_ID0]` / `wait_flag[MTE2 -> MTE3, +// EVENT_ID0]` per iteration even though the producer/consumer slots +// cannot overlap same-iter. With affine the pair is provably disjoint +// and the same-iter sync is skipped. + +module { + func.func @prefetch_disjoint_slots( + %gm : memref<16x16xf16, #pto.address_space>, + %dst : memref<16x16xf16, #pto.address_space>, + %n : index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + + %mb = pto.alloc_multi_tile + : !pto.multi_tile_buf + + scf.for %i = %c0 to %n step %c1 { + %next = arith.addi %i, %c1 : index + %cur_idx = arith.remui %i, %c2 : index + %nxt_idx = arith.remui %next, %c2 : index + + // Producer (MTE2): tload writes slot (iv+1)%2. + %prod = pto.multi_tile_get %mb[%nxt_idx] + : !pto.multi_tile_buf + -> !pto.tile_buf + pto.tload ins(%gm : memref<16x16xf16, #pto.address_space>) + outs(%prod : !pto.tile_buf) + + // Consumer (MTE3): tstore reads slot iv%2. + %cons = pto.multi_tile_get %mb[%cur_idx] + : !pto.multi_tile_buf + -> !pto.tile_buf + pto.tstore ins(%cons : !pto.tile_buf) + outs(%dst : memref<16x16xf16, #pto.address_space>) + } + return + } +} + +// Pre-loop primes 2 per-slot events for the back-edge sync. +// CHECK-LABEL: func.func @prefetch_disjoint_slots +// CHECK: pto.set_flag[, , ] +// CHECK: pto.set_flag[, , ] +// CHECK: scf.for +// Loop body: back-edge dyn wait, the tload, then dyn set after tstore. +// The same-iter MTE2 -> MTE3 forward sync has been dropped by the affine +// disjoint-slot proof. +// CHECK: pto.wait_flag_dyn[, , +// CHECK: pto.tload +// CHECK-NOT: pto.set_flag[, +// CHECK-NOT: pto.wait_flag[, +// CHECK: pto.tstore +// CHECK: pto.set_flag_dyn[, , +// Post-loop drains. +// CHECK: pto.wait_flag[, , ] +// CHECK: pto.wait_flag[, , ] + +// GSS-LABEL: func.func @prefetch_disjoint_slots +// GSS: scf.for +// GSS: pto.wait_flag_dyn[, , +// GSS: pto.tload +// GSS-NOT: pto.set_flag[, +// GSS-NOT: pto.wait_flag[, +// The MTE2->MTE3 loop-carried dependence remains dyn-slot keyed, but the +// same-iteration static forward sync is filtered out. +// GSS: pto.set_flag_dyn[, , +// GSS-NOT: pto.set_flag[, +// GSS-NOT: pto.wait_flag[, +// GSS: pto.tstore +// GSS: pto.set_flag_dyn[, , diff --git a/test/lit/pto/multi_tile_buf_n3_planmem_e2e.pto b/test/lit/pto/multi_tile_buf_n3_planmem_e2e.pto new file mode 100644 index 000000000..76b917133 --- /dev/null +++ b/test/lit/pto/multi_tile_buf_n3_planmem_e2e.pto @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --mlir-print-ir-after=pto-plan-memory %s 2>&1 1>/dev/null | FileCheck %s + +// Sanity for N == 3 to confirm `ExpandMultiBufferStorageEntry` produces 3 +// physical slots and `UpdateBuffer2Offsets` emits them in slot order. + +module { + func.func @n3_three_slots(%gm : memref<16x16xf16, #pto.address_space>) { + %mb = pto.alloc_multi_tile + : !pto.multi_tile_buf + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %s0 = pto.multi_tile_get %mb[%c0] + : !pto.multi_tile_buf + -> !pto.tile_buf + %s1 = pto.multi_tile_get %mb[%c1] + : !pto.multi_tile_buf + -> !pto.tile_buf + %s2 = pto.multi_tile_get %mb[%c2] + : !pto.multi_tile_buf + -> !pto.tile_buf + pto.tload ins(%gm : memref<16x16xf16, #pto.address_space>) + outs(%s0 : !pto.tile_buf) + pto.tload ins(%gm : memref<16x16xf16, #pto.address_space>) + outs(%s1 : !pto.tile_buf) + pto.tload ins(%gm : memref<16x16xf16, #pto.address_space>) + outs(%s2 : !pto.tile_buf) + return + } +} + +// CHECK: IR Dump After PlanMemory +// CHECK-LABEL: func.func @n3_three_slots +// CHECK: pto.pointer_cast({{[^,]*}}, {{[^,]*}}, {{[^,]*}}) diff --git a/test/lit/pto/multi_tile_buf_type_parse_print.pto b/test/lit/pto/multi_tile_buf_type_parse_print.pto new file mode 100644 index 000000000..90b9f47e9 --- /dev/null +++ b/test/lit/pto/multi_tile_buf_type_parse_print.pto @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --mlir-print-ir-before=pto-view-to-memref %s 2>&1 1>/dev/null | FileCheck %s + +// Verifies that both the compact and verbose forms of `!pto.multi_tile_buf` +// parse and round-trip through the printer. The compact form (preferred) +// reads as `!pto.multi_tile_buf` and is what most +// frontends will use. + +module { + func.func @mtb_compact_n2() { + %mb = pto.alloc_multi_tile + : !pto.multi_tile_buf + return + } + + func.func @mtb_compact_n4() { + %mb = pto.alloc_multi_tile + : !pto.multi_tile_buf + return + } + + func.func @mtb_verbose() { + %mb = pto.alloc_multi_tile + : !pto.multi_tile_buf, count=3> + return + } +} + +// CHECK-LABEL: func.func @mtb_compact_n2 +// CHECK: pto.alloc_multi_tile : !pto.multi_tile_buf, count=2> + +// CHECK-LABEL: func.func @mtb_compact_n4 +// CHECK: pto.alloc_multi_tile : !pto.multi_tile_buf, count=4> + +// CHECK-LABEL: func.func @mtb_verbose +// CHECK: pto.alloc_multi_tile : !pto.multi_tile_buf, count=3> diff --git a/test/lit/pto/multi_tile_buf_verify_count.pto b/test/lit/pto/multi_tile_buf_verify_count.pto new file mode 100644 index 000000000..69df98a20 --- /dev/null +++ b/test/lit/pto/multi_tile_buf_verify_count.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas %s 2>&1 | FileCheck %s + +// A count of 1 is below the multi-buffer minimum (2). The verifier should +// reject this during type construction. + +module { + func.func @bad_count_one() { + %mb = pto.alloc_multi_tile + : !pto.multi_tile_buf + return + } +} + +// CHECK: multi_tile_buf count must be >= 2 diff --git a/test/lit/pto/multi_tile_const_preload_dyn_loop_select.pto b/test/lit/pto/multi_tile_const_preload_dyn_loop_select.pto new file mode 100644 index 000000000..8b6b147c8 --- /dev/null +++ b/test/lit/pto/multi_tile_const_preload_dyn_loop_select.pto @@ -0,0 +1,81 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --mlir-print-ir-after=pto-view-to-memref %s 2>&1 1>/dev/null | FileCheck %s +// RUN: ptoas --mlir-print-ir-after=pto-resolve-buffer-select %s 2>&1 1>/dev/null | FileCheck %s --check-prefix=SELECT +// RUN: ptoas --pto-arch=a3 --mlir-print-ir-after=pto-materialize-tile-handles %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=MAT +// RUN: ptoas --pto-arch=a3 %s -o - 2>&1 | FileCheck %s --check-prefix=EMITC + +// Mixed slot-select example: loop-external preload uses a constant slot id, +// while the loop body uses a runtime slot id to select one physical buffer. + +module { + func.func @const_preload_dyn_loop_select( + %gm : memref<16x16xf16, #pto.address_space>, %n : index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + + %mb = pto.alloc_multi_tile + : !pto.multi_tile_buf + + // Preload slot 0 before the loop with a constant slot id. + %pre = pto.multi_tile_get %mb[%c0] + : !pto.multi_tile_buf + -> !pto.tile_buf + pto.tload ins(%gm : memref<16x16xf16, #pto.address_space>) + outs(%pre : !pto.tile_buf) + + scf.for %i = %c0 to %n step %c1 { + %idx = arith.remui %i, %c2 : index + %slot = pto.multi_tile_get %mb[%idx] + : !pto.multi_tile_buf + -> !pto.tile_buf + pto.tload ins(%gm : memref<16x16xf16, #pto.address_space>) + outs(%slot : !pto.tile_buf) + } + return + } +} + +// CHECK: IR Dump After PTOViewToMemref +// CHECK-LABEL: func.func @const_preload_dyn_loop_select +// CHECK: memref.alloc() {pto.multi_buffer = 2 : i32} +// CHECK: pto.slot_marker %{{.*}}[%c0 +// CHECK: scf.for +// CHECK: %[[IDX:.*]] = arith.remui %arg{{.*}}, %c2 +// CHECK: pto.slot_marker %{{.*}}[%[[IDX]] + +// After ResolveBufferSelect, the const preload slot is a direct single-address +// pointer_cast, while the loop slot is chosen with arith.select. +// SELECT: IR Dump After PTOResolveBufferSelect +// SELECT-LABEL: func.func @const_preload_dyn_loop_select +// SELECT: %[[ANCHOR:.*]] = pto.pointer_cast(%[[ADDR0:.*]], %[[ADDR1:.*]]) +// SELECT-NOT: pto.slot_marker +// SELECT: pto.pointer_cast(%[[ADDR0]]) +// SELECT: scf.for +// SELECT: arith.select + +// MaterializeTileHandles must carry the dynamic selected buffer address into +// alloc_tile, otherwise EmitC creates a tile without TASSIGN in the loop. +// MAT: IR Dump After PTOMaterializeTileHandles +// MAT-LABEL: func.func @const_preload_dyn_loop_select +// MAT-DAG: %[[ADDR0:[A-Za-z0-9_]+]] = arith.constant 0 : i64 +// MAT-DAG: %[[ADDR1:[A-Za-z0-9_]+]] = arith.constant 512 : i64 +// MAT: pto.alloc_tile addr = %[[ADDR0]] +// MAT: scf.for +// MAT: %[[COND:[A-Za-z0-9_]+]] = arith.cmpi eq +// MAT: %[[DYN_ADDR:[A-Za-z0-9_]+]] = arith.select %[[COND]], %[[ADDR1]], %[[ADDR0]] : i64 +// MAT: pto.alloc_tile addr = %[[DYN_ADDR]] + +// EMITC: __global__ AICORE void const_preload_dyn_loop_select +// EMITC: for ( +// EMITC-NEXT: Tile [[TILE:v[0-9]+]]; +// EMITC-NEXT: uint64_t [[ADDR:v[0-9]+]] = (uint64_t) +// EMITC-NEXT: TASSIGN([[TILE]], [[ADDR]]); +// EMITC: TLOAD([[TILE]], diff --git a/test/lit/pto/multi_tile_const_slot_disjoint_sync.pto b/test/lit/pto/multi_tile_const_slot_disjoint_sync.pto new file mode 100644 index 000000000..16412acec --- /dev/null +++ b/test/lit/pto/multi_tile_const_slot_disjoint_sync.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --enable-insert-sync --mlir-print-ir-after=pto-insert-sync %s 2>&1 1>/dev/null | FileCheck %s + +// Confirms that two constant-slot uses of a `multi_tile_buf` are recognised +// as disjoint by InsertSync: the two MTE2 loads land on distinct physical +// slots (different `pto.pointer_cast` rootBuffer i64 constants), so no +// inter-slot `set_flag` / `wait_flag` is needed between them. Only the +// pipeline tail barrier remains. Without slot-aware buffer planning this +// would otherwise emit redundant MTE2-MTE2 synchronization between the +// two tloads. + +module { + func.func @const_slot_no_inter_sync( + %gm0 : memref<16x16xf16, #pto.address_space>, + %gm1 : memref<16x16xf16, #pto.address_space>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + %mb = pto.alloc_multi_tile + : !pto.multi_tile_buf + + %s0 = pto.multi_tile_get %mb[%c0] + : !pto.multi_tile_buf + -> !pto.tile_buf + %s1 = pto.multi_tile_get %mb[%c1] + : !pto.multi_tile_buf + -> !pto.tile_buf + + pto.tload ins(%gm0 : memref<16x16xf16, #pto.address_space>) + outs(%s0 : !pto.tile_buf) + pto.tload ins(%gm1 : memref<16x16xf16, #pto.address_space>) + outs(%s1 : !pto.tile_buf) + return + } +} + +// Both tloads run on PIPE_MTE2 but touch disjoint slots, so no +// MTE2-internal set/wait should be inserted between them. The only sync +// emitted is the auto pipeline-tail barrier just before `return`. +// CHECK-LABEL: func.func @const_slot_no_inter_sync +// CHECK-NOT: pto.set_flag +// CHECK-NOT: pto.wait_flag +// CHECK: pto.barrier +// CHECK: return diff --git a/test/lit/pto/multi_tile_const_slot_gss_compiles.pto b/test/lit/pto/multi_tile_const_slot_gss_compiles.pto new file mode 100644 index 000000000..1dbfb1708 --- /dev/null +++ b/test/lit/pto/multi_tile_const_slot_gss_compiles.pto @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --enable-graph-sync-solver --graph-sync-solver-event-id-max=8 %s 2>/dev/null > /dev/null + +// Smoke test: a constant-slot multi-buffer program compiles cleanly through +// the GraphSyncSolver pipeline as well. GSS uses its own IR translator +// distinct from InsertSync's, so this guards against regressions when +// adding new view-like ops (here `pto.bind_tile` and `pto.slot_marker`) +// into the shared `getOperationAliasInfo` helper. We don't FileCheck +// flag-emission specifics: a follow-up will teach GSS multi-buffer +// event-id allocation, after which we can lock in optimization patterns. + +module { + func.func @const_slot_gss_compiles( + %gm0 : memref<16x16xf16, #pto.address_space>, + %gm1 : memref<16x16xf16, #pto.address_space>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + %mb = pto.alloc_multi_tile + : !pto.multi_tile_buf + + %s0 = pto.multi_tile_get %mb[%c0] + : !pto.multi_tile_buf + -> !pto.tile_buf + %s1 = pto.multi_tile_get %mb[%c1] + : !pto.multi_tile_buf + -> !pto.tile_buf + + pto.tload ins(%gm0 : memref<16x16xf16, #pto.address_space>) + outs(%s0 : !pto.tile_buf) + pto.tload ins(%gm1 : memref<16x16xf16, #pto.address_space>) + outs(%s1 : !pto.tile_buf) + return + } +} diff --git a/test/lit/pto/multi_tile_dyn_slot_conservative_sync.pto b/test/lit/pto/multi_tile_dyn_slot_conservative_sync.pto new file mode 100644 index 000000000..ea2020dca --- /dev/null +++ b/test/lit/pto/multi_tile_dyn_slot_conservative_sync.pto @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --enable-insert-sync %s 2>/dev/null > /dev/null + +// Smoke test that the dynamic-slot path still compiles end-to-end through +// PTOResolveBufferSelect (arith.select chain) + InsertSync. The +// PTOIRTranslator now sees the arith.select on memref values and +// conservatively aliases both branches via +// `UpdateConservativeAliasBufferInfo`, ensuring sync analysis never +// silently drops a real dependency just because the slot index is a +// runtime SSA. We don't FileCheck specific flag patterns here -- only +// that the pipeline runs to a clean exit code. + +module { + func.func @dyn_slot_compiles( + %gm : memref<16x16xf16, #pto.address_space>, %n : index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + + %mb = pto.alloc_multi_tile + : !pto.multi_tile_buf + + scf.for %i = %c0 to %n step %c1 { + %idx = arith.remui %i, %c2 : index + %s = pto.multi_tile_get %mb[%idx] + : !pto.multi_tile_buf + -> !pto.tile_buf + pto.tload ins(%gm : memref<16x16xf16, #pto.address_space>) + outs(%s : !pto.tile_buf) + } + return + } +} diff --git a/test/lit/pto/multi_tile_get_const_slot_lowering.pto b/test/lit/pto/multi_tile_get_const_slot_lowering.pto new file mode 100644 index 000000000..e8f88edc6 --- /dev/null +++ b/test/lit/pto/multi_tile_get_const_slot_lowering.pto @@ -0,0 +1,66 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --mlir-print-ir-after=pto-view-to-memref %s 2>&1 1>/dev/null | FileCheck %s +// RUN: ptoas --mlir-print-ir-after=pto-plan-memory %s 2>&1 1>/dev/null | FileCheck %s --check-prefix=PLAN +// RUN: ptoas --mlir-print-ir-after=pto-resolve-buffer-select %s 2>&1 1>/dev/null | FileCheck %s --check-prefix=SELECT + +// Verifies the full multi-buffer pipeline for constant slots: +// 1. PTOViewToMemref lowers `alloc_multi_tile` to `memref.alloc` with the +// `pto.multi_buffer = N : i32` attribute, and `multi_tile_get [%k]` to +// `pto.slot_marker` carrying the constant slot SSA. +// 2. PTOPlanMemory reserves N physical addresses and emits a multi- +// address `pto.pointer_cast(addr0, addr1)`. +// 3. PTOResolveBufferSelect lowers each `slot_marker` to a single- +// address `pto.pointer_cast(addrK)` -- one per slot use. + +module { + func.func @const_slot_two_buffers( + %gm0 : memref<16x16xf16, #pto.address_space>, + %gm1 : memref<16x16xf16, #pto.address_space>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + %mb = pto.alloc_multi_tile + : !pto.multi_tile_buf + + %s0 = pto.multi_tile_get %mb[%c0] + : !pto.multi_tile_buf + -> !pto.tile_buf + %s1 = pto.multi_tile_get %mb[%c1] + : !pto.multi_tile_buf + -> !pto.tile_buf + + pto.tload ins(%gm0 : memref<16x16xf16, #pto.address_space>) + outs(%s0 : !pto.tile_buf) + pto.tload ins(%gm1 : memref<16x16xf16, #pto.address_space>) + outs(%s1 : !pto.tile_buf) + return + } +} + +// CHECK: IR Dump After PTOViewToMemref +// CHECK-LABEL: func.func @const_slot_two_buffers +// CHECK: memref.alloc() {pto.multi_buffer = 2 : i32} +// CHECK: pto.slot_marker %{{.*}}[%c0 +// CHECK: pto.slot_marker %{{.*}}[%c1 + +// Stage 2 -- PlanMemory: two physical slots reserved -> 2-addr pointer_cast. +// PLAN: IR Dump After PlanMemory +// PLAN-LABEL: func.func @const_slot_two_buffers +// PLAN: pto.pointer_cast({{.*}}, {{.*}}) +// PLAN: pto.slot_marker + +// Stage 3 -- ResolveBufferSelect: each slot_marker becomes a single-addr +// pointer_cast; original 2-addr pointer_cast remains as alloc anchor. +// SELECT: IR Dump After PTOResolveBufferSelect +// SELECT-LABEL: func.func @const_slot_two_buffers +// SELECT: %[[ANCHOR:.*]] = pto.pointer_cast(%{{.*}}, %{{.*}}) +// SELECT-NOT: pto.slot_marker +// SELECT: pto.pointer_cast(%[[ADDR0:.*]]) +// SELECT: pto.pointer_cast(%[[ADDR1:.*]]) diff --git a/test/lit/pto/multi_tile_get_dyn_slot_lowering.pto b/test/lit/pto/multi_tile_get_dyn_slot_lowering.pto new file mode 100644 index 000000000..31b4583b6 --- /dev/null +++ b/test/lit/pto/multi_tile_get_dyn_slot_lowering.pto @@ -0,0 +1,64 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --mlir-print-ir-after=pto-view-to-memref %s 2>&1 1>/dev/null | FileCheck %s +// RUN: ptoas --mlir-print-ir-after=pto-resolve-buffer-select %s 2>&1 1>/dev/null | FileCheck %s --check-prefix=SELECT + +// Verifies that `multi_tile_get` with a dynamic slot SSA preserves the +// user-supplied slot expression all the way through `pto.slot_marker` and +// the final `arith.select` chain. The lowering must NOT rewrite the slot +// expression into `iv mod N` -- the frontend SSA is the source of truth +// for which slot this use refers to. + +module { + func.func @dyn_slot_prefetch( + %gm : memref<16x16xf16, #pto.address_space>, %n : index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + + %mb = pto.alloc_multi_tile + : !pto.multi_tile_buf + + scf.for %i = %c0 to %n step %c1 { + %next = arith.addi %i, %c1 : index + %cur_idx = arith.remui %i, %c2 : index + %next_idx = arith.remui %next, %c2 : index + + %s_next = pto.multi_tile_get %mb[%next_idx] + : !pto.multi_tile_buf + -> !pto.tile_buf + pto.tload ins(%gm : memref<16x16xf16, #pto.address_space>) + outs(%s_next : !pto.tile_buf) + + %s_cur = pto.multi_tile_get %mb[%cur_idx] + : !pto.multi_tile_buf + -> !pto.tile_buf + pto.tload ins(%gm : memref<16x16xf16, #pto.address_space>) + outs(%s_cur : !pto.tile_buf) + } + return + } +} + +// CHECK: IR Dump After PTOViewToMemref +// CHECK-LABEL: func.func @dyn_slot_prefetch +// CHECK: memref.alloc() {pto.multi_buffer = 2 : i32} +// CHECK: %[[NEXT:.*]] = arith.addi %{{.*}}, %c1 +// CHECK-DAG: %[[CUR_IDX:.*]] = arith.remui %arg{{.*}}, %c2 +// CHECK-DAG: %[[NEXT_IDX:.*]] = arith.remui %[[NEXT]], %c2 +// CHECK: pto.slot_marker %{{.*}}[%[[NEXT_IDX]] +// CHECK: pto.slot_marker %{{.*}}[%[[CUR_IDX]] + +// After ResolveBufferSelect: slot_marker -> per-slot pointer_cast + +// arith.select chain driven by the user SSA (not iv mod N). +// SELECT: IR Dump After PTOResolveBufferSelect +// SELECT-LABEL: func.func @dyn_slot_prefetch +// SELECT-NOT: pto.slot_marker +// SELECT: pto.pointer_cast(%{{.*}}, %{{.*}}) +// SELECT: arith.select diff --git a/test/lit/pto/multi_tile_get_verify_slot.pto b/test/lit/pto/multi_tile_get_verify_slot.pto new file mode 100644 index 000000000..67397271a --- /dev/null +++ b/test/lit/pto/multi_tile_get_verify_slot.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas %s 2>&1 | FileCheck %s + +// multi_tile_get with a constant slot >= count must be rejected. + +module { + func.func @bad_const_slot_out_of_range() { + %c5 = arith.constant 5 : index + %mb = pto.alloc_multi_tile + : !pto.multi_tile_buf + %s = pto.multi_tile_get %mb[%c5] + : !pto.multi_tile_buf + -> !pto.tile_buf + return + } +} + +// CHECK: constant slot 5 is out of range for multi_tile_buf count=2 diff --git a/test/lit/pto/multi_tile_n4_planmem_e2e.pto b/test/lit/pto/multi_tile_n4_planmem_e2e.pto new file mode 100644 index 000000000..361e06da7 --- /dev/null +++ b/test/lit/pto/multi_tile_n4_planmem_e2e.pto @@ -0,0 +1,67 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --mlir-print-ir-after=pto-plan-memory %s 2>&1 1>/dev/null | FileCheck %s --check-prefix=PLAN +// RUN: ptoas --mlir-print-ir-after=pto-resolve-buffer-select %s 2>&1 1>/dev/null | FileCheck %s --check-prefix=SELECT + +// End-to-end test for the N == 4 multi-buffer path: PlanMemory must +// reserve 4 physical slots and emit a 4-address `pto.pointer_cast`. +// PTOResolveBufferSelect then lowers each constant-slot `pto.slot_marker` +// to a single-address pointer_cast picking the right slot. + +module { + func.func @n4_rotate( + %gm : memref<16x16xf16, #pto.address_space>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + + %mb = pto.alloc_multi_tile + : !pto.multi_tile_buf + + %s0 = pto.multi_tile_get %mb[%c0] + : !pto.multi_tile_buf + -> !pto.tile_buf + %s1 = pto.multi_tile_get %mb[%c1] + : !pto.multi_tile_buf + -> !pto.tile_buf + %s2 = pto.multi_tile_get %mb[%c2] + : !pto.multi_tile_buf + -> !pto.tile_buf + %s3 = pto.multi_tile_get %mb[%c3] + : !pto.multi_tile_buf + -> !pto.tile_buf + + pto.tload ins(%gm : memref<16x16xf16, #pto.address_space>) + outs(%s0 : !pto.tile_buf) + pto.tload ins(%gm : memref<16x16xf16, #pto.address_space>) + outs(%s1 : !pto.tile_buf) + pto.tload ins(%gm : memref<16x16xf16, #pto.address_space>) + outs(%s2 : !pto.tile_buf) + pto.tload ins(%gm : memref<16x16xf16, #pto.address_space>) + outs(%s3 : !pto.tile_buf) + return + } +} + +// PlanMemory: 4 physical addresses, one `pto.pointer_cast` with 4 i64 args. +// PLAN: IR Dump After PlanMemory +// PLAN-LABEL: func.func @n4_rotate +// PLAN: pto.pointer_cast({{[^,]*}}, {{[^,]*}}, {{[^,]*}}, {{[^,]*}}) + +// PTOResolveBufferSelect: four single-addr casts (one per slot), no leftover +// slot_marker. The original 4-addr cast remains as the alloc anchor. +// SELECT: IR Dump After PTOResolveBufferSelect +// SELECT-LABEL: func.func @n4_rotate +// SELECT-NOT: pto.slot_marker +// SELECT: pto.pointer_cast({{[^,]*}}, {{[^,]*}}, {{[^,]*}}, {{[^,]*}}) +// SELECT: pto.pointer_cast(%{{[^,)]+}}) +// SELECT: pto.pointer_cast(%{{[^,)]+}}) +// SELECT: pto.pointer_cast(%{{[^,)]+}}) +// SELECT: pto.pointer_cast(%{{[^,)]+}}) diff --git a/test/lit/pto/multi_tile_no_loop_unroll.pto b/test/lit/pto/multi_tile_no_loop_unroll.pto new file mode 100644 index 000000000..e7b4373e9 --- /dev/null +++ b/test/lit/pto/multi_tile_no_loop_unroll.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --mlir-print-ir-after=pto-view-to-memref %s 2>&1 1>/dev/null | FileCheck %s + +// Verifies that multi-buffer expression and slot selection work WITHOUT an +// enclosing `scf.for`. Two constant slots are tagged via `pto.slot_marker` +// at the memref layer and live independently of any loop induction +// variable. This is something the PR615 fully-automatic `iv mod N` path +// could not express. + +module { + func.func @unroll_two_slots( + %gm0 : memref<16x16xf16, #pto.address_space>, + %gm1 : memref<16x16xf16, #pto.address_space>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + %mb = pto.alloc_multi_tile + : !pto.multi_tile_buf + + %s0 = pto.multi_tile_get %mb[%c0] + : !pto.multi_tile_buf + -> !pto.tile_buf + pto.tload ins(%gm0 : memref<16x16xf16, #pto.address_space>) + outs(%s0 : !pto.tile_buf) + + %s1 = pto.multi_tile_get %mb[%c1] + : !pto.multi_tile_buf + -> !pto.tile_buf + pto.tload ins(%gm1 : memref<16x16xf16, #pto.address_space>) + outs(%s1 : !pto.tile_buf) + return + } +} + +// CHECK: IR Dump After PTOViewToMemref +// CHECK-LABEL: func.func @unroll_two_slots +// CHECK-NOT: scf.for +// CHECK: memref.alloc() {pto.multi_buffer = 2 : i32} +// CHECK: pto.slot_marker %{{.*}}[%c0 +// CHECK: pto.slot_marker %{{.*}}[%c1 diff --git a/test/lit/pto/multi_tile_prefetch_dyn_event_id.pto b/test/lit/pto/multi_tile_prefetch_dyn_event_id.pto new file mode 100644 index 000000000..4f00bb59b --- /dev/null +++ b/test/lit/pto/multi_tile_prefetch_dyn_event_id.pto @@ -0,0 +1,67 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --enable-insert-sync --mlir-print-ir-after=pto-insert-sync %s 2>&1 1>/dev/null | FileCheck %s + +// End-to-end test for the multi-buffer dyn-event-id sync path: a +// double-buffer prefetch loop with one MTE2 producer and one V consumer, +// both routed through `pto.multi_tile_get` with a dynamic slot index. +// +// Without slot-aware sync the analysis would either drop the dependency +// (incorrect) or emit a conservative static `set_flag` per iteration, +// serializing the prefetch. With the slot-aware path the producer side +// gets `set_flag_dyn[MTE2, V, ...]` keyed off its slot SSA, and the +// consumer side gets a matching `wait_flag_dyn[MTE2, V, ...]` keyed off +// its own slot SSA. The hardware event id is selected by an arith.select +// chain over the N allocated event ids. + +module { + func.func @prefetch_dyn_eid( + %gm : memref<16x16xf16, #pto.address_space>, + %dst : memref<16x16xf16, #pto.address_space>, + %n : index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + + %mb = pto.alloc_multi_tile + : !pto.multi_tile_buf + + scf.for %i = %c0 to %n step %c1 { + %idx = arith.remui %i, %c2 : index + // Producer: MTE2 load on dynamic slot. + %prod = pto.multi_tile_get %mb[%idx] + : !pto.multi_tile_buf + -> !pto.tile_buf + pto.tload ins(%gm : memref<16x16xf16, #pto.address_space>) + outs(%prod : !pto.tile_buf) + // Consumer: V op on dynamic slot. The MTE2 -> V dep on a + // multi-buffer alloc should lower to a dyn event-id pair. + pto.tstore ins(%prod : !pto.tile_buf) + outs(%dst : memref<16x16xf16, #pto.address_space>) + } + return + } +} + +// CHECK-LABEL: func.func @prefetch_dyn_eid +// Pre-loop primes both per-slot events (one set per allocated event id). +// CHECK: pto.set_flag[, , ] +// CHECK: pto.set_flag[, , ] +// CHECK: scf.for +// In the loop body the consumer waits first (back-edge into the upcoming +// store slot), then the producer issues its slot-keyed set after the load. +// CHECK: arith.select +// CHECK: pto.wait_flag_dyn[, , +// CHECK: pto.tload +// CHECK: pto.tstore +// CHECK: arith.select +// CHECK: pto.set_flag_dyn[, , +// Post-loop drains both per-slot events. +// CHECK: pto.wait_flag[, , ] +// CHECK: pto.wait_flag[, , ] diff --git a/test/lit/pto/multi_tile_prefetch_gss_event_id.pto b/test/lit/pto/multi_tile_prefetch_gss_event_id.pto new file mode 100644 index 000000000..62188120a --- /dev/null +++ b/test/lit/pto/multi_tile_prefetch_gss_event_id.pto @@ -0,0 +1,59 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --enable-graph-sync-solver --graph-sync-solver-event-id-max=8 --mlir-print-ir-after=pto-graph-sync-solver %s 2>&1 1>/dev/null | FileCheck %s + +// Same-SSA dynamic slot expression: both tload and tstore use `%i % 2` for +// their slot index. GSS keeps two physical slots (count=2), allocates N=2 +// hardware event ids, and emits dyn `set_flag` / `wait_flag` keyed by the +// slot SSA -- matching the InsertSync path for the same input. Different +// iterations touch different physical slots, so the per-slot event id +// pipeline lets them overlap; collapsing to a single static event id (the +// earlier GSS behaviour) would have serialised every iteration. + +module { + func.func @prefetch_gss( + %gm : memref<16x16xf16, #pto.address_space>, + %dst : memref<16x16xf16, #pto.address_space>, + %n : index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + + %mb = pto.alloc_multi_tile + : !pto.multi_tile_buf + + scf.for %i = %c0 to %n step %c1 { + %idx = arith.remui %i, %c2 : index + %prod = pto.multi_tile_get %mb[%idx] + : !pto.multi_tile_buf + -> !pto.tile_buf + pto.tload ins(%gm : memref<16x16xf16, #pto.address_space>) + outs(%prod : !pto.tile_buf) + pto.tstore ins(%prod : !pto.tile_buf) + outs(%dst : memref<16x16xf16, #pto.address_space>) + } + return + } +} + +// CHECK-LABEL: func.func @prefetch_gss +// Pre-loop primes both per-slot events for the cross-iter back-edge. +// CHECK: pto.set_flag[, , ] +// CHECK: pto.set_flag[, , ] +// CHECK: scf.for +// Loop body uses dyn flag keyed on the slot SSA. +// CHECK: arith.select +// CHECK: pto.wait_flag_dyn[, , +// CHECK: pto.tload +// CHECK: pto.tstore +// CHECK: arith.select +// CHECK: pto.set_flag_dyn[, , +// Post-loop drains both per-slot events. +// CHECK: pto.wait_flag[, , ] +// CHECK: pto.wait_flag[, , ] diff --git a/test/lit/pto/multi_tile_preload_loop_set_wait.pto b/test/lit/pto/multi_tile_preload_loop_set_wait.pto new file mode 100644 index 000000000..cdb0ebbce --- /dev/null +++ b/test/lit/pto/multi_tile_preload_loop_set_wait.pto @@ -0,0 +1,60 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a3 --enable-insert-sync %s -o - 2>&1 | FileCheck %s +// RUN: ptoas --pto-arch=a3 --enable-graph-sync-solver --graph-sync-solver-event-id-max=8 %s -o - 2>&1 | FileCheck %s + +// Preload outside the loop and the dynamic-slot loop body both perform +// tload -> tstore on the selected tile, so both regions should lower to a +// concrete set_flag/wait_flag pair. + +module { + func.func @preload_and_loop_set_wait( + %gm : memref<16x16xf16, #pto.address_space>, + %dst : memref<16x16xf16, #pto.address_space>, + %n : index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + + %mb = pto.alloc_multi_tile + : !pto.multi_tile_buf + + // Loop-external preload path uses a constant slot id. + %pre = pto.multi_tile_get %mb[%c0] + : !pto.multi_tile_buf + -> !pto.tile_buf + pto.tload ins(%gm : memref<16x16xf16, #pto.address_space>) + outs(%pre : !pto.tile_buf) + pto.tstore ins(%pre : !pto.tile_buf) + outs(%dst : memref<16x16xf16, #pto.address_space>) + + scf.for %i = %c0 to %n step %c1 { + %idx = arith.remui %i, %c2 : index + %slot = pto.multi_tile_get %mb[%idx] + : !pto.multi_tile_buf + -> !pto.tile_buf + pto.tload ins(%gm : memref<16x16xf16, #pto.address_space>) + outs(%slot : !pto.tile_buf) + pto.tstore ins(%slot : !pto.tile_buf) + outs(%dst : memref<16x16xf16, #pto.address_space>) + } + return + } +} + +// CHECK-LABEL: __global__ AICORE void preload_and_loop_set_wait +// CHECK: TLOAD([[PRE:v[0-9]+]], +// CHECK: set_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID[[PRE_EID:[0-9]+]]); +// CHECK-NEXT: wait_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID[[PRE_EID]]); +// CHECK: TSTORE({{v[0-9]+}}, [[PRE]]); +// CHECK: for ( +// CHECK: TLOAD([[LOOP:v[0-9]+]], +// CHECK: set_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID[[LOOP_EID:[0-9]+]]); +// CHECK-NEXT: wait_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID[[LOOP_EID]]); +// CHECK: TSTORE({{v[0-9]+}}, [[LOOP]]); diff --git a/test/lit/pto/multi_tile_unknown_slot_gss_dyn_event_id.pto b/test/lit/pto/multi_tile_unknown_slot_gss_dyn_event_id.pto new file mode 100644 index 000000000..7ea560bb4 --- /dev/null +++ b/test/lit/pto/multi_tile_unknown_slot_gss_dyn_event_id.pto @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --enable-graph-sync-solver --graph-sync-solver-event-id-max=8 --mlir-print-ir-after=pto-graph-sync-solver %s 2>&1 1>/dev/null | FileCheck %s + +// GSS cannot prove the producer slot `%i % 2` equal or disjoint from the +// consumer slot `(%i + %n) % 2`, because the offset is runtime data. Keep the +// conservative multi-event dyn-flag path. + +module { + func.func @unknown_slot_gss( + %gm : memref<16x16xf16, #pto.address_space>, + %dst : memref<16x16xf16, #pto.address_space>, + %n : index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + + %mb = pto.alloc_multi_tile + : !pto.multi_tile_buf + + scf.for %i = %c0 to %n step %c1 { + %prod_idx = arith.remui %i, %c2 : index + %mixed = arith.addi %i, %n : index + %cons_idx = arith.remui %mixed, %c2 : index + + %prod = pto.multi_tile_get %mb[%prod_idx] + : !pto.multi_tile_buf + -> !pto.tile_buf + pto.tload ins(%gm : memref<16x16xf16, #pto.address_space>) + outs(%prod : !pto.tile_buf) + + %cons = pto.multi_tile_get %mb[%cons_idx] + : !pto.multi_tile_buf + -> !pto.tile_buf + pto.tstore ins(%cons : !pto.tile_buf) + outs(%dst : memref<16x16xf16, #pto.address_space>) + } + return + } +} + +// CHECK-LABEL: func.func @unknown_slot_gss +// CHECK: pto.set_flag[, , ] +// CHECK: scf.for +// CHECK: pto.wait_flag_dyn[, , +// CHECK: pto.tload +// CHECK: pto.tstore +// CHECK: pto.set_flag_dyn[, , +// CHECK: pto.wait_flag[, , ] diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index e4fc4e71c..17c61924a 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -1175,7 +1175,9 @@ int main(int argc, char **argv) { // Conditionally add one automatic synchronization mode. Barrier-all is a // conservative standalone pass; InsertSync and GraphSyncSolver are set/wait - // solvers. + // solvers. Sync runs BEFORE PTOResolveBufferSelect so it sees per-use + // `pto.slot_marker` ops and can keep multi-buffer slot identity (const slot + // K vs slot K' or dynamic slot) for the alias / event-id analysis. if (enableInsertSync) pm.addNestedPass(pto::createPTOInsertSyncPass()); else if (enableInjectBarrierAllSync) @@ -1188,7 +1190,10 @@ int main(int argc, char **argv) { pto::createPTOGraphSyncSolverPass(graphSyncOpts)); } - + // Materialize per-slot single-address `pto.pointer_cast` (constant slot) + // or an `arith.select` chain (dynamic slot). The multi-address cast + // produced by PlanMemory survives as the alloc anchor. + pm.addPass(pto::createPTOResolveBufferSelectPass()); std::unique_ptr outputFile; llvm::raw_ostream *outputOS = &llvm::outs();