From f4aedcda9307f70597fb8a56ce23b6f5940a9e46 Mon Sep 17 00:00:00 2001 From: Sam Mish Date: Fri, 17 May 2024 15:01:16 -0700 Subject: [PATCH 01/12] adding enzyme to serac, trying to bring it in to Functional kernels --- CMakeLists.txt | 5 + src/serac/numerics/functional/CMakeLists.txt | 1 + .../functional/domain_integral_kernels.hpp | 98 +++++++++++++------ .../functional/enzyme_declarations.hpp | 12 +++ .../numerics/functional/tests/CMakeLists.txt | 4 + .../numerics/functional/tests/enzyme_test.cpp | 17 ++++ 6 files changed, 106 insertions(+), 31 deletions(-) create mode 100644 src/serac/numerics/functional/enzyme_declarations.hpp create mode 100644 src/serac/numerics/functional/tests/enzyme_test.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 6657266a5e..5cdd182e16 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,6 +17,11 @@ if(ENABLE_CUDA AND ${CMAKE_VERSION} VERSION_LESS 3.18.0) message(FATAL_ERROR "Serac requires CMake version 3.18.0+ when CUDA is enabled.") endif() +# N.B. leave compilers unspecified when configuring CMake to ensure that +# the clang/clang++ binaries appropriate for enzyme are chosen +set(CMAKE_C_COMPILER "${LLVM_BIN_DIR}/clang") +set(CMAKE_CXX_COMPILER "${LLVM_BIN_DIR}/clang++") + project(serac LANGUAGES CXX C) # MPI is required in Serac. diff --git a/src/serac/numerics/functional/CMakeLists.txt b/src/serac/numerics/functional/CMakeLists.txt index b7edf37419..e1e0749181 100644 --- a/src/serac/numerics/functional/CMakeLists.txt +++ b/src/serac/numerics/functional/CMakeLists.txt @@ -71,6 +71,7 @@ blt_add_library( DEPENDS_ON ${functional_depends} ) +target_link_libraries(serac_functional PUBLIC ClangEnzymeFlags) install(FILES ${functional_headers} DESTINATION include/serac/numerics/functional ) install(FILES ${functional_detail_headers} DESTINATION include/serac/numerics/functional/detail ) diff --git a/src/serac/numerics/functional/domain_integral_kernels.hpp b/src/serac/numerics/functional/domain_integral_kernels.hpp index 9f13cd42fb..128f74982c 100644 --- a/src/serac/numerics/functional/domain_integral_kernels.hpp +++ b/src/serac/numerics/functional/domain_integral_kernels.hpp @@ -102,25 +102,76 @@ auto get_derivative_type(lambda qf, qpt_data_type&& qpt_data) make_dual_wrt(qf_arguments{}))); }; -template + +template +SERAC_HOST_DEVICE auto parent_to_physical(const tuple& qf_input, const tensor & invJ) { + return tuple{get<0>(qf_input), dot(get<1>(qf_input), invJ)}; +} + +template +SERAC_HOST_DEVICE auto physical_to_parent(const tuple& qf_output, const tensor & invJ, double detJ) { + // assumes family == Family::H1 for now + return tuple{get<0>(qf_output) * detJ, dot(get<1>(qf_output), transpose(invJ)) * detJ}; +} + +double scalar_type(double) { return 0.0; } + +template < typename T > +dual scalar_type(dual) { return dual{}; } + +template < typename T, int ... n > +T scalar_type(tensor) { return {}; } + +template < typename T1, typename T2 > +auto scalar_type(tuple< T1, T2 >) { return tuple{ scalar_type(T1{}), scalar_type(T2{}) }; } + +template < uint32_t differentiation_index, typename lambda, int dim, int n, typename... T> SERAC_HOST_DEVICE auto batch_apply_qf_no_qdata(lambda qf, double t, const tensor& x, const tensor& J, const T&... inputs) { - using position_t = serac::tuple, tensor >; - using return_type = decltype(qf(double{}, position_t{}, T{}[0]...)); - tensor outputs{}; - for (int i = 0; i < n; i++) { - tensor x_q; - tensor J_q; - for (int j = 0; j < dim; j++) { - for (int k = 0; k < dim; k++) { - J_q[j][k] = J(k, j, i); + if constexpr (differentiation_index == NO_DIFFERENTIATION) { + using position_t = serac::tuple, tensor >; + using return_type = decltype(qf(double{}, position_t{}, T{}[0]...)); + + tensor outputs{}; + for (int i = 0; i < n; i++) { + tensor x_q; + tensor J_q; + for (int j = 0; j < dim; j++) { + for (int k = 0; k < dim; k++) { + J_q[j][k] = J(k, j, i); + } + x_q[j] = x(j, i); } - x_q[j] = x(j, i); + double detJ_q = det(J_q); + tensor invJ_q = inv(J_q); + auto qf_output = qf(t, serac::tuple{x_q, J_q}, parent_to_physical(inputs[i], invJ_q) ...); + physical_to_parent(qf_output, invJ_q, detJ_q); + outputs[i] = qf_output; } - outputs[i] = qf(t, serac::tuple{x_q, J_q}, inputs[i]...); + return outputs; + + } else { + using position_t = serac::tuple, tensor >; + using return_value_type = decltype(qf(double{}, position_t{}, get_value(T{}[0])...)); + using gradient_type = decltype(get_gradient((scalar_type(inputs[0]) + ...))); + + tensor outputs{}; + for (int i = 0; i < n; i++) { + tensor x_q; + tensor J_q; + for (int j = 0; j < dim; j++) { + for (int k = 0; k < dim; k++) { + J_q[j][k] = J(k, j, i); + } + x_q[j] = x(j, i); + } + double detJ_q = det(J_q); + tensor invJ_q = inv(J_q); + outputs[i] = physical_to_parent(qf(t, serac::tuple{x_q, J_q}, parent_to_physical(inputs[i], invJ_q) ...), invJ_q, detJ_q); + } + return outputs; } - return outputs; } template @@ -179,36 +230,21 @@ void evaluation_kernel_impl(trial_element_tuple trial_elements, test_element, do //[[maybe_unused]] static constexpr trial_element_tuple trial_element_tuple{}; // batch-calculate values / derivatives of each trial space, at each quadrature point - [[maybe_unused]] tuple qf_inputs = {promote_each_to_dual_when( - get(trial_elements).interpolate(get(u)[elements[e]], rule))...}; - - // use J_e to transform values / derivatives on the parent element - // to the to the corresponding values / derivatives on the physical element - (parent_to_physical(trial_elements).family>(get(qf_inputs), J_e), ...); + tuple qf_inputs = {get(trial_elements).interpolate(get(u)[elements[e]], rule)...}; // (batch) evalute the q-function at each quadrature point // // note: the weird immediately-invoked lambda expression is // a workaround for a bug in GCC(<12.0) where it fails to // decide which function overload to use, and crashes - auto qf_outputs = [&]() { - if constexpr (std::is_same_v) { - return batch_apply_qf_no_qdata(qf, t, x_e, J_e, get(qf_inputs)...); - } else { - return batch_apply_qf(qf, t, x_e, J_e, &qf_state(e, 0), update_state, get(qf_inputs)...); - } - }(); - - // use J to transform sources / fluxes on the physical element - // back to the corresponding sources / fluxes on the parent element - physical_to_parent(qf_outputs, J_e); + auto qf_outputs = batch_apply_qf_no_qdata(qf, t, x_e, J_e, get(qf_inputs)...); // write out the q-function derivatives after applying the // physical_to_parent transformation, so that those transformations // won't need to be applied in the action_of_gradient and element_gradient kernels if constexpr (differentiation_index != serac::NO_DIFFERENTIATION) { for (int q = 0; q < leading_dimension(qf_outputs); q++) { - qf_derivatives[e * uint32_t(qpts_per_elem) + uint32_t(q)] = get_gradient(qf_outputs[q]); + //qf_derivatives[e * uint32_t(qpts_per_elem) + uint32_t(q)] = get_gradient(qf_outputs[q]); } } diff --git a/src/serac/numerics/functional/enzyme_declarations.hpp b/src/serac/numerics/functional/enzyme_declarations.hpp new file mode 100644 index 0000000000..ef6ba8cf67 --- /dev/null +++ b/src/serac/numerics/functional/enzyme_declarations.hpp @@ -0,0 +1,12 @@ +#pragma once + +int enzyme_dup; +int enzyme_dupnoneed; +int enzyme_out; +int enzyme_const; + +template < typename return_type, typename ... T > +return_type __enzyme_fwddiff(void*, T ... ); + +template < typename return_type, typename ... T > +return_type __enzyme_autodiff(void*, T ... ); diff --git a/src/serac/numerics/functional/tests/CMakeLists.txt b/src/serac/numerics/functional/tests/CMakeLists.txt index bb0f646336..9c7fdb8490 100644 --- a/src/serac/numerics/functional/tests/CMakeLists.txt +++ b/src/serac/numerics/functional/tests/CMakeLists.txt @@ -68,3 +68,7 @@ if(ENABLE_CUDA) DEPENDS_ON gtest serac_functional serac_state ${functional_depends}) endif() + +add_executable(enzyme_test enzyme_test.cpp) + +target_link_libraries(enzyme_test PUBLIC serac_functional) diff --git a/src/serac/numerics/functional/tests/enzyme_test.cpp b/src/serac/numerics/functional/tests/enzyme_test.cpp new file mode 100644 index 0000000000..0e551d4f20 --- /dev/null +++ b/src/serac/numerics/functional/tests/enzyme_test.cpp @@ -0,0 +1,17 @@ +#include + +#include "serac/numerics/functional/enzyme_declarations.hpp" + +double square(double x) { + return x * x; +} + +double dsquare(double x) { + return __enzyme_autodiff(reinterpret_cast(square), x); +} + +int main() { + for(double i=1; i<5; i++) { + std::cout << square(i) << " " << dsquare(i) << std::endl; + } +} From 59545612063faed45261665783f77ecbcd1e451c Mon Sep 17 00:00:00 2001 From: Sam Mish Date: Fri, 17 May 2024 15:01:52 -0700 Subject: [PATCH 02/12] forgot to commit FetchContent for enzyme --- cmake/thirdparty/SetupSeracThirdParty.cmake | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/cmake/thirdparty/SetupSeracThirdParty.cmake b/cmake/thirdparty/SetupSeracThirdParty.cmake index a0109a3df4..3f9816de6e 100644 --- a/cmake/thirdparty/SetupSeracThirdParty.cmake +++ b/cmake/thirdparty/SetupSeracThirdParty.cmake @@ -559,3 +559,16 @@ if (NOT SERAC_THIRD_PARTY_LIBRARIES_FOUND) endforeach() endif() endif() + + +#------------------------------------------------------------------------------ +# Enzyme +#------------------------------------------------------------------------------ +include(FetchContent) + +FetchContent_Declare( + enzyme + URL https://github.com/EnzymeAD/Enzyme/archive/refs/tags/v0.0.110.tar.gz + SOURCE_SUBDIR enzyme +) +FetchContent_MakeAvailable(enzyme) From 3fe5f579ff9d0c86520f9e16247051fe72f8e4ec Mon Sep 17 00:00:00 2001 From: Sam Mish Date: Fri, 17 May 2024 15:40:15 -0700 Subject: [PATCH 03/12] working around OpenMP linkage issue by using the built-in cmake taget --- src/serac/numerics/functional/CMakeLists.txt | 8 +- .../functional/domain_integral_kernels.hpp | 98 ++-- .../domain_integral_kernels_new.hpp | 438 ++++++++++++++++++ .../numerics/functional/tests/enzyme_test.cpp | 25 +- .../physics/state/finite_element_vector.hpp | 2 +- 5 files changed, 498 insertions(+), 73 deletions(-) create mode 100644 src/serac/numerics/functional/domain_integral_kernels_new.hpp diff --git a/src/serac/numerics/functional/CMakeLists.txt b/src/serac/numerics/functional/CMakeLists.txt index e1e0749181..113c002107 100644 --- a/src/serac/numerics/functional/CMakeLists.txt +++ b/src/serac/numerics/functional/CMakeLists.txt @@ -69,8 +69,14 @@ blt_add_library( HEADERS ${functional_headers} ${functional_detail_headers} SOURCES ${functional_sources} DEPENDS_ON ${functional_depends} - ) +) +# without this, I get +# "error while loading shared libraries: libomp.so: cannot open +# shared object file: No such file or directory" +# +# are we not using the native OpenMP cmake targets already? +target_link_libraries(serac_functional PUBLIC OpenMP::OpenMP_CXX) target_link_libraries(serac_functional PUBLIC ClangEnzymeFlags) install(FILES ${functional_headers} DESTINATION include/serac/numerics/functional ) diff --git a/src/serac/numerics/functional/domain_integral_kernels.hpp b/src/serac/numerics/functional/domain_integral_kernels.hpp index 128f74982c..9f13cd42fb 100644 --- a/src/serac/numerics/functional/domain_integral_kernels.hpp +++ b/src/serac/numerics/functional/domain_integral_kernels.hpp @@ -102,76 +102,25 @@ auto get_derivative_type(lambda qf, qpt_data_type&& qpt_data) make_dual_wrt(qf_arguments{}))); }; - -template -SERAC_HOST_DEVICE auto parent_to_physical(const tuple& qf_input, const tensor & invJ) { - return tuple{get<0>(qf_input), dot(get<1>(qf_input), invJ)}; -} - -template -SERAC_HOST_DEVICE auto physical_to_parent(const tuple& qf_output, const tensor & invJ, double detJ) { - // assumes family == Family::H1 for now - return tuple{get<0>(qf_output) * detJ, dot(get<1>(qf_output), transpose(invJ)) * detJ}; -} - -double scalar_type(double) { return 0.0; } - -template < typename T > -dual scalar_type(dual) { return dual{}; } - -template < typename T, int ... n > -T scalar_type(tensor) { return {}; } - -template < typename T1, typename T2 > -auto scalar_type(tuple< T1, T2 >) { return tuple{ scalar_type(T1{}), scalar_type(T2{}) }; } - -template < uint32_t differentiation_index, typename lambda, int dim, int n, typename... T> +template SERAC_HOST_DEVICE auto batch_apply_qf_no_qdata(lambda qf, double t, const tensor& x, const tensor& J, const T&... inputs) { - if constexpr (differentiation_index == NO_DIFFERENTIATION) { - using position_t = serac::tuple, tensor >; - using return_type = decltype(qf(double{}, position_t{}, T{}[0]...)); - - tensor outputs{}; - for (int i = 0; i < n; i++) { - tensor x_q; - tensor J_q; - for (int j = 0; j < dim; j++) { - for (int k = 0; k < dim; k++) { - J_q[j][k] = J(k, j, i); - } - x_q[j] = x(j, i); - } - double detJ_q = det(J_q); - tensor invJ_q = inv(J_q); - auto qf_output = qf(t, serac::tuple{x_q, J_q}, parent_to_physical(inputs[i], invJ_q) ...); - physical_to_parent(qf_output, invJ_q, detJ_q); - outputs[i] = qf_output; - } - return outputs; - - } else { - using position_t = serac::tuple, tensor >; - using return_value_type = decltype(qf(double{}, position_t{}, get_value(T{}[0])...)); - using gradient_type = decltype(get_gradient((scalar_type(inputs[0]) + ...))); - - tensor outputs{}; - for (int i = 0; i < n; i++) { - tensor x_q; - tensor J_q; - for (int j = 0; j < dim; j++) { - for (int k = 0; k < dim; k++) { - J_q[j][k] = J(k, j, i); - } - x_q[j] = x(j, i); + using position_t = serac::tuple, tensor >; + using return_type = decltype(qf(double{}, position_t{}, T{}[0]...)); + tensor outputs{}; + for (int i = 0; i < n; i++) { + tensor x_q; + tensor J_q; + for (int j = 0; j < dim; j++) { + for (int k = 0; k < dim; k++) { + J_q[j][k] = J(k, j, i); } - double detJ_q = det(J_q); - tensor invJ_q = inv(J_q); - outputs[i] = physical_to_parent(qf(t, serac::tuple{x_q, J_q}, parent_to_physical(inputs[i], invJ_q) ...), invJ_q, detJ_q); + x_q[j] = x(j, i); } - return outputs; + outputs[i] = qf(t, serac::tuple{x_q, J_q}, inputs[i]...); } + return outputs; } template @@ -230,21 +179,36 @@ void evaluation_kernel_impl(trial_element_tuple trial_elements, test_element, do //[[maybe_unused]] static constexpr trial_element_tuple trial_element_tuple{}; // batch-calculate values / derivatives of each trial space, at each quadrature point - tuple qf_inputs = {get(trial_elements).interpolate(get(u)[elements[e]], rule)...}; + [[maybe_unused]] tuple qf_inputs = {promote_each_to_dual_when( + get(trial_elements).interpolate(get(u)[elements[e]], rule))...}; + + // use J_e to transform values / derivatives on the parent element + // to the to the corresponding values / derivatives on the physical element + (parent_to_physical(trial_elements).family>(get(qf_inputs), J_e), ...); // (batch) evalute the q-function at each quadrature point // // note: the weird immediately-invoked lambda expression is // a workaround for a bug in GCC(<12.0) where it fails to // decide which function overload to use, and crashes - auto qf_outputs = batch_apply_qf_no_qdata(qf, t, x_e, J_e, get(qf_inputs)...); + auto qf_outputs = [&]() { + if constexpr (std::is_same_v) { + return batch_apply_qf_no_qdata(qf, t, x_e, J_e, get(qf_inputs)...); + } else { + return batch_apply_qf(qf, t, x_e, J_e, &qf_state(e, 0), update_state, get(qf_inputs)...); + } + }(); + + // use J to transform sources / fluxes on the physical element + // back to the corresponding sources / fluxes on the parent element + physical_to_parent(qf_outputs, J_e); // write out the q-function derivatives after applying the // physical_to_parent transformation, so that those transformations // won't need to be applied in the action_of_gradient and element_gradient kernels if constexpr (differentiation_index != serac::NO_DIFFERENTIATION) { for (int q = 0; q < leading_dimension(qf_outputs); q++) { - //qf_derivatives[e * uint32_t(qpts_per_elem) + uint32_t(q)] = get_gradient(qf_outputs[q]); + qf_derivatives[e * uint32_t(qpts_per_elem) + uint32_t(q)] = get_gradient(qf_outputs[q]); } } diff --git a/src/serac/numerics/functional/domain_integral_kernels_new.hpp b/src/serac/numerics/functional/domain_integral_kernels_new.hpp new file mode 100644 index 0000000000..128f74982c --- /dev/null +++ b/src/serac/numerics/functional/domain_integral_kernels_new.hpp @@ -0,0 +1,438 @@ +// Copyright (c) 2019-2024, Lawrence Livermore National Security, LLC and +// other Serac Project Developers. See the top-level LICENSE file for +// details. +// +// SPDX-License-Identifier: (BSD-3-Clause) +#pragma once + +#include "serac/infrastructure/accelerator.hpp" +#include "serac/numerics/functional/quadrature_data.hpp" +#include "serac/numerics/functional/function_signature.hpp" +#include "serac/numerics/functional/differentiate_wrt.hpp" + +#include +#include +#include +#include + +namespace serac { + +namespace domain_integral { + +/** + * @tparam space the user-specified trial space + * @tparam dimension describes whether the problem is 1D, 2D, or 3D + * + * @brief a struct used to encode what type of arguments will be passed to a domain integral q-function, for the given + * trial space + */ +template +SERAC_HOST_DEVICE struct QFunctionArgument; + +/// @overload +template +SERAC_HOST_DEVICE struct QFunctionArgument, Dimension > { + using type = tuple >; ///< what will be passed to the q-function +}; + +/// @overload +template +SERAC_HOST_DEVICE struct QFunctionArgument, Dimension > { + using type = tuple, tensor >; ///< what will be passed to the q-function +}; + +/// @overload +template +SERAC_HOST_DEVICE struct QFunctionArgument, Dimension > { + using type = tuple >; ///< what will be passed to the q-function +}; +/// @overload +template +SERAC_HOST_DEVICE struct QFunctionArgument, Dimension > { + using type = tuple, tensor >; ///< what will be passed to the q-function +}; + +/// @overload +template +SERAC_HOST_DEVICE struct QFunctionArgument, Dimension<2> > { + using type = tuple, double>; ///< what will be passed to the q-function +}; + +/// @overload +template +SERAC_HOST_DEVICE struct QFunctionArgument, Dimension<3> > { + using type = tuple, tensor >; ///< what will be passed to the q-function +}; + +/// @brief layer of indirection needed to unpack the entries of the argument tuple +SERAC_SUPPRESS_NVCC_HOSTDEVICE_WARNING +template +SERAC_HOST_DEVICE auto apply_qf_helper(lambda&& qf, double t, coords_type&& x_q, qpt_data_type&& qpt_data, + const T& arg_tuple, std::integer_sequence) +{ + if constexpr (std::is_same::type, Nothing>::value) { + return qf(t, x_q, serac::get(arg_tuple)...); + } else { + return qf(t, x_q, qpt_data, serac::get(arg_tuple)...); + } +} + +/** + * @brief Actually calls the q-function + * This is an indirection layer to provide a transparent call site usage regardless of whether + * quadrature point (state) information is required + * @param[in] qf The quadrature function functor object + * @param[in] x_q The physical coordinates of the quadrature point + * @param[in] arg_tuple The values and derivatives at the quadrature point, as a dual + * @param[inout] qpt_data The state information at the quadrature point + */ +template +SERAC_HOST_DEVICE auto apply_qf(lambda&& qf, double t, coords_type&& x_q, qpt_data_type&& qpt_data, + const serac::tuple& arg_tuple) +{ + return apply_qf_helper(qf, t, x_q, qpt_data, arg_tuple, + std::make_integer_sequence(sizeof...(T))>{}); +} + +template +auto get_derivative_type(lambda qf, qpt_data_type&& qpt_data) +{ + using qf_arguments = serac::tuple >::type...>; + return get_gradient(apply_qf(qf, double{}, serac::tuple, tensor >{}, qpt_data, + make_dual_wrt(qf_arguments{}))); +}; + + +template +SERAC_HOST_DEVICE auto parent_to_physical(const tuple& qf_input, const tensor & invJ) { + return tuple{get<0>(qf_input), dot(get<1>(qf_input), invJ)}; +} + +template +SERAC_HOST_DEVICE auto physical_to_parent(const tuple& qf_output, const tensor & invJ, double detJ) { + // assumes family == Family::H1 for now + return tuple{get<0>(qf_output) * detJ, dot(get<1>(qf_output), transpose(invJ)) * detJ}; +} + +double scalar_type(double) { return 0.0; } + +template < typename T > +dual scalar_type(dual) { return dual{}; } + +template < typename T, int ... n > +T scalar_type(tensor) { return {}; } + +template < typename T1, typename T2 > +auto scalar_type(tuple< T1, T2 >) { return tuple{ scalar_type(T1{}), scalar_type(T2{}) }; } + +template < uint32_t differentiation_index, typename lambda, int dim, int n, typename... T> +SERAC_HOST_DEVICE auto batch_apply_qf_no_qdata(lambda qf, double t, const tensor& x, + const tensor& J, const T&... inputs) +{ + if constexpr (differentiation_index == NO_DIFFERENTIATION) { + using position_t = serac::tuple, tensor >; + using return_type = decltype(qf(double{}, position_t{}, T{}[0]...)); + + tensor outputs{}; + for (int i = 0; i < n; i++) { + tensor x_q; + tensor J_q; + for (int j = 0; j < dim; j++) { + for (int k = 0; k < dim; k++) { + J_q[j][k] = J(k, j, i); + } + x_q[j] = x(j, i); + } + double detJ_q = det(J_q); + tensor invJ_q = inv(J_q); + auto qf_output = qf(t, serac::tuple{x_q, J_q}, parent_to_physical(inputs[i], invJ_q) ...); + physical_to_parent(qf_output, invJ_q, detJ_q); + outputs[i] = qf_output; + } + return outputs; + + } else { + using position_t = serac::tuple, tensor >; + using return_value_type = decltype(qf(double{}, position_t{}, get_value(T{}[0])...)); + using gradient_type = decltype(get_gradient((scalar_type(inputs[0]) + ...))); + + tensor outputs{}; + for (int i = 0; i < n; i++) { + tensor x_q; + tensor J_q; + for (int j = 0; j < dim; j++) { + for (int k = 0; k < dim; k++) { + J_q[j][k] = J(k, j, i); + } + x_q[j] = x(j, i); + } + double detJ_q = det(J_q); + tensor invJ_q = inv(J_q); + outputs[i] = physical_to_parent(qf(t, serac::tuple{x_q, J_q}, parent_to_physical(inputs[i], invJ_q) ...), invJ_q, detJ_q); + } + return outputs; + } +} + +template +SERAC_HOST_DEVICE auto batch_apply_qf(lambda qf, double t, const tensor& x, + const tensor& J, qpt_data_type* qpt_data, bool update_state, + const T&... inputs) +{ + using position_t = serac::tuple, tensor >; + using return_type = decltype(qf(double{}, position_t{}, qpt_data[0], T{}[0]...)); + tensor outputs{}; + for (int i = 0; i < n; i++) { + tensor x_q; + tensor J_q; + for (int j = 0; j < dim; j++) { + for (int k = 0; k < dim; k++) { + J_q[j][k] = J(k, j, i); + } + x_q[j] = x(j, i); + } + auto qdata = qpt_data[i]; + outputs[i] = qf(t, serac::tuple{x_q, J_q}, qdata, inputs[i]...); + if (update_state) { + qpt_data[i] = qdata; + } + } + return outputs; +} + +template +void evaluation_kernel_impl(trial_element_tuple trial_elements, test_element, double t, + const std::vector& inputs, double* outputs, const double* positions, + const double* jacobians, lambda_type qf, + [[maybe_unused]] axom::ArrayView qf_state, + [[maybe_unused]] derivative_type* qf_derivatives, const int* elements, + uint32_t num_elements, bool update_state, camp::int_seq) +{ + // mfem provides this information as opaque arrays of doubles, + // so we reinterpret the pointer with + auto r = reinterpret_cast(outputs); + auto x = reinterpret_cast::type*>(positions); + auto J = reinterpret_cast::type*>(jacobians); + TensorProductQuadratureRule rule{}; + + [[maybe_unused]] auto qpts_per_elem = num_quadrature_points(geom, Q); + + [[maybe_unused]] tuple u = { + reinterpret_cast(trial_elements))::dof_type*>(inputs[indices])...}; + + // for each element in the domain + for (uint32_t e = 0; e < num_elements; ++e) { + // load the jacobians and positions for each quadrature point in this element + auto J_e = J[e]; + auto x_e = x[e]; + + //[[maybe_unused]] static constexpr trial_element_tuple trial_element_tuple{}; + // batch-calculate values / derivatives of each trial space, at each quadrature point + tuple qf_inputs = {get(trial_elements).interpolate(get(u)[elements[e]], rule)...}; + + // (batch) evalute the q-function at each quadrature point + // + // note: the weird immediately-invoked lambda expression is + // a workaround for a bug in GCC(<12.0) where it fails to + // decide which function overload to use, and crashes + auto qf_outputs = batch_apply_qf_no_qdata(qf, t, x_e, J_e, get(qf_inputs)...); + + // write out the q-function derivatives after applying the + // physical_to_parent transformation, so that those transformations + // won't need to be applied in the action_of_gradient and element_gradient kernels + if constexpr (differentiation_index != serac::NO_DIFFERENTIATION) { + for (int q = 0; q < leading_dimension(qf_outputs); q++) { + //qf_derivatives[e * uint32_t(qpts_per_elem) + uint32_t(q)] = get_gradient(qf_outputs[q]); + } + } + + // (batch) integrate the material response against the test-space basis functions + test_element::integrate(get_value(qf_outputs), rule, &r[elements[e]]); + } + + return; +} + +//clang-format off +template +SERAC_HOST_DEVICE auto chain_rule(const S& dfdx, const T& dx) +{ + if constexpr (is_QOI) { + return serac::chain_rule(serac::get<0>(dfdx), serac::get<0>(dx)) + + serac::chain_rule(serac::get<1>(dfdx), serac::get<1>(dx)); + } + + if constexpr (!is_QOI) { + return serac::tuple{serac::chain_rule(serac::get<0>(serac::get<0>(dfdx)), serac::get<0>(dx)) + + serac::chain_rule(serac::get<1>(serac::get<0>(dfdx)), serac::get<1>(dx)), + serac::chain_rule(serac::get<0>(serac::get<1>(dfdx)), serac::get<0>(dx)) + + serac::chain_rule(serac::get<1>(serac::get<1>(dfdx)), serac::get<1>(dx))}; + } +} +//clang-format on + +template +SERAC_HOST_DEVICE auto batch_apply_chain_rule(derivative_type* qf_derivatives, const tensor& inputs) +{ + using return_type = decltype(chain_rule(derivative_type{}, T{})); + tensor outputs{}; + for (int i = 0; i < n; i++) { + outputs[i] = chain_rule(qf_derivatives[i], inputs[i]); + } + return outputs; +} + +/** + * @brief The base kernel template used to create create custom directional derivative + * kernels associated with finite element calculations + * + * @tparam test The type of the test function space + * @tparam trial The type of the trial function space + * The above spaces can be any combination of {H1, Hcurl, Hdiv (TODO), L2 (TODO), QOI} + * + * Template parameters other than the test and trial spaces are used for customization + optimization + * and are erased through the @p std::function members of @p DomainIntegral + * @tparam g The shape of the element (only quadrilateral and hexahedron are supported at present) + * @tparam Q parameter describing number of quadrature points (see num_quadrature_points() function for more details) + * @tparam derivatives_type Type representing the derivative of the q-function w.r.t. its input arguments + * + * @note lambda does not appear as a template argument, as the directional derivative is + * inherently just a linear transformation + * + * @param[in] dU The full set of per-element DOF values (primary input) + * @param[inout] dR The full set of per-element residuals (primary output) + * @param[in] derivatives_ptr The address at which derivatives of the q-function with + * respect to its arguments are stored + * @param[in] J_ The Jacobians of the element transformations at all quadrature points + * @see mfem::GeometricFactors + * @param[in] num_elements The number of elements in the mesh + */ + +template +void action_of_gradient_kernel(const double* dU, double* dR, derivatives_type* qf_derivatives, const int* elements, + std::size_t num_elements) +{ + using test_element = finite_element; + using trial_element = finite_element; + + constexpr bool is_QOI = (test::family == Family::QOI); + constexpr int num_qpts = num_quadrature_points(g, Q); + + // mfem provides this information in 1D arrays, so we reshape it + // into strided multidimensional arrays before using + auto du = reinterpret_cast(dU); + auto dr = reinterpret_cast(dR); + constexpr TensorProductQuadratureRule rule{}; + + // for each element in the domain + for (uint32_t e = 0; e < num_elements; e++) { + // (batch) interpolate each quadrature point's value + auto qf_inputs = trial_element::interpolate(du[elements[e]], rule); + + // (batch) evalute the q-function at each quadrature point + auto qf_outputs = batch_apply_chain_rule(qf_derivatives + e * num_qpts, qf_inputs); + + // (batch) integrate the material response against the test-space basis functions + test_element::integrate(qf_outputs, rule, &dr[elements[e]]); + } +} + +/** + * @brief The base kernel template used to compute tangent element entries that can be assembled + * into a tangent matrix + * + * @tparam test The type of the test function space + * @tparam trial The type of the trial function space + * The above spaces can be any combination of {H1, Hcurl, Hdiv (TODO), L2 (TODO), QOI} + * + * Template parameters other than the test and trial spaces are used for customization + optimization + * and are erased through the @p std::function members of @p Integral + * @tparam g The shape of the element (only quadrilateral and hexahedron are supported at present) + * @tparam Q parameter describing number of quadrature points (see num_quadrature_points() function for more details) + * @tparam derivatives_type Type representing the derivative of the q-function w.r.t. its input arguments + * + * + * @param[inout] dk 3-dimensional array storing the element gradient matrices + * @param[in] derivatives_ptr pointer to data describing the derivatives of the q-function with respect to its arguments + * @param[in] J_ The Jacobians of the element transformations at all quadrature points + * @see mfem::GeometricFactors + * @param[in] num_elements The number of elements in the mesh + */ +template +void element_gradient_kernel(ExecArrayView dK, derivatives_type* qf_derivatives, + const int* elements, std::size_t num_elements) +{ + // quantities of interest have no flux term, so we pad the derivative + // tuple with a "zero" type in the second position to treat it like the standard case + constexpr bool is_QOI = test::family == Family::QOI; + using padded_derivative_type = std::conditional_t, derivatives_type>; + + using test_element = finite_element; + using trial_element = finite_element; + + constexpr int nquad = num_quadrature_points(g, Q); + + static constexpr TensorProductQuadratureRule rule{}; + + // for each element in the domain + for (uint32_t e = 0; e < num_elements; e++) { + auto* output_ptr = reinterpret_cast(&dK(elements[e], 0, 0)); + + tensor derivatives{}; + for (int q = 0; q < nquad; q++) { + if constexpr (is_QOI) { + get<0>(derivatives(q)) = qf_derivatives[e * nquad + uint32_t(q)]; + } else { + derivatives(q) = qf_derivatives[e * nquad + uint32_t(q)]; + } + } + + for (int J = 0; J < trial_element::ndof; J++) { + auto source_and_flux = trial_element::batch_apply_shape_fn(J, derivatives, rule); + test_element::integrate(source_and_flux, rule, output_ptr + J, trial_element::ndof); + } + } +} + +template +auto evaluation_kernel(signature s, lambda_type qf, const double* positions, const double* jacobians, + std::shared_ptr > qf_state, + std::shared_ptr qf_derivatives, const int* elements, uint32_t num_elements) +{ + auto trial_elements = trial_elements_tuple(s); + auto test_element = get_test_element(s); + return [=](double time, const std::vector& inputs, double* outputs, bool update_state) { + domain_integral::evaluation_kernel_impl( + trial_elements, test_element, time, inputs, outputs, positions, jacobians, qf, (*qf_state)[geom], + qf_derivatives.get(), elements, num_elements, update_state, s.index_seq); + }; +} + +template +std::function jacobian_vector_product_kernel( + signature, std::shared_ptr qf_derivatives, const int* elements, uint32_t num_elements) +{ + return [=](const double* du, double* dr) { + using test_space = typename signature::return_type; + using trial_space = typename std::tuple_element::type; + action_of_gradient_kernel(du, dr, qf_derivatives.get(), elements, num_elements); + }; +} + +template +std::function)> element_gradient_kernel( + signature, std::shared_ptr qf_derivatives, const int* elements, uint32_t num_elements) +{ + return [=](ExecArrayView K_elem) { + using test_space = typename signature::return_type; + using trial_space = typename std::tuple_element::type; + element_gradient_kernel(K_elem, qf_derivatives.get(), elements, num_elements); + }; +} + +} // namespace domain_integral + +} // namespace serac diff --git a/src/serac/numerics/functional/tests/enzyme_test.cpp b/src/serac/numerics/functional/tests/enzyme_test.cpp index 0e551d4f20..313e38cc51 100644 --- a/src/serac/numerics/functional/tests/enzyme_test.cpp +++ b/src/serac/numerics/functional/tests/enzyme_test.cpp @@ -1,7 +1,10 @@ #include -#include "serac/numerics/functional/enzyme_declarations.hpp" +#include "../enzyme_declarations.hpp" +#include "../tuple.hpp" +#include "../tensor.hpp" +#if 0 double square(double x) { return x * x; } @@ -9,9 +12,23 @@ double square(double x) { double dsquare(double x) { return __enzyme_autodiff(reinterpret_cast(square), x); } +#endif + +template< typename T, typename ... arg_types > +auto wrapper(const T & f, arg_types ... args) { + return f(args...); +} int main() { - for(double i=1; i<5; i++) { - std::cout << square(i) << " " << dsquare(i) << std::endl; - } + + auto f = [](double x) { return x * x; }; + + auto df = [&](double x) { + return __enzyme_fwddiff((void*)(wrapper), + enzyme_const, (void*)&f, + enzyme_dup, x, 1.0); + }; + + std::cout << df(3) << std::endl; + } diff --git a/src/serac/physics/state/finite_element_vector.hpp b/src/serac/physics/state/finite_element_vector.hpp index f35d9c8981..ac82fac50b 100644 --- a/src/serac/physics/state/finite_element_vector.hpp +++ b/src/serac/physics/state/finite_element_vector.hpp @@ -18,7 +18,7 @@ #include "mfem.hpp" #include "serac/infrastructure/variant.hpp" -#include "serac/numerics/functional/functional.hpp" +#include "serac/numerics/functional/finite_element.hpp" namespace serac { From 0fc40d53b23af32c6f65b2d8b6f33de5a867a413 Mon Sep 17 00:00:00 2001 From: Sam Mish Date: Thu, 23 May 2024 13:19:50 -0700 Subject: [PATCH 04/12] investigating wrong derivatives --- .../numerics/functional/tests/CMakeLists.txt | 7 +- .../functional/tests/enzyme_reproducer.cpp | 142 +++++++++++++ .../tests/enzyme_reproducer_rbr.cpp | 141 +++++++++++++ .../numerics/functional/tests/enzyme_test.cpp | 186 ++++++++++++++++-- 4 files changed, 462 insertions(+), 14 deletions(-) create mode 100644 src/serac/numerics/functional/tests/enzyme_reproducer.cpp create mode 100644 src/serac/numerics/functional/tests/enzyme_reproducer_rbr.cpp diff --git a/src/serac/numerics/functional/tests/CMakeLists.txt b/src/serac/numerics/functional/tests/CMakeLists.txt index 9c7fdb8490..06f3bba69e 100644 --- a/src/serac/numerics/functional/tests/CMakeLists.txt +++ b/src/serac/numerics/functional/tests/CMakeLists.txt @@ -70,5 +70,10 @@ if(ENABLE_CUDA) endif() add_executable(enzyme_test enzyme_test.cpp) - target_link_libraries(enzyme_test PUBLIC serac_functional) + +add_executable(enzyme_reproducer enzyme_reproducer.cpp) +target_link_libraries(enzyme_reproducer PUBLIC serac_functional) + +add_executable(enzyme_reproducer_rbr enzyme_reproducer_rbr.cpp) +target_link_libraries(enzyme_reproducer_rbr PUBLIC serac_functional) diff --git a/src/serac/numerics/functional/tests/enzyme_reproducer.cpp b/src/serac/numerics/functional/tests/enzyme_reproducer.cpp new file mode 100644 index 0000000000..ba8155a5cf --- /dev/null +++ b/src/serac/numerics/functional/tests/enzyme_reproducer.cpp @@ -0,0 +1,142 @@ +#include +#include + +//////////////////////////////////////////////////////////////////////////////// + +template < typename T, int m, int n > +struct tensor { + T & operator()(int i, int j) { return values[i][j]; } + const T & operator()(int i, int j) const { return values[i][j]; } + T values[m][n]; +}; + +template < int m, int n > +tensor operator*(double scale, const tensor & A) { + tensor scaled{}; + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + scaled(i,j) = scale * A(i,j); + } + } + return scaled; +} + +template < typename T, int m, int n > +std::ostream& operator<<(std::ostream & out, const tensor & A) { + out << "{"; + for (int i = 0; i < m; i++) { + out << "{"; + for (int j = 0; j < n; j++) { + out << A(i,j); + if (j != n - 1) { out << ", "; } + } + out << "}"; + if (i != m - 1) { out << ", "; } + } + out << "}"; + return out; +} + +//////////////////////////////////////////////////////////////////////////////// + +namespace impl { + + constexpr int vsize(const double &) { return 1; } + + template < int m, int n > + constexpr int vsize(const tensor< double, m, n > &) { return m * n; } + + template < int m, int n, int p, int q > + constexpr int vsize(const tensor< tensor< double, p, q >, m, n > &) { return m * n * p * q; } + + template < typename T1, typename T2 > + struct outer_prod; + + template < int m, int n > + struct outer_prod< double, tensor< double, m, n > >{ + using type = tensor; + }; + + template < int m, int n > + struct outer_prod< tensor< double, m, n >, double >{ + using type = tensor; + }; + + template < int m, int n, int p, int q > + struct outer_prod< tensor< double, m, n >, tensor< double, p, q > >{ + using type = tensor, m, n>; + }; +} + +//////////////////////////////////////////////////////////////////////////////// + +template< typename output_type, typename function, typename ... arg_types > +void wrapper(output_type & output, const function & f, const arg_types & ... args) { + output = f(args...); +} + +template < typename function, typename input_type > +auto jacfwd(const function & f, const input_type & x) { + using output_type = decltype(f(x)); + using jac_type = typename impl::outer_prod::type; + void * func_ptr = reinterpret_cast(wrapper< output_type, function, input_type >); + + constexpr int m = impl::vsize(output_type{}); + jac_type J{}; + double * J_ptr = reinterpret_cast(&J); + + constexpr int n = impl::vsize(input_type{}); + input_type dx{}; + double * dx_ptr = reinterpret_cast(&dx); + + for (int j = 0; j < n; j++) { + dx_ptr[j] = 1.0; + + output_type unused{}; + output_type df_dxj{}; + __enzyme_fwddiff(func_ptr, + enzyme_dupnoneed, &unused, &df_dxj, + enzyme_const, reinterpret_cast(&f), + enzyme_dup, &x, &dx + ); + + double * df_dxj_ptr = reinterpret_cast(&df_dxj); + for (int i = 0; i < m; i++) { + J_ptr[i * n + j] = df_dxj_ptr[i]; + } + + dx_ptr[j] = 0.0; + } + + return J; +} + +template < int i, typename function, typename T0, typename T1 > +auto jacfwd(const function & f, const T0 & arg0, const T1 & arg1) { + if constexpr (i == 0) { return jacfwd([&](T0 x){ return f(x, arg1); }, arg0); } + if constexpr (i == 1) { return jacfwd([&](T1 x){ return f(arg0, x); }, arg1); } +} + +//////////////////////////////////////////////////////////////////////////////// + +int main() { + + auto f = [=](double z, const tensor< double, 3, 3 > & du_dx) { return z * du_dx; }; + + double z = 3.0; + tensor du_dx = {{{1.0, 2.0, 3.0}, {2.0, 3.0, 1.0}, {1.0, 0.5, 0.2}}}; + + std::cout << "f(x, du_dx): " << f(z, du_dx) << std::endl; + std::cout << "expected: {{3., 6., 9.}, {6., 9., 3.}, {3., 1.5, 0.6}}" << std::endl; + std::cout << std::endl; + + auto df_darg0 = jacfwd<0>(f, z, du_dx); + std::cout << "df_dz: " << df_darg0 << std::endl; + std::cout << "expected: {{1., 2., 3.}, {2., 3., 1.}, {1., 0.5, 0.2}}" << std::endl; + std::cout << std::endl; + + auto df_darg1 = jacfwd<1>(f, z, du_dx); + std::cout << "df_d(du_dx): " << df_darg1 << std::endl; + std::cout << "expected: {{{{3., 0, 0}, {0, 0, 0}, {0, 0, 0}}, {{0, 3., 0}, {0, 0, 0}, {0, 0, 0}}, {{0, 0, 3.}, {0, 0, 0}, {0, 0, 0}}}, {{{0, 0, 0}, {3., 0, 0}, {0, 0, 0}}, {{0, 0, 0}, {0, 3., 0}, {0, 0, 0}}, {{0, 0, 0}, {0, 0, 3.}, {0, 0, 0}}}, {{{0, 0, 0}, {0, 0, 0}, {3., 0, 0}}, {{0, 0, 0}, {0, 0, 0}, {0, 3., 0}}, {{0, 0, 0}, {0, 0, 0}, {0, 0, 3.}}}}" << std::endl; + +} diff --git a/src/serac/numerics/functional/tests/enzyme_reproducer_rbr.cpp b/src/serac/numerics/functional/tests/enzyme_reproducer_rbr.cpp new file mode 100644 index 0000000000..ac75d954ed --- /dev/null +++ b/src/serac/numerics/functional/tests/enzyme_reproducer_rbr.cpp @@ -0,0 +1,141 @@ +#include +#include + +extern int enzyme_active_return; + +//////////////////////////////////////////////////////////////////////////////// + +template < typename T, int m, int n > +struct tensor { + T & operator()(int i, int j) { return values[i][j]; } + const T & operator()(int i, int j) const { return values[i][j]; } + T values[m][n]; +}; + +template < int m, int n > +tensor operator*(double scale, const tensor & A) { + tensor scaled{}; + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + scaled(i,j) = scale * A(i,j); + } + } + return scaled; +} + +template < typename T, int m, int n > +std::ostream& operator<<(std::ostream & out, const tensor & A) { + out << "{"; + for (int i = 0; i < m; i++) { + out << "{"; + for (int j = 0; j < n; j++) { + out << A(i,j); + if (j != n - 1) { out << ", "; } + } + out << "}"; + if (i != m - 1) { out << ", "; } + } + out << "}"; + return out; +} + +//////////////////////////////////////////////////////////////////////////////// + +namespace impl { + + constexpr int vsize(const double &) { return 1; } + + template < int m, int n > + constexpr int vsize(const tensor< double, m, n > &) { return m * n; } + + template < int m, int n, int p, int q > + constexpr int vsize(const tensor< tensor< double, p, q >, m, n > &) { return m * n * p * q; } + + template < typename T1, typename T2 > + struct outer_prod; + + template < int m, int n > + struct outer_prod< double, tensor< double, m, n > >{ + using type = tensor; + }; + + template < int m, int n > + struct outer_prod< tensor< double, m, n >, double >{ + using type = tensor; + }; + + template < int m, int n, int p, int q > + struct outer_prod< tensor< double, m, n >, tensor< double, p, q > >{ + using type = tensor, m, n>; + }; +} + +//////////////////////////////////////////////////////////////////////////////// + +template< typename T, typename ... arg_types > +auto wrapper(const T & f, const arg_types & ... args) { + return f(args...); +} + +template < typename function, typename input_type > +auto jacfwd(const function & f, const input_type & x) { + using output_type = decltype(f(x)); + using jac_type = typename impl::outer_prod::type; + void * func_ptr = reinterpret_cast(wrapper< function, input_type >); + + constexpr int m = impl::vsize(output_type{}); + jac_type J{}; + double * J_ptr = reinterpret_cast(&J); + + constexpr int n = impl::vsize(input_type{}); + input_type dx{}; + double * dx_ptr = reinterpret_cast(&dx); + + for (int j = 0; j < n; j++) { + dx_ptr[j] = 1.0; + + output_type df_dxj = __enzyme_fwddiff(func_ptr, + enzyme_const, reinterpret_cast(&f), + enzyme_dup, &x, &dx + ); + + double * df_dxj_ptr = reinterpret_cast(&df_dxj); + for (int i = 0; i < m; i++) { + J_ptr[i * n + j] = df_dxj_ptr[i]; + } + + dx_ptr[j] = 0.0; + } + + return J; +} + +template < int i, typename function, typename T0, typename T1 > +auto jacfwd(const function & f, const T0 & arg0, const T1 & arg1) { + if constexpr (i == 0) { return jacfwd([&](T0 x){ return f(x, arg1); }, arg0); } + if constexpr (i == 1) { return jacfwd([&](T1 x){ return f(arg0, x); }, arg1); } +} + +//////////////////////////////////////////////////////////////////////////////// + +int main() { + + auto f = [=](double z, const tensor< double, 3, 3 > & du_dx) { return z * du_dx; }; + + double z = 3.0; + tensor du_dx = {{{1.0, 2.0, 3.0}, {2.0, 3.0, 1.0}, {1.0, 0.5, 0.2}}}; + + std::cout << "f(x, du_dx): " << f(z, du_dx) << std::endl; + std::cout << "expected: {{3., 6., 9.}, {6., 9., 3.}, {3., 1.5, 0.6}}" << std::endl; + std::cout << std::endl; + + auto df_darg0 = jacfwd<0>(f, z, du_dx); + std::cout << "df_dz: " << df_darg0 << std::endl; + std::cout << "expected: {{1., 2., 3.}, {2., 3., 1.}, {1., 0.5, 0.2}}" << std::endl; + std::cout << std::endl; + + auto df_darg1 = jacfwd<1>(f, z, du_dx); + std::cout << "df_d(du_dx): " << df_darg1 << std::endl; + std::cout << "expected: {{{{3., 0, 0}, {0, 0, 0}, {0, 0, 0}}, {{0, 3., 0}, {0, 0, 0}, {0, 0, 0}}, {{0, 0, 3.}, {0, 0, 0}, {0, 0, 0}}}, {{{0, 0, 0}, {3., 0, 0}, {0, 0, 0}}, {{0, 0, 0}, {0, 3., 0}, {0, 0, 0}}, {{0, 0, 0}, {0, 0, 3.}, {0, 0, 0}}}, {{{0, 0, 0}, {0, 0, 0}, {3., 0, 0}}, {{0, 0, 0}, {0, 0, 0}, {0, 3., 0}}, {{0, 0, 0}, {0, 0, 0}, {0, 0, 3.}}}}" << std::endl; + +} diff --git a/src/serac/numerics/functional/tests/enzyme_test.cpp b/src/serac/numerics/functional/tests/enzyme_test.cpp index 313e38cc51..1b6bea6a74 100644 --- a/src/serac/numerics/functional/tests/enzyme_test.cpp +++ b/src/serac/numerics/functional/tests/enzyme_test.cpp @@ -1,8 +1,9 @@ #include -#include "../enzyme_declarations.hpp" -#include "../tuple.hpp" -#include "../tensor.hpp" +#include + +#include "serac/numerics/functional/tuple.hpp" +#include "serac/numerics/functional/tensor.hpp" #if 0 double square(double x) { @@ -14,21 +15,180 @@ double dsquare(double x) { } #endif -template< typename T, typename ... arg_types > -auto wrapper(const T & f, arg_types ... args) { - return f(args...); +using namespace serac; + +namespace impl { + + constexpr int vsize(const double &) { return 1; } + + template < int n > + constexpr int vsize(const serac::tensor< double, n > &) { return n; } + + template < int m, int n > + constexpr int vsize(const serac::tensor< double, m, n > &) { return m * n; } + + template < typename T0, typename T1 > + constexpr int vsize(const serac::tuple< T0, T1 > &) { return vsize(T0{}) + vsize(T1{}); } + +} + +template < typename S, typename T0, typename T1 > +struct serac::detail::outer_prod< S, serac::tuple >{ + using type = serac::tuple < + typename outer_prod< S, T0 >::type, + typename outer_prod< S, T1 >::type + >; +}; + +template< typename output_type, typename function, typename ... arg_types > +void wrapper(output_type & output, const function & f, const arg_types & ... args) { + output = f(args...); +} + +template < typename function, typename input_type > +__attribute__((always_inline)) +auto jacfwd(const function & f, const input_type & x) { + using output_type = decltype(f(x)); + using jac_type = typename serac::detail::outer_prod::type; + void * func_ptr = reinterpret_cast(wrapper< output_type, function, input_type >); + + constexpr int m = impl::vsize(output_type{}); + jac_type J{}; + double * J_ptr = reinterpret_cast(&J); + + constexpr int n = impl::vsize(input_type{}); + input_type dx{}; + double * dx_ptr = reinterpret_cast(&dx); + + for (int j = 0; j < n; j++) { + dx_ptr[j] = 1.0; + + std::cout << dx << std::endl; + + output_type unused{}; + output_type df_dxj{}; + __enzyme_fwddiff(func_ptr, + enzyme_dupnoneed, &unused, &df_dxj, + enzyme_const, reinterpret_cast(&f), + enzyme_dup, &x, &dx + ); + std::cout << df_dxj << std::endl; + + double * df_dxj_ptr = reinterpret_cast(&df_dxj); + for (int i = 0; i < m; i++) { + J_ptr[i * n + j] = df_dxj_ptr[i]; + } + + std::cout << J << std::endl; + + dx_ptr[j] = 0.0; + } + + return J; +} + +template < int i, typename function, typename T0, typename T1 > +__attribute__((always_inline)) +auto jacfwd(const function & f, const T0 & arg0, const T1 & arg1) { + if constexpr (i == 0) { return jacfwd([&](T0 x){ return f(x, arg1); }, arg0); } + if constexpr (i == 1) { return jacfwd([&](T1 x){ return f(arg0, x); }, arg1); } +} + +template < int i, typename function, typename T0, typename T1, typename T2 > +__attribute__((always_inline)) +auto jacfwd(const function & f, const T0 & arg0, const T1 & arg1, const T2 & arg2) { + if constexpr (i == 0) { return jacfwd([&](T0 x){ return f(x, arg1, arg2); }, arg0); } + if constexpr (i == 1) { return jacfwd([&](T1 x){ return f(arg0, x, arg2); }, arg1); } + if constexpr (i == 2) { return jacfwd([&](T2 x){ return f(arg0, arg1, x); }, arg2); } +} + +template < int i, typename function, typename T0, typename T1, typename T2, typename T3 > +__attribute__((always_inline)) +auto jacfwd(const function & f, const T0 & arg0, const T1 & arg1, const T2 & arg2, const T3 & arg3) { + if constexpr (i == 0) { return jacfwd([&](T0 x){ return f(x, arg1, arg2, arg3); }, arg0); } + if constexpr (i == 1) { return jacfwd([&](T1 x){ return f(arg0, x, arg2, arg3); }, arg1); } + if constexpr (i == 2) { return jacfwd([&](T2 x){ return f(arg0, arg1, x, arg3); }, arg2); } + if constexpr (i == 3) { return jacfwd([&](T3 x){ return f(arg0, arg1, arg2, x); }, arg3); } +} + +template < int i, typename function, typename T0, typename T1, typename T2, typename T3, typename T4 > +__attribute__((always_inline)) +auto jacfwd(const function & f, const T0 & arg0, const T1 & arg1, const T2 & arg2, const T3 & arg3, const T4 & arg4) { + if constexpr (i == 0) { return jacfwd([&](T0 x){ return f(x, arg1, arg2, arg3, arg4); }, arg0); } + if constexpr (i == 1) { return jacfwd([&](T1 x){ return f(arg0, x, arg2, arg3, arg4); }, arg1); } + if constexpr (i == 2) { return jacfwd([&](T2 x){ return f(arg0, arg1, x, arg3, arg4); }, arg2); } + if constexpr (i == 3) { return jacfwd([&](T3 x){ return f(arg0, arg1, arg2, x, arg4); }, arg3); } + if constexpr (i == 4) { return jacfwd([&](T4 x){ return f(arg0, arg1, arg2, arg3, x); }, arg4); } +} + +template < int i, typename function, typename T0, typename T1, typename T2, typename T3, typename T4, typename T5 > +__attribute__((always_inline)) +auto jacfwd(const function & f, const T0 & arg0, const T1 & arg1, const T2 & arg2, const T3 & arg3, const T4 & arg4, const T5 & arg5) { + if constexpr (i == 0) { return jacfwd([&](T0 x){ return f(x, arg1, arg2, arg3, arg4, arg5); }, arg0); } + if constexpr (i == 1) { return jacfwd([&](T1 x){ return f(arg0, x, arg2, arg3, arg4, arg5); }, arg1); } + if constexpr (i == 2) { return jacfwd([&](T2 x){ return f(arg0, arg1, x, arg3, arg4, arg5); }, arg2); } + if constexpr (i == 3) { return jacfwd([&](T3 x){ return f(arg0, arg1, arg2, x, arg4, arg5); }, arg3); } + if constexpr (i == 4) { return jacfwd([&](T4 x){ return f(arg0, arg1, arg2, arg3, x, arg5); }, arg4); } + if constexpr (i == 5) { return jacfwd([&](T5 x){ return f(arg0, arg1, arg2, arg3, arg4, x); }, arg5); } } int main() { - auto f = [](double x) { return x * x; }; + auto f = [](double z, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement) { + auto [u, du_dx] = displacement; + return z * (du_dx + transpose(du_dx)) - outer(u, u); + }; + + auto dfdz = [](double, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement) { + auto [u, du_dx] = displacement; + return (du_dx + transpose(du_dx)); + }; + + auto dfdu = [](double, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement) { + auto u = get<0>(displacement); + tensor output{}; + for (int k = 0; k < 3; k++) { + for (int j = 0; j < 3; j++) { + for (int i = 0; i < 3; i++) { + output(i,j,k) = - (u(i) * (j == k) + u(j) * (i == k)); + } + } + } + return output; + }; + + auto dfddudx = [](double z, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > &) { + tensor output{}; + for (int l = 0; l < 3; l++) { + for (int k = 0; k < 3; k++) { + for (int j = 0; j < 3; j++) { + for (int i = 0; i < 3; i++) { + output(i,j,k,l) = z * ((i==k) * (j==l) + (j==k) * (i==l)); + } + } + } + } + return output; + }; + + double z = 3.0; + auto displacement = tuple { + tensor{{1.0, 1.0, 1.0}}, + tensor{{{1.0, 2.0, 3.0}, {2.0, 3.0, 1.0}, {1.0, 0.5, 0.2}}} + }; + + auto df_dz = jacfwd<0>(f, z, displacement); + std::cout << "df_dz: " << df_dz << std::endl; + std::cout << "expected: " << dfdz(z, displacement) << std::endl; + std::cout << std::endl; - auto df = [&](double x) { - return __enzyme_fwddiff((void*)(wrapper), - enzyme_const, (void*)&f, - enzyme_dup, x, 1.0); - }; + auto df_ddisp = jacfwd<1>(f, z, displacement); + std::cout << "df_du: " << get<0>(df_ddisp) << std::endl; + std::cout << "expected: " << dfdu(z, displacement) << std::endl; + std::cout << std::endl; - std::cout << df(3) << std::endl; + std::cout << "df_d(du_dx): " << get<1>(df_ddisp) << std::endl; + std::cout << "expected: " << dfddudx(z, displacement) << std::endl; + std::cout << std::endl; } From 9be9b035d8778cafc0a440a65656551f9db569a9 Mon Sep 17 00:00:00 2001 From: Sam Mish Date: Thu, 23 May 2024 16:10:19 -0700 Subject: [PATCH 05/12] fix tuple data layout for aliasing --- .../numerics/functional/tests/enzyme_test.cpp | 57 ++++++++++++++++--- 1 file changed, 49 insertions(+), 8 deletions(-) diff --git a/src/serac/numerics/functional/tests/enzyme_test.cpp b/src/serac/numerics/functional/tests/enzyme_test.cpp index 1b6bea6a74..7204c69e89 100644 --- a/src/serac/numerics/functional/tests/enzyme_test.cpp +++ b/src/serac/numerics/functional/tests/enzyme_test.cpp @@ -30,6 +30,30 @@ namespace impl { template < typename T0, typename T1 > constexpr int vsize(const serac::tuple< T0, T1 > &) { return vsize(T0{}) + vsize(T1{}); } + template < typename T0, typename T1 > + struct nested; + + template <> + struct nested< double, double >{ using type = double; }; + + template < int ... n > + struct nested< double, tensor >{ using type = tensor; }; + + template < typename T0, typename T1 > + struct nested< double, tuple >{ using type = tuple; }; + +//////////////////////////////////////////////////////////////////////////////// + + template < int ... n, typename T > + struct nested< tensor, T >{ using type = tensor; }; + +//////////////////////////////////////////////////////////////////////////////// + + template < typename S0, typename S1, typename T > + struct nested< tuple< S0, S1 >, T >{ + using type = tuple< typename nested::type, typename nested::type >; + }; + } template < typename S, typename T0, typename T1 > @@ -49,7 +73,7 @@ template < typename function, typename input_type > __attribute__((always_inline)) auto jacfwd(const function & f, const input_type & x) { using output_type = decltype(f(x)); - using jac_type = typename serac::detail::outer_prod::type; + using jac_type = typename impl::nested::type; void * func_ptr = reinterpret_cast(wrapper< output_type, function, input_type >); constexpr int m = impl::vsize(output_type{}); @@ -63,8 +87,6 @@ auto jacfwd(const function & f, const input_type & x) { for (int j = 0; j < n; j++) { dx_ptr[j] = 1.0; - std::cout << dx << std::endl; - output_type unused{}; output_type df_dxj{}; __enzyme_fwddiff(func_ptr, @@ -72,15 +94,12 @@ auto jacfwd(const function & f, const input_type & x) { enzyme_const, reinterpret_cast(&f), enzyme_dup, &x, &dx ); - std::cout << df_dxj << std::endl; double * df_dxj_ptr = reinterpret_cast(&df_dxj); for (int i = 0; i < m; i++) { J_ptr[i * n + j] = df_dxj_ptr[i]; } - std::cout << J << std::endl; - dx_ptr[j] = 0.0; } @@ -183,11 +202,33 @@ int main() { std::cout << std::endl; auto df_ddisp = jacfwd<1>(f, z, displacement); - std::cout << "df_du: " << get<0>(df_ddisp) << std::endl; + std::cout << "df_du: "; + std::cout << "{"; + for (int i = 0; i < 3; i++) { + std::cout << "{"; + for (int j = 0; j < 3; j++) { + std::cout << get<0>(df_ddisp(i,j)); + if (j != 2) { std::cout << ","; } + } + std::cout << "}"; + if (i != 2) { std::cout << ","; } + } + std::cout << "}" << std::endl; std::cout << "expected: " << dfdu(z, displacement) << std::endl; std::cout << std::endl; - std::cout << "df_d(du_dx): " << get<1>(df_ddisp) << std::endl; + std::cout << "df_d(du_dx): "; + std::cout << "{"; + for (int i = 0; i < 3; i++) { + std::cout << "{"; + for (int j = 0; j < 3; j++) { + std::cout << get<1>(df_ddisp(i,j)); + if (j != 2) { std::cout << ","; } + } + std::cout << "}"; + if (i != 2) { std::cout << ","; } + } + std::cout << "}" << std::endl; std::cout << "expected: " << dfddudx(z, displacement) << std::endl; std::cout << std::endl; From 3ff138c4b33324316068a38d810709ef35b3a964 Mon Sep 17 00:00:00 2001 From: Sam Mish Date: Fri, 24 May 2024 11:43:23 -0700 Subject: [PATCH 06/12] get more representative enzyme example working --- .../numerics/functional/tests/enzyme_test.cpp | 77 ++++++++++++++++--- 1 file changed, 66 insertions(+), 11 deletions(-) diff --git a/src/serac/numerics/functional/tests/enzyme_test.cpp b/src/serac/numerics/functional/tests/enzyme_test.cpp index 7204c69e89..fd84d59c32 100644 --- a/src/serac/numerics/functional/tests/enzyme_test.cpp +++ b/src/serac/numerics/functional/tests/enzyme_test.cpp @@ -153,17 +153,30 @@ auto jacfwd(const function & f, const T0 & arg0, const T1 & arg1, const T2 & arg int main() { - auto f = [](double z, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement) { + auto fg = [](double z, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement) { auto [u, du_dx] = displacement; - return z * (du_dx + transpose(du_dx)) - outer(u, u); + return tuple{dot(du_dx, z * u), z * (du_dx + transpose(du_dx)) - outer(u, u)}; }; +//////////////////////////////////////////////////////////////////////////////// + auto dfdz = [](double, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement) { + auto [u, du_dx] = displacement; + return dot(du_dx, u); + }; + + auto dgdz = [](double, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement) { auto [u, du_dx] = displacement; return (du_dx + transpose(du_dx)); }; - auto dfdu = [](double, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement) { +//////////////////////////////////////////////////////////////////////////////// + + auto dfdu = [](double z, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement) { + return get<1>(displacement) * z; + }; + + auto dgdu = [](double, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement) { auto u = get<0>(displacement); tensor output{}; for (int k = 0; k < 3; k++) { @@ -176,7 +189,22 @@ int main() { return output; }; - auto dfddudx = [](double z, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > &) { +//////////////////////////////////////////////////////////////////////////////// + + auto dfddudx = [](double z, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement) { + auto u = get<0>(displacement); + tensor output{}; + for (int k = 0; k < 3; k++) { + for (int j = 0; j < 3; j++) { + for (int i = 0; i < 3; i++) { + output(i,j,k) = z * u[k] * (i == j); + } + } + } + return output; + }; + + auto dgddudx = [](double z, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > &) { tensor output{}; for (int l = 0; l < 3; l++) { for (int k = 0; k < 3; k++) { @@ -190,46 +218,73 @@ int main() { return output; }; +//////////////////////////////////////////////////////////////////////////////// + double z = 3.0; auto displacement = tuple { tensor{{1.0, 1.0, 1.0}}, tensor{{{1.0, 2.0, 3.0}, {2.0, 3.0, 1.0}, {1.0, 0.5, 0.2}}} }; - auto df_dz = jacfwd<0>(f, z, displacement); + auto [df_dz, dg_dz] = jacfwd<0>(fg, z, displacement); std::cout << "df_dz: " << df_dz << std::endl; std::cout << "expected: " << dfdz(z, displacement) << std::endl; std::cout << std::endl; - auto df_ddisp = jacfwd<1>(f, z, displacement); + std::cout << "df_dz: " << dg_dz << std::endl; + std::cout << "expected: " << dgdz(z, displacement) << std::endl; + std::cout << std::endl; + + auto [df_ddisp, dg_ddisp] = jacfwd<1>(fg, z, displacement); std::cout << "df_du: "; std::cout << "{"; + for (int i = 0; i < 3; i++) { + std::cout << get<0>(df_ddisp(i)); + if (i != 2) { std::cout << ","; } + } + std::cout << "}" << std::endl; + std::cout << "expected: " << dfdu(z, displacement) << std::endl; + std::cout << std::endl; + + std::cout << "df_d(du_dx): "; + std::cout << "{"; + for (int i = 0; i < 3; i++) { + std::cout << get<1>(df_ddisp(i)); + if (i != 2) { std::cout << ","; } + } + std::cout << "}" << std::endl; + std::cout << "expected: " << dfddudx(z, displacement) << std::endl; + std::cout << std::endl; + + std::cout << "dg_du: "; + std::cout << "{"; for (int i = 0; i < 3; i++) { std::cout << "{"; for (int j = 0; j < 3; j++) { - std::cout << get<0>(df_ddisp(i,j)); + std::cout << get<0>(dg_ddisp(i,j)); if (j != 2) { std::cout << ","; } } std::cout << "}"; if (i != 2) { std::cout << ","; } } std::cout << "}" << std::endl; - std::cout << "expected: " << dfdu(z, displacement) << std::endl; + std::cout << "expected: " << dgdu(z, displacement) << std::endl; std::cout << std::endl; - std::cout << "df_d(du_dx): "; + + std::cout << "dg_d(du_dx): "; std::cout << "{"; for (int i = 0; i < 3; i++) { std::cout << "{"; for (int j = 0; j < 3; j++) { - std::cout << get<1>(df_ddisp(i,j)); + std::cout << get<1>(dg_ddisp(i,j)); if (j != 2) { std::cout << ","; } } std::cout << "}"; if (i != 2) { std::cout << ","; } } std::cout << "}" << std::endl; - std::cout << "expected: " << dfddudx(z, displacement) << std::endl; + std::cout << "expected: " << dgddudx(z, displacement) << std::endl; std::cout << std::endl; } From 581cfe97ce16e51f259fd638465a2de67044d87c Mon Sep 17 00:00:00 2001 From: Sam Mish Date: Sat, 25 May 2024 13:39:42 -0700 Subject: [PATCH 07/12] add jvp implementation + test --- cmake/thirdparty/SetupSeracThirdParty.cmake | 4 + .../functional/domain_integral_kernels.hpp | 101 ++++++-- ...ew.hpp => domain_integral_kernels_old.hpp} | 98 +++---- .../numerics/functional/enzyme_wrapper.hpp | 241 ++++++++++++++++++ src/serac/numerics/functional/tensor.hpp | 14 + .../numerics/functional/tests/CMakeLists.txt | 5 + .../functional/tests/enzyme_jvp_test.cpp | 84 ++++++ .../numerics/functional/tests/enzyme_test.cpp | 237 +++-------------- 8 files changed, 484 insertions(+), 300 deletions(-) rename src/serac/numerics/functional/{domain_integral_kernels_new.hpp => domain_integral_kernels_old.hpp} (85%) create mode 100644 src/serac/numerics/functional/enzyme_wrapper.hpp create mode 100644 src/serac/numerics/functional/tests/enzyme_jvp_test.cpp diff --git a/cmake/thirdparty/SetupSeracThirdParty.cmake b/cmake/thirdparty/SetupSeracThirdParty.cmake index 3f9816de6e..532541a3dc 100644 --- a/cmake/thirdparty/SetupSeracThirdParty.cmake +++ b/cmake/thirdparty/SetupSeracThirdParty.cmake @@ -564,6 +564,7 @@ endif() #------------------------------------------------------------------------------ # Enzyme #------------------------------------------------------------------------------ +if (FALSE) include(FetchContent) FetchContent_Declare( @@ -572,3 +573,6 @@ FetchContent_Declare( SOURCE_SUBDIR enzyme ) FetchContent_MakeAvailable(enzyme) +else() +add_subdirectory("/home/sam/code/Enzyme/official/enzyme" ${PROJECT_BINARY_DIR}/tmp/enzyme) +endif () diff --git a/src/serac/numerics/functional/domain_integral_kernels.hpp b/src/serac/numerics/functional/domain_integral_kernels.hpp index 9f13cd42fb..ff8dc50737 100644 --- a/src/serac/numerics/functional/domain_integral_kernels.hpp +++ b/src/serac/numerics/functional/domain_integral_kernels.hpp @@ -9,6 +9,7 @@ #include "serac/numerics/functional/quadrature_data.hpp" #include "serac/numerics/functional/function_signature.hpp" #include "serac/numerics/functional/differentiate_wrt.hpp" +#include "serac/numerics/functional/enzyme_wrapper.hpp" #include #include @@ -90,24 +91,36 @@ template & arg_tuple) { - return apply_qf_helper(qf, t, x_q, qpt_data, arg_tuple, - std::make_integer_sequence(sizeof...(T))>{}); + return apply_qf_helper(qf, t, x_q, qpt_data, arg_tuple, std::make_integer_sequence(sizeof...(T))>{}); } template auto get_derivative_type(lambda qf, qpt_data_type&& qpt_data) { using qf_arguments = serac::tuple >::type...>; - return get_gradient(apply_qf(qf, double{}, serac::tuple, tensor >{}, qpt_data, - make_dual_wrt(qf_arguments{}))); + using output_type = decltype(apply_qf(qf, double{}, serac::tuple, tensor >{}, qpt_data, qf_arguments{})); + return typename impl::nested< output_type, decltype(type(qf_arguments{})) >::type{}; }; -template + +template +SERAC_HOST_DEVICE auto parent_to_physical(const tuple& qf_input, const tensor & invJ) { + return tuple{get<0>(qf_input), dot(get<1>(qf_input), invJ)}; +} + +template +SERAC_HOST_DEVICE auto physical_to_parent(const tuple& qf_output, const tensor & invJ, double detJ) { + // assumes family == Family::H1 for now + return tuple{get<0>(qf_output) * detJ, dot(get<1>(qf_output), transpose(invJ)) * detJ}; +} + +template < typename lambda, int dim, int n, typename... T> SERAC_HOST_DEVICE auto batch_apply_qf_no_qdata(lambda qf, double t, const tensor& x, const tensor& J, const T&... inputs) { using position_t = serac::tuple, tensor >; using return_type = decltype(qf(double{}, position_t{}, T{}[0]...)); + tensor outputs{}; for (int i = 0; i < n; i++) { tensor x_q; @@ -118,11 +131,55 @@ SERAC_HOST_DEVICE auto batch_apply_qf_no_qdata(lambda qf, double t, const tensor } x_q[j] = x(j, i); } - outputs[i] = qf(t, serac::tuple{x_q, J_q}, inputs[i]...); + double detJ_q = det(J_q); + tensor invJ_q = inv(J_q); + auto qf_output = qf(t, serac::tuple{x_q, J_q}, parent_to_physical(inputs[i], invJ_q) ...); + physical_to_parent(qf_output, invJ_q, detJ_q); + outputs[i] = qf_output; } return outputs; } +template < uint32_t differentiation_index, typename derivative_type, typename lambda, int dim, int n, typename... T> +SERAC_HOST_DEVICE auto batch_apply_qf_derivative(derivative_type * doutputs, lambda qf, double t, const tensor& x, + const tensor& J, const tensor &... inputs) +{ + for (int i = 0; i < n; i++) { + tensor x_q; + tensor J_q; + for (int j = 0; j < dim; j++) { + for (int k = 0; k < dim; k++) { + J_q[j][k] = J(k, j, i); + } + x_q[j] = x(j, i); + } + double detJ_q = det(J_q); + tensor invJ_q = inv(J_q); + auto func = [&](const auto & ... input){ + return physical_to_parent(qf(t, serac::tuple{x_q, J_q}, parent_to_physical(input, invJ_q) ...), invJ_q, detJ_q); + }; + + doutputs[i] = jacfwd(func, inputs[i]...); + + std::cout << doutputs[i] << std::endl; + std::cout << get_gradient(func(make_dual(inputs[i]...))) << std::endl; + + #if 0 + tuple{ + tuple{0.1309, {-0.0013125, 0.265717, 0.484116}}, + tuple{{-0.004375, 0.510588, 1.58955}, {{3.4324, 0.536657, 1.05022}, {2.15217, 4.00694, 1.96279}, {1.47173, 2.55073, 4.34394}}} + } + tuple{ + tuple{0.1309, {-0.0013125, 0.265717, 0.484116}}, + tuple{{-0.004375, 0.536657, 1.96279}, {{0.510588, 1.58955, 3.4324}, {1.05022, 2.15217, 4.00694}, {1.47173, 2.55073, 4.34394}}}} + } + #endif + + } + + return doutputs; +} + template SERAC_HOST_DEVICE auto batch_apply_qf(lambda qf, double t, const tensor& x, const tensor& J, qpt_data_type* qpt_data, bool update_state, @@ -179,37 +236,21 @@ void evaluation_kernel_impl(trial_element_tuple trial_elements, test_element, do //[[maybe_unused]] static constexpr trial_element_tuple trial_element_tuple{}; // batch-calculate values / derivatives of each trial space, at each quadrature point - [[maybe_unused]] tuple qf_inputs = {promote_each_to_dual_when( - get(trial_elements).interpolate(get(u)[elements[e]], rule))...}; - - // use J_e to transform values / derivatives on the parent element - // to the to the corresponding values / derivatives on the physical element - (parent_to_physical(trial_elements).family>(get(qf_inputs), J_e), ...); + tuple qf_inputs = {get(trial_elements).interpolate(get(u)[elements[e]], rule)...}; // (batch) evalute the q-function at each quadrature point // // note: the weird immediately-invoked lambda expression is // a workaround for a bug in GCC(<12.0) where it fails to // decide which function overload to use, and crashes - auto qf_outputs = [&]() { - if constexpr (std::is_same_v) { - return batch_apply_qf_no_qdata(qf, t, x_e, J_e, get(qf_inputs)...); - } else { - return batch_apply_qf(qf, t, x_e, J_e, &qf_state(e, 0), update_state, get(qf_inputs)...); - } - }(); - - // use J to transform sources / fluxes on the physical element - // back to the corresponding sources / fluxes on the parent element - physical_to_parent(qf_outputs, J_e); + // TODO: reenable internal_variables + auto qf_outputs = batch_apply_qf_no_qdata(qf, t, x_e, J_e, get(qf_inputs)...); // write out the q-function derivatives after applying the // physical_to_parent transformation, so that those transformations // won't need to be applied in the action_of_gradient and element_gradient kernels if constexpr (differentiation_index != serac::NO_DIFFERENTIATION) { - for (int q = 0; q < leading_dimension(qf_outputs); q++) { - qf_derivatives[e * uint32_t(qpts_per_elem) + uint32_t(q)] = get_gradient(qf_outputs[q]); - } + batch_apply_qf_derivative(qf_derivatives + e * uint32_t(qpts_per_elem), qf, t, x_e, J_e, get(qf_inputs)...); } // (batch) integrate the material response against the test-space basis functions @@ -229,6 +270,14 @@ SERAC_HOST_DEVICE auto chain_rule(const S& dfdx, const T& dx) } if constexpr (!is_QOI) { +#if 0 + serac::tuple< + serac::tuple< double, serac::tensor >, + serac::tensor< serac::tuple >, 2> >, 3, serac::tuple + +#endif + + return serac::tuple{serac::chain_rule(serac::get<0>(serac::get<0>(dfdx)), serac::get<0>(dx)) + serac::chain_rule(serac::get<1>(serac::get<0>(dfdx)), serac::get<1>(dx)), serac::chain_rule(serac::get<0>(serac::get<1>(dfdx)), serac::get<0>(dx)) + diff --git a/src/serac/numerics/functional/domain_integral_kernels_new.hpp b/src/serac/numerics/functional/domain_integral_kernels_old.hpp similarity index 85% rename from src/serac/numerics/functional/domain_integral_kernels_new.hpp rename to src/serac/numerics/functional/domain_integral_kernels_old.hpp index 128f74982c..9f13cd42fb 100644 --- a/src/serac/numerics/functional/domain_integral_kernels_new.hpp +++ b/src/serac/numerics/functional/domain_integral_kernels_old.hpp @@ -102,76 +102,25 @@ auto get_derivative_type(lambda qf, qpt_data_type&& qpt_data) make_dual_wrt(qf_arguments{}))); }; - -template -SERAC_HOST_DEVICE auto parent_to_physical(const tuple& qf_input, const tensor & invJ) { - return tuple{get<0>(qf_input), dot(get<1>(qf_input), invJ)}; -} - -template -SERAC_HOST_DEVICE auto physical_to_parent(const tuple& qf_output, const tensor & invJ, double detJ) { - // assumes family == Family::H1 for now - return tuple{get<0>(qf_output) * detJ, dot(get<1>(qf_output), transpose(invJ)) * detJ}; -} - -double scalar_type(double) { return 0.0; } - -template < typename T > -dual scalar_type(dual) { return dual{}; } - -template < typename T, int ... n > -T scalar_type(tensor) { return {}; } - -template < typename T1, typename T2 > -auto scalar_type(tuple< T1, T2 >) { return tuple{ scalar_type(T1{}), scalar_type(T2{}) }; } - -template < uint32_t differentiation_index, typename lambda, int dim, int n, typename... T> +template SERAC_HOST_DEVICE auto batch_apply_qf_no_qdata(lambda qf, double t, const tensor& x, const tensor& J, const T&... inputs) { - if constexpr (differentiation_index == NO_DIFFERENTIATION) { - using position_t = serac::tuple, tensor >; - using return_type = decltype(qf(double{}, position_t{}, T{}[0]...)); - - tensor outputs{}; - for (int i = 0; i < n; i++) { - tensor x_q; - tensor J_q; - for (int j = 0; j < dim; j++) { - for (int k = 0; k < dim; k++) { - J_q[j][k] = J(k, j, i); - } - x_q[j] = x(j, i); - } - double detJ_q = det(J_q); - tensor invJ_q = inv(J_q); - auto qf_output = qf(t, serac::tuple{x_q, J_q}, parent_to_physical(inputs[i], invJ_q) ...); - physical_to_parent(qf_output, invJ_q, detJ_q); - outputs[i] = qf_output; - } - return outputs; - - } else { - using position_t = serac::tuple, tensor >; - using return_value_type = decltype(qf(double{}, position_t{}, get_value(T{}[0])...)); - using gradient_type = decltype(get_gradient((scalar_type(inputs[0]) + ...))); - - tensor outputs{}; - for (int i = 0; i < n; i++) { - tensor x_q; - tensor J_q; - for (int j = 0; j < dim; j++) { - for (int k = 0; k < dim; k++) { - J_q[j][k] = J(k, j, i); - } - x_q[j] = x(j, i); + using position_t = serac::tuple, tensor >; + using return_type = decltype(qf(double{}, position_t{}, T{}[0]...)); + tensor outputs{}; + for (int i = 0; i < n; i++) { + tensor x_q; + tensor J_q; + for (int j = 0; j < dim; j++) { + for (int k = 0; k < dim; k++) { + J_q[j][k] = J(k, j, i); } - double detJ_q = det(J_q); - tensor invJ_q = inv(J_q); - outputs[i] = physical_to_parent(qf(t, serac::tuple{x_q, J_q}, parent_to_physical(inputs[i], invJ_q) ...), invJ_q, detJ_q); + x_q[j] = x(j, i); } - return outputs; + outputs[i] = qf(t, serac::tuple{x_q, J_q}, inputs[i]...); } + return outputs; } template @@ -230,21 +179,36 @@ void evaluation_kernel_impl(trial_element_tuple trial_elements, test_element, do //[[maybe_unused]] static constexpr trial_element_tuple trial_element_tuple{}; // batch-calculate values / derivatives of each trial space, at each quadrature point - tuple qf_inputs = {get(trial_elements).interpolate(get(u)[elements[e]], rule)...}; + [[maybe_unused]] tuple qf_inputs = {promote_each_to_dual_when( + get(trial_elements).interpolate(get(u)[elements[e]], rule))...}; + + // use J_e to transform values / derivatives on the parent element + // to the to the corresponding values / derivatives on the physical element + (parent_to_physical(trial_elements).family>(get(qf_inputs), J_e), ...); // (batch) evalute the q-function at each quadrature point // // note: the weird immediately-invoked lambda expression is // a workaround for a bug in GCC(<12.0) where it fails to // decide which function overload to use, and crashes - auto qf_outputs = batch_apply_qf_no_qdata(qf, t, x_e, J_e, get(qf_inputs)...); + auto qf_outputs = [&]() { + if constexpr (std::is_same_v) { + return batch_apply_qf_no_qdata(qf, t, x_e, J_e, get(qf_inputs)...); + } else { + return batch_apply_qf(qf, t, x_e, J_e, &qf_state(e, 0), update_state, get(qf_inputs)...); + } + }(); + + // use J to transform sources / fluxes on the physical element + // back to the corresponding sources / fluxes on the parent element + physical_to_parent(qf_outputs, J_e); // write out the q-function derivatives after applying the // physical_to_parent transformation, so that those transformations // won't need to be applied in the action_of_gradient and element_gradient kernels if constexpr (differentiation_index != serac::NO_DIFFERENTIATION) { for (int q = 0; q < leading_dimension(qf_outputs); q++) { - //qf_derivatives[e * uint32_t(qpts_per_elem) + uint32_t(q)] = get_gradient(qf_outputs[q]); + qf_derivatives[e * uint32_t(qpts_per_elem) + uint32_t(q)] = get_gradient(qf_outputs[q]); } } diff --git a/src/serac/numerics/functional/enzyme_wrapper.hpp b/src/serac/numerics/functional/enzyme_wrapper.hpp new file mode 100644 index 0000000000..3ca4317124 --- /dev/null +++ b/src/serac/numerics/functional/enzyme_wrapper.hpp @@ -0,0 +1,241 @@ +#pragma once + +#include + +#include "serac/numerics/functional/tuple.hpp" +#include "serac/numerics/functional/tensor.hpp" + +namespace serac { + +namespace impl { + + constexpr int vsize(const double &) { return 1; } + + template < int n > + constexpr int vsize(const tensor< double, n > &) { return n; } + + template < int m, int n > + constexpr int vsize(const tensor< double, m, n > &) { return m * n; } + + template < typename T0, typename T1 > + constexpr int vsize(const tuple< T0, T1 > &) { return vsize(T0{}) + vsize(T1{}); } + + template < typename T0, typename T1 > + struct nested; + + template <> + struct nested< double, double >{ using type = double; }; + + template < int ... n > + struct nested< double, tensor >{ using type = tensor; }; + + template < typename T0, typename T1 > + struct nested< double, tuple >{ using type = tuple; }; + +//////////////////////////////////////////////////////////////////////////////// + + template < int ... n, typename T > + struct nested< tensor, T >{ using type = tensor; }; + +//////////////////////////////////////////////////////////////////////////////// + + //template < int ... n, typename T > + //struct nested< tensor, T >{ using type = tensor; }; + +// sam: with tuples-of-tuples, figuring out which indices correspond to the directional +// derivatives gets pretty nasty, so I'm commenting these out to prevent jacfwd from +// compiling in that case +// +// template < typename S0, typename S1, typename T > +// struct nested< tuple< S0, S1 >, T >{ +// using type = tuple< typename nested::type, typename nested::type >; +// }; +// +// template < typename S0, typename S1, typename T0, typename T1 > +// struct nested< tuple< S0, S1 >, tuple< T0, T1 > >{ +// using type = tuple< +// tuple< typename nested::type, typename nested::type >, +// tuple< typename nested::type, typename nested::type > +// >; +// }; + +//////////////////////////////////////////////////////////////////////////////// + + template< typename output_type, typename function, typename ... arg_types > + void wrapper(output_type & output, const function & f, const arg_types & ... args) { + output = f(args...); + } + +//////////////////////////////////////////////////////////////////////////////// + + template < typename function, typename input_type > + __attribute__((always_inline)) + auto jvp(const function & f, const input_type & x) { + using output_type = decltype(f(x)); + void * func_ptr = reinterpret_cast(wrapper< output_type, function, input_type >); + return [=](const input_type & dx) { + output_type unused{}; + output_type df{}; + __enzyme_fwddiff(func_ptr, + enzyme_dupnoneed, &unused, &df, + enzyme_const, reinterpret_cast(&f), + enzyme_dup, &x, &dx + ); + return df; + }; + } + +//////////////////////////////////////////////////////////////////////////////// + + template < typename function, typename input_type > + __attribute__((always_inline)) + auto jacfwd(const function & f, const input_type & x) { + using output_type = decltype(f(x)); + using jac_type = typename impl::nested::type; + void * func_ptr = reinterpret_cast(wrapper< output_type, function, input_type >); + + constexpr int m = impl::vsize(output_type{}); + jac_type J{}; + double * J_ptr = reinterpret_cast(&J); + + constexpr int n = impl::vsize(input_type{}); + input_type dx{}; + double * dx_ptr = reinterpret_cast(&dx); + + for (int j = 0; j < n; j++) { + dx_ptr[j] = 1.0; + + std::cout << dx << std::endl; + + output_type unused{}; + output_type df_dxj{}; + __enzyme_fwddiff(func_ptr, + enzyme_dupnoneed, &unused, &df_dxj, + enzyme_const, reinterpret_cast(&f), + enzyme_dup, &x, &dx + ); + + std::cout << df_dxj << std::endl; + + double * df_dxj_ptr = reinterpret_cast(&df_dxj); + for (int i = 0; i < m; i++) { + J_ptr[i * n + j] = df_dxj_ptr[i]; + } + + std::cout << J << std::endl; + + dx_ptr[j] = 0.0; + } + + return J; + } + +} + +//////////////////////////////////////////////////////////////////////////////// + +template < int i, typename function, typename T0 > +__attribute__((always_inline)) +auto jvp(const function & f, const T0 & arg0) { + if constexpr (i == 0) { return impl::jvp(f, arg0); } +} + +template < int i, typename function, typename T0, typename T1 > +__attribute__((always_inline)) +auto jvp(const function & f, const T0 & arg0, const T1 & arg1) { + if constexpr (i == 0) { return impl::jvp([&](T0 x){ return f(x, arg1); }, arg0); } + if constexpr (i == 1) { return impl::jvp([&](T1 x){ return f(arg0, x); }, arg1); } +} + +template < int i, typename function, typename T0, typename T1, typename T2 > +__attribute__((always_inline)) +auto jvp(const function & f, const T0 & arg0, const T1 & arg1, const T2 & arg2) { + if constexpr (i == 0) { return impl::jvp([&](const T0 & x){ return f(x, arg1, arg2); }, arg0); } + if constexpr (i == 1) { return impl::jvp([&](const T1 & x){ return f(arg0, x, arg2); }, arg1); } + if constexpr (i == 2) { return impl::jvp([&](const T2 & x){ return f(arg0, arg1, x); }, arg2); } +} + +template < int i, typename function, typename T0, typename T1, typename T2, typename T3 > +__attribute__((always_inline)) +auto jvp(const function & f, const T0 & arg0, const T1 & arg1, const T2 & arg2, const T3 & arg3) { + if constexpr (i == 0) { return impl::jvp([&](const T0 & x){ return f(x, arg1, arg2, arg3); }, arg0); } + if constexpr (i == 1) { return impl::jvp([&](const T1 & x){ return f(arg0, x, arg2, arg3); }, arg1); } + if constexpr (i == 2) { return impl::jvp([&](const T2 & x){ return f(arg0, arg1, x, arg3); }, arg2); } + if constexpr (i == 3) { return impl::jvp([&](const T3 & x){ return f(arg0, arg1, arg2, x); }, arg3); } +} + +template < int i, typename function, typename T0, typename T1, typename T2, typename T3, typename T4 > +__attribute__((always_inline)) +auto jvp(const function & f, const T0 & arg0, const T1 & arg1, const T2 & arg2, const T3 & arg3, const T4 & arg4) { + if constexpr (i == 0) { return impl::jvp([&](const T0 & x){ return f(x, arg1, arg2, arg3, arg4); }, arg0); } + if constexpr (i == 1) { return impl::jvp([&](const T1 & x){ return f(arg0, x, arg2, arg3, arg4); }, arg1); } + if constexpr (i == 2) { return impl::jvp([&](const T2 & x){ return f(arg0, arg1, x, arg3, arg4); }, arg2); } + if constexpr (i == 3) { return impl::jvp([&](const T3 & x){ return f(arg0, arg1, arg2, x, arg4); }, arg3); } + if constexpr (i == 4) { return impl::jvp([&](const T4 & x){ return f(arg0, arg1, arg2, arg3, x); }, arg4); } +} + +template < int i, typename function, typename T0, typename T1, typename T2, typename T3, typename T4, typename T5 > +__attribute__((always_inline)) +auto jvp(const function & f, const T0 & arg0, const T1 & arg1, const T2 & arg2, const T3 & arg3, const T4 & arg4, const T5 & arg5) { + if constexpr (i == 0) { return impl::jvp([&](const T0 & x){ return f(x, arg1, arg2, arg3, arg4, arg5); }, arg0); } + if constexpr (i == 1) { return impl::jvp([&](const T1 & x){ return f(arg0, x, arg2, arg3, arg4, arg5); }, arg1); } + if constexpr (i == 2) { return impl::jvp([&](const T2 & x){ return f(arg0, arg1, x, arg3, arg4, arg5); }, arg2); } + if constexpr (i == 3) { return impl::jvp([&](const T3 & x){ return f(arg0, arg1, arg2, x, arg4, arg5); }, arg3); } + if constexpr (i == 4) { return impl::jvp([&](const T4 & x){ return f(arg0, arg1, arg2, arg3, x, arg5); }, arg4); } + if constexpr (i == 5) { return impl::jvp([&](const T5 & x){ return f(arg0, arg1, arg2, arg3, arg4, x); }, arg5); } +} + +//////////////////////////////////////////////////////////////////////////////// + +template < int i, typename function, typename T0 > +__attribute__((always_inline)) +auto jacfwd(const function & f, const T0 & arg0) { + if constexpr (i == 0) { return impl::jacfwd(f, arg0); } +} + +template < int i, typename function, typename T0, typename T1 > +__attribute__((always_inline)) +auto jacfwd(const function & f, const T0 & arg0, const T1 & arg1) { + if constexpr (i == 0) { return impl::jacfwd([&](T0 x){ return f(x, arg1); }, arg0); } + if constexpr (i == 1) { return impl::jacfwd([&](T1 x){ return f(arg0, x); }, arg1); } +} + +template < int i, typename function, typename T0, typename T1, typename T2 > +__attribute__((always_inline)) +auto jacfwd(const function & f, const T0 & arg0, const T1 & arg1, const T2 & arg2) { + if constexpr (i == 0) { return impl::jacfwd([&](T0 x){ return f(x, arg1, arg2); }, arg0); } + if constexpr (i == 1) { return impl::jacfwd([&](T1 x){ return f(arg0, x, arg2); }, arg1); } + if constexpr (i == 2) { return impl::jacfwd([&](T2 x){ return f(arg0, arg1, x); }, arg2); } +} + +template < int i, typename function, typename T0, typename T1, typename T2, typename T3 > +__attribute__((always_inline)) +auto jacfwd(const function & f, const T0 & arg0, const T1 & arg1, const T2 & arg2, const T3 & arg3) { + if constexpr (i == 0) { return impl::jacfwd([&](T0 x){ return f(x, arg1, arg2, arg3); }, arg0); } + if constexpr (i == 1) { return impl::jacfwd([&](T1 x){ return f(arg0, x, arg2, arg3); }, arg1); } + if constexpr (i == 2) { return impl::jacfwd([&](T2 x){ return f(arg0, arg1, x, arg3); }, arg2); } + if constexpr (i == 3) { return impl::jacfwd([&](T3 x){ return f(arg0, arg1, arg2, x); }, arg3); } +} + +template < int i, typename function, typename T0, typename T1, typename T2, typename T3, typename T4 > +__attribute__((always_inline)) +auto jacfwd(const function & f, const T0 & arg0, const T1 & arg1, const T2 & arg2, const T3 & arg3, const T4 & arg4) { + if constexpr (i == 0) { return impl::jacfwd([&](T0 x){ return f(x, arg1, arg2, arg3, arg4); }, arg0); } + if constexpr (i == 1) { return impl::jacfwd([&](T1 x){ return f(arg0, x, arg2, arg3, arg4); }, arg1); } + if constexpr (i == 2) { return impl::jacfwd([&](T2 x){ return f(arg0, arg1, x, arg3, arg4); }, arg2); } + if constexpr (i == 3) { return impl::jacfwd([&](T3 x){ return f(arg0, arg1, arg2, x, arg4); }, arg3); } + if constexpr (i == 4) { return impl::jacfwd([&](T4 x){ return f(arg0, arg1, arg2, arg3, x); }, arg4); } +} + +template < int i, typename function, typename T0, typename T1, typename T2, typename T3, typename T4, typename T5 > +__attribute__((always_inline)) +auto jacfwd(const function & f, const T0 & arg0, const T1 & arg1, const T2 & arg2, const T3 & arg3, const T4 & arg4, const T5 & arg5) { + if constexpr (i == 0) { return impl::jacfwd([&](T0 x){ return f(x, arg1, arg2, arg3, arg4, arg5); }, arg0); } + if constexpr (i == 1) { return impl::jacfwd([&](T1 x){ return f(arg0, x, arg2, arg3, arg4, arg5); }, arg1); } + if constexpr (i == 2) { return impl::jacfwd([&](T2 x){ return f(arg0, arg1, x, arg3, arg4, arg5); }, arg2); } + if constexpr (i == 3) { return impl::jacfwd([&](T3 x){ return f(arg0, arg1, arg2, x, arg4, arg5); }, arg3); } + if constexpr (i == 4) { return impl::jacfwd([&](T4 x){ return f(arg0, arg1, arg2, arg3, x, arg5); }, arg4); } + if constexpr (i == 5) { return impl::jacfwd([&](T5 x){ return f(arg0, arg1, arg2, arg3, arg4, x); }, arg5); } +} + +} diff --git a/src/serac/numerics/functional/tensor.hpp b/src/serac/numerics/functional/tensor.hpp index b3733d20e6..8d2c3b4b01 100644 --- a/src/serac/numerics/functional/tensor.hpp +++ b/src/serac/numerics/functional/tensor.hpp @@ -1810,6 +1810,20 @@ SERAC_HOST_DEVICE constexpr auto chain_rule(const tensor& df_dx, c return total; } +/** + * @overload + * @note for a vector-valued function of a tensor, the chain rule contracts over all indices of dx + */ +template +SERAC_HOST_DEVICE constexpr auto chain_rule(const tensor< tensor< double, n ... >, m >& df_dx, const tensor& dx) +{ + tensor total{}; + for (int i = 0; i < m; i++) { + total[i] = chain_rule(df_dx[i], dx); + } + return total; +} + /** * @overload * @note for a vector-valued function of a tensor, the chain rule contracts over all indices of dx diff --git a/src/serac/numerics/functional/tests/CMakeLists.txt b/src/serac/numerics/functional/tests/CMakeLists.txt index 06f3bba69e..81c3fd7797 100644 --- a/src/serac/numerics/functional/tests/CMakeLists.txt +++ b/src/serac/numerics/functional/tests/CMakeLists.txt @@ -69,9 +69,14 @@ if(ENABLE_CUDA) endif() +target_compile_options(functional_basic_h1_scalar PUBLIC -mllvm -enzyme-loose-types) + add_executable(enzyme_test enzyme_test.cpp) target_link_libraries(enzyme_test PUBLIC serac_functional) +add_executable(enzyme_jvp_test enzyme_jvp_test.cpp) +target_link_libraries(enzyme_jvp_test PUBLIC serac_functional) + add_executable(enzyme_reproducer enzyme_reproducer.cpp) target_link_libraries(enzyme_reproducer PUBLIC serac_functional) diff --git a/src/serac/numerics/functional/tests/enzyme_jvp_test.cpp b/src/serac/numerics/functional/tests/enzyme_jvp_test.cpp new file mode 100644 index 0000000000..81acd32352 --- /dev/null +++ b/src/serac/numerics/functional/tests/enzyme_jvp_test.cpp @@ -0,0 +1,84 @@ +#include + +#include "serac/numerics/functional/dual.hpp" +#include "serac/numerics/functional/tuple.hpp" +#include "serac/numerics/functional/tensor.hpp" +#include "serac/numerics/functional/enzyme_wrapper.hpp" + +namespace serac { + +namespace impl { + +template < typename function, typename input_type > +__attribute__((always_inline)) +auto jvp(const function & f, const input_type & x) { + using output_type = decltype(f(x)); + void * func_ptr = reinterpret_cast(wrapper< output_type, function, input_type >); + return [=](const input_type & dx) { + output_type unused{}; + output_type df{}; + __enzyme_fwddiff(func_ptr, + enzyme_dupnoneed, &unused, &df, + enzyme_const, reinterpret_cast(&f), + enzyme_dup, &x, &dx + ); + return df; + }; +} + +} + + + +} + +using namespace serac; + +int main() { + + auto f = [](double z, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement) { + auto [u, du_dx] = displacement; + return tuple{dot(du_dx, z * u), z * (du_dx + transpose(du_dx)) - outer(u, u)}; + }; + +//////////////////////////////////////////////////////////////////////////////// + + auto f_jvp0 = [](double z, + double dz, + const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement) { + auto [u, du_dx] = displacement; + return tuple{dot(du_dx, u) * dz, (du_dx + transpose(du_dx)) * dz}; + }; + + auto f_jvp1 = [](double z, + const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement, + const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & ddisplacement) { + auto [u, du_dx] = displacement; + auto [du, ddu_dx] = ddisplacement; + vec3 df1 = dot(du_dx, du) * z + dot(ddu_dx, u) * z; + mat3 df2 = outer(du, u) - outer(u, du) + (ddu_dx + transpose(ddu_dx)) * z; + return tuple{df1, df2}; + }; + +//////////////////////////////////////////////////////////////////////////////// + + double eps = 1.0e-6; + + double z = 3.0; + double dz = 1.4; + + auto displacement = tuple { + tensor{{1.0, 1.0, 1.0}}, + tensor{{{1.0, 2.0, 3.0}, {2.0, 3.0, 1.0}, {1.0, 0.5, 0.2}}} + }; + + auto ddisplacement = tuple { + tensor{{0.1, 0.8, -0.3}}, + tensor{{{0.2, -0.4, 0.2}, {0.1, 0.8, 0.5}, {0.3, 0.7, 1.8}}} + }; + + std::cout << "expected: " << f_jvp0(z, dz, displacement) << std::endl; + std::cout << "enzyme: " << jvp<0>(f, z, displacement)(dz) << std::endl; + std::cout << "finite_difference: " << ((f(z + eps * dz, displacement) - f(z, displacement)) / eps) << std::endl; + +} diff --git a/src/serac/numerics/functional/tests/enzyme_test.cpp b/src/serac/numerics/functional/tests/enzyme_test.cpp index fd84d59c32..94441ab7c6 100644 --- a/src/serac/numerics/functional/tests/enzyme_test.cpp +++ b/src/serac/numerics/functional/tests/enzyme_test.cpp @@ -1,182 +1,32 @@ #include -#include - +#include "serac/numerics/functional/dual.hpp" #include "serac/numerics/functional/tuple.hpp" #include "serac/numerics/functional/tensor.hpp" +#include "serac/numerics/functional/enzyme_wrapper.hpp" -#if 0 -double square(double x) { - return x * x; -} - -double dsquare(double x) { - return __enzyme_autodiff(reinterpret_cast(square), x); -} -#endif +template < typename T > +struct undefined; using namespace serac; -namespace impl { - - constexpr int vsize(const double &) { return 1; } - - template < int n > - constexpr int vsize(const serac::tensor< double, n > &) { return n; } - - template < int m, int n > - constexpr int vsize(const serac::tensor< double, m, n > &) { return m * n; } - - template < typename T0, typename T1 > - constexpr int vsize(const serac::tuple< T0, T1 > &) { return vsize(T0{}) + vsize(T1{}); } - - template < typename T0, typename T1 > - struct nested; - - template <> - struct nested< double, double >{ using type = double; }; - - template < int ... n > - struct nested< double, tensor >{ using type = tensor; }; - - template < typename T0, typename T1 > - struct nested< double, tuple >{ using type = tuple; }; - -//////////////////////////////////////////////////////////////////////////////// - - template < int ... n, typename T > - struct nested< tensor, T >{ using type = tensor; }; - -//////////////////////////////////////////////////////////////////////////////// - - template < typename S0, typename S1, typename T > - struct nested< tuple< S0, S1 >, T >{ - using type = tuple< typename nested::type, typename nested::type >; - }; - -} - -template < typename S, typename T0, typename T1 > -struct serac::detail::outer_prod< S, serac::tuple >{ - using type = serac::tuple < - typename outer_prod< S, T0 >::type, - typename outer_prod< S, T1 >::type - >; -}; - -template< typename output_type, typename function, typename ... arg_types > -void wrapper(output_type & output, const function & f, const arg_types & ... args) { - output = f(args...); -} - -template < typename function, typename input_type > -__attribute__((always_inline)) -auto jacfwd(const function & f, const input_type & x) { - using output_type = decltype(f(x)); - using jac_type = typename impl::nested::type; - void * func_ptr = reinterpret_cast(wrapper< output_type, function, input_type >); - - constexpr int m = impl::vsize(output_type{}); - jac_type J{}; - double * J_ptr = reinterpret_cast(&J); - - constexpr int n = impl::vsize(input_type{}); - input_type dx{}; - double * dx_ptr = reinterpret_cast(&dx); - - for (int j = 0; j < n; j++) { - dx_ptr[j] = 1.0; - - output_type unused{}; - output_type df_dxj{}; - __enzyme_fwddiff(func_ptr, - enzyme_dupnoneed, &unused, &df_dxj, - enzyme_const, reinterpret_cast(&f), - enzyme_dup, &x, &dx - ); - - double * df_dxj_ptr = reinterpret_cast(&df_dxj); - for (int i = 0; i < m; i++) { - J_ptr[i * n + j] = df_dxj_ptr[i]; - } - - dx_ptr[j] = 0.0; - } - - return J; -} - -template < int i, typename function, typename T0, typename T1 > -__attribute__((always_inline)) -auto jacfwd(const function & f, const T0 & arg0, const T1 & arg1) { - if constexpr (i == 0) { return jacfwd([&](T0 x){ return f(x, arg1); }, arg0); } - if constexpr (i == 1) { return jacfwd([&](T1 x){ return f(arg0, x); }, arg1); } -} - -template < int i, typename function, typename T0, typename T1, typename T2 > -__attribute__((always_inline)) -auto jacfwd(const function & f, const T0 & arg0, const T1 & arg1, const T2 & arg2) { - if constexpr (i == 0) { return jacfwd([&](T0 x){ return f(x, arg1, arg2); }, arg0); } - if constexpr (i == 1) { return jacfwd([&](T1 x){ return f(arg0, x, arg2); }, arg1); } - if constexpr (i == 2) { return jacfwd([&](T2 x){ return f(arg0, arg1, x); }, arg2); } -} - -template < int i, typename function, typename T0, typename T1, typename T2, typename T3 > -__attribute__((always_inline)) -auto jacfwd(const function & f, const T0 & arg0, const T1 & arg1, const T2 & arg2, const T3 & arg3) { - if constexpr (i == 0) { return jacfwd([&](T0 x){ return f(x, arg1, arg2, arg3); }, arg0); } - if constexpr (i == 1) { return jacfwd([&](T1 x){ return f(arg0, x, arg2, arg3); }, arg1); } - if constexpr (i == 2) { return jacfwd([&](T2 x){ return f(arg0, arg1, x, arg3); }, arg2); } - if constexpr (i == 3) { return jacfwd([&](T3 x){ return f(arg0, arg1, arg2, x); }, arg3); } -} - -template < int i, typename function, typename T0, typename T1, typename T2, typename T3, typename T4 > -__attribute__((always_inline)) -auto jacfwd(const function & f, const T0 & arg0, const T1 & arg1, const T2 & arg2, const T3 & arg3, const T4 & arg4) { - if constexpr (i == 0) { return jacfwd([&](T0 x){ return f(x, arg1, arg2, arg3, arg4); }, arg0); } - if constexpr (i == 1) { return jacfwd([&](T1 x){ return f(arg0, x, arg2, arg3, arg4); }, arg1); } - if constexpr (i == 2) { return jacfwd([&](T2 x){ return f(arg0, arg1, x, arg3, arg4); }, arg2); } - if constexpr (i == 3) { return jacfwd([&](T3 x){ return f(arg0, arg1, arg2, x, arg4); }, arg3); } - if constexpr (i == 4) { return jacfwd([&](T4 x){ return f(arg0, arg1, arg2, arg3, x); }, arg4); } -} - -template < int i, typename function, typename T0, typename T1, typename T2, typename T3, typename T4, typename T5 > -__attribute__((always_inline)) -auto jacfwd(const function & f, const T0 & arg0, const T1 & arg1, const T2 & arg2, const T3 & arg3, const T4 & arg4, const T5 & arg5) { - if constexpr (i == 0) { return jacfwd([&](T0 x){ return f(x, arg1, arg2, arg3, arg4, arg5); }, arg0); } - if constexpr (i == 1) { return jacfwd([&](T1 x){ return f(arg0, x, arg2, arg3, arg4, arg5); }, arg1); } - if constexpr (i == 2) { return jacfwd([&](T2 x){ return f(arg0, arg1, x, arg3, arg4, arg5); }, arg2); } - if constexpr (i == 3) { return jacfwd([&](T3 x){ return f(arg0, arg1, arg2, x, arg4, arg5); }, arg3); } - if constexpr (i == 4) { return jacfwd([&](T4 x){ return f(arg0, arg1, arg2, arg3, x, arg5); }, arg4); } - if constexpr (i == 5) { return jacfwd([&](T5 x){ return f(arg0, arg1, arg2, arg3, arg4, x); }, arg5); } -} - int main() { - auto fg = [](double z, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement) { + auto f = [](double z, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement) { auto [u, du_dx] = displacement; - return tuple{dot(du_dx, z * u), z * (du_dx + transpose(du_dx)) - outer(u, u)}; + return z * (du_dx + transpose(du_dx)) - outer(u, u); }; //////////////////////////////////////////////////////////////////////////////// auto dfdz = [](double, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement) { - auto [u, du_dx] = displacement; - return dot(du_dx, u); - }; - - auto dgdz = [](double, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement) { auto [u, du_dx] = displacement; return (du_dx + transpose(du_dx)); }; //////////////////////////////////////////////////////////////////////////////// - auto dfdu = [](double z, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement) { - return get<1>(displacement) * z; - }; - - auto dgdu = [](double, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement) { + auto dfdu = [](double, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement) { auto u = get<0>(displacement); tensor output{}; for (int k = 0; k < 3; k++) { @@ -191,20 +41,7 @@ int main() { //////////////////////////////////////////////////////////////////////////////// - auto dfddudx = [](double z, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement) { - auto u = get<0>(displacement); - tensor output{}; - for (int k = 0; k < 3; k++) { - for (int j = 0; j < 3; j++) { - for (int i = 0; i < 3; i++) { - output(i,j,k) = z * u[k] * (i == j); - } - } - } - return output; - }; - - auto dgddudx = [](double z, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > &) { + auto dfddudx = [](double z, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > &) { tensor output{}; for (int l = 0; l < 3; l++) { for (int k = 0; k < 3; k++) { @@ -226,65 +63,51 @@ int main() { tensor{{{1.0, 2.0, 3.0}, {2.0, 3.0, 1.0}, {1.0, 0.5, 0.2}}} }; - auto [df_dz, dg_dz] = jacfwd<0>(fg, z, displacement); + auto df_dz = jacfwd<0>(f, z, displacement); std::cout << "df_dz: " << df_dz << std::endl; std::cout << "expected: " << dfdz(z, displacement) << std::endl; std::cout << std::endl; - std::cout << "df_dz: " << dg_dz << std::endl; - std::cout << "expected: " << dgdz(z, displacement) << std::endl; - std::cout << std::endl; - - auto [df_ddisp, dg_ddisp] = jacfwd<1>(fg, z, displacement); + auto df_ddisp = jacfwd<1>(f, z, displacement); std::cout << "df_du: "; std::cout << "{"; for (int i = 0; i < 3; i++) { - std::cout << get<0>(df_ddisp(i)); - if (i != 2) { std::cout << ","; } - } - std::cout << "}" << std::endl; - std::cout << "expected: " << dfdu(z, displacement) << std::endl; - std::cout << std::endl; - - std::cout << "df_d(du_dx): "; - std::cout << "{"; - for (int i = 0; i < 3; i++) { - std::cout << get<1>(df_ddisp(i)); - if (i != 2) { std::cout << ","; } - } - std::cout << "}" << std::endl; - std::cout << "expected: " << dfddudx(z, displacement) << std::endl; - std::cout << std::endl; - - std::cout << "dg_du: "; - std::cout << "{"; - for (int i = 0; i < 3; i++) { - std::cout << "{"; for (int j = 0; j < 3; j++) { - std::cout << get<0>(dg_ddisp(i,j)); + std::cout << get<0>(df_ddisp(i,j)); if (j != 2) { std::cout << ","; } } - std::cout << "}"; if (i != 2) { std::cout << ","; } } std::cout << "}" << std::endl; - std::cout << "expected: " << dgdu(z, displacement) << std::endl; + std::cout << "expected: " << dfdu(z, displacement) << std::endl; std::cout << std::endl; - - std::cout << "dg_d(du_dx): "; + std::cout << "df_d(du_dx): "; std::cout << "{"; for (int i = 0; i < 3; i++) { - std::cout << "{"; for (int j = 0; j < 3; j++) { - std::cout << get<1>(dg_ddisp(i,j)); + std::cout << get<1>(df_ddisp(i,j)); if (j != 2) { std::cout << ","; } } - std::cout << "}"; if (i != 2) { std::cout << ","; } } std::cout << "}" << std::endl; - std::cout << "expected: " << dgddudx(z, displacement) << std::endl; + std::cout << "expected: " << dfddudx(z, displacement) << std::endl; std::cout << std::endl; + //tuple< double, vec3 > x; + //serac::tuple< + // serac::tuple, + // serac::tuple + //> + //auto x_dx = make_dual(x); + //undefined< decltype(get_gradient(x_dx)) > foo; + + //serac::tuple< + // serac::tuple>, + // serac::tuple, serac::tensor, 3> + //> + //using derivative_type = impl::nested< tuple< double, vec3 >, tuple< double, vec3 > >::type; + //undefined< derivative_type > foo; + } From 07162d9b0eecad292f2b1603e8f4dce5d8c2bf69 Mon Sep 17 00:00:00 2001 From: Sam Mish Date: Tue, 28 May 2024 12:59:55 -0700 Subject: [PATCH 08/12] adding some more tests --- .../functional/detail/triangle_H1.inl | 2 +- .../functional/domain_integral_kernels.hpp | 26 ++- .../numerics/functional/enzyme_wrapper.hpp | 148 +++++++++++++----- src/serac/numerics/functional/tensor.hpp | 12 ++ .../numerics/functional/tests/CMakeLists.txt | 6 + .../functional/tests/enzyme_jvp_perf_test.cxx | 69 ++++++++ .../functional/tests/enzyme_jvp_test.cpp | 27 ---- .../functional/tests/enzyme_test_tuples.cpp | 92 +++++++++++ .../tests/enzyme_test_tuples_of_tuples.cpp | 147 +++++++++++++++++ 9 files changed, 444 insertions(+), 85 deletions(-) create mode 100644 src/serac/numerics/functional/tests/enzyme_jvp_perf_test.cxx create mode 100644 src/serac/numerics/functional/tests/enzyme_test_tuples.cpp create mode 100644 src/serac/numerics/functional/tests/enzyme_test_tuples_of_tuples.cpp diff --git a/src/serac/numerics/functional/detail/triangle_H1.inl b/src/serac/numerics/functional/detail/triangle_H1.inl index 37c1600fa0..fe3bda06d5 100644 --- a/src/serac/numerics/functional/detail/triangle_H1.inl +++ b/src/serac/numerics/functional/detail/triangle_H1.inl @@ -240,7 +240,7 @@ struct finite_element > { } template - static auto batch_apply_shape_fn(int j, tensor input, const TensorProductQuadratureRule&) + static auto batch_apply_shape_fn(int j, tensor input, const TensorProductQuadratureRule&) { using source_t = decltype(get<0>(get<0>(in_t{})) + dot(get<1>(get<0>(in_t{})), tensor{})); using flux_t = decltype(get<0>(get<1>(in_t{})) + dot(get<1>(get<1>(in_t{})), tensor{})); diff --git a/src/serac/numerics/functional/domain_integral_kernels.hpp b/src/serac/numerics/functional/domain_integral_kernels.hpp index ff8dc50737..89df6ab5be 100644 --- a/src/serac/numerics/functional/domain_integral_kernels.hpp +++ b/src/serac/numerics/functional/domain_integral_kernels.hpp @@ -164,16 +164,7 @@ SERAC_HOST_DEVICE auto batch_apply_qf_derivative(derivative_type * doutputs, lam std::cout << doutputs[i] << std::endl; std::cout << get_gradient(func(make_dual(inputs[i]...))) << std::endl; - #if 0 - tuple{ - tuple{0.1309, {-0.0013125, 0.265717, 0.484116}}, - tuple{{-0.004375, 0.510588, 1.58955}, {{3.4324, 0.536657, 1.05022}, {2.15217, 4.00694, 1.96279}, {1.47173, 2.55073, 4.34394}}} - } - tuple{ - tuple{0.1309, {-0.0013125, 0.265717, 0.484116}}, - tuple{{-0.004375, 0.536657, 1.96279}, {{0.510588, 1.58955, 3.4324}, {1.05022, 2.15217, 4.00694}, {1.47173, 2.55073, 4.34394}}}} - } - #endif + //doutputs[i] = get_gradient(func(make_dual(inputs[i]...))); } @@ -210,7 +201,9 @@ template void evaluation_kernel_impl(trial_element_tuple trial_elements, test_element, double t, - const std::vector& inputs, double* outputs, const double* positions, + const std::vector& inputs_e, + double* outputs, + const double* positions, const double* jacobians, lambda_type qf, [[maybe_unused]] axom::ArrayView qf_state, [[maybe_unused]] derivative_type* qf_derivatives, const int* elements, @@ -226,7 +219,7 @@ void evaluation_kernel_impl(trial_element_tuple trial_elements, test_element, do [[maybe_unused]] auto qpts_per_elem = num_quadrature_points(geom, Q); [[maybe_unused]] tuple u = { - reinterpret_cast(trial_elements))::dof_type*>(inputs[indices])...}; + reinterpret_cast(trial_elements))::dof_type*>(inputs_e[indices])...}; // for each element in the domain for (uint32_t e = 0; e < num_elements; ++e) { @@ -271,9 +264,12 @@ SERAC_HOST_DEVICE auto chain_rule(const S& dfdx, const T& dx) if constexpr (!is_QOI) { #if 0 - serac::tuple< - serac::tuple< double, serac::tensor >, - serac::tensor< serac::tuple >, 2> >, 3, serac::tuple + +serac::tuple< + serac::tuple< serac::tensor, 2>, serac::tensor, 2> >, + serac::tuple< serac::tensor, 2, 2>, serac::tensor, 2, 2>> +>, +serac::tuple< serac::tensor, serac::tensor > #endif diff --git a/src/serac/numerics/functional/enzyme_wrapper.hpp b/src/serac/numerics/functional/enzyme_wrapper.hpp index 3ca4317124..83ee107233 100644 --- a/src/serac/numerics/functional/enzyme_wrapper.hpp +++ b/src/serac/numerics/functional/enzyme_wrapper.hpp @@ -39,25 +39,22 @@ namespace impl { //////////////////////////////////////////////////////////////////////////////// - //template < int ... n, typename T > - //struct nested< tensor, T >{ using type = tensor; }; - -// sam: with tuples-of-tuples, figuring out which indices correspond to the directional -// derivatives gets pretty nasty, so I'm commenting these out to prevent jacfwd from -// compiling in that case -// -// template < typename S0, typename S1, typename T > -// struct nested< tuple< S0, S1 >, T >{ -// using type = tuple< typename nested::type, typename nested::type >; -// }; -// -// template < typename S0, typename S1, typename T0, typename T1 > -// struct nested< tuple< S0, S1 >, tuple< T0, T1 > >{ -// using type = tuple< -// tuple< typename nested::type, typename nested::type >, -// tuple< typename nested::type, typename nested::type > -// >; -// }; + // sam: with tuples-of-tuples, figuring out which indices correspond to the directional + // derivatives gets pretty nasty, so I'm commenting these out to prevent jacfwd from + // compiling in that case + // + template < typename S0, typename S1, typename T > + struct nested< tuple< S0, S1 >, T >{ + using type = tuple< typename nested::type, typename nested::type >; + }; + + template < typename S0, typename S1, typename T0, typename T1 > + struct nested< tuple< S0, S1 >, tuple< T0, T1 > >{ + using type = tuple< + tuple< typename nested::type, typename nested::type >, + tuple< typename nested::type, typename nested::type > + >; + }; //////////////////////////////////////////////////////////////////////////////// @@ -93,41 +90,108 @@ namespace impl { using output_type = decltype(f(x)); using jac_type = typename impl::nested::type; void * func_ptr = reinterpret_cast(wrapper< output_type, function, input_type >); - - constexpr int m = impl::vsize(output_type{}); + jac_type J{}; - double * J_ptr = reinterpret_cast(&J); - - constexpr int n = impl::vsize(input_type{}); input_type dx{}; + double * dx_ptr = reinterpret_cast(&dx); - - for (int j = 0; j < n; j++) { - dx_ptr[j] = 1.0; - std::cout << dx << std::endl; + // tuple input and tuple output + if constexpr (is_tuple{} && is_tuple{}) { + + static_assert((tuple_size{} == 2) && (tuple_size{} == 2), "error: jacfwd() currently only supports tuples of 2 values"); + + constexpr int m0 = impl::vsize(get<0>(output_type{})); + constexpr int m1 = impl::vsize(get<1>(output_type{})); + constexpr int n0 = impl::vsize(get<0>(input_type{})); + constexpr int n1 = impl::vsize(get<1>(input_type{})); + + for (int j = 0; j < (n0 + n1); j++) { + dx_ptr[j] = 1.0; + + std::cout << dx << std::endl; - output_type unused{}; - output_type df_dxj{}; - __enzyme_fwddiff(func_ptr, - enzyme_dupnoneed, &unused, &df_dxj, - enzyme_const, reinterpret_cast(&f), - enzyme_dup, &x, &dx - ); + output_type unused{}; + output_type df_dxj{}; + __enzyme_fwddiff(func_ptr, + enzyme_dupnoneed, &unused, &df_dxj, + enzyme_const, reinterpret_cast(&f), + enzyme_dup, &x, &dx + ); + + std::cout << df_dxj << std::endl; + + double * df0_dxj_ptr = reinterpret_cast(&get<0>(df_dxj)); + double * df1_dxj_ptr = reinterpret_cast(&get<1>(df_dxj)); + + if (j < n0) { + int j0 = j; + double * J00_ptr = reinterpret_cast(&get<0>(get<0>(J))); + double * J10_ptr = reinterpret_cast(&get<0>(get<1>(J))); + for (int i0 = 0; i0 < m0; i0++) { + J00_ptr[i0 * n0 + j0] = df0_dxj_ptr[i0]; + } + + for (int i1 = 0; i1 < m1; i1++) { + J10_ptr[i1 * n0 + j0] = df1_dxj_ptr[i1]; + } + } else { + int j1 = j - n0; + double * J01_ptr = reinterpret_cast(&get<1>(get<0>(J))); + double * J11_ptr = reinterpret_cast(&get<1>(get<1>(J))); + for (int i0 = 0; i0 < m0; i0++) { + J01_ptr[i0 * n1 + j1] = df0_dxj_ptr[i0]; + } + + for (int i1 = 0; i1 < m1; i1++) { + J11_ptr[i1 * n1 + j1] = df1_dxj_ptr[i1]; + } + } + + std::cout << J << std::endl; + + dx_ptr[j] = 0.0; + } + + return J; + + } else { + + constexpr int m = impl::vsize(output_type{}); + constexpr int n = impl::vsize(input_type{}); + + double * J_ptr = reinterpret_cast(&J); - std::cout << df_dxj << std::endl; + for (int j = 0; j < n; j++) { + dx_ptr[j] = 1.0; + + std::cout << dx << std::endl; + + output_type unused{}; + output_type df_dxj{}; + __enzyme_fwddiff(func_ptr, + enzyme_dupnoneed, &unused, &df_dxj, + enzyme_const, reinterpret_cast(&f), + enzyme_dup, &x, &dx + ); + + std::cout << df_dxj << std::endl; + + double * df_dxj_ptr = reinterpret_cast(&df_dxj); + for (int i = 0; i < m; i++) { + J_ptr[i * n + j] = df_dxj_ptr[i]; + } - double * df_dxj_ptr = reinterpret_cast(&df_dxj); - for (int i = 0; i < m; i++) { - J_ptr[i * n + j] = df_dxj_ptr[i]; + std::cout << J << std::endl; + + dx_ptr[j] = 0.0; } - std::cout << J << std::endl; + return J; - dx_ptr[j] = 0.0; } - return J; + } } diff --git a/src/serac/numerics/functional/tensor.hpp b/src/serac/numerics/functional/tensor.hpp index 8d2c3b4b01..db4e9687c0 100644 --- a/src/serac/numerics/functional/tensor.hpp +++ b/src/serac/numerics/functional/tensor.hpp @@ -1824,6 +1824,18 @@ SERAC_HOST_DEVICE constexpr auto chain_rule(const tensor< tensor< double, n ... return total; } +template +SERAC_HOST_DEVICE constexpr auto chain_rule(const tensor< tensor< double, n ... >, m1, m2 >& df_dx, const tensor& dx) +{ + tensor total{}; + for (int i1 = 0; i1 < m1; i1++) { + for (int i2 = 0; i2 < m2; i2++) { + total[i1][i2] = chain_rule(df_dx[i1][i2], dx); + } + } + return total; +} + /** * @overload * @note for a vector-valued function of a tensor, the chain rule contracts over all indices of dx diff --git a/src/serac/numerics/functional/tests/CMakeLists.txt b/src/serac/numerics/functional/tests/CMakeLists.txt index 81c3fd7797..aed5bf9367 100644 --- a/src/serac/numerics/functional/tests/CMakeLists.txt +++ b/src/serac/numerics/functional/tests/CMakeLists.txt @@ -77,6 +77,12 @@ target_link_libraries(enzyme_test PUBLIC serac_functional) add_executable(enzyme_jvp_test enzyme_jvp_test.cpp) target_link_libraries(enzyme_jvp_test PUBLIC serac_functional) +add_executable(enzyme_test_tuples enzyme_test_tuples.cpp) +target_link_libraries(enzyme_test_tuples PUBLIC serac_functional) + +add_executable(enzyme_test_tuples_of_tuples enzyme_test_tuples_of_tuples.cpp) +target_link_libraries(enzyme_test_tuples_of_tuples PUBLIC serac_functional) + add_executable(enzyme_reproducer enzyme_reproducer.cpp) target_link_libraries(enzyme_reproducer PUBLIC serac_functional) diff --git a/src/serac/numerics/functional/tests/enzyme_jvp_perf_test.cxx b/src/serac/numerics/functional/tests/enzyme_jvp_perf_test.cxx new file mode 100644 index 0000000000..eb77915c2c --- /dev/null +++ b/src/serac/numerics/functional/tests/enzyme_jvp_perf_test.cxx @@ -0,0 +1,69 @@ +#include + +#include "serac/numerics/functional/tuple.hpp" +#include "serac/numerics/functional/tensor.hpp" +#include "serac/physics/materials/solid_material.hpp" +#include "serac/numerics/functional/enzyme_wrapper.hpp" + +using namespace serac; + +template < int i, typename material_model, typename ... arg_types > +auto precompute_gradient(material_model mat, const std::vector< arg_types > & ... args) { + using output_type = decltype(mat(args[0] ...)); + uint32_t n = + +} + +template < int i, typename material_model, typename ... arg_types, typename darg_type > +double jvp_precomputed_gradient( std::vector< arg_types > & ) { + +} + +int main() { + + auto f = [](double z, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement) { + auto [u, du_dx] = displacement; + return tuple{dot(du_dx, z * u), z * (du_dx + transpose(du_dx)) - outer(u, u)}; + }; + +//////////////////////////////////////////////////////////////////////////////// + + auto f_jvp0 = [](double z, + double dz, + const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement) { + auto [u, du_dx] = displacement; + return tuple{dot(du_dx, u) * dz, (du_dx + transpose(du_dx)) * dz}; + }; + + auto f_jvp1 = [](double z, + const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement, + const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & ddisplacement) { + auto [u, du_dx] = displacement; + auto [du, ddu_dx] = ddisplacement; + vec3 df1 = dot(du_dx, du) * z + dot(ddu_dx, u) * z; + mat3 df2 = outer(du, u) - outer(u, du) + (ddu_dx + transpose(ddu_dx)) * z; + return tuple{df1, df2}; + }; + +//////////////////////////////////////////////////////////////////////////////// + + double eps = 1.0e-6; + + double z = 3.0; + double dz = 1.4; + + auto displacement = tuple { + tensor{{1.0, 1.0, 1.0}}, + tensor{{{1.0, 2.0, 3.0}, {2.0, 3.0, 1.0}, {1.0, 0.5, 0.2}}} + }; + + auto ddisplacement = tuple { + tensor{{0.1, 0.8, -0.3}}, + tensor{{{0.2, -0.4, 0.2}, {0.1, 0.8, 0.5}, {0.3, 0.7, 1.8}}} + }; + + std::cout << "expected: " << f_jvp0(z, dz, displacement) << std::endl; + std::cout << "enzyme: " << jvp<0>(f, z, displacement)(dz) << std::endl; + std::cout << "finite_difference: " << ((f(z + eps * dz, displacement) - f(z, displacement)) / eps) << std::endl; + +} diff --git a/src/serac/numerics/functional/tests/enzyme_jvp_test.cpp b/src/serac/numerics/functional/tests/enzyme_jvp_test.cpp index 81acd32352..04204cfa40 100644 --- a/src/serac/numerics/functional/tests/enzyme_jvp_test.cpp +++ b/src/serac/numerics/functional/tests/enzyme_jvp_test.cpp @@ -5,33 +5,6 @@ #include "serac/numerics/functional/tensor.hpp" #include "serac/numerics/functional/enzyme_wrapper.hpp" -namespace serac { - -namespace impl { - -template < typename function, typename input_type > -__attribute__((always_inline)) -auto jvp(const function & f, const input_type & x) { - using output_type = decltype(f(x)); - void * func_ptr = reinterpret_cast(wrapper< output_type, function, input_type >); - return [=](const input_type & dx) { - output_type unused{}; - output_type df{}; - __enzyme_fwddiff(func_ptr, - enzyme_dupnoneed, &unused, &df, - enzyme_const, reinterpret_cast(&f), - enzyme_dup, &x, &dx - ); - return df; - }; -} - -} - - - -} - using namespace serac; int main() { diff --git a/src/serac/numerics/functional/tests/enzyme_test_tuples.cpp b/src/serac/numerics/functional/tests/enzyme_test_tuples.cpp new file mode 100644 index 0000000000..49f9e942c2 --- /dev/null +++ b/src/serac/numerics/functional/tests/enzyme_test_tuples.cpp @@ -0,0 +1,92 @@ +#include + +#include + +#include "serac/numerics/functional/tuple.hpp" +#include "serac/numerics/functional/tensor.hpp" +#include "serac/numerics/functional/enzyme_wrapper.hpp" + +using namespace serac; + +int main() { + + auto f = [](double z, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement) { + auto [u, du_dx] = displacement; + return z * (du_dx + transpose(du_dx)) - outer(u, u); + }; + + auto dfdz = [](double, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement) { + auto [u, du_dx] = displacement; + return (du_dx + transpose(du_dx)); + }; + + auto dfdu = [](double, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement) { + auto u = get<0>(displacement); + tensor output{}; + for (int k = 0; k < 3; k++) { + for (int j = 0; j < 3; j++) { + for (int i = 0; i < 3; i++) { + output(i,j,k) = - (u(i) * (j == k) + u(j) * (i == k)); + } + } + } + return output; + }; + + auto dfddudx = [](double z, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > &) { + tensor output{}; + for (int l = 0; l < 3; l++) { + for (int k = 0; k < 3; k++) { + for (int j = 0; j < 3; j++) { + for (int i = 0; i < 3; i++) { + output(i,j,k,l) = z * ((i==k) * (j==l) + (j==k) * (i==l)); + } + } + } + } + return output; + }; + + double z = 3.0; + auto displacement = tuple { + tensor{{1.0, 1.0, 1.0}}, + tensor{{{1.0, 2.0, 3.0}, {2.0, 3.0, 1.0}, {1.0, 0.5, 0.2}}} + }; + + auto df_dz = jacfwd<0>(f, z, displacement); + std::cout << "df_dz: " << df_dz << std::endl; + std::cout << "expected: " << dfdz(z, displacement) << std::endl; + std::cout << std::endl; + + auto df_ddisp = jacfwd<1>(f, z, displacement); + std::cout << "df_du: "; + std::cout << "{"; + for (int i = 0; i < 3; i++) { + std::cout << "{"; + for (int j = 0; j < 3; j++) { + std::cout << get<0>(df_ddisp(i,j)); + if (j != 2) { std::cout << ","; } + } + std::cout << "}"; + if (i != 2) { std::cout << ","; } + } + std::cout << "}" << std::endl; + std::cout << "expected: " << dfdu(z, displacement) << std::endl; + std::cout << std::endl; + + std::cout << "df_d(du_dx): "; + std::cout << "{"; + for (int i = 0; i < 3; i++) { + std::cout << "{"; + for (int j = 0; j < 3; j++) { + std::cout << get<1>(df_ddisp(i,j)); + if (j != 2) { std::cout << ","; } + } + std::cout << "}"; + if (i != 2) { std::cout << ","; } + } + std::cout << "}" << std::endl; + std::cout << "expected: " << dfddudx(z, displacement) << std::endl; + std::cout << std::endl; + +} diff --git a/src/serac/numerics/functional/tests/enzyme_test_tuples_of_tuples.cpp b/src/serac/numerics/functional/tests/enzyme_test_tuples_of_tuples.cpp new file mode 100644 index 0000000000..4222dda9b2 --- /dev/null +++ b/src/serac/numerics/functional/tests/enzyme_test_tuples_of_tuples.cpp @@ -0,0 +1,147 @@ +#include + +#include + +#include "serac/numerics/functional/tuple.hpp" +#include "serac/numerics/functional/tensor.hpp" +#include "serac/numerics/functional/enzyme_wrapper.hpp" + +using namespace serac; + +int main() { + + auto fg = [](double z, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement) { + auto [u, du_dx] = displacement; + return tuple{dot(du_dx, z * u), z * (du_dx + transpose(du_dx)) - outer(u, u)}; + }; + +//////////////////////////////////////////////////////////////////////////////// + + auto dfdz = [](double, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement) { + auto [u, du_dx] = displacement; + return dot(du_dx, u); + }; + + auto dgdz = [](double, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement) { + auto [u, du_dx] = displacement; + return (du_dx + transpose(du_dx)); + }; + +//////////////////////////////////////////////////////////////////////////////// + + auto dfdu = [](double z, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement) { + return get<1>(displacement) * z; + }; + + auto dgdu = [](double, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement) { + auto u = get<0>(displacement); + tensor output{}; + for (int k = 0; k < 3; k++) { + for (int j = 0; j < 3; j++) { + for (int i = 0; i < 3; i++) { + output(i,j,k) = - (u(i) * (j == k) + u(j) * (i == k)); + } + } + } + return output; + }; + +//////////////////////////////////////////////////////////////////////////////// + + auto dfddudx = [](double z, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > & displacement) { + auto u = get<0>(displacement); + tensor output{}; + for (int k = 0; k < 3; k++) { + for (int j = 0; j < 3; j++) { + for (int i = 0; i < 3; i++) { + output(i,j,k) = z * u[k] * (i == j); + } + } + } + return output; + }; + + auto dgddudx = [](double z, const tuple< tensor< double, 3 >, tensor< double, 3, 3 > > &) { + tensor output{}; + for (int l = 0; l < 3; l++) { + for (int k = 0; k < 3; k++) { + for (int j = 0; j < 3; j++) { + for (int i = 0; i < 3; i++) { + output(i,j,k,l) = z * ((i==k) * (j==l) + (j==k) * (i==l)); + } + } + } + } + return output; + }; + +//////////////////////////////////////////////////////////////////////////////// + + double z = 3.0; + auto displacement = tuple { + tensor{{1.0, 1.0, 1.0}}, + tensor{{{1.0, 2.0, 3.0}, {2.0, 3.0, 1.0}, {1.0, 0.5, 0.2}}} + }; + + auto [df_dz, dg_dz] = jacfwd<0>(fg, z, displacement); + std::cout << "df_dz: " << df_dz << std::endl; + std::cout << "expected: " << dfdz(z, displacement) << std::endl; + std::cout << std::endl; + + std::cout << "df_dz: " << dg_dz << std::endl; + std::cout << "expected: " << dgdz(z, displacement) << std::endl; + std::cout << std::endl; + + auto [df_ddisp, dg_ddisp] = jacfwd<1>(fg, z, displacement); + std::cout << "df_du: "; + std::cout << "{"; + for (int i = 0; i < 3; i++) { + std::cout << get<0>(df_ddisp)(i); + if (i != 2) { std::cout << ","; } + } + std::cout << "}" << std::endl; + std::cout << "expected: " << dfdu(z, displacement) << std::endl; + std::cout << std::endl; + + std::cout << "df_d(du_dx): "; + std::cout << "{"; + for (int i = 0; i < 3; i++) { + std::cout << get<1>(df_ddisp)(i); + if (i != 2) { std::cout << ","; } + } + std::cout << "}" << std::endl; + std::cout << "expected: " << dfddudx(z, displacement) << std::endl; + std::cout << std::endl; + + std::cout << "dg_du: "; + std::cout << "{"; + for (int i = 0; i < 3; i++) { + std::cout << "{"; + for (int j = 0; j < 3; j++) { + std::cout << get<0>(dg_ddisp)(i,j); + if (j != 2) { std::cout << ","; } + } + std::cout << "}"; + if (i != 2) { std::cout << ","; } + } + std::cout << "}" << std::endl; + std::cout << "expected: " << dgdu(z, displacement) << std::endl; + std::cout << std::endl; + + + std::cout << "dg_d(du_dx): "; + std::cout << "{"; + for (int i = 0; i < 3; i++) { + std::cout << "{"; + for (int j = 0; j < 3; j++) { + std::cout << get<1>(dg_ddisp)(i,j); + if (j != 2) { std::cout << ","; } + } + std::cout << "}"; + if (i != 2) { std::cout << ","; } + } + std::cout << "}" << std::endl; + std::cout << "expected: " << dgddudx(z, displacement) << std::endl; + std::cout << std::endl; + +} From e0d48e9ca7ae2d471fff18e3c6fd5aa508a712d3 Mon Sep 17 00:00:00 2001 From: Sam Mish Date: Wed, 29 May 2024 12:49:15 -0700 Subject: [PATCH 09/12] change nested type for tensor of tensors --- src/serac/numerics/functional/enzyme_wrapper.hpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/serac/numerics/functional/enzyme_wrapper.hpp b/src/serac/numerics/functional/enzyme_wrapper.hpp index 83ee107233..bbcd9c8beb 100644 --- a/src/serac/numerics/functional/enzyme_wrapper.hpp +++ b/src/serac/numerics/functional/enzyme_wrapper.hpp @@ -34,8 +34,11 @@ namespace impl { //////////////////////////////////////////////////////////////////////////////// - template < int ... n, typename T > - struct nested< tensor, T >{ using type = tensor; }; + template < int ... n > + struct nested< tensor, double >{ using type = tensor; }; + + template < int ... m, int ... n > + struct nested< tensor, tensor >{ using type = tensor; }; //////////////////////////////////////////////////////////////////////////////// From 39bc95a28671029f6af4825be33c452170afadf2 Mon Sep 17 00:00:00 2001 From: Sam Mish Date: Wed, 29 May 2024 13:24:58 -0700 Subject: [PATCH 10/12] fix mistake in qfunction evaluation --- .../functional/domain_integral_kernels.hpp | 5 ++--- .../numerics/functional/enzyme_wrapper.hpp | 17 +++++++++++------ .../tests/functional_basic_h1_scalar.cpp | 7 ++++--- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/serac/numerics/functional/domain_integral_kernels.hpp b/src/serac/numerics/functional/domain_integral_kernels.hpp index 89df6ab5be..700cc6bf27 100644 --- a/src/serac/numerics/functional/domain_integral_kernels.hpp +++ b/src/serac/numerics/functional/domain_integral_kernels.hpp @@ -134,8 +134,7 @@ SERAC_HOST_DEVICE auto batch_apply_qf_no_qdata(lambda qf, double t, const tensor double detJ_q = det(J_q); tensor invJ_q = inv(J_q); auto qf_output = qf(t, serac::tuple{x_q, J_q}, parent_to_physical(inputs[i], invJ_q) ...); - physical_to_parent(qf_output, invJ_q, detJ_q); - outputs[i] = qf_output; + outputs[i] = physical_to_parent(qf_output, invJ_q, detJ_q); } return outputs; } @@ -162,7 +161,7 @@ SERAC_HOST_DEVICE auto batch_apply_qf_derivative(derivative_type * doutputs, lam doutputs[i] = jacfwd(func, inputs[i]...); std::cout << doutputs[i] << std::endl; - std::cout << get_gradient(func(make_dual(inputs[i]...))) << std::endl; + //std::cout << get_gradient(func(make_dual(inputs[i]...))) << std::endl; //doutputs[i] = get_gradient(func(make_dual(inputs[i]...))); diff --git a/src/serac/numerics/functional/enzyme_wrapper.hpp b/src/serac/numerics/functional/enzyme_wrapper.hpp index bbcd9c8beb..8ea92c9b79 100644 --- a/src/serac/numerics/functional/enzyme_wrapper.hpp +++ b/src/serac/numerics/functional/enzyme_wrapper.hpp @@ -9,6 +9,8 @@ namespace serac { namespace impl { + constexpr int vsize(const zero &) { return 0; } + constexpr int vsize(const double &) { return 1; } template < int n > @@ -26,6 +28,9 @@ namespace impl { template <> struct nested< double, double >{ using type = double; }; + template < typename T > + struct nested< zero, T >{ using type = zero; }; + template < int ... n > struct nested< double, tensor >{ using type = tensor; }; @@ -112,7 +117,7 @@ namespace impl { for (int j = 0; j < (n0 + n1); j++) { dx_ptr[j] = 1.0; - std::cout << dx << std::endl; + //std::cout << dx << std::endl; output_type unused{}; output_type df_dxj{}; @@ -122,7 +127,7 @@ namespace impl { enzyme_dup, &x, &dx ); - std::cout << df_dxj << std::endl; + //std::cout << df_dxj << std::endl; double * df0_dxj_ptr = reinterpret_cast(&get<0>(df_dxj)); double * df1_dxj_ptr = reinterpret_cast(&get<1>(df_dxj)); @@ -151,7 +156,7 @@ namespace impl { } } - std::cout << J << std::endl; + //std::cout << J << std::endl; dx_ptr[j] = 0.0; } @@ -168,7 +173,7 @@ namespace impl { for (int j = 0; j < n; j++) { dx_ptr[j] = 1.0; - std::cout << dx << std::endl; + //std::cout << dx << std::endl; output_type unused{}; output_type df_dxj{}; @@ -178,14 +183,14 @@ namespace impl { enzyme_dup, &x, &dx ); - std::cout << df_dxj << std::endl; + //std::cout << df_dxj << std::endl; double * df_dxj_ptr = reinterpret_cast(&df_dxj); for (int i = 0; i < m; i++) { J_ptr[i * n + j] = df_dxj_ptr[i]; } - std::cout << J << std::endl; + //std::cout << J << std::endl; dx_ptr[j] = 0.0; } diff --git a/src/serac/numerics/functional/tests/functional_basic_h1_scalar.cpp b/src/serac/numerics/functional/tests/functional_basic_h1_scalar.cpp index 8238748b5a..6db8f20a22 100644 --- a/src/serac/numerics/functional/tests/functional_basic_h1_scalar.cpp +++ b/src/serac/numerics/functional/tests/functional_basic_h1_scalar.cpp @@ -55,11 +55,12 @@ void thermal_test_impl(std::unique_ptr& mesh) residual.AddDomainIntegral( Dimension{}, DependsOn<0>{}, - [=](double /*t*/, auto position, auto temperature) { + [=](double /*t*/, tuple< tensor, tensor > position, tuple< double, tensor > temperature) { auto [X, dX_dxi] = position; auto [u, du_dx] = temperature; - auto source = d00 * u + dot(d01, du_dx) - 0.0 * (100 * X[0] * X[1]); - auto flux = d10 * u + dot(d11, du_dx); + double source = d00 * u + dot(d01, du_dx) - 0.0 * (100 * X[0] * X[1]); + tensor flux = d10 * u + dot(d11, du_dx); + flux[0] += X[1]; return serac::tuple{source, flux}; }, *mesh); From a8c91760acc84b0ed623f6f271143417398765b2 Mon Sep 17 00:00:00 2001 From: Sam Mish Date: Wed, 29 May 2024 13:30:58 -0700 Subject: [PATCH 11/12] edit comment --- src/serac/numerics/functional/CMakeLists.txt | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/serac/numerics/functional/CMakeLists.txt b/src/serac/numerics/functional/CMakeLists.txt index 113c002107..fb0ecb96d4 100644 --- a/src/serac/numerics/functional/CMakeLists.txt +++ b/src/serac/numerics/functional/CMakeLists.txt @@ -74,8 +74,6 @@ blt_add_library( # without this, I get # "error while loading shared libraries: libomp.so: cannot open # shared object file: No such file or directory" -# -# are we not using the native OpenMP cmake targets already? target_link_libraries(serac_functional PUBLIC OpenMP::OpenMP_CXX) target_link_libraries(serac_functional PUBLIC ClangEnzymeFlags) From 2938e1851f8bb5345f0c3a321219e7ba9f96b66a Mon Sep 17 00:00:00 2001 From: Sam Mish Date: Thu, 30 May 2024 13:44:32 -0700 Subject: [PATCH 12/12] propagate `enzyme-loose-types` to everything downstream of serac::Functional --- src/serac/numerics/functional/CMakeLists.txt | 2 ++ src/serac/numerics/functional/domain_integral_kernels.hpp | 6 ------ 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/serac/numerics/functional/CMakeLists.txt b/src/serac/numerics/functional/CMakeLists.txt index fb0ecb96d4..d60bd3f41d 100644 --- a/src/serac/numerics/functional/CMakeLists.txt +++ b/src/serac/numerics/functional/CMakeLists.txt @@ -77,6 +77,8 @@ blt_add_library( target_link_libraries(serac_functional PUBLIC OpenMP::OpenMP_CXX) target_link_libraries(serac_functional PUBLIC ClangEnzymeFlags) +target_compile_options(serac_functional PUBLIC -mllvm -enzyme-loose-types) + install(FILES ${functional_headers} DESTINATION include/serac/numerics/functional ) install(FILES ${functional_detail_headers} DESTINATION include/serac/numerics/functional/detail ) diff --git a/src/serac/numerics/functional/domain_integral_kernels.hpp b/src/serac/numerics/functional/domain_integral_kernels.hpp index 700cc6bf27..ba7754b85c 100644 --- a/src/serac/numerics/functional/domain_integral_kernels.hpp +++ b/src/serac/numerics/functional/domain_integral_kernels.hpp @@ -159,12 +159,6 @@ SERAC_HOST_DEVICE auto batch_apply_qf_derivative(derivative_type * doutputs, lam }; doutputs[i] = jacfwd(func, inputs[i]...); - - std::cout << doutputs[i] << std::endl; - //std::cout << get_gradient(func(make_dual(inputs[i]...))) << std::endl; - - //doutputs[i] = get_gradient(func(make_dual(inputs[i]...))); - } return doutputs;