-
-
Notifications
You must be signed in to change notification settings - Fork 3.5k
/
Copy pathactivations.py
138 lines (109 loc) · 5.41 KB
/
activations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
"""Activation functions."""
import torch
import torch.nn as nn
import torch.nn.functional as F
class SiLU(nn.Module):
"""Applies the SiLU activation function to the input tensor as described in https://arxiv.org/pdf/1606.08415.pdf."""
@staticmethod
def forward(x):
"""Applies the SiLU activation function, as detailed in https://arxiv.org/pdf/1606.08415.pdf, on input tensor
`x`.
"""
return x * torch.sigmoid(x)
class Hardswish(nn.Module):
"""Applies the Hardswish activation function to the input tensor `x`."""
@staticmethod
def forward(x):
"""Applies Hardswish activation, suitable for TorchScript, CoreML, ONNX, modifying input `x` as per Hard-SiLU
definition.
"""
return x * F.hardtanh(x + 3, 0.0, 6.0) / 6.0 # for TorchScript, CoreML and ONNX
class Mish(nn.Module):
"""Applies the Mish activation function to improve model performance; see https://github.com/digantamisra98/Mish."""
@staticmethod
def forward(x):
"""
Applies the Mish activation function, enhancing model performance and convergence.
Reference: https://github.com/digantamisra98/Mish
"""
return x * F.softplus(x).tanh()
class MemoryEfficientMish(nn.Module):
"""Applies the memory-efficient Mish activation function for improved model performance and reduced memory usage."""
class F(torch.autograd.Function):
"""Memory-efficient implementation of the Mish activation function for enhanced model performance."""
@staticmethod
def forward(ctx, x):
"""Applies the Mish activation function in a memory-efficient manner, useful for enhancing model
performance.
"""
ctx.save_for_backward(x)
return x.mul(torch.tanh(F.softplus(x))) # x * tanh(ln(1 + exp(x)))
@staticmethod
def backward(ctx, grad_output):
"""Computes gradient of the Mish activation function for backpropagation, returning the derivative with
respect to the input.
"""
x = ctx.saved_tensors[0]
sx = torch.sigmoid(x)
fx = F.softplus(x).tanh()
return grad_output * (fx + x * sx * (1 - fx * fx))
def forward(self, x):
"""Applies Mish activation function, useful in neural networks for nonlinear transformation of inputs."""
return self.F.apply(x)
class FReLU(nn.Module):
"""Implements the FReLU activation, combining ReLU and convolution from https://arxiv.org/abs/2007.11824."""
def __init__(self, c1, k=3): # ch_in, kernel
"""Initializes FReLU with specified channel size and kernel, implementing activation from
https://arxiv.org/abs/2007.11824.
"""
super().__init__()
self.conv = nn.Conv2d(c1, c1, k, 1, 1, groups=c1, bias=False)
self.bn = nn.BatchNorm2d(c1)
def forward(self, x):
"""Performs FReLU activation on input, returning the max of input and its 2D convolution."""
return torch.max(x, self.bn(self.conv(x)))
class AconC(nn.Module):
r"""ACON activation (activate or not)
AconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is a learnable parameter
according to "Activate or Not: Learning Customized Activation" <https://arxiv.org/pdf/2009.04759.pdf>.
"""
def __init__(self, c1):
"""Initializes ACON activation with learnable parameters p1, p2, and beta as per
https://arxiv.org/pdf/2009.04759.pdf.
"""
super().__init__()
self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1))
self.p2 = nn.Parameter(torch.randn(1, c1, 1, 1))
self.beta = nn.Parameter(torch.ones(1, c1, 1, 1))
def forward(self, x):
"""Applies a parametric activation function to tensor x; see https://arxiv.org/pdf/2009.04759.pdf for
details.
"""
dpx = (self.p1 - self.p2) * x
return dpx * torch.sigmoid(self.beta * dpx) + self.p2 * x
class MetaAconC(nn.Module):
r"""ACON activation (activate or not)
MetaAconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is generated by a small network
according to "Activate or Not: Learning Customized Activation" <https://arxiv.org/pdf/2009.04759.pdf>.
"""
def __init__(self, c1, k=1, s=1, r=16): # ch_in, kernel, stride, r
"""Initializes MetaAconC activation with params c1, optional k (kernel=1), s (stride=1), r (16), defining
activation dynamics.
"""
super().__init__()
c2 = max(r, c1 // r)
self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1))
self.p2 = nn.Parameter(torch.randn(1, c1, 1, 1))
self.fc1 = nn.Conv2d(c1, c2, k, s, bias=True)
self.fc2 = nn.Conv2d(c2, c1, k, s, bias=True)
# self.bn1 = nn.BatchNorm2d(c2)
# self.bn2 = nn.BatchNorm2d(c1)
def forward(self, x):
"""Applies a forward pass transforming input `x` using parametric operations and returns the modified tensor."""
y = x.mean(dim=2, keepdims=True).mean(dim=3, keepdims=True)
# batch-size 1 bug/instabilities https://github.com/ultralytics/yolov5/issues/2891
# beta = torch.sigmoid(self.bn2(self.fc2(self.bn1(self.fc1(y))))) # bug/unstable
beta = torch.sigmoid(self.fc2(self.fc1(y))) # bug patch BN layers removed
dpx = (self.p1 - self.p2) * x
return dpx * torch.sigmoid(beta * dpx) + self.p2 * x