diff --git a/include/tensorwrapper/sparsity/pattern.hpp b/include/tensorwrapper/sparsity/pattern.hpp index 138c4777..5ce28291 100644 --- a/include/tensorwrapper/sparsity/pattern.hpp +++ b/include/tensorwrapper/sparsity/pattern.hpp @@ -15,17 +15,47 @@ */ #pragma once +#include +#include namespace tensorwrapper::sparsity { /** @brief Base class for objects describing the sparsity of a tensor. */ -class Pattern { +class Pattern : public tensorwrapper::detail_::DSLBase, + public tensorwrapper::detail_::PolymorphicBase { +private: + /// Type defining the polymorphic API of *this + using polymorphic_base = tensorwrapper::detail_::PolymorphicBase; + public: + /// Type used for indexing and offsets + using size_type = std::size_t; + + /** @brief Creates a pattern for a rank @p rank tensor. + * + * This constructor creates a sparsity pattern for a dense tensor with + * @p rank modes. + * + * @param[in] rank The number of modes in the associated tensor. + * + * @throw None No throw guarantee. + */ + Pattern(size_type rank = 0) noexcept : m_rank_(rank) {} + + /** @brief Provides the rank of the tensor *this assumes. + * + * @return The rank of the tensor *this describes. + * + * @throw None No throw guarantee. + */ + size_type rank() const noexcept { return m_rank_; } + /** @brief Determines if *this and @p rhs describe the same sparsity * pattern. * - * At present the sparsity component of TensorWrapper is a stub so this - * method always returns true. + * At present the sparsity component only tracks the rank of the tensor so + * two Patterns are value equal if they describe tensors with the same + * rank. * * @param[in] rhs The object to compare against. * @@ -33,7 +63,9 @@ class Pattern { * * @throw None No throw guarantee. */ - bool operator==(const Pattern& rhs) const noexcept { return true; } + bool operator==(const Pattern& rhs) const noexcept { + return rank() == rhs.rank(); + } /** @brief Is *this different from @p rhs? * @@ -49,6 +81,40 @@ class Pattern { bool operator!=(const Pattern& rhs) const noexcept { return !((*this) == rhs); } + +protected: + /// Implements clone by calling copy constructor + typename polymorphic_base::base_pointer clone_() const override { + return std::make_unique(*this); + } + + /// Implements are_equal by calling implementation provided by the base + bool are_equal_(const_base_reference rhs) const noexcept override { + return are_equal_impl_(rhs); + } + + /// Implements addition_assignment via permute_assignment + dsl_reference addition_assignment_(label_type this_labels, + const_labeled_reference lhs, + const_labeled_reference rhs) override; + + /// Implements subtraction_assignment via permute_assignment + dsl_reference subtraction_assignment_(label_type this_labels, + const_labeled_reference lhs, + const_labeled_reference rhs) override; + + /// Implements multiplication_assignment via permute_assignment + dsl_reference multiplication_assignment_( + label_type this_labels, const_labeled_reference lhs, + const_labeled_reference rhs) override; + + /// Implements permute_assignment by permuting the extents in @p rhs. + dsl_reference permute_assignment_(label_type this_labels, + const_labeled_reference rhs) override; + +private: + /// The rank of the tensor associated with *this + size_type m_rank_; }; } // namespace tensorwrapper::sparsity diff --git a/include/tensorwrapper/symmetry/group.hpp b/include/tensorwrapper/symmetry/group.hpp index 6d66ff2d..591b4eae 100644 --- a/include/tensorwrapper/symmetry/group.hpp +++ b/include/tensorwrapper/symmetry/group.hpp @@ -16,6 +16,9 @@ #pragma once #include +#include +#include +#include #include #include @@ -34,13 +37,18 @@ namespace tensorwrapper::symmetry { * mathematically know that the permutation (0, 2, 1) is also a symmetry * operation because it is the inverse of (0, 1, 2). */ -class Group : public utilities::IndexableContainerBase { +class Group : public utilities::IndexableContainerBase, + public tensorwrapper::detail_::DSLBase, + public tensorwrapper::detail_::PolymorphicBase { private: /// Type of *this using my_type = Group; - /// Type *this derives from - using base_type = utilities::IndexableContainerBase; + /// Type *this derives from to become container-like + using container_type = utilities::IndexableContainerBase; + + /// Type *this derives from to behave like other polymorphic objects + using polymorphic_type = tensorwrapper::detail_::PolymorphicBase; public: /// The base type of each object in *this @@ -55,18 +63,30 @@ class Group : public utilities::IndexableContainerBase { /// Unsigned integral type used for indexing and offsets using size_type = std::size_t; + /// Type used for mode indices + using mode_index_type = typename value_type::mode_index_type; + // ------------------------------------------------------------------------- // -- Ctors and assignment // ------------------------------------------------------------------------- - /** @brief Initializes *this to an empty group. + /** @brief Initializes *this to be the identity group of a scalar. + * + * @throw None No throw guarantee + */ + Group() noexcept = default; + + /** @brief Initializes *this as the identity group of a rank @p rank tensor. * - * This ctor creates a group describing the symmetries of an empty set - * set. Such a set is also the symmetry group of a scalar. + * This ctor creates a group representing the identity group of a rank + * @p rank tensor. + * + * @param[in] rank The rank of the tensor the identity group describes. * * @throw None No throw guarantee. */ - Group() noexcept = default; + explicit Group(mode_index_type rank) noexcept : + m_relations_{}, m_rank_(rank) {} /** @brief Creates a Group from the provided symmetry operations. * @@ -81,12 +101,17 @@ class Group : public utilities::IndexableContainerBase { * @param[in] op A symmetry operation to initialize the group with. * @param[in] ops The remaining symmetry operations. * + * @throw std::runtime_error if @p op and @p ops do not all have the same + * rank. Strong throw guarantee. * @throw std::bad_alloc if there is a problem allocating the initial * state. Strong throw guarantee. */ template explicit Group(const_reference op, Args&&... ops) : Group(std::forward(ops)...) { + if(m_rank_ && rank() != op.rank()) + throw std::runtime_error("Ranks of operations are not consistent"); + if(!m_rank_) m_rank_.emplace(op.rank()); if(!count(op) && !op.is_identity()) m_relations_.emplace_front(op.clone()); } @@ -104,6 +129,7 @@ class Group : public utilities::IndexableContainerBase { Group(const Group& other) { for(const auto& x : other.m_relations_) m_relations_.push_back(x->clone()); + m_rank_ = other.m_rank_; } /** @brief Transfers the state of @p other in to *this. @@ -158,6 +184,17 @@ class Group : public utilities::IndexableContainerBase { */ bool count(const_reference op) const noexcept; + /** @brief The rank of the tensor these symmetries describe. + * + * This is not the rank of the group, but rather the rank of the tensor + * that the symmetries of the group describe. + * + * @return The rank of the tensor described by the symmetries in *this. + * + * @throw None No throw guarantee. + */ + mode_index_type rank() const noexcept { return m_rank_.value_or(0); } + // ------------------------------------------------------------------------- // -- Utility methods // ------------------------------------------------------------------------- @@ -170,12 +207,17 @@ class Group : public utilities::IndexableContainerBase { * * @throw None No throw guarantee. */ - void swap(Group& other) noexcept { m_relations_.swap(other.m_relations_); } + void swap(Group& other) noexcept { + m_relations_.swap(other.m_relations_); + m_rank_.swap(other.m_rank_); + } /** @brief Determines if *this is value equal to @p rhs. * * Two Group objects are value equal if they contain the same number of - * operations and if each operation found in *this is also found in @p rhs. + * operations, if each operation found in *this is also found in @p rhs, + * and if the rank of the associated tensor is the same for *this as + * @p rhs. * * @param[in] rhs The Group object we are comparing against. * @@ -198,9 +240,37 @@ class Group : public utilities::IndexableContainerBase { */ bool operator!=(const Group& rhs) const noexcept { return !(*this == rhs); } +protected: + typename polymorphic_type::base_pointer clone_() const override { + return std::make_unique(*this); + } + + bool are_equal_(const_base_reference rhs) const noexcept override { + return (*this) == rhs; + } + + /// Implements addition_assignment via permute_assignment + dsl_reference addition_assignment_(label_type this_labels, + const_labeled_reference lhs, + const_labeled_reference rhs) override; + + /// Implements subtraction_assignment via permute_assignment + dsl_reference subtraction_assignment_(label_type this_labels, + const_labeled_reference lhs, + const_labeled_reference rhs) override; + + /// Implements multiplication_assignment via permute_assignment + dsl_reference multiplication_assignment_( + label_type this_labels, const_labeled_reference lhs, + const_labeled_reference rhs) override; + + /// Implements permute_assignment by permuting the extents in @p rhs. + dsl_reference permute_assignment_(label_type this_labels, + const_labeled_reference rhs) override; + private: /// Allow base class to access implementations - friend base_type; + friend container_type; /// Base type common to all symmetry operations defining the API using value_pointer = value_type::base_pointer; @@ -219,6 +289,9 @@ class Group : public utilities::IndexableContainerBase { /// The symmetry operations of *this relation_container_type m_relations_; + + /// The rank of the tensor these symmetries apply to + std::optional m_rank_; }; // -- Out of line implementations @@ -231,6 +304,7 @@ inline bool Group::count(const_reference op) const noexcept { } inline bool Group::operator==(const Group& rhs) const noexcept { + if(rank() != rhs.rank()) return false; if(size() != rhs.size()) return false; for(const auto& x : *this) if(!rhs.count(x)) return false; diff --git a/include/tensorwrapper/symmetry/operation.hpp b/include/tensorwrapper/symmetry/operation.hpp index a1adc06c..30548939 100644 --- a/include/tensorwrapper/symmetry/operation.hpp +++ b/include/tensorwrapper/symmetry/operation.hpp @@ -56,6 +56,8 @@ class Operation : public detail_::PolymorphicBase { bool is_identity() const noexcept { return is_identity_(); } + mode_index_type rank() const noexcept { return rank_(); } + // ------------------------------------------------------------------------- // -- Utility methods // ------------------------------------------------------------------------- @@ -63,6 +65,9 @@ class Operation : public detail_::PolymorphicBase { protected: /// Derived class should overwrite to implement is_identity virtual bool is_identity_() const noexcept = 0; + + /// Derived class should overwrite to implement rank() + virtual mode_index_type rank_() const noexcept = 0; }; } // namespace tensorwrapper::symmetry diff --git a/include/tensorwrapper/symmetry/permutation.hpp b/include/tensorwrapper/symmetry/permutation.hpp index 2d4b40ae..e749402b 100644 --- a/include/tensorwrapper/symmetry/permutation.hpp +++ b/include/tensorwrapper/symmetry/permutation.hpp @@ -27,7 +27,7 @@ namespace tensorwrapper::symmetry { * cycles (those which actually swap modes of the tensor) are stored * explicitly all other cycles are stored implicitly. * - * @note Stored cycles will be canonicalized and sorted. "Canoicalized" means + * @note Stored cycles will be canonicalized and sorted. "Canonicalized" means * that each cycle will be cyclically permuted until the smallest * element is first), e.g., the cycle 231 will be stored as 123. Sorting * will be done lexicographically, e.g., the cycle 012 will come before @@ -53,44 +53,52 @@ class Permutation : public Operation { /** @brief Creates an identity permutation. * - * The default Permutation contains no explicit cycles making it equivalent - * to storing only implicit fixed points for an arbitrary rank tensor. - * Such a Permutation is equivalent to the identity permutation (i.e., - * do nothing). + * The identity permutation for a rank `r` tensor contains `r` fixed + * points. When used as a default ctor this will create the identity + * permutation for a scalar (rank 0 tensor). + * + * @param[in] rank The rank of the tensor this permutation represents. + * Default is 0. * * @throw No throw guarantee. */ - Permutation() = default; + explicit Permutation(mode_index_type rank = 0) : m_rank_(rank) {} - /** @brief Creates a Permutation containing a single cycle. + /** @brief Creates a Permutation from "one-line" notation. * - * Many permutations involve a single cycle. For convenience this ctor has - * been defined so that the user can construct the resulting Permutation - * object with by only providing a single initialization list, i.e., no - * need to do something like `Permutation p123{{1, 2, 3}};`. Ultimately, - * this ctor dispatches to `Permutation(cycle_set_initializer_list)`. + * One-line notation for a permutation of a rank `r` tensor is an ordered + * set of the numbers [0, r) such that i-th number in the set is the + * new mode offset of what was the `i`-th mode before the permutation, + * e.g., the permutation (1, 0, 3, 2) means that after the permutation + * mode 0 is now mode 1, mode 1 is now mode 0, mode 2 is now mode 3, and + * mode 3 is now mode 2. In other words, one-line notation shows the new + * mode order written in terms of the old mode offsets. * * @note If @p il is a trivial cycle it will NOT be explicitly stored. * - * @param[in] il The modes involved in the cycle. Modes need not be in - * canonical order. - * - * @throw std::runtime_error if a mode appears more than once in il. Strong - * throw guarantee. + * @throw std::runtime_error if @p il is not a valid one-line + * representation. Strong throw guarantee. * * @throw std::bad_alloc if there is a problem allocating the internal * state. Strong throw guarantee. */ explicit Permutation(cycle_initializer_list il) : - Permutation(cycle_container_type{cycle_type(il.begin(), il.end())}) {} + Permutation(il.size(), + parse_one_line_(cycle_type(il.begin(), il.end()))) {} /** @brief Creates a Permutation by explicitly specifying the cycles. + * + * @tparam Args the qualified types of @p args. Each type in @p Args is + * assumed to be implicitly convertible to cycle_type. * * Any arbitrary permutation can be specified by providing the cycles which - * comprise it. This ctor takes a list of cycles and creates the resulting - * Permutation. + * comprise it. This ctor takes the rank of the tensor, and a list of one + * or more cycles (zero cycles is handled by the identity constructor). + * Any mode not appearing in @p cycle0 or @p args is assumed to be a + * fixed point. * - * @param[in] cycles The cycles comprising the permutation. + * @param[in] cycle0 The first cycle in the permutation. + * @param[in] args The remaining sizeof...(args) cycles in the permutation. * * @throw std::runtime_error if a mode appears more than once in a cycle, * or if more than one cycle contains the same @@ -99,28 +107,15 @@ class Permutation : public Operation { * state. Strong throw guarantee. */ template - explicit Permutation(cycle_type cycle0, Args&&... args) : - Permutation(cycle_container_type{ - std::move(cycle0), cycle_type(std::forward(args))...}) {} + Permutation(mode_index_type rank, cycle_type cycle0, Args&&... args) : + Permutation( + rank, cycle_container_type{std::move(cycle0), + cycle_type(std::forward(args))...}) {} // ------------------------------------------------------------------------- // -- Getters // ------------------------------------------------------------------------- - /** @brief Determines the minimum rank a tensor must be to apply *this. - * - * Cycles stored in *this are expressed in terms of mode offsets. If for - * example a cycle swaps modes 3 and 4 we know that we can only apply - * such a permutation to a tensor with a minimum rank of 5 (otherwise it - * would not have a mode with offset 4). This method analyzes the cycles - * stored in *this and finds the largest mode offset. - * - * @return The maximum mode offset involved in any non-trivial cycle. - * - * @throw None No throw guarantee. - */ - mode_index_type minimum_rank() const noexcept; - /** @brief Obtains the @p i -th non-trivial cycle in *this. * * @param[in] i The offset of the requested cycles. Must be in the range @@ -178,7 +173,10 @@ class Permutation : public Operation { * * @throw None No throw guarantee. */ - void swap(Permutation& other) noexcept { m_cycles_.swap(other.m_cycles_); } + void swap(Permutation& other) noexcept { + m_cycles_.swap(other.m_cycles_); + std::swap(m_rank_, other.m_rank_); + } /** @brief Is *this value equal to @p rhs? * @@ -196,7 +194,8 @@ class Permutation : public Operation { * @throw None No throw guarantee. */ bool operator==(const Permutation& rhs) const noexcept { - return m_cycles_ == rhs.m_cycles_; + return std::tie(m_rank_, m_cycles_) == + std::tie(rhs.m_rank_, rhs.m_cycles_); } /** @brief Is *this different than @p rhs? @@ -229,10 +228,15 @@ class Permutation : public Operation { return are_equal_impl_(other); } + /// Implements rank by returning the stored rank + mode_index_type rank_() const noexcept override { return m_rank_; } + private: /// Type of container holding a set of cycles using cycle_container_type = std::set; + cycle_container_type parse_one_line_(const cycle_type& one_line) const; + void valid_offset_(mode_index_type i) const; /// Verifies that @p cycle does not contain repeat elements @@ -249,11 +253,20 @@ class Permutation : public Operation { cycle_container_type input); /// Primary ctor for the class. All others dispatch here - explicit Permutation(cycle_container_type cycles) : - m_cycles_(remove_trivial_cycles_(std::move(cycles))) {} + Permutation(mode_index_type rank, cycle_container_type cycles) : + m_cycles_(remove_trivial_cycles_(std::move(cycles))), m_rank_(rank) { + for(const auto& x : m_cycles_) + for(auto xi : x) + if(xi >= m_rank_) + throw std::runtime_error( + "Offset is inconsistent with rank"); + } /// The modes which can be freely permuted among each other cycle_container_type m_cycles_; + + /// The overall rank of the tensor + mode_index_type m_rank_; }; } // namespace tensorwrapper::symmetry diff --git a/src/tensorwrapper/sparsity/pattern.cpp b/src/tensorwrapper/sparsity/pattern.cpp new file mode 100644 index 00000000..c4bb9b4e --- /dev/null +++ b/src/tensorwrapper/sparsity/pattern.cpp @@ -0,0 +1,46 @@ +/* + * Copyright 2025 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace tensorwrapper::sparsity { + +using dsl_reference = typename Pattern::dsl_reference; + +dsl_reference Pattern::addition_assignment_(label_type this_labels, + const_labeled_reference lhs, + const_labeled_reference rhs) { + return permute_assignment_(this_labels, lhs); +} + +dsl_reference Pattern::subtraction_assignment_(label_type this_labels, + const_labeled_reference lhs, + const_labeled_reference rhs) { + return permute_assignment_(this_labels, lhs); +} + +dsl_reference Pattern::multiplication_assignment_(label_type this_labels, + const_labeled_reference lhs, + const_labeled_reference rhs) { + return permute_assignment_(this_labels, lhs); +} + +dsl_reference Pattern::permute_assignment_(label_type this_labels, + const_labeled_reference rhs) { + return *this = rhs.object(); +} + +} // namespace tensorwrapper::sparsity \ No newline at end of file diff --git a/src/tensorwrapper/symmetry/group.cpp b/src/tensorwrapper/symmetry/group.cpp new file mode 100644 index 00000000..a7f2f0ed --- /dev/null +++ b/src/tensorwrapper/symmetry/group.cpp @@ -0,0 +1,61 @@ +/* + * Copyright 2025 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace tensorwrapper::symmetry { +namespace { + +template +void assert_non_trivial(const LHSType& lhs, const RHSType& rhs) { + if(lhs.object().size() != 0 || rhs.object().size() != 0) + throw std::runtime_error("Support for non-trivial symmetry NYI!"); +} + +} // namespace + +using dsl_reference = typename Group::dsl_reference; + +dsl_reference Group::addition_assignment_(label_type this_labels, + const_labeled_reference lhs, + const_labeled_reference rhs) { + assert_non_trivial(lhs, rhs); + return permute_assignment_(this_labels, lhs); +} + +dsl_reference Group::subtraction_assignment_(label_type this_labels, + const_labeled_reference lhs, + const_labeled_reference rhs) { + assert_non_trivial(lhs, rhs); + return permute_assignment_(this_labels, lhs); +} + +dsl_reference Group::multiplication_assignment_(label_type this_labels, + const_labeled_reference lhs, + const_labeled_reference rhs) { + assert_non_trivial(lhs, rhs); + return permute_assignment_(this_labels, lhs); +} + +dsl_reference Group::permute_assignment_(label_type this_labels, + const_labeled_reference rhs) { + if(rhs.object().size() != 0) + throw std::runtime_error("Support for non-trivial symmetry NYI!"); + + return *this = rhs.object(); +} + +} // namespace tensorwrapper::symmetry \ No newline at end of file diff --git a/src/tensorwrapper/symmetry/permutation.cpp b/src/tensorwrapper/symmetry/permutation.cpp index 64f81441..22c2d182 100644 --- a/src/tensorwrapper/symmetry/permutation.cpp +++ b/src/tensorwrapper/symmetry/permutation.cpp @@ -20,24 +20,34 @@ #include namespace tensorwrapper::symmetry { +namespace { + +void assert_valid_one_line(const Permutation::cycle_type& one_line) { + auto n = one_line.size(); + std::set found; + + // Make sure each index from 0 to n-1 appears once (and only once) + for(std::size_t i = 0; i < n; ++i) { + auto xi = one_line[i]; + if(found.count(xi)) + throw std::runtime_error("Mode offset: " + std::to_string(xi) + + " appears multiple times"); + if(xi >= n) + throw std::runtime_error("Mode offset: " + std::to_string(xi) + + " is not in the range [0, " + + std::to_string(n) + + " ). Did you forget " + "one or more offsets?"); + found.insert(xi); + } +} + +} // namespace // ----------------------------------------------------------------------------- // -- Getters // ----------------------------------------------------------------------------- -Permutation::mode_index_type Permutation::minimum_rank() const noexcept { - if(m_cycles_.empty()) return mode_index_type(0); - - mode_index_type the_max(1); - for(const auto& cycle : m_cycles_) { - mode_index_type cycle_max = - *std::max_element(cycle.begin(), cycle.end()); - mode_index_type cycle_max_plus_one(cycle_max + mode_index_type(1)); - the_max = std::max(cycle_max_plus_one, the_max); - } - return the_max; -} - Permutation::cycle_type Permutation::operator[]( mode_index_type i) const noexcept { auto itr = m_cycles_.begin(); @@ -45,6 +55,33 @@ Permutation::cycle_type Permutation::operator[]( return *itr; } +Permutation::cycle_container_type Permutation::parse_one_line_( + const cycle_type& one_line) const { + assert_valid_one_line(one_line); + auto n = one_line.size(); + cycle_container_type rv; + std::set found; + + for(std::size_t i = 0; i < n; ++i) { + auto xi = one_line[i]; + if(found.count(xi)) continue; + found.insert(xi); + decltype(n) counter = 0; + auto xj = one_line[i]; + cycle_type cycle{xj}; + while(counter < n) { + if(xi == one_line[xj]) { + rv.insert(cycle); + break; + } + xj = one_line[xj]; + cycle.push_back(xj); + found.insert(xj); + } + } + return rv; +} + void Permutation::valid_offset_(mode_index_type i) const { if(i < size()) return; auto i_str = std::to_string(i); diff --git a/tests/cxx/unit_tests/tensorwrapper/dsl/dsl.cpp b/tests/cxx/unit_tests/tensorwrapper/dsl/dsl.cpp index 68c5d9d4..8b72b415 100644 --- a/tests/cxx/unit_tests/tensorwrapper/dsl/dsl.cpp +++ b/tests/cxx/unit_tests/tensorwrapper/dsl/dsl.cpp @@ -19,13 +19,16 @@ using namespace tensorwrapper; -using test_types = std::tuple; +using test_types = + std::tuple; TEMPLATE_LIST_TEST_CASE("DSL", "", test_types) { using object_type = TestType; - test_types scalar_values{test_tensorwrapper::smooth_scalar()}; - test_types matrix_values{test_tensorwrapper::smooth_matrix()}; + test_types scalar_values{test_tensorwrapper::smooth_scalar(), + symmetry::Group(0), sparsity::Pattern(0)}; + test_types matrix_values{test_tensorwrapper::smooth_matrix(), + symmetry::Group(2), sparsity::Pattern(2)}; auto value0 = std::get(scalar_values); auto value2 = std::get(matrix_values); diff --git a/tests/cxx/unit_tests/tensorwrapper/dsl/pairwise_parser.cpp b/tests/cxx/unit_tests/tensorwrapper/dsl/pairwise_parser.cpp index 1c82fb60..59d34a13 100644 --- a/tests/cxx/unit_tests/tensorwrapper/dsl/pairwise_parser.cpp +++ b/tests/cxx/unit_tests/tensorwrapper/dsl/pairwise_parser.cpp @@ -17,13 +17,16 @@ using namespace tensorwrapper; -using test_types = std::tuple; +using test_types = + std::tuple; TEMPLATE_LIST_TEST_CASE("PairwiseParser", "", test_types) { using object_type = TestType; - test_types scalar_values{test_tensorwrapper::smooth_scalar()}; - test_types matrix_values{test_tensorwrapper::smooth_matrix()}; + test_types scalar_values{test_tensorwrapper::smooth_scalar(), + symmetry::Group(0), sparsity::Pattern(0)}; + test_types matrix_values{test_tensorwrapper::smooth_matrix(), + symmetry::Group(2), sparsity::Pattern(2)}; auto value0 = std::get(scalar_values); auto value2 = std::get(matrix_values); diff --git a/tests/cxx/unit_tests/tensorwrapper/sparsity/pattern.cpp b/tests/cxx/unit_tests/tensorwrapper/sparsity/pattern.cpp index 57373535..bbf8fa95 100644 --- a/tests/cxx/unit_tests/tensorwrapper/sparsity/pattern.cpp +++ b/tests/cxx/unit_tests/tensorwrapper/sparsity/pattern.cpp @@ -21,17 +21,87 @@ using namespace tensorwrapper::sparsity; TEST_CASE("Pattern") { Pattern defaulted; + Pattern p1(1); SECTION("Ctors, assignment") { - SECTION("Default") {} + SECTION("Default") { REQUIRE(defaulted.rank() == 0); } - test_copy_move_ctor_and_assignment(defaulted); + SECTION("rank ctor") { + REQUIRE(Pattern(0).rank() == 0); + REQUIRE(p1.rank() == 1); + REQUIRE(Pattern(2).rank() == 2); + } + + test_copy_move_ctor_and_assignment(defaulted, p1); + } + + SECTION("rank") { + REQUIRE(defaulted.rank() == 0); + REQUIRE(p1.rank() == 1); } - SECTION("operator==") { REQUIRE(defaulted == Pattern{}); } + SECTION("operator==") { + // Defaulted is same as another defaulted + REQUIRE(defaulted == Pattern{}); + + // Defaulted is same as scalar + REQUIRE(defaulted == Pattern(0)); + + // Defaulted is not same as vector + REQUIRE_FALSE(defaulted == p1); + + // Vector equals vector + REQUIRE(p1 == Pattern(1)); + + // Vector not same as matrix + REQUIRE_FALSE(p1 == Pattern(2)); + } SECTION("operator!=") { // Just spot check because it is implemented in terms of operator== REQUIRE_FALSE(defaulted != Pattern{}); + REQUIRE(defaulted != p1); + } + + SECTION("clone") { + auto pdefaulted = defaulted.clone(); + REQUIRE(pdefaulted->are_equal(defaulted)); + + auto pp1 = p1.clone(); + REQUIRE(pp1->are_equal(p1)); + } + + SECTION("are_equal") { + // Just calls operator== so spot check. + REQUIRE(defaulted.are_equal(Pattern{})); + REQUIRE_FALSE(defaulted.are_equal(p1)); + } + + SECTION("addition_assignment") { + Pattern rv; + auto prv = &(rv.addition_assignment("i", p1("i"), p1("i"))); + REQUIRE(prv == &rv); + REQUIRE(rv == p1); + } + + SECTION("subtraction_assignment") { + Pattern rv; + auto prv = &(rv.subtraction_assignment("i", p1("i"), p1("i"))); + REQUIRE(prv == &rv); + REQUIRE(rv == p1); + } + + SECTION("multiplication_assignment") { + Pattern rv; + auto prv = &(rv.multiplication_assignment("i", p1("i"), p1("i"))); + REQUIRE(prv == &rv); + REQUIRE(rv == p1); + } + + SECTION("permute_assignment") { + Pattern rv; + auto prv = &(rv.permute_assignment("i", p1("i"))); + REQUIRE(prv == &rv); + REQUIRE(rv == p1); } } diff --git a/tests/cxx/unit_tests/tensorwrapper/symmetry/group.cpp b/tests/cxx/unit_tests/tensorwrapper/symmetry/group.cpp index 2f07f6b0..14792008 100644 --- a/tests/cxx/unit_tests/tensorwrapper/symmetry/group.cpp +++ b/tests/cxx/unit_tests/tensorwrapper/symmetry/group.cpp @@ -22,29 +22,52 @@ using namespace tensorwrapper::testing; using namespace tensorwrapper::symmetry; TEST_CASE("Group") { - Permutation p01{0, 1}; - Permutation p123{1, 2, 3}; + using cycle_type = typename Permutation::cycle_type; + using label_type = typename Group::label_type; + + Permutation p01(4, cycle_type{0, 1}); + Permutation p23(4, cycle_type{2, 3}); Group empty; - Group g(p01, p123); + Group g(p01, p23); SECTION("Ctors and assignment") { - SECTION("Default") { REQUIRE(empty.size() == 0); } + SECTION("Default") { + REQUIRE(empty.size() == 0); + REQUIRE(empty.rank() == 0); + } + + SECTION("Identity") { + Group i0(0); + REQUIRE(i0.size() == 0); + REQUIRE(i0.rank() == 0); + + Group i1(1); + REQUIRE(i1.size() == 0); + REQUIRE(i1.rank() == 1); + } SECTION("Value") { + REQUIRE(g.rank() == 4); REQUIRE(g.size() == 2); REQUIRE(g.at(0).are_equal(p01)); - REQUIRE(g.at(1).are_equal(p123)); + REQUIRE(g.at(1).are_equal(p23)); // Removes duplicates Group g2(p01, p01); + REQUIRE(g2.rank() == 4); REQUIRE(g2.size() == 1); REQUIRE(g2.at(0).are_equal(p01)); // Doesn't store identity operations - Group g3(p01, Permutation{}, Permutation{2}); + Group g3(p01, Permutation{0, 1, 2, 3}); + REQUIRE(g3.rank() == 4); REQUIRE(g3.size() == 1); REQUIRE(g3.at(0).are_equal(p01)); + + // Throws if operations have different ranks + using error_t = std::runtime_error; + REQUIRE_THROWS_AS(Group(p01, Permutation(2)), error_t); } test_copy_move_ctor_and_assignment(empty, g); } @@ -52,7 +75,12 @@ TEST_CASE("Group") { SECTION("count") { REQUIRE_FALSE(empty.count(p01)); REQUIRE(g.count(p01)); - REQUIRE(g.count(p123)); + REQUIRE(g.count(p23)); + } + + SECTION("rank") { + REQUIRE(empty.rank() == 0); + REQUIRE(g.rank() == 4); } SECTION("swap") { @@ -65,14 +93,37 @@ TEST_CASE("Group") { } SECTION("operator==") { + // Default constructed equals default constructed REQUIRE(empty == Group{}); + + // Default equal value constructio of scalar identity group + REQUIRE(empty == Group(0)); + REQUIRE(empty == Group{Permutation(0)}); + + // Default does not equal a general value construction REQUIRE_FALSE(empty == g); - REQUIRE(g == Group{p01, p123}); - REQUIRE(g == Group{p123, p01}); // Order doesn't matter + // Identity constructed with same rank + Group g1(1); + REQUIRE(g1 == Group(1)); + + // Identity with different ranks + REQUIRE_FALSE(g1 == Group(2)); + + // Identity with non-identity + REQUIRE_FALSE(Group(4) == g); + + // Value constructed equal value constructed with same value + REQUIRE(g == Group{p01, p23}); + REQUIRE(g == Group{p23, p01}); // Order doesn't matter + // Value constructed with different numbers of elements REQUIRE_FALSE(g == Group{p01}); - REQUIRE_FALSE(g == Group{p01, p123, Permutation{4, 5}}); + + // Value constructed with different elements + Permutation p0213{0, 2, 1, 3}; + Permutation p3120{3, 1, 2, 0}; + REQUIRE_FALSE(g == Group{p0213, p3120}); } SECTION("operator!=") { @@ -83,16 +134,90 @@ TEST_CASE("Group") { SECTION("at_()") { REQUIRE(g.at(0).are_equal(p01)); - REQUIRE(g.at(1).are_equal(p123)); + REQUIRE(g.at(1).are_equal(p23)); } SECTION("at_() const") { REQUIRE(std::as_const(g).at(0).are_equal(p01)); - REQUIRE(std::as_const(g).at(1).are_equal(p123)); + REQUIRE(std::as_const(g).at(1).are_equal(p23)); } SECTION("size_()") { REQUIRE(empty.size() == 0); REQUIRE(g.size() == 2); } + + SECTION("addition_assignment_") { + Group empty2; + + SECTION("Identity plus identity") { + Group g2(2); + auto g2ij = g2("i,j"); + auto pempty2 = &(empty2.addition_assignment("i,j", g2ij, g2ij)); + REQUIRE(pempty2 == &empty2); + REQUIRE(empty2 == g2); + } + + // Throws if non-trivial symmetry + using error_t = std::runtime_error; + label_type ijkl("i,j,k,l"); + auto lg = g("i,j,k,l"); + REQUIRE_THROWS_AS(empty2.addition_assignment(ijkl, lg, lg), error_t); + } + + SECTION("subtraction_assignment_") { + Group empty2; + + SECTION("Identity plus identity") { + Group g2(2); + auto g2ij = g2("i,j"); + auto pempty2 = &(empty2.subtraction_assignment("i,j", g2ij, g2ij)); + REQUIRE(pempty2 == &empty2); + REQUIRE(empty2 == g2); + } + + // Throws if non-trivial symmetry + using error_t = std::runtime_error; + label_type ijkl("i,j,k,l"); + auto lg = g("i,j,k,l"); + REQUIRE_THROWS_AS(empty2.subtraction_assignment(ijkl, lg, lg), error_t); + } + + SECTION("multiplication_assignment_") { + Group empty2; + + SECTION("Identity plus identity") { + Group g2(2); + auto g2ij = g2("i,j"); + auto pempty2 = + &(empty2.multiplication_assignment("i,j", g2ij, g2ij)); + REQUIRE(pempty2 == &empty2); + REQUIRE(empty2 == g2); + } + + // Throws if non-trivial symmetry + using error_t = std::runtime_error; + label_type ijkl("i,j,k,l"); + auto lg = g("i,j,k,l"); + REQUIRE_THROWS_AS(empty2.multiplication_assignment(ijkl, lg, lg), + error_t); + } + + SECTION("permute_assignment_") { + Group empty2; + + SECTION("Permute identity") { + Group g2(2); + auto g2ij = g2("i,j"); + auto pempty2 = &(empty2.permute_assignment("i,j", g2ij)); + REQUIRE(pempty2 == &empty2); + REQUIRE(empty2 == g2); + } + + // Throws if non-trivial symmetry + using error_t = std::runtime_error; + label_type ijkl("i,j,k,l"); + auto lg = g("i,j,k,l"); + REQUIRE_THROWS_AS(empty2.permute_assignment(ijkl, lg), error_t); + } } diff --git a/tests/cxx/unit_tests/tensorwrapper/symmetry/permutation.cpp b/tests/cxx/unit_tests/tensorwrapper/symmetry/permutation.cpp index 9a41406a..61cb8d0a 100644 --- a/tests/cxx/unit_tests/tensorwrapper/symmetry/permutation.cpp +++ b/tests/cxx/unit_tests/tensorwrapper/symmetry/permutation.cpp @@ -24,43 +24,74 @@ using mode_index_type = Permutation::mode_index_type; using cycle_type = Permutation::cycle_type; TEST_CASE("Permutation") { - Permutation defaulted; - Permutation one_cycle{0, 1}; - Permutation two_cycles(cycle_type{2, 1, 3}, cycle_type{4, 5}); - + // Create some cycles to make permutations from + cycle_type c0{0}; + cycle_type c1{1}; cycle_type c01{0, 1}; cycle_type c132{1, 3, 2}; + cycle_type c213{2, 1, 3}; cycle_type c45{4, 5}; + Permutation defaulted; + Permutation one_cycle{1, 0}; + Permutation two_cycles(6, c213, c45); + SECTION("Ctors and assignment") { SECTION("Default") { REQUIRE(defaulted.size() == mode_index_type(0)); - REQUIRE(defaulted.minimum_rank() == mode_index_type(0)); + REQUIRE(defaulted.rank() == mode_index_type(0)); } - SECTION("Cycle") { + SECTION("Identity") { + Permutation p2(2); + REQUIRE(p2.size() == mode_index_type(0)); + REQUIRE(p2.rank() == mode_index_type(2)); + } + + SECTION("One-line") { REQUIRE(one_cycle.size() == mode_index_type(1)); - REQUIRE(one_cycle.minimum_rank() == mode_index_type(2)); + REQUIRE(one_cycle.rank() == mode_index_type(2)); REQUIRE(one_cycle.at(0) == c01); + // Identity permutation via one-line + Permutation p5{0, 1, 2, 3, 4}; + REQUIRE(p5.size() == mode_index_type(0)); + REQUIRE(p5.rank() == mode_index_type(5)); + + // Two cycles via one-line + Permutation p01_23{1, 0, 3, 2}; + REQUIRE(p01_23.size() == mode_index_type(2)); + REQUIRE(p01_23.rank() == mode_index_type(4)); + REQUIRE(p01_23.at(0) == c01); + REQUIRE(p01_23.at(1) == cycle_type{2, 3}); + + using error_t = std::runtime_error; + // Not all indices appear (or equivalently a mode index is too high) + REQUIRE_THROWS_AS(Permutation({0, 2}), error_t); + + // Index appears multiple times + REQUIRE_THROWS_AS(Permutation({0, 0}), error_t); + } + + SECTION("Cycle") { REQUIRE(two_cycles.size() == mode_index_type(2)); - REQUIRE(two_cycles.minimum_rank() == mode_index_type(6)); + REQUIRE(two_cycles.rank() == mode_index_type(6)); REQUIRE(two_cycles.at(0) == c132); // Canonicalization must work REQUIRE(two_cycles.at(1) == c45); SECTION("Removes trivial cycles") { - Permutation one_trivial_cycle{0}; + Permutation one_trivial_cycle(1, cycle_type{0}); REQUIRE(one_trivial_cycle.size() == 0); - REQUIRE(one_trivial_cycle.minimum_rank() == 0); + REQUIRE(one_trivial_cycle.rank() == 1); - Permutation two_trivial_cycles(cycle_type{0}, cycle_type{1}); + Permutation two_trivial_cycles(2, cycle_type{0}, cycle_type{1}); REQUIRE(two_trivial_cycles.size() == 0); - REQUIRE(two_trivial_cycles.minimum_rank() == 0); + REQUIRE(two_trivial_cycles.rank() == 2); - Permutation one_trivial_one_real(cycle_type{4}, + Permutation one_trivial_one_real(5, cycle_type{4}, cycle_type{0, 1}); REQUIRE(one_trivial_one_real.size() == 1); - REQUIRE(one_trivial_one_real.minimum_rank() == 2); + REQUIRE(one_trivial_one_real.rank() == 5); } using except = std::runtime_error; @@ -71,7 +102,7 @@ TEST_CASE("Permutation") { SECTION("Error if cycles overlap") { REQUIRE_THROWS_AS( - (Permutation(cycle_type{0, 1}, cycle_type{1, 2})), except); + (Permutation(3, cycle_type{0, 1}, cycle_type{1, 2})), except); } test_copy_move_ctor_and_assignment(defaulted, one_cycle, @@ -79,10 +110,10 @@ TEST_CASE("Permutation") { } } - SECTION("minimum_rank") { - REQUIRE(defaulted.minimum_rank() == 0); - REQUIRE(one_cycle.minimum_rank() == 2); - REQUIRE(two_cycles.minimum_rank() == 6); + SECTION("rank_") { + REQUIRE(defaulted.rank() == 0); + REQUIRE(one_cycle.rank() == 2); + REQUIRE(two_cycles.rank() == 6); } SECTION("operator[]") { @@ -122,27 +153,31 @@ TEST_CASE("Permutation") { // Defaulted equals another defaulted object REQUIRE(defaulted == Permutation{}); - // Defaulted equals an object with only trivial cycles - REQUIRE(defaulted == Permutation{1}); - REQUIRE(defaulted == Permutation(cycle_type{0}, cycle_type{1})); + // Defaulted does not equal an object with only trivial cycles + REQUIRE_FALSE(defaulted == Permutation(1)); + REQUIRE_FALSE(defaulted == Permutation(2, c0, c1)); // Defaulted does not equal an object with non-trivial cycles REQUIRE_FALSE(defaulted == one_cycle); + /// Identity equals same rank identity + REQUIRE(Permutation(1) == Permutation(1)); + + /// Identity doe not equal different rank identity + REQUIRE_FALSE(Permutation(1) == Permutation(2)); + // Values input in same order - REQUIRE(two_cycles == - Permutation(cycle_type{2, 1, 3}, cycle_type{4, 5})); + REQUIRE(two_cycles == Permutation(6, c213, c45)); // Values input in different order - REQUIRE(two_cycles == - Permutation(cycle_type{4, 5}, cycle_type{1, 3, 2})); + REQUIRE(two_cycles == Permutation(6, c45, c132)); // Different number of cycles REQUIRE_FALSE(one_cycle == two_cycles); // Different cycles REQUIRE_FALSE(two_cycles == - Permutation(cycle_type{1, 2}, cycle_type{3, 4, 5})); + Permutation(6, cycle_type{1, 2}, cycle_type{3, 4, 5})); } SECTION("operator!=") { @@ -178,7 +213,7 @@ TEST_CASE("Permutation") { const_base_reference one_base = one_cycle; const_base_reference two_base = two_cycles; REQUIRE_FALSE(one_base.are_equal(two_base)); - REQUIRE(Permutation{0, 1}.are_equal(one_base)); + REQUIRE(Permutation{1, 0}.are_equal(one_base)); } } }