-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathconv_layers.py
54 lines (37 loc) · 1.56 KB
/
conv_layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import torch.nn as nn
import numpy as np
class CroppedConv2d(nn.Conv2d):
def __init__(self, *args, **kwargs):
super(CroppedConv2d, self).__init__(*args, **kwargs)
def forward(self, x):
x = super(CroppedConv2d, self).forward(x)
kernel_height, _ = self.kernel_size
res = x[:, :, 1:-kernel_height, :]
shifted_up_res = x[:, :, :-kernel_height-1, :]
return res, shifted_up_res
class MaskedConv2d(nn.Conv2d):
def __init__(self, *args, mask_type, data_channels, **kwargs):
super(MaskedConv2d, self).__init__(*args, **kwargs)
assert mask_type in ['A', 'B'], 'Invalid mask type.'
out_channels, in_channels, height, width = self.weight.size()
yc, xc = height // 2, width // 2
mask = np.zeros(self.weight.size(), dtype=np.float32)
mask[:, :, :yc, :] = 1
mask[:, :, yc, :xc + 1] = 1
def cmask(out_c, in_c):
a = (np.arange(out_channels) % data_channels == out_c)[:, None]
b = (np.arange(in_channels) % data_channels == in_c)[None, :]
return a * b
for o in range(data_channels):
for i in range(o + 1, data_channels):
mask[cmask(o, i), yc, xc] = 0
if mask_type == 'A':
for c in range(data_channels):
mask[cmask(c, c), yc, xc] = 0
mask = torch.from_numpy(mask).float()
self.register_buffer('mask', mask)
def forward(self, x):
self.weight.data *= self.mask
x = super(MaskedConv2d, self).forward(x)
return x