-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathReshape.h
62 lines (49 loc) · 2.11 KB
/
Reshape.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#ifndef LIBDL_RESHAPE_H
#define LIBDL_RESHAPE_H
#include <numeric>
#include "CNode.h"
#include "../Utils.h"
template <typename D, std::int64_t RA, std::int64_t RB>
class Reshape : public CNode<D, RB> {
public:
Reshape(
const std::shared_ptr<Tensor<D, RA>> &x,
const std::shared_ptr<Tensor<D, RB>> &result)
: CNode<D, RB>(Utils::removeOption<std::shared_ptr<CNodeBase>>({x->gradFn}), result),
oldShape(x->data->dimensions()),
cx(x->gradFn) {}
/*
* \brief reshapes the given tensor to the given shape
*
* \param x a tensor of any shape that should be reshaped
* \param newShape the shape to which x should be reshaped
* one element can be -1, its value is then infered from the size of x and the remaining dimensions
*
* \return a new tensor with the new shape
* */
static std::shared_ptr<Tensor<D, RB>> reshape(
const std::shared_ptr<Tensor<D, RA>> &x,
std::array<std::int64_t, RB> newShape) {
for (std::int64_t i = 0; i < RB; i++)
if (newShape[i] == -1) {
newShape[i] = x->data->size() / std::accumulate(std::begin(newShape), std::end(newShape), (std::int64_t) -1, std::multiplies<>());
break;
}
std::int64_t newSize = std::accumulate(std::begin(newShape), std::end(newShape), (std::int64_t) 1, std::multiplies<>());
if (newSize != x->data->size())
throw std::invalid_argument("x can't be reshaped to the given shape");
auto result = std::make_shared<Tensor<D, RB>>(x->data->reshape(newShape), newShape);
if (x->needsGradient() && !CNodeBase::noGrad)
result->setGradFn(std::make_shared<Reshape<D, RA, RB>>(x, result));
return result;
}
void computeGradients() override {
if (cx.has_value())
cx.value()->addGrad(CNode<D, RB>::grad->reshape(oldShape));
CNode<D, RB>::finishComputeGradient();
}
private:
std::array<std::int64_t, RA> oldShape;
std::optional<std::shared_ptr<CNode<D, RA>>> cx;
};
#endif //LIBDL_RESHAPE_H