From c649bae835609621b8ea19373916a860dd4f7d08 Mon Sep 17 00:00:00 2001 From: Nils Wentzell Date: Tue, 2 Dec 2025 17:28:23 -0500 Subject: [PATCH] Add C++23 multi-argument operator[] as primary indexing interface - Require C++23 and remove deprecated comma-subscript warnings - Add variadic operator[] to basic_array and basic_array_view - Rename call<>() helper to subscript<>() for clarity - Make operator[] the primary implementation with operator() delegating to it - Update expression templates (expr, expr_unary) consistently - Add variadic operator[] to expr_call and array_adapter - Add comprehensive tests for the new subscript operator Co-Authored-By: Claude --- CMakeLists.txt | 2 - c++/nda/CMakeLists.txt | 2 +- c++/nda/_impl_basic_array_view_common.hpp | 87 ++++++++-------- c++/nda/arithmetic.hpp | 79 +++++++++------ c++/nda/array_adapter.hpp | 16 ++- c++/nda/map.hpp | 43 ++++---- test/c++/nda_basic_array_and_view.cpp | 116 ++++++++++++++++++++++ 7 files changed, 233 insertions(+), 112 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ba4197f31..c3a88da1e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -105,12 +105,10 @@ target_compile_options(${PROJECT_NAME}_warnings -Wfloat-conversion -Wpedantic -Wno-sign-compare - $<$:-Wno-comma-subscript> $<$:-Wno-psabi> # Disable notes about ABI changes $<$:-Wshadow=local> $<$:-Wno-attributes> $<$:-Wno-deprecated-declarations> - $<$:-Wno-deprecated-comma-subscript> $<$:-Wno-unknown-warning-option> $<$:-Wshadow> $<$:-Wno-gcc-compat> diff --git a/c++/nda/CMakeLists.txt b/c++/nda/CMakeLists.txt index 7e16dfc3a..fd5e1e099 100644 --- a/c++/nda/CMakeLists.txt +++ b/c++/nda/CMakeLists.txt @@ -9,7 +9,7 @@ add_library(${PROJECT_NAME}::${PROJECT_NAME}_c ALIAS ${PROJECT_NAME}_c) target_link_libraries(${PROJECT_NAME}_c PRIVATE $) # Configure target and compilation -target_compile_features(${PROJECT_NAME}_c PUBLIC cxx_std_20) +target_compile_features(${PROJECT_NAME}_c PUBLIC cxx_std_23) set_target_properties(${PROJECT_NAME}_c PROPERTIES POSITION_INDEPENDENT_CODE ON VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR} diff --git a/c++/nda/_impl_basic_array_view_common.hpp b/c++/nda/_impl_basic_array_view_common.hpp index cde88d5e3..59bd2eff6 100644 --- a/c++/nda/_impl_basic_array_view_common.hpp +++ b/c++/nda/_impl_basic_array_view_common.hpp @@ -159,23 +159,23 @@ static constexpr bool has_no_boundcheck = true; public: /** - * @brief Implementation of the function call operator. + * @brief Implementation of the subscript operator. * - * @details This function is an implementation detail an should be private. Since the Green's function library in + * @details This function is an implementation detail and should be private. Since the Green's function library in * TRIQS uses this function, it is kept public (for now). * * @tparam ResultAlgebra Algebra of the resulting view/array. * @tparam SelfIsRvalue True if the view/array is an rvalue. * @tparam Self Type of the calling view/array. - * @tparam T Types of the arguments. + * @tparam Ts Types of the arguments. * - * @param self Calling view. + * @param self Calling view/array. * @param idxs Multi-dimensional index consisting of `long`, `nda::range`, `nda::range::all_t`, nda::ellipsis or lazy * arguments. - * @return Result of the function call depending on the given arguments and type of the view/array. + * @return Result of the subscript operation depending on the given arguments and type of the view/array. */ template -FORCEINLINE static decltype(auto) call(Self &&self, Ts const &...idxs) noexcept(has_no_boundcheck) { +FORCEINLINE static decltype(auto) subscript(Self &&self, Ts const &...idxs) noexcept(has_no_boundcheck) { // resulting value type using r_v_t = std::conditional_t>, ValueType const, ValueType>; @@ -219,9 +219,10 @@ FORCEINLINE static decltype(auto) call(Self &&self, Ts const &...idxs) noexcept( public: /** - * @brief Function call operator to access the view/array. + * @brief Subscript operator to access the view/array. * - * @details Depending on the type of the calling object and the given arguments, this function call does the following: + * @details Depending on the type of the calling object and the given arguments, this subscript operation does the + * following: * - If any of the arguments is lazy, an nda::clef::expr with the nda::clef::tags::function tag is returned. * - If no arguments are given, a full view of the calling object is returned: * - If the calling object itself or its value type is const, a view with a const value type is returned. @@ -234,69 +235,59 @@ FORCEINLINE static decltype(auto) call(Self &&self, Ts const &...idxs) noexcept( * the calling object. The algebra of the slice is the same as well, except if a 1-dimensional slice of a matrix is * taken. In this case, the algebra is changed to 'V'. * - * @tparam Ts Types of the function arguments. + * @tparam Ts Types of the arguments. * @param idxs Multi-dimensional index consisting of `long`, `nda::range`, `nda::range::all_t`, nda::ellipsis or lazy * arguments. - * @return Result of the function call depending on the given arguments and type of the view/array. + * @return Result of the subscript operation depending on the given arguments and type of the view/array. */ template -FORCEINLINE decltype(auto) operator()(Ts const &...idxs) const & noexcept(has_no_boundcheck) { +FORCEINLINE decltype(auto) operator[](Ts const &...idxs) const & noexcept(has_no_boundcheck) { static_assert((rank == -1) or (sizeof...(Ts) == rank) or (sizeof...(Ts) == 0) or (ellipsis_is_present and (sizeof...(Ts) <= rank + 1)), - "Error in array/view: Incorrect number of parameters in call operator"); - return call(*this, idxs...); + "Error in array/view: Incorrect number of parameters in subscript operator"); + return subscript(*this, idxs...); } -/// Non-const overload of `nda::basic_array_view::operator()(Ts const &...) const &`. +/// Non-const overload of `nda::basic_array_view::operator[](Ts const &...) const &`. template -FORCEINLINE decltype(auto) operator()(Ts const &...idxs) & noexcept(has_no_boundcheck) { +FORCEINLINE decltype(auto) operator[](Ts const &...idxs) & noexcept(has_no_boundcheck) { static_assert((rank == -1) or (sizeof...(Ts) == rank) or (sizeof...(Ts) == 0) or (ellipsis_is_present and (sizeof...(Ts) <= rank + 1)), - "Error in array/view: Incorrect number of parameters in call operator"); - return call(*this, idxs...); + "Error in array/view: Incorrect number of parameters in subscript operator"); + return subscript(*this, idxs...); } -/// Rvalue overload of `nda::basic_array_view::operator()(Ts const &...) const &`. +/// Rvalue overload of `nda::basic_array_view::operator[](Ts const &...) const &`. template -FORCEINLINE decltype(auto) operator()(Ts const &...idxs) && noexcept(has_no_boundcheck) { +FORCEINLINE decltype(auto) operator[](Ts const &...idxs) && noexcept(has_no_boundcheck) { static_assert((rank == -1) or (sizeof...(Ts) == rank) or (sizeof...(Ts) == 0) or (ellipsis_is_present and (sizeof...(Ts) <= rank + 1)), - "Error in array/view: Incorrect number of parameters in call operator"); - return call(*this, idxs...); + "Error in array/view: Incorrect number of parameters in subscript operator"); + return subscript(*this, idxs...); } /** - * @brief Subscript operator to access the 1-dimensional view/array. + * @brief Function call operator to access the view/array. * - * @details Depending on the type of the calling object and the given argument, this subscript operation does the - * following: - * - If the argument is lazy, an nda::clef::expr with the nda::clef::tags::function tag is returned. - * - If the argument is convertible to `long`, a single element is accessed: - * - If the calling object is a view or an lvalue, a (const) reference to the element is returned. - * - Otherwise, a copy of the element is returned. - * - Otherwise a slice of the calling object is returned with the same value type, algebra and accessor and owning - * policies as the calling object. + * @details Forwards to the subscript operator. See `nda::basic_array_view::operator[]` for details. * - * @tparam T Type of the argument. - * @param idx 1-dimensional index that is either a `long`, `nda::range`, `nda::range::all_t`, nda::ellipsis or a lazy - * argument. - * @return Result of the subscript operation depending on the given argument and type of the view/array. + * @tparam Ts Types of the function arguments. + * @param idxs Multi-dimensional index consisting of `long`, `nda::range`, `nda::range::all_t`, nda::ellipsis or lazy + * arguments. + * @return Result of the function call depending on the given arguments and type of the view/array. */ -template -decltype(auto) operator[](T const &idx) const & noexcept(has_no_boundcheck) { - static_assert((rank == 1), "Error in array/view: Subscript operator is only available for rank 1 views/arrays in C++17/20"); - return call(*this, idx); +template +FORCEINLINE decltype(auto) operator()(Ts const &...idxs) const & noexcept(has_no_boundcheck) { + return (*this)[idxs...]; } -/// Non-const overload of `nda::basic_array_view::operator[](T const &) const &`. -template -decltype(auto) operator[](T const &x) & noexcept(has_no_boundcheck) { - static_assert((rank == 1), "Error in array/view: Subscript operator is only available for rank 1 views/arrays in C++17/20"); - return call(*this, x); +/// Non-const overload of `nda::basic_array_view::operator()(Ts const &...) const &`. +template +FORCEINLINE decltype(auto) operator()(Ts const &...idxs) & noexcept(has_no_boundcheck) { + return (*this)[idxs...]; } -/// Rvalue overload of `nda::basic_array_view::operator[](T const &) const &`. -template -decltype(auto) operator[](T const &x) && noexcept(has_no_boundcheck) { - static_assert((rank == 1), "Error in array/view: Subscript operator is only available for rank 1 views/arrays in C++17/20"); - return call(*this, x); +/// Rvalue overload of `nda::basic_array_view::operator()(Ts const &...) const &`. +template +FORCEINLINE decltype(auto) operator()(Ts const &...idxs) && noexcept(has_no_boundcheck) { + return std::move(*this)[idxs...]; } /// Rank of the nda::array_iterator for the view/array. diff --git a/c++/nda/arithmetic.hpp b/c++/nda/arithmetic.hpp index 3b76167b7..62738883f 100644 --- a/c++/nda/arithmetic.hpp +++ b/c++/nda/arithmetic.hpp @@ -53,18 +53,32 @@ namespace nda { A a; /** - * @brief Function call operator. + * @brief Subscript operator. * * @details Forwards the arguments to the nda::Array operand and negates the result. * * @tparam Args Types of the arguments. - * @param args Function call arguments. - * @return If the result of the forwarded function call is another nda::Array, a new lazy expression is returned. + * @param args Subscript arguments. + * @return If the result of the forwarded subscript is another nda::Array, a new lazy expression is returned. * Otherwise the result is negated and returned. */ template + auto operator[](Args &&...args) const { + return -a[std::forward(args)...]; + } + + /** + * @brief Function call operator. + * + * @details Forwards to the subscript operator. + * + * @tparam Args Types of the arguments. + * @param args Function call arguments. + * @return Result of the corresponding subscript operation. + */ + template auto operator()(Args &&...args) const { - return -a(std::forward(args)...); + return (*this)[std::forward(args)...]; } /** @@ -157,38 +171,38 @@ namespace nda { } /** - * @brief Function call operator. + * @brief Subscript operator. * * @details Forwards the arguments to the nda::Array operands and performs the binary operation. * * @tparam Args Types of the arguments. - * @param args Function call arguments. - * @return If the result of the forwarded function calls contains another nda::Array, a new lazy expression is + * @param args Subscript arguments. + * @return If the result of the forwarded subscript operations contains another nda::Array, a new lazy expression is * returned. Otherwise the result of the binary operation is returned. */ template - auto operator()(Args const &...args) const { + auto operator[](Args const &...args) const { // addition if constexpr (OP == '+') { if constexpr (l_is_scalar) { // lhs is a scalar if constexpr (algebra == 'M') // rhs is a matrix - return (std::equal_to{}(args...) ? l + r(args...) : r(args...)); + return (std::equal_to{}(args...) ? l + r[args...] : r[args...]); else // rhs is an array - return l + r(args...); + return l + r[args...]; } else if constexpr (r_is_scalar) { // rhs is a scalar if constexpr (algebra == 'M') // lhs is a matrix - return (std::equal_to{}(args...) ? l(args...) + r : l(args...)); + return (std::equal_to{}(args...) ? l[args...] + r : l[args...]); else // lhs is an array - return l(args...) + r; + return l[args...] + r; } else // both are arrays or matrices - return l(args...) + r(args...); + return l[args...] + r[args...]; } // subtraction @@ -197,35 +211,35 @@ namespace nda { // lhs is a scalar if constexpr (algebra == 'M') // rhs is a matrix - return (std::equal_to{}(args...) ? l - r(args...) : -r(args...)); + return (std::equal_to{}(args...) ? l - r[args...] : -r[args...]); else // rhs is an array - return l - r(args...); + return l - r[args...]; } else if constexpr (r_is_scalar) { // rhs is a scalar if constexpr (algebra == 'M') // lhs is a matrix - return (std::equal_to{}(args...) ? l(args...) - r : l(args...)); + return (std::equal_to{}(args...) ? l[args...] - r : l[args...]); else // lhs is an array - return l(args...) - r; + return l[args...] - r; } else // both are arrays or matrices - return l(args...) - r(args...); + return l[args...] - r[args...]; } // multiplication if constexpr (OP == '*') { if constexpr (l_is_scalar) // lhs is a scalar - return l * r(args...); + return l * r[args...]; else if constexpr (r_is_scalar) // rhs is a scalar - return l(args...) * r; + return l[args...] * r; else { // both are arrays (matrix product is not supported here) static_assert(algebra != 'M', "Error in nda::expr: Matrix algebra not supported"); - return l(args...) * r(args...); + return l[args...] * r[args...]; } } @@ -234,31 +248,30 @@ namespace nda { if constexpr (l_is_scalar) { // lhs is a scalar static_assert(algebra != 'M', "Error in nda::expr: Matrix algebra not supported"); - return l / r(args...); + return l / r[args...]; } else if constexpr (r_is_scalar) // rhs is a scalar - return l(args...) / r; + return l[args...] / r; else { // both are arrays (matrix division is not supported here) static_assert(algebra != 'M', "Error in nda::expr: Matrix algebra not supported"); - return l(args...) / r(args...); + return l[args...] / r[args...]; } } } /** - * @brief Subscript operator. + * @brief Function call operator. * - * @details Simply forwards the argument to the function call operator. + * @details Forwards to the subscript operator. * - * @tparam Arg Type of the argument. - * @param arg Subscript argument. - * @return Result of the corresponding function call. + * @tparam Args Types of the arguments. + * @param args Function call arguments. + * @return Result of the corresponding subscript operation. */ - template - auto operator[](Arg &&arg) const { - static_assert(get_rank == 1, "Error in nda::expr: Subscript operator only available for expressions of rank 1"); - return operator()(std::forward(arg)); + template + auto operator()(Args const &...args) const { + return (*this)[args...]; } }; diff --git a/c++/nda/array_adapter.hpp b/c++/nda/array_adapter.hpp index c46472c18..1fe3d3638 100644 --- a/c++/nda/array_adapter.hpp +++ b/c++/nda/array_adapter.hpp @@ -64,17 +64,29 @@ namespace nda { [[nodiscard]] long size() const { return stdutil::product(myshape); } /** - * @brief Function call operator simply forwards the arguments to the callable object. + * @brief Subscript operator simply forwards the arguments to the callable object. * * @tparam Ints Integer types (convertible to long). * @param i0 First argument. * @param is Rest of the arguments. */ template - auto operator()(long i0, Ints... is) const { + auto operator[](long i0, Ints... is) const { static_assert((std::is_convertible_v and ...), "Error in nda::array_adapter: Arguments must be convertible to long"); return f(i0, is...); } + + /** + * @brief Function call operator simply forwards the arguments to the subscript operator. + * + * @tparam Ints Integer types (convertible to long). + * @param i0 First argument. + * @param is Rest of the arguments. + */ + template + auto operator()(long i0, Ints... is) const { + return (*this)[i0, is...]; + } }; // Class template argument deduction guides. diff --git a/c++/nda/map.hpp b/c++/nda/map.hpp index 10ee2c3f0..c21ec87ed 100644 --- a/c++/nda/map.hpp +++ b/c++/nda/map.hpp @@ -84,26 +84,20 @@ namespace nda { std::tuple a; private: - // Implementation of the function call operator. + // Implementation of the subscript operator. template - [[gnu::always_inline]] [[nodiscard]] auto _call(std::index_sequence, Args const &...args) const { + [[gnu::always_inline]] [[nodiscard]] auto _subscript(std::index_sequence, Args const &...args) const { // if args contains a range, we need to return an expr_call on the resulting slice if constexpr ((is_range_or_ellipsis or ... or false)) { - return mapped{f}(std::get(a)(args...)...); + return mapped{f}(std::get(a)[args...]...); } else { - return f(std::get(a)(args...)...); + return f(std::get(a)[args...]...); } } - // Implementation of the subscript operator. - template - [[gnu::always_inline]] auto _call_bra(std::index_sequence, Arg const &arg) const { - return f(std::get(a)[arg]...); - } - public: /** - * @brief Function call operator. + * @brief Subscript operator. * * @details The arguments (usually multi-dimensional indices) are passed to all the nda::Array objects stored in the * tuple and the results are then passed to the callable object. @@ -111,29 +105,26 @@ namespace nda { * If the arguments contain a range, a new lazy function call expression is returned. * * @tparam Args Argument types. - * @param args Function call arguments. - * @return The result of the function call (depends on the callable and the arguments). + * @param args Subscript arguments. + * @return The result of the subscript operation (depends on the callable and the arguments). */ template - auto operator()(Args const &...args) const { - return _call(std::make_index_sequence{}, args...); + auto operator[](Args const &...args) const { + return _subscript(std::make_index_sequence{}, args...); } /** - * @brief Subscript operator. - * - * @details The argument (usually a 1-dimensional index) is passed to all the nda::Array objects stored in the tuple - * and the results are then passed to the callable object. + * @brief Function call operator. * - * If the argument is a range, a new lazy function call expression is returned. + * @details Equivalent to the subscript operator. Provided for backwards compatibility. * - * @tparam Arg Argument types. - * @param arg Subscript argument. - * @return The result of the subscript operation (depends on the callable and the arguments). + * @tparam Args Argument types. + * @param args Function call arguments. + * @return The result of the function call (depends on the callable and the arguments). */ - template - auto operator[](Arg const &arg) const { - return _call_bra(std::make_index_sequence{}, arg); + template + auto operator()(Args const &...args) const { + return (*this)[args...]; } // FIXME copy needed for the && case only. Overload ? diff --git a/test/c++/nda_basic_array_and_view.cpp b/test/c++/nda_basic_array_and_view.cpp index a1e0461b5..89a31a924 100644 --- a/test/c++/nda_basic_array_and_view.cpp +++ b/test/c++/nda_basic_array_and_view.cpp @@ -843,6 +843,122 @@ TEST_F(NDAArrayAndView, AccessViaSubscriptOperator) { EXPECT_EQ(B, (nda::array{42, 1, 42, 3, 42})); } +TEST_F(NDAArrayAndView, MultiDimensionalSubscriptOperator) { + using namespace nda::clef::literals; + + // multi-dimensional single element access via operator[] + for (long i = 0; i < shape_3d[0]; ++i) { + for (long j = 0; j < shape_3d[1]; ++j) { + for (long k = 0; k < shape_3d[2]; ++k) { + EXPECT_EQ((A_3d[i, j, k]), A_3d(i, j, k)); + EXPECT_EQ((A_3d_v[i, j, k]), A_3d_v(i, j, k)); + EXPECT_EQ((A_3d_cv[i, j, k]), A_3d_cv(i, j, k)); + } + } + } + + // verify return types + static_assert(std::is_reference_v); + static_assert(std::is_reference_v); + static_assert(!std::is_reference_v); + + // modify via operator[] + auto C = A_3d; + C[0, 1, 2] = 42; + EXPECT_EQ((C[0, 1, 2]), 42); + EXPECT_NE(C, A_3d); + + // multi-dimensional lazy access + auto A_3d_lazy = A_3d[i_, j_, k_]; + static_assert(nda::clef::is_lazy); + EXPECT_EQ(nda::clef::eval(A_3d_lazy, i_ = 0, j_ = 1, k_ = 2), A_3d(0, 1, 2)); + + // CLEF auto-assign with operator[] (the << operator) + nda::array E(3, 4); + E[i_, j_] << i_ + 0.1 * j_; + for (int ii = 0; ii < 3; ++ii) { + for (int jj = 0; jj < 4; ++jj) { EXPECT_DOUBLE_EQ((E[ii, jj]), ii + 0.1 * jj); } + } + + // CLEF auto-assign with operator[] for 3D array + nda::array F(2, 3, 4); + F[i_, j_, k_] << 100 * i_ + 10 * j_ + k_; + for (int ii = 0; ii < 2; ++ii) { + for (int jj = 0; jj < 3; ++jj) { + for (int kk = 0; kk < 4; ++kk) { EXPECT_EQ((F[ii, jj, kk]), 100 * ii + 10 * jj + kk); } + } + } + + // 2D slice via operator[] (fix first index) + auto D = A_3d; + auto D_slice = D[0, nda::range::all, nda::range::all]; + EXPECT_EQ(D_slice.shape(), (std::array{3, 4})); + for (long j = 0; j < shape_3d[1]; ++j) { + for (long k = 0; k < shape_3d[2]; ++k) { EXPECT_EQ((D_slice[j, k]), A_3d(0, j, k)); } + } + + // 1D slice via operator[] (fix first two indices) + auto D_1d = D[1, 2, nda::range::all]; + EXPECT_EQ(D_1d.size(), shape_3d[2]); + for (long k = 0; k < shape_3d[2]; ++k) { EXPECT_EQ(D_1d[k], A_3d(1, 2, k)); } + + // slicing with ellipsis via operator[] + auto D_ellipsis = D[0, nda::ellipsis{}]; + EXPECT_EQ(D_ellipsis.shape(), (std::array{3, 4})); + for (long j = 0; j < shape_3d[1]; ++j) { + for (long k = 0; k < shape_3d[2]; ++k) { EXPECT_EQ((D_ellipsis[j, k]), A_3d(0, j, k)); } + } + + // assign to slice via operator[] + D[0, nda::range::all, 0] = 99; + for (long j = 0; j < shape_3d[1]; ++j) { EXPECT_EQ((D[0, j, 0]), 99); } + + // full view via operator[] with no arguments + auto D_full = D[]; + EXPECT_EQ(D_full.shape(), D.shape()); + EXPECT_EQ(D_full, D); + + // matrix tests + nda::matrix M(3, 4); + for (int i = 0; i < 3; ++i) + for (int j = 0; j < 4; ++j) M(i, j) = i * 4 + j; + + EXPECT_EQ((M[1, 2]), M(1, 2)); + + // matrix row slice + auto M_row = M[1, nda::range::all]; + auto M_row2 = M(1, nda::range::all); + EXPECT_EQ(M_row, M_row2); + + // matrix column slice + auto M_col = M[nda::range::all, 2]; + auto M_col2 = M(nda::range::all, 2); + EXPECT_EQ(M_col, M_col2); + + // expression templates with operator[] + nda::array G(3, 4), H(3, 4); + G[i_, j_] << i_ + j_; + H[i_, j_] << 2.0 * i_ - j_; + + // binary expression with operator[] + auto expr_add = G + H; + for (int ii = 0; ii < 3; ++ii) { + for (int jj = 0; jj < 4; ++jj) { EXPECT_DOUBLE_EQ((expr_add[ii, jj]), (G[ii, jj] + H[ii, jj])); } + } + + // unary expression with operator[] + auto expr_neg = -G; + for (int ii = 0; ii < 3; ++ii) { + for (int jj = 0; jj < 4; ++jj) { EXPECT_DOUBLE_EQ((expr_neg[ii, jj]), -(G[ii, jj])); } + } + + // scalar-array expression with operator[] + auto expr_scale = 3.0 * G; + for (int ii = 0; ii < 3; ++ii) { + for (int jj = 0; jj < 4; ++jj) { EXPECT_DOUBLE_EQ((expr_scale[ii, jj]), 3.0 * (G[ii, jj])); } + } +} + TEST_F(NDAArrayAndView, Indices) { auto indices = A_3d_v.indices(); auto it = indices.begin();