Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions include/tensorwrapper/buffer/buffer_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#pragma once
#include <tensorwrapper/detail_/dsl_base.hpp>
#include <tensorwrapper/detail_/polymorphic_base.hpp>
#include <tensorwrapper/dsl/labeled.hpp>
#include <tensorwrapper/layout/layout_base.hpp>
Expand All @@ -25,30 +26,33 @@ namespace tensorwrapper::buffer {
*
* All classes which wrap existing tensor libraries derive from this class.
*/
class BufferBase : public detail_::PolymorphicBase<BufferBase> {
class BufferBase : public detail_::PolymorphicBase<BufferBase>,
public detail_::DSLBase<BufferBase> {
private:
/// Type of *this
using my_type = BufferBase;

protected:
/// Type *this inherits from
using my_base_type = detail_::PolymorphicBase<my_type>;
using polymorphic_base = detail_::PolymorphicBase<my_type>;

public:
/// Type all buffers inherit from
using buffer_base_type = typename my_base_type::base_type;
using buffer_base_type = typename polymorphic_base::base_type;

/// Type of a mutable reference to a buffer_base_type object
using buffer_base_reference = typename my_base_type::base_reference;
using buffer_base_reference = typename polymorphic_base::base_reference;

/// Type of a read-only reference to a buffer_base_type object
using const_buffer_base_reference =
typename my_base_type::const_base_reference;
typename polymorphic_base::const_base_reference;

/// Type of a pointer to an object of type buffer_base_type
using buffer_base_pointer = typename my_base_type::base_pointer;
using buffer_base_pointer = typename polymorphic_base::base_pointer;

/// Type of a pointer to a read-only object of type buffer_base_type
using const_buffer_base_pointer = typename my_base_type::const_base_pointer;
using const_buffer_base_pointer =
typename polymorphic_base::const_base_pointer;

/// Type of the class describing the physical layout of the buffer
using layout_type = layout::LayoutBase;
Expand All @@ -59,6 +63,9 @@ class BufferBase : public detail_::PolymorphicBase<BufferBase> {
/// Type of a pointer to the layout
using layout_pointer = typename layout_type::layout_pointer;

/// Type used to represent the tensor's rank
using rank_type = typename layout_type::size_type;

// -------------------------------------------------------------------------
// -- Accessors
// -------------------------------------------------------------------------
Expand Down Expand Up @@ -90,6 +97,10 @@ class BufferBase : public detail_::PolymorphicBase<BufferBase> {
return *m_layout_;
}

rank_type rank() const noexcept {
return has_layout() ? layout().rank() : 0;
}

// -------------------------------------------------------------------------
// -- Utility methods
// -------------------------------------------------------------------------
Expand Down Expand Up @@ -191,6 +202,21 @@ class BufferBase : public detail_::PolymorphicBase<BufferBase> {
return *this;
}

dsl_reference addition_assignment_(label_type this_labels,
const_labeled_reference lhs,
const_labeled_reference rhs) override;

dsl_reference subtraction_assignment_(label_type this_labels,
const_labeled_reference lhs,
const_labeled_reference rhs) override;

dsl_reference multiplication_assignment_(
label_type this_labels, const_labeled_reference lhs,
const_labeled_reference rhs) override;

dsl_reference permute_assignment_(label_type this_labels,
const_labeled_reference rhs) override;

private:
/// Throws std::runtime_error when there is no layout
void assert_layout_() const {
Expand Down
28 changes: 27 additions & 1 deletion include/tensorwrapper/buffer/eigen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,36 @@ class Eigen : public Replicated {
return my_base_type::are_equal_impl_<my_type>(rhs);
}

/// Implements addition_assignment by calling addition_assignment on state
dsl_reference addition_assignment_(label_type this_labels,
const_labeled_reference lhs,
const_labeled_reference rhs) override;

/// Calls subtraction_assignment on each member
dsl_reference subtraction_assignment_(label_type this_labels,
const_labeled_reference lhs,
const_labeled_reference rhs) override;

/// Calls multiplication_assignment on each member
dsl_reference multiplication_assignment_(
label_type this_labels, const_labeled_reference lhs,
const_labeled_reference rhs) override;

/// Calls permute_assignment on each member
dsl_reference permute_assignment_(label_type this_labels,
const_labeled_reference rhs) override;

/// Implements to_string
typename my_base_type::string_type to_string_() const override;
typename polymorphic_base::string_type to_string_() const override;

private:
dsl_reference hadamard_(label_type this_labels, const_labeled_reference lhs,
const_labeled_reference rhs);

dsl_reference contraction_(label_type this_labels,
const_labeled_reference lhs,
const_labeled_reference rhs);

/// The actual Eigen tensor
data_type m_tensor_;
};
Expand Down
94 changes: 94 additions & 0 deletions include/tensorwrapper/dsl/dummy_indices.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,51 @@ class DummyIndices
return true;
}

/** @brief Is a thruple of DummyIndices consistent with a pure element-wise
* product?
*
* In generalized Einstein notation a pure element-wise (also commonly
* termed Hadamard) product is denoted by *this, @p lhs, and @p rhs
* having the same ordered set of dummy indices, up to permutation.
* Additionally, the dummy indices associated with any given tensor may
* not include a repeated index.
*
* @param[in] lhs The dummy indices associated with the tensor to the
* left of the times operator.
* @param[in] rhs The dummy indices associated with the tensor to the
* right of the times operator.
*
* @return True If the dummy indices given by *this, @p lhs, and @p rhs
* are consistent with a purely element-wise product of the tensors
* that @p lhs and @p rhs label.
*
* @throw None No throw guarantee.
*/
bool is_hadamard_product(const DummyIndices& lhs,
const DummyIndices& rhs) const noexcept;

/** @brief Does a thruple of DummyIndices indicate a product is a pure
* contraction?
*
* In generalized Einstein notation a pure contraction is an operation
* where indices common to @p lhs and @p rhs are summed over and do NOT
* appear in the result, i.e., *this. Additionally, we stipulate that
* there must be at least one index summed over (if no index is summed over
* the operation is a pure direct-product).
*
* @param[in] lhs The dummy indices associated with the tensor to the
* left of the times operator.
* @param[in] rhs The dummy indices associated with the tensor to the
* right of the times operator.
*
* @return True if the indices associated with *this, @p lhs, and @p rhs
* are consistent with a contraction and false otherwise.
*
* @throw None No throw guarantee.
*/
bool is_contraction(const DummyIndices& lhs,
const DummyIndices& rhs) const noexcept;

/** @brief Computes the permutation needed to convert *this into @p other.
*
* Each DummyIndices object is viewed as an ordered set of objects. If
Expand Down Expand Up @@ -366,6 +411,31 @@ class DummyIndices
return rv;
}

/** @brief Returns the set difference of *this and @p other.
*
* The set difference of *this with @p other is the set of indices which
* appear in *this, but not in @p other. This method will return the set
* (indices which appear more than once in *this will only appear once
* in the result) which results from the set difference of *this with
* @p other.
*
* @param[in] other The set to remove from *this.
*
* @return The set difference of *this and @p rhs.
*
* @throw std::bad_alloc if there is a problem allocating the return.
* Strong throw guarantee.
*/
DummyIndices difference(const DummyIndices& other) const {
DummyIndices rv;
for(const auto& x : *this) {
if(other.count(x)) continue;
if(rv.count(x)) continue;
rv.m_dummy_indices_.push_back(x);
}
return rv;
}

protected:
/// Main ctor for setting the value, throws if any index is empty
explicit DummyIndices(split_string_type split_dummy_indices) :
Expand Down Expand Up @@ -401,4 +471,28 @@ class DummyIndices
split_string_type m_dummy_indices_;
};

template<typename StringType>
bool DummyIndices<StringType>::is_hadamard_product(
const DummyIndices& lhs, const DummyIndices& rhs) const noexcept {
if(has_repeated_indices()) return false;
if(lhs.has_repeated_indices()) return false;
if(rhs.has_repeated_indices()) return false;
if(!is_permutation(lhs)) return false;
if(!is_permutation(rhs)) return false;
return true;
}

template<typename StringType>
bool DummyIndices<StringType>::is_contraction(
const DummyIndices& lhs, const DummyIndices& rhs) const noexcept {
if(has_repeated_indices()) return false;
if(lhs.has_repeated_indices()) return false;
if(rhs.has_repeated_indices()) return false;
auto lhs_cap_rhs = lhs.intersection(rhs);
if(lhs_cap_rhs.empty()) return false; // No common indices
if(!intersection(lhs_cap_rhs).empty())
return false; // Common index not summed
return true;
}

} // namespace tensorwrapper::dsl
Loading
Loading