diff --git a/docs/source/developer/adding_operations_to_contiguous.rst b/docs/source/developer/adding_operations_to_contiguous.rst new file mode 100644 index 00000000..6a685cbc --- /dev/null +++ b/docs/source/developer/adding_operations_to_contiguous.rst @@ -0,0 +1,66 @@ +.. Copyright 2023 NWChemEx-Project +.. +.. Licensed under the Apache License, Version 2.0 (the "License"); +.. you may not use this file except in compliance with the License. +.. You may obtain a copy of the License at +.. +.. http://www.apache.org/licenses/LICENSE-2.0 +.. +.. Unless required by applicable law or agreed to in writing, software +.. distributed under the License is distributed on an "AS IS" BASIS, +.. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +.. See the License for the specific language governing permissions and +.. limitations under the License. + +############################### +Adding Operations to Contiguous +############################### + +The ``Contiguous`` class is the workhorse of most tensor operations because it +provides the kernels that non-contiguous tensors are built on. As such, we may +need to add operations to it from time to time. This document describes how to +do that. + +********************************** +Understanding How Contiguous Works +********************************** + +.. figure:: assets/how_contiguous_works.png + :align: center + + Control flow for an operation resulting in a ``Contiguous`` buffer object. + +For concreteness, we'll trace how ``subtraction_assignment`` is implemented. +Other binary operations are implemented nearly identically and the +implementation of unary operations is extremely similar. + +1. The input objects, ``lhs`` and ``rhs`` are converted to ``Contiguous`` + objects. N.b., we should eventually use performance models to decide whether + the time to convert to ``Contiguous`` objects is worth it, or if we should + rely on algorithms which do not require contiguous data. +2. We work out the shape of the output tensor. +3. A visitor for the desired operation is created. For + ``subtraction_assignment``, this is ``detail_::SubtractionVisitor``. + + - Visitor definitions live in ``wtf/src/tensorwrapper/buffer/detail_/``. + +5. Control enters ``wtf::buffer::visit_contiguous_buffer`` to restore floating- + point types. +6. ``lhs`` and ``rhs`` are converted to ``std::span`` objects. +7. Control enters the visitor. +8. With types known, the output tensor can be initialized (and is). +9. The visitor converts the ``std::span`` objects into the tensor backend's + tensor objects. + + - Backend implementations live in ``wtf/src/tensorwrapper/backends/``. + +10. The backend's implementation of the operation is invoked. + +********************** +Adding a New Operation +********************** + +1. Verify that one of the backends supports the desired operation. If not, add + it to a backend first. +2. Create a visitor for it. +3. Add the operation to ``wtf::buffer::Contiguous``. diff --git a/docs/source/developer/assets/how_contiguous_works.png b/docs/source/developer/assets/how_contiguous_works.png new file mode 100644 index 00000000..9602d632 Binary files /dev/null and b/docs/source/developer/assets/how_contiguous_works.png differ diff --git a/docs/source/developer/index.rst b/docs/source/developer/index.rst index 9fb2285b..826a9ffd 100644 --- a/docs/source/developer/index.rst +++ b/docs/source/developer/index.rst @@ -21,3 +21,4 @@ Developer Documentation :caption: Contents: design/index + adding_operations_to_contiguous diff --git a/src/tensorwrapper/buffer/detail_/binary_operation_visitor.hpp b/src/tensorwrapper/buffer/detail_/binary_operation_visitor.hpp index 25363fd1..76432135 100644 --- a/src/tensorwrapper/buffer/detail_/binary_operation_visitor.hpp +++ b/src/tensorwrapper/buffer/detail_/binary_operation_visitor.hpp @@ -135,4 +135,29 @@ class SubtractionVisitor : public BinaryOperationVisitor { } }; +/// Visitor that calls hadamard_assignment or contraction_assignment +class MultiplicationVisitor : public BinaryOperationVisitor { +public: + using BinaryOperationVisitor::BinaryOperationVisitor; + using BinaryOperationVisitor::operator(); + + template + void operator()(std::span lhs, std::span rhs) { + using clean_t = std::decay_t; + auto pthis = this->make_this_eigen_tensor_(); + auto plhs = this->make_lhs_eigen_tensor_(lhs); + auto prhs = this->make_rhs_eigen_tensor_(rhs); + + if(this_labels().is_hadamard_product(lhs_labels(), rhs_labels())) + pthis->hadamard_assignment(this_labels(), lhs_labels(), + rhs_labels(), *plhs, *prhs); + else if(this_labels().is_contraction(lhs_labels(), rhs_labels())) + pthis->contraction_assignment(this_labels(), lhs_labels(), + rhs_labels(), *plhs, *prhs); + else + throw std::runtime_error( + "MultiplicationVisitor: Batched contraction NYI"); + } +}; + } // namespace tensorwrapper::buffer::detail_ diff --git a/src/tensorwrapper/buffer/mdbuffer.cpp b/src/tensorwrapper/buffer/mdbuffer.cpp index dc016b47..79400829 100644 --- a/src/tensorwrapper/buffer/mdbuffer.cpp +++ b/src/tensorwrapper/buffer/mdbuffer.cpp @@ -157,7 +157,26 @@ auto MDBuffer::multiplication_assignment_(label_type this_labels, const_labeled_reference lhs, const_labeled_reference rhs) -> dsl_reference { - throw std::runtime_error("multiplication NYI"); + const auto& lhs_down = downcast(lhs.object()); + const auto& rhs_down = downcast(rhs.object()); + const auto& lhs_shape = lhs_down.m_shape_; + const auto& rhs_shape = rhs_down.m_shape_; + + auto labeled_lhs_shape = lhs_shape(lhs.labels()); + auto labeled_rhs_shape = rhs_shape(rhs.labels()); + + m_shape_.multiplication_assignment(this_labels, labeled_lhs_shape, + labeled_rhs_shape); + + detail_::MultiplicationVisitor visitor(m_buffer_, this_labels, m_shape_, + lhs.labels(), lhs_shape, + rhs.labels(), rhs_shape); + + wtf::buffer::visit_contiguous_buffer(visitor, lhs_down.m_buffer_, + rhs_down.m_buffer_); + + mark_for_rehash_(); + return *this; } auto MDBuffer::permute_assignment_(label_type this_labels, diff --git a/tests/cxx/unit_tests/tensorwrapper/buffer/detail_/binary_operation_visitor.cpp b/tests/cxx/unit_tests/tensorwrapper/buffer/detail_/binary_operation_visitor.cpp index bbcc0599..e3dcfa2c 100644 --- a/tests/cxx/unit_tests/tensorwrapper/buffer/detail_/binary_operation_visitor.cpp +++ b/tests/cxx/unit_tests/tensorwrapper/buffer/detail_/binary_operation_visitor.cpp @@ -146,3 +146,69 @@ TEMPLATE_LIST_TEST_CASE("SubtractionVisitor", "[buffer][detail_]", REQUIRE(empty_buffer.at(3) == TestType(0.0)); } } + +TEMPLATE_LIST_TEST_CASE("MultiplicationVisitor", "[buffer][detail_]", + types::floating_point_types) { + using VisitorType = buffer::detail_::MultiplicationVisitor; + using buffer_type = typename VisitorType::buffer_type; + using label_type = typename VisitorType::label_type; + using shape_type = typename VisitorType::shape_type; + + TestType one{1.0}, two{2.0}, three{3.0}, four{4.0}; + std::vector this_data{one, two, three, four}; + std::vector lhs_data{four, three, two, one}; + std::vector rhs_data{one, one, one, one}; + shape_type shape({4}); + label_type labels("i"); + + std::span lhs_span(lhs_data.data(), lhs_data.size()); + std::span clhs_span(lhs_data.data(), lhs_data.size()); + std::span rhs_span(rhs_data.data(), rhs_data.size()); + std::span crhs_span(rhs_data.data(), rhs_data.size()); + + SECTION("existing buffer: Hadamard") { + buffer_type this_buffer(this_data); + VisitorType visitor(this_buffer, labels, shape, labels, shape, labels, + shape); + + visitor(lhs_span, rhs_span); + REQUIRE(this_buffer.at(0) == TestType(4.0)); + REQUIRE(this_buffer.at(1) == TestType(3.0)); + REQUIRE(this_buffer.at(2) == TestType(2.0)); + REQUIRE(this_buffer.at(3) == TestType(1.0)); + } + + SECTION("existing buffer: contraction") { + buffer_type this_buffer(this_data); + shape_type scalar_shape; + VisitorType visitor(this_buffer, label_type(""), scalar_shape, labels, + shape, labels, shape); + + visitor(lhs_span, rhs_span); + REQUIRE(this_buffer.size() == 1); + REQUIRE(this_buffer.at(0) == TestType(10.0)); + } + + SECTION("existing buffer: batched contraction") { + buffer_type this_buffer(this_data); + shape_type out_shape({2}); + label_type lhs_labels("a,i"); + label_type rhs_labels("i,a"); + VisitorType visitor(this_buffer, labels, out_shape, lhs_labels, shape, + rhs_labels, shape); + + REQUIRE_THROWS_AS(visitor(lhs_span, rhs_span), std::runtime_error); + } + + SECTION("non-existing buffer") { + buffer_type empty_buffer; + VisitorType visitor(empty_buffer, labels, shape, labels, shape, labels, + shape); + + visitor(clhs_span, crhs_span); + REQUIRE(empty_buffer.at(0) == TestType(4.0)); + REQUIRE(empty_buffer.at(1) == TestType(3.0)); + REQUIRE(empty_buffer.at(2) == TestType(2.0)); + REQUIRE(empty_buffer.at(3) == TestType(1.0)); + } +} diff --git a/tests/cxx/unit_tests/tensorwrapper/buffer/mdbuffer.cpp b/tests/cxx/unit_tests/tensorwrapper/buffer/mdbuffer.cpp index 11dd8081..33c13421 100644 --- a/tests/cxx/unit_tests/tensorwrapper/buffer/mdbuffer.cpp +++ b/tests/cxx/unit_tests/tensorwrapper/buffer/mdbuffer.cpp @@ -317,6 +317,44 @@ TEMPLATE_LIST_TEST_CASE("MDBuffer", "", types::floating_point_types) { } } + SECTION("multiplication_assignment_") { + // N.b., dispatching among hadamard, contraction, etc. is the visitor's + // responsibility and happens there. Here we just test hadamard. + + SECTION("scalar") { + label_type labels(""); + MDBuffer result; + result.multiplication_assignment(labels, scalar(labels), + scalar(labels)); + REQUIRE(result.shape() == scalar_shape); + REQUIRE(result.get_elem({}) == TestType(1.0)); + } + + SECTION("vector") { + label_type labels("i"); + MDBuffer result; + result.multiplication_assignment(labels, vector(labels), + vector(labels)); + REQUIRE(result.shape() == vector_shape); + REQUIRE(result.get_elem({0}) == TestType(1.0)); + REQUIRE(result.get_elem({1}) == TestType(4.0)); + REQUIRE(result.get_elem({2}) == TestType(9.0)); + REQUIRE(result.get_elem({3}) == TestType(16.0)); + } + + SECTION("matrix") { + label_type labels("i,j"); + MDBuffer result; + result.multiplication_assignment(labels, matrix(labels), + matrix(labels)); + REQUIRE(result.shape() == matrix_shape); + REQUIRE(result.get_elem({0, 0}) == TestType(1.0)); + REQUIRE(result.get_elem({0, 1}) == TestType(4.0)); + REQUIRE(result.get_elem({1, 0}) == TestType(9.0)); + REQUIRE(result.get_elem({1, 1}) == TestType(16.0)); + } + } + SECTION("scalar_multiplication_") { // TODO: Test with other scalar types when public API supports it using scalar_type = double;