diff --git a/.github/.licenserc.yaml b/.github/.licenserc.yaml index 7cae0725..af8f72e3 100644 --- a/.github/.licenserc.yaml +++ b/.github/.licenserc.yaml @@ -22,6 +22,7 @@ header: - docs/Makefile - LICENSE - cmake/config.hpp.in + - cmake/FindcuTENSOR.cmake - docs/requirements.txt - docs/source/bibliography/*.bib - version.txt diff --git a/CMakeLists.txt b/CMakeLists.txt index 66157314..9f7f4ddd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -29,16 +29,34 @@ include(get_cmaize) set(project_inc_dir "${CMAKE_CURRENT_LIST_DIR}/include/${PROJECT_NAME}") set(project_src_dir "${CMAKE_CURRENT_LIST_DIR}/src/${PROJECT_NAME}") +# Documentation include(nwx_cxx_api_docs) nwx_cxx_api_docs("${project_inc_dir}" "${project_src_dir}") +## Extensions ## +set(SOURCE_EXTS "cpp") +set(INCLUDE_EXTS "hpp") + ### Options ### cmaize_option_list( BUILD_TESTING OFF "Should we build the tests?" BUILD_PYBIND11_PYBINDINGS ON "Should we build Python3 bindings?" ENABLE_SIGMA OFF "Should we enable Sigma for uncertainty tracking?" + ENABLE_CUTENSOR OFF "Should we enable cuTENSOR?" ) +if("${ENABLE_CUTENSOR}") + if("${ENABLE_SIGMA}") + set(MSG "Sigma is not compatible with cuTENSOR. Turning Sigma OFF.") + message(WARNING ${MSG}) + set(ENABLE_SIGMA OFF) + endif() + enable_language(CUDA) + set(SOURCE_EXTS ${SOURCE_EXTS} cu) + set(INCLUDE_EXTS ${INCLUDE_EXTS} cuh hu) + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") +endif() + ### Dependendencies ### include(get_utilities) @@ -64,15 +82,27 @@ cmaize_find_or_build_optional_dependency( CMAKE_ARGS BUILD_TESTING=OFF ENABLE_EIGEN_SUPPORT=ON ) +set(DEPENDENCIES utilities parallelzone Boost::boost eigen sigma) + +if("${ENABLE_CUTENSOR}") + include(cmake/FindcuTENSOR.cmake) + list(APPEND DEPENDENCIES cuTENSOR::cuTENSOR) +endif() cmaize_add_library( ${PROJECT_NAME} SOURCE_DIR "${project_src_dir}" + SOURCE_EXTS "${SOURCE_EXTS}" INCLUDE_DIRS "${project_inc_dir}" - DEPENDS utilities parallelzone Boost::boost eigen sigma + INCLUDE_EXTS "${INCLUDE_EXTS}" + DEPENDS "${DEPENDENCIES}" ) target_include_directories(${PROJECT_NAME} PUBLIC "${CMAKE_CURRENT_BINARY_DIR}") +if("${ENABLE_CUTENSOR}") + target_compile_definitions("${PROJECT_NAME}" PUBLIC ENABLE_CUTENSOR) +endif() + include(nwx_pybind11) nwx_add_pybind11_module( ${PROJECT_NAME} diff --git a/cmake/FindcuTENSOR.cmake b/cmake/FindcuTENSOR.cmake new file mode 100644 index 00000000..17b14b48 --- /dev/null +++ b/cmake/FindcuTENSOR.cmake @@ -0,0 +1,135 @@ +#============================================================================= +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#============================================================================= + +#[=======================================================================[.rst: +FindcuTENSOR +-------- + +Find cuTENSOR + +Imported targets +^^^^^^^^^^^^^^^^ + +This module defines the following :prop_tgt:`IMPORTED` target(s): + +``cuTENSOR::cuTENSOR`` + The cuTENSOR library, if found. + +Result variables +^^^^^^^^^^^^^^^^ + +This module will set the following variables in your project: + +``cuTENSOR_FOUND`` + True if cuTENSOR is found. +``cuTENSOR_INCLUDE_DIRS`` + The include directories needed to use cuTENSOR. +``cuTENSOR_LIBRARIES`` + The libraries needed to usecuTENSOR. +``cuTENSOR_VERSION_STRING`` + The version of the cuTENSOR library found. [OPTIONAL] + +#]=======================================================================] + +# Prefer using a Config module if it exists for this project +set(cuTENSOR_NO_CONFIG FALSE) +if(NOT cuTENSOR_NO_CONFIG) + find_package(cuTENSOR CONFIG QUIET HINTS ${cutensor_DIR}) + if(cuTENSOR_FOUND) + find_package_handle_standard_args(cuTENSOR DEFAULT_MSG cuTENSOR_CONFIG) + return() + endif() +endif() + +find_path(cuTENSOR_INCLUDE_DIR NAMES cutensor.h ) + +set(cuTENSOR_IS_HEADER_ONLY FALSE) +if(NOT cuTENSOR_LIBRARY AND NOT cuTENSOR_IS_HEADER_ONLY) + find_library(cuTENSOR_LIBRARY_RELEASE NAMES libcutensor.so NAMES_PER_DIR ) + find_library(cuTENSOR_LIBRARY_DEBUG NAMES libcutensor.sod NAMES_PER_DIR ) + + include(${CMAKE_ROOT}/Modules/SelectLibraryConfigurations.cmake) + select_library_configurations(cuTENSOR) + unset(cuTENSOR_FOUND) #incorrectly set by select_library_configurations +endif() + +include(${CMAKE_ROOT}/Modules/FindPackageHandleStandardArgs.cmake) + +if(cuTENSOR_IS_HEADER_ONLY) + find_package_handle_standard_args(cuTENSOR + REQUIRED_VARS cuTENSOR_INCLUDE_DIR + VERSION_VAR ) +else() + find_package_handle_standard_args(cuTENSOR + REQUIRED_VARS cuTENSOR_LIBRARY cuTENSOR_INCLUDE_DIR + VERSION_VAR ) +endif() + +if(NOT cuTENSOR_FOUND) + set(CUTENSOR_FILENAME libcutensor-linux-x86_64-${CUTENSOR_VERSION}-archive) + + message(STATUS "cuTENSOR not found. Downloading library. By continuing this download you accept to the license terms of cuTENSOR") + + CPMAddPackage( + NAME cutensor + VERSION ${CUTENSOR_VERSION} + URL https://developer.download.nvidia.com/compute/cutensor/redist/libcutensor/linux-x86_64/libcutensor-linux-x86_64-${CUTENSOR_VERSION}-archive.tar.xz + # Eigen's CMakelists are not intended for library use + DOWNLOAD_ONLY YES + ) + + set(cuTENSOR_LIBRARY ${cutensor_SOURCE_DIR}/lib/${CUDAToolkit_VERSION_MAJOR}/libcutensor.so) + set(cuTENSOR_INCLUDE_DIR ${cutensor_SOURCE_DIR}/include) + + + set(cuTENSOR_FOUND TRUE) +endif() + +if(cuTENSOR_FOUND) + set(cuTENSOR_INCLUDE_DIRS ${cuTENSOR_INCLUDE_DIR}) + + if(NOT cuTENSOR_LIBRARIES) + set(cuTENSOR_LIBRARIES ${cuTENSOR_LIBRARY}) + endif() + + if(NOT TARGET cuTENSOR::cuTENSOR) + add_library(cuTENSOR::cuTENSOR UNKNOWN IMPORTED) + set_target_properties(cuTENSOR::cuTENSOR PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${cuTENSOR_INCLUDE_DIRS}") + + if(cuTENSOR_LIBRARY_RELEASE) + set_property(TARGET cuTENSOR::cuTENSOR APPEND PROPERTY + IMPORTED_CONFIGURATIONS RELEASE) + set_target_properties(cuTENSOR::cuTENSOR PROPERTIES + IMPORTED_LOCATION_RELEASE "${cuTENSOR_LIBRARY_RELEASE}") + endif() + + if(cuTENSOR_LIBRARY_DEBUG) + set_property(TARGET cuTENSOR::cuTENSOR APPEND PROPERTY + IMPORTED_CONFIGURATIONS DEBUG) + set_target_properties(cuTENSOR::cuTENSOR PROPERTIES + IMPORTED_LOCATION_DEBUG "${cuTENSOR_LIBRARY_DEBUG}") + endif() + + if(NOT cuTENSOR_LIBRARY_RELEASE AND NOT cuTENSOR_LIBRARY_DEBUG) + set_property(TARGET cuTENSOR::cuTENSOR APPEND PROPERTY + IMPORTED_LOCATION "${cuTENSOR_LIBRARY}") + endif() + endif() +endif() + +unset(cuTENSOR_NO_CONFIG) +unset(cuTENSOR_IS_HEADER_ONLY) \ No newline at end of file diff --git a/src/tensorwrapper/buffer/detail_/cutensor_traits.cuh b/src/tensorwrapper/buffer/detail_/cutensor_traits.cuh new file mode 100644 index 00000000..005039bd --- /dev/null +++ b/src/tensorwrapper/buffer/detail_/cutensor_traits.cuh @@ -0,0 +1,41 @@ +/* + * Copyright 2025 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#ifdef ENABLE_CUTENSOR +#include +#include + +namespace tensorwrapper::buffer::detail_ { + +// Traits for cuTENSOR based on the floating point type +template +struct cutensor_traits {}; + +template<> +struct cutensor_traits { + cutensorDataType_t cutensorDataType = CUTENSOR_R_32F; + cutensorComputeDescriptor_t descCompute = CUTENSOR_COMPUTE_DESC_32F; +}; + +template<> +struct cutensor_traits { + cutensorDataType_t cutensorDataType = CUTENSOR_R_64F; + cutensorComputeDescriptor_t descCompute = CUTENSOR_COMPUTE_DESC_64F; +}; + +} // namespace tensorwrapper::buffer::detail_ + +#endif \ No newline at end of file diff --git a/src/tensorwrapper/buffer/detail_/eigen_tensor.cpp b/src/tensorwrapper/buffer/detail_/eigen_tensor.cpp index 0d66206f..7c08a5ee 100644 --- a/src/tensorwrapper/buffer/detail_/eigen_tensor.cpp +++ b/src/tensorwrapper/buffer/detail_/eigen_tensor.cpp @@ -18,6 +18,10 @@ #include "../contraction_planner.hpp" #include "eigen_tensor.hpp" +#ifdef ENABLE_CUTENSOR +#include "eigen_tensor.cuh" +#endif + namespace tensorwrapper::buffer::detail_ { #define TPARAMS template @@ -96,6 +100,16 @@ void EIGEN_TENSOR::contraction_assignment_(label_type olabels, const_pimpl_reference rhs) { ContractionPlanner plan(olabels, llabels, rlabels); +#ifdef ENABLE_CUTENSOR + // Prepare m_tensor_ + m_tensor_ = allocate_from_shape_(result_shape.as_smooth(), + std::make_index_sequence()); + m_tensor_.setZero(); + + // Dispatch to cuTENSOR + cutensor_contraction(olabels, llabels, rlabels, result_shape, lhs, + rhs, m_tensor_); +#else auto lt = lhs.clone(); auto rt = rhs.clone(); lt->permute_assignment(plan.lhs_permutation(), llabels, lhs); @@ -140,6 +154,7 @@ void EIGEN_TENSOR::contraction_assignment_(label_type olabels, } else { m_tensor_ = tensor; } +#endif mark_for_rehash_(); } diff --git a/src/tensorwrapper/buffer/detail_/eigen_tensor.cu b/src/tensorwrapper/buffer/detail_/eigen_tensor.cu new file mode 100644 index 00000000..516abff1 --- /dev/null +++ b/src/tensorwrapper/buffer/detail_/eigen_tensor.cu @@ -0,0 +1,265 @@ +/* + * Copyright 2025 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_CUTENSOR +#include "cutensor_traits.cuh" +#include "eigen_tensor.cuh" +#include +#include + +namespace tensorwrapper::buffer::detail_ { + +// Handle cuda errors +#define HANDLE_CUDA_ERROR(x) \ + { \ + const auto err = x; \ + if(err != cudaSuccess) { \ + printf("Error: %s\n", cudaGetErrorString(err)); \ + exit(-1); \ + } \ + }; + +// Handle cuTENSOR errors +#define HANDLE_CUTENSOR_ERROR(x) \ + { \ + const auto err = x; \ + if(err != CUTENSOR_STATUS_SUCCESS) { \ + printf("Error: %s\n", cutensorGetErrorString(err)); \ + exit(-1); \ + } \ + }; + +// Some common typedefs +using mode_vector_t = std::vector; +using int64_vector_t = std::vector; + +// Convert a label into a vector of modes +template +mode_vector_t label_to_modes(const LabelType& label) { + mode_vector_t mode; + for(const auto& i : label) { mode.push_back(i.data()[0]); } + return mode; +} + +// Query extent information from an input +template +int64_vector_t get_extents(const InfoType& info) { + int64_vector_t extent; + for(std::size_t i = 0; i < info.rank(); ++i) { + extent.push_back((int64_t)info.extent(i)); + } + return extent; +} + +// Compute strides in row major +int64_vector_t get_strides(std::size_t N, const int64_vector_t& extent) { + int64_vector_t strides; + for(std::size_t i = 0; i < N; ++i) { + int64_t product = 1; + for(std::size_t j = N - 1; j > i; --j) product *= extent[j]; + strides.push_back(product); + } + return strides; +} + +// Perform tensor contraction with cuTENSOR +template +void cutensor_contraction(typename TensorType::label_type c_label, + typename TensorType::label_type a_label, + typename TensorType::label_type b_label, + typename TensorType::const_shape_reference c_shape, + typename TensorType::const_pimpl_reference A, + typename TensorType::const_pimpl_reference B, + typename TensorType::eigen_reference C) { + using element_t = typename TensorType::element_type; + using eigen_data_t = typename TensorType::eigen_data_type; + + // GEMM alpha and beta (hardcoded for now) + element_t alpha = 1.0; + element_t beta = 0.0; + + // The modes of the tensors + mode_vector_t a_modes = label_to_modes(a_label); + mode_vector_t b_modes = label_to_modes(b_label); + mode_vector_t c_modes = label_to_modes(c_label); + + // The extents of each tensor + int64_vector_t a_extents = get_extents(A); + int64_vector_t b_extents = get_extents(B); + int64_vector_t c_extents = get_extents(c_shape.as_smooth()); + + // The strides of each tensor + int64_vector_t a_strides = get_strides(A.rank(), a_extents); + int64_vector_t b_strides = get_strides(B.rank(), b_extents); + int64_vector_t c_strides = get_strides(c_shape.rank(), c_extents); + + // The size of each tensor + std::size_t a_size = sizeof(element_t) * A.size(); + std::size_t b_size = sizeof(element_t) * B.size(); + std::size_t c_size = sizeof(element_t) * c_shape.size(); + + // Allocate on device + void *A_d, *B_d, *C_d; + cudaMalloc((void**)&A_d, a_size); + cudaMalloc((void**)&B_d, b_size); + cudaMalloc((void**)&C_d, c_size); + + // Copy to data to device + HANDLE_CUDA_ERROR( + cudaMemcpy(A_d, A.get_immutable_data(), a_size, cudaMemcpyHostToDevice)); + HANDLE_CUDA_ERROR( + cudaMemcpy(B_d, B.get_immutable_data(), b_size, cudaMemcpyHostToDevice)); + HANDLE_CUDA_ERROR( + cudaMemcpy(C_d, C.data(), c_size, cudaMemcpyHostToDevice)); + + // Assert alignment + const uint32_t kAlignment = + 128; // Alignment of the global-memory device pointers (bytes) + assert(uintptr_t(A_d) % kAlignment == 0); + assert(uintptr_t(B_d) % kAlignment == 0); + assert(uintptr_t(C_d) % kAlignment == 0); + + // cuTENSOR traits + cutensor_traits traits; + + // cuTENSOR handle + cutensorHandle_t handle; + HANDLE_CUTENSOR_ERROR(cutensorCreate(&handle)); + + // Create Tensor Descriptors + cutensorTensorDescriptor_t descA; + HANDLE_CUTENSOR_ERROR(cutensorCreateTensorDescriptor( + handle, &descA, A.rank(), a_extents.data(), a_strides.data(), + traits.cutensorDataType, kAlignment)); + + cutensorTensorDescriptor_t descB; + HANDLE_CUTENSOR_ERROR(cutensorCreateTensorDescriptor( + handle, &descB, B.rank(), b_extents.data(), b_strides.data(), + traits.cutensorDataType, kAlignment)); + + cutensorTensorDescriptor_t descC; + HANDLE_CUTENSOR_ERROR(cutensorCreateTensorDescriptor( + handle, &descC, c_shape.rank(), c_extents.data(), c_strides.data(), + traits.cutensorDataType, kAlignment)); + + // Create Contraction Descriptor + cutensorOperationDescriptor_t desc; + HANDLE_CUTENSOR_ERROR(cutensorCreateContraction( + handle, &desc, // Base + descA, a_modes.data(), CUTENSOR_OP_IDENTITY, // A + descB, b_modes.data(), CUTENSOR_OP_IDENTITY, // B + descC, c_modes.data(), CUTENSOR_OP_IDENTITY, // C + descC, c_modes.data(), traits.descCompute // Result + )); + + // Ensure that the scalar type is correct. + cutensorDataType_t scalarType; + HANDLE_CUTENSOR_ERROR(cutensorOperationDescriptorGetAttribute( + handle, desc, CUTENSOR_OPERATION_DESCRIPTOR_SCALAR_TYPE, + (void*)&scalarType, sizeof(scalarType))); + assert(scalarType == traits.cutensorDataType); + + // Set the algorithm to use + const cutensorAlgo_t algo = CUTENSOR_ALGO_DEFAULT; + cutensorPlanPreference_t planPref; + HANDLE_CUTENSOR_ERROR(cutensorCreatePlanPreference(handle, &planPref, algo, + CUTENSOR_JIT_MODE_NONE)); + + // Query workspace estimate + uint64_t workspaceSizeEstimate = 0; + const cutensorWorksizePreference_t workspacePref = + CUTENSOR_WORKSPACE_DEFAULT; + HANDLE_CUTENSOR_ERROR(cutensorEstimateWorkspaceSize( + handle, desc, planPref, workspacePref, &workspaceSizeEstimate)); + + // Create Contraction Plan + cutensorPlan_t plan; + HANDLE_CUTENSOR_ERROR( + cutensorCreatePlan(handle, &plan, desc, planPref, workspaceSizeEstimate)); + + // Determine workspace size and allocate + uint64_t actualWorkspaceSize = 0; + HANDLE_CUTENSOR_ERROR(cutensorPlanGetAttribute( + handle, plan, CUTENSOR_PLAN_REQUIRED_WORKSPACE, &actualWorkspaceSize, + sizeof(actualWorkspaceSize))); + assert(actualWorkspaceSize <= workspaceSizeEstimate); + + void* work = nullptr; + if(actualWorkspaceSize > 0) { + HANDLE_CUDA_ERROR(cudaMalloc(&work, actualWorkspaceSize)); + assert(uintptr_t(work) % 128 == + 0); // workspace must be aligned to 128 byte-boundary + } + + // Execute + cudaStream_t stream; + HANDLE_CUDA_ERROR(cudaStreamCreate(&stream)); + HANDLE_CUTENSOR_ERROR(cutensorContract(handle, plan, (void*)&alpha, A_d, + B_d, (void*)&beta, C_d, C_d, work, + actualWorkspaceSize, stream)); + + // Copy Results from Device + HANDLE_CUDA_ERROR( + cudaMemcpy(C.data(), C_d, c_size, cudaMemcpyDeviceToHost)); + + // Free allocated memory + HANDLE_CUTENSOR_ERROR(cutensorDestroy(handle)); + HANDLE_CUTENSOR_ERROR(cutensorDestroyPlan(plan)); + HANDLE_CUTENSOR_ERROR(cutensorDestroyOperationDescriptor(desc)); + HANDLE_CUTENSOR_ERROR(cutensorDestroyTensorDescriptor(descA)); + HANDLE_CUTENSOR_ERROR(cutensorDestroyTensorDescriptor(descB)); + HANDLE_CUTENSOR_ERROR(cutensorDestroyTensorDescriptor(descC)); + HANDLE_CUDA_ERROR(cudaStreamDestroy(stream)); + if(A_d) cudaFree(A_d); + if(B_d) cudaFree(B_d); + if(C_d) cudaFree(C_d); + if(work) cudaFree(work); +} + +#undef HANDLE_CUTENSOR_ERROR +#undef HANDLE_CUDA_ERROR + +// Template instantiations +#define FUNCTION_INSTANTIATE(TYPE, RANK) \ + template void cutensor_contraction>( \ + typename EigenTensor::label_type, \ + typename EigenTensor::label_type, \ + typename EigenTensor::label_type, \ + typename EigenTensor::const_shape_reference, \ + typename EigenTensor::const_pimpl_reference, \ + typename EigenTensor::const_pimpl_reference, \ + typename EigenTensor::eigen_reference) + +#define DEFINE_CUTENSOR_CONTRACTION(TYPE) \ + FUNCTION_INSTANTIATE(TYPE, 0); \ + FUNCTION_INSTANTIATE(TYPE, 1); \ + FUNCTION_INSTANTIATE(TYPE, 2); \ + FUNCTION_INSTANTIATE(TYPE, 3); \ + FUNCTION_INSTANTIATE(TYPE, 4); \ + FUNCTION_INSTANTIATE(TYPE, 5); \ + FUNCTION_INSTANTIATE(TYPE, 6); \ + FUNCTION_INSTANTIATE(TYPE, 7); \ + FUNCTION_INSTANTIATE(TYPE, 8); \ + FUNCTION_INSTANTIATE(TYPE, 9); \ + FUNCTION_INSTANTIATE(TYPE, 10) + +TW_APPLY_FLOATING_POINT_TYPES(DEFINE_CUTENSOR_CONTRACTION); + +#undef DEFINE_CUTENSOR_CONTRACTION +#undef FUNCTION_INSTANTIATE + +} // namespace tensorwrapper::buffer::detail_ + +#endif \ No newline at end of file diff --git a/src/tensorwrapper/buffer/detail_/eigen_tensor.cuh b/src/tensorwrapper/buffer/detail_/eigen_tensor.cuh new file mode 100644 index 00000000..6e026667 --- /dev/null +++ b/src/tensorwrapper/buffer/detail_/eigen_tensor.cuh @@ -0,0 +1,46 @@ +/* + * Copyright 2025 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#ifdef ENABLE_CUTENSOR +#include "eigen_tensor.hpp" + +namespace tensorwrapper::buffer::detail_ { + +/** @brief Performs a tensor contraction on GPU + * + * @param[in] olabel The labels for the modes of the output. + * @param[in] llabel The labels for the modes of the left hand tensor. + * @param[in] rlabel The labels for the modes of the right hand tensor. + * @param[in] result_shape The intended shape of the result. + * @param[in] lhs The left hand tensor. + * @param[in] rhs The right hand tensor. + * @param[in, out] result The eigen tensor where the results are stored. + * + * @throw std::bad_alloc if there is a problem allocating the copy of + * @p layout. Strong throw guarantee. + */ +template +void cutensor_contraction(typename TensorType::label_type c_label, + typename TensorType::label_type a_label, + typename TensorType::label_type b_label, + typename TensorType::const_shape_reference c_shape, + typename TensorType::const_pimpl_reference A, + typename TensorType::const_pimpl_reference B, + typename TensorType::eigen_reference C); + +} // namespace tensorwrapper::buffer::detail_ + +#endif \ No newline at end of file