[Fix] Fix PTO row-reduce temporary buffer size and type mismatch for TROWSUM#1027
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Code Review
This pull request introduces logic to calculate temporary buffer column sizes for row reduction operations, specifically updating CodegenRowReduce and the TmpBufferInjector transformation. It also fixes a type-checking bug in the Python gather interface and applies formatting updates. Feedback highlights a high-severity issue where a hardcoded "0" offset in tmp_cast could cause incorrect memory access, and a potential crash in GetRowReduceTmpCol due to unsupported 4-bit integer types in GetTypeLen.
| ShapeInfo tmp_cast = | ||
| ShapeInfo{src.slice_valid_row, tmp_col, src.slice_valid_row, tmp_col, | ||
| src.slice_valid_row, tmp_col, tmp.extent, tmp.first_addr, | ||
| "0", src.type, tmp.ub_name, false}; |
There was a problem hiding this comment.
The offset for the temporary buffer view tmp_cast is hardcoded to "0". While injected scratchpad buffers currently start at offset 0, this hardcoding is fragile and inconsistent with the implementation in CodegenColReduce (line 2414), which correctly uses tmp.offset. If the temporary buffer is ever a slice or part of a larger allocation, this will lead to incorrect memory access.
| ShapeInfo tmp_cast = | |
| ShapeInfo{src.slice_valid_row, tmp_col, src.slice_valid_row, tmp_col, | |
| src.slice_valid_row, tmp_col, tmp.extent, tmp.first_addr, | |
| "0", src.type, tmp.ub_name, false}; | |
| ShapeInfo tmp_cast = | |
| ShapeInfo{src.slice_valid_row, tmp_col, src.slice_valid_row, tmp_col, | |
| src.slice_valid_row, tmp_col, tmp.extent, tmp.first_addr, | |
| tmp.offset, src.type, tmp.ub_name, false}; |
| int GetRowReduceTmpCol(int valid_col, const std::string &dtype) { | ||
| constexpr int kVectorRepeatBytes = 256; | ||
| int dtype_len = GetTypeLen(dtype); | ||
| int elem_per_repeat = kVectorRepeatBytes / dtype_len; | ||
| int tmp_col = valid_col <= elem_per_repeat | ||
| ? 1 | ||
| : std::max(valid_col / 2, elem_per_repeat); | ||
| return GetValidShape(tmp_col, dtype); | ||
| } |
There was a problem hiding this comment.
The function GetRowReduceTmpCol uses GetTypeLen(dtype), which currently does not support int4b_t (it will trigger an ICHECK(false)). However, getType (line 67) does support 4-bit integers. While row reduction is primarily used for floating-point types, this inconsistency could lead to unexpected compiler crashes if 4-bit integer reduction is attempted.
问题
在 PTO 后端使用
T.reduce_sum(dim=-1)的内核在运行时会崩溃,报错:以xllm中的
split_qkv_rmsnorm_mrope算子(head_size=256)复现:第一个编译变体(q_heads=32, kv_heads=2)表面通过,但后续变体在
npuSynchronizeDevice时崩溃。根因
PTO 后端将 reduce 操作的临时buffer以
uint8(裸字节)类型分配。旧代码存在两个耦合的 bug,导致 TROWSUM 无法正确使用这块缓冲区:
缓冲区太小(
allocate_tmp_buffer.cc):TROWSUM 与 TCOLSUM 共用同一公式:
tmp_shape_size = valid_row × valid_col / 2(单位为元素个数)。由于缓冲区类型是
uint8,实际字节数为:valid_row × valid_col / 2字节。但 TROWSUM 内部需要以源数据的类型(如 float32,每元素 4 字节)
来使用这块 scratch 空间。正确的字节数应为:
valid_row × tmp_col × 4,其中tmp_col基于 NPU 向量单元宽度(每次处理 256 字节)计算得出。
以 head_size=256 为例:
valid_col / 2 = 128,tmp_col = 128。旧分配:32 × 128 = 4096 字节。
实际需要:32 × 128 × 4 = 16384 字节(差了 4 倍)。
缓冲区以错误类型传给硬件(
codegen_ascend_pto.cc):CodegenRowReduce将uint8类型的临时缓冲区直接传给 TROWSUM硬件指令。硬件内部以 float32 的步长(4 字节)访问该缓冲区,
偏移量远超
uint8声明的边界,触发 CCU 地址检查异常。修改
src/transform/allocate_tmp_buffer.ccGetPtoRowReduceTmpCols(valid_col, dtype_bytes)—— 根据NPU 向量单元宽度(256 字节/次)和 32 字节对齐要求,计算行归约
临时缓冲区的正确列数。
新公式计入源数据元素字节数:
tmp_shape_size = valid_row × tmp_col × dtype_bytes。src/target/codegen_ascend_pto.ccGetRowReduceTmpCol(valid_col, dtype)—— codegen 层对应逻辑,基于 dtype 字符串计算列数。
CodegenRowReduce中新增ICHECK(dst.type == src.type)——在编译期捕获 dtype 不匹配,避免静默生成错误代码。
src.type != tmp.type(即 uint8 缓冲区 vs float32 源数据)时,通过
CreateUbVariableND在临时缓冲区上创建正确类型和维度的 ND 视图。确保 TROWSUM 以 float32 类型、正确列数访问缓冲区。
验证
split_qkv_rmsnorm_mrope内核对测试套件中所有 head 配置(32×2、24×4、16×4、16×2 等)均通过
torch.testing.assert_close与 PyTorch 参考实现对比。