[AscendNPU-IR][Expert] Fix FA reduce block_n_tail bug#851
[AscendNPU-IR][Expert] Fix FA reduce block_n_tail bug#851Wamdus3 wants to merge 1 commit intotile-ai:npuirfrom
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Code Review
This pull request updates the T.reduce calls in examples/flash_attn_npuir.py to include a size parameter, intended to handle tail data during max and sum reductions. However, the review feedback indicates that the current library infrastructure in tilelang (both the Python API and the C++ backend) does not support these keyword arguments or the size parameter, meaning the underlying implementation must be updated to accommodate these changes.
| T.vmul(cross_kernel_f32_N, acc_c_scale, cross_kernel_f32_N) | ||
| T.reduce( | ||
| cross_kernel_f32_N, scores_max, dims=[1], reduce_mode="max" | ||
| cross_kernel_f32_N, scores_max, dims=[1], reduce_mode="max", size=[real_m, tail_size_n] |
There was a problem hiding this comment.
The addition of the size parameter is logically correct for fixing the tail bug by limiting the reduction to valid data. However, there is a significant discrepancy between this call and the library definition in tilelang/language/reduce.py. The current T.reduce implementation expects (buffer, out, reduce_type, dim, clear) and does not support keyword arguments like dims, reduce_mode, or size. Furthermore, the C++ implementation in src/op/reduce.cc (specifically the ReduceOp constructor and Lower function) lacks the logic to handle a dynamic size argument. Please ensure the library infrastructure is updated to support this extended API and to utilize the size bounds during lowering.
|
|
||
| T.reduce( | ||
| cross_kernel_f32_N, scores_sum, dims=[1], reduce_mode="sum" | ||
| cross_kernel_f32_N, scores_sum, dims=[1], reduce_mode="sum", size=[real_m, tail_size_n] |
There was a problem hiding this comment.
Similar to the change at line 210, this T.reduce call uses an API that is not currently supported by the provided reduce.py and reduce.cc files. While the logic of passing size=[real_m, tail_size_n] correctly addresses the issue of stale data in the padding area during the sum reduction, the underlying operator must be updated to accept and implement these bounds.
| T.vmul(cross_kernel_f32_N, acc_c_scale, cross_kernel_f32_N) | ||
| T.reduce( | ||
| cross_kernel_f32_N, scores_max, dims=[1], reduce_mode="max" | ||
| cross_kernel_f32_N, scores_max, dims=[1], reduce_mode="max", size=[real_m, tail_size_n] |
There was a problem hiding this comment.
We previously removed size because slice syntax was more precise; why is size being reinstated here?
Fixed FA reduce block_n_tail bug