Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/flash_attn_npuir.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def FlashAttnExp(

T.vmul(cross_kernel_f32_N, acc_c_scale, cross_kernel_f32_N)
T.reduce(
cross_kernel_f32_N, scores_max, dims=[1], reduce_mode="max"
cross_kernel_f32_N, scores_max, dims=[1], reduce_mode="max", size=[real_m, tail_size_n]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The addition of the size parameter is logically correct for fixing the tail bug by limiting the reduction to valid data. However, there is a significant discrepancy between this call and the library definition in tilelang/language/reduce.py. The current T.reduce implementation expects (buffer, out, reduce_type, dim, clear) and does not support keyword arguments like dims, reduce_mode, or size. Furthermore, the C++ implementation in src/op/reduce.cc (specifically the ReduceOp constructor and Lower function) lacks the logic to handle a dynamic size argument. Please ensure the library infrastructure is updated to support this extended API and to utilize the size bounds during lowering.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We previously removed size because slice syntax was more precise; why is size being reinstated here?

)
if i != 0:
T.vmax(scores_max_prev, scores_max, scores_max)
Expand Down Expand Up @@ -235,7 +235,7 @@ def FlashAttnExp(
T.sync_block_set(i)

T.reduce(
cross_kernel_f32_N, scores_sum, dims=[1], reduce_mode="sum"
cross_kernel_f32_N, scores_sum, dims=[1], reduce_mode="sum", size=[real_m, tail_size_n]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the change at line 210, this T.reduce call uses an API that is not currently supported by the provided reduce.py and reduce.cc files. While the logic of passing size=[real_m, tail_size_n] correctly addresses the issue of stale data in the padding area during the sum reduction, the underlying operator must be updated to accept and implement these bounds.

)
T.vmul(logsum, scores_scale, logsum)
T.vadd(logsum, scores_sum, logsum)
Expand Down
Loading