forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathWeightNorm.cpp
More file actions
164 lines (139 loc) · 6.32 KB
/
WeightNorm.cpp
File metadata and controls
164 lines (139 loc) · 6.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/TensorUtils.h>
#include <ATen/TensorOperators.h>
#include <ATen/native/cpu/WeightNormKernel.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_weight_norm_differentiable_backward_native.h>
#include <ATen/ops/_weight_norm_interface.h>
#include <ATen/ops/_weight_norm_interface_backward_native.h>
#include <ATen/ops/_weight_norm_interface_native.h>
#include <ATen/ops/_weight_norm_native.h>
#include <ATen/ops/empty_strided.h>
#include <ATen/ops/norm_except_dim.h>
#include <ATen/ops/norm_except_dim_native.h>
#endif
#include <vector>
namespace at {
namespace native {
DEFINE_DISPATCH(weight_norm_stub);
DEFINE_DISPATCH(weight_norm_backward_stub);
// Staying faithful to the Python for now for clarity, look for optimizations later
// (e.g., single return statement for RVO)
Tensor norm_except_dim(const Tensor & v, int64_t pow, int64_t dim)
{
// I assume tensor.contiguous(), view(), norm(), etc. here will dispatch through VariableType.
if (dim == -1) {
return v.norm(pow);
} else if (dim == 0) {
std::vector<int64_t> output_size(v.dim(), 1);
output_size[0] = v.size(0);
return v.contiguous().view({v.size(0), -1}).norm(pow, 1).view(output_size);
} else if (dim == v.dim() - 1) {
std::vector<int64_t> output_size(v.dim(), 1);
output_size[v.dim() - 1] = v.size(v.dim() - 1);
return v.contiguous().view({-1, v.size(v.dim() - 1)}).norm(pow, 0).view(output_size);
} else {
// To consider: at::native::norm_except_dim is probably fine as well,
// and would avoid an additional dynamic dispatch.
return at::norm_except_dim(v.transpose(0, dim), pow, 0).transpose(0, dim); // optimize?
}
}
std::tuple<Tensor,Tensor> weight_norm_cpu(
const Tensor& v,
const Tensor& g,
int64_t dim) {
auto w = at::empty_like(v, at::MemoryFormat::Contiguous);
// align with cuda behavior, keep norm in 'Float' when g is 'BFloat16'
const auto dtype = g.scalar_type() == at::ScalarType::BFloat16 ?
at::ScalarType::Float : g.scalar_type();
auto norm = at::empty_strided(g.sizes(), g.strides(), g.options().dtype(dtype));
weight_norm_stub(kCPU, w, norm, v, g, dim);
return std::tuple<Tensor, Tensor>{w, norm};
}
std::tuple<Tensor, Tensor> weight_norm_backward_cpu(
const Tensor& grad_w,
const Tensor& saved_v,
const Tensor& saved_g,
const Tensor& saved_norm,
int64_t dim) {
TORCH_CHECK(saved_v.is_contiguous(), "saved_v must be contiguous");
TORCH_CHECK(saved_g.is_contiguous(), "saved_g must be contiguous");
TORCH_CHECK(saved_norm.is_contiguous(), "saved_norm must be contiguous");
auto grad_v = at::empty_like(saved_v, at::MemoryFormat::Contiguous);
auto grad_g = at::empty_like(saved_g, at::MemoryFormat::Contiguous);
weight_norm_backward_stub(kCPU, grad_v, grad_g, grad_w, saved_v, saved_g, saved_norm, dim);
return std::tuple<Tensor, Tensor>{grad_v, grad_g};
}
Tensor _weight_norm
(const Tensor & v_in,
const Tensor & g_in,
int64_t dim)
{
TORCH_CHECK(
v_in.device() == g_in.device(),
"weight_norm: expected v_in and g_in to be on the same device, but v_in is "
"on ", v_in.device(), " and g_in is on ", g_in.device());
auto v = v_in.contiguous();
auto g = g_in.contiguous();
auto has_half_dtype = v.scalar_type() == at::ScalarType::Half
|| g.scalar_type() == at::ScalarType::Half;
bool can_use_fused = !has_half_dtype && ((dim == 0) || (dim == v.dim() - 1));
if (can_use_fused) {
// weight_norm does not have a derivative defined for it, so this will route back through
// VariableType.cpp, and construct a WeightNormFusedBackward object in the autograd graph.
return std::get<0>(at::_weight_norm_interface(v, g, dim));
} else {
// Double-differentiable primitive ops
// at::native::norm_except_dim would probably be fine as well.
return v*(g/at::norm_except_dim(v, 2, dim));
}
}
// Differentiable backward path, an alternative to weight_norm_backward, to be used
// when backward is itself creating a graph.
// The GradMode::is_enabled() check must be performed within Functions.cpp; that's why we
// define a separate function here, instead of inlining it in weight_norm_cuda_backward.
std::tuple<Tensor, Tensor> _weight_norm_differentiable_backward
(const Tensor & grad_w,
const Tensor & saved_v,
const Tensor & saved_g,
const Tensor & saved_norms,
int64_t dim)
{
// In Functions.cpp, the HardshrinkBackward object supplies "grad.contiguous()"
// as the first argument, so grad_w should be contiguous here.
// All these checks should succeed:
TORCH_CHECK(grad_w.is_contiguous(), "grad_w must be contiguous");
TORCH_CHECK(saved_v.is_contiguous(), "saved_v must be contiguous");
TORCH_CHECK(saved_g.is_contiguous(), "saved_g must be contiguous");
TORCH_CHECK(saved_norms.is_contiguous(), "saved_norms must be contiguous");
int64_t last_dim = saved_v.dim() - 1;
int64_t last_size = saved_v.size(last_dim);
// Like weight_norm_fused_backward, weight_norm_differentiable_backward should only ever be called
// through a WeightNormFusedBackward object, so we expect that dim == 0 || dim == saved_v.size(-1)
TORCH_CHECK(dim == 0 || dim == last_dim, "Expected dim to be the first or last dimension");
// saved_g and saved_norms are already shaped to broadcast over the correct dimensions
// ...but saved_norms might be Float when saved_g and saved_v are half.
// To consider: saved_norms.to(..., True /*non_blocking*/);
auto norms = saved_norms.to(saved_g.scalar_type());
std::vector<int64_t> bcast_size(saved_v.dim(), 1);
// Analytic backward path using differentiable primitive ops
if (dim == 0) {
bcast_size[0] = saved_v.size(0);
auto per_dim_sums = (grad_w*saved_v).view({saved_v.size(0), -1}).sum(1).view(bcast_size);
auto grad_v = (saved_g/norms)*(grad_w - saved_v*(per_dim_sums/(norms*norms)));
auto grad_g = per_dim_sums/norms;
return std::tuple<Tensor, Tensor>{grad_v, grad_g};
} else { // dim == last_dim
bcast_size[last_dim] = last_size;
auto per_dim_sums = (grad_w*saved_v).view({-1, last_size}).sum(0).view(bcast_size);
auto grad_v = (saved_g/norms)*(grad_w - saved_v*(per_dim_sums/(norms*norms)));
auto grad_g = per_dim_sums/norms;
return std::tuple<Tensor, Tensor>{grad_v, grad_g};
}
}
} // namespace native
} // namespace at