Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/tensorwrapper/buffer/eigen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ class Eigen : public Replicated {
dsl_reference permute_assignment_(label_type this_labels,
const_labeled_reference rhs) override;

dsl_reference scalar_multiplication_(label_type this_labels, double scalar,
const_labeled_reference rhs) override;

/// Implements to_string
typename polymorphic_base::string_type to_string_() const override;

Expand Down
18 changes: 11 additions & 7 deletions include/tensorwrapper/detail_/dsl_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,16 +209,18 @@ class DSLBase {
* @tparam ScalarType The type of @p scalar. Assumed to be a floating-
* point type.
*
* This method is responsible for scaling @p *this by @p scalar.
* This method is responsible for scaling @p rhs by @p scalar and assigning
* it to *this.
*
* @note This method is templated on the scalar type to avoid limiting the
* API. That said, at present the backend converts @p scalar to
* double precision.
* double precision, but we could use a variant or something similar
* to avoid this
*/
template<typename ScalarType>
dsl_reference scalar_multiplication(ScalarType&& scalar) {
return scalar_multiplication_(std::forward<ScalarType>(scalar));
}
template<typename LabelType, typename ScalarType>
dsl_reference scalar_multiplication(LabelType&& this_labels,
ScalarType&& scalar,
const_labeled_reference rhs);

protected:
/// Derived class should overwrite to implement addition_assignment
Expand Down Expand Up @@ -249,7 +251,9 @@ class DSLBase {
}

/// Derived class should overwrite to implement scalar_multiplication
dsl_reference scalar_multiplication_(double scalar) {
virtual dsl_reference scalar_multiplication_(label_type this_labels,
double scalar,
const_labeled_reference rhs) {
throw std::runtime_error("Scalar multiplication NYI");
}

Expand Down
12 changes: 12 additions & 0 deletions include/tensorwrapper/detail_/dsl_base.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,18 @@ typename DSL_BASE::dsl_reference DSL_BASE::permute_assignment(
return permute_assignment_(std::move(lhs_labels), rhs);
}

TPARAMS
template<typename LabelType, typename FloatType>
typename DSL_BASE::dsl_reference DSL_BASE::scalar_multiplication(
LabelType&& this_labels, FloatType&& scalar, const_labeled_reference rhs) {
assert_indices_match_rank_(rhs);

label_type lhs_labels(std::forward<LabelType>(this_labels));
assert_is_subset_(lhs_labels, rhs.labels());

return scalar_multiplication_(std::move(lhs_labels), scalar, rhs);
}

#undef DSL_BASE
#undef TPARAMS

Expand Down
37 changes: 24 additions & 13 deletions include/tensorwrapper/dsl/pairwise_parser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,7 @@ class PairwiseParser {
*/
template<typename LHSType, typename RHSType>
void dispatch(LHSType&& lhs, const RHSType& rhs) {
if constexpr(std::is_floating_point_v<std::decay_t<RHSType>>) {
lhs.object().scalar_multiplication(rhs);
} else {
lhs.object().permute_assignment(lhs.labels(), rhs);
}
lhs.object().permute_assignment(lhs.labels(), rhs);
}

/** @brief Handles adding two expressions together.
Expand Down Expand Up @@ -130,14 +126,29 @@ class PairwiseParser {
*/
template<typename LHSType, typename T, typename U>
void dispatch(LHSType&& lhs, const utilities::dsl::Multiply<T, U>& rhs) {
auto pA = lhs.object().clone();
auto pB = lhs.object().clone();
auto labels = lhs.labels();
auto lA = (*pA)(labels);
auto lB = (*pB)(labels);
dispatch(lA, rhs.lhs());
dispatch(lB, rhs.rhs());
lhs.object().multiplication_assignment(labels, lA, lB);
constexpr bool t_is_float = std::is_floating_point_v<T>;
constexpr bool u_is_float = std::is_floating_point_v<U>;
static_assert(!(t_is_float && u_is_float), "Both can be float??");
if constexpr(t_is_float) {
auto pA = lhs.object().clone();
auto lA = (*pA)(lhs.labels());
dispatch(lA, rhs.rhs());
lhs.object().scalar_multiplication(lhs.labels(), rhs.lhs(), lA);
} else if constexpr(u_is_float) {
auto pA = lhs.object().clone();
auto lA = (*pA)(lhs.labels());
dispatch(lA, rhs.lhs());
lhs.object().scalar_multiplication(lhs.labels(), rhs.rhs(), lA);
} else {
auto pA = lhs.object().clone();
auto pB = lhs.object().clone();
auto labels = lhs.labels();
auto lA = (*pA)(labels);
auto lB = (*pB)(labels);
dispatch(lA, rhs.lhs());
dispatch(lB, rhs.rhs());
lhs.object().multiplication_assignment(labels, lA, lB);
}
}
};

Expand Down
61 changes: 59 additions & 2 deletions include/tensorwrapper/tensor/tensor_class.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
*/

#pragma once
#include <tensorwrapper/dsl/labeled.hpp>
#include <tensorwrapper/detail_/dsl_base.hpp>
#include <tensorwrapper/detail_/polymorphic_base.hpp>
#include <tensorwrapper/tensor/detail_/tensor_input.hpp>

namespace tensorwrapper {
Expand All @@ -34,7 +35,8 @@ struct IsTuple<std::tuple<Args...>> : std::true_type {};
* The Tensor class is envisioned as being the most user-facing class of
* TensorWrapper and forms the entry point into TensorWrapper's DSL.
*/
class Tensor {
class Tensor : public detail_::DSLBase<Tensor>,
public detail_::PolymorphicBase<Tensor> {
private:
/// Type of a helper class which collects the inputs needed to make a tensor
using input_type = detail_::TensorInput;
Expand All @@ -53,6 +55,8 @@ class Tensor {
using enable_if_no_tensors_t =
std::enable_if_t<!are_any_tensors_v<Args...>>;

using polymorphic_base = detail_::PolymorphicBase<Tensor>;

public:
/// Type of the object implementing *this
using pimpl_type = detail_::TensorPIMPL;
Expand Down Expand Up @@ -81,6 +85,9 @@ class Tensor {
/// Type of a pointer to a read-only buffer
using const_buffer_pointer = input_type::const_buffer_pointer;

/// Type used to convey rank
using rank_type = typename logical_layout_type::size_type;

/// Type of an initializer list if *this is a scalar
using scalar_il_type = double;

Expand Down Expand Up @@ -299,6 +306,19 @@ class Tensor {
*/
const_buffer_reference buffer() const;

/** @brief Returns the logical rank of the tensor.
*
* Most users interacting with a tensor will be thinking of it in terms of
* its logical rank. This function is a convenience function for calling
* `rank()` on the logical layout.
*
* @return The rank of the tensor, logically.
*
* @throw std::runtime_error if *this does not have a logical layout.
* Strong throw guarantee.
*/
rank_type rank() const;

// -------------------------------------------------------------------------
// -- Utility methods
// -------------------------------------------------------------------------
Expand Down Expand Up @@ -344,6 +364,43 @@ class Tensor {
*/
bool operator!=(const Tensor& rhs) const noexcept;

protected:
/// Implements clone by calling copy ctor
polymorphic_base::base_pointer clone_() const override {
return std::make_unique<Tensor>(*this);
}

/// Implements are_equal by calling are_equal_impl_
bool are_equal_(const_base_reference rhs) const noexcept override {
return polymorphic_base::are_equal_impl_<Tensor>(rhs);
}

/// Implements addition_assignment by calling addition_assignment on state
dsl_reference addition_assignment_(label_type this_labels,
const_labeled_reference lhs,
const_labeled_reference rhs) override;

/// Calls subtraction_assignment on each member
dsl_reference subtraction_assignment_(label_type this_labels,
const_labeled_reference lhs,
const_labeled_reference rhs) override;

/// Calls multiplication_assignment on each member
dsl_reference multiplication_assignment_(
label_type this_labels, const_labeled_reference lhs,
const_labeled_reference rhs) override;

/// Calls scalar_multiplication on each member
dsl_reference scalar_multiplication_(label_type this_labels, double scalar,
const_labeled_reference rhs) override;

/// Calls permute_assignment on each member
dsl_reference permute_assignment_(label_type this_labels,
const_labeled_reference rhs) override;

/// Implements to_string
typename polymorphic_base::string_type to_string_() const override;

private:
/// All ctors ultimately dispatch to this ctor
Tensor(pimpl_pointer pimpl) noexcept;
Expand Down
24 changes: 24 additions & 0 deletions src/tensorwrapper/buffer/eigen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,30 @@ typename EIGEN::dsl_reference EIGEN::permute_assignment_(
return *this;
}

TPARAMS
typename EIGEN::dsl_reference EIGEN::scalar_multiplication_(
label_type this_labels, double scalar, const_labeled_reference rhs) {
BufferBase::permute_assignment_(this_labels, rhs);

using allocator_type = allocator::Eigen<FloatType, Rank>;
const auto& rhs_downcasted = allocator_type::rebind(rhs.object());

const auto& rlabels = rhs.labels();

FloatType c(scalar);

if(this_labels != rlabels) { // We need to permute rhs before assignment
auto r_to_l = rhs.labels().permutation(this_labels);
// Eigen wants int objects
std::vector<int> r_to_l2(r_to_l.begin(), r_to_l.end());
m_tensor_ = rhs_downcasted.value().shuffle(r_to_l2) * c;
} else {
m_tensor_ = rhs_downcasted.value() * c;
}

return *this;
}

TPARAMS
typename detail_::PolymorphicBase<BufferBase>::string_type EIGEN::to_string_()
const {
Expand Down
131 changes: 131 additions & 0 deletions src/tensorwrapper/tensor/tensor_class.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ const_buffer_reference Tensor::buffer() const {
return m_pimpl_->buffer();
}

Tensor::rank_type Tensor::rank() const { return logical_layout().rank(); }

// -- Utility

void Tensor::swap(Tensor& other) noexcept { m_pimpl_.swap(other.m_pimpl_); }
Expand All @@ -89,6 +91,135 @@ bool Tensor::operator!=(const Tensor& rhs) const noexcept {
return !(*this == rhs);
}

// -- Protected methods

Tensor::dsl_reference Tensor::addition_assignment_(
label_type this_labels, const_labeled_reference lhs,
const_labeled_reference rhs) {
const auto& lobject = lhs.object();
const auto& llabels = lhs.labels();
const auto& robject = rhs.object();
const auto& rlabels = rhs.labels();

auto llayout = lobject.logical_layout();
auto rlayout = robject.logical_layout();
auto pthis_layout = llayout.clone_as<logical_layout_type>();

pthis_layout->addition_assignment(this_labels, llayout(llabels),
rlayout(rlabels));

auto pthis_buffer = lobject.buffer().clone();
auto lbuffer = lobject.buffer()(llabels);
auto rbuffer = robject.buffer()(rlabels);
pthis_buffer->addition_assignment(this_labels, lbuffer, rbuffer);

auto new_pimpl = std::make_unique<pimpl_type>(std::move(pthis_layout),
std::move(pthis_buffer));
new_pimpl.swap(m_pimpl_);

return *this;
}

Tensor::dsl_reference Tensor::subtraction_assignment_(
label_type this_labels, const_labeled_reference lhs,
const_labeled_reference rhs) {
const auto& lobject = lhs.object();
const auto& llabels = lhs.labels();
const auto& robject = rhs.object();
const auto& rlabels = rhs.labels();

auto llayout = lobject.logical_layout();
auto rlayout = robject.logical_layout();
auto pthis_layout = llayout.clone_as<logical_layout_type>();

pthis_layout->subtraction_assignment(this_labels, llayout(llabels),
rlayout(rlabels));

auto pthis_buffer = lobject.buffer().clone();
auto lbuffer = lobject.buffer()(llabels);
auto rbuffer = robject.buffer()(rlabels);
pthis_buffer->subtraction_assignment(this_labels, lbuffer, rbuffer);

auto new_pimpl = std::make_unique<pimpl_type>(std::move(pthis_layout),
std::move(pthis_buffer));
new_pimpl.swap(m_pimpl_);

return *this;
}

Tensor::dsl_reference Tensor::multiplication_assignment_(
label_type this_labels, const_labeled_reference lhs,
const_labeled_reference rhs) {
const auto& lobject = lhs.object();
const auto& llabels = lhs.labels();
const auto& robject = rhs.object();
const auto& rlabels = rhs.labels();

auto llayout = lobject.logical_layout();
auto rlayout = robject.logical_layout();
auto pthis_layout = llayout.clone_as<logical_layout_type>();

pthis_layout->multiplication_assignment(this_labels, llayout(llabels),
rlayout(rlabels));

auto pthis_buffer = lobject.buffer().clone();
auto lbuffer = lobject.buffer()(llabels);
auto rbuffer = robject.buffer()(rlabels);
pthis_buffer->multiplication_assignment(this_labels, lbuffer, rbuffer);

auto new_pimpl = std::make_unique<pimpl_type>(std::move(pthis_layout),
std::move(pthis_buffer));
new_pimpl.swap(m_pimpl_);

return *this;
}

Tensor::dsl_reference Tensor::scalar_multiplication_(
label_type this_labels, double scalar, const_labeled_reference rhs) {
const auto& robject = rhs.object();
const auto& rlabels = rhs.labels();

auto rlayout = robject.logical_layout();
auto pthis_layout = rlayout.clone_as<logical_layout_type>();

pthis_layout->permute_assignment(this_labels, rlayout(rlabels));

auto pthis_buffer = robject.buffer().clone();
auto rbuffer = robject.buffer()(rlabels);
pthis_buffer->scalar_multiplication(this_labels, scalar, rbuffer);

auto new_pimpl = std::make_unique<pimpl_type>(std::move(pthis_layout),
std::move(pthis_buffer));
new_pimpl.swap(m_pimpl_);

return *this;
}

Tensor::dsl_reference Tensor::permute_assignment_(label_type this_labels,
const_labeled_reference rhs) {
const auto& robject = rhs.object();
const auto& rlabels = rhs.labels();

auto rlayout = robject.logical_layout();
auto pthis_layout = rlayout.clone_as<logical_layout_type>();

pthis_layout->permute_assignment(this_labels, rlayout(rlabels));

auto pthis_buffer = robject.buffer().clone();
auto rbuffer = robject.buffer()(rlabels);
pthis_buffer->permute_assignment(this_labels, rbuffer);

auto new_pimpl = std::make_unique<pimpl_type>(std::move(pthis_layout),
std::move(pthis_buffer));
new_pimpl.swap(m_pimpl_);

return *this;
}

typename Tensor::polymorphic_base::string_type Tensor::to_string_() const {
return has_pimpl_() ? buffer().to_string() : "";
}

// -- Private methods

Tensor::Tensor(pimpl_pointer pimpl) noexcept : m_pimpl_(std::move(pimpl)) {}
Expand Down
Loading
Loading