Skip to content

Commit

Permalink
update multi-head attention
Browse files Browse the repository at this point in the history
  • Loading branch information
leaderj1001 committed Mar 14, 2021
1 parent 4850d9d commit f494b0e
Showing 1 changed file with 20 additions and 19 deletions.
39 changes: 20 additions & 19 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,34 @@ def get_n_params(model):


class MHSA(nn.Module):
def __init__(self, n_dims, width=14, height=14):
def __init__(self, n_dims, width=14, height=14, heads=4):
super(MHSA, self).__init__()
self.heads = heads

self.query = nn.Conv2d(n_dims, n_dims, kernel_size=1)
self.key = nn.Conv2d(n_dims, n_dims, kernel_size=1)
self.value = nn.Conv2d(n_dims, n_dims, kernel_size=1)

self.rel_h = nn.Parameter(torch.randn([1, n_dims, 1, height]), requires_grad=True)
self.rel_w = nn.Parameter(torch.randn([1, n_dims, width, 1]), requires_grad=True)
self.rel_h = nn.Parameter(torch.randn([1, heads, n_dims // heads, 1, height]), requires_grad=True)
self.rel_w = nn.Parameter(torch.randn([1, heads, n_dims // heads, width, 1]), requires_grad=True)

self.softmax = nn.Softmax(dim=-1)

def forward(self, x):
n_batch, C, width, height = x.size()
q = self.query(x).view(n_batch, C, -1)
k = self.key(x).view(n_batch, C, -1)
v = self.value(x).view(n_batch, C, -1)
q = self.query(x).view(n_batch, self.heads, C // self.heads, -1)
k = self.key(x).view(n_batch, self.heads, C // self.heads, -1)
v = self.value(x).view(n_batch, self.heads, C // self.heads, -1)

content_content = torch.bmm(q.permute(0, 2, 1), k)
content_content = torch.matmul(q.permute(0, 1, 3, 2), k)

content_position = (self.rel_h + self.rel_w).view(1, C, -1).permute(0, 2, 1)
content_position = (self.rel_h + self.rel_w).view(1, self.heads, C // self.heads, -1).permute(0, 1, 3, 2)
content_position = torch.matmul(content_position, q)

energy = content_content + content_position
attention = self.softmax(energy)

out = torch.bmm(v, attention.permute(0, 2, 1))
out = torch.matmul(v, attention.permute(0, 1, 3, 2))
out = out.view(n_batch, C, width, height)

return out
Expand All @@ -49,7 +50,7 @@ def forward(self, x):
class Bottleneck(nn.Module):
expansion = 4

def __init__(self, in_planes, planes, stride=1, mhsa=False, resolution=None):
def __init__(self, in_planes, planes, stride=1, heads=4, mhsa=False, resolution=None):
super(Bottleneck, self).__init__()

self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
Expand All @@ -58,7 +59,7 @@ def __init__(self, in_planes, planes, stride=1, mhsa=False, resolution=None):
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=stride, bias=False)
else:
self.conv2 = nn.ModuleList()
self.conv2.append(MHSA(planes, width=int(resolution[0]), height=int(resolution[1])))
self.conv2.append(MHSA(planes, width=int(resolution[0]), height=int(resolution[1]), heads=heads))
if stride == 2:
self.conv2.append(nn.AvgPool2d(2, 2))
self.conv2 = nn.Sequential(*self.conv2)
Expand All @@ -85,7 +86,7 @@ def forward(self, x):
# reference
# https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py
class ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=1000, resolution=(224, 224)):
def __init__(self, block, num_blocks, num_classes=1000, resolution=(224, 224), heads=4):
super(ResNet, self).__init__()
self.in_planes = 64
self.resolution = list(resolution)
Expand All @@ -106,19 +107,19 @@ def __init__(self, block, num_blocks, num_classes=1000, resolution=(224, 224)):
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2, mhsa=True)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2, heads=heads, mhsa=True)

self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Sequential(
nn.Dropout(0.3), # All architecture deeper than ResNet-200 dropout_rate: 0.2
nn.Linear(512 * block.expansion, num_classes)
)

def _make_layer(self, block, planes, num_blocks, stride=1, mhsa=False):
def _make_layer(self, block, planes, num_blocks, stride=1, heads=4, mhsa=False):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for idx, stride in enumerate(strides):
layers.append(block(self.in_planes, planes, stride, mhsa, self.resolution))
layers.append(block(self.in_planes, planes, stride, heads, mhsa, self.resolution))
if stride == 2:
self.resolution[0] /= 2
self.resolution[1] /= 2
Expand All @@ -140,16 +141,16 @@ def forward(self, x):
return out


def ResNet50(num_classes=1000, resolution=(224, 224)):
return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, resolution=resolution)
def ResNet50(num_classes=1000, resolution=(224, 224), heads=4):
return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, resolution=resolution, heads=heads)


def main():
model = ResNet50()
x = torch.randn([2, 3, 224, 224])
model = ResNet50(resolution=tuple(x.shape[2:]), heads=8)
print(model(x).size())
print(get_n_params(model))


# if __name__ == '__main__':
# main()
# main()

0 comments on commit f494b0e

Please sign in to comment.