diff --git a/docs/en/dev/passes/18-lower_transpose_load_param_layout.md b/docs/en/dev/passes/18-lower_transpose_load_param_layout.md index 3272bca78..114f63bff 100644 --- a/docs/en/dev/passes/18-lower_transpose_load_param_layout.md +++ b/docs/en/dev/passes/18-lower_transpose_load_param_layout.md @@ -1,32 +1,36 @@ # LowerTransposeLoadParamLayout Pass -Lowers ``tile.load(..., transpose=True)`` to canonical-form DN parameter layout (RFC #1300 P6). +Lowers ``tile.load(..., transpose=True)`` by emitting an explicit +``tensor.as_layout`` view inside the InCore body (RFC #1300 P6). ## Overview -Before this pass, ``tile.load(transpose=True)`` is the user's way of saying "I want -the column-major view of this source tensor at the load site". After this pass, that -intent is encoded into the InCore parameter's TensorType itself — the source/load -combo is rewritten to RFC #1300 §3.3 canonical form so codegen, verifier, and -downstream passes consume a single, self-consistent ``(shape, stride, layout)`` triple. +Before this pass, ``tile.load(transpose=True)`` is the user's way of saying "I +want the column-major view of this source tensor at the load site". After this +pass, that intent is encoded into a body-local ``tensor.as_layout`` view at the +top of the InCore body so codegen, verifier, and downstream passes consume a +single, self-consistent ``(shape, stride, layout)`` triple. For each InCore parameter ``p`` loaded via ``tile.load(p, ..., transpose=True)``: -- ``p``'s TensorType is promoted from ``[..., a, b] ND`` to ``[..., b, a] DN`` — - the trailing-pair shape swap plus the DN layout tag. The new TensorView carries - an empty stride; ``MaterializeTensorStrides`` (which runs later in the default - pipeline, after ``CanonicalizeIOOrder``) fills it with the packed canonical - strides. -- Every ``tile.load(p, offsets, shapes, valid_shapes, ..., transpose=True)`` whose - source is a promoted parameter is rewritten so the three tuples' trailing pair - is swapped to canonical coords and the ``transpose=True`` kwarg is dropped. - ``DeduceTileLoadType`` reads the source's DN layout to derive the Mat tile-view +- The InCore body is **prepended** with ``p_dn = tensor.as_layout(p, layout=DN)``. + The new Var ``p_dn`` carries the canonical ``[..., b, a] DN`` view (trailing-pair + shape swap + DN layout tag with packed canonical strides set by the + ``tensor.as_layout`` deduce-type). +- Body uses of ``p`` are substituted with ``p_dn``. ``p``'s parameter + signature is left unchanged — the orch side keeps passing its original + row-major ND tensor (which matches the runtime torch tensor's layout). +- Every ``tile.load(p, offsets, shapes, valid_shapes, ..., transpose=True)`` + whose source is a promoted parameter is rewritten to ``tile.load(p_dn, ...)``, + with the three tuples' trailing pair swapped to canonical coords and + ``transpose=True`` flipped to ``transpose=False``. + ``DeduceTileLoadType`` reads ``p_dn``'s DN layout to derive the Mat tile-view layout that the legacy ``transpose=True`` swap produced — the two signals are equivalent (§4.2 canonical pair). -- Every non-InCore call site that targets a promoted callee wraps the promoted - argument in ``tensor.as_layout(arg, DN)`` (RFC #1300 P4). The bridging op is - pure metadata — it emits no PTOAS instruction; ``make_tensor_view`` consumes - the new view directly. + +Non-InCore (orch) functions are not touched. The DN reinterpret is a +single-function concern owned by the InCore body that needs it, which keeps the +cross-function boundary trivial: orch always passes a row-major ND tensor. **Requirements**: @@ -37,9 +41,7 @@ For each InCore parameter ``p`` loaded via ``tile.load(p, ..., transpose=True)`` **When to use**: 18th pass in the ``Default`` strategy, after ``InferTileMemorySpace`` and before ``ResolveBackendOpLayouts``. The 2D shape -produced by ``FlattenTileNdTo2D`` is a precondition. ``MaterializeTensorStrides`` -runs later in the pipeline (after ``CanonicalizeIOOrder``) to materialize the -DN-packed canonical strides on the promoted parameters. +produced by ``FlattenTileNdTo2D`` is a precondition. ## API @@ -64,26 +66,26 @@ For each InCore function f: set P_nt = {param idx with tile.load(p, ..., transpose=False/absent)} reject P_t ∩ P_nt (mixed-use) for each idx in P_t: - promote f.params[idx].type: [..., a, b] ND → [..., b, a] DN (empty stride) - substitute old Var → new Var in body - rewrite each tile.load(promoted_param, off, shp, vs, transpose=True) in body: + let p = f.params[idx] + skip if p is already DN-tagged (the user-written / pre-canonical case) + build p_dn := tensor.as_layout(p, layout=DN) — type derived by op deducer + prepend (p_dn = ...) AssignStmt to body + record p → p_dn in substitution map + substitute body uses of every promoted p with p_dn + rewrite each tile.load(p_dn, off, shp, vs, transpose=True) in body: swap last two dims of off / shp / vs drop transpose=True kwarg -For each non-InCore function: - walk body; for every Call whose op is a GlobalVar of a promoted callee: - wrap each promoted-slot arg with tensor.as_layout(arg, DN) +(Non-InCore functions are untouched.) ``` -**Complexity:** O(N log N) — one body walk per function plus one program-wide call-site -walk. Map lookups (``promotions_by_callee_name``) are ``log N`` per call. +**Complexity:** O(N log N) — one body walk per InCore function. | Behavior | Trigger | | -------- | ------- | -| Promote param to ``[..., b, a] DN`` | InCore param is source of ``tile.load(..., transpose=True)`` | +| Prepend ``p_dn = tensor.as_layout(p, DN)`` and rewrite tile.load | InCore param is source of ``tile.load(..., transpose=True)`` | | Skip param | Already DN, or no transposed load | | Skip whole function | Function is Orchestration / Opaque / Group | -| Wrap call-site arg in ``tensor.as_layout`` | Non-InCore call to a promoted callee | | Reject | Mixed transpose=True / transpose=False on same param | | Reject | DN + explicit physical stride source (would compose as double transpose) | @@ -118,28 +120,28 @@ class Before: def matmul_incore( self, a: pl.Tensor[[64, 128], pl.FP32], - b: pl.Tensor[[128, 32], pl.FP32, pl.DN], # ← shape swapped + DN tag + b: pl.Tensor[[32, 128], pl.FP32], # ← param signature unchanged c: pl.Out[pl.Tensor[[64, 32], pl.FP32]], ) -> pl.Tensor[[64, 32], pl.FP32]: + b_dn = tensor.as_layout(b, layout=DN) # ← prepended view + # type: [128, 32] DN tile_a = pl.load(a, [0, 0], [64, 128], target_memory=pl.MemorySpace.Mat) - tile_b = pl.load(b, [0, 0], [128, 32], target_memory=pl.MemorySpace.Mat) - # ↑ no transpose kwarg - # ↑ shapes swapped to canonical coords + tile_b = pl.load(b_dn, [0, 0], [128, 32], target_memory=pl.MemorySpace.Mat) + # ↑ source switched to b_dn + # ↑ shapes swapped to canonical coords + # ↑ no transpose kwarg ... @pl.function(type=pl.FunctionType.Orchestration) def orchestrator(self, a, b): c = pl.create_tensor([64, 32], dtype=pl.FP32) - # b is wrapped in tensor.as_layout to bridge ND → DN at the call site: - bridged_b = tensor.as_layout(b, pl.DN) # type: [128, 32] DN - return self.matmul_incore(a, bridged_b, c) + return self.matmul_incore(a, b, c) # ← unchanged ``` -``a`` is loaded without transpose, so it is unchanged. ``b`` is promoted in the -InCore signature, all body loads of ``b`` are rewritten to canonical coords with -no transpose, and the orchestrator's call site wraps ``b`` in -``tensor.as_layout`` to bridge ``[32, 128] ND`` → ``[128, 32] DN`` over the same -physical buffer. +``a`` is loaded without transpose, so it is unchanged. ``b``'s param signature +is preserved; the kernel internally derives a DN view via ``tensor.as_layout`` +and references that view in its ``tile.load``. The orchestrator is not +touched — it passes its own row-major ``b`` straight through. ## Implementation @@ -163,29 +165,31 @@ physical buffer. | Function type | Action | | ------------- | ------ | -| InCore (InCore, AIC, AIV) | Scanned, possibly promoted | -| Orchestration / Group / Opaque | Scanned for call sites; promoted-arg wrapped in ``tensor.as_layout`` | +| InCore (InCore, AIC, AIV) | Scanned, body prepended with ``tensor.as_layout`` views as needed | +| Orchestration / Group / Opaque | Untouched | | Parameter state | Action | | --------------- | ------ | -| Sourced by ``tile.load(..., transpose=True)``, layout != DN, rank ≥ 2 | Promoted (shape swap + DN tag) | -| Sourced by ``tile.load(..., transpose=True)``, already DN | Idempotent — unchanged | +| Sourced by ``tile.load(..., transpose=True)``, layout != DN, rank ≥ 2 | ``tensor.as_layout`` view prepended; body uses substituted | +| Sourced by ``tile.load(..., transpose=True)``, already DN | Skipped — ``DeduceTileLoadType`` already handles DN-source XOR transpose | | Mixed transpose=True / transpose=False on same param | ``CHECK`` failure | | Not sourced by any transposed load | Unchanged | | Rank < 2 candidate | ``CHECK`` failure | -## Interaction with ``tensor.as_layout`` (P4) and ``MaterializeTensorStrides`` (P3) +## Interaction with ``tensor.as_layout`` (P4) -This pass is the first real consumer of ``tensor.as_layout`` in the default -pipeline. The bridging op is single-purpose: it flips the layout tag and derives -the new shape from §4.2 canonical pair semantics — callers never write the -target shape, so the call-site rewriter cannot get it wrong. +This pass is the first consumer of ``tensor.as_layout`` in the default +pipeline. The bridging op is single-purpose: it flips the layout tag and +derives the new shape from §4.2 canonical pair semantics, then attaches the +packed canonical strides via ``CanonicalizeView``. Codegen lowers +``tensor.as_layout`` to a fresh ``pto.make_tensor_view`` bound to the input +tensor's underlying SSA buffer with the LHS's +``(shape, stride, layout)`` triple — no PTOAS instruction is emitted, the +result is a pure metadata reinterpret. -Downstream, ``MaterializeTensorStrides`` fills the empty stride slot on each -promoted parameter with the packed canonical DN strides (RFC §2.4). The -combination of P6 + P3 is what gives codegen a self-consistent -``(shape, stride, layout)`` triple — no further ``dn_swap`` / ``get_shape_source_idx`` -fix-ups are needed in the codegen path for promoted parameters. +Per RFC §4.2, the InCore-side reinterpret does not violate the "InCore cannot +create tensors" invariant: ``tensor.as_layout`` allocates nothing, it +re-describes the input's existing physical buffer. ## Interaction with ``tensor.transpose`` at Orchestration diff --git a/docs/en/dev/passes/26-materialize_tensor_strides.md b/docs/en/dev/passes/26-materialize_tensor_strides.md index 71e6e3988..f963f124a 100644 --- a/docs/en/dev/passes/26-materialize_tensor_strides.md +++ b/docs/en/dev/passes/26-materialize_tensor_strides.md @@ -66,7 +66,7 @@ The pass is **idempotent**: re-running on already-materialized IR is a no-op, si ## Example -**Before** — InCore param with empty-stride DN view (e.g. produced by a future `LowerTransposeLoadParamLayout` rewrite): +**Before** — InCore param with empty-stride DN view (user-written `pl.Tensor[..., pl.DN]` without an explicit stride hint): ```python @pl.function(type=pl.FunctionType.InCore) diff --git a/docs/zh-cn/dev/passes/18-lower_transpose_load_param_layout.md b/docs/zh-cn/dev/passes/18-lower_transpose_load_param_layout.md index c3c65790a..afb76f64a 100644 --- a/docs/zh-cn/dev/passes/18-lower_transpose_load_param_layout.md +++ b/docs/zh-cn/dev/passes/18-lower_transpose_load_param_layout.md @@ -1,16 +1,18 @@ # LowerTransposeLoadParamLayout Pass -将 `tile.load(..., transpose=True)` 下沉为 canonical 形式的 DN 参数布局(RFC #1300 P6)。 +将 `tile.load(..., transpose=True)` 下沉为 InCore body 内显式的 `tensor.as_layout` 视图(RFC #1300 P6)。 ## 概述 -本 Pass 之前,`tile.load(transpose=True)` 是用户表达"我希望在 load 站点看到源张量的列主序视图"的方式。Pass 之后,这一意图被编码进 InCore 参数的 TensorType 本身 —— 源张量/load 组合被改写为 RFC #1300 §3.3 的 canonical 形式,使 codegen、verifier、下游 Pass 看到一份自洽的 `(shape, stride, layout)` 三元组。 +本 Pass 之前,`tile.load(transpose=True)` 是用户表达"我希望在 load 站点看到源张量的列主序视图"的方式。Pass 之后,这一意图被编码进 InCore body 顶部的一条 `tensor.as_layout` 视图绑定 —— 使 codegen、verifier、下游 Pass 看到一份自洽的 `(shape, stride, layout)` 三元组。 对每个被 `tile.load(p, ..., transpose=True)` 加载的 InCore 参数 `p`: -- `p` 的 TensorType 从 `[..., a, b] ND` 提升为 `[..., b, a] DN` —— 末两维形状互换 + DN 布局标签。新 TensorView 的 stride 为空;`MaterializeTensorStrides`(在默认 pipeline 中位于 `CanonicalizeIOOrder` 之后运行)会把它填为 packed canonical 的 stride。 -- 每个 `tile.load(p, offsets, shapes, valid_shapes, ..., transpose=True)`(源是已提升的参数)被改写为:三个 tuple 的末两维互换以匹配 canonical 坐标,丢弃 `transpose=True` kwarg。`DeduceTileLoadType` 通过源张量的 DN 布局推导出 Mat tile-view 的 layout —— 这两种信号在 §4.2 canonical pair 下是等价的。 -- 每个目标是已提升 callee 的非 InCore 函数调用站点,会把对应实参用 `tensor.as_layout(arg, DN)` 包一层(RFC #1300 P4)。该桥接 op 是纯元数据 —— 不生成 PTOAS 指令;`make_tensor_view` 直接消费新视图。 +- **在 InCore body 顶部插入** `p_dn = tensor.as_layout(p, layout=DN)`。新 Var `p_dn` 携带 canonical `[..., b, a] DN` 视图(末两维 shape 互换 + DN layout 标签 + `tensor.as_layout` 的 deduce-type 填入的 packed canonical strides)。 +- body 中对 `p` 的引用被替换为 `p_dn`。`p` 的参数签名保持不变 —— orch 侧继续按原 row-major ND 形式传 tensor(与 runtime 的 torch tensor 一致)。 +- body 中每个 `tile.load(p, offsets, shapes, valid_shapes, ..., transpose=True)`(源是已提升参数)被改写为 `tile.load(p_dn, ...)`,三个 tuple 的末两维互换为 canonical 坐标,`transpose=True` 翻为 `transpose=False`。`DeduceTileLoadType` 通过 `p_dn` 的 DN 布局推出 Mat tile-view 的 layout —— 两种信号在 §4.2 canonical pair 下等价。 + +非 InCore(orch)函数完全不动。DN 重解释是单函数(InCore)内部的关注点,由用到它的 body 自己拥有;跨函数边界保持简单:orch 永远传 row-major ND tensor。 **前置条件**: @@ -19,7 +21,7 @@ - Tile op 已存在且为 2D(`IncoreTileOps`、`TileOps2D`) - 被提升的参数 rank ≥ 2 -**使用时机**:在 `Default` 策略中作为第 18 个 Pass 运行(文档编号 18 对应于 docs/passes/ 中的执行顺序槽位,与 pass_manager.py 中的相对顺序匹配),位于 `InferTileMemorySpace` 之后、`ResolveBackendOpLayouts` 之前。`FlattenTileNdTo2D` 产生的 2D 形状是前置条件。`MaterializeTensorStrides` 在 pipeline 后段运行(在 `CanonicalizeIOOrder` 之后)以物化 DN-packed canonical stride。 +**使用时机**:在 `Default` 策略中作为第 18 个 Pass 运行(文档编号 18 对应于 docs/passes/ 中的执行顺序槽位,与 pass_manager.py 中的相对顺序匹配),位于 `InferTileMemorySpace` 之后、`ResolveBackendOpLayouts` 之前。`FlattenTileNdTo2D` 产生的 2D 形状是前置条件。 ## API @@ -44,25 +46,26 @@ program_canonical = p(program) 得到 P_nt = {tile.load(p, ..., transpose=False/缺省) 命中的 param 索引} 拒绝 P_t ∩ P_nt (混用) 对每个 idx in P_t: - 提升 f.params[idx].type:[..., a, b] ND → [..., b, a] DN(stride 留空) - 在 body 中以新 Var 替换旧 Var - 改写 body 中每个 tile.load(promoted_param, off, shp, vs, transpose=True): + let p = f.params[idx] + 若 p 已经是 DN(用户写的 / 预先 canonical 化的情形)则跳过 + 构造 p_dn := tensor.as_layout(p, layout=DN) —— 类型由 op deduce 推出 + 将 (p_dn = ...) AssignStmt 插入 body 顶部 + 记录 p → p_dn 的替换映射 + 按映射替换 body 中所有对已提升 p 的引用为 p_dn + 改写 body 中每个 tile.load(p_dn, off, shp, vs, transpose=True): 交换 off / shp / vs 末两维 丢弃 transpose=True kwarg -对每个非 InCore 函数: - 遍历 body;对每个 op 为已提升 callee 的 GlobalVar 的 Call: - 给每个已提升槽位的实参包一层 tensor.as_layout(arg, DN) +(非 InCore 函数原样保留) ``` -**复杂度:** O(N log N) —— 每个函数一次 body 走查,加一次全程序级调用站点走查。Map 查找(`promotions_by_callee_name`)为每次调用 `log N`。 +**复杂度:** O(N log N) —— 每个 InCore 函数一次 body 走查。 | 行为 | 触发条件 | | ---- | -------- | -| 提升参数到 `[..., b, a] DN` | InCore 参数是 `tile.load(..., transpose=True)` 的源 | +| 插入 `p_dn = tensor.as_layout(p, DN)` 并改写 tile.load | InCore 参数是 `tile.load(..., transpose=True)` 的源 | | 跳过参数 | 已经是 DN,或没有转置 load | | 整个函数跳过 | 函数为 Orchestration / Opaque / Group | -| 调用站点 wrap `tensor.as_layout` | 非 InCore 函数调用已提升 callee | | 拒绝 | 同一参数既被 transpose=True 也被 transpose=False 加载 | | 拒绝 | DN + 显式物理 stride 源(与 tile.load 转置会叠成双重转置) | @@ -97,24 +100,25 @@ class Before: def matmul_incore( self, a: pl.Tensor[[64, 128], pl.FP32], - b: pl.Tensor[[128, 32], pl.FP32, pl.DN], # ← 形状互换 + DN 标签 + b: pl.Tensor[[32, 128], pl.FP32], # ← 参数签名保持不变 c: pl.Out[pl.Tensor[[64, 32], pl.FP32]], ) -> pl.Tensor[[64, 32], pl.FP32]: + b_dn = tensor.as_layout(b, layout=DN) # ← body 顶部插入的视图 + # 类型:[128, 32] DN tile_a = pl.load(a, [0, 0], [64, 128], target_memory=pl.MemorySpace.Mat) - tile_b = pl.load(b, [0, 0], [128, 32], target_memory=pl.MemorySpace.Mat) - # ↑ 没有 transpose kwarg - # ↑ shapes 已互换到 canonical 坐标 + tile_b = pl.load(b_dn, [0, 0], [128, 32], target_memory=pl.MemorySpace.Mat) + # ↑ 源切到 b_dn + # ↑ shapes 已互换到 canonical 坐标 + # ↑ 无 transpose kwarg ... @pl.function(type=pl.FunctionType.Orchestration) def orchestrator(self, a, b): c = pl.create_tensor([64, 32], dtype=pl.FP32) - # b 在调用站点被 tensor.as_layout 包一层做 ND → DN 桥接: - bridged_b = tensor.as_layout(b, pl.DN) # type: [128, 32] DN - return self.matmul_incore(a, bridged_b, c) + return self.matmul_incore(a, b, c) # ← 保持不变 ``` -`a` 不转置加载,原样保留。`b` 在 InCore 签名被提升,body 中所有对 `b` 的加载改写到 canonical 坐标且无转置 kwarg,orchestrator 调用站点把 `b` 用 `tensor.as_layout` 包起来,把 `[32, 128] ND` 桥接到 `[128, 32] DN`(同一片物理内存)。 +`a` 不转置加载,原样保留。`b` 的参数签名保留;kernel 内部用 `tensor.as_layout` 派生 DN 视图,其 `tile.load` 引用该视图。orchestrator 完全不动 —— 它把自己的 row-major `b` 原样传下去。 ## 实现 @@ -138,22 +142,22 @@ def orchestrator(self, a, b): | 函数类型 | 行为 | | -------- | ---- | -| InCore(InCore、AIC、AIV) | 扫描,可能被提升 | -| Orchestration / Group / Opaque | 扫描调用站点;已提升实参 wrap `tensor.as_layout` | +| InCore(InCore、AIC、AIV) | 扫描,按需在 body 顶部插入 `tensor.as_layout` 视图 | +| Orchestration / Group / Opaque | 不动 | | 参数状态 | 行为 | | -------- | ---- | -| 是 `tile.load(..., transpose=True)` 的源,layout != DN,rank ≥ 2 | 提升(形状互换 + DN 标签) | -| 是 `tile.load(..., transpose=True)` 的源,已是 DN | 幂等 —— 保持不变 | +| 是 `tile.load(..., transpose=True)` 的源,layout != DN,rank ≥ 2 | 插入 `tensor.as_layout` 视图;body 中引用被替换 | +| 是 `tile.load(..., transpose=True)` 的源,已是 DN | 跳过 —— `DeduceTileLoadType` 已经处理 DN-源 XOR transpose | | 同一参数既 transpose=True 又 transpose=False | `CHECK` 失败 | | 没有转置 load 引用 | 保持不变 | | Rank < 2 候选 | `CHECK` 失败 | -## 与 `tensor.as_layout`(P4)和 `MaterializeTensorStrides`(P3)的交互 +## 与 `tensor.as_layout`(P4)的交互 -本 Pass 是默认 pipeline 中 `tensor.as_layout` 的第一个真实消费者。该桥接 op 单一职责:翻转 layout 标签,目标 shape 由 §4.2 canonical pair 机械导出 —— 调用方不传 target shape,所以调用站点改写器不会出错。 +本 Pass 是默认 pipeline 中 `tensor.as_layout` 的第一个消费者。该桥接 op 单一职责:翻转 layout 标签,目标 shape 由 §4.2 canonical pair 机械导出,并通过 `CanonicalizeView` 附加 packed canonical strides。codegen 把 `tensor.as_layout` 下沉为一条新的 `pto.make_tensor_view`,绑定到输入 tensor 的底层 SSA buffer 上,使用 LHS 的 `(shape, stride, layout)` 三元组 —— 不发射任何 PTOAS 指令,结果是纯元数据 reinterpret。 -下游的 `MaterializeTensorStrides` 把每个被提升的参数 TensorView 空 stride 填为 packed canonical DN strides(RFC §2.4)。P6 + P3 的组合让 codegen 看到自洽的 `(shape, stride, layout)` 三元组 —— 对被提升的参数,codegen 路径无需再做 `dn_swap` / `get_shape_source_idx` 修正。 +按 RFC §4.2,InCore 侧的 reinterpret 不违反"核内不能创建 tensor"约束:`tensor.as_layout` 不分配任何内存,它只是为输入的现有物理 buffer 换一份描述。 ## 与 Orchestration 层 `tensor.transpose` 的交互 diff --git a/docs/zh-cn/dev/passes/26-materialize_tensor_strides.md b/docs/zh-cn/dev/passes/26-materialize_tensor_strides.md index 9b25e770d..6e8def61a 100644 --- a/docs/zh-cn/dev/passes/26-materialize_tensor_strides.md +++ b/docs/zh-cn/dev/passes/26-materialize_tensor_strides.md @@ -66,7 +66,7 @@ Pass **幂等** —— 在已物化的 IR 上重跑等于无操作(类型比 ## 示例 -**Before** —— InCore 形参带有空 stride 的 DN view(例如未来 `LowerTransposeLoadParamLayout` 改写产生的形态): +**Before** —— InCore 形参带有空 stride 的 DN view(用户写的 `pl.Tensor[..., pl.DN]` 未给显式 stride 提示): ```python @pl.function(type=pl.FunctionType.InCore) diff --git a/include/pypto/ir/transforms/passes.h b/include/pypto/ir/transforms/passes.h index 186db1f7b..8e3032cfb 100644 --- a/include/pypto/ir/transforms/passes.h +++ b/include/pypto/ir/transforms/passes.h @@ -426,23 +426,25 @@ Pass AutoTileMatmulL0(); Pass InferTileMemorySpace(); /** - * @brief Lower ``tile.load(transpose=True)`` to canonical-form parameter layout (RFC #1300 P6) + * @brief Lower ``tile.load(transpose=True)`` to a body-local DN view (RFC #1300 P6) * * For each InCore function, detects ``tile.load(..., transpose=True)`` whose source - * is a function parameter and promotes that parameter to canonical-form DN + * is a function parameter ``p`` and rewrites the body so the transpose intent is + * encoded as an explicit ``tensor.as_layout`` view at the top of the body * (RFC #1300 §3.3 + §4.2): * - * - Param TensorType: ``[..., a, b] ND`` → ``[..., b, a] DN`` (trailing-pair swap + - * DN layout tag with empty stride; ``MaterializeTensorStrides`` later fills the - * packed canonical strides). - * - Each ``tile.load(p, offsets, shapes, valid_shapes, ..., transpose=True)`` whose - * source ``p`` is a promoted param is rewritten to: offsets / shapes / - * valid_shapes' trailing pair is swapped to canonical coords, and the - * ``transpose=True`` kwarg is dropped — the DN-source + Mat-target signal - * fully encodes the load's tile-view orientation. - * - Every non-InCore call site that targets a promoted callee is wrapped with - * ``tensor.as_layout(arg, DN)`` so the orch-side ``[..., a, b] ND`` runtime - * tensor bridges to the InCore-side ``[..., b, a] DN`` param type. + * - Prepends ``p_dn = tensor.as_layout(p, layout=DN)`` to the InCore body. + * ``p_dn`` carries the canonical ``[..., b, a] DN`` view; ``p``'s parameter + * signature is left unchanged. + * - Substitutes body uses of ``p`` with ``p_dn``. + * - Rewrites each ``tile.load(p_dn, offsets, shapes, valid_shapes, ..., transpose=True)`` + * to swap the trailing pair of offsets / shapes / valid_shapes into canonical + * coords and drop the ``transpose=True`` kwarg — the DN-source + Mat-target + * signal on ``p_dn`` now fully encodes the load's tile-view orientation. + * + * Non-InCore (orch) functions are left untouched: the orch caller continues to + * pass its original row-major ND tensor straight through to the kernel, which + * keeps the cross-function type boundary trivial. * * Mixed-use parameters (same param loaded with both ``transpose=True`` and * ``transpose=False``) are rejected with ``pypto::ValueError``. diff --git a/python/bindings/modules/passes.cpp b/python/bindings/modules/passes.cpp index ec05a8a4a..1c2b76b0d 100644 --- a/python/bindings/modules/passes.cpp +++ b/python/bindings/modules/passes.cpp @@ -421,12 +421,15 @@ void BindPass(nb::module_& m) { passes.def("lower_transpose_load_param_layout", &pass::LowerTransposeLoadParamLayout, "Create the LowerTransposeLoadParamLayout pass (RFC #1300 P6).\n\n" "For each InCore function, detects tile.load(..., transpose=True) whose source\n" - "is a function parameter and promotes the parameter to canonical-form DN:\n" - "shape trailing-pair is swapped, the DN layout tag is added, the tile.load\n" - "body call's offsets/shapes/valid_shapes are swapped and the transpose=True\n" - "kwarg dropped, and every non-InCore call site wraps the promoted argument\n" - "in tensor.as_layout(arg, DN) to bridge orch-side ND tensors to InCore DN\n" - "params. Mixed-use params (both transpose=True and transpose=False loads on\n" + "is a function parameter `p` and rewrites the body to encode the transpose\n" + "intent as an explicit `tensor.as_layout` view:\n" + " - prepends `p_dn = tensor.as_layout(p, layout=DN)` to the InCore body\n" + " (`p_dn` carries the canonical `[..., b, a] DN` view);\n" + " - substitutes body uses of `p` with `p_dn`;\n" + " - swaps the trailing pair of offsets/shapes/valid_shapes on the matching\n" + " tile.load calls and drops `transpose=True`.\n" + "Parameter signatures are left unchanged. Non-InCore (orch) functions are\n" + "untouched. Mixed-use params (both transpose=True and transpose=False loads on\n" "the same param) are rejected."); passes.def("materialize_tensor_strides", &pass::MaterializeTensorStrides, "Create the MaterializeTensorStrides pass (RFC #1300 §2.4).\n\n" diff --git a/python/pypto/pypto_core/passes.pyi b/python/pypto/pypto_core/passes.pyi index f76630fba..662fb6803 100644 --- a/python/pypto/pypto_core/passes.pyi +++ b/python/pypto/pypto_core/passes.pyi @@ -448,14 +448,18 @@ def lower_transpose_load_param_layout() -> Pass: """Create the LowerTransposeLoadParamLayout pass (RFC #1300 P6). For each InCore function, detects ``tile.load(..., transpose=True)`` whose - source is a function parameter and promotes the parameter to canonical-form - DN: shape trailing-pair is swapped, the DN layout tag is added, body - ``tile.load`` calls have offsets/shapes/valid_shapes' trailing pair swapped - and the ``transpose=True`` kwarg dropped, and every non-InCore call site - wraps the promoted argument in ``tensor.as_layout(arg, DN)`` to bridge - orch-side ND tensors to InCore DN params. Mixed-use parameters (both - ``transpose=True`` and ``transpose=False`` loads on the same param) are - rejected. + source is a function parameter ``p`` and rewrites the body to encode the + transpose intent as an explicit ``tensor.as_layout`` view: + + - prepends ``p_dn = tensor.as_layout(p, layout=DN)`` to the InCore body + (``p_dn`` carries the canonical ``[..., b, a] DN`` view); + - substitutes body uses of ``p`` with ``p_dn``; + - swaps the trailing pair of offsets/shapes/valid_shapes on the matching + ``tile.load`` calls and drops ``transpose=True``. + + Parameter signatures are left unchanged. Non-InCore (orch) functions are + untouched. Mixed-use parameters (both ``transpose=True`` and + ``transpose=False`` loads on the same param) are rejected. """ def materialize_tensor_strides() -> Pass: diff --git a/src/backend/common/pto_ops_common.cpp b/src/backend/common/pto_ops_common.cpp index aa4217775..16cffc53e 100644 --- a/src/backend/common/pto_ops_common.cpp +++ b/src/backend/common/pto_ops_common.cpp @@ -2134,6 +2134,91 @@ void RegisterPTOOps(Backend& backend, const std::unordered_set& exc reg("tensor.write", [](const ir::CallPtr& op, codegen::CodegenBase& codegen) { return MakeTensorWriteCodegenPTO(op, codegen); }); + // ``tensor.as_layout`` (RFC #1300 §3.3): pure metadata reinterpret over the + // same physical buffer. In InCore code, ``LowerTransposeLoadParamLayout`` + // prepends ``b_dn = tensor.as_layout(b, DN)`` at the top of the body for + // each ``tile.load(transpose=True)``-loaded param, then rewrites the body + // to use ``b_dn`` (a Var with TensorType ``[..., b, a] DN`` and explicit + // canonical strides) in place of the original param ``b``. + // + // Codegen lowers this to a fresh ``pto.make_tensor_view`` bound to the + // input's underlying buffer (the function parameter SSA), using the LHS's + // own ``(shape, stride, layout)`` from its TensorType. Downstream + // ``tile.load`` lookups via ``GetOrCreateTensorView`` find the LHS through + // the ``RegisterTensorView`` call below. + reg("tensor.as_layout", [](const ir::CallPtr& op, codegen::CodegenBase& codegen_base) { + auto& codegen = dynamic_cast(codegen_base); + CHECK(op->args_.size() == 1) << "tensor.as_layout requires 1 arg (input)"; + auto input_var = AsVarLike(op->args_[0]); + CHECK(input_var) << "tensor.as_layout input must be a Var/IterArg"; + + auto lhs_var = codegen.GetCurrentResultVar(); + INTERNAL_CHECK_SPAN(static_cast(lhs_var), op->span_) + << "Internal error: tensor.as_layout result var must be set by VisitStmt_(AssignStmt)"; + auto lhs_type = As(lhs_var->GetType()); + CHECK(lhs_type) << "tensor.as_layout output must be TensorType, got " << lhs_var->GetType()->TypeName(); + INTERNAL_CHECK_SPAN(lhs_type->tensor_view_.has_value(), op->span_) + << "Internal error: tensor.as_layout output must have an explicit TensorView " + "(set by DeduceTensorAsLayoutType + CanonicalizeView)"; + + const size_t rank = lhs_type->shape_.size(); + const auto& view = lhs_type->tensor_view_.value(); + INTERNAL_CHECK_SPAN(view.stride.size() == rank, op->span_) + << "Internal error: tensor.as_layout output stride rank " << view.stride.size() + << " does not match shape rank " << rank; + + // The result SSA name (auto-allocated by VisitStmt_(AssignStmt) for the + // backend-dispatched RHS Call) doubles as the tensor_view SSA name — + // register it in tensor_to_view so downstream tile.load lookups resolve. + std::string result_buf = codegen.GetCurrentResultTarget(); + INTERNAL_CHECK_SPAN(!result_buf.empty(), op->span_) << "Internal error: result buf must be set"; + codegen.RegisterTensorView(lhs_var, result_buf); + + // Materialize shape and stride SSA names. + auto emit_dim = [&](const ir::ExprPtr& dim) -> std::string { + if (auto c = As(dim)) { + return codegen.GetOrEmitConstant(c->value_, DataType::INDEX); + } + return codegen.EmitCastToIndex(dim, codegen.GetExprAsCode(dim)); + }; + std::vector shape_dim_names(rank); + for (size_t j = 0; j < rank; ++j) shape_dim_names[j] = emit_dim(lhs_type->shape_[j]); + std::vector stride_names(rank); + for (size_t j = 0; j < rank; ++j) stride_names[j] = emit_dim(view.stride[j]); + + std::string layout_str = "nd"; + switch (view.layout) { + case ir::TensorLayout::DN: + layout_str = "dn"; + break; + case ir::TensorLayout::NZ: + layout_str = "nz"; + break; + case ir::TensorLayout::ND: + break; + } + + std::ostringstream oss; + oss << result_buf << " = pto.make_tensor_view " << codegen.GetVarName(input_var) << ", shape = ["; + for (size_t j = 0; j < rank; ++j) { + if (j > 0) oss << ", "; + oss << shape_dim_names[j]; + } + oss << "], strides = ["; + for (size_t j = 0; j < rank; ++j) { + if (j > 0) oss << ", "; + oss << stride_names[j]; + } + oss << "] {layout = #pto.layout<" << layout_str << ">}"; + oss << ": !pto.tensor_view<"; + for (size_t j = 0; j < rank; ++j) { + if (j > 0) oss << "x"; + oss << "?"; + } + oss << "x" << codegen.GetTypeString(lhs_type->dtype_) << ">"; + return oss.str(); + }); + reg("tile.load", [](const ir::CallPtr& op, codegen::CodegenBase& codegen) { return MakeTileLoadCodegenPTO(op, codegen); }); diff --git a/src/codegen/orchestration/orchestration_codegen.cpp b/src/codegen/orchestration/orchestration_codegen.cpp index 88c084c54..74b095eb9 100644 --- a/src/codegen/orchestration/orchestration_codegen.cpp +++ b/src/codegen/orchestration/orchestration_codegen.cpp @@ -1174,55 +1174,6 @@ class OrchestrationStmtCodegen : public CodegenBase { std::move(info.inner_callee)}; } - /// Build a "wrapper-internal alias map" — for every AssignStmt in the - /// wrapper body whose RHS is a Call to a no-op view op (currently just - /// ``tensor.as_layout``), record LHS-var → upstream-var. This lets - /// ``BuildWrapperReorderedParams`` chase the inner-call's arg back to a - /// wrapper parameter through any orch-side ``tensor.as_layout`` bridge that - /// ``LowerTransposeLoadParamLayout`` may have injected. - std::unordered_map BuildWrapperAliasMap(const FunctionPtr& wrapper_func) { - std::unordered_map alias_map; - class AliasCollector : public IRVisitor { - public: - explicit AliasCollector(std::unordered_map* out) : out_(out) {} - void VisitStmt_(const AssignStmtPtr& op) override { - if (auto call = As(op->value_)) { - // ``tensor.as_layout`` is the canonical orch-side view alias that - // P6 (``LowerTransposeLoadParamLayout``) emits before the kernel - // call. Its runtime lowering is a plain ``Tensor x = src;`` alias, - // so for arg-routing purposes the LHS is interchangeable with the - // RHS's first arg. - if (call->op_ && call->op_->name_ == "tensor.as_layout" && !call->args_.empty()) { - if (auto src = AsVarLike(call->args_[0])) { - (*out_)[op->var_.get()] = src; - } - } - } - IRVisitor::VisitStmt_(op); - } - - private: - std::unordered_map* out_; - }; - if (wrapper_func->body_) { - AliasCollector(&alias_map).VisitStmt(wrapper_func->body_); - } - return alias_map; - } - - /// Resolve ``var`` to its ultimate alias source within the wrapper body - /// (walking through any ``tensor.as_layout`` bindings). Returns ``var`` - /// itself if no alias chain applies. - VarPtr ResolveAliasChain(VarPtr var, const std::unordered_map& alias_map) { - std::unordered_set seen; - while (true) { - auto it = alias_map.find(var.get()); - if (it == alias_map.end()) return var; - if (!seen.insert(var.get()).second) return var; // cycle guard - var = it->second; - } - } - /// Build task params for a wrapper function call, reordered to match the /// inner callee's parameter order. /// @@ -1237,7 +1188,6 @@ class OrchestrationStmtCodegen : public CodegenBase { for (size_t i = 0; i < wrapper_func->params_.size(); ++i) { wrapper_param_to_outer_idx[wrapper_func->params_[i].get()] = i; } - auto alias_map = BuildWrapperAliasMap(wrapper_func); // Phase-5 invariant: the outer Call must carry explicit arg_directions // (populated by DeriveCallDirections). The legacy ParamDirection fallback @@ -1273,16 +1223,6 @@ class OrchestrationStmtCodegen : public CodegenBase { } auto it = wrapper_param_to_outer_idx.find(inner_arg_var.get()); - if (it == wrapper_param_to_outer_idx.end()) { - // The inner-call arg may be a wrapper-local Var bound by a - // ``tensor.as_layout`` AssignStmt (injected by P6 to bridge - // orch-side ND tensors to the InCore-side DN param type). Chase - // the alias chain back to the upstream wrapper parameter. - auto upstream = ResolveAliasChain(inner_arg_var, alias_map); - if (upstream.get() != inner_arg_var.get()) { - it = wrapper_param_to_outer_idx.find(upstream.get()); - } - } if (it == wrapper_param_to_outer_idx.end()) { // Some wrapper-expansion paths can leave inner-call scalar ivs that are // not part of the user-visible wrapper signature. They should not be diff --git a/src/ir/transforms/lower_transpose_load_param_layout_pass.cpp b/src/ir/transforms/lower_transpose_load_param_layout_pass.cpp index b2d3faada..a2d518ac5 100644 --- a/src/ir/transforms/lower_transpose_load_param_layout_pass.cpp +++ b/src/ir/transforms/lower_transpose_load_param_layout_pass.cpp @@ -85,24 +85,6 @@ class TransposeLoadScanner : public IRVisitor { std::unordered_set non_transposed_uses_; }; -/// Build the canonical TensorType for an InCore parameter that is loaded via -/// ``tile.load(transpose=True)`` (RFC #1300 §3.3 + §4.2): -/// src ``[..., a, b] ND`` ≡ canonical ``[..., b, a] DN`` -/// -/// The new TensorView carries an empty stride; ``MaterializeTensorStrides`` -/// (P6-b) fills it with the packed canonical strides later in the pipeline. -TensorTypePtr PromoteToCanonicalDN(const TensorTypePtr& src) { - CHECK(src->shape_.size() >= 2) - << "LowerTransposeLoadParamLayout: parameter must have rank >= 2 to apply DN " - "canonical form, got " - << src->shape_.size(); - std::vector new_shape = src->shape_; - std::iter_swap(new_shape.end() - 2, new_shape.end() - 1); - TensorView dn_view(std::vector{}, TensorLayout::DN); - return std::make_shared(new_shape, src->dtype_, src->memref_, - std::make_optional(std::move(dn_view))); -} - /// Swap the last two elements of a ``MakeTuple`` (offsets / shapes / /// valid_shapes argument of ``tile.load``). MakeTuplePtr SwapTrailingPair(const MakeTuplePtr& tuple) { @@ -116,19 +98,16 @@ MakeTuplePtr SwapTrailingPair(const MakeTuplePtr& tuple) { return std::make_shared(std::move(new_elements), tuple->span_); } -/// Rewrite tile.load calls whose first arg is one of the promoted parameters -/// so that: +/// Rewrite tile.load calls whose first arg is one of the body-local +/// ``b_dn = tensor.as_layout(b, DN)`` bindings (one per promoted param), so: /// - offsets / shapes / valid_shapes are swapped to canonical coords; /// - the ``transpose=True`` kwarg is dropped (DN source + Mat target now /// drives the tile-view swap inside ``DeduceTileLoadType``). /// All other Calls are passed through unchanged. class TileLoadBodyRewriter : public IRMutator { public: - explicit TileLoadBodyRewriter(const std::unordered_map& param_subs) { - for (const auto& [old_ptr, new_var] : param_subs) { - promoted_param_set_.insert(new_var.get()); - } - } + explicit TileLoadBodyRewriter(const std::unordered_set& dn_view_vars) + : dn_view_vars_(dn_view_vars) {} ExprPtr VisitExpr_(const CallPtr& op) override { auto base = IRMutator::VisitExpr_(op); @@ -137,14 +116,14 @@ class TileLoadBodyRewriter : public IRMutator { if (call->args_.empty()) return base; auto src_var = As(call->args_[0]); - if (!src_var || promoted_param_set_.find(src_var.get()) == promoted_param_set_.end()) { + if (!src_var || dn_view_vars_.find(src_var.get()) == dn_view_vars_.end()) { return base; } if (!call->GetKwarg("transpose", false)) return base; // tile.load(tensor, offsets, shapes, valid_shapes, ...) — swap the trailing // pair of all three tuples so the load is expressed in canonical (DN - // logical) coordinates that match the promoted parameter's new shape. + // logical) coordinates that match the body-local ``b_dn`` view's shape. INTERNAL_CHECK_SPAN(call->args_.size() == 4, call->span_) << "LowerTransposeLoadParamLayout: expected tile.load to have 4 args, got " << call->args_.size(); auto offsets = As(call->args_[1]); @@ -179,238 +158,124 @@ class TileLoadBodyRewriter : public IRMutator { } private: - std::unordered_set promoted_param_set_; -}; - -/// Result of promoting a single InCore function. -struct PromotionResult { - FunctionPtr func; - std::map promoted_params; // param index → new param Var + const std::unordered_set& dn_view_vars_; }; -/// Promote an InCore function. Returns the rewritten Function (or the -/// original if no rewrite was needed) and the map of promoted param slots. -/// Throws if any promoted parameter is also loaded without `transpose=True` +/// Rewrite an InCore function: keep params unchanged; prepend +/// ``b_dn = tensor.as_layout(b, layout=DN)`` AssignStmts at the top of the +/// body for every param ``b`` loaded with ``transpose=True``; substitute body +/// uses of ``b`` with ``b_dn``; rewrite each transposed tile.load to swap the +/// trailing pair of offsets/shapes/valid_shapes and drop ``transpose=True``. +/// +/// Returns the rewritten Function (or the original if no rewrite was needed). +/// Throws if any promoted parameter is also loaded without ``transpose=True`` /// in the same body (mixed use would corrupt non-transpose loads). -PromotionResult PromoteInCoreFunction(const FunctionPtr& func) { +FunctionPtr LowerInCoreFunction(const FunctionPtr& func) { TransposeLoadScanner scanner(func->params_); scanner.VisitStmt(func->body_); const auto& promoted = scanner.GetPromoted(); const auto& non_transposed = scanner.GetNonTransposedUses(); if (promoted.empty()) { - return {func, {}}; + return func; } + // Build, in deterministic param-index order: + // - the prepend AssignStmts (one per promoted param), each of the form + // ``b_dn = tensor.as_layout(b, layout=DN)``; + // - the substitution map ``b -> b_dn`` used to rewrite body uses; + // - the set of body-local ``b_dn`` Vars used by ``TileLoadBodyRewriter`` + // to recognize which tile.loads need the trailing-pair swap. + std::vector sorted_promoted(promoted.begin(), promoted.end()); + std::sort(sorted_promoted.begin(), sorted_promoted.end()); + + std::vector prepend; std::unordered_map substitutions; - std::vector new_params = func->params_; - std::map promoted_params; + std::unordered_set dn_view_vars; - for (size_t idx : promoted) { - // Mixed-use rejection: a param promoted from `[a, b]` ND → `[b, a]` DN - // would invalidate every non-transpose `tile.load(p, ...)` that still - // expects the original coordinate system. + for (size_t idx : sorted_promoted) { + // Mixed-use rejection: a body-local DN view derived from ``b`` only + // makes sense if every load of ``b`` agrees on ``transpose=True``. CHECK(non_transposed.find(idx) == non_transposed.end()) << "LowerTransposeLoadParamLayout: parameter at index " << idx << " is loaded both with transpose=True and transpose=False — only one " "mode is supported per InCore parameter. Split the parameter or unify " "the load direction."; - const auto& old_param = func->params_[idx]; - auto old_tensor_type = As(old_param->GetType()); - CHECK(old_tensor_type) << "LowerTransposeLoadParamLayout: promoted parameter at index " << idx - << " must be TensorType"; + const auto& param = func->params_[idx]; + auto param_tensor_type = As(param->GetType()); + CHECK(param_tensor_type) << "LowerTransposeLoadParamLayout: promoted parameter at index " << idx + << " must be TensorType"; // Reject the (DN view + explicit physical stride) combination — these // came from `tensor.transpose` and would compose with the load-side // transpose to produce a double-encoded transpose. - if (old_tensor_type->tensor_view_.has_value()) { - const auto& view = old_tensor_type->tensor_view_.value(); + if (param_tensor_type->tensor_view_.has_value()) { + const auto& view = param_tensor_type->tensor_view_.value(); CHECK(!(view.layout == TensorLayout::DN && !view.stride.empty())) << "LowerTransposeLoadParamLayout: tile.load(transpose=True) on a " "tensor.transpose result is not supported (the DN tag and explicit " "physical strides would compose as a double transpose). Drop one of " "the two transpose layers in the source program."; - // Param already promoted in a prior round (idempotent): skip. + // Param already DN-tagged at the boundary (user-written + // ``pl.Tensor[..., pl.DN]``): the load-side ``transpose=True`` is the + // user-intended signal that the on-chip tile flips back to row-major + // Mat orientation. ``DeduceTileLoadType`` already handles this via + // the (source_is_dn XOR transpose) tile-view logic — adding a bridge + // and dropping ``transpose=True`` would shift the XOR result and + // produce the wrong TileType. Skip this param. if (view.layout == TensorLayout::DN) continue; } - auto new_tensor_type = PromoteToCanonicalDN(old_tensor_type); - auto new_var = std::make_shared(old_param->name_hint_, new_tensor_type, old_param->span_); - new_params[idx] = new_var; - substitutions[old_param.get()] = new_var; - promoted_params.emplace(idx, new_var); + // Build ``b_dn = tensor.as_layout(b, layout=DN)``. Routing through the + // OpRegistry::Create path makes ``DeduceTensorAsLayoutType`` compute + // the post-flip type and inherit ``b``'s MemRef. + std::vector> kwargs = {{"layout", std::any(TensorLayout::DN)}}; + auto bridge_call = OpRegistry::GetInstance().Create("tensor.as_layout", {param}, kwargs, param->span_); + auto bridge_var = + std::make_shared(param->name_hint_ + "_dn_view", bridge_call->GetType(), param->span_); + prepend.push_back(std::make_shared(bridge_var, bridge_call, param->span_)); + substitutions[param.get()] = bridge_var; + dn_view_vars.insert(bridge_var.get()); } - if (substitutions.empty()) { - return {func, {}}; + if (prepend.empty()) { + return func; } - // 1) Substitute param Vars in the body. + // Substitute body uses of each promoted param ``b`` with the body-local + // ``b_dn``. ``Substitute`` walks the entire body — the prepend stmts are + // built using the original ``param`` Vars *before* substitution, so they + // are not affected. auto subbed_body = Substitute(func->body_, substitutions); - // 2) Rewrite each `tile.load(promoted_param, ..., transpose=True)` in the - // body — swap offsets / shapes / valid_shapes trailing pair, drop the - // transpose kwarg. - TileLoadBodyRewriter body_rewriter(substitutions); - auto new_body = body_rewriter.VisitStmt(subbed_body); + // Rewrite each ``tile.load(b_dn, ..., transpose=True)`` to canonical + // (DN-coord) form: swap offsets/shapes/valid_shapes trailing pair, drop + // ``transpose=True``. + TileLoadBodyRewriter body_rewriter(dn_view_vars); + auto rewritten_body = body_rewriter.VisitStmt(subbed_body); + + // Concatenate: new body = SeqStmts([prepend stmts..., rewritten original body]). + std::vector new_body_stmts = std::move(prepend); + new_body_stmts.push_back(rewritten_body); + auto new_body = SeqStmts::Flatten(std::move(new_body_stmts), func->body_->span_); auto new_func = MutableCopy(func); - new_func->params_ = new_params; new_func->body_ = new_body; - return {new_func, promoted_params}; + return new_func; } -/// Walks every non-InCore function in the program and, for each call site -/// targeting a promoted InCore callee, emits an SSA-form binding for each -/// promoted-slot arg: -/// -/// bridged_ = tensor.as_layout(, DN) -/// = (..., bridged_, ...) -/// -/// The binding is emitted as a separate ``AssignStmt`` immediately before the -/// call statement (instead of being inlined inside the call's args), which is -/// what downstream orchestration codegen expects — it consumes a ``Var`` or a -/// constant literal per call arg, not a nested ``Call``. -class CallSiteAsLayoutInjector : public IRMutator { - public: - explicit CallSiteAsLayoutInjector(const std::map>& promotions) - : promotions_(promotions) {} - - StmtPtr VisitStmt_(const SeqStmtsPtr& op) override { - std::vector new_stmts; - new_stmts.reserve(op->stmts_.size()); - bool any_changed = false; - for (const auto& stmt : op->stmts_) { - // Recurse into nested SeqStmts / control-flow first so inner call sites - // get patched too. - auto recursed = IRMutator::VisitStmt(stmt); - bool inserted = false; - auto patched = MaybeInjectBindings(recursed, new_stmts, &inserted); - if (inserted || patched.get() != recursed.get() || recursed.get() != stmt.get()) { - any_changed = true; - } - new_stmts.push_back(patched); - } - if (!any_changed) return op; - return SeqStmts::Flatten(std::move(new_stmts), op->span_); - } - - // Bare (non-SeqStmts) statement bodies — e.g. ``then_body`` of an ``IfStmt`` - // that contains a single ``AssignStmt``. Wrap any injected bindings into - // a fresh SeqStmts so the resulting body stays a single Stmt. - StmtPtr VisitStmt_(const AssignStmtPtr& op) override { - auto recursed = IRMutator::VisitStmt_(op); - std::vector pre; - bool inserted = false; - auto patched = MaybeInjectBindings(recursed, pre, &inserted); - if (!inserted) return patched; - pre.push_back(patched); - return SeqStmts::Flatten(std::move(pre), op->span_); - } - - StmtPtr VisitStmt_(const EvalStmtPtr& op) override { - auto recursed = IRMutator::VisitStmt_(op); - std::vector pre; - bool inserted = false; - auto patched = MaybeInjectBindings(recursed, pre, &inserted); - if (!inserted) return patched; - pre.push_back(patched); - return SeqStmts::Flatten(std::move(pre), op->span_); - } - - StmtPtr VisitStmt_(const ReturnStmtPtr& op) override { - auto recursed = IRMutator::VisitStmt_(op); - std::vector pre; - bool inserted = false; - auto patched = MaybeInjectBindings(recursed, pre, &inserted); - if (!inserted) return patched; - pre.push_back(patched); - return SeqStmts::Flatten(std::move(pre), op->span_); - } - - private: - /// If ``stmt``'s RHS is a Call to a promoted callee, build the binding - /// AssignStmts (one per promoted slot) and emit them into ``pre``; - /// rewrite the Call to reference the bound Vars. Returns the (possibly - /// rewritten) statement and sets ``*inserted = true`` if any bindings - /// were added. - StmtPtr MaybeInjectBindings(const StmtPtr& stmt, std::vector& pre, bool* inserted) { - auto extract_call = [](const StmtPtr& s) -> std::pair { - if (auto assign = As(s)) { - return {As(assign->value_), assign->var_}; - } - if (auto eval = As(s)) { - return {As(eval->expr_), nullptr}; - } - if (auto ret = As(s)) { - if (ret->value_.size() == 1) { - return {As(ret->value_[0]), nullptr}; - } - } - return {nullptr, nullptr}; - }; - - auto [call, lhs_var] = extract_call(stmt); - if (!call) return stmt; - auto gv = As(call->op_); - if (!gv) return stmt; - auto it = promotions_.find(gv->name_); - if (it == promotions_.end() || it->second.empty()) return stmt; - const auto& slots = it->second; - - std::vector new_args = call->args_; - bool changed = false; - for (const auto& [idx, new_param_var] : slots) { - INTERNAL_CHECK_SPAN(idx < new_args.size(), call->span_) - << "LowerTransposeLoadParamLayout: promoted param index " << idx << " out of range for call to " - << gv->name_; - auto arg = new_args[idx]; - auto arg_tensor = As(arg->GetType()); - if (!arg_tensor) continue; - // Idempotency: an arg already in DN form needs no bridge. - if (arg_tensor->tensor_view_.has_value() && arg_tensor->tensor_view_->layout == TensorLayout::DN) { - continue; - } - // Build the bridge: bridged = tensor.as_layout(arg, DN). - std::vector> kwargs = {{"layout", std::any(TensorLayout::DN)}}; - auto bridge_call = OpRegistry::GetInstance().Create("tensor.as_layout", {arg}, kwargs, arg->span_); - auto bridge_var = - std::make_shared(new_param_var->name_hint_ + "_dn_view", bridge_call->GetType(), arg->span_); - pre.push_back(std::make_shared(bridge_var, bridge_call, arg->span_)); - new_args[idx] = bridge_var; - changed = true; - } - if (!changed) return stmt; - *inserted = true; - - auto new_call = std::make_shared(call->op_, std::move(new_args), call->kwargs_, call->attrs_, - call->GetType(), call->span_); - if (auto assign = As(stmt)) { - return std::make_shared(assign->var_, new_call, assign->span_); - } - if (auto eval = As(stmt)) { - return std::make_shared(new_call, eval->span_); - } - if (auto ret = As(stmt)) { - return std::make_shared(std::vector{new_call}, ret->span_); - } - return stmt; // unreachable — extract_call only returns non-null for the three above - } - - const std::map>& promotions_; -}; - } // namespace namespace pass { Pass LowerTransposeLoadParamLayout() { auto pass_func = [](const ProgramPtr& program) -> ProgramPtr { - // Phase 1: rewrite InCore functions and collect promotion info keyed by - // callee name (callers reference InCore functions through Call->op_'s - // GlobalVar, which is matched on name_). + // Rewrite each InCore function: prepend ``b_dn = tensor.as_layout(b, DN)`` + // for every ``transpose=True``-loaded param ``b`` and substitute body uses + // accordingly. Non-InCore functions (orch callers) are left untouched — + // they pass their original ND args straight through; the layout + // reinterpret is now owned by the InCore body it serves. std::map new_functions; - std::map> promotions_by_callee_name; bool modified = false; for (const auto& [gvar, func] : program->functions_) { @@ -418,32 +283,9 @@ Pass LowerTransposeLoadParamLayout() { new_functions[gvar] = func; continue; } - auto result = PromoteInCoreFunction(func); - new_functions[gvar] = result.func; - if (result.func.get() != func.get()) modified = true; - if (!result.promoted_params.empty()) { - promotions_by_callee_name[gvar->name_] = std::move(result.promoted_params); - } - } - - if (promotions_by_callee_name.empty()) { - return modified ? std::make_shared(std::move(new_functions), program->name_, program->span_) - : program; - } - - // Phase 2: walk non-InCore functions and inject `tensor.as_layout` at - // each call site that targets a promoted callee. - CallSiteAsLayoutInjector injector(promotions_by_callee_name); - for (auto& [gvar, func] : new_functions) { - if (IsInCoreType(func->func_type_)) continue; - if (!func->body_) continue; - auto new_body = injector.VisitStmt(func->body_); - if (new_body.get() != func->body_.get()) { - auto new_func = MutableCopy(func); - new_func->body_ = new_body; - new_functions[gvar] = new_func; - modified = true; - } + auto new_func = LowerInCoreFunction(func); + new_functions[gvar] = new_func; + if (new_func.get() != func.get()) modified = true; } if (!modified) return program; diff --git a/tests/ut/ir/transforms/test_lower_transpose_load_param_layout_pass.py b/tests/ut/ir/transforms/test_lower_transpose_load_param_layout_pass.py index 470ef6abb..3c140f575 100644 --- a/tests/ut/ir/transforms/test_lower_transpose_load_param_layout_pass.py +++ b/tests/ut/ir/transforms/test_lower_transpose_load_param_layout_pass.py @@ -9,12 +9,15 @@ """Unit tests for LowerTransposeLoadParamLayout pass (RFC #1300 P6). -The pass promotes each InCore parameter loaded via ``tile.load(transpose=True)`` -to canonical-form DN (RFC §3.3 + §4.2): the trailing shape pair is swapped, -the layout tag becomes DN, the body's ``tile.load`` call swaps its -``offsets`` / ``shapes`` / ``valid_shapes`` trailing pair and drops the -``transpose=True`` kwarg, and every non-InCore call site bridges its arg -through ``tensor.as_layout(arg, DN)``. +The pass leaves InCore parameter signatures untouched and instead prepends a +``b_dn = tensor.as_layout(b, layout=DN)`` AssignStmt at the top of the InCore +body for each param ``b`` loaded via ``tile.load(transpose=True)``. Body uses +of ``b`` are substituted with ``b_dn`` (which has the canonical +``[..., b_dim, a_dim] DN`` view per RFC §3.3 + §4.2), and the matching +``tile.load`` calls have their ``offsets`` / ``shapes`` / ``valid_shapes`` +trailing pair swapped while ``transpose=True`` is flipped to +``transpose=False``. Non-InCore (orch) call sites are not touched — they pass +their original ND args straight through to the kernel. ``tensor.as_layout`` is internal-only and not exposed via ``pypto.language``, so we cannot write the post-pass IR as ``@pl.program``. Instead we drive the @@ -74,14 +77,6 @@ def _find_calls_to(func, callee_name): return calls -def _find_assign_rhs(func, var): - """Return the RHS expression of the ``AssignStmt`` that defines ``var``.""" - for stmt in _iter_stmts(func.body): - if isinstance(stmt, ir.AssignStmt) and stmt.var is var: - return stmt.value - raise AssertionError(f"no AssignStmt defines var {var.name_hint}") - - def _shape_dims(ty): """Return ConstInt shape dims as ints (rejects symbolic dims for test fixtures).""" tensor_type = _as_tensor_type(ty) @@ -97,6 +92,49 @@ def _transpose_kwarg(call): return call.kwargs.get("transpose") +def _find_as_layout_binding(func, input_var): + """Find the body-prepended ``b_dn = tensor.as_layout(input_var, ...)`` + AssignStmt and return ``(lhs_var, rhs_call)``. Asserts that exactly one + such binding exists in the body. + """ + matches = [] + for stmt in _iter_stmts(func.body): + if not isinstance(stmt, ir.AssignStmt): + continue + rhs = stmt.value + if not isinstance(rhs, ir.Call) or rhs.op is None: + continue + if rhs.op.name != "tensor.as_layout": + continue + if not rhs.args or not isinstance(rhs.args[0], ir.Var): + continue + if rhs.args[0] is input_var: + matches.append((stmt.var, rhs)) + assert len(matches) == 1, ( + f"expected exactly one tensor.as_layout binding for {input_var.name_hint}, found {len(matches)}" + ) + return matches[0] + + +def _has_as_layout_for(func, input_var): + """Return True iff ``func.body`` contains a tensor.as_layout binding whose + first arg is ``input_var``.""" + for stmt in _iter_stmts(func.body): + if not isinstance(stmt, ir.AssignStmt): + continue + rhs = stmt.value + if ( + isinstance(rhs, ir.Call) + and rhs.op is not None + and rhs.op.name == "tensor.as_layout" + and rhs.args + and isinstance(rhs.args[0], ir.Var) + and rhs.args[0] is input_var + ): + return True + return False + + class TestBTransposePromotesParam: """``C = A @ B^T`` with B loaded via ``transpose=True`` — param promoted to DN.""" @@ -131,47 +169,61 @@ def orchestrator( After = passes.lower_transpose_load_param_layout()(Before) incore = _find_function(After, "matmul_incore") - b_type = _as_tensor_type(incore.params[1].type) - assert _shape_dims(b_type) == [K, N], f"b param shape: {_shape_dims(b_type)}" - assert b_type.tensor_view is not None - assert b_type.tensor_view.layout == ir.TensorLayout.DN + + # Param signatures are untouched — ``b`` is still ``[N, K] ND``. + b_param = incore.params[1] + b_type = _as_tensor_type(b_param.type) + assert _shape_dims(b_type) == [N, K], f"b param shape: {_shape_dims(b_type)}" + assert b_type.tensor_view is None a_type = _as_tensor_type(incore.params[0].type) assert _shape_dims(a_type) == [M, K] assert a_type.tensor_view is None + # The body has a prepended ``b_dn = tensor.as_layout(b, DN)`` binding + # whose result carries the canonical ``[K, N] DN`` view. + b_dn_var, b_dn_call = _find_as_layout_binding(incore, b_param) + b_dn_type = _as_tensor_type(b_dn_var.type) + assert _shape_dims(b_dn_type) == [K, N], f"b_dn shape: {_shape_dims(b_dn_type)}" + assert b_dn_type.tensor_view is not None + assert b_dn_type.tensor_view.layout == ir.TensorLayout.DN + # ``layout`` kwarg on the as_layout call should be DN. + assert b_dn_call.kwargs.get("layout") == ir.TensorLayout.DN + + # ``tile.load`` on the promoted slot now reads from ``b_dn``, has its + # trailing pair swapped, and carries ``transpose=False``. loads_by_src = {} for ld in _find_tile_loads(incore): assert isinstance(ld.args[0], ir.Var) loads_by_src[ld.args[0].name_hint] = ld - load_b = loads_by_src["b"] + # The tile.load(b, transpose=True) was rewritten to load from b_dn. + assert b_dn_var.name_hint in loads_by_src, ( + f"tile.load must read from the as_layout LHS, not raw param. " + f"loaded srcs: {list(loads_by_src.keys())}" + ) + load_b = loads_by_src[b_dn_var.name_hint] shapes_arg = load_b.args[2] assert isinstance(shapes_arg, ir.MakeTuple) shape_vals = [el.value for el in shapes_arg.elements if isinstance(el, ir.ConstInt)] - assert shape_vals == [K, N], f"tile.load(b) shapes: {shape_vals}" - assert _transpose_kwarg(load_b) is False, "tile.load(b) transpose kwarg must be False after P6" + assert shape_vals == [K, N], f"tile.load(b_dn) shapes: {shape_vals}" + assert _transpose_kwarg(load_b) is False, "tile.load(b_dn) transpose kwarg must be False after P6" load_a = loads_by_src["a"] shape_vals_a = [el.value for el in load_a.args[2].elements if isinstance(el, ir.ConstInt)] assert shape_vals_a == [M, K] + # Orch is untouched — its call site passes ``b`` (a wrapper param) + # directly without any tensor.as_layout bridge. orch = _find_function(After, "orchestrator") calls = _find_calls_to(orch, "matmul_incore") assert len(calls) == 1 - # `b` is bridged via an SSA AssignStmt: the call arg is a Var bound to - # a separately-emitted ``tensor.as_layout(orig_b, DN)`` Call. b_arg = calls[0].args[1] assert isinstance(b_arg, ir.Var) - b_def_rhs = _find_assign_rhs(orch, b_arg) - assert isinstance(b_def_rhs, ir.Call) and b_def_rhs.op is not None - assert b_def_rhs.op.name == "tensor.as_layout", ( - f"orch must wrap b in tensor.as_layout, got {b_def_rhs.op.name if b_def_rhs.op else None}" + assert b_arg is orch.params[1], "orch call arg must be the orch's own param (no bridge)" + assert not _has_as_layout_for(orch, orch.params[1]), ( + "no tensor.as_layout bridge should be emitted in orch under the InCore-side design" ) - bridged_t = _as_tensor_type(b_def_rhs.type) - assert _shape_dims(bridged_t) == [K, N] - assert bridged_t.tensor_view is not None - assert bridged_t.tensor_view.layout == ir.TensorLayout.DN def test_btranspose_non_square(self): M, K, N = 128, 64, 32 @@ -203,10 +255,17 @@ def orchestrator( After = passes.lower_transpose_load_param_layout()(Before) incore = _find_function(After, "matmul_incore") - b_type = _as_tensor_type(incore.params[1].type) - assert _shape_dims(b_type) == [K, N] - assert b_type.tensor_view is not None - assert b_type.tensor_view.layout == ir.TensorLayout.DN + b_param = incore.params[1] + b_type = _as_tensor_type(b_param.type) + # Param is untouched. + assert _shape_dims(b_type) == [N, K] + assert b_type.tensor_view is None + # Body prepends b_dn = as_layout(b, DN); LHS carries [K, N] DN. + b_dn_var, _ = _find_as_layout_binding(incore, b_param) + b_dn_type = _as_tensor_type(b_dn_var.type) + assert _shape_dims(b_dn_type) == [K, N] + assert b_dn_type.tensor_view is not None + assert b_dn_type.tensor_view.layout == ir.TensorLayout.DN class TestATransposePromotesParam: @@ -242,34 +301,40 @@ def orchestrator( After = passes.lower_transpose_load_param_layout()(Before) incore = _find_function(After, "matmul_incore") - a_t = _as_tensor_type(incore.params[0].type) - assert _shape_dims(a_t) == [M, K] - assert a_t.tensor_view is not None - assert a_t.tensor_view.layout == ir.TensorLayout.DN - - b_t = _as_tensor_type(incore.params[1].type) + a_param = incore.params[0] + a_t = _as_tensor_type(a_param.type) + # Param is untouched. + assert _shape_dims(a_t) == [K, M] + assert a_t.tensor_view is None + # Body prepends a_dn = as_layout(a, DN); LHS carries [M, K] DN. + a_dn_var, _ = _find_as_layout_binding(incore, a_param) + a_dn_type = _as_tensor_type(a_dn_var.type) + assert _shape_dims(a_dn_type) == [M, K] + assert a_dn_type.tensor_view is not None + assert a_dn_type.tensor_view.layout == ir.TensorLayout.DN + + # ``b`` is not promoted (no transpose=True load), so no binding for it. + b_param = incore.params[1] + b_t = _as_tensor_type(b_param.type) assert _shape_dims(b_t) == [K, N] assert b_t.tensor_view is None + assert not _has_as_layout_for(incore, b_param) + # ``tile.load`` reads from a_dn (the binding's LHS), not from raw ``a``. loads = {ld.args[0].name_hint: ld for ld in _find_tile_loads(incore)} - load_a = loads["a"] + assert a_dn_var.name_hint in loads, f"expected load from a_dn, got srcs: {list(loads)}" + load_a = loads[a_dn_var.name_hint] shape_vals = [el.value for el in load_a.args[2].elements if isinstance(el, ir.ConstInt)] assert shape_vals == [M, K] assert _transpose_kwarg(load_a) is False + # Orch is untouched — call args are direct refs to wrapper params, + # no tensor.as_layout bridges injected. orch = _find_function(After, "orchestrator") call = _find_calls_to(orch, "matmul_incore")[0] - # `a` is bridged via tensor.as_layout. After P6's SSA refactor (PR - # review fix), the bridge is bound to a fresh Var by a preceding - # AssignStmt, so the call arg is a Var, not the inline Call. Look up - # the binding's RHS. - a_arg = call.args[0] - assert isinstance(a_arg, ir.Var) - a_def_rhs = _find_assign_rhs(orch, a_arg) - assert isinstance(a_def_rhs, ir.Call) and a_def_rhs.op is not None - assert a_def_rhs.op.name == "tensor.as_layout" - # `b` is not promoted, so its arg is the raw Var (no bridge). - assert isinstance(call.args[1], ir.Var) + assert call.args[0] is orch.params[0] + assert call.args[1] is orch.params[1] + assert not _has_as_layout_for(orch, orch.params[0]) class TestABTransposePromotesBothParams: @@ -305,22 +370,31 @@ def orchestrator( After = passes.lower_transpose_load_param_layout()(Before) incore = _find_function(After, "matmul_incore") - a_t = _as_tensor_type(incore.params[0].type) - b_t = _as_tensor_type(incore.params[1].type) - assert _shape_dims(a_t) == [M, K] - assert a_t.tensor_view is not None and a_t.tensor_view.layout == ir.TensorLayout.DN - assert _shape_dims(b_t) == [K, N] - assert b_t.tensor_view is not None and b_t.tensor_view.layout == ir.TensorLayout.DN - + a_param = incore.params[0] + b_param = incore.params[1] + # Both params are untouched. + a_t = _as_tensor_type(a_param.type) + b_t = _as_tensor_type(b_param.type) + assert _shape_dims(a_t) == [K, M] + assert a_t.tensor_view is None + assert _shape_dims(b_t) == [N, K] + assert b_t.tensor_view is None + # Body prepends one as_layout binding per promoted param. + a_dn_var, _ = _find_as_layout_binding(incore, a_param) + b_dn_var, _ = _find_as_layout_binding(incore, b_param) + a_dn_t = _as_tensor_type(a_dn_var.type) + b_dn_t = _as_tensor_type(b_dn_var.type) + assert _shape_dims(a_dn_t) == [M, K] + assert a_dn_t.tensor_view is not None and a_dn_t.tensor_view.layout == ir.TensorLayout.DN + assert _shape_dims(b_dn_t) == [K, N] + assert b_dn_t.tensor_view is not None and b_dn_t.tensor_view.layout == ir.TensorLayout.DN + + # Orch is untouched — no bridges injected. orch = _find_function(After, "orchestrator") call = _find_calls_to(orch, "matmul_incore")[0] - # Both promoted args are bridged via SSA AssignStmts. for slot in (0, 1): - arg = call.args[slot] - assert isinstance(arg, ir.Var) - rhs = _find_assign_rhs(orch, arg) - assert isinstance(rhs, ir.Call) and rhs.op is not None - assert rhs.op.name == "tensor.as_layout" + assert call.args[slot] is orch.params[slot] + assert not _has_as_layout_for(orch, orch.params[slot]) class TestNoOpCases: @@ -467,13 +541,25 @@ def orchestrator( After = passes.lower_transpose_load_param_layout()(Before) incore = _find_function(After, "kernel") - kc_t = _as_tensor_type(incore.params[1].type) + kc_param = incore.params[1] + kc_t = _as_tensor_type(kc_param.type) + # Param is untouched. assert _shape_dims(kc_t) == [128, 128] - assert kc_t.tensor_view is not None - assert kc_t.tensor_view.layout == ir.TensorLayout.DN - + assert kc_t.tensor_view is None + # Body prepends key_cache_dn = as_layout(key_cache, DN). Shape stays + # [128, 128] (square) but layout is DN — the trailing-pair swap on the + # full TensorType shape happens to be identity here. + kc_dn_var, _ = _find_as_layout_binding(incore, kc_param) + kc_dn_t = _as_tensor_type(kc_dn_var.type) + assert _shape_dims(kc_dn_t) == [128, 128] + assert kc_dn_t.tensor_view is not None + assert kc_dn_t.tensor_view.layout == ir.TensorLayout.DN + + # tile.load reads from the binding's LHS, with swapped load window + # ([64, 128] -> [128, 64]) and transpose=False. loads = {ld.args[0].name_hint: ld for ld in _find_tile_loads(incore)} - load_k = loads["key_cache"] + assert kc_dn_var.name_hint in loads + load_k = loads[kc_dn_var.name_hint] shape_vals = [el.value for el in load_k.args[2].elements if isinstance(el, ir.ConstInt)] assert shape_vals == [128, 64] assert _transpose_kwarg(load_k) is False