Skip to content

[bugfix]pto merge_sort undefined tmp buffer#988

Merged
fuhouyu-hw merged 1 commit intotile-ai:ascendc_ptofrom
pbbb205:ascendc_pto
May 9, 2026
Merged

[bugfix]pto merge_sort undefined tmp buffer#988
fuhouyu-hw merged 1 commit intotile-ai:ascendc_ptofrom
pbbb205:ascendc_pto

Conversation

@pbbb205
Copy link
Copy Markdown
Contributor

@pbbb205 pbbb205 commented May 8, 2026

No description provided.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 8, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the temporary buffer allocation logic in allocate_tmp_buffer.cc to specifically handle the ascend_merge_sort operation for the 'pto' target. The changes include logic for selecting appropriate buffers based on data type and calculating the required workspace size. Feedback identifies a potential null pointer dereference, an unused variable, and a correctness issue where the buffer size should be calculated using a maximum check to ensure compatibility with other operations.

Comment on lines +883 to +898
} else if (call->op.same_as(tl::ascend_merge_sort())) {
const CallNode *dst_access_ptr = Downcast<Call>(call->args[2]).get();
std::string dst_buffer_name =
dst_access_ptr->args[1].as<VarNode>()->name_hint;
const BufferNode *dst_buffer_node =
GetBufferNodeByName_(alloc_buffers, dst_buffer_name);
DataType dtype = dst_buffer_node->dtype;
if (dtype == DataType::UInt(8)) {
int64_t tmp_shape_size =
Downcast<IntImm>(dst_access_ptr->args[3])->value * 4;
Array<PrimExpr> tmp_shape;
shape = {
IntImm(DataType::Int(32), tmp_shape_size),
};
shape_size = tmp_shape_size;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for calculating the temporary buffer size for merge_sort in the pto target has a few issues:

  1. Correctness Issue: It always overwrites the shape_size and shape variables (line 897). If multiple operations require a temporary buffer (e.g., a reduce followed by a merge_sort), the allocated size might be incorrect if the merge_sort workspace is smaller than the previous operation's requirement. It should use a maximum check similar to other operations in this function (e.g., bitwise_xor at line 857).
  2. Safety Issue: GetBufferNodeByName_ can return nullptr. A null check should be added before accessing dst_buffer_node->dtype to prevent a potential crash.
  3. Style/Efficiency: The tmp_shape variable at line 893 is declared but never used.
Suggested change
} else if (call->op.same_as(tl::ascend_merge_sort())) {
const CallNode *dst_access_ptr = Downcast<Call>(call->args[2]).get();
std::string dst_buffer_name =
dst_access_ptr->args[1].as<VarNode>()->name_hint;
const BufferNode *dst_buffer_node =
GetBufferNodeByName_(alloc_buffers, dst_buffer_name);
DataType dtype = dst_buffer_node->dtype;
if (dtype == DataType::UInt(8)) {
int64_t tmp_shape_size =
Downcast<IntImm>(dst_access_ptr->args[3])->value * 4;
Array<PrimExpr> tmp_shape;
shape = {
IntImm(DataType::Int(32), tmp_shape_size),
};
shape_size = tmp_shape_size;
}
} else if (call->op.same_as(tl::ascend_merge_sort())) {
const CallNode *dst_access_ptr = Downcast<Call>(call->args[2]).get();
std::string dst_buffer_name =
dst_access_ptr->args[1].as<VarNode>()->name_hint;
const BufferNode *dst_buffer_node =
GetBufferNodeByName_(alloc_buffers, dst_buffer_name);
if (dst_buffer_node && dst_buffer_node->dtype == DataType::UInt(8)) {
int64_t tmp_shape_size =
Downcast<IntImm>(dst_access_ptr->args[3])->value * 4;
if (tmp_shape_size > shape_size) {
shape = {
IntImm(DataType::Int(32), tmp_shape_size),
};
shape_size = tmp_shape_size;
}
}

Copy link
Copy Markdown
Collaborator

@LLMZhangYC LLMZhangYC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@fuhouyu-hw
Copy link
Copy Markdown
Collaborator

/approve

@fuhouyu-hw fuhouyu-hw merged commit d8c7d29 into tile-ai:ascendc_pto May 9, 2026
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants