forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathElu.h
72 lines (67 loc) · 2.74 KB
/
Elu.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#pragma once
// On Windows, math.h needs to be included with _USE_MATH_DEFINES defined to
// access constants such as M_SQRT2 and M_2_SQRTPI.
#ifdef _WIN32
#define _USE_MATH_DEFINES
#include <cmath>
#endif // _WIN32
#include <ATen/cpu/vec/vec.h>
#include <c10/util/BFloat16.h> // For c10::is_reduced_floating_point_v.
namespace at::native {
/**
* Return a function object that calculates ELU with the given
* parameters on its input element. ParamT is the type of the input
* and output to the ELU, and MathT is the type (possibly
* higher-precision, e.g. float if ParamT is reduced-precision float)
* in which to do intermediate calculations.
*/
template <typename ParamT, typename MathT=ParamT>
auto get_scalar_elu_elementwise_func(MathT alpha, MathT scale, MathT input_scale) {
const auto negcoef = alpha * scale;
const auto poscoef = scale;
const auto negiptcoef = input_scale;
return [negcoef, negiptcoef, poscoef](ParamT a) -> ParamT {
return MathT(a) <= MathT(0)
? std::expm1(MathT(a) * negiptcoef) * negcoef
: MathT(a) * poscoef;
};
}
/**
* Return a function object that calculates ELU with the given
* parameters on its input element. The function object takes and
* returns Vectorized<T>.
*/
template <typename T, std::enable_if_t<!c10::is_reduced_floating_point_v<T>, bool> = true>
auto get_vectorized_elu_elementwise_func(T alpha, T scale, T input_scale) {
const vec::Vectorized<T> negcoef_vec(alpha * scale);
const vec::Vectorized<T> poscoef_vec(scale);
const vec::Vectorized<T> negiptcoef_vec(input_scale);
const vec::Vectorized<T> zero_vec(static_cast<T>(0));
return [negcoef_vec, poscoef_vec, negiptcoef_vec, zero_vec](vec::Vectorized<T> a) -> vec::Vectorized<T> {
const auto cmp = a > zero_vec;
if (!cmp.zero_mask()) {
return a * poscoef_vec;
} else {
return vec::Vectorized<T>::blendv((a * negiptcoef_vec).expm1() * negcoef_vec, a * poscoef_vec, cmp);
}
};
}
/**
* Return a function object that calculates ELU with the given
* parameters on its input element. The function object takes and
* returns Vectorized<ParamT>, and Vectorized<MathT> is the type
* (possibly higher-precision) in which to do intermediate
* calculations.
*/
template <typename T, std::enable_if_t<c10::is_reduced_floating_point_v<T>, bool> = true>
auto get_vectorized_elu_elementwise_func(float alpha, float scale, float input_scale) {
// Takes float->float.
const auto float_func = get_vectorized_elu_elementwise_func<float>(alpha, scale, input_scale);
return [float_func](vec::Vectorized<T> a) -> vec::Vectorized<T> {
auto [a0, a1] = vec::convert_to_float<T>(a);
auto res0 = float_func(a0);
auto res1 = float_func(a1);
return vec::convert_from_float<T>(res0, res1);
};
}
} // namespace at::native