forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlayer_norm.h
100 lines (85 loc) · 2.73 KB
/
layer_norm.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#pragma once
#include <ATen/core/Tensor.h>
#include <ATen/native/DispatchStub.h>
#include <c10/util/accumulate.h>
namespace at::native {
namespace {
C10_ALWAYS_INLINE std::pair<int64_t, int64_t> _check_layer_norm_inputs(
const Tensor& input,
IntArrayRef normalized_shape,
const Tensor& weight /* optional */,
const Tensor& bias /* optional */) {
const int normalized_ndim = normalized_shape.size();
TORCH_CHECK(
normalized_ndim >= 1,
"Expected normalized_shape to be at least 1-dimensional, i.e., ",
"containing at least one element, but got normalized_shape = ",
normalized_shape);
TORCH_CHECK(
!weight.defined() || weight.sizes().equals(normalized_shape),
"Expected weight to be of same shape as normalized_shape, but got ",
"weight of shape ",
weight.sizes(),
" and normalized_shape = ",
normalized_shape);
TORCH_CHECK(
!bias.defined() || bias.sizes().equals(normalized_shape),
"Expected bias to be of same shape as normalized_shape, but got ",
"bias of shape ",
bias.sizes(),
" and normalized_shape = ",
normalized_shape);
const auto input_shape = input.sizes();
const auto input_ndim = input.dim();
if (input_ndim < normalized_ndim ||
!input_shape.slice(input_ndim - normalized_ndim)
.equals(normalized_shape)) {
std::stringstream ss;
ss << "Given normalized_shape=" << normalized_shape
<< ", expected input with shape [*";
for (auto size : normalized_shape) {
ss << ", " << size;
}
ss << "], but got input of size" << input_shape;
AT_ERROR(ss.str());
}
const int axis = input_ndim - normalized_ndim;
const int64_t M =
c10::multiply_integers(input_shape.cbegin(), input_shape.cbegin() + axis);
const int64_t N =
c10::multiply_integers(input_shape.cbegin() + axis, input_shape.cend());
return std::make_pair(M, N);
}
} // namespace
void layer_norm_cpu_out(
at::Tensor& out,
const at::Tensor& input,
const Tensor& gamma,
const Tensor& beta,
double eps,
int64_t M,
int64_t N);
using forward_fn = void (*)(
const Tensor& /* X */,
const Tensor& /* gamma */,
const Tensor& /* beta */,
int64_t /* M */,
int64_t /* N */,
double /* eps */,
Tensor* /* Y */,
Tensor* /* mean */,
Tensor* /* rstd */);
using backward_fn = void (*)(
const Tensor& /* dY */,
const Tensor& /* X */,
const Tensor& /* mean */,
const Tensor& /* rstd */,
const Tensor& /* gamma */,
int64_t /* M */,
int64_t /* N */,
Tensor* /* dX */,
Tensor* /* dgamma */,
Tensor* /* dbeta */);
DECLARE_DISPATCH(forward_fn, LayerNormKernel);
DECLARE_DISPATCH(backward_fn, LayerNormBackwardKernel);
} // namespace at::native