From 666f1e9dc56f43c24a1b8a6f8f448b83cac42c9f Mon Sep 17 00:00:00 2001 From: Martin Valgur Date: Thu, 30 Jul 2020 01:31:59 +0300 Subject: [PATCH] make SuperGlue compatible with TorchScript --- jit.py | 5 ++-- models/superglue.py | 67 +++++++++++++++++++++++++++------------------ 2 files changed, 44 insertions(+), 28 deletions(-) diff --git a/jit.py b/jit.py index 2f28c850..5735d15a 100644 --- a/jit.py +++ b/jit.py @@ -1,5 +1,6 @@ from models.superpoint import SuperPoint +from models.superglue import SuperGlue import torch -superpoint = SuperPoint({}) -torch.jit.save(superpoint, 'SuperPoint.zip') +torch.jit.save(SuperPoint({}), 'SuperPoint.zip') +torch.jit.save(SuperGlue({'weights': 'outdoor'}), 'SuperGlue.zip') diff --git a/models/superglue.py b/models/superglue.py index e75c9b9c..5e08727b 100644 --- a/models/superglue.py +++ b/models/superglue.py @@ -42,6 +42,8 @@ from copy import deepcopy from pathlib import Path +from typing import List, Dict + import torch from torch import nn @@ -60,23 +62,23 @@ def MLP(channels: list, do_bn=True): return nn.Sequential(*layers) -def normalize_keypoints(kpts, image_shape): +def normalize_keypoints(kpts, image_shape: List[int]): """ Normalize keypoints locations based on image image_shape""" _, _, height, width = image_shape - one = kpts.new_tensor(1) - size = torch.stack([one*width, one*height])[None] + size = torch.tensor([[width, height]], dtype=torch.float, device=kpts.device) center = size / 2 scaling = size.max(1, keepdim=True).values * 0.7 return (kpts - center[:, None, :]) / scaling[:, None, :] -class KeypointEncoder(nn.Module): +class KeypointEncoder(torch.jit.ScriptModule): """ Joint encoding of visual appearance and location using MLPs""" def __init__(self, feature_dim, layers): super().__init__() self.encoder = MLP([3] + layers + [feature_dim]) nn.init.constant_(self.encoder[-1].bias, 0.0) + @torch.jit.script_method def forward(self, kpts, scores): inputs = [kpts.transpose(1, 2), scores.unsqueeze(1)] return self.encoder(torch.cat(inputs, dim=1)) @@ -89,8 +91,10 @@ def attention(query, key, value): return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob -class MultiHeadedAttention(nn.Module): +class MultiHeadedAttention(torch.jit.ScriptModule): """ Multi-head attention to increase model expressivitiy """ + prob: List[torch.Tensor] + def __init__(self, num_heads: int, d_model: int): super().__init__() assert d_model % num_heads == 0 @@ -98,7 +102,9 @@ def __init__(self, num_heads: int, d_model: int): self.num_heads = num_heads self.merge = nn.Conv1d(d_model, d_model, kernel_size=1) self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)]) + self.prob = [] + @torch.jit.script_method def forward(self, query, key, value): batch_dim = query.size(0) query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1) @@ -108,19 +114,20 @@ def forward(self, query, key, value): return self.merge(x.contiguous().view(batch_dim, self.dim*self.num_heads, -1)) -class AttentionalPropagation(nn.Module): +class AttentionalPropagation(torch.jit.ScriptModule): def __init__(self, feature_dim: int, num_heads: int): super().__init__() self.attn = MultiHeadedAttention(num_heads, feature_dim) self.mlp = MLP([feature_dim*2, feature_dim*2, feature_dim]) nn.init.constant_(self.mlp[-1].bias, 0.0) + @torch.jit.script_method def forward(self, x, source): message = self.attn(x, source, source) return self.mlp(torch.cat([x, message], dim=1)) -class AttentionalGNN(nn.Module): +class AttentionalGNN(torch.jit.ScriptModule): def __init__(self, feature_dim: int, layer_names: list): super().__init__() self.layers = nn.ModuleList([ @@ -128,10 +135,11 @@ def __init__(self, feature_dim: int, layer_names: list): for _ in range(len(layer_names))]) self.names = layer_names + @torch.jit.script_method def forward(self, desc0, desc1): - for layer, name in zip(self.layers, self.names): + for i, layer in enumerate(self.layers): layer.attn.prob = [] - if name == 'cross': + if self.names[i] == 'cross': src0, src1 = desc1, desc0 else: # if name == 'self': src0, src1 = desc0, desc1 @@ -152,8 +160,7 @@ def log_sinkhorn_iterations(Z, log_mu, log_nu, iters: int): def log_optimal_transport(scores, alpha, iters: int): """ Perform Differentiable Optimal Transport in Log-space for stability""" b, m, n = scores.shape - one = scores.new_tensor(1) - ms, ns = (m*one).to(scores), (n*one).to(scores) + ms, ns = torch.tensor(m).to(scores), torch.tensor(n).to(scores) bins0 = alpha.expand(b, m, 1) bins1 = alpha.expand(b, 1, n) @@ -173,10 +180,10 @@ def log_optimal_transport(scores, alpha, iters: int): def arange_like(x, dim: int): - return x.new_ones(x.shape[dim]).cumsum(0) - 1 # traceable in 1.1 + return torch.ones(x.shape[dim], dtype=x.dtype, device=x.device).cumsum(0) - 1 -class SuperGlue(nn.Module): +class SuperGlue(torch.jit.ScriptModule): """SuperGlue feature matching middle-end Given two sets of keypoints and locations, we determine the @@ -207,27 +214,35 @@ def __init__(self, config): super().__init__() self.config = {**self.default_config, **config} + self.descriptor_dim = self.config['descriptor_dim'] + self.weights = self.config['weights'] + self.keypoint_encoder = self.config['keypoint_encoder'] + self.GNN_layers = self.config['GNN_layers'] + self.sinkhorn_iterations = self.config['sinkhorn_iterations'] + self.match_threshold = self.config['match_threshold'] + self.kenc = KeypointEncoder( - self.config['descriptor_dim'], self.config['keypoint_encoder']) + self.descriptor_dim, self.keypoint_encoder) self.gnn = AttentionalGNN( - self.config['descriptor_dim'], self.config['GNN_layers']) + self.descriptor_dim, self.GNN_layers) self.final_proj = nn.Conv1d( - self.config['descriptor_dim'], self.config['descriptor_dim'], + self.descriptor_dim, self.descriptor_dim, kernel_size=1, bias=True) bin_score = torch.nn.Parameter(torch.tensor(1.)) self.register_parameter('bin_score', bin_score) - assert self.config['weights'] in ['indoor', 'outdoor'] + assert self.weights in ['indoor', 'outdoor'] path = Path(__file__).parent - path = path / 'weights/superglue_{}.pth'.format(self.config['weights']) + path = path / 'weights/superglue_{}.pth'.format(self.weights) self.load_state_dict(torch.load(path)) print('Loaded SuperGlue model (\"{}\" weights)'.format( - self.config['weights'])) + self.weights)) - def forward(self, data): + @torch.jit.script_method + def forward(self, data: Dict[str, torch.Tensor]): """Run SuperGlue on a pair of keypoints and descriptors""" desc0, desc1 = data['descriptors0'], data['descriptors1'] kpts0, kpts1 = data['keypoints0'], data['keypoints1'] @@ -257,25 +272,25 @@ def forward(self, data): # Compute matching descriptor distance. scores = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1) - scores = scores / self.config['descriptor_dim']**.5 + scores = scores / self.descriptor_dim**.5 # Run the optimal transport. scores = log_optimal_transport( scores, self.bin_score, - iters=self.config['sinkhorn_iterations']) + iters=self.sinkhorn_iterations) # Get the matches with score above "match_threshold". max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1) indices0, indices1 = max0.indices, max1.indices mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0) mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1) - zero = scores.new_tensor(0) + zero = torch.tensor(0).to(scores) mscores0 = torch.where(mutual0, max0.values.exp(), zero) mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero) - valid0 = mutual0 & (mscores0 > self.config['match_threshold']) + valid0 = mutual0 & (mscores0 > self.match_threshold) valid1 = mutual1 & valid0.gather(1, indices1) - indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1)) - indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1)) + indices0 = torch.where(valid0, indices0, torch.tensor(-1).to(indices0)) + indices1 = torch.where(valid1, indices1, torch.tensor(-1).to(indices1)) return { 'matches0': indices0, # use -1 for invalid match