From b6b4ad7ec6bbf5b57bab2cd376dba59595be0ecb Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Wed, 6 May 2026 21:07:43 -0700 Subject: [PATCH] Add cuBLAS mm_out shim to eliminate libtorch runtime dependency Implements aoti_torch_cuda_mm_out as a thin cuBLAS wrapper in the ExecuTorch AOTI CUDA shims. When Inductor picks cuBLAS over Triton templates for aten::mm (F.linear), the compiled .so requires this symbol at runtime. Without this shim, it resolves from libtorch_cuda.so, pulling in the full libtorch runtime. In practice, Inductor's autotune on A100 picks Triton templates for the Qwen3.5 MoE dense projections (bf16 [M,2048]x[2048,N]), so the shim is not exercised for this model. It serves as a safety net for models or shapes where cuBLAS wins the autotune, ensuring fully libtorch-free AOTI CUDA deployment in all cases. Co-authored-by: Claude --- backends/cuda/CMakeLists.txt | 13 +- backends/cuda/runtime/shims/mm.cu | 175 ++++++++++ backends/cuda/runtime/shims/mm.h | 41 +++ .../cuda/runtime/shims/tests/CMakeLists.txt | 18 + backends/cuda/runtime/shims/tests/targets.bzl | 1 + .../tests/test_aoti_torch_cuda_mm_out.cpp | 315 ++++++++++++++++++ 6 files changed, 559 insertions(+), 4 deletions(-) create mode 100644 backends/cuda/runtime/shims/mm.cu create mode 100644 backends/cuda/runtime/shims/mm.h create mode 100644 backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_mm_out.cpp diff --git a/backends/cuda/CMakeLists.txt b/backends/cuda/CMakeLists.txt index 157cc05a54f..cf3a509e069 100644 --- a/backends/cuda/CMakeLists.txt +++ b/backends/cuda/CMakeLists.txt @@ -110,7 +110,7 @@ set(_aoti_cuda_shim_sources runtime/shims/memory.cpp # Only build CUDA shims when CUDA language/toolchain is available. if(CMAKE_CUDA_COMPILER) list(APPEND _aoti_cuda_shim_sources runtime/shims/int4mm.cu - runtime/shims/sort.cu runtime/shims/rand.cu + runtime/shims/sort.cu runtime/shims/rand.cu runtime/shims/mm.cu ) endif() @@ -153,7 +153,7 @@ endif() if(_cuda_is_msvc_toolchain) target_link_libraries( aoti_cuda_shims PRIVATE cuda_platform CUDA::cudart CUDA::curand - ${CMAKE_DL_LIBS} + CUDA::cublas ${CMAKE_DL_LIBS} ) # Link object library directly so symbols are pulled exactly once while # avoiding duplicate static/object inclusion and interface leakage. @@ -162,8 +162,13 @@ else() target_link_libraries( aoti_cuda_shims PRIVATE cuda_platform - PUBLIC -Wl,--whole-archive aoti_common_shims_slim -Wl,--no-whole-archive - CUDA::cudart CUDA::curand ${CMAKE_DL_LIBS} + PUBLIC -Wl,--whole-archive + aoti_common_shims_slim + -Wl,--no-whole-archive + CUDA::cudart + CUDA::curand + CUDA::cublas + ${CMAKE_DL_LIBS} ) endif() diff --git a/backends/cuda/runtime/shims/mm.cu b/backends/cuda/runtime/shims/mm.cu new file mode 100644 index 00000000000..f555aabbf32 --- /dev/null +++ b/backends/cuda/runtime/shims/mm.cu @@ -0,0 +1,175 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include +#include +#include + +#include +#include + +namespace executorch::backends::cuda { + +namespace c10_slim = executorch::backends::aoti::slim::c10; + +namespace { + +constexpr int kMaxDevices = 16; + +struct CuBLASHandles { + std::mutex mutex; + cublasHandle_t handles[kMaxDevices] = {}; + bool initialized[kMaxDevices] = {}; + + cublasHandle_t get(int device) { + std::lock_guard lock(mutex); + if (!initialized[device]) { + cudaSetDevice(device); + cublasCreate(&handles[device]); + cublasSetMathMode(handles[device], CUBLAS_DEFAULT_MATH); + initialized[device] = true; + } + return handles[device]; + } +}; + +CuBLASHandles& cublas_handles() { + static CuBLASHandles instance; + return instance; +} + +} // namespace + +#ifdef __cplusplus +extern "C" { +#endif + +AOTITorchError +aoti_torch_cuda_mm_out(Tensor* out, Tensor* self, Tensor* mat2) { + ET_CHECK_OR_RETURN_ERROR( + out != nullptr, InvalidArgument, "mm_out: out is null"); + ET_CHECK_OR_RETURN_ERROR( + self != nullptr, InvalidArgument, "mm_out: self is null"); + ET_CHECK_OR_RETURN_ERROR( + mat2 != nullptr, InvalidArgument, "mm_out: mat2 is null"); + ET_CHECK_OR_RETURN_ERROR( + self->dim() == 2 && mat2->dim() == 2 && out->dim() == 2, + InvalidArgument, + "mm_out: all tensors must be 2D"); + ET_CHECK_OR_RETURN_ERROR( + self->is_contiguous() && mat2->is_contiguous() && out->is_contiguous(), + InvalidArgument, + "mm_out: all tensors must be contiguous"); + + int64_t M = self->size(0); + int64_t K = self->size(1); + int64_t N = mat2->size(1); + + ET_CHECK_OR_RETURN_ERROR( + mat2->size(0) == K, + InvalidArgument, + "mm_out: self [%ld,%ld] x mat2 [%ld,%ld] inner dims mismatch", + M, + K, + mat2->size(0), + N); + ET_CHECK_OR_RETURN_ERROR( + out->size(0) == M && out->size(1) == N, + InvalidArgument, + "mm_out: out shape mismatch"); + + auto dtype = self->dtype(); + ET_CHECK_OR_RETURN_ERROR( + mat2->dtype() == dtype && out->dtype() == dtype, + InvalidArgument, + "mm_out: dtype mismatch"); + + cudaDataType_t cuda_dtype; + cublasComputeType_t compute_type; + if (dtype == c10_slim::ScalarType::BFloat16) { + cuda_dtype = CUDA_R_16BF; + compute_type = CUBLAS_COMPUTE_32F; + } else if (dtype == c10_slim::ScalarType::Half) { + cuda_dtype = CUDA_R_16F; + compute_type = CUBLAS_COMPUTE_32F; + } else if (dtype == c10_slim::ScalarType::Float) { + cuda_dtype = CUDA_R_32F; + compute_type = CUBLAS_COMPUTE_32F; + } else { + ET_CHECK_OR_RETURN_ERROR( + false, InvalidArgument, "mm_out: unsupported dtype"); + } + + int device = self->device_index(); + ET_CHECK_OR_RETURN_ERROR( + device >= 0 && device < kMaxDevices, + InvalidArgument, + "mm_out: device index %d out of range", + device); + + auto stream_result = getCurrentCUDAStream(device); + ET_CHECK_OR_RETURN_ERROR( + stream_result.ok(), Internal, "mm_out: failed to get CUDA stream"); + + // Per-device handle; mutex in get() ensures thread-safe initialization. + // cublasSetStream + cublasGemmEx are serialized under the same mutex to + // prevent races when multiple threads share a device. + auto& handles = cublas_handles(); + std::lock_guard lock(handles.mutex); + cublasHandle_t handle = handles.get(device); + cublasSetStream(handle, stream_result.get()); + + // cuBLAS is column-major. For row-major C = A @ B: + // C^T = B^T @ A^T + // With column-major interpretation of row-major data: + // A_row[M,K] looks like A^T_col[K,M] with lda=K + // B_row[K,N] looks like B^T_col[N,K] with ldb=N + // C_row[M,N] looks like C^T_col[N,M] with ldc=N + // So: C^T = B^T @ A^T → gemm(N, N, N, M, K, B, N, A, K, C, N) + float alpha = 1.0f; + float beta = 0.0f; + + auto status = cublasGemmEx( + handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + N, // m (columns of C^T) + M, // n (rows of C^T) + K, // k + &alpha, + mat2->data_ptr(), // B^T in col-major = B in row-major + cuda_dtype, + N, // ldb (row-major stride of mat2) + self->data_ptr(), // A^T in col-major = A in row-major + cuda_dtype, + K, // lda (row-major stride of self) + &beta, + out->data_ptr(), + cuda_dtype, + N, // ldc (row-major stride of out) + compute_type, + CUBLAS_GEMM_DEFAULT); + + ET_CHECK_OR_RETURN_ERROR( + status == CUBLAS_STATUS_SUCCESS, + Internal, + "mm_out: cublasGemmEx failed with status %d", + (int)status); + + return Error::Ok; +} + +#ifdef __cplusplus +} +#endif + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/mm.h b/backends/cuda/runtime/shims/mm.h new file mode 100644 index 00000000000..cd36868b411 --- /dev/null +++ b/backends/cuda/runtime/shims/mm.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace executorch::backends::cuda { + +using executorch::backends::aoti::AOTITorchError; +using executorch::backends::aoti::Tensor; + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * Matrix multiplication via cuBLAS: out = self @ mat2. + * + * Replaces libtorch's aoti_torch_cuda_mm_out so the AOTI CUDA backend + * can run without libtorch_cuda.so. Calls cublasGemmEx directly. + * + * @param out Pre-allocated output [M, N], same dtype as inputs. + * @param self Input matrix [M, K]. Must be bf16 or fp16, 2D, contiguous. + * @param mat2 Input matrix [K, N]. Must be bf16 or fp16, 2D, contiguous. + */ +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_cuda_mm_out(Tensor* out, Tensor* self, Tensor* mat2); + +#ifdef __cplusplus +} +#endif + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/tests/CMakeLists.txt b/backends/cuda/runtime/shims/tests/CMakeLists.txt index aec5219d680..a2cbae5d6d8 100644 --- a/backends/cuda/runtime/shims/tests/CMakeLists.txt +++ b/backends/cuda/runtime/shims/tests/CMakeLists.txt @@ -67,3 +67,21 @@ foreach(test_name ${CUDA_SHIM_TESTS}) add_test(NAME ${test_name} COMMAND ${test_name}) endforeach() + +# mm_out test — cuBLAS is already linked into aoti_cuda_shims +add_executable(test_aoti_torch_cuda_mm_out test_aoti_torch_cuda_mm_out.cpp) + +target_include_directories( + test_aoti_torch_cuda_mm_out PRIVATE ${EXECUTORCH_ROOT}/.. ${EXECUTORCH_ROOT} + ${CUDAToolkit_INCLUDE_DIRS} +) + +target_compile_definitions(test_aoti_torch_cuda_mm_out PRIVATE CUDA_AVAILABLE=1) + +target_link_libraries( + test_aoti_torch_cuda_mm_out + PRIVATE GTest::gtest GTest::gtest_main aoti_cuda_shims executorch_core + CUDA::cudart +) + +add_test(NAME test_aoti_torch_cuda_mm_out COMMAND test_aoti_torch_cuda_mm_out) diff --git a/backends/cuda/runtime/shims/tests/targets.bzl b/backends/cuda/runtime/shims/tests/targets.bzl index b68043f7feb..5fec9c8fdcc 100644 --- a/backends/cuda/runtime/shims/tests/targets.bzl +++ b/backends/cuda/runtime/shims/tests/targets.bzl @@ -42,3 +42,4 @@ def define_common_targets(): cuda_shim_cpp_unittest("aoti_torch_new_tensor_handle") cuda_shim_cpp_unittest("aoti_torch_item_bool") cuda_shim_cpp_unittest("aoti_torch_assign_tensors_out") + cuda_shim_cpp_unittest("aoti_torch_cuda_mm_out") diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_mm_out.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_mm_out.cpp new file mode 100644 index 00000000000..ad6ac644de2 --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_mm_out.cpp @@ -0,0 +1,315 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +using executorch::backends::cuda::aoti_torch_cuda_mm_out; +using executorch::backends::cuda::aoti_torch_delete_tensor_object; +using executorch::backends::cuda::aoti_torch_empty_strided; +using executorch::backends::cuda::AOTITorchError; +using executorch::runtime::Error; +namespace slim_c10 = executorch::backends::aoti::slim::c10; + +using Tensor = executorch::backends::aoti::slim::SlimTensor; + +// -- Dtype traits for templated tests ---------------------------------------- + +template +struct DtypeTraits; + +template <> +struct DtypeTraits<__nv_bfloat16> { + static constexpr slim_c10::ScalarType scalar_type = + slim_c10::ScalarType::BFloat16; + static __nv_bfloat16 from_float(float v) { + return __float2bfloat16(v); + } + static float to_float(__nv_bfloat16 v) { + return __bfloat162float(v); + } +}; + +template <> +struct DtypeTraits<__half> { + static constexpr slim_c10::ScalarType scalar_type = + slim_c10::ScalarType::Half; + static __half from_float(float v) { + return __float2half(v); + } + static float to_float(__half v) { + return __half2float(v); + } +}; + +template <> +struct DtypeTraits { + static constexpr slim_c10::ScalarType scalar_type = + slim_c10::ScalarType::Float; + static float from_float(float v) { + return v; + } + static float to_float(float v) { + return v; + } +}; + +// -- Test fixture ------------------------------------------------------------ + +template +class AOTITorchMmOutTypedTest : public ::testing::Test { + protected: + using Traits = DtypeTraits; + + void SetUp() override { + et_pal_init(); + int device_count = 0; + if (cudaGetDeviceCount(&device_count) != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available"; + } + } + + Tensor* createTensor(const std::vector& sizes) { + Tensor* tensor = nullptr; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, + static_cast(Traits::scalar_type), + static_cast(slim_c10::DeviceType::CUDA), + 0, + &tensor); + return (error == Error::Ok) ? tensor : nullptr; + } + + // Small-integer reference: inputs are exactly representable in all dtypes, + // so cuBLAS output must match the serial f32 reference exactly. + void runExactTest( + int64_t M, + int64_t K, + int64_t N, + const std::vector& h_A, + const std::vector& h_B) { + Tensor* self = createTensor({M, K}); + ASSERT_NE(self, nullptr); + Tensor* mat2 = createTensor({K, N}); + ASSERT_NE(mat2, nullptr); + Tensor* out = createTensor({M, N}); + ASSERT_NE(out, nullptr); + + std::vector d_A(M * K), d_B(K * N); + for (int64_t i = 0; i < M * K; i++) + d_A[i] = Traits::from_float(h_A[i]); + for (int64_t i = 0; i < K * N; i++) + d_B[i] = Traits::from_float(h_B[i]); + + cudaMemcpy( + self->data_ptr(), + d_A.data(), + M * K * sizeof(T), + cudaMemcpyHostToDevice); + cudaMemcpy( + mat2->data_ptr(), + d_B.data(), + K * N * sizeof(T), + cudaMemcpyHostToDevice); + + AOTITorchError error = aoti_torch_cuda_mm_out(out, self, mat2); + EXPECT_EQ(error, Error::Ok); + cudaDeviceSynchronize(); + + std::vector h_out(M * N); + cudaMemcpy( + h_out.data(), + out->data_ptr(), + M * N * sizeof(T), + cudaMemcpyDeviceToHost); + + // Serial f32 reference + for (int64_t i = 0; i < M; i++) { + for (int64_t j = 0; j < N; j++) { + float expected = 0.0f; + for (int64_t p = 0; p < K; p++) { + expected += h_A[i * K + p] * h_B[p * N + j]; + } + float actual = Traits::to_float(h_out[i * N + j]); + EXPECT_EQ(actual, expected) << "Mismatch at [" << i << "," << j << "]"; + } + } + + EXPECT_EQ(aoti_torch_delete_tensor_object(self), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(mat2), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(out), Error::Ok); + } +}; + +using MmOutTestTypes = ::testing::Types<__nv_bfloat16, __half, float>; +TYPED_TEST_SUITE(AOTITorchMmOutTypedTest, MmOutTestTypes); + +// -- Typed correctness tests (run for bf16, fp16, fp32) ---------------------- +// Use small integers so results are exact in all dtypes. + +TYPED_TEST(AOTITorchMmOutTypedTest, SmallSquare) { + int64_t M = 4, K = 8, N = 6; + std::vector h_A(M * K), h_B(K * N); + for (int64_t i = 0; i < M * K; i++) + h_A[i] = static_cast((i % 5) + 1); + for (int64_t i = 0; i < K * N; i++) + h_B[i] = static_cast((i % 3) + 1); + this->runExactTest(M, K, N, h_A, h_B); +} + +TYPED_TEST(AOTITorchMmOutTypedTest, SingleRow) { + int64_t M = 1, K = 16, N = 8; + std::vector h_A(M * K), h_B(K * N); + for (int64_t i = 0; i < M * K; i++) + h_A[i] = static_cast((i % 4) + 1); + for (int64_t i = 0; i < K * N; i++) + h_B[i] = static_cast((i % 3) + 1); + this->runExactTest(M, K, N, h_A, h_B); +} + +TYPED_TEST(AOTITorchMmOutTypedTest, AllOnes) { + int64_t M = 1, K = 2048, N = 256; + std::vector h_A(M * K, 1.0f), h_B(K * N, 1.0f); + this->runExactTest(M, K, N, h_A, h_B); +} + +TYPED_TEST(AOTITorchMmOutTypedTest, Identity) { + int64_t N = 32; + std::vector h_A(N * N, 0.0f), h_B(N * N, 0.0f); + for (int64_t i = 0; i < N; i++) { + h_A[i * N + i] = 1.0f; + h_B[i * N + i] = static_cast(i + 1); + } + this->runExactTest(N, N, N, h_A, h_B); +} + +// -- Non-typed tests (contract validation) ----------------------------------- + +class AOTITorchMmOutTest : public ::testing::Test { + protected: + void SetUp() override { + et_pal_init(); + int device_count = 0; + if (cudaGetDeviceCount(&device_count) != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available"; + } + } + + Tensor* createTensor( + const std::vector& sizes, + slim_c10::ScalarType dtype) { + Tensor* tensor = nullptr; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, + static_cast(dtype), + static_cast(slim_c10::DeviceType::CUDA), + 0, + &tensor); + return (error == Error::Ok) ? tensor : nullptr; + } +}; + +TEST_F(AOTITorchMmOutTest, InnerDimensionMismatch) { + Tensor* self = createTensor({4, 8}, slim_c10::ScalarType::Float); + Tensor* mat2 = createTensor({6, 6}, slim_c10::ScalarType::Float); + Tensor* out = createTensor({4, 6}, slim_c10::ScalarType::Float); + EXPECT_EQ(aoti_torch_cuda_mm_out(out, self, mat2), Error::InvalidArgument); + aoti_torch_delete_tensor_object(self); + aoti_torch_delete_tensor_object(mat2); + aoti_torch_delete_tensor_object(out); +} + +TEST_F(AOTITorchMmOutTest, NullOut) { + Tensor* self = createTensor({4, 8}, slim_c10::ScalarType::Float); + Tensor* mat2 = createTensor({8, 6}, slim_c10::ScalarType::Float); + EXPECT_EQ( + aoti_torch_cuda_mm_out(nullptr, self, mat2), Error::InvalidArgument); + aoti_torch_delete_tensor_object(self); + aoti_torch_delete_tensor_object(mat2); +} + +TEST_F(AOTITorchMmOutTest, NullSelf) { + Tensor* mat2 = createTensor({8, 6}, slim_c10::ScalarType::Float); + Tensor* out = createTensor({4, 6}, slim_c10::ScalarType::Float); + EXPECT_EQ(aoti_torch_cuda_mm_out(out, nullptr, mat2), Error::InvalidArgument); + aoti_torch_delete_tensor_object(mat2); + aoti_torch_delete_tensor_object(out); +} + +TEST_F(AOTITorchMmOutTest, NullMat2) { + Tensor* self = createTensor({4, 8}, slim_c10::ScalarType::Float); + Tensor* out = createTensor({4, 6}, slim_c10::ScalarType::Float); + EXPECT_EQ(aoti_torch_cuda_mm_out(out, self, nullptr), Error::InvalidArgument); + aoti_torch_delete_tensor_object(self); + aoti_torch_delete_tensor_object(out); +} + +TEST_F(AOTITorchMmOutTest, DtypeMismatch) { + Tensor* self = createTensor({4, 8}, slim_c10::ScalarType::Float); + Tensor* mat2 = createTensor({8, 6}, slim_c10::ScalarType::BFloat16); + Tensor* out = createTensor({4, 6}, slim_c10::ScalarType::Float); + EXPECT_EQ(aoti_torch_cuda_mm_out(out, self, mat2), Error::InvalidArgument); + aoti_torch_delete_tensor_object(self); + aoti_torch_delete_tensor_object(mat2); + aoti_torch_delete_tensor_object(out); +} + +TEST_F(AOTITorchMmOutTest, NonContiguousRejected) { + // Create a [8, 8] tensor and slice rows to get non-contiguous [4, 8] + int64_t big_sizes[] = {8, 8}; + int64_t big_strides[] = {8, 1}; + Tensor* big = nullptr; + aoti_torch_empty_strided( + 2, + big_sizes, + big_strides, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CUDA), + 0, + &big); + ASSERT_NE(big, nullptr); + + // Create a non-contiguous view by using stride(0)=16 on a [4, 8] shape + int64_t nc_sizes[] = {4, 8}; + int64_t nc_strides[] = {16, 1}; // stride(0) > size(1), non-contiguous + Tensor* nc = nullptr; + aoti_torch_empty_strided( + 2, + nc_sizes, + nc_strides, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CUDA), + 0, + &nc); + ASSERT_NE(nc, nullptr); + + Tensor* mat2 = createTensor({8, 6}, slim_c10::ScalarType::Float); + Tensor* out = createTensor({4, 6}, slim_c10::ScalarType::Float); + + EXPECT_EQ(aoti_torch_cuda_mm_out(out, nc, mat2), Error::InvalidArgument); + + aoti_torch_delete_tensor_object(big); + aoti_torch_delete_tensor_object(nc); + aoti_torch_delete_tensor_object(mat2); + aoti_torch_delete_tensor_object(out); +}