From f494b0e915a4937d4d42c0b4044c658e71e28657 Mon Sep 17 00:00:00 2001 From: myeongjun Date: Sun, 14 Mar 2021 15:44:38 +0900 Subject: [PATCH] update multi-head attention --- model.py | 39 ++++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/model.py b/model.py index cd244bb..2457b2b 100644 --- a/model.py +++ b/model.py @@ -14,33 +14,34 @@ def get_n_params(model): class MHSA(nn.Module): - def __init__(self, n_dims, width=14, height=14): + def __init__(self, n_dims, width=14, height=14, heads=4): super(MHSA, self).__init__() + self.heads = heads self.query = nn.Conv2d(n_dims, n_dims, kernel_size=1) self.key = nn.Conv2d(n_dims, n_dims, kernel_size=1) self.value = nn.Conv2d(n_dims, n_dims, kernel_size=1) - self.rel_h = nn.Parameter(torch.randn([1, n_dims, 1, height]), requires_grad=True) - self.rel_w = nn.Parameter(torch.randn([1, n_dims, width, 1]), requires_grad=True) + self.rel_h = nn.Parameter(torch.randn([1, heads, n_dims // heads, 1, height]), requires_grad=True) + self.rel_w = nn.Parameter(torch.randn([1, heads, n_dims // heads, width, 1]), requires_grad=True) self.softmax = nn.Softmax(dim=-1) def forward(self, x): n_batch, C, width, height = x.size() - q = self.query(x).view(n_batch, C, -1) - k = self.key(x).view(n_batch, C, -1) - v = self.value(x).view(n_batch, C, -1) + q = self.query(x).view(n_batch, self.heads, C // self.heads, -1) + k = self.key(x).view(n_batch, self.heads, C // self.heads, -1) + v = self.value(x).view(n_batch, self.heads, C // self.heads, -1) - content_content = torch.bmm(q.permute(0, 2, 1), k) + content_content = torch.matmul(q.permute(0, 1, 3, 2), k) - content_position = (self.rel_h + self.rel_w).view(1, C, -1).permute(0, 2, 1) + content_position = (self.rel_h + self.rel_w).view(1, self.heads, C // self.heads, -1).permute(0, 1, 3, 2) content_position = torch.matmul(content_position, q) energy = content_content + content_position attention = self.softmax(energy) - out = torch.bmm(v, attention.permute(0, 2, 1)) + out = torch.matmul(v, attention.permute(0, 1, 3, 2)) out = out.view(n_batch, C, width, height) return out @@ -49,7 +50,7 @@ def forward(self, x): class Bottleneck(nn.Module): expansion = 4 - def __init__(self, in_planes, planes, stride=1, mhsa=False, resolution=None): + def __init__(self, in_planes, planes, stride=1, heads=4, mhsa=False, resolution=None): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) @@ -58,7 +59,7 @@ def __init__(self, in_planes, planes, stride=1, mhsa=False, resolution=None): self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=stride, bias=False) else: self.conv2 = nn.ModuleList() - self.conv2.append(MHSA(planes, width=int(resolution[0]), height=int(resolution[1]))) + self.conv2.append(MHSA(planes, width=int(resolution[0]), height=int(resolution[1]), heads=heads)) if stride == 2: self.conv2.append(nn.AvgPool2d(2, 2)) self.conv2 = nn.Sequential(*self.conv2) @@ -85,7 +86,7 @@ def forward(self, x): # reference # https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py class ResNet(nn.Module): - def __init__(self, block, num_blocks, num_classes=1000, resolution=(224, 224)): + def __init__(self, block, num_blocks, num_classes=1000, resolution=(224, 224), heads=4): super(ResNet, self).__init__() self.in_planes = 64 self.resolution = list(resolution) @@ -106,7 +107,7 @@ def __init__(self, block, num_blocks, num_classes=1000, resolution=(224, 224)): self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) - self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2, mhsa=True) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2, heads=heads, mhsa=True) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Sequential( @@ -114,11 +115,11 @@ def __init__(self, block, num_blocks, num_classes=1000, resolution=(224, 224)): nn.Linear(512 * block.expansion, num_classes) ) - def _make_layer(self, block, planes, num_blocks, stride=1, mhsa=False): + def _make_layer(self, block, planes, num_blocks, stride=1, heads=4, mhsa=False): strides = [stride] + [1]*(num_blocks-1) layers = [] for idx, stride in enumerate(strides): - layers.append(block(self.in_planes, planes, stride, mhsa, self.resolution)) + layers.append(block(self.in_planes, planes, stride, heads, mhsa, self.resolution)) if stride == 2: self.resolution[0] /= 2 self.resolution[1] /= 2 @@ -140,16 +141,16 @@ def forward(self, x): return out -def ResNet50(num_classes=1000, resolution=(224, 224)): - return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, resolution=resolution) +def ResNet50(num_classes=1000, resolution=(224, 224), heads=4): + return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, resolution=resolution, heads=heads) def main(): - model = ResNet50() x = torch.randn([2, 3, 224, 224]) + model = ResNet50(resolution=tuple(x.shape[2:]), heads=8) print(model(x).size()) print(get_n_params(model)) # if __name__ == '__main__': -# main() +# main() \ No newline at end of file