Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions include/tensorwrapper/buffer/buffer_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ namespace tensorwrapper::buffer {
*
* All classes which wrap existing tensor libraries derive from this class.
*/
class BufferBase : public detail_::PolymorphicBase<BufferBase>,
public detail_::DSLBase<BufferBase> {
class BufferBase : public tensorwrapper::detail_::PolymorphicBase<BufferBase>,
public tensorwrapper::detail_::DSLBase<BufferBase> {
private:
/// Type of *this
using my_type = BufferBase;
Expand All @@ -39,7 +39,7 @@ class BufferBase : public detail_::PolymorphicBase<BufferBase>,

protected:
/// Type *this inherits from
using polymorphic_base = detail_::PolymorphicBase<my_type>;
using polymorphic_base = tensorwrapper::detail_::PolymorphicBase<my_type>;

public:
/// Type all buffers inherit from
Expand Down
138 changes: 98 additions & 40 deletions include/tensorwrapper/buffer/contiguous.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class Contiguous : public Replicated {
/// Type of a read-only reference to an object of type element_type
using const_reference = const element_type&;

using elements_type = std::vector<element_type>;

/// Type of a pointer to a mutable element_type object
using pointer = element_type*;

Expand All @@ -65,77 +67,133 @@ class Contiguous : public Replicated {
/// Returns the number of elements in contiguous memory
size_type size() const noexcept { return size_(); }

/// Returns a mutable pointer to the first element in contiguous memory
pointer data() noexcept { return data_(); }
/** @brief Returns a mutable pointer to the first element in contiguous
* memory
*
* @warning Returning a mutable pointer to the underlying data makes it
* no longer possible for *this to reliably track changes to that
* data. Calling this method may have performance implications, so
* use only when strictly required.
*
* @return A read/write pointer to the data.
*
* @throw None No throw guarantee.
*/
pointer get_mutable_data() noexcept { return get_mutable_data_(); }

/// Returns a read-only pointer to the first element in contiguous memory
const_pointer data() const noexcept { return data_(); }
/** @brief Returns an immutable pointer to the first element in contiguous
* memory
*
* @return A read-only pointer to the data.
*
* @throw None No throw guarantee.
*/
const_pointer get_immutable_data() const noexcept {
return get_immutable_data_();
}

/** @brief Retrieves a tensor element by offset.
*
* @tparam Args The types of each offset. Must decay to integral types.
* This method is used to access the element in an immutable way.
*
* @param[in] args The offsets such that the i-th value in @p args is the
* offset of the element along the i-th mode of the tensor.
* @param[in] index The offset of the element being retrieved.
*
* @return A mutable reference to the element.
* @return A read-only reference to the element.
*
* @throw std::runtime_error if the number of indices does not match the
* rank of the tensor. Strong throw guarantee.
*/
template<typename... Args>
reference at(Args&&... args) {
static_assert(
std::conjunction_v<std::is_integral<std::decay_t<Args>>...>,
"Offsets must be integral types");
if(sizeof...(Args) != this->rank())
const_reference get_elem(index_vector index) const {
if(index.size() != this->rank())
throw std::runtime_error("Number of offsets must match rank");
return get_elem_(
index_vector{detail_::to_size_t(std::forward<Args>(args))...});
return get_elem_(index);
}

/** @brief Retrieves a tensor element by offset.
*
* @tparam Args The types of each offset. Must decay to integral types.
/** @brief Sets a tensor element by offset.
*
* This method is the same as the non-const version except that the result
* is read-only. See the documentation for the mutable version for more
* details.
* This method is used to change the value of an element.
*
* @param[in] args The offsets such that the i-th value in @p args is the
* offset of the element along the i-th mode of the tensor.
*
* @return A read-only reference to the element.
* @param[in] index The offset of the element being updated.
* @param[in] new_value The new value of the element.
*
* @throw std::runtime_error if the number of indices does not match the
* rank of the tensor. Strong throw guarantee.
*/
template<typename... Args>
const_reference at(Args&&... args) const {
static_assert(
std::conjunction_v<std::is_integral<std::decay_t<Args>>...>,
"Offsets must be integral types");
if(sizeof...(Args) != this->rank())
void set_elem(index_vector index, element_type new_value) {
if(index.size() != this->rank())
throw std::runtime_error("Number of offsets must match rank");
return get_elem_(
index_vector{detail_::to_size_t(std::forward<Args>(args))...});
return set_elem_(index, new_value);
}

/** @brief Retrieves a tensor element by ordinal offset.
*
* This method is used to access the element in an immutable way.
*
* @param[in] index The ordinal offset of the element being retrieved.
*
* @return A read-only reference to the element.
*
* @throw std::runtime_error if the index is greater than the number of
* elements. Strong throw guarantee.
*/
const_reference get_data(size_type index) const {
if(index >= this->size())
throw std::runtime_error("Index greater than number of elements");
return get_data_(std::move(index));
}

/** @brief Sets a tensor element by ordinal offset.
*
* This method is used to change the value of an element.
*
* @param[in] index The ordinal offset of the element being updated.
* @param[in] new_value The new value of the element.
*
* @throw std::runtime_error if the index is greater than the number of
* elements. Strong throw guarantee.
*/
void set_data(size_type index, element_type new_value) {
if(index >= this->size())
throw std::runtime_error("Index greater than number of elements");
set_data_(index, new_value);
}

/** @brief Sets all elements to a value.
*
* @param[in] value The new value of all elements.
*
* @throw None No throw guarantee.
*/
void fill(element_type value) { fill_(value); }

void copy(elements_type& values) { copy_(values); }

protected:
/// Derived class can override if it likes
virtual size_type size_() const noexcept { return layout().shape().size(); }

/// Derived class should implement according to data() description
virtual pointer data_() noexcept = 0;
virtual pointer get_mutable_data_() noexcept = 0;

/// Derived class should implement according to data() const description
virtual const_pointer data_() const noexcept = 0;
virtual const_pointer get_immutable_data_() const noexcept = 0;

/// Derived class should implement according to operator()()
virtual reference get_elem_(index_vector index) = 0;

/// Derived class should implement according to operator()()const
/// Derived class should implement according to get_elem()
virtual const_reference get_elem_(index_vector index) const = 0;

/// Derived class should implement according to set_elem()
virtual void set_elem_(index_vector index, element_type new_value) = 0;

/// Derived class should implement according to get_data()
virtual const_reference get_data_(size_type index) const = 0;

/// Derived class should implement according to set_data()
virtual void set_data_(size_type index, element_type new_value) = 0;

/// Derived class should implement according to fill()
virtual void fill_(element_type) = 0;

virtual void copy_(elements_type& values) = 0;
};

#define DECLARE_CONTIG_BUFFER(TYPE) extern template class Contiguous<TYPE>
Expand Down
23 changes: 18 additions & 5 deletions include/tensorwrapper/buffer/eigen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class Eigen : public Contiguous<FloatType> {
using typename my_base_type::const_pointer;
using typename my_base_type::const_reference;
using typename my_base_type::dsl_reference;
using typename my_base_type::element_type;
using typename my_base_type::elements_type;
using typename my_base_type::index_vector;
using typename my_base_type::label_type;
using typename my_base_type::layout_pointer;
Expand Down Expand Up @@ -210,17 +212,28 @@ class Eigen : public Contiguous<FloatType> {
const_labeled_reference rhs) override;

/// Implements getting the raw pointer
pointer data_() noexcept override;
pointer get_mutable_data_() noexcept override;

/// Implements getting the raw pointer (read-only)
const_pointer data_() const noexcept override;

/// Implements mutable element access
reference get_elem_(index_vector index) override;
const_pointer get_immutable_data_() const noexcept override;

/// Implements read-only element access
const_reference get_elem_(index_vector index) const override;

// Implements element updating
void set_elem_(index_vector index, element_type new_value) override;

/// Implements read-only element access by ordinal index
const_reference get_data_(size_type index) const override;

// Implements element updating by ordinal index
void set_data_(size_type index, element_type new_value) override;

/// Implements filling the tensor
void fill_(element_type value) override;

void copy_(elements_type& values) override;

/// Implements to_string
typename polymorphic_base::string_type to_string_() const override;

Expand Down
8 changes: 4 additions & 4 deletions src/python/tensor/export_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ auto make_buffer_info(buffer::Contiguous<FloatType>& buffer) {
stride_i *= smooth_shape.extent(mode_i);
strides[rank_i] = stride_i * nbytes;
}
return pybind11::buffer_info(buffer.data(), nbytes, desc, rank, shape,
strides);
return pybind11::buffer_info(buffer.get_mutable_data(), nbytes, desc, rank,
shape, strides);
}

auto make_tensor(pybind11::buffer b) {
Expand All @@ -60,8 +60,8 @@ auto make_tensor(pybind11::buffer b) {

auto n_elements = std::accumulate(dims.begin(), dims.end(), 1,
std::multiplies<std::size_t>());
for(auto i = 0; i < n_elements; ++i)
pBuffer->data()[i] = static_cast<double*>(info.ptr)[i];
auto pData = static_cast<double*>(info.ptr);
for(auto i = 0; i < n_elements; ++i) pBuffer->set_data(i, pData[i]);

return Tensor(matrix_shape, std::move(pBuffer));
}
Expand Down
5 changes: 2 additions & 3 deletions src/tensorwrapper/allocator/eigen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,7 @@ typename EIGEN::contiguous_pointer EIGEN::construct_(layout_pointer playout,
element_type value) {
auto pbuffer = this->allocate(std::move(playout));
auto& contig_buffer = static_cast<buffer::Contiguous<FloatType>&>(*pbuffer);
auto* pdata = contig_buffer.data();
std::fill(pdata, pdata + contig_buffer.size(), value);
contig_buffer.fill(value);
return pbuffer;
}

Expand All @@ -115,7 +114,7 @@ typename EIGEN::contiguous_pointer EIGEN::il_construct_(ILType il) {
auto playout = std::make_unique<layout::Physical>(std::move(shape));
auto pbuffer = this->allocate(std::move(playout));
auto& buffer_down = rebind(*pbuffer);
std::copy(data.begin(), data.end(), buffer_down.data());
buffer_down.copy(data);
return pbuffer;
}

Expand Down
Loading
Loading