Skip to content

[Fix] Fix PTO row-reduce temporary buffer size and type mismatch for TROWSUM#1027

Merged
LLMZhangYC merged 2 commits into
tile-ai:ascendc_ptofrom
ShareableXue:add_gather_pto_codegen
May 19, 2026
Merged

[Fix] Fix PTO row-reduce temporary buffer size and type mismatch for TROWSUM#1027
LLMZhangYC merged 2 commits into
tile-ai:ascendc_ptofrom
ShareableXue:add_gather_pto_codegen

Conversation

@ShareableXue
Copy link
Copy Markdown
Contributor

@ShareableXue ShareableXue commented May 15, 2026

问题

在 PTO 后端使用 T.reduce_sum(dim=-1) 的内核在运行时会崩溃,报错:

  VEC instruction error: the ub address out of bounds
  CCU instruction address check error

以xllm中的 split_qkv_rmsnorm_mrope算子(head_size=256)复现:第一个编译变体
(q_heads=32, kv_heads=2)表面通过,但后续变体在
npuSynchronizeDevice 时崩溃。

根因

PTO 后端将 reduce 操作的临时buffer以 uint8(裸字节)类型分配。
旧代码存在两个耦合的 bug,导致 TROWSUM 无法正确使用这块缓冲区:

  1. 缓冲区太小allocate_tmp_buffer.cc):
    TROWSUM 与 TCOLSUM 共用同一公式:
    tmp_shape_size = valid_row × valid_col / 2(单位为元素个数)。

    由于缓冲区类型是 uint8,实际字节数为:
    valid_row × valid_col / 2 字节。

    但 TROWSUM 内部需要以源数据的类型(如 float32,每元素 4 字节)
    来使用这块 scratch 空间。正确的字节数应为:
    valid_row × tmp_col × 4,其中 tmp_col 基于 NPU 向量单元宽度
    (每次处理 256 字节)计算得出。

    以 head_size=256 为例:valid_col / 2 = 128tmp_col = 128
    旧分配:32 × 128 = 4096 字节。
    实际需要:32 × 128 × 4 = 16384 字节(差了 4 倍)。

  2. 缓冲区以错误类型传给硬件codegen_ascend_pto.cc):
    CodegenRowReduceuint8 类型的临时缓冲区直接传给 TROWSUM
    硬件指令。硬件内部以 float32 的步长(4 字节)访问该缓冲区,
    偏移量远超 uint8 声明的边界,触发 CCU 地址检查异常。

修改

src/transform/allocate_tmp_buffer.cc

  • 新增 GetPtoRowReduceTmpCols(valid_col, dtype_bytes) —— 根据
    NPU 向量单元宽度(256 字节/次)和 32 字节对齐要求,计算行归约
    临时缓冲区的正确列数。
  • 将 TROWSUM 从 TCOLSUM 分组中移出,与 TROWMAX/TROWMIN 归入同一分组。
    新公式计入源数据元素字节数:
    tmp_shape_size = valid_row × tmp_col × dtype_bytes

src/target/codegen_ascend_pto.cc

  • 新增 GetRowReduceTmpCol(valid_col, dtype) —— codegen 层对应逻辑,
    基于 dtype 字符串计算列数。
  • CodegenRowReduce 中新增 ICHECK(dst.type == src.type) ——
    在编译期捕获 dtype 不匹配,避免静默生成错误代码。
  • src.type != tmp.type(即 uint8 缓冲区 vs float32 源数据)时,
    通过 CreateUbVariableND 在临时缓冲区上创建正确类型和维度的 ND 视图。
    确保 TROWSUM 以 float32 类型、正确列数访问缓冲区。

验证

  • split_qkv_rmsnorm_mrope 内核对测试套件中所有 head 配置
    (32×2、24×4、16×4、16×2 等)均通过 torch.testing.assert_close
    与 PyTorch 参考实现对比。
  • 此前崩溃的算例现在正常运行,无硬件异常。

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces logic to calculate temporary buffer column sizes for row reduction operations, specifically updating CodegenRowReduce and the TmpBufferInjector transformation. It also fixes a type-checking bug in the Python gather interface and applies formatting updates. Feedback highlights a high-severity issue where a hardcoded "0" offset in tmp_cast could cause incorrect memory access, and a potential crash in GetRowReduceTmpCol due to unsupported 4-bit integer types in GetTypeLen.

Comment thread src/target/codegen_ascend_pto.cc Outdated
Comment on lines +2368 to +2371
ShapeInfo tmp_cast =
ShapeInfo{src.slice_valid_row, tmp_col, src.slice_valid_row, tmp_col,
src.slice_valid_row, tmp_col, tmp.extent, tmp.first_addr,
"0", src.type, tmp.ub_name, false};
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The offset for the temporary buffer view tmp_cast is hardcoded to "0". While injected scratchpad buffers currently start at offset 0, this hardcoding is fragile and inconsistent with the implementation in CodegenColReduce (line 2414), which correctly uses tmp.offset. If the temporary buffer is ever a slice or part of a larger allocation, this will lead to incorrect memory access.

Suggested change
ShapeInfo tmp_cast =
ShapeInfo{src.slice_valid_row, tmp_col, src.slice_valid_row, tmp_col,
src.slice_valid_row, tmp_col, tmp.extent, tmp.first_addr,
"0", src.type, tmp.ub_name, false};
ShapeInfo tmp_cast =
ShapeInfo{src.slice_valid_row, tmp_col, src.slice_valid_row, tmp_col,
src.slice_valid_row, tmp_col, tmp.extent, tmp.first_addr,
tmp.offset, src.type, tmp.ub_name, false};

Comment on lines +165 to +173
int GetRowReduceTmpCol(int valid_col, const std::string &dtype) {
constexpr int kVectorRepeatBytes = 256;
int dtype_len = GetTypeLen(dtype);
int elem_per_repeat = kVectorRepeatBytes / dtype_len;
int tmp_col = valid_col <= elem_per_repeat
? 1
: std::max(valid_col / 2, elem_per_repeat);
return GetValidShape(tmp_col, dtype);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function GetRowReduceTmpCol uses GetTypeLen(dtype), which currently does not support int4b_t (it will trigger an ICHECK(false)). However, getType (line 67) does support 4-bit integers. While row reduction is primarily used for floating-point types, this inconsistency could lead to unexpected compiler crashes if 4-bit integer reduction is attempted.

Copy link
Copy Markdown
Collaborator

@benyang0506 benyang0506 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/lgtm

Copy link
Copy Markdown
Collaborator

@LLMZhangYC LLMZhangYC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approve

@LLMZhangYC LLMZhangYC merged commit 1e763f4 into tile-ai:ascendc_pto May 19, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants