Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 70 additions & 4 deletions include/tensorwrapper/sparsity/pattern.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,57 @@
*/

#pragma once
#include <tensorwrapper/detail_/dsl_base.hpp>
#include <tensorwrapper/detail_/polymorphic_base.hpp>

namespace tensorwrapper::sparsity {

/** @brief Base class for objects describing the sparsity of a tensor. */
class Pattern {
class Pattern : public tensorwrapper::detail_::DSLBase<Pattern>,
public tensorwrapper::detail_::PolymorphicBase<Pattern> {
private:
/// Type defining the polymorphic API of *this
using polymorphic_base = tensorwrapper::detail_::PolymorphicBase<Pattern>;

public:
/// Type used for indexing and offsets
using size_type = std::size_t;

/** @brief Creates a pattern for a rank @p rank tensor.
*
* This constructor creates a sparsity pattern for a dense tensor with
* @p rank modes.
*
* @param[in] rank The number of modes in the associated tensor.
*
* @throw None No throw guarantee.
*/
Pattern(size_type rank = 0) noexcept : m_rank_(rank) {}

/** @brief Provides the rank of the tensor *this assumes.
*
* @return The rank of the tensor *this describes.
*
* @throw None No throw guarantee.
*/
size_type rank() const noexcept { return m_rank_; }

/** @brief Determines if *this and @p rhs describe the same sparsity
* pattern.
*
* At present the sparsity component of TensorWrapper is a stub so this
* method always returns true.
* At present the sparsity component only tracks the rank of the tensor so
* two Patterns are value equal if they describe tensors with the same
* rank.
*
* @param[in] rhs The object to compare against.
*
* @return True if *this is value equal to @p rhs and false otherwise.
*
* @throw None No throw guarantee.
*/
bool operator==(const Pattern& rhs) const noexcept { return true; }
bool operator==(const Pattern& rhs) const noexcept {
return rank() == rhs.rank();
}

/** @brief Is *this different from @p rhs?
*
Expand All @@ -49,6 +81,40 @@ class Pattern {
bool operator!=(const Pattern& rhs) const noexcept {
return !((*this) == rhs);
}

protected:
/// Implements clone by calling copy constructor
typename polymorphic_base::base_pointer clone_() const override {
return std::make_unique<Pattern>(*this);
}

/// Implements are_equal by calling implementation provided by the base
bool are_equal_(const_base_reference rhs) const noexcept override {
return are_equal_impl_<Pattern>(rhs);
}

/// Implements addition_assignment via permute_assignment
dsl_reference addition_assignment_(label_type this_labels,
const_labeled_reference lhs,
const_labeled_reference rhs) override;

/// Implements subtraction_assignment via permute_assignment
dsl_reference subtraction_assignment_(label_type this_labels,
const_labeled_reference lhs,
const_labeled_reference rhs) override;

/// Implements multiplication_assignment via permute_assignment
dsl_reference multiplication_assignment_(
label_type this_labels, const_labeled_reference lhs,
const_labeled_reference rhs) override;

/// Implements permute_assignment by permuting the extents in @p rhs.
dsl_reference permute_assignment_(label_type this_labels,
const_labeled_reference rhs) override;

private:
/// The rank of the tensor associated with *this
size_type m_rank_;
};

} // namespace tensorwrapper::sparsity
94 changes: 84 additions & 10 deletions include/tensorwrapper/symmetry/group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

#pragma once
#include <deque>
#include <optional>
#include <tensorwrapper/detail_/dsl_base.hpp>
#include <tensorwrapper/detail_/polymorphic_base.hpp>
#include <tensorwrapper/symmetry/operation.hpp>
#include <utilities/containers/indexable_container_base.hpp>

Expand All @@ -34,13 +37,18 @@ namespace tensorwrapper::symmetry {
* mathematically know that the permutation (0, 2, 1) is also a symmetry
* operation because it is the inverse of (0, 1, 2).
*/
class Group : public utilities::IndexableContainerBase<Group> {
class Group : public utilities::IndexableContainerBase<Group>,
public tensorwrapper::detail_::DSLBase<Group>,
public tensorwrapper::detail_::PolymorphicBase<Group> {
private:
/// Type of *this
using my_type = Group;

/// Type *this derives from
using base_type = utilities::IndexableContainerBase<my_type>;
/// Type *this derives from to become container-like
using container_type = utilities::IndexableContainerBase<my_type>;

/// Type *this derives from to behave like other polymorphic objects
using polymorphic_type = tensorwrapper::detail_::PolymorphicBase<Group>;

public:
/// The base type of each object in *this
Expand All @@ -55,18 +63,30 @@ class Group : public utilities::IndexableContainerBase<Group> {
/// Unsigned integral type used for indexing and offsets
using size_type = std::size_t;

/// Type used for mode indices
using mode_index_type = typename value_type::mode_index_type;

// -------------------------------------------------------------------------
// -- Ctors and assignment
// -------------------------------------------------------------------------

/** @brief Initializes *this to an empty group.
/** @brief Initializes *this to be the identity group of a scalar.
*
* @throw None No throw guarantee
*/
Group() noexcept = default;

/** @brief Initializes *this as the identity group of a rank @p rank tensor.
*
* This ctor creates a group describing the symmetries of an empty set
* set. Such a set is also the symmetry group of a scalar.
* This ctor creates a group representing the identity group of a rank
* @p rank tensor.
*
* @param[in] rank The rank of the tensor the identity group describes.
*
* @throw None No throw guarantee.
*/
Group() noexcept = default;
explicit Group(mode_index_type rank) noexcept :
m_relations_{}, m_rank_(rank) {}

/** @brief Creates a Group from the provided symmetry operations.
*
Expand All @@ -81,12 +101,17 @@ class Group : public utilities::IndexableContainerBase<Group> {
* @param[in] op A symmetry operation to initialize the group with.
* @param[in] ops The remaining symmetry operations.
*
* @throw std::runtime_error if @p op and @p ops do not all have the same
* rank. Strong throw guarantee.
* @throw std::bad_alloc if there is a problem allocating the initial
* state. Strong throw guarantee.
*/
template<typename... Args>
explicit Group(const_reference op, Args&&... ops) :
Group(std::forward<Args>(ops)...) {
if(m_rank_ && rank() != op.rank())
throw std::runtime_error("Ranks of operations are not consistent");
if(!m_rank_) m_rank_.emplace(op.rank());
if(!count(op) && !op.is_identity())
m_relations_.emplace_front(op.clone());
}
Expand All @@ -104,6 +129,7 @@ class Group : public utilities::IndexableContainerBase<Group> {
Group(const Group& other) {
for(const auto& x : other.m_relations_)
m_relations_.push_back(x->clone());
m_rank_ = other.m_rank_;
}

/** @brief Transfers the state of @p other in to *this.
Expand Down Expand Up @@ -158,6 +184,17 @@ class Group : public utilities::IndexableContainerBase<Group> {
*/
bool count(const_reference op) const noexcept;

/** @brief The rank of the tensor these symmetries describe.
*
* This is not the rank of the group, but rather the rank of the tensor
* that the symmetries of the group describe.
*
* @return The rank of the tensor described by the symmetries in *this.
*
* @throw None No throw guarantee.
*/
mode_index_type rank() const noexcept { return m_rank_.value_or(0); }

// -------------------------------------------------------------------------
// -- Utility methods
// -------------------------------------------------------------------------
Expand All @@ -170,12 +207,17 @@ class Group : public utilities::IndexableContainerBase<Group> {
*
* @throw None No throw guarantee.
*/
void swap(Group& other) noexcept { m_relations_.swap(other.m_relations_); }
void swap(Group& other) noexcept {
m_relations_.swap(other.m_relations_);
m_rank_.swap(other.m_rank_);
}

/** @brief Determines if *this is value equal to @p rhs.
*
* Two Group objects are value equal if they contain the same number of
* operations and if each operation found in *this is also found in @p rhs.
* operations, if each operation found in *this is also found in @p rhs,
* and if the rank of the associated tensor is the same for *this as
* @p rhs.
*
* @param[in] rhs The Group object we are comparing against.
*
Expand All @@ -198,9 +240,37 @@ class Group : public utilities::IndexableContainerBase<Group> {
*/
bool operator!=(const Group& rhs) const noexcept { return !(*this == rhs); }

protected:
typename polymorphic_type::base_pointer clone_() const override {
return std::make_unique<Group>(*this);
}

bool are_equal_(const_base_reference rhs) const noexcept override {
return (*this) == rhs;
}

/// Implements addition_assignment via permute_assignment
dsl_reference addition_assignment_(label_type this_labels,
const_labeled_reference lhs,
const_labeled_reference rhs) override;

/// Implements subtraction_assignment via permute_assignment
dsl_reference subtraction_assignment_(label_type this_labels,
const_labeled_reference lhs,
const_labeled_reference rhs) override;

/// Implements multiplication_assignment via permute_assignment
dsl_reference multiplication_assignment_(
label_type this_labels, const_labeled_reference lhs,
const_labeled_reference rhs) override;

/// Implements permute_assignment by permuting the extents in @p rhs.
dsl_reference permute_assignment_(label_type this_labels,
const_labeled_reference rhs) override;

private:
/// Allow base class to access implementations
friend base_type;
friend container_type;

/// Base type common to all symmetry operations defining the API
using value_pointer = value_type::base_pointer;
Expand All @@ -219,6 +289,9 @@ class Group : public utilities::IndexableContainerBase<Group> {

/// The symmetry operations of *this
relation_container_type m_relations_;

/// The rank of the tensor these symmetries apply to
std::optional<mode_index_type> m_rank_;
};

// -- Out of line implementations
Expand All @@ -231,6 +304,7 @@ inline bool Group::count(const_reference op) const noexcept {
}

inline bool Group::operator==(const Group& rhs) const noexcept {
if(rank() != rhs.rank()) return false;
if(size() != rhs.size()) return false;
for(const auto& x : *this)
if(!rhs.count(x)) return false;
Expand Down
5 changes: 5 additions & 0 deletions include/tensorwrapper/symmetry/operation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,18 @@ class Operation : public detail_::PolymorphicBase<Operation> {

bool is_identity() const noexcept { return is_identity_(); }

mode_index_type rank() const noexcept { return rank_(); }

// -------------------------------------------------------------------------
// -- Utility methods
// -------------------------------------------------------------------------

protected:
/// Derived class should overwrite to implement is_identity
virtual bool is_identity_() const noexcept = 0;

/// Derived class should overwrite to implement rank()
virtual mode_index_type rank_() const noexcept = 0;
};

} // namespace tensorwrapper::symmetry
Loading
Loading