Skip to content
This repository was archived by the owner on May 22, 2020. It is now read-only.

Commit 18dfe4b

Browse files
author
Irhum Shafkat
authored
Code for module and network
1 parent 5b4814a commit 18dfe4b

File tree

2 files changed

+214
-0
lines changed

2 files changed

+214
-0
lines changed

module.py

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import math
2+
3+
import torch.nn as nn
4+
from torch.nn.modules.utils import _triple
5+
6+
7+
class SpatioTemporalConv(nn.Module):
8+
r"""Applies a factored 3D convolution over an input signal composed of several input
9+
planes with distinct spatial and time axes, by performing a 2D convolution over the
10+
spatial axes to an intermediate subspace, followed by a 1D convolution over the time
11+
axis to produce the final output.
12+
13+
Args:
14+
in_channels (int): Number of channels in the input tensor
15+
out_channels (int): Number of channels produced by the convolution
16+
kernel_size (int or tuple): Size of the convolving kernel
17+
stride (int or tuple, optional): Stride of the convolution. Default: 1
18+
padding (int or tuple, optional): Zero-padding added to the sides of the input during their respective convolutions. Default: 0
19+
bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
20+
"""
21+
22+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True):
23+
super(SpatioTemporalConv, self).__init__()
24+
25+
# if ints are entered, convert them to iterables, 1 -> [1, 1, 1]
26+
kernel_size = _triple(kernel_size)
27+
stride = _triple(stride)
28+
padding = _triple(padding)
29+
30+
# decomposing the parameters into spatial and temporal components by
31+
# masking out the values with the defaults on the axis that
32+
# won't be convolved over. This is necessary to avoid unintentional
33+
# behavior such as padding being added twice
34+
spatial_kernel_size = [1, kernel_size[1], kernel_size[2]]
35+
spatial_stride = [1, stride[1], stride[2]]
36+
spatial_padding = [0, padding[1], padding[2]]
37+
38+
temporal_kernel_size = [kernel_size[0], 1, 1]
39+
temporal_stride = [stride[0], 1, 1]
40+
temporal_padding = [padding[0], 0, 0]
41+
42+
# compute the number of intermediary channels (M) using formula
43+
# from the paper section 3.5
44+
intermed_channels = int(math.floor((kernel_size[0] * kernel_size[1] * kernel_size[2] * in_channels * out_channels)/ \
45+
(kernel_size[1]* kernel_size[2] * in_channels + kernel_size[0] * out_channels)))
46+
47+
# the spatial conv is effectively a 2D conv due to the
48+
# spatial_kernel_size, followed by batch_norm and ReLU
49+
self.spatial_conv = nn.Conv3d(in_channels, intermed_channels, spatial_kernel_size,
50+
stride=spatial_stride, padding=spatial_padding, bias=bias)
51+
self.bn = nn.BatchNorm3d(intermed_channels)
52+
self.relu = nn.ReLU()
53+
54+
# the temporal conv is effectively a 1D conv, but has batch norm
55+
# and ReLU added inside the model constructor, not here. This is an
56+
# intentional design choice, to allow this module to externally act
57+
# identical to a standard Conv3D, so it can be reused easily in any
58+
# other codebase
59+
self.temporal_conv = nn.Conv3d(intermed_channels, out_channels, temporal_kernel_size,
60+
stride=temporal_stride, padding=temporal_padding, bias=bias)
61+
62+
def forward(self, x):
63+
x = self.relu(self.bn(self.spatial_conv(x)))
64+
x = self.temporal_conv(x)
65+
return x

network.py

+149
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import torch.nn as nn
2+
from torch.nn.modules.utils import _triple
3+
4+
from module import SpatioTemporalConv
5+
6+
7+
class SpatioTemporalResBlock(nn.Module):
8+
r"""Single block for the ResNet network. Uses SpatioTemporalConv in
9+
the standard ResNet block layout (conv->batchnorm->ReLU->conv->batchnorm->sum->ReLU)
10+
11+
Args:
12+
in_channels (int): Number of channels in the input tensor.
13+
out_channels (int): Number of channels in the output produced by the block.
14+
kernel_size (int or tuple): Size of the convolving kernels.
15+
downsample (bool, optional): If ``True``, the output size is to be smaller than the input. Default: ``False``
16+
"""
17+
def __init__(self, in_channels, out_channels, kernel_size, downsample=False):
18+
super(SpatioTemporalResBlock, self).__init__()
19+
20+
# If downsample == True, the first conv of the layer has stride = 2
21+
# to halve the residual output size, and the input x is passed
22+
# through a seperate 1x1x1 conv with stride = 2 to also halve it.
23+
24+
# no pooling layers are used inside ResNet
25+
self.downsample = downsample
26+
27+
# to allow for SAME padding
28+
padding = kernel_size//2
29+
30+
if self.downsample:
31+
# downsample with stride =2 the input x
32+
self.downsampleconv = SpatioTemporalConv(in_channels, out_channels, 1, stride=2)
33+
self.downsamplebn = nn.BatchNorm3d(out_channels)
34+
35+
# downsample with stride = 2when producing the residual
36+
self.conv1 = SpatioTemporalConv(in_channels, out_channels, kernel_size, padding=padding, stride=2)
37+
else:
38+
self.conv1 = SpatioTemporalConv(in_channels, out_channels, kernel_size, padding=padding)
39+
40+
self.bn1 = nn.BatchNorm3d(out_channels)
41+
self.relu1 = nn.ReLU()
42+
43+
# standard conv->batchnorm->ReLU
44+
self.conv2 = SpatioTemporalConv(out_channels, out_channels, kernel_size, padding=padding)
45+
self.bn2 = nn.BatchNorm3d(out_channels)
46+
self.outrelu = nn.ReLU()
47+
48+
def forward(self, x):
49+
res = self.relu1(self.bn1(self.conv1(x)))
50+
res = self.bn2(self.conv2(res))
51+
52+
if self.downsample:
53+
x = self.downsamplebn(self.downsampleconv(x))
54+
55+
return self.outrelu(x + res)
56+
57+
58+
class SpatioTemporalResLayer(nn.Module):
59+
r"""Forms a single layer of the ResNet network, with a number of repeating
60+
blocks of same output size stacked on top of each other
61+
62+
Args:
63+
in_channels (int): Number of channels in the input tensor.
64+
out_channels (int): Number of channels in the output produced by the layer.
65+
kernel_size (int or tuple): Size of the convolving kernels.
66+
layer_size (int): Number of blocks to be stacked to form the layer
67+
block_type (Module, optional): Type of block that is to be used to form the layer. Default: SpatioTemporalResBlock.
68+
downsample (bool, optional): If ``True``, the first block in layer will implement downsampling. Default: ``False``
69+
"""
70+
71+
def __init__(self, in_channels, out_channels, kernel_size, layer_size, block_type=SpatioTemporalResBlock, downsample=False):
72+
73+
super(SpatioTemporalResLayer, self).__init__()
74+
75+
# implement the first block
76+
self.block1 = block_type(in_channels, out_channels, kernel_size, downsample)
77+
78+
# prepare module list to hold all (layer_size - 1) blocks
79+
self.blocks = nn.ModuleList([])
80+
for i in range(layer_size - 1):
81+
# all these blocks are identical, and have downsample = False by default
82+
self.blocks += [block_type(out_channels, out_channels, kernel_size)]
83+
84+
def forward(self, x):
85+
x = self.block1(x)
86+
for block in self.blocks:
87+
x = block(x)
88+
89+
return x
90+
91+
92+
class R2Plus1DNet(nn.Module):
93+
r"""Forms the overall ResNet feature extractor by initializng 5 layers, with the number of blocks in
94+
each layer set by layer_sizes, and by performing a global average pool at the end producing a
95+
512-dimensional vector for each element in the batch.
96+
97+
Args:
98+
layer_sizes (tuple): An iterable containing the number of blocks in each layer
99+
block_type (Module, optional): Type of block that is to be used to form the layers. Default: SpatioTemporalResBlock.
100+
"""
101+
def __init__(self, layer_sizes, block_type=SpatioTemporalResBlock):
102+
super(R2Plus1DNet, self).__init__()
103+
104+
# first conv, with stride 1x2x2 and kernel size 3x7x7
105+
self.conv1 = SpatioTemporalConv(3, 64, [3, 7, 7], stride=[1, 2, 2], padding=[1, 3, 3])
106+
# output of conv2 is same size as of conv1, no downsampling needed. kernel_size 3x3x3
107+
self.conv2 = SpatioTemporalResLayer(64, 64, 3, layer_sizes[0], block_type=block_type)
108+
# each of the final three layers doubles num_channels, while performing downsampling
109+
# inside the first block
110+
self.conv3 = SpatioTemporalResLayer(64, 128, 3, layer_sizes[1], block_type=block_type, downsample=True)
111+
self.conv4 = SpatioTemporalResLayer(128, 256, 3, layer_sizes[2], block_type=block_type, downsample=True)
112+
self.conv5 = SpatioTemporalResLayer(256, 512, 3, layer_sizes[3], block_type=block_type, downsample=True)
113+
114+
# global average pooling of the output
115+
self.pool = nn.AdaptiveAvgPool3d(1)
116+
117+
def forward(self, x):
118+
x = self.conv1(x)
119+
x = self.conv2(x)
120+
x = self.conv3(x)
121+
x = self.conv4(x)
122+
x = self.conv5(x)
123+
124+
x = self.pool(x)
125+
126+
return x.view(-1, 512)
127+
128+
class R2Plus1DClassifier(nn.Module):
129+
r"""Forms a complete ResNet classifier producing vectors of size num_classes, by initializng 5 layers,
130+
with the number of blocks in each layer set by layer_sizes, and by performing a global average pool
131+
at the end producing a 512-dimensional vector for each element in the batch,
132+
and passing them through a Linear layer.
133+
134+
Args:
135+
num_classes(int): Number of classes in the data
136+
layer_sizes (tuple): An iterable containing the number of blocks in each layer
137+
block_type (Module, optional): Type of block that is to be used to form the layers. Default: SpatioTemporalResBlock.
138+
"""
139+
def __init__(self, num_classes, layer_sizes, block_type=SpatioTemporalResBlock):
140+
super(R2Plus1DClassifier, self).__init__()
141+
142+
self.res2plus1d = R2Plus1DNet(layer_sizes, block_type)
143+
self.linear = nn.Linear(512, num_classes)
144+
145+
def forward(self, x):
146+
x = self.res2plus1d(x)
147+
x = self.linear(x)
148+
149+
return x

0 commit comments

Comments
 (0)