Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions include/tensorwrapper/buffer/buffer_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ namespace tensorwrapper::buffer {
*
* All classes which wrap existing tensor libraries derive from this class.
*/
class BufferBase : public detail_::PolymorphicBase<BufferBase>,
public detail_::DSLBase<BufferBase> {
class BufferBase : public tensorwrapper::detail_::PolymorphicBase<BufferBase>,
public tensorwrapper::detail_::DSLBase<BufferBase> {
private:
/// Type of *this
using my_type = BufferBase;
Expand All @@ -39,7 +39,7 @@ class BufferBase : public detail_::PolymorphicBase<BufferBase>,

protected:
/// Type *this inherits from
using polymorphic_base = detail_::PolymorphicBase<my_type>;
using polymorphic_base = tensorwrapper::detail_::PolymorphicBase<my_type>;

public:
/// Type all buffers inherit from
Expand Down
144 changes: 104 additions & 40 deletions include/tensorwrapper/buffer/contiguous.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class Contiguous : public Replicated {
/// Type of a read-only reference to an object of type element_type
using const_reference = const element_type&;

using element_vector = std::vector<element_type>;

/// Type of a pointer to a mutable element_type object
using pointer = element_type*;

Expand All @@ -65,77 +67,139 @@ class Contiguous : public Replicated {
/// Returns the number of elements in contiguous memory
size_type size() const noexcept { return size_(); }

/// Returns a mutable pointer to the first element in contiguous memory
pointer data() noexcept { return data_(); }
/** @brief Returns a mutable pointer to the first element in contiguous
* memory
*
* @warning Returning a mutable pointer to the underlying data makes it
* no longer possible for *this to reliably track changes to that
* data. Calling this method may have performance implications, so
* use only when strictly required.
*
* @return A read/write pointer to the data.
*
* @throw None No throw guarantee.
*/
pointer get_mutable_data() noexcept { return get_mutable_data_(); }

/// Returns a read-only pointer to the first element in contiguous memory
const_pointer data() const noexcept { return data_(); }
/** @brief Returns an immutable pointer to the first element in contiguous
* memory
*
* @return A read-only pointer to the data.
*
* @throw None No throw guarantee.
*/
const_pointer get_immutable_data() const noexcept {
return get_immutable_data_();
}

/** @brief Retrieves a tensor element by offset.
*
* @tparam Args The types of each offset. Must decay to integral types.
* This method is used to access the element in an immutable way.
*
* @param[in] args The offsets such that the i-th value in @p args is the
* offset of the element along the i-th mode of the tensor.
* @param[in] index The offset of the element being retrieved.
*
* @return A mutable reference to the element.
* @return A read-only reference to the element.
*
* @throw std::runtime_error if the number of indices does not match the
* rank of the tensor. Strong throw guarantee.
*/
template<typename... Args>
reference at(Args&&... args) {
static_assert(
std::conjunction_v<std::is_integral<std::decay_t<Args>>...>,
"Offsets must be integral types");
if(sizeof...(Args) != this->rank())
const_reference get_elem(index_vector index) const {
if(index.size() != this->rank())
throw std::runtime_error("Number of offsets must match rank");
return get_elem_(
index_vector{detail_::to_size_t(std::forward<Args>(args))...});
return get_elem_(index);
}

/** @brief Retrieves a tensor element by offset.
/** @brief Sets a tensor element by offset.
*
* @tparam Args The types of each offset. Must decay to integral types.
* This method is used to change the value of an element.
*
* This method is the same as the non-const version except that the result
* is read-only. See the documentation for the mutable version for more
* details.
*
* @param[in] args The offsets such that the i-th value in @p args is the
* offset of the element along the i-th mode of the tensor.
*
* @return A read-only reference to the element.
* @param[in] index The offset of the element being updated.
* @param[in] new_value The new value of the element.
*
* @throw std::runtime_error if the number of indices does not match the
* rank of the tensor. Strong throw guarantee.
*/
template<typename... Args>
const_reference at(Args&&... args) const {
static_assert(
std::conjunction_v<std::is_integral<std::decay_t<Args>>...>,
"Offsets must be integral types");
if(sizeof...(Args) != this->rank())
void set_elem(index_vector index, element_type new_value) {
if(index.size() != this->rank())
throw std::runtime_error("Number of offsets must match rank");
return get_elem_(
index_vector{detail_::to_size_t(std::forward<Args>(args))...});
return set_elem_(index, new_value);
}

/** @brief Retrieves a tensor element by ordinal offset.
*
* This method is used to access the element in an immutable way.
*
* @param[in] index The ordinal offset of the element being retrieved.
*
* @return A read-only reference to the element.
*
* @throw std::runtime_error if the index is greater than the number of
* elements. Strong throw guarantee.
*/
const_reference get_data(size_type index) const {
if(index >= this->size())
throw std::runtime_error("Index greater than number of elements");
return get_data_(std::move(index));
}

/** @brief Sets a tensor element by ordinal offset.
*
* This method is used to change the value of an element.
*
* @param[in] index The ordinal offset of the element being updated.
* @param[in] new_value The new value of the element.
*
* @throw std::runtime_error if the index is greater than the number of
* elements. Strong throw guarantee.
*/
void set_data(size_type index, element_type new_value) {
if(index >= this->size())
throw std::runtime_error("Index greater than number of elements");
set_data_(index, new_value);
}

/** @brief Sets all elements to a value.
*
* @param[in] value The new value of all elements.
*
* @throw None No throw guarantee.
*/
void fill(element_type value) { fill_(std::move(value)); }

/** @brief Sets elements using a list of values.
*
* @param[in] values The new values of all elements.
*
* @throw None No throw guarantee.
*/
void copy(const element_vector& values) { copy_(values); }

protected:
/// Derived class can override if it likes
virtual size_type size_() const noexcept { return layout().shape().size(); }

/// Derived class should implement according to data() description
virtual pointer data_() noexcept = 0;
virtual pointer get_mutable_data_() noexcept = 0;

/// Derived class should implement according to data() const description
virtual const_pointer data_() const noexcept = 0;

/// Derived class should implement according to operator()()
virtual reference get_elem_(index_vector index) = 0;
virtual const_pointer get_immutable_data_() const noexcept = 0;

/// Derived class should implement according to operator()()const
/// Derived class should implement according to get_elem()
virtual const_reference get_elem_(index_vector index) const = 0;

/// Derived class should implement according to set_elem()
virtual void set_elem_(index_vector index, element_type new_value) = 0;

/// Derived class should implement according to get_data()
virtual const_reference get_data_(size_type index) const = 0;

/// Derived class should implement according to set_data()
virtual void set_data_(size_type index, element_type new_value) = 0;

/// Derived class should implement according to fill()
virtual void fill_(element_type) = 0;

virtual void copy_(const element_vector& values) = 0;
};

#define DECLARE_CONTIG_BUFFER(TYPE) extern template class Contiguous<TYPE>
Expand Down
24 changes: 19 additions & 5 deletions include/tensorwrapper/buffer/eigen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class Eigen : public Contiguous<FloatType> {
using typename my_base_type::const_pointer;
using typename my_base_type::const_reference;
using typename my_base_type::dsl_reference;
using typename my_base_type::element_type;
using typename my_base_type::element_vector;
using typename my_base_type::index_vector;
using typename my_base_type::label_type;
using typename my_base_type::layout_pointer;
Expand Down Expand Up @@ -210,17 +212,29 @@ class Eigen : public Contiguous<FloatType> {
const_labeled_reference rhs) override;

/// Implements getting the raw pointer
pointer data_() noexcept override;
pointer get_mutable_data_() noexcept override;

/// Implements getting the raw pointer (read-only)
const_pointer data_() const noexcept override;

/// Implements mutable element access
reference get_elem_(index_vector index) override;
const_pointer get_immutable_data_() const noexcept override;

/// Implements read-only element access
const_reference get_elem_(index_vector index) const override;

// Implements element updating
void set_elem_(index_vector index, element_type new_value) override;

/// Implements read-only element access by ordinal index
const_reference get_data_(size_type index) const override;

// Implements element updating by ordinal index
void set_data_(size_type index, element_type new_value) override;

/// Implements filling the tensor
void fill_(element_type value) override;

/// Implements copying new values into the tensor
void copy_(const element_vector& values) override;

/// Implements to_string
typename polymorphic_base::string_type to_string_() const override;

Expand Down
8 changes: 4 additions & 4 deletions src/python/tensor/export_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ auto make_buffer_info(buffer::Contiguous<FloatType>& buffer) {
stride_i *= smooth_shape.extent(mode_i);
strides[rank_i] = stride_i * nbytes;
}
return pybind11::buffer_info(buffer.data(), nbytes, desc, rank, shape,
strides);
return pybind11::buffer_info(buffer.get_mutable_data(), nbytes, desc, rank,
shape, strides);
}

auto make_tensor(pybind11::buffer b) {
Expand All @@ -60,8 +60,8 @@ auto make_tensor(pybind11::buffer b) {

auto n_elements = std::accumulate(dims.begin(), dims.end(), 1,
std::multiplies<std::size_t>());
for(auto i = 0; i < n_elements; ++i)
pBuffer->data()[i] = static_cast<double*>(info.ptr)[i];
auto pData = static_cast<double*>(info.ptr);
for(auto i = 0; i < n_elements; ++i) pBuffer->set_data(i, pData[i]);

return Tensor(matrix_shape, std::move(pBuffer));
}
Expand Down
5 changes: 2 additions & 3 deletions src/tensorwrapper/allocator/eigen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,7 @@ typename EIGEN::contiguous_pointer EIGEN::construct_(layout_pointer playout,
element_type value) {
auto pbuffer = this->allocate(std::move(playout));
auto& contig_buffer = static_cast<buffer::Contiguous<FloatType>&>(*pbuffer);
auto* pdata = contig_buffer.data();
std::fill(pdata, pdata + contig_buffer.size(), value);
contig_buffer.fill(value);
return pbuffer;
}

Expand All @@ -115,7 +114,7 @@ typename EIGEN::contiguous_pointer EIGEN::il_construct_(ILType il) {
auto playout = std::make_unique<layout::Physical>(std::move(shape));
auto pbuffer = this->allocate(std::move(playout));
auto& buffer_down = rebind(*pbuffer);
std::copy(data.begin(), data.end(), buffer_down.data());
buffer_down.copy(data);
return pbuffer;
}

Expand Down
Loading