diff --git a/src/transform/allocate_tmp_buffer.cc b/src/transform/allocate_tmp_buffer.cc index aa8f9b608..4c923f9c8 100644 --- a/src/transform/allocate_tmp_buffer.cc +++ b/src/transform/allocate_tmp_buffer.cc @@ -172,9 +172,7 @@ class CallNodeModifier : public StmtExprMutator { } } } - } else if ("pto" == target_ && - (op->op.same_as(tl::ascend_bitwise_xor()) || - op->op.same_as(tl::ascend_merge_sort())) && + } else if ("pto" == target_ && op->op.same_as(tl::ascend_bitwise_xor()) && tmp_bufs_.size() > 0) { const CallNode *src_access_ptr = Downcast(op->args[1]).get(); DataType dtype = src_access_ptr->args[0].as()->dtype; @@ -188,6 +186,20 @@ class CallNodeModifier : public StmtExprMutator { } } } + } else if ("pto" == target_ && op->op.same_as(tl::ascend_merge_sort()) && + tmp_bufs_.size() > 0) { + const CallNode *dst_access_ptr = Downcast(op->args[2]).get(); + DataType dtype = dst_access_ptr->args[0].as()->dtype; + if (dtype == DataType::UInt(8)) { + tmp_buffer = tmp_buf_; + } else { + for (const Buffer &merge_sort_tmp_buffer : tmp_bufs_) { + if (merge_sort_tmp_buffer.get()->dtype == dtype) { + tmp_buffer = merge_sort_tmp_buffer; + break; + } + } + } } else if ("pto" == target_ && op->op.same_as(tl::ascend_gather_mask()) && tmp_bufs_.size() > 0) { const CallNode *src_access_ptr = Downcast(op->args[3]).get(); @@ -868,6 +880,22 @@ class TmpBufferInjector : public StmtExprMutator { }; shape_size = tmp_shape_size; } + } else if (call->op.same_as(tl::ascend_merge_sort())) { + const CallNode *dst_access_ptr = Downcast(call->args[2]).get(); + std::string dst_buffer_name = + dst_access_ptr->args[1].as()->name_hint; + const BufferNode *dst_buffer_node = + GetBufferNodeByName_(alloc_buffers, dst_buffer_name); + DataType dtype = dst_buffer_node->dtype; + if (dtype == DataType::UInt(8)) { + int64_t tmp_shape_size = + Downcast(dst_access_ptr->args[3])->value * 4; + Array tmp_shape; + shape = { + IntImm(DataType::Int(32), tmp_shape_size), + }; + shape_size = tmp_shape_size; + } } else if (call->op.same_as(tl::ascend_clamp()) || call->op.same_as(tl::ascend_clamp_max()) || call->op.same_as(tl::ascend_clamp_min())) { @@ -1027,4 +1055,4 @@ TVM_REGISTER_GLOBAL("tl.transform.InjectTmpBuffer") .set_body_typed(InjectTmpBuffer); } // namespace tl -} // namespace tvm +} // namespace tvm \ No newline at end of file