Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions backends/aoti/common_shims_slim.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/aoti/common_shims_slim.h>

namespace executorch {
namespace backends {
namespace aoti {

extern "C" {

// ============================================================
// Basic Property Getters - Implementations
// ============================================================

AOTITorchError aoti_torch_get_data_ptr(Tensor* tensor, void** ret_data_ptr) {
if (tensor == nullptr || ret_data_ptr == nullptr) {
return Error::InvalidArgument;
}
*ret_data_ptr = tensor->data_ptr();
return Error::Ok;
}

AOTITorchError aoti_torch_get_sizes(Tensor* tensor, int64_t** ret_sizes) {
if (tensor == nullptr || ret_sizes == nullptr) {
return Error::InvalidArgument;
}
*ret_sizes = const_cast<int64_t*>(tensor->sizes().data());
return Error::Ok;
}

AOTITorchError aoti_torch_get_strides(Tensor* tensor, int64_t** ret_strides) {
if (tensor == nullptr || ret_strides == nullptr) {
return Error::InvalidArgument;
}
*ret_strides = const_cast<int64_t*>(tensor->strides().data());
return Error::Ok;
}

AOTITorchError aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype) {
if (tensor == nullptr || ret_dtype == nullptr) {
return Error::InvalidArgument;
}
*ret_dtype = static_cast<int32_t>(tensor->dtype());
return Error::Ok;
}

AOTITorchError aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim) {
if (tensor == nullptr || ret_dim == nullptr) {
return Error::InvalidArgument;
}
*ret_dim = static_cast<int64_t>(tensor->dim());
return Error::Ok;
}

int32_t aoti_torch_layout_strided() {
// Slimtensor only support strided layout, the return value will always be 0,
// a.k.a at::Layout::Strided;
return 0;
}

// ============================================================
// Storage & Device Property Getters - Implementations
// ============================================================

AOTITorchError aoti_torch_get_storage_offset(
Tensor* tensor,
int64_t* ret_storage_offset) {
if (tensor == nullptr || ret_storage_offset == nullptr) {
return Error::InvalidArgument;
}
*ret_storage_offset = tensor->storage_offset();
return Error::Ok;
}

AOTITorchError aoti_torch_get_storage_size(Tensor* tensor, int64_t* ret_size) {
if (tensor == nullptr || ret_size == nullptr) {
return Error::InvalidArgument;
}
*ret_size = static_cast<int64_t>(tensor->storage()->nbytes());
return Error::Ok;
}

AOTITorchError aoti_torch_get_device_type(
Tensor* tensor,
int32_t* ret_device_type) {
if (tensor == nullptr || ret_device_type == nullptr) {
return Error::InvalidArgument;
}
*ret_device_type = static_cast<int32_t>(tensor->device_type());
return Error::Ok;
}

AOTITorchError aoti_torch_get_device_index(
Tensor* tensor,
int32_t* ret_device_index) {
if (tensor == nullptr || ret_device_index == nullptr) {
return Error::InvalidArgument;
}
*ret_device_index = static_cast<int32_t>(tensor->device_index());
return Error::Ok;
}

// ============================================================
// DType Constants - Implementations
// ============================================================

int32_t aoti_torch_dtype_float32() {
return 6; // ScalarType::Float
}

int32_t aoti_torch_dtype_bfloat16() {
return 15; // ScalarType::BFloat16
}

int32_t aoti_torch_dtype_int64() {
return 4; // ScalarType::Long
}

int32_t aoti_torch_dtype_int32() {
return 3; // ScalarType::Int
}

int32_t aoti_torch_dtype_int16() {
return 2; // ScalarType::Short
}

int32_t aoti_torch_dtype_int8() {
return 1; // ScalarType::Char
}

int32_t aoti_torch_dtype_bool() {
return 11; // ScalarType::Bool
}

// ============================================================
// Device Type Constants - Implementations
// ============================================================

int32_t aoti_torch_device_type_cpu() {
return 0; // DeviceType::CPU
}

int32_t aoti_torch_device_type_cuda() {
return 1; // DeviceType::CUDA
}

// ============================================================
// Grad Mode Functions - Implementations
// ============================================================

bool aoti_torch_grad_mode_is_enabled() {
// ExecuTorch doesn't support autograd
return false;
}

AOTITorchError aoti_torch_grad_mode_set_enabled(bool enabled) {
if (enabled) {
// ExecuTorch doesn't support autograd
return Error::NotSupported;
}
return Error::Ok;
}

} // extern "C"
} // namespace aoti
} // namespace backends
} // namespace executorch
98 changes: 98 additions & 0 deletions backends/aoti/common_shims_slim.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <executorch/backends/aoti/export.h>
#include <executorch/backends/aoti/slim/core/SlimTensor.h>
#include <executorch/runtime/core/error.h>
#include <cstdint>

namespace executorch {
namespace backends {
namespace aoti {

extern "C" {

// Common using declarations for ExecuTorch types
using executorch::runtime::Error;

// Tensor type definition using SlimTensor
using Tensor = executorch::backends::aoti::slim::SlimTensor;

// Common AOTI type aliases
using AOTIRuntimeError = Error;
using AOTITorchError = Error;

// ============================================================
// Basic Property Getters - Declarations
// ============================================================

AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_get_data_ptr(Tensor* tensor, void** ret_data_ptr);

AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_get_sizes(Tensor* tensor, int64_t** ret_sizes);

AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_get_strides(Tensor* tensor, int64_t** ret_strides);

AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype);

AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim);

AOTI_SHIM_EXPORT int32_t aoti_torch_layout_strided();

// ============================================================
// Storage & Device Property Getters - Declarations
// ============================================================

AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_get_storage_offset(Tensor* tensor, int64_t* ret_storage_offset);

AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_get_storage_size(Tensor* tensor, int64_t* ret_size);

AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_get_device_type(Tensor* tensor, int32_t* ret_device_type);

AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_get_device_index(Tensor* tensor, int32_t* ret_device_index);

// ============================================================
// DType Constants - Declarations
// ============================================================

AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_float32();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bfloat16();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int64();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int32();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int16();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int8();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bool();

// ============================================================
// Device Type Constants - Declarations
// ============================================================

AOTI_SHIM_EXPORT int32_t aoti_torch_device_type_cpu();
AOTI_SHIM_EXPORT int32_t aoti_torch_device_type_cuda();

// ============================================================
// Grad Mode Functions - Declarations
// ============================================================

AOTI_SHIM_EXPORT bool aoti_torch_grad_mode_is_enabled();
AOTI_SHIM_EXPORT AOTITorchError aoti_torch_grad_mode_set_enabled(bool enabled);

} // extern "C"
} // namespace aoti
} // namespace backends
} // namespace executorch
18 changes: 18 additions & 0 deletions backends/aoti/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,21 @@ def define_common_targets():
":delegate_handle",
],
)

# SlimTensor-based common shims library
# Uses SlimTensor for all tensor operations
runtime.cxx_library(
name = "common_shims_slim",
srcs = [
"common_shims_slim.cpp",
],
headers = [
"common_shims_slim.h",
"export.h",
],
visibility = ["@EXECUTORCH_CLIENTS"],
exported_deps = [
"//executorch/runtime/core:core",
"//executorch/backends/aoti/slim/core:slimtensor",
],
)
25 changes: 25 additions & 0 deletions backends/aoti/tests/TARGETS
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest")
load("@fbcode_macros//build_defs/lib:re_test_utils.bzl", "re_test_utils")

oncall("executorch")

Expand All @@ -20,3 +21,27 @@ cpp_unittest(
"//executorch/extension/tensor:tensor",
],
)

cpp_unittest(
name = "test_common_shims_slim",
srcs = [
"test_common_shims_slim.cpp",
],
deps = [
"//executorch/backends/aoti:common_shims_slim",
"//executorch/backends/aoti/slim/core:slimtensor",
"//executorch/backends/aoti/slim/factory:empty",
"//executorch/runtime/core:core",
"//executorch/runtime/platform:platform",
],
external_deps = [
("cuda", None, "cuda-lazy"),
],
preprocessor_flags = [
"-DCUDA_AVAILABLE=1",
],
keep_gpu_sections = True,
remote_execution = re_test_utils.remote_execution(
platform = "gpu-remote-execution",
),
)
Loading
Loading