-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathadjust_layer.py
41 lines (34 loc) · 1.29 KB
/
adjust_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch.nn as nn
class AdjustLayer(nn.Module):
def __init__(self, in_channels, out_channels):
super(AdjustLayer, self).__init__()
self.downsample = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels),
)
def forward(self, x):
x = self.downsample(x)
if x.size(3) < 20: # The author only picks 7 grids
l = 4
r = l + 7
x = x[:, :, l:r, l:r]
return x
class AdjustAllLayer(nn.Module):
def __init__(self, in_channels, out_channels):
super(AdjustAllLayer, self).__init__()
self.num = len(out_channels)
if self.num == 1:
self.downsample = AdjustLayer(in_channels[0], out_channels[0])
else:
for i in range(self.num):
self.add_module('downsample'+str(i+2),
AdjustLayer(in_channels[i], out_channels[i]))
def forward(self, features):
if self.num == 1:
return self.downsample(features)
else:
out = []
for i in range(self.num):
adj_layer = getattr(self, 'downsample'+str(i+2))
out.append(adj_layer(features[i]).contiguous())
return out