Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ description: TileLang Ascend API 使用最佳实践。提供内存分配、数
| **Softmax/LayerNorm** | [api-compute](references/api-compute.md) | T.reduce_max/sum、T.tile.exp/sub/div |
| **逐元素计算** | [api-compute](references/api-compute.md) | T.Parallel + 符号 API 或 T.tile.xxx 两种范式 |
| **多 block/core 累加到 GM** | [api-compute](references/api-compute.md) | T.tile.atomic_add(dst_gm, src_local),调用前显式清零 GM |
| **CV 融合算子** | [api-kernel-memory](references/api-kernel-memory.md), [api-schedule-sync](references/api-schedule-sync.md) | workspace 索引一致性、AUTO_CV_COMBINE、vid 并行化 |
| **流水线优化** | [api-schedule-sync](references/api-schedule-sync.md) | T.Pipelined num_stages、核间/核内流水线 |
| **多核负载均衡** | [api-schedule-sync](references/api-schedule-sync.md) | T.Persistent 缓存友好调度 |
| **排序** | [api-compute](references/api-compute.md) | T.tile.sort → T.tile.merge_sort → T.tile.topk |
Expand Down Expand Up @@ -90,4 +91,4 @@ description: TileLang Ascend API 使用最佳实践。提供内存分配、数
| `TL_ASCEND_AUTO_SYNC: True` | 自动同步插入 |
| `TL_ASCEND_MEMORY_PLANNING: True` | 自动内存规划 |
| `TL_ASCEND_AUTO_CV_COMBINE: True` | 自动 CV 分离(核间流水线) |
| `tl.ascend_auto_cross_core_sync: True` | 自动核间同步(核间流水线) |
| `TL_ASCEND_AUTO_CV_SYNC: True` | 自动核间同步(核间流水线) |
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,34 @@ T.copy(C_L0, C[bx * block_M, by * block_N])
T.gemm_v0(q_l1, k_l1, acc_s_l0c, transpose_B=True, init=True)
```

**⚠️ 重要:矩阵 Buffer 分形限制**

使用 `T.gemm_v0` 时,矩阵 Buffer 必须满足最小分形限制。分形大小固定为 512 Byte(L0A/L0B)或 256 元素(L0C),shape 与 dtype 相关:

**分形公式**:
- **L0A**:`16 × (32B / sizeof(AType))`,固定 512 Byte
- **L0B**:`(32B / sizeof(BType)) × 16`,固定 512 Byte
- **L0C**:`16 × 16`,固定 256 元素(不随 dtype 变化)

**不同 dtype 的最小维度限制**:

| dtype | sizeof | L0A 分形 | L0B 分形 | 最小限制 |
|-------|--------|----------|----------|---------|
| int8 / uint8 | 1 Byte | 16 × 32 | 32 × 16 | M ≥ 16, K ≥ 32, N ≥ 16 |
| float16 / bfloat16 | 2 Byte | 16 × 16 | 16 × 16 | M ≥ 16, K ≥ 16, N ≥ 16 |
| int32 / float32 | 4 Byte | 16 × 8 | 8 × 16 | M ≥ 16, K ≥ 8, N ≥ 16 |

**L0C 分形固定为 16 × 16,不随 dtype 变化**,因此 M 和 N 的最小值始终为 16。

**常见错误**:`block_N = 8` 不满足 L0C 分形限制(N ≥ 16),会导致计算结果错误。

**示例**:int8 GEMM 的正确 block size 选择
```python
block_M = 64 # ≥ 16 ✓
block_N = 16 # ≥ 16 ✓(满足 L0B/L0C 分形限制)
block_K = 256 # ≥ 32 ✓(int8 的 L0A/L0B K 维度限制)
```

### T.mma(A, B, C, init=False)

NPU 级别的矩阵乘累加指令,比 `gemm_v0` 更底层。不支持 `transpose_A`/`transpose_B`。通常配合 `T.alloc_L0A`/`T.alloc_L0B` 和 `T.annotate_layout` 使用。
Expand Down Expand Up @@ -161,9 +189,76 @@ for i in range(block_M // VEC_NUM): # 行顺序
c_ub[i, j] = a_ub[i, j] * b_ub[i, j]
```

### 3.1 T.Parallel 在 TileLang-Ascend 上的限制

> **核心原理**:`T.Parallel` 在 TileLang-Ascend 上会被编译器 lowering 为 `T.tile.xxx` Buffer 级 SIMD 指令。因此,T.Parallel 的能力边界受限于 AscendC Vector 指令的能力。

#### 支持的循环维度

- ✅ **1D 并行**:`for j in T.Parallel(N)`
- ✅ **2D 并行**:`for i, j in T.Parallel(M, N)`
- ✅ **serial + parallel 组合**:`for i in range(M): for j in T.Parallel(N)`
- ❌ **3D 或更高维并行**:不支持,会触发编译错误
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Support the 3D double buffer scenario. Actually, the computation is still performed only on 2D tiles; the highest dimension merely indicates which stream to operate on. Such as https://github.com/tile-ai/tilelang-ascend/blob/ascendc_pto/examples/elementwise/elementwise_add_pipeline.py#L71

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume the "3D 或更高维并行" in the document means for i, j, k ... in T.Parallel(M, N, K, ...).
It seems that the use case in elementwise_add_pipeline.py#L71 still fits the "2D 并行" case.
Perhaps the description is OK here I think 🤔


#### 支持的表达式类型

`T.Parallel` 内的表达式会被自动分解并翻译为 Vector 指令。**仅支持以下模式**:

| 类型 | 支持的表达式 | 备注 |
|------|-------------|------|
| 简单赋值 | `a[i] = b[i]` | 等价于 `T.copy` |
| 简单运算 | `c[i] = a[i] + b[i]` | 等价于 `T.tile.add` |
| 标量运算 | `c[i] = a[i] + scalar` | 等价于 `T.tile.add` |
| 广播运算 | `c[i,j] = a[i,j] * b[j]` | 自动广播处理(仅支持 1D→2D,索引必须是简单变量) |
| 复合表达式 | `c[i] = a[i] * b[i] + d[i]` | 自动分解为多步操作 |
| 离散索引 | 非简单变量索引,如 `a[idx[i]]` | 编译器退回到 `T.serial` 循环 |

#### 不支持的表达式

以下表达式**无法在 T.Parallel 中使用**,需要改用其他方案:

| 不支持的表达式 | 错误类型 | 替代方案 |
|---------------|---------|---------|
| `if-else` 条件分支 | 编译错误(SIMD 架构不支持元素级条件判断) | 使用 `T.tile.compare` + `T.tile.select` |
| `T.if_then_else(...)` | 编译错误 ("undefined Variable v_thread") | 使用 `T.tile.compare` + `T.tile.select` |
| `tir.reinterpret("int8", ...)` | 运行时错误 | 使用 `T.reinterpretcast`(整个 buffer) |
| `T.int8(expr)` 或 `.astype("int8")` | 编译错误或数据异常 | 使用 `T.tile.cast`(整个 buffer) |
| 非线性索引 `a[i*i]` | 未实现 | 使用 `T.tile.xxx` + 手动索引计算 |
| 动态 shift `a[i] >> shift[i]` | 不支持(shift 必须是 scalar) | 使用固定 scalar shift |

#### 循环范围要求

`T.Parallel` 的循环范围必须是编译期可确定的常量值(IntImm),不支持动态变量作为循环边界。

#### 从 CUDA TileLang 迁移注意事项

TileLang-Ascend 的 T.Parallel 语法与 CUDA 版本对齐,但底层执行模型不同:

- **CUDA (SIMT)**:每个元素独立执行,支持复杂控制流
- **Ascend (SIMD)**:所有元素并行执行相同指令,不支持条件分支

CUDA 代码中的以下模式在 Ascend 上需要改写:

```python
# CUDA 版本(SIMT,逐元素条件判断)
for i in T.Parallel(N):
if a[i] > threshold: # ❌ Ascend 不支持
b[i] = a[i] * scale
else:
b[i] = a[i]

# Ascend 版本(SIMD,用 compare + select 替代)
T.tile.compare(mask_ub, a_ub, threshold, "GT")
T.tile.select(b_ub, mask_ub, a_scaled_ub, a_ub, "VSEL_CMPMASK_SPR")
```

详细用法参考 `docs/tutorials/t_parallel.md`。

Pass 设计详见 `.agents/skills/tilelang-pass-analyzer/references/pass-designs/ascend_lower_parallel_to_vector_design.md`。

---

## 4. Tile 扩展原语(Expert / 混合模式 T.tile.xxx)
## 4. Tile 扩展原语(T.tile.xxx Buffer 级 SIMD 操作

`T.tile.xxx` 系列接口直接触发 Tile 级的 Ascend 操作。它们既可用于全手动 Expert 模式,也可在 Developer pass_configs 下作为混合模式原语使用。

Expand Down Expand Up @@ -406,7 +501,7 @@ T.tile.topk(topk_global, sort_result, K, actual_num)
### 4.12 两种编程范式对比

```python
# 方式一:T.Parallel + 符号 API(推荐,跨平台兼容)
# 方式一:T.Parallel + 符号 API(Developer 模式,跨平台兼容)
for i, j in T.Parallel(block_M // VEC_NUM, block_N):
b_ub[i, j] = T.exp(a_ub[i, j])

Expand Down
Loading