Skip to content

Commit

Permalink
#12
Browse files Browse the repository at this point in the history
  • Loading branch information
lgarithm committed Nov 19, 2018
1 parent ad7753f commit 89d629c
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 63 deletions.
14 changes: 10 additions & 4 deletions include/nn/bits/layers/conv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ class conv_layer_trait : public ops::conv_trait<ops::hw>
{
}

conv_layer_trait(const ksize_t &ksize, size_t n_filters,
const conv_trait &trait)
: conv_trait(trait), ksize_(ksize), n_filters_(n_filters)
{
}

template <typename image_order, typename filter_order>
shape<4> filter_shape(const shape<4> &x) const
{
Expand Down Expand Up @@ -63,8 +69,8 @@ class conv<image_order, filter_order, false, Act> : public conv_layer_trait
{
auto w = ops::new_parameter<ttl::tensor<R, 4>>(
filter_shape<image_order, filter_order>(x.shape()), w_init);
auto y = ops::new_result<ttl::tensor<R, 4>>(
conv_op(padding_, stride_, rate_), x, *w);
auto y = ops::new_result<ttl::tensor<R, 4>>(conv_op(h_trait_, w_trait_),
x, *w);

Act()(ref(*y), view(*y));
return make_layer(y, w);
Expand All @@ -86,8 +92,8 @@ class conv<image_order, filter_order, true, Act> : public conv_layer_trait
{
auto w = ops::new_parameter<ttl::tensor<R, 4>>(
filter_shape<image_order, filter_order>(x.shape()), w_init);
auto y = ops::new_result<ttl::tensor<R, 4>>(
conv_op(padding_, stride_, rate_), x, *w);
auto y = ops::new_result<ttl::tensor<R, 4>>(conv_op(h_trait_, w_trait_),
x, *w);

using add_bias = nn::ops::apply_bias<image_order, std::plus<R>>;
auto b = ops::new_parameter<ttl::tensor<R, 1>>(bias_shape(x.shape()),
Expand Down
99 changes: 72 additions & 27 deletions include/nn/bits/ops/conv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ template <typename dim_t> class linear_conv_trait
const dim_t stride_;

public:
using padding_t = typename linear_sample_trait<dim_t>::padding_t;

static padding_t padding(int p) { return padding_t(p, p); }

static padding_t padding(int left, int right)
{
return padding_t(left, right);
};

linear_conv_trait() : linear_conv_trait(default_pad_lr) {}

linear_conv_trait(dim_t pad_lr) : linear_conv_trait(pad_lr, default_stride)
Expand All @@ -34,10 +43,24 @@ template <typename dim_t> class linear_conv_trait
}

linear_conv_trait(dim_t pad_lr, dim_t stride, dim_t rate)
: pad_l_(pad_lr), pad_r_(pad_lr), rate_(rate), stride_(stride)
: linear_conv_trait(paddint(pad_lr), stride, rate)
{
}

linear_conv_trait(const padding_t &pad, dim_t stride, dim_t rate)
: pad_l_(std::get<0>(pad.dims)),
pad_r_(std::get<1>(pad.dims)),
rate_(rate),
stride_(stride)
{
}

padding_t get_padding() const { return padding_t(pad_l_, pad_r_); }

dim_t get_stride() const { return stride_; }

dim_t get_rate() const { return rate_; }

dim_t operator()(dim_t n, dim_t k) const
{
return linear_sample_trait(k, stride_, rate_, pad_l_, pad_r_)(n);
Expand All @@ -48,39 +71,59 @@ template <typename image_order> class conv_trait;

template <> class conv_trait<hw>
{
using conv_trait_1d_t = linear_conv_trait<size_t>;
using dim_t = size_t;
using conv_trait_1d_t = linear_conv_trait<dim_t>;
using sample_t = linear_sample_trait<dim_t>;
using padding_1d_t = sample_t::padding_t;

protected:
struct padding_trait;
struct stride_trait;
struct rate_trait;

using padding_t = std::experimental::new_type<shape<2>, padding_trait>;
using padding_t = std::array<padding_1d_t, 2>;

using stride_t = std::experimental::new_type<shape<2>, stride_trait>;
using rate_t = std::experimental::new_type<shape<2>, rate_trait>;

static constexpr auto default_padding = padding_t(0, 0);
static constexpr auto default_stride = stride_t(1, 1);
static constexpr auto default_rate = rate_t(1, 1);

const padding_t padding_;
const stride_t stride_;
const rate_t rate_;

const conv_trait_1d_t h_trait_;
const conv_trait_1d_t w_trait_;

static padding_t default_padding() { return padding(0, 0); }

public:
static padding_t padding(int r, int s) { return padding_t(r, s); };
static padding_1d_t padding_1d(dim_t p) { return padding_1d_t(p, p); }

static padding_1d_t padding_1d(dim_t left, dim_t right)
{
return padding_1d_t(left, right);
}

static padding_t padding(dim_t r, dim_t s)
{
return padding(padding_1d(r), padding_1d(s));
};

static padding_t padding(const padding_1d_t &r, const padding_1d_t &s)
{
return {r, s};
};

static stride_t stride(int r, int s) { return stride_t(r, s); };

conv_trait() : conv_trait(default_padding) {}
static rate_t rate(int r, int s) { return rate_t(r, s); };

conv_trait() : conv_trait(default_padding()) {}

conv_trait(const padding_t &padding) : conv_trait(padding, default_stride)
{
}

conv_trait(const stride_t &stride) : conv_trait(default_padding, stride) {}
conv_trait(const stride_t &stride) : conv_trait(default_padding(), stride)
{
}

conv_trait(const padding_t &padding, const stride_t &stride)
: conv_trait(padding, stride, default_rate)
Expand All @@ -89,11 +132,13 @@ template <> class conv_trait<hw>

conv_trait(const padding_t &padding, const stride_t &stride,
const rate_t &rate)
: padding_(padding),
stride_(stride),
rate_(rate),
h_trait_(padding.dims[0], stride.dims[0], rate.dims[0]),
w_trait_(padding.dims[1], stride.dims[1], rate.dims[1])
: h_trait_(padding[0], stride.dims[0], rate.dims[0]),
w_trait_(padding[1], stride.dims[1], rate.dims[1])
{
}

conv_trait(const conv_trait_1d_t &h_trait, const conv_trait_1d_t &w_trait)
: h_trait_(h_trait), w_trait_(w_trait)
{
}

Expand Down Expand Up @@ -138,11 +183,11 @@ template <> class conv<nhwc, rscd> : public conv_trait<hw>
const auto [r, s] = filter_shape<rscd>(y.shape()).dims;

using upper_op = im2col<hwc, hwrsc>;
const auto upper = internal::make_batched(
upper_op(upper_op::ksize(r, s),
upper_op::padding(padding_.dims[0], padding_.dims[1]),
upper_op::stride(stride_.dims[0], stride_.dims[1]),
upper_op::rate(rate_.dims[0], rate_.dims[1])));
const auto upper = internal::make_batched(upper_op(
upper_op::ksize(r, s),
upper_op::padding(h_trait_.get_padding(), w_trait_.get_padding()),
upper_op::stride(h_trait_.get_stride(), w_trait_.get_stride()),
upper_op::rate(h_trait_.get_rate(), w_trait_.get_rate())));

ttl::tensor<R, 6> x_upper(upper(x.shape()));
upper(ref(x_upper), view(x));
Expand Down Expand Up @@ -171,11 +216,11 @@ template <> class conv<nchw, dcrs> : public conv_trait<hw>
{
using upper_op = im2col<hw, rshw>;
const auto [r, s] = filter_shape<dcrs>(y.shape()).dims;
const auto upper = internal::make_batched(
upper_op(upper_op::ksize(r, s),
upper_op::padding(padding_.dims[0], padding_.dims[1]),
upper_op::stride(stride_.dims[0], stride_.dims[1]),
upper_op::rate(rate_.dims[0], rate_.dims[1])));
const auto upper = internal::make_batched(upper_op(
upper_op::ksize(r, s),
upper_op::padding(h_trait_.get_padding(), w_trait_.get_padding()),
upper_op::stride(h_trait_.get_stride(), w_trait_.get_stride()),
upper_op::rate(h_trait_.get_rate(), w_trait_.get_rate())));
ttl::tensor<R, 5> x_upper(upper(x.shape().template subshape<1>()));
const auto n = batch_size<nchw>(z.shape());
for (auto l : range(n)) {
Expand Down
53 changes: 35 additions & 18 deletions include/nn/bits/ops/im2col.hpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
#pragma once
#include <experimental/contract>
#include <experimental/new_type>
#include <nn/bits/ops/linear_sample.hpp>
#include <nn/bits/ops/reshape.hpp>
#include <nn/bits/ops/traits.hpp>
#include <stdtensor>
#include <nn/common.hpp>

namespace nn::ops
{
Expand All @@ -13,22 +11,23 @@ template <typename image_order> class im2col_trait;
template <> class im2col_trait<hw>
{
protected:
using dim_t = size_t;
using sample_t = linear_sample_trait<dim_t>;
using padding_1d_t = sample_t::padding_t;

struct ksize_trait;
struct padding_trait;
struct stride_trait;
struct rate_trait;

using ksize_t = std::experimental::new_type<shape<2>, ksize_trait>;
using padding_t = std::experimental::new_type<shape<2>, padding_trait>;
using stride_t = std::experimental::new_type<shape<2>, stride_trait>;
using rate_t = std::experimental::new_type<shape<2>, rate_trait>;

static constexpr auto default_padding = padding_t(0, 0);
using padding_t = std::array<padding_1d_t, 2>;

static constexpr auto default_stride = stride_t(1, 1);
static constexpr auto default_rate = rate_t(1, 1);

using sample_t = linear_sample_trait<size_t>;

const sample_t h_sample_;
const sample_t w_sample_;

Expand All @@ -37,14 +36,34 @@ template <> class im2col_trait<hw>
return ksize_t(h_sample_.ksize_, w_sample_.ksize_);
}

static padding_t default_padding() { return padding(0, 0); }

public:
static ksize_t ksize(int r, int s) { return ksize_t(r, s); };
static padding_t padding(int r, int s) { return padding_t(r, s); };
static stride_t stride(int r, int s) { return stride_t(r, s); };
static rate_t rate(int r, int s) { return rate_t(r, s); };
static ksize_t ksize(dim_t r, dim_t s) { return ksize_t(r, s); };

static padding_1d_t padding_1d(dim_t p) { return padding_1d_t(p, p); }

static padding_1d_t padding_1d(dim_t left, dim_t right)
{
return padding_1d_t(left, right);
}

static padding_t padding(dim_t r, dim_t s)
{
return padding(padding_1d(r), padding_1d(s));
};

static padding_t padding(const padding_1d_t &r, const padding_1d_t &s)
{
return {r, s};
};

static stride_t stride(dim_t r, dim_t s) { return stride_t(r, s); };

static rate_t rate(dim_t r, dim_t s) { return rate_t(r, s); };

im2col_trait(const ksize_t &ksize)
: im2col_trait(ksize, default_padding, default_stride)
: im2col_trait(ksize, default_padding(), default_stride)
{
}

Expand All @@ -54,7 +73,7 @@ template <> class im2col_trait<hw>
}

im2col_trait(const ksize_t &ksize, const stride_t &stride)
: im2col_trait(ksize, default_padding, stride)
: im2col_trait(ksize, default_padding(), stride)
{
}

Expand All @@ -66,10 +85,8 @@ template <> class im2col_trait<hw>

im2col_trait(const ksize_t &ksize, const padding_t &padding,
const stride_t &stride, const rate_t &rate)
: h_sample_(ksize.dims[0], stride.dims[0], rate.dims[0],
padding.dims[0]),
w_sample_(ksize.dims[1], stride.dims[1], rate.dims[1],
padding.dims[1])
: h_sample_(ksize.dims[0], stride.dims[0], rate.dims[0], padding[0]),
w_sample_(ksize.dims[1], stride.dims[1], rate.dims[1], padding[1])
{
}

Expand Down
32 changes: 25 additions & 7 deletions include/nn/bits/ops/linear_sample.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#pragma once
#include <experimental/contract>
#include <nn/common.hpp>

/*!
\begin{definition}
Expand Down Expand Up @@ -47,8 +47,19 @@ template <typename dim_t> class linear_sample_trait
const dim_t pad_l_; // TODO: make it template parameter
const dim_t pad_r_; // TODO: make it template parameter

struct padding_trait;

public:
// FIXME:
using padding_t = std::experimental::new_type<shape<2>, padding_trait>;

static padding_t padding(dim_t p) { return padding_t(p, p); }

static padding_t padding(dim_t left, dim_t right)
{
return padding_t(left, right);
};

// FIXME: make them private
const dim_t rate_;
const dim_t stride_;
const dim_t ksize_;
Expand All @@ -70,23 +81,30 @@ template <typename dim_t> class linear_sample_trait
}

linear_sample_trait(dim_t ksize, dim_t stride, dim_t rate, dim_t pad_lr)
: linear_sample_trait(ksize, stride, rate, pad_lr, pad_lr)
: linear_sample_trait(ksize, stride, rate, padding(pad_lr))
{
}

linear_sample_trait(dim_t ksize, dim_t stride, dim_t rate, dim_t pad_l,
dim_t pad_r)
: pad_l_(pad_l),
pad_r_(pad_r),
: linear_sample_trait(ksize, stride, rate, padding(pad_l, pad_r))
{
// TODO: deprecate it
}

linear_sample_trait(dim_t ksize, dim_t stride, dim_t rate,
const padding_t &pad)
: pad_l_(std::get<0>(pad.dims)),
pad_r_(std::get<1>(pad.dims)),
rate_(rate),
stride_(stride),
ksize_(ksize)
{
contract_assert(rate_ >= 1);
contract_assert(stride_ >= 1);
contract_assert(ksize_ >= 1);
contract_assert(pad_l >= 0);
contract_assert(pad_r >= 0);
contract_assert(pad_l_ >= 0);
contract_assert(pad_r_ >= 0);
}

/*! Compute the output size from input size. */
Expand Down
Loading

0 comments on commit 89d629c

Please sign in to comment.