Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 63 additions & 8 deletions docs/en/user/01-language_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,37 @@ idx: pl.Scalar[pl.INDEX] # index scalar

### Tensor Layouts

Layouts control the physical memory arrangement of Tensors:
Tensors describe row-major memory; layout is normally **derived** from the
ops that produce a view. The user-facing toolbox:

| Layout | Description |
| ------ | ----------- |
| `pl.ND` | N-Dimensional (default, row-major) |
| `pl.DN` | DN layout |
| `pl.NZ` | NZ fractal format (hardware-specific tiling) |
| Pattern | Use when | Result |
| ------- | -------- | ------ |
| `pl.Tensor[[..], pl.FP32]` (no layout marker) | Default — your source tensor is plain row-major memory | ND |
| `pl.transpose(t, -2, -1)` | Need a transposed view at use site (e.g. matmul B^T pattern) | DN |
| `pl.slice(view, ...)` / `pl.reshape(view, ...)` | Sub-view that should inherit a parent's layout | Same family as parent |

```python
# Specify layout as third type parameter
a: pl.Tensor[[64, 128], pl.FP16, pl.NZ]
# ✅ Recommended — write source tensor shape, derive DN at use site:
b: pl.Tensor[[N, K], pl.FP32]
b_t = pl.transpose(b, -2, -1) # ND → DN view, same physical buffer
tile_b = pl.load(b_t, [0, 0], [K, N], target_memory=pl.MemorySpace.Mat)
```

```python
# ⚠️ Deprecated (RFC #1300 supplementary 1):
b: pl.Tensor[[K, N], pl.FP32, pl.DN] # → DeprecationWarning at parse time
```

> **Why `pl.Tensor[..., pl.DN]` is deprecated.** Writing the DN
> layout-only shorthand forces you to mentally hold two coordinate systems
> at once (the IR-logical post-view shape and the runtime row-major shape).
> Compose `pl.transpose` instead — the source tensor always uses its
> runtime shape, and the DN view appears explicitly in the program.

For NZ (hardware-specific tile layout), use `pl.Tile[..., pl.NZ]` — NZ is
tile-only, never a TensorType annotation. The `pl.NZ` constant remains
available for tile annotations and IR-internal use.
Comment thread
coderabbitai[bot] marked this conversation as resolved.

### Dynamic Shapes

Use `pl.dynamic()` for dimensions determined at runtime:
Expand Down Expand Up @@ -515,6 +533,43 @@ c_acc = pl.matmul(a_l0a, b_l0b) # result → Acc
out = pl.store(c_acc, [0, 0], output) # Acc → DDR
```

### Pattern: Matmul B^T via Transposed View

For `c = a @ b^T`, derive the transposed view via `pl.transpose` **before**
the load — do not pass `transpose=True` to `pl.load`:

```python
# ✅ Recommended (RFC #1300 supplementary 2):
a_l1 = pl.load(a, [0, 0], [M, K], target_memory=pl.Mem.Mat)
b_t = pl.transpose(b, -2, -1) # b is [N, K] ND; b_t is [K, N] DN view
b_l1 = pl.load(b_t, [0, 0], [K, N], target_memory=pl.Mem.Mat)
a_l0a = pl.move(a_l1, target_memory=pl.Mem.Left)
b_l0b = pl.move(b_l1, target_memory=pl.Mem.Right)
c_acc = pl.matmul(a_l0a, b_l0b)
```

```python
# ⚠️ Deprecated (kept for back-compat; emits DeprecationWarning):
b_l1 = pl.load(b, [0, 0], [N, K], target_memory=pl.Mem.Mat, transpose=True)
```

Reusing a transposed view across multiple loads (e.g. K-tiled matmuls)
is more compact in the new form:

```python
b_t = pl.transpose(b, -2, -1) # one-time view
for ki in pl.range(0, K, K_TILE):
b_i = pl.load(b_t, [ki, 0], [K_TILE, N], target_memory=pl.Mem.Mat)
...
```

> **Why `transpose=True` is deprecated.** It mixes a view-reinterpret
> ("treat this tensor as transposed") into a memory-copy op, which breaks
> the orthogonality of `pl.slice` / `pl.reshape` / `pl.transpose`.
> `pl.transpose` is itself a pure metadata reinterpret (no allocation, no
> compute) and produces the same end-to-end semantics — `pl.load` then
> handles only memory movement.

## Compilation

### `ir.compile()`
Expand Down
57 changes: 49 additions & 8 deletions docs/zh-cn/user/01-language_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,30 @@ idx: pl.Scalar[pl.INDEX] # 索引标量

### 张量布局(TensorLayout)

布局控制 Tensor 的物理内存排列
Tensor 描述行优先内存,layout 通常由派生 view 的 op **推导**而出。用户面工具箱

| 布局 | 说明 |
| ---- | ---- |
| `pl.ND` | N 维(默认,行优先) |
| `pl.DN` | DN 布局 |
| `pl.NZ` | NZ 分形格式(硬件特定分块) |
| 写法 | 适用场景 | 结果 |
| ---- | -------- | ---- |
| `pl.Tensor[[..], pl.FP32]`(不写 layout) | 默认 —— 源 tensor 是普通行优先内存 | ND |
| `pl.transpose(t, -2, -1)` | 在使用点派生转置视图(如 matmul B^T) | DN |
| `pl.slice(view, ...)` / `pl.reshape(view, ...)` | 子视图应该继承父 layout | 同父 layout 家族 |

```python
# 指定布局作为第三个类型参数
a: pl.Tensor[[64, 128], pl.FP16, pl.NZ]
# ✅ 推荐 —— 写源 tensor shape,在使用点派生 DN:
b: pl.Tensor[[N, K], pl.FP32]
b_t = pl.transpose(b, -2, -1) # ND → DN 视图,同一片物理 buffer
tile_b = pl.load(b_t, [0, 0], [K, N], target_memory=pl.MemorySpace.Mat)
```

```python
# ⚠️ 已弃用(RFC #1300 补充 1):
b: pl.Tensor[[K, N], pl.FP32, pl.DN] # → 解析期触发 DeprecationWarning
```

> **为什么弃用 `pl.Tensor[..., pl.DN]`。** 这个 layout-only 简写迫使用户脑子里同时持有两套坐标系(IR 逻辑后视图 shape 与 runtime 行优先 shape)—— 恰恰是 RFC #1300 想要消除的歧义。改用 `pl.transpose` 组合,源 tensor 永远写 runtime shape,DN 视图在程序里显式出现。

如需 NZ(硬件 tile layout),写 `pl.Tile[..., pl.NZ]` —— NZ 是 tile-only,不允许作为 TensorType annotation。`pl.NZ` 常量保留用于 tile annotation 和 IR 内部使用。

Comment thread
coderabbitai[bot] marked this conversation as resolved.
### 动态形状(Dynamic Shapes)

使用 `pl.dynamic()` 声明运行时确定的维度:
Expand Down Expand Up @@ -515,6 +526,36 @@ c_acc = pl.matmul(a_l0a, b_l0b) # 结果 → Acc
out = pl.store(c_acc, [0, 0], output) # Acc → DDR
```

### 模式:B^T 矩阵乘法(用 transpose 视图代替 transpose=True)

`c = a @ b^T` 应该在 load 之前用 `pl.transpose` 派生转置视图 —— 不要给 `pl.load` 传 `transpose=True`:

```python
# ✅ 推荐(RFC #1300 补充 2):
a_l1 = pl.load(a, [0, 0], [M, K], target_memory=pl.Mem.Mat)
b_t = pl.transpose(b, -2, -1) # b 是 [N, K] ND;b_t 是 [K, N] DN 视图
b_l1 = pl.load(b_t, [0, 0], [K, N], target_memory=pl.Mem.Mat)
a_l0a = pl.move(a_l1, target_memory=pl.Mem.Left)
b_l0b = pl.move(b_l1, target_memory=pl.Mem.Right)
c_acc = pl.matmul(a_l0a, b_l0b)
```

```python
# ⚠️ 已弃用(兼容期保留;触发 DeprecationWarning):
b_l1 = pl.load(b, [0, 0], [N, K], target_memory=pl.Mem.Mat, transpose=True)
```

转置视图可以跨多次 load 复用(K-tiled 矩阵乘法常见模式),新形式更紧凑:

```python
b_t = pl.transpose(b, -2, -1) # 一次性派生
for ki in pl.range(0, K, K_TILE):
b_i = pl.load(b_t, [ki, 0], [K_TILE, N], target_memory=pl.Mem.Mat)
...
```

> **为什么弃用 `transpose=True`。** 它把 view 重解释("把这个 tensor 视为转置后的")混进了内存搬运 op 里,破坏了 `pl.slice` / `pl.reshape` / `pl.transpose` 的正交性。`pl.transpose` 本身就是纯元数据 reinterpret(不分配、不计算),端到端语义相同 —— 拆开后 `pl.load` 只负责内存搬运。

## 编译

### `ir.compile()`
Expand Down
7 changes: 5 additions & 2 deletions examples/models/04_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,12 @@ def kernel_qk_matmul(
kj: pl.Tensor[[128, 128], pl.BF16],
output: pl.Out[pl.Tensor[[16, 128], pl.FP32]],
) -> pl.Tensor[[16, 128], pl.FP32]:
"""QK matmul: sij = qi @ kj.T (CUBE). kj transposed during load to L1."""
"""QK matmul: sij = qi @ kj.T (CUBE). kj is reinterpreted as a DN view via
``pl.transpose`` before load (RFC #1300 supplementary 2 — replaces the
deprecated ``pl.load(..., transpose=True)`` shorthand)."""
qi_l1 = pl.load(qi, [0, 0], [16, 128], target_memory=pl.MemorySpace.Mat)
kj_l1 = pl.load(kj, [0, 0], [128, 128], target_memory=pl.MemorySpace.Mat, transpose=True)
kj_t = pl.transpose(kj, -2, -1) # [128, 128] ND → [128, 128] DN view
kj_l1 = pl.load(kj_t, [0, 0], [128, 128], target_memory=pl.MemorySpace.Mat)
qi_l0a = pl.move(qi_l1, target_memory=pl.MemorySpace.Left)
kj_l0b = pl.move(kj_l1, target_memory=pl.MemorySpace.Right)
sij_l0c = pl.matmul(qi_l0a, kj_l0b)
Expand Down
9 changes: 5 additions & 4 deletions examples/models/06_paged_attention_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,12 @@ def dyn_kernel_qk_matmul(
kj: pl.Tensor[[BLOCK_SIZE_DYN, HEAD_DIM_DYN], pl.BF16],
output: pl.Out[pl.Tensor[[Q_HEADS, BLOCK_SIZE_DYN], pl.FP32]],
) -> pl.Tensor[[Q_HEADS, BLOCK_SIZE_DYN], pl.FP32]:
"""QK matmul: output = qi @ kj.T (CUBE). kj transposed on load."""
"""QK matmul: output = qi @ kj.T (CUBE). kj is reinterpreted as DN via
``pl.transpose`` before load (RFC #1300 supplementary 2 — replaces the
deprecated ``pl.load(..., transpose=True)`` shorthand)."""
qi_l1 = pl.load(qi, [0, 0], [_Q_TILE, _HEAD_DIM], target_memory=pl.MemorySpace.Mat)
kj_l1 = pl.load(
kj, [0, 0], [_BLOCK_SIZE, _HEAD_DIM], target_memory=pl.MemorySpace.Mat, transpose=True
)
kj_t = pl.transpose(kj, -2, -1) # ND → DN view; offsets/shapes flip too
kj_l1 = pl.load(kj_t, [0, 0], [_HEAD_DIM, _BLOCK_SIZE], target_memory=pl.MemorySpace.Mat)
qi_l0a = pl.move(qi_l1, target_memory=pl.MemorySpace.Left)
kj_l0b = pl.move(kj_l1, target_memory=pl.MemorySpace.Right)
sij_l0c = pl.matmul(qi_l0a, kj_l0b)
Expand Down
34 changes: 30 additions & 4 deletions python/pypto/language/op/tile_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,16 +327,42 @@ def load(
transpose: Whether to transpose the tile during load (default: False).
Only supported when target_memory is MemorySpace.Mat (L1).

.. deprecated:: RFC #1300 §3.3
``transpose=True`` mixes a view-reinterpret into a memory-copy
op and breaks the orthogonality of ``pl.slice`` / ``pl.reshape``
/ ``pl.transpose``. Compose ``pl.transpose(tensor, -2, -1)``
followed by ``pl.load(..., target_memory=pl.MemorySpace.Mat)``
on the DN view instead — ``pl.transpose`` is a pure metadata
reinterpret and the resulting ``tile.load`` carries the same
semantics without the kwarg.

Returns:
Tile wrapping the load operation

Example:
>>> # 2D load
>>> tile = load(tensor, offsets=[0, 0], shapes=[32, 32])
>>> # 2D load with transpose to L1 (tensor is [N, K], output tile is [K, N])
>>> tile = load(tensor, offsets=[0, 0], shapes=[N, K],
... target_memory=pl.MemorySpace.Mat, transpose=True)
"""
>>> # Migrating away from transpose=True (B^T-style load to L1):
>>> # ❌ deprecated:
>>> # tile = load(tensor, offsets=[0, 0], shapes=[N, K],
>>> # target_memory=pl.MemorySpace.Mat, transpose=True)
>>> # ✅ new pattern:
>>> tensor_t = transpose(tensor, -2, -1) # ND → DN view
>>> tile = load(tensor_t, offsets=[0, 0], shapes=[K, N],
... target_memory=pl.MemorySpace.Mat)
"""
if transpose:
warnings.warn(
"pl.load(..., transpose=True) is deprecated (RFC #1300 supplementary 2). "
"Mixing a view-reinterpret into a memory-copy op breaks the orthogonality "
"of pl.slice / pl.reshape / pl.transpose. Migrate to "
"`pl.transpose(tensor, -2, -1)` followed by `pl.load(...)` on the resulting "
"DN view (no `transpose=True` kwarg). The new form is equivalent end-to-end: "
"pl.transpose is a pure metadata reinterpret, and the DN-tagged tile.load "
"produces the same TileType per RFC §4.2 canonical pair.",
DeprecationWarning,
stacklevel=2,
)
if valid_shapes is None:
valid_shapes = shapes
call_expr = _ir_ops.load(
Expand Down
32 changes: 32 additions & 0 deletions python/pypto/language/parser/type_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"""Type annotation resolution for IR parsing."""

import ast
import warnings
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any, cast

Expand Down Expand Up @@ -441,6 +442,7 @@ def _resolve_subscript_type(self, subscript_node: ast.Subscript) -> ir.Type: #
tensor_view = self._resolve_tensorview(third)
return tensor_ctor(shape, dtype, None, tensor_view)
layout = self.resolve_layout(third)
self._warn_on_user_facing_dn_layout(layout, type_name)
tensor_view = ir.TensorView([], layout)
return tensor_ctor(shape, dtype, None, tensor_view)

Expand All @@ -450,6 +452,7 @@ def _resolve_subscript_type(self, subscript_node: ast.Subscript) -> ir.Type: #
tensor_view = self._resolve_tensorview(third)
else:
layout = self.resolve_layout(third)
self._warn_on_user_facing_dn_layout(layout, type_name)
tensor_view = ir.TensorView([], layout)
memref_node = slice_value.elts[3]
if not self._is_memref_node(memref_node):
Expand Down Expand Up @@ -986,6 +989,35 @@ def resolve_dtype(self, dtype_node: ast.expr) -> DataType:
hint="Use pl.FP32, pl.INT32, or other supported dtype constants",
)

def _warn_on_user_facing_dn_layout(self, layout: "ir.TensorLayout", type_name: str) -> None:
"""Emit a ``DeprecationWarning`` when the user writes the layout-only DN
shorthand on a tensor type annotation (RFC #1300 supplementary 1).

Suppressed for ``ir.TensorLayout.ND`` (default, no-op marker) and for
explicit ``pl.TensorView(stride=..., layout=DN)`` forms (which carry
their own stride and don't rely on the shorthand's implicit coordinate
flip). Tile-side layouts are never seen here — Tile annotations route
through ``_resolve_tile_annotation_args``.
"""
if layout != ir.TensorLayout.DN:
return
warnings.warn(
f"pl.{type_name}[..., pl.DN] is deprecated (RFC #1300 supplementary 1). "
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The warning message uses pl.{type_name} as a prefix. While this is correct for Tensor, DistributedTensor is typically imported from the pypto.language.distributed namespace (often aliased as pld). If a user uses pld.DistributedTensor[..., pl.DN], the warning pl.DistributedTensor[...] might be slightly confusing. Consider making the prefix more generic or detecting the actual namespace if possible, though pl. is a reasonable default for the project.

"Writing the DN layout-only shorthand requires the user to mentally hold "
"two coordinate systems at once (IR-logical post-view vs. runtime "
"row-major), which is exactly the ambiguity RFC #1300 aims to eliminate. "
"Three migration patterns cover every DN scenario without writing pl.DN:\n"
" * source tensor shape, no layout marker: pl.Tensor[[N, K], pl.FP32]\n"
" * derive DN at use site: xt = pl.transpose(x, -2, -1) # ND -> DN\n"
" * inherit DN through slice/reshape from a DN-producing op\n"
"If you must express a strided-DN view (e.g. canonical pretty-print "
"round-trip), use pl.TensorView(stride=[...], layout=pl.TensorLayout.DN) "
"instead — it forces explicit stride and avoids the implicit-coord-flip "
"hazard.",
DeprecationWarning,
stacklevel=4,
)

def resolve_layout(self, layout_node: ast.expr) -> "ir.TensorLayout":
"""Resolve layout annotation to ir.TensorLayout.

Expand Down
Loading