Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions backends/aoti/slim/c10/cuda/Exception.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#ifdef CUDA_AVAILABLE

#include <cuda.h>
#include <cuda_runtime.h>

#include <executorch/backends/aoti/slim/c10/macros/Macros.h>
#include <executorch/runtime/platform/assert.h>
#include <executorch/runtime/platform/log.h>

/// Checks a CUDA expression and aborts on error.
/// @param EXPR The CUDA expression to check.
#define ET_CUDA_CHECK(EXPR) \
do { \
const cudaError_t __err = EXPR; \
ET_CHECK_MSG( \
__err == cudaSuccess, "CUDA error: %s", cudaGetErrorString(__err)); \
} while (0)

/// Checks a CUDA expression and logs a warning on error (non-fatal).
/// @param EXPR The CUDA expression to check.
#define ET_CUDA_LOG_WARN(EXPR) \
do { \
const cudaError_t __err = EXPR; \
if (SLIMTENSOR_UNLIKELY(__err != cudaSuccess)) { \
[[maybe_unused]] auto error_unused = cudaGetLastError(); \
ET_LOG(Error, "CUDA warning: %s", cudaGetErrorString(__err)); \
} \
} while (0)

#endif // CUDA_AVAILABLE
6 changes: 6 additions & 0 deletions backends/aoti/slim/c10/cuda/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
load(":targets.bzl", "define_common_targets")

oncall("executorch")

define_common_targets()
16 changes: 16 additions & 0 deletions backends/aoti/slim/c10/cuda/targets.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

def define_common_targets():
"""Define targets for SlimTensor CUDA exception handling module."""

runtime.cxx_library(
name = "exception",
exported_headers = [
"Exception.h",
],
visibility = ["@EXECUTORCH_CLIENTS"],
exported_deps = [
"//executorch/backends/aoti/slim/c10/macros:macros",
"//executorch/runtime/platform:platform",
],
)
78 changes: 43 additions & 35 deletions backends/aoti/slim/core/SlimTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <cstdint>
#include <cstring>
#include <utility>
#include <vector>

#include <executorch/backends/aoti/slim/c10/core/Contiguity.h>
#include <executorch/backends/aoti/slim/c10/core/Device.h>
Expand Down Expand Up @@ -227,6 +228,13 @@ class SlimTensor {
return device().is_cpu();
}

/**
* Check if the tensor is on CUDA.
*/
bool is_cuda() const {
return device().is_cuda();
}

/**
* Check if the tensor is defined (has valid storage).
*/
Expand Down Expand Up @@ -270,69 +278,67 @@ class SlimTensor {
* Copy data from another tensor to this tensor.
*
* Both tensors must have the same numel and dtype.
* Currently only supports CPU-to-CPU copy (contiguous tensors only).
* Supports CPU-to-CPU and cross-device copies (CPU↔CUDA, CUDA↔CUDA).
*
* @param other The source tensor to copy from
* @return Reference to this tensor
*/
SlimTensor& copy_(const SlimTensor& other) {
ET_CHECK_MSG(
this->numel() == other.numel(),
"copy_: numel mismatch (dst=%zu, src=%zu)",
this->numel(),
other.numel());
ET_CHECK_MSG(this->dtype() == other.dtype(), "copy_: dtype mismatch");
this->numel() == other.numel(), "copy_: numel of tensors must match");
ET_CHECK_MSG(this->dtype() == other.dtype(), "copy_: dtype must match");

if (this->numel() == 0) {
return *this;
}

// Current we only support CPU-only tensors
// TODO(gasoonjia): support other device types.
ET_CHECK_MSG(
this->is_cpu() && other.is_cpu(), "copy_: only CPU tensors supported");

// Case 1: Both tensors are contiguous. We can do a fast bulk copy.
if (this->is_contiguous() && other.is_contiguous()) {
// Fast path: both tensors are contiguous, use memcpy
std::memcpy(this->data_ptr(), other.data_ptr(), other.nbytes());
} else {
// Slow path: element-wise copy for non-contiguous tensors
copy_strided_(other);
storage_->copy_(
this->data_ptr(), other.data_ptr(), other.nbytes(), other.device());
return *this;
}

return *this;
}

private:
/**
* Element-wise copy for non-contiguous tensors.
*/
void copy_strided_(const SlimTensor& other) {
// Case 2: At least one tensor is non-contiguous, perform element-wise copy
// that respects both source and destination strides.
const size_t elem_size = c10::elementSize(dtype_);
char* dst_data = static_cast<char*>(this->data_ptr());
const char* src_data = static_cast<const char*>(other.data_ptr());

std::vector<int64_t> counter(this->dim(), 0);
for (size_t i = 0; i < this->numel(); i++) {
// Compute source offset
// Compute src offset in elements
int64_t src_offset = 0;
for (size_t d = 0; d < other.dim(); d++) {
src_offset += counter[d] * other.stride(static_cast<int64_t>(d));
src_offset += counter[d] * other.stride(d);
}

// Compute destination offset
// Compute dst offset in elements
int64_t dst_offset = 0;
for (size_t d = 0; d < this->dim(); d++) {
dst_offset += counter[d] * this->stride(static_cast<int64_t>(d));
dst_offset += counter[d] * this->stride(d);
}

// Copy single element
std::memcpy(
dst_data + dst_offset * static_cast<int64_t>(elem_size),
src_data + src_offset * static_cast<int64_t>(elem_size),
elem_size);

// Increment multi-dimensional counter
// Copy elem_size bytes from src to dst
if (this->device().is_cpu() && other.device().is_cpu()) {
std::memcpy(
dst_data + dst_offset * elem_size,
src_data + src_offset * elem_size,
elem_size);
} else if (this->device().is_cuda() || other.device().is_cuda()) {
#if defined(CUDA_AVAILABLE)
DeviceTraits<c10::DeviceType::CUDA>::memcpy(
dst_data + dst_offset * elem_size,
src_data + src_offset * elem_size,
elem_size,
device(), // dst device
other.device() // src device
);
#else
ET_CHECK_MSG(false, "Failed on copy_ cuda tensors: no CUDA support");
#endif
}
// Increment the multi-dimensional counter
for (int64_t d = static_cast<int64_t>(this->dim()) - 1; d >= 0; --d) {
counter[d]++;
if (counter[d] < this->size(d)) {
Expand All @@ -341,8 +347,10 @@ class SlimTensor {
counter[d] = 0;
}
}
return *this;
}

private:
void refresh_numel() {
numel_ = compute_numel(sizes_and_strides_.sizes_arrayref());
}
Expand Down
Loading
Loading