[bugfix]pto merge_sort undefined tmp buffer#988
[bugfix]pto merge_sort undefined tmp buffer#988fuhouyu-hw merged 1 commit intotile-ai:ascendc_ptofrom
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Code Review
This pull request updates the temporary buffer allocation logic in allocate_tmp_buffer.cc to specifically handle the ascend_merge_sort operation for the 'pto' target. The changes include logic for selecting appropriate buffers based on data type and calculating the required workspace size. Feedback identifies a potential null pointer dereference, an unused variable, and a correctness issue where the buffer size should be calculated using a maximum check to ensure compatibility with other operations.
| } else if (call->op.same_as(tl::ascend_merge_sort())) { | ||
| const CallNode *dst_access_ptr = Downcast<Call>(call->args[2]).get(); | ||
| std::string dst_buffer_name = | ||
| dst_access_ptr->args[1].as<VarNode>()->name_hint; | ||
| const BufferNode *dst_buffer_node = | ||
| GetBufferNodeByName_(alloc_buffers, dst_buffer_name); | ||
| DataType dtype = dst_buffer_node->dtype; | ||
| if (dtype == DataType::UInt(8)) { | ||
| int64_t tmp_shape_size = | ||
| Downcast<IntImm>(dst_access_ptr->args[3])->value * 4; | ||
| Array<PrimExpr> tmp_shape; | ||
| shape = { | ||
| IntImm(DataType::Int(32), tmp_shape_size), | ||
| }; | ||
| shape_size = tmp_shape_size; | ||
| } |
There was a problem hiding this comment.
The logic for calculating the temporary buffer size for merge_sort in the pto target has a few issues:
- Correctness Issue: It always overwrites the
shape_sizeandshapevariables (line 897). If multiple operations require a temporary buffer (e.g., areducefollowed by amerge_sort), the allocated size might be incorrect if themerge_sortworkspace is smaller than the previous operation's requirement. It should use a maximum check similar to other operations in this function (e.g.,bitwise_xorat line 857). - Safety Issue:
GetBufferNodeByName_can returnnullptr. A null check should be added before accessingdst_buffer_node->dtypeto prevent a potential crash. - Style/Efficiency: The
tmp_shapevariable at line 893 is declared but never used.
| } else if (call->op.same_as(tl::ascend_merge_sort())) { | |
| const CallNode *dst_access_ptr = Downcast<Call>(call->args[2]).get(); | |
| std::string dst_buffer_name = | |
| dst_access_ptr->args[1].as<VarNode>()->name_hint; | |
| const BufferNode *dst_buffer_node = | |
| GetBufferNodeByName_(alloc_buffers, dst_buffer_name); | |
| DataType dtype = dst_buffer_node->dtype; | |
| if (dtype == DataType::UInt(8)) { | |
| int64_t tmp_shape_size = | |
| Downcast<IntImm>(dst_access_ptr->args[3])->value * 4; | |
| Array<PrimExpr> tmp_shape; | |
| shape = { | |
| IntImm(DataType::Int(32), tmp_shape_size), | |
| }; | |
| shape_size = tmp_shape_size; | |
| } | |
| } else if (call->op.same_as(tl::ascend_merge_sort())) { | |
| const CallNode *dst_access_ptr = Downcast<Call>(call->args[2]).get(); | |
| std::string dst_buffer_name = | |
| dst_access_ptr->args[1].as<VarNode>()->name_hint; | |
| const BufferNode *dst_buffer_node = | |
| GetBufferNodeByName_(alloc_buffers, dst_buffer_name); | |
| if (dst_buffer_node && dst_buffer_node->dtype == DataType::UInt(8)) { | |
| int64_t tmp_shape_size = | |
| Downcast<IntImm>(dst_access_ptr->args[3])->value * 4; | |
| if (tmp_shape_size > shape_size) { | |
| shape = { | |
| IntImm(DataType::Int(32), tmp_shape_size), | |
| }; | |
| shape_size = tmp_shape_size; | |
| } | |
| } |
|
/approve |
No description provided.