Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/tensorwrapper/detail_/dsl_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ class DSLBase {
/// Type of parsed labels
using label_type = typename labeled_type::label_type;

/// Type of a mutable reference to a labeled_type object
using labeled_reference = labeled_type&;

/// Type of a read-only reference to a labeled_type object
using const_labeled_reference = const labeled_const_type&;

Expand Down
57 changes: 42 additions & 15 deletions include/tensorwrapper/dsl/dummy_indices.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#pragma once
#include <ostream>
#include <set>
#include <string>
#include <utilities/containers/indexable_container_base.hpp>
Expand Down Expand Up @@ -110,6 +111,15 @@ class DummyIndices
DummyIndices(
utilities::strings::split_string(remove_spaces_(dummy_indices), ",")) {}

/// Main ctor for setting the value, throws if any index is empty
explicit DummyIndices(split_string_type split_dummy_indices) :
m_dummy_indices_(std::move(split_dummy_indices)) {
for(const auto& x : m_dummy_indices_)
if(x.empty())
throw std::runtime_error(
"Dummy index is not allowed to be empty");
}

/** @brief Determines the number of unique indices in *this.
*
* A dummy index can be repeated if it is going to be summed over. This
Expand Down Expand Up @@ -212,12 +222,27 @@ class DummyIndices
*
* Each DummyIndices object is viewed as an ordered set of objects. If
* two DummyIndices objects contain the same objects, but in a different
* order, we can convert either object into the other by permuting it.
* This method computes the permutation needed to change *this into
* @p other. More specifically the result of this method is a vector
* of length `size()` such that the `i`-th element is the offset of
* `(*this)[i]` in @p other, i.e., if `x` is the return then
* `other[x[i]] == (*this)[i]`.
* order, we can convert either object into the other by permuting it. This
* method determines the permutation necessary to convert *this into other;
* the reverse permutation, i.e., converting other to *this is obtained
* by `other.permutation(*this)`.
*
* For concreteness assume we have a set A=(i, j, k) and we want to
* permute it into the set B=(j,k,i), then there are two subtly different
* ways of writing the necessary permutation:
* 1. Write the target set in terms of the current offsets, e.g., A goes
* to B is written as P1=(1, 2, 0)
* 2. Write the current set in terms of the target offsets, e.g., A goes
* to B is written as P2=(2, 0, 1)
*
* Option one maps offsets in B to offsets in A, e.g., A[P1[0]] == B[0],
* whereas option two maps offsets in A to offsets in B, e.g.,
* A[0] == B[P2[0]]. This method follows definition two.
*
* @note The definitions are inverses of each other. So if you don't like
* our definition, and want the other one, flip *this and @p other, e.g.,
* Permuting B to A using definition 1 yields P1'=(2, 0, 1)=P2 and
* permuting B to A using definition 1 yields P2'=(1, 2, 0)=P1.
*
* @param[in] other The order we want to permute *this to.
*
Expand Down Expand Up @@ -437,15 +462,6 @@ class DummyIndices
}

protected:
/// Main ctor for setting the value, throws if any index is empty
explicit DummyIndices(split_string_type split_dummy_indices) :
m_dummy_indices_(std::move(split_dummy_indices)) {
for(const auto& x : m_dummy_indices_)
if(x.empty())
throw std::runtime_error(
"Dummy index is not allowed to be empty");
}

/// Lets the base class get at these implementations
friend base_type;

Expand All @@ -471,6 +487,17 @@ class DummyIndices
split_string_type m_dummy_indices_;
};

template<typename StringType>
std::ostream& operator<<(std::ostream& os, const DummyIndices<StringType>& i) {
if(i.size() == 0) return os;
os << i.at(0);
for(std::size_t j = 1; j < i.size(); ++j) {
os << ",";
os << i.at(j);
}
return os;
}

template<typename StringType>
bool DummyIndices<StringType>::is_hadamard_product(
const DummyIndices& lhs, const DummyIndices& rhs) const noexcept {
Expand Down
130 changes: 130 additions & 0 deletions src/tensorwrapper/buffer/contraction_planner.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Copyright 2025 NWChemEx-Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once
#include <tensorwrapper/dsl/dummy_indices.hpp>

namespace tensorwrapper::buffer {

/** @brief Class for working out details pertaining to a tensor contraction.
*
* N.B. Contraction covers direct product (which is a special case of
* contraction with 0 dummy indices).
*/
class ContractionPlanner {
public:
/// String type users use to label modes
using string_type = std::string;

/// Type of the parsed labels
using label_type = dsl::DummyIndices<string_type>;

ContractionPlanner(string_type result, string_type lhs, string_type rhs) :
ContractionPlanner(label_type(result), label_type(lhs), label_type(rhs)) {
}

ContractionPlanner(label_type result, label_type lhs, label_type rhs) :
m_result_(std::move(result)),
m_lhs_(std::move(lhs)),
m_rhs_(std::move(rhs)) {
assert_no_repeated_indices_();
assert_dummy_indices_are_similar_();
assert_no_shared_free_();
}

/// Labels in LHS that are NOT summed over
label_type lhs_free() const { return m_lhs_.intersection(m_result_); }

/// Labels in RHS that are NOT summed over
label_type rhs_free() const { return m_rhs_.intersection(m_result_); }

/// Labels in LHS that ARE summed over
label_type lhs_dummy() const { return m_lhs_.difference(m_result_); }

/// Labels in RHS that ARE summed over
label_type rhs_dummy() const { return m_rhs_.difference(m_result_); }

/** @brief LHS permuted so free indices are followed by dummy indices. */
label_type lhs_permutation() const {
using split_string_type = typename label_type::split_string_type;
split_string_type rv;
auto lfree = lhs_free();
auto ldummy = lhs_dummy();
for(const auto& freei : m_result_) {
if(!lfree.count(freei)) continue;
rv.push_back(freei);
}
for(const auto& dummyi : ldummy) rv.push_back(dummyi);
return label_type(std::move(rv));
}

/** @brief RHS permuted so dummy indices are followed by free indices. */
label_type rhs_permutation() const {
typename label_type::split_string_type rv;
auto rfree = rhs_free();
auto rdummy = lhs_dummy(); // Use LHS dummy to get the same order!
for(const auto& dummyi : rdummy)
rv.push_back(dummyi); // Know it only appears 1x
for(const auto& freei : m_result_) {
if(!rfree.count(freei)) continue;
rv.push_back(freei); // Know it only appears 1x
}
return label_type(std::move(rv));
}

/** @brief Flattened result labels.
*
* After applying lhs_permutation to LHS to get A, and rhs_permutation to
* RHS to get B, A and B can be multiplied together with a gemm. The
* resulting matrix has indices given by concatenating the free indices of
* A with the free indices of B. This method returns those indices.
*
*/
label_type result_matrix_labels() const {
const auto lhs = lhs_permutation();
const auto rhs = rhs_permutation();
return lhs.concatenation(rhs).difference(lhs_dummy());
}

private:
/// Ensures no tensor contains a repeated label
void assert_no_repeated_indices_() const {
const bool result_good = !m_result_.has_repeated_indices();
const bool lhs_good = !m_lhs_.has_repeated_indices();
const bool rhs_good = !m_rhs_.has_repeated_indices();

if(result_good && lhs_good && rhs_good) return;
throw std::runtime_error("One or more terms contain repeated labels");
}

/// Ensures the dummy indices are permutations of each other
void assert_dummy_indices_are_similar_() const {
if(lhs_dummy().is_permutation(rhs_dummy())) return;
throw std::runtime_error("Dummy indices must appear in all terms");
}

/// Asserts LHS and RHS do not share free indices, which is Hadamard-product
void assert_no_shared_free_() const {
if(!lhs_free().intersection(rhs_free()).size()) return;
throw std::runtime_error("Contraction must sum repeated indices");
}

label_type m_result_;
label_type m_lhs_;
label_type m_rhs_;
};

} // namespace tensorwrapper::buffer
27 changes: 4 additions & 23 deletions src/tensorwrapper/buffer/eigen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "eigen_contraction.hpp"
#include <sstream>
#include <tensorwrapper/allocator/eigen.hpp>
Expand Down Expand Up @@ -139,7 +140,8 @@ typename EIGEN::dsl_reference EIGEN::permute_assignment_(
const auto& rlabels = rhs.labels();

if(this_labels != rlabels) { // We need to permute rhs before assignment
auto r_to_l = rhs.labels().permutation(this_labels);
// Eigen adopts the opposite definition of permutation from us.
auto r_to_l = this_labels.permutation(rlabels);
// Eigen wants int objects
std::vector<int> r_to_l2(r_to_l.begin(), r_to_l.end());
m_tensor_ = rhs_downcasted.value().shuffle(r_to_l2);
Expand Down Expand Up @@ -226,28 +228,7 @@ typename EIGEN::dsl_reference EIGEN::hadamard_(label_type this_labels,
TPARAMS typename EIGEN::dsl_reference EIGEN::contraction_(
label_type this_labels, const_labeled_reference lhs,
const_labeled_reference rhs) {
const auto& llabels = lhs.labels();
const auto& lobject = lhs.object();
const auto& rlabels = rhs.labels();
const auto& robject = rhs.object();

// N.b. is a pure contraction, so common indices are summed over
auto common = llabels.intersection(rlabels);

// -- This block converts string indices to mode offsets
using rank_type = unsigned short;
using pair_type = std::pair<rank_type, rank_type>;
std::vector<pair_type> modes;
auto rank = common.size();
for(decltype(rank) i = 0; i < rank; ++i) {
const auto& index_i = common.at(i);
// N.b., pure contraction so there's no repeats within a tensor's label
auto lindex = llabels.find(index_i)[0];
auto rindex = rlabels.find(index_i)[0];
modes.push_back(pair_type(lindex, rindex));
}

return eigen_contraction<FloatType>(*this, lobject, robject, modes);
return eigen_contraction(*this, this_labels, lhs, rhs);
}

#undef EIGEN
Expand Down
Loading