Skip to content
Open
19 changes: 19 additions & 0 deletions backends/cuda/runtime/shims/memory_slim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,25 @@ aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking) {
return Error::Ok;
}

AOTITorchError aoti_torch_assign_tensors_out(Tensor* src, Tensor** ret_dst) {
ET_CHECK_OR_RETURN_ERROR(
src != nullptr,
InvalidArgument,
"aoti_torch_assign_tensors_out: src is null");

ET_CHECK_OR_RETURN_ERROR(
ret_dst != nullptr,
InvalidArgument,
"aoti_torch_assign_tensors_out: ret_dst is null");

// Move the source tensor into the destination. After this operation,
// the source tensor will be left in an undefined state (reset).
// This differs from aoti_torch_new_tensor_handle which copies the tensor.
*ret_dst = new Tensor(std::move(*src));

return Error::Ok;
}

} // extern "C"

} // namespace executorch::backends::cuda
14 changes: 14 additions & 0 deletions backends/cuda/runtime/shims/memory_slim.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,20 @@ AOTI_SHIM_EXPORT AOTITorchError aoti_torch__reinterpret_tensor(
AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking);

/**
* Moves a tensor into a new handle and assigns it to the output parameter.
*
* Unlike aoti_torch_new_tensor_handle which copies, this function moves the
* source tensor into the destination. After this operation, the source tensor
* is left in an undefined/reset state and should not be used.
*
* @param src Source tensor to move from (must not be null, will be reset)
* @param ret_dst Output parameter for the new tensor handle
* @return AOTITorchError error code (Error::Ok on success)
*/
AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_assign_tensors_out(Tensor* src, Tensor** ret_dst);

} // extern "C"

} // namespace executorch::backends::cuda
1 change: 1 addition & 0 deletions backends/cuda/runtime/shims/tests/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,4 @@ def define_common_targets():
cuda_shim_slim_cpp_unittest("aoti_torch_new_tensor_handle")
cuda_shim_slim_cpp_unittest("aoti_torch__reinterpret_tensor")
cuda_shim_slim_cpp_unittest("aoti_torch_copy_")
cuda_shim_slim_cpp_unittest("aoti_torch_assign_tensors_out")
Loading
Loading