Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions src/transform/allocate_tmp_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,7 @@ class CallNodeModifier : public StmtExprMutator {
}
}
}
} else if ("pto" == target_ &&
(op->op.same_as(tl::ascend_bitwise_xor()) ||
op->op.same_as(tl::ascend_merge_sort())) &&
} else if ("pto" == target_ && op->op.same_as(tl::ascend_bitwise_xor()) &&
tmp_bufs_.size() > 0) {
const CallNode *src_access_ptr = Downcast<Call>(op->args[1]).get();
DataType dtype = src_access_ptr->args[0].as<CallNode>()->dtype;
Expand All @@ -188,6 +186,20 @@ class CallNodeModifier : public StmtExprMutator {
}
}
}
} else if ("pto" == target_ && op->op.same_as(tl::ascend_merge_sort()) &&
tmp_bufs_.size() > 0) {
const CallNode *dst_access_ptr = Downcast<Call>(op->args[2]).get();
DataType dtype = dst_access_ptr->args[0].as<CallNode>()->dtype;
if (dtype == DataType::UInt(8)) {
tmp_buffer = tmp_buf_;
} else {
for (const Buffer &merge_sort_tmp_buffer : tmp_bufs_) {
if (merge_sort_tmp_buffer.get()->dtype == dtype) {
tmp_buffer = merge_sort_tmp_buffer;
break;
}
}
}
} else if ("pto" == target_ && op->op.same_as(tl::ascend_gather_mask()) &&
tmp_bufs_.size() > 0) {
const CallNode *src_access_ptr = Downcast<Call>(op->args[3]).get();
Expand Down Expand Up @@ -868,6 +880,22 @@ class TmpBufferInjector : public StmtExprMutator {
};
shape_size = tmp_shape_size;
}
} else if (call->op.same_as(tl::ascend_merge_sort())) {
const CallNode *dst_access_ptr = Downcast<Call>(call->args[2]).get();
std::string dst_buffer_name =
dst_access_ptr->args[1].as<VarNode>()->name_hint;
const BufferNode *dst_buffer_node =
GetBufferNodeByName_(alloc_buffers, dst_buffer_name);
DataType dtype = dst_buffer_node->dtype;
if (dtype == DataType::UInt(8)) {
int64_t tmp_shape_size =
Downcast<IntImm>(dst_access_ptr->args[3])->value * 4;
Array<PrimExpr> tmp_shape;
shape = {
IntImm(DataType::Int(32), tmp_shape_size),
};
shape_size = tmp_shape_size;
}
Comment on lines +883 to +898
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for calculating the temporary buffer size for merge_sort in the pto target has a few issues:

  1. Correctness Issue: It always overwrites the shape_size and shape variables (line 897). If multiple operations require a temporary buffer (e.g., a reduce followed by a merge_sort), the allocated size might be incorrect if the merge_sort workspace is smaller than the previous operation's requirement. It should use a maximum check similar to other operations in this function (e.g., bitwise_xor at line 857).
  2. Safety Issue: GetBufferNodeByName_ can return nullptr. A null check should be added before accessing dst_buffer_node->dtype to prevent a potential crash.
  3. Style/Efficiency: The tmp_shape variable at line 893 is declared but never used.
Suggested change
} else if (call->op.same_as(tl::ascend_merge_sort())) {
const CallNode *dst_access_ptr = Downcast<Call>(call->args[2]).get();
std::string dst_buffer_name =
dst_access_ptr->args[1].as<VarNode>()->name_hint;
const BufferNode *dst_buffer_node =
GetBufferNodeByName_(alloc_buffers, dst_buffer_name);
DataType dtype = dst_buffer_node->dtype;
if (dtype == DataType::UInt(8)) {
int64_t tmp_shape_size =
Downcast<IntImm>(dst_access_ptr->args[3])->value * 4;
Array<PrimExpr> tmp_shape;
shape = {
IntImm(DataType::Int(32), tmp_shape_size),
};
shape_size = tmp_shape_size;
}
} else if (call->op.same_as(tl::ascend_merge_sort())) {
const CallNode *dst_access_ptr = Downcast<Call>(call->args[2]).get();
std::string dst_buffer_name =
dst_access_ptr->args[1].as<VarNode>()->name_hint;
const BufferNode *dst_buffer_node =
GetBufferNodeByName_(alloc_buffers, dst_buffer_name);
if (dst_buffer_node && dst_buffer_node->dtype == DataType::UInt(8)) {
int64_t tmp_shape_size =
Downcast<IntImm>(dst_access_ptr->args[3])->value * 4;
if (tmp_shape_size > shape_size) {
shape = {
IntImm(DataType::Int(32), tmp_shape_size),
};
shape_size = tmp_shape_size;
}
}

} else if (call->op.same_as(tl::ascend_clamp()) ||
call->op.same_as(tl::ascend_clamp_max()) ||
call->op.same_as(tl::ascend_clamp_min())) {
Expand Down Expand Up @@ -1027,4 +1055,4 @@ TVM_REGISTER_GLOBAL("tl.transform.InjectTmpBuffer")
.set_body_typed(InjectTmpBuffer);

} // namespace tl
} // namespace tvm
} // namespace tvm
Loading