-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathCapsnet.py
74 lines (54 loc) · 2.62 KB
/
Capsnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import torch.nn.functional as func
import torch.nn as nn
from Capsule import CapsuleLayer
from exp_decoder import Decoder
import math
def conv_size(shape, k = 9, s = 1, p = False):
H, W = shape
if p:
pad = (k-1)//2
else:
pad = 0
Ho = math.floor(((H + 2*pad - (k - 1) - 1)/s) + 1)
Wo = math.floor(((W + 2*pad - (k - 1) - 1)/s) + 1)
return Ho, Wo
class CapsuleNetwork(nn.Module):
def __init__(self, img_size, ic_channels, num_pcaps, num_classes, num_coc, num_doc, mode='mono', use_padding=False):
super(CapsuleNetwork, self).__init__()
self.initial_conv = nn.Conv2d(in_channels=1 if mode=='mono' else 3, out_channels=ic_channels, kernel_size=9, stride=1)
Ho, Wo = conv_size(img_size, k=9, s=1, p=False)
self.p_caps = CapsuleLayer(num_caps=num_pcaps, num_routes=-1, in_channels=ic_channels, out_channels=num_coc,
k_size=9, stride=2)
Ho, Wo = conv_size((Ho, Wo), k=9, s=2, p=use_padding)
self.d_caps = CapsuleLayer(num_caps=num_classes, num_routes=num_coc*Ho*Wo, in_channels=num_pcaps, out_channels=num_doc)
self.decoder = Decoder()
def forward(self, x, y=None):
x = func.relu(self.initial_conv(x))
# print(x.shape)
x = self.p_caps(x)
# print(x.shape)
x = self.d_caps(x).squeeze().transpose(0,1)
classes = (x ** 2).sum(dim=-1) ** 0.5
classes = func.softmax(classes, dim=-1)
if y is not None:
_, max_index = classes.max(dim=1)
y = torch.eye(10, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
requires_grad = True).index_select(dim=0, index=max_index.data)
reconst = self.decoder((x * y[:, :, None]).view(x.size(0), -1))
return classes, reconst
class CapsuleLoss(nn.Module):
def __init__(self):
super(CapsuleLoss, self).__init__()
self.reconst = nn.MSELoss(size_average=False)
def forward(self, img, label, classes, reconst):
label = torch.eye(10, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
requires_grad=True).index_select(dim=0, index=label.data)
# print(classes.size(), label.size())
# print(img.size(), reconst.size())
left = func.relu(0.9-classes) ** 2
right = func.relu(classes - 0.1) ** 2
margin = label * left + 0.5 * (1-label) * right
margin = margin.sum()
recon = self.reconst(img, reconst)
return (margin + 0.0005 * recon) / img.size(0)