From 3b0f659a1ad84e99ff6161e596195b489ce946a5 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 27 Jan 2026 23:46:03 -0800 Subject: [PATCH 1/2] [slimtensor] Add aoti_torch__reinterpret_tensor for SlimTensor Pull Request resolved: https://github.com/pytorch/executorch/pull/16450 Add SlimTensor-based `aoti_torch__reinterpret_tensor()` - Creates a reinterpreted view of a tensor with new sizes, strides, and storage offset using SlimTensor's `as_strided()` method. The view shares the same underlying storage. ghstack-source-id: 336360654 @exported-using-ghexport Differential Revision: [D90126249](https://our.internmc.facebook.com/intern/diff/D90126249/) --- backends/cuda/runtime/shims/memory_slim.cpp | 41 ++ backends/cuda/runtime/shims/memory_slim.h | 22 + backends/cuda/runtime/shims/tests/targets.bzl | 1 + ...st_aoti_torch__reinterpret_tensor_slim.cpp | 692 ++++++++++++++++++ 4 files changed, 756 insertions(+) create mode 100644 backends/cuda/runtime/shims/tests/test_aoti_torch__reinterpret_tensor_slim.cpp diff --git a/backends/cuda/runtime/shims/memory_slim.cpp b/backends/cuda/runtime/shims/memory_slim.cpp index fcc5acfafb9..580b9a4530c 100644 --- a/backends/cuda/runtime/shims/memory_slim.cpp +++ b/backends/cuda/runtime/shims/memory_slim.cpp @@ -158,6 +158,47 @@ AOTITorchError aoti_torch_new_tensor_handle( return Error::Ok; } +AOTITorchError aoti_torch__reinterpret_tensor( + Tensor* self, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + Tensor** ret_new_tensor) { + ET_CHECK_OR_RETURN_ERROR( + self != nullptr, + InvalidArgument, + "aoti_torch__reinterpret_tensor: self is null"); + + ET_CHECK_OR_RETURN_ERROR( + ret_new_tensor != nullptr, + InvalidArgument, + "aoti_torch__reinterpret_tensor: ret_new_tensor is null"); + + ET_CHECK_OR_RETURN_ERROR( + ndim >= 0, + InvalidArgument, + "aoti_torch__reinterpret_tensor: ndim must be non-negative, got %lld", + static_cast(ndim)); + + ET_CHECK_OR_RETURN_ERROR( + !(sizes_ptr == nullptr && ndim > 0), + InvalidArgument, + "aoti_torch__reinterpret_tensor: sizes_ptr is null but ndim > 0"); + + IntArrayRef sizes(sizes_ptr, static_cast(ndim)); + IntArrayRef strides(strides_ptr, static_cast(ndim)); + + // Create a new tensor view using as_strided. This creates a tensor that + // shares the same underlying storage but with different sizes, strides, + // and storage offset. SlimTensor::as_strided() handles this via copy + // constructor which shares the SharedPtr. + *ret_new_tensor = + new Tensor(self->as_strided(sizes, strides, storage_offset)); + + return Error::Ok; +} + } // extern "C" } // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/memory_slim.h b/backends/cuda/runtime/shims/memory_slim.h index b65bb9e709b..64a7a561141 100644 --- a/backends/cuda/runtime/shims/memory_slim.h +++ b/backends/cuda/runtime/shims/memory_slim.h @@ -106,6 +106,28 @@ AOTI_SHIM_EXPORT AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor); AOTI_SHIM_EXPORT AOTITorchError aoti_torch_new_tensor_handle(Tensor* orig_handle, Tensor** new_handle); +/** + * Creates a reinterpreted view of a tensor with new sizes, strides, and offset. + * + * This is equivalent to torch.as_strided() - it creates a new tensor that + * shares the same underlying storage but with different view parameters. + * + * @param self Original tensor to reinterpret (must not be null) + * @param ndim Number of dimensions for the new view + * @param sizes_ptr Pointer to array of dimension sizes + * @param strides_ptr Pointer to array of strides for each dimension + * @param storage_offset Storage offset in number of elements + * @param ret_new_tensor Output parameter for the reinterpreted tensor view + * @return AOTITorchError error code (Error::Ok on success) + */ +AOTI_SHIM_EXPORT AOTITorchError aoti_torch__reinterpret_tensor( + Tensor* self, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + Tensor** ret_new_tensor); + } // extern "C" } // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/tests/targets.bzl b/backends/cuda/runtime/shims/tests/targets.bzl index 67ca286dacf..ce9f8fcc647 100644 --- a/backends/cuda/runtime/shims/tests/targets.bzl +++ b/backends/cuda/runtime/shims/tests/targets.bzl @@ -75,3 +75,4 @@ def define_common_targets(): cuda_shim_slim_cpp_unittest("aoti_torch_create_tensor_from_blob_v2") cuda_shim_slim_cpp_unittest("aoti_torch_delete_tensor_object") cuda_shim_slim_cpp_unittest("aoti_torch_new_tensor_handle") + cuda_shim_slim_cpp_unittest("aoti_torch__reinterpret_tensor") diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch__reinterpret_tensor_slim.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch__reinterpret_tensor_slim.cpp new file mode 100644 index 00000000000..d2ad645136e --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch__reinterpret_tensor_slim.cpp @@ -0,0 +1,692 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include +#include +#include +#include + +using namespace executorch::backends::cuda; +using executorch::runtime::Error; + +namespace slim_c10 = executorch::backends::aoti::slim::c10; + +namespace { + +bool isCudaAvailable() { + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + return (err == cudaSuccess && device_count > 0); +} + +std::vector calculateContiguousStrides( + const std::vector& sizes) { + std::vector strides(sizes.size()); + if (sizes.empty()) { + return strides; + } + strides[sizes.size() - 1] = 1; + for (int64_t i = static_cast(sizes.size()) - 2; i >= 0; i--) { + strides[i] = strides[i + 1] * sizes[i + 1]; + } + return strides; +} + +} // namespace + +class AOTITorchReinterpretTensorSlimTest : public ::testing::Test { + protected: + void SetUp() override { + et_pal_init(); + } + + Tensor* createTestTensor( + const std::vector& sizes, + const std::vector& strides = {}, + int32_t dtype = static_cast(slim_c10::ScalarType::Float), + int32_t device_type = static_cast(slim_c10::DeviceType::CPU), + int32_t device_index = 0) { + Tensor* tensor = nullptr; + + std::vector effective_strides = strides; + if (strides.empty()) { + effective_strides = calculateContiguousStrides(sizes); + } + + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + effective_strides.data(), + dtype, + device_type, + device_index, + &tensor); + + return (error == Error::Ok) ? tensor : nullptr; + } +}; + +// ============================================================================ +// Basic Functionality Tests +// ============================================================================ + +TEST_F(AOTITorchReinterpretTensorSlimTest, BasicView_CPU) { + std::vector sizes = {2, 3, 4}; + Tensor* orig_tensor = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(orig_tensor, nullptr); + + std::vector new_sizes = {6, 4}; + std::vector new_strides = {4, 1}; + int64_t storage_offset = 0; + + Tensor* view_tensor = nullptr; + AOTITorchError error = aoti_torch__reinterpret_tensor( + orig_tensor, + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + storage_offset, + &view_tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(view_tensor, nullptr); + + EXPECT_EQ(view_tensor->dim(), 2); + EXPECT_EQ(view_tensor->size(0), 6); + EXPECT_EQ(view_tensor->size(1), 4); + EXPECT_EQ(view_tensor->stride(0), 4); + EXPECT_EQ(view_tensor->stride(1), 1); + + EXPECT_EQ(view_tensor->data_ptr(), orig_tensor->data_ptr()); + + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(view_tensor), Error::Ok); +} + +TEST_F(AOTITorchReinterpretTensorSlimTest, NullSelf) { + std::vector sizes = {2, 3}; + std::vector strides = {3, 1}; + + Tensor* view_tensor = nullptr; + AOTITorchError error = aoti_torch__reinterpret_tensor( + nullptr, sizes.size(), sizes.data(), strides.data(), 0, &view_tensor); + + EXPECT_EQ(error, Error::InvalidArgument); +} + +TEST_F(AOTITorchReinterpretTensorSlimTest, NullReturnPointer) { + std::vector sizes = {2, 3}; + Tensor* orig_tensor = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(orig_tensor, nullptr); + + std::vector new_sizes = {6}; + std::vector new_strides = {1}; + + AOTITorchError error = aoti_torch__reinterpret_tensor( + orig_tensor, + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + 0, + nullptr); + + EXPECT_EQ(error, Error::InvalidArgument); + + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); +} + +TEST_F(AOTITorchReinterpretTensorSlimTest, NegativeNdim) { + std::vector sizes = {2, 3}; + Tensor* orig_tensor = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(orig_tensor, nullptr); + + std::vector new_sizes = {6}; + std::vector new_strides = {1}; + + Tensor* view_tensor = nullptr; + AOTITorchError error = aoti_torch__reinterpret_tensor( + orig_tensor, -1, new_sizes.data(), new_strides.data(), 0, &view_tensor); + + EXPECT_EQ(error, Error::InvalidArgument); + + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); +} + +// ============================================================================ +// Storage Offset Tests +// ============================================================================ + +TEST_F(AOTITorchReinterpretTensorSlimTest, WithStorageOffset_CPU) { + std::vector sizes = {4, 4}; + Tensor* orig_tensor = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(orig_tensor, nullptr); + + std::vector new_sizes = {2, 4}; + std::vector new_strides = {4, 1}; + int64_t storage_offset = 4; // Skip first row + + Tensor* view_tensor = nullptr; + AOTITorchError error = aoti_torch__reinterpret_tensor( + orig_tensor, + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + storage_offset, + &view_tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(view_tensor, nullptr); + + EXPECT_EQ(view_tensor->dim(), 2); + EXPECT_EQ(view_tensor->size(0), 2); + EXPECT_EQ(view_tensor->size(1), 4); + + char* orig_ptr = static_cast(orig_tensor->data_ptr()); + char* view_ptr = static_cast(view_tensor->data_ptr()); + EXPECT_EQ(view_ptr, orig_ptr + storage_offset * sizeof(float)); + + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(view_tensor), Error::Ok); +} + +// ============================================================================ +// Memory Sharing Tests +// ============================================================================ + +TEST_F(AOTITorchReinterpretTensorSlimTest, MemorySharing_CPU) { + std::vector sizes = {6}; + Tensor* orig_tensor = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(orig_tensor, nullptr); + + void* orig_ptr = orig_tensor->data_ptr(); + + std::vector new_sizes = {2, 3}; + std::vector new_strides = {3, 1}; + + Tensor* view_tensor = nullptr; + AOTITorchError error = aoti_torch__reinterpret_tensor( + orig_tensor, + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + 0, + &view_tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(view_tensor, nullptr); + + EXPECT_EQ(view_tensor->data_ptr(), orig_ptr); + + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + + EXPECT_EQ(view_tensor->data_ptr(), orig_ptr); + + EXPECT_EQ(aoti_torch_delete_tensor_object(view_tensor), Error::Ok); +} + +TEST_F(AOTITorchReinterpretTensorSlimTest, MultipleViews_CPU) { + std::vector sizes = {24}; + Tensor* orig_tensor = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(orig_tensor, nullptr); + + void* orig_ptr = orig_tensor->data_ptr(); + + std::vector sizes1 = {2, 12}; + std::vector strides1 = {12, 1}; + + std::vector sizes2 = {4, 6}; + std::vector strides2 = {6, 1}; + + std::vector sizes3 = {2, 3, 4}; + std::vector strides3 = {12, 4, 1}; + + Tensor* view1 = nullptr; + Tensor* view2 = nullptr; + Tensor* view3 = nullptr; + + EXPECT_EQ( + aoti_torch__reinterpret_tensor( + orig_tensor, + sizes1.size(), + sizes1.data(), + strides1.data(), + 0, + &view1), + Error::Ok); + EXPECT_EQ( + aoti_torch__reinterpret_tensor( + orig_tensor, + sizes2.size(), + sizes2.data(), + strides2.data(), + 0, + &view2), + Error::Ok); + EXPECT_EQ( + aoti_torch__reinterpret_tensor( + orig_tensor, + sizes3.size(), + sizes3.data(), + strides3.data(), + 0, + &view3), + Error::Ok); + + EXPECT_EQ(view1->data_ptr(), orig_ptr); + EXPECT_EQ(view2->data_ptr(), orig_ptr); + EXPECT_EQ(view3->data_ptr(), orig_ptr); + + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + + EXPECT_EQ(view1->data_ptr(), orig_ptr); + EXPECT_EQ(view2->data_ptr(), orig_ptr); + EXPECT_EQ(view3->data_ptr(), orig_ptr); + + EXPECT_EQ(aoti_torch_delete_tensor_object(view1), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(view2), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(view3), Error::Ok); +} + +// ============================================================================ +// Dimension Change Tests +// ============================================================================ + +TEST_F(AOTITorchReinterpretTensorSlimTest, ExpandDimensions_CPU) { + std::vector sizes = {6}; + Tensor* orig_tensor = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(orig_tensor, nullptr); + EXPECT_EQ(orig_tensor->dim(), 1); + + std::vector new_sizes = {2, 3}; + std::vector new_strides = {3, 1}; + + Tensor* view_tensor = nullptr; + AOTITorchError error = aoti_torch__reinterpret_tensor( + orig_tensor, + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + 0, + &view_tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(view_tensor, nullptr); + EXPECT_EQ(view_tensor->dim(), 2); + + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(view_tensor), Error::Ok); +} + +TEST_F(AOTITorchReinterpretTensorSlimTest, CollapseDimensions_CPU) { + std::vector sizes = {2, 3, 4}; + Tensor* orig_tensor = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(orig_tensor, nullptr); + EXPECT_EQ(orig_tensor->dim(), 3); + + std::vector new_sizes = {24}; + std::vector new_strides = {1}; + + Tensor* view_tensor = nullptr; + AOTITorchError error = aoti_torch__reinterpret_tensor( + orig_tensor, + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + 0, + &view_tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(view_tensor, nullptr); + EXPECT_EQ(view_tensor->dim(), 1); + EXPECT_EQ(view_tensor->numel(), 24); + + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(view_tensor), Error::Ok); +} + +TEST_F(AOTITorchReinterpretTensorSlimTest, ScalarTensorView_CPU) { + std::vector sizes = {1}; + Tensor* orig_tensor = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(orig_tensor, nullptr); + + std::vector new_sizes = {}; + std::vector new_strides = {}; + + Tensor* view_tensor = nullptr; + AOTITorchError error = aoti_torch__reinterpret_tensor( + orig_tensor, 0, new_sizes.data(), new_strides.data(), 0, &view_tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(view_tensor, nullptr); + EXPECT_EQ(view_tensor->dim(), 0); + EXPECT_EQ(view_tensor->numel(), 1); + + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(view_tensor), Error::Ok); +} + +// ============================================================================ +// Stride Tests +// ============================================================================ + +TEST_F(AOTITorchReinterpretTensorSlimTest, TransposeViaStrides_CPU) { + std::vector sizes = {3, 4}; + Tensor* orig_tensor = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(orig_tensor, nullptr); + + std::vector new_sizes = {4, 3}; + std::vector new_strides = {1, 4}; + + Tensor* view_tensor = nullptr; + AOTITorchError error = aoti_torch__reinterpret_tensor( + orig_tensor, + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + 0, + &view_tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(view_tensor, nullptr); + EXPECT_EQ(view_tensor->size(0), 4); + EXPECT_EQ(view_tensor->size(1), 3); + EXPECT_EQ(view_tensor->stride(0), 1); + EXPECT_EQ(view_tensor->stride(1), 4); + + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(view_tensor), Error::Ok); +} + +// ============================================================================ +// Different Dtype Tests +// ============================================================================ + +TEST_F(AOTITorchReinterpretTensorSlimTest, Int64Tensor_CPU) { + std::vector sizes = {6}; + Tensor* orig_tensor = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Long), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(orig_tensor, nullptr); + + std::vector new_sizes = {2, 3}; + std::vector new_strides = {3, 1}; + + Tensor* view_tensor = nullptr; + AOTITorchError error = aoti_torch__reinterpret_tensor( + orig_tensor, + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + 0, + &view_tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(view_tensor, nullptr); + EXPECT_EQ(view_tensor->itemsize(), 8); + + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(view_tensor), Error::Ok); +} + +TEST_F(AOTITorchReinterpretTensorSlimTest, BFloat16Tensor_CPU) { + std::vector sizes = {6}; + Tensor* orig_tensor = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::BFloat16), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(orig_tensor, nullptr); + + std::vector new_sizes = {2, 3}; + std::vector new_strides = {3, 1}; + + Tensor* view_tensor = nullptr; + AOTITorchError error = aoti_torch__reinterpret_tensor( + orig_tensor, + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + 0, + &view_tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(view_tensor, nullptr); + EXPECT_EQ(view_tensor->itemsize(), 2); + + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(view_tensor), Error::Ok); +} + +// ============================================================================ +// CUDA Tests +// ============================================================================ + +TEST_F(AOTITorchReinterpretTensorSlimTest, BasicView_CUDA) { + if (!isCudaAvailable()) { + GTEST_SKIP() << "CUDA not available"; + } + + std::vector sizes = {2, 3, 4}; + Tensor* orig_tensor = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CUDA), + 0); + ASSERT_NE(orig_tensor, nullptr); + EXPECT_TRUE(orig_tensor->is_cuda()); + + std::vector new_sizes = {6, 4}; + std::vector new_strides = {4, 1}; + + Tensor* view_tensor = nullptr; + AOTITorchError error = aoti_torch__reinterpret_tensor( + orig_tensor, + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + 0, + &view_tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(view_tensor, nullptr); + EXPECT_TRUE(view_tensor->is_cuda()); + + EXPECT_EQ(view_tensor->dim(), 2); + EXPECT_EQ(view_tensor->size(0), 6); + EXPECT_EQ(view_tensor->size(1), 4); + + EXPECT_EQ(view_tensor->data_ptr(), orig_tensor->data_ptr()); + + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(view_tensor), Error::Ok); +} + +TEST_F(AOTITorchReinterpretTensorSlimTest, WithStorageOffset_CUDA) { + if (!isCudaAvailable()) { + GTEST_SKIP() << "CUDA not available"; + } + + std::vector sizes = {4, 4}; + Tensor* orig_tensor = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CUDA), + 0); + ASSERT_NE(orig_tensor, nullptr); + + std::vector new_sizes = {2, 4}; + std::vector new_strides = {4, 1}; + int64_t storage_offset = 8; + + Tensor* view_tensor = nullptr; + AOTITorchError error = aoti_torch__reinterpret_tensor( + orig_tensor, + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + storage_offset, + &view_tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(view_tensor, nullptr); + EXPECT_TRUE(view_tensor->is_cuda()); + + char* orig_ptr = static_cast(orig_tensor->data_ptr()); + char* view_ptr = static_cast(view_tensor->data_ptr()); + EXPECT_EQ(view_ptr, orig_ptr + storage_offset * sizeof(float)); + + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(view_tensor), Error::Ok); +} + +TEST_F(AOTITorchReinterpretTensorSlimTest, MemorySharing_CUDA) { + if (!isCudaAvailable()) { + GTEST_SKIP() << "CUDA not available"; + } + + std::vector sizes = {6}; + Tensor* orig_tensor = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CUDA), + 0); + ASSERT_NE(orig_tensor, nullptr); + + void* orig_ptr = orig_tensor->data_ptr(); + + std::vector new_sizes = {2, 3}; + std::vector new_strides = {3, 1}; + + Tensor* view_tensor = nullptr; + AOTITorchError error = aoti_torch__reinterpret_tensor( + orig_tensor, + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + 0, + &view_tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(view_tensor, nullptr); + + EXPECT_EQ(view_tensor->data_ptr(), orig_ptr); + + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + EXPECT_EQ(view_tensor->data_ptr(), orig_ptr); + + EXPECT_EQ(aoti_torch_delete_tensor_object(view_tensor), Error::Ok); +} + +TEST_F(AOTITorchReinterpretTensorSlimTest, ChainedViews_CUDA) { + if (!isCudaAvailable()) { + GTEST_SKIP() << "CUDA not available"; + } + + std::vector sizes = {24}; + Tensor* orig_tensor = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CUDA), + 0); + ASSERT_NE(orig_tensor, nullptr); + + void* orig_ptr = orig_tensor->data_ptr(); + + std::vector sizes1 = {4, 6}; + std::vector strides1 = {6, 1}; + + Tensor* view1 = nullptr; + EXPECT_EQ( + aoti_torch__reinterpret_tensor( + orig_tensor, + sizes1.size(), + sizes1.data(), + strides1.data(), + 0, + &view1), + Error::Ok); + + std::vector sizes2 = {2, 2, 6}; + std::vector strides2 = {12, 6, 1}; + + Tensor* view2 = nullptr; + EXPECT_EQ( + aoti_torch__reinterpret_tensor( + view1, sizes2.size(), sizes2.data(), strides2.data(), 0, &view2), + Error::Ok); + + EXPECT_EQ(view1->data_ptr(), orig_ptr); + EXPECT_EQ(view2->data_ptr(), orig_ptr); + + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(view1), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(view2), Error::Ok); +} From 2b4b60c9695e3e1e7c857cd890bd72d7e55ccfe6 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 27 Jan 2026 23:46:05 -0800 Subject: [PATCH 2/2] [slimtensor] Add aoti_torch_copy_ for SlimTensor Pull Request resolved: https://github.com/pytorch/executorch/pull/16451 Add SlimTensor-based `aoti_torch_copy_()` - Copies data from source tensor to destination tensor. Delegates to SlimTensor's `copy_()` which handles all device combinations (CPU-CPU, CPU-CUDA, CUDA-CPU, CUDA-CUDA). ghstack-source-id: 336360663 @exported-using-ghexport Differential Revision: [D90126246](https://our.internmc.facebook.com/intern/diff/D90126246/) --- backends/cuda/runtime/shims/memory_slim.cpp | 20 + backends/cuda/runtime/shims/memory_slim.h | 15 + backends/cuda/runtime/shims/tests/targets.bzl | 1 + .../tests/test_aoti_torch_copy__slim.cpp | 487 ++++++++++++++++++ 4 files changed, 523 insertions(+) create mode 100644 backends/cuda/runtime/shims/tests/test_aoti_torch_copy__slim.cpp diff --git a/backends/cuda/runtime/shims/memory_slim.cpp b/backends/cuda/runtime/shims/memory_slim.cpp index 580b9a4530c..45f0d1bc913 100644 --- a/backends/cuda/runtime/shims/memory_slim.cpp +++ b/backends/cuda/runtime/shims/memory_slim.cpp @@ -199,6 +199,26 @@ AOTITorchError aoti_torch__reinterpret_tensor( return Error::Ok; } +AOTITorchError +aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking) { + (void)non_blocking; // SlimTensor::copy_() is always synchronous for now + + ET_CHECK_OR_RETURN_ERROR( + self != nullptr, InvalidArgument, "aoti_torch_copy_: self is null"); + + ET_CHECK_OR_RETURN_ERROR( + src != nullptr, InvalidArgument, "aoti_torch_copy_: src is null"); + + // SlimTensor::copy_() handles: + // - Same numel validation + // - Same dtype validation + // - CPU-CPU, CPU-CUDA, CUDA-CPU, CUDA-CUDA copies + // - Contiguous fast path and non-contiguous element-wise copy + self->copy_(*src); + + return Error::Ok; +} + } // extern "C" } // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/memory_slim.h b/backends/cuda/runtime/shims/memory_slim.h index 64a7a561141..3c6a58fb783 100644 --- a/backends/cuda/runtime/shims/memory_slim.h +++ b/backends/cuda/runtime/shims/memory_slim.h @@ -128,6 +128,21 @@ AOTI_SHIM_EXPORT AOTITorchError aoti_torch__reinterpret_tensor( int64_t storage_offset, Tensor** ret_new_tensor); +/** + * Copies data from source tensor to destination tensor. + * + * Handles all device combinations (CPU-CPU, CPU-CUDA, CUDA-CPU, CUDA-CUDA) + * and supports tensors with different strides. The destination tensor must + * already be allocated with sufficient storage. + * + * @param self Destination tensor (must not be null) + * @param src Source tensor to copy from (must not be null) + * @param non_blocking If true, the copy may be asynchronous (currently ignored) + * @return AOTITorchError error code (Error::Ok on success) + */ +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking); + } // extern "C" } // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/tests/targets.bzl b/backends/cuda/runtime/shims/tests/targets.bzl index ce9f8fcc647..099759d0649 100644 --- a/backends/cuda/runtime/shims/tests/targets.bzl +++ b/backends/cuda/runtime/shims/tests/targets.bzl @@ -76,3 +76,4 @@ def define_common_targets(): cuda_shim_slim_cpp_unittest("aoti_torch_delete_tensor_object") cuda_shim_slim_cpp_unittest("aoti_torch_new_tensor_handle") cuda_shim_slim_cpp_unittest("aoti_torch__reinterpret_tensor") + cuda_shim_slim_cpp_unittest("aoti_torch_copy_") diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_copy__slim.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_copy__slim.cpp new file mode 100644 index 00000000000..c2e67732b41 --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_copy__slim.cpp @@ -0,0 +1,487 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include +#include +#include +#include + +using namespace executorch::backends::cuda; +using executorch::runtime::Error; + +namespace slim_c10 = executorch::backends::aoti::slim::c10; + +namespace { + +bool isCudaAvailable() { + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + return (err == cudaSuccess && device_count > 0); +} + +std::vector calculateContiguousStrides( + const std::vector& sizes) { + std::vector strides(sizes.size()); + if (sizes.empty()) { + return strides; + } + strides[sizes.size() - 1] = 1; + for (int64_t i = static_cast(sizes.size()) - 2; i >= 0; i--) { + strides[i] = strides[i + 1] * sizes[i + 1]; + } + return strides; +} + +} // namespace + +class AOTITorchCopySlimTest : public ::testing::Test { + protected: + void SetUp() override { + et_pal_init(); + } + + Tensor* createTestTensor( + const std::vector& sizes, + const std::vector& strides = {}, + int32_t dtype = static_cast(slim_c10::ScalarType::Float), + int32_t device_type = static_cast(slim_c10::DeviceType::CPU), + int32_t device_index = 0) { + Tensor* tensor = nullptr; + + std::vector effective_strides = strides; + if (strides.empty()) { + effective_strides = calculateContiguousStrides(sizes); + } + + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + effective_strides.data(), + dtype, + device_type, + device_index, + &tensor); + + return (error == Error::Ok) ? tensor : nullptr; + } +}; + +// ============================================================================ +// Basic Functionality Tests +// ============================================================================ + +TEST_F(AOTITorchCopySlimTest, BasicCopy_CPU) { + std::vector sizes = {3, 4}; + Tensor* src = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(src, nullptr); + + float* src_data = static_cast(src->data_ptr()); + for (int64_t i = 0; i < src->numel(); i++) { + src_data[i] = static_cast(i + 1); + } + + Tensor* dst = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(dst, nullptr); + + AOTITorchError error = aoti_torch_copy_(dst, src, 0); + EXPECT_EQ(error, Error::Ok); + + float* dst_data = static_cast(dst->data_ptr()); + for (int64_t i = 0; i < dst->numel(); i++) { + EXPECT_FLOAT_EQ(dst_data[i], static_cast(i + 1)); + } + + EXPECT_EQ(aoti_torch_delete_tensor_object(src), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(dst), Error::Ok); +} + +TEST_F(AOTITorchCopySlimTest, NullSelf) { + std::vector sizes = {2, 3}; + Tensor* src = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(src, nullptr); + + AOTITorchError error = aoti_torch_copy_(nullptr, src, 0); + EXPECT_EQ(error, Error::InvalidArgument); + + EXPECT_EQ(aoti_torch_delete_tensor_object(src), Error::Ok); +} + +TEST_F(AOTITorchCopySlimTest, NullSrc) { + std::vector sizes = {2, 3}; + Tensor* dst = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(dst, nullptr); + + AOTITorchError error = aoti_torch_copy_(dst, nullptr, 0); + EXPECT_EQ(error, Error::InvalidArgument); + + EXPECT_EQ(aoti_torch_delete_tensor_object(dst), Error::Ok); +} + +// ============================================================================ +// Different Dtype Tests +// ============================================================================ + +TEST_F(AOTITorchCopySlimTest, Int64Copy_CPU) { + std::vector sizes = {2, 3}; + Tensor* src = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Long), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(src, nullptr); + + int64_t* src_data = static_cast(src->data_ptr()); + for (int64_t i = 0; i < src->numel(); i++) { + src_data[i] = i * 100; + } + + Tensor* dst = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Long), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(dst, nullptr); + + AOTITorchError error = aoti_torch_copy_(dst, src, 0); + EXPECT_EQ(error, Error::Ok); + + int64_t* dst_data = static_cast(dst->data_ptr()); + for (int64_t i = 0; i < dst->numel(); i++) { + EXPECT_EQ(dst_data[i], i * 100); + } + + EXPECT_EQ(aoti_torch_delete_tensor_object(src), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(dst), Error::Ok); +} + +TEST_F(AOTITorchCopySlimTest, BoolCopy_CPU) { + std::vector sizes = {4}; + Tensor* src = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Bool), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(src, nullptr); + + bool* src_data = static_cast(src->data_ptr()); + src_data[0] = true; + src_data[1] = false; + src_data[2] = true; + src_data[3] = false; + + Tensor* dst = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Bool), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(dst, nullptr); + + AOTITorchError error = aoti_torch_copy_(dst, src, 0); + EXPECT_EQ(error, Error::Ok); + + bool* dst_data = static_cast(dst->data_ptr()); + EXPECT_EQ(dst_data[0], true); + EXPECT_EQ(dst_data[1], false); + EXPECT_EQ(dst_data[2], true); + EXPECT_EQ(dst_data[3], false); + + EXPECT_EQ(aoti_torch_delete_tensor_object(src), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(dst), Error::Ok); +} + +// ============================================================================ +// Tensor Shape Tests +// ============================================================================ + +TEST_F(AOTITorchCopySlimTest, ScalarTensorCopy_CPU) { + std::vector sizes = {}; + Tensor* src = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(src, nullptr); + EXPECT_EQ(src->dim(), 0); + EXPECT_EQ(src->numel(), 1); + + float* src_data = static_cast(src->data_ptr()); + *src_data = 42.0f; + + Tensor* dst = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(dst, nullptr); + + AOTITorchError error = aoti_torch_copy_(dst, src, 0); + EXPECT_EQ(error, Error::Ok); + + float* dst_data = static_cast(dst->data_ptr()); + EXPECT_FLOAT_EQ(*dst_data, 42.0f); + + EXPECT_EQ(aoti_torch_delete_tensor_object(src), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(dst), Error::Ok); +} + +TEST_F(AOTITorchCopySlimTest, LargeTensorCopy_CPU) { + std::vector sizes = {100, 100}; + Tensor* src = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(src, nullptr); + + float* src_data = static_cast(src->data_ptr()); + for (int64_t i = 0; i < src->numel(); i++) { + src_data[i] = static_cast(i); + } + + Tensor* dst = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(dst, nullptr); + + AOTITorchError error = aoti_torch_copy_(dst, src, 0); + EXPECT_EQ(error, Error::Ok); + + float* dst_data = static_cast(dst->data_ptr()); + for (int64_t i = 0; i < dst->numel(); i++) { + EXPECT_FLOAT_EQ(dst_data[i], static_cast(i)); + } + + EXPECT_EQ(aoti_torch_delete_tensor_object(src), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(dst), Error::Ok); +} + +// ============================================================================ +// CUDA Tests +// ============================================================================ + +TEST_F(AOTITorchCopySlimTest, CudaToCuda) { + if (!isCudaAvailable()) { + GTEST_SKIP() << "CUDA not available"; + } + + std::vector sizes = {3, 4}; + + std::vector host_src_data(12); + for (size_t i = 0; i < host_src_data.size(); i++) { + host_src_data[i] = static_cast(i + 1); + } + + Tensor* src = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CUDA), + 0); + ASSERT_NE(src, nullptr); + EXPECT_TRUE(src->is_cuda()); + + cudaMemcpy( + src->data_ptr(), + host_src_data.data(), + host_src_data.size() * sizeof(float), + cudaMemcpyHostToDevice); + + Tensor* dst = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CUDA), + 0); + ASSERT_NE(dst, nullptr); + EXPECT_TRUE(dst->is_cuda()); + + AOTITorchError error = aoti_torch_copy_(dst, src, 0); + EXPECT_EQ(error, Error::Ok); + + std::vector host_dst_data(12); + cudaMemcpy( + host_dst_data.data(), + dst->data_ptr(), + host_dst_data.size() * sizeof(float), + cudaMemcpyDeviceToHost); + + for (size_t i = 0; i < host_dst_data.size(); i++) { + EXPECT_FLOAT_EQ(host_dst_data[i], static_cast(i + 1)); + } + + EXPECT_EQ(aoti_torch_delete_tensor_object(src), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(dst), Error::Ok); +} + +TEST_F(AOTITorchCopySlimTest, CpuToCuda) { + if (!isCudaAvailable()) { + GTEST_SKIP() << "CUDA not available"; + } + + std::vector sizes = {2, 3}; + Tensor* src = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(src, nullptr); + EXPECT_TRUE(src->is_cpu()); + + float* src_data = static_cast(src->data_ptr()); + for (int64_t i = 0; i < src->numel(); i++) { + src_data[i] = static_cast(i * 10); + } + + Tensor* dst = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CUDA), + 0); + ASSERT_NE(dst, nullptr); + EXPECT_TRUE(dst->is_cuda()); + + AOTITorchError error = aoti_torch_copy_(dst, src, 0); + EXPECT_EQ(error, Error::Ok); + + std::vector host_dst_data(6); + cudaMemcpy( + host_dst_data.data(), + dst->data_ptr(), + host_dst_data.size() * sizeof(float), + cudaMemcpyDeviceToHost); + + for (size_t i = 0; i < host_dst_data.size(); i++) { + EXPECT_FLOAT_EQ(host_dst_data[i], static_cast(i * 10)); + } + + EXPECT_EQ(aoti_torch_delete_tensor_object(src), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(dst), Error::Ok); +} + +TEST_F(AOTITorchCopySlimTest, CudaToCpu) { + if (!isCudaAvailable()) { + GTEST_SKIP() << "CUDA not available"; + } + + std::vector sizes = {2, 3}; + + std::vector host_src_data(6); + for (size_t i = 0; i < host_src_data.size(); i++) { + host_src_data[i] = static_cast(i * 5); + } + + Tensor* src = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CUDA), + 0); + ASSERT_NE(src, nullptr); + + cudaMemcpy( + src->data_ptr(), + host_src_data.data(), + host_src_data.size() * sizeof(float), + cudaMemcpyHostToDevice); + + Tensor* dst = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(dst, nullptr); + EXPECT_TRUE(dst->is_cpu()); + + AOTITorchError error = aoti_torch_copy_(dst, src, 0); + EXPECT_EQ(error, Error::Ok); + + float* dst_data = static_cast(dst->data_ptr()); + for (int64_t i = 0; i < dst->numel(); i++) { + EXPECT_FLOAT_EQ(dst_data[i], static_cast(i * 5)); + } + + EXPECT_EQ(aoti_torch_delete_tensor_object(src), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(dst), Error::Ok); +} + +// ============================================================================ +// Non-blocking Tests +// ============================================================================ + +TEST_F(AOTITorchCopySlimTest, NonBlockingFlag_CPU) { + std::vector sizes = {2, 3}; + Tensor* src = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(src, nullptr); + + float* src_data = static_cast(src->data_ptr()); + for (int64_t i = 0; i < src->numel(); i++) { + src_data[i] = static_cast(i); + } + + Tensor* dst = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(dst, nullptr); + + AOTITorchError error = aoti_torch_copy_(dst, src, 1); + EXPECT_EQ(error, Error::Ok); + + float* dst_data = static_cast(dst->data_ptr()); + for (int64_t i = 0; i < dst->numel(); i++) { + EXPECT_FLOAT_EQ(dst_data[i], static_cast(i)); + } + + EXPECT_EQ(aoti_torch_delete_tensor_object(src), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(dst), Error::Ok); +}