Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions docs/source/developer/adding_operations_to_contiguous.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
.. Copyright 2023 NWChemEx-Project
..
.. Licensed under the Apache License, Version 2.0 (the "License");
.. you may not use this file except in compliance with the License.
.. You may obtain a copy of the License at
..
.. http://www.apache.org/licenses/LICENSE-2.0
..
.. Unless required by applicable law or agreed to in writing, software
.. distributed under the License is distributed on an "AS IS" BASIS,
.. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
.. See the License for the specific language governing permissions and
.. limitations under the License.

###############################
Adding Operations to Contiguous
###############################

The ``Contiguous`` class is the workhorse of most tensor operations because it
provides the kernels that non-contiguous tensors are built on. As such, we may
need to add operations to it from time to time. This document describes how to
do that.

**********************************
Understanding How Contiguous Works
**********************************

.. figure:: assets/how_contiguous_works.png
:align: center

Control flow for an operation resulting in a ``Contiguous`` buffer object.

For concreteness, we'll trace how ``subtraction_assignment`` is implemented.
Other binary operations are implemented nearly identically and the
implementation of unary operations is extremely similar.

1. The input objects, ``lhs`` and ``rhs`` are converted to ``Contiguous``
objects. N.b., we should eventually use performance models to decide whether
the time to convert to ``Contiguous`` objects is worth it, or if we should
rely on algorithms which do not require contiguous data.
2. We work out the shape of the output tensor.
3. A visitor for the desired operation is created. For
``subtraction_assignment``, this is ``detail_::SubtractionVisitor``.

- Visitor definitions live in ``wtf/src/tensorwrapper/buffer/detail_/``.

5. Control enters ``wtf::buffer::visit_contiguous_buffer`` to restore floating-
point types.
6. ``lhs`` and ``rhs`` are converted to ``std::span`` objects.
7. Control enters the visitor.
8. With types known, the output tensor can be initialized (and is).
9. The visitor converts the ``std::span`` objects into the tensor backend's
tensor objects.

- Backend implementations live in ``wtf/src/tensorwrapper/backends/``.

10. The backend's implementation of the operation is invoked.

**********************
Adding a New Operation
**********************

1. Verify that one of the backends supports the desired operation. If not, add
it to a backend first.
2. Create a visitor for it.
3. Add the operation to ``wtf::buffer::Contiguous``.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/source/developer/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ Developer Documentation
:caption: Contents:

design/index
adding_operations_to_contiguous
25 changes: 25 additions & 0 deletions src/tensorwrapper/buffer/detail_/binary_operation_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,29 @@ class SubtractionVisitor : public BinaryOperationVisitor {
}
};

/// Visitor that calls hadamard_assignment or contraction_assignment
class MultiplicationVisitor : public BinaryOperationVisitor {
public:
using BinaryOperationVisitor::BinaryOperationVisitor;
using BinaryOperationVisitor::operator();

template<typename FloatType>
void operator()(std::span<FloatType> lhs, std::span<FloatType> rhs) {
using clean_t = std::decay_t<FloatType>;
auto pthis = this->make_this_eigen_tensor_<clean_t>();
auto plhs = this->make_lhs_eigen_tensor_(lhs);
auto prhs = this->make_rhs_eigen_tensor_(rhs);

if(this_labels().is_hadamard_product(lhs_labels(), rhs_labels()))
pthis->hadamard_assignment(this_labels(), lhs_labels(),
rhs_labels(), *plhs, *prhs);
else if(this_labels().is_contraction(lhs_labels(), rhs_labels()))
pthis->contraction_assignment(this_labels(), lhs_labels(),
rhs_labels(), *plhs, *prhs);
else
throw std::runtime_error(
"MultiplicationVisitor: Batched contraction NYI");
}
};

} // namespace tensorwrapper::buffer::detail_
21 changes: 20 additions & 1 deletion src/tensorwrapper/buffer/mdbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,26 @@ auto MDBuffer::multiplication_assignment_(label_type this_labels,
const_labeled_reference lhs,
const_labeled_reference rhs)
-> dsl_reference {
throw std::runtime_error("multiplication NYI");
const auto& lhs_down = downcast(lhs.object());
const auto& rhs_down = downcast(rhs.object());
const auto& lhs_shape = lhs_down.m_shape_;
const auto& rhs_shape = rhs_down.m_shape_;

auto labeled_lhs_shape = lhs_shape(lhs.labels());
auto labeled_rhs_shape = rhs_shape(rhs.labels());

m_shape_.multiplication_assignment(this_labels, labeled_lhs_shape,
labeled_rhs_shape);

detail_::MultiplicationVisitor visitor(m_buffer_, this_labels, m_shape_,
lhs.labels(), lhs_shape,
rhs.labels(), rhs_shape);

wtf::buffer::visit_contiguous_buffer<fp_types>(visitor, lhs_down.m_buffer_,
rhs_down.m_buffer_);

mark_for_rehash_();
return *this;
}

auto MDBuffer::permute_assignment_(label_type this_labels,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,69 @@ TEMPLATE_LIST_TEST_CASE("SubtractionVisitor", "[buffer][detail_]",
REQUIRE(empty_buffer.at(3) == TestType(0.0));
}
}

TEMPLATE_LIST_TEST_CASE("MultiplicationVisitor", "[buffer][detail_]",
types::floating_point_types) {
using VisitorType = buffer::detail_::MultiplicationVisitor;
using buffer_type = typename VisitorType::buffer_type;
using label_type = typename VisitorType::label_type;
using shape_type = typename VisitorType::shape_type;

TestType one{1.0}, two{2.0}, three{3.0}, four{4.0};
std::vector<TestType> this_data{one, two, three, four};
std::vector<TestType> lhs_data{four, three, two, one};
std::vector<TestType> rhs_data{one, one, one, one};
shape_type shape({4});
label_type labels("i");

std::span<TestType> lhs_span(lhs_data.data(), lhs_data.size());
std::span<const TestType> clhs_span(lhs_data.data(), lhs_data.size());
std::span<TestType> rhs_span(rhs_data.data(), rhs_data.size());
std::span<const TestType> crhs_span(rhs_data.data(), rhs_data.size());

SECTION("existing buffer: Hadamard") {
buffer_type this_buffer(this_data);
VisitorType visitor(this_buffer, labels, shape, labels, shape, labels,
shape);

visitor(lhs_span, rhs_span);
REQUIRE(this_buffer.at(0) == TestType(4.0));
REQUIRE(this_buffer.at(1) == TestType(3.0));
REQUIRE(this_buffer.at(2) == TestType(2.0));
REQUIRE(this_buffer.at(3) == TestType(1.0));
}

SECTION("existing buffer: contraction") {
buffer_type this_buffer(this_data);
shape_type scalar_shape;
VisitorType visitor(this_buffer, label_type(""), scalar_shape, labels,
shape, labels, shape);

visitor(lhs_span, rhs_span);
REQUIRE(this_buffer.size() == 1);
REQUIRE(this_buffer.at(0) == TestType(10.0));
}

SECTION("existing buffer: batched contraction") {
buffer_type this_buffer(this_data);
shape_type out_shape({2});
label_type lhs_labels("a,i");
label_type rhs_labels("i,a");
VisitorType visitor(this_buffer, labels, out_shape, lhs_labels, shape,
rhs_labels, shape);

REQUIRE_THROWS_AS(visitor(lhs_span, rhs_span), std::runtime_error);
}

SECTION("non-existing buffer") {
buffer_type empty_buffer;
VisitorType visitor(empty_buffer, labels, shape, labels, shape, labels,
shape);

visitor(clhs_span, crhs_span);
REQUIRE(empty_buffer.at(0) == TestType(4.0));
REQUIRE(empty_buffer.at(1) == TestType(3.0));
REQUIRE(empty_buffer.at(2) == TestType(2.0));
REQUIRE(empty_buffer.at(3) == TestType(1.0));
}
}
38 changes: 38 additions & 0 deletions tests/cxx/unit_tests/tensorwrapper/buffer/mdbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,44 @@ TEMPLATE_LIST_TEST_CASE("MDBuffer", "", types::floating_point_types) {
}
}

SECTION("multiplication_assignment_") {
// N.b., dispatching among hadamard, contraction, etc. is the visitor's
// responsibility and happens there. Here we just test hadamard.

SECTION("scalar") {
label_type labels("");
MDBuffer result;
result.multiplication_assignment(labels, scalar(labels),
scalar(labels));
REQUIRE(result.shape() == scalar_shape);
REQUIRE(result.get_elem({}) == TestType(1.0));
}

SECTION("vector") {
label_type labels("i");
MDBuffer result;
result.multiplication_assignment(labels, vector(labels),
vector(labels));
REQUIRE(result.shape() == vector_shape);
REQUIRE(result.get_elem({0}) == TestType(1.0));
REQUIRE(result.get_elem({1}) == TestType(4.0));
REQUIRE(result.get_elem({2}) == TestType(9.0));
REQUIRE(result.get_elem({3}) == TestType(16.0));
}

SECTION("matrix") {
label_type labels("i,j");
MDBuffer result;
result.multiplication_assignment(labels, matrix(labels),
matrix(labels));
REQUIRE(result.shape() == matrix_shape);
REQUIRE(result.get_elem({0, 0}) == TestType(1.0));
REQUIRE(result.get_elem({0, 1}) == TestType(4.0));
REQUIRE(result.get_elem({1, 0}) == TestType(9.0));
REQUIRE(result.get_elem({1, 1}) == TestType(16.0));
}
}

SECTION("scalar_multiplication_") {
// TODO: Test with other scalar types when public API supports it
using scalar_type = double;
Expand Down