Skip to content

Commit

Permalink
make SuperGlue compatible with TorchScript
Browse files Browse the repository at this point in the history
  • Loading branch information
valgur committed Jul 29, 2020
1 parent 2e80db8 commit 666f1e9
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 28 deletions.
5 changes: 3 additions & 2 deletions jit.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from models.superpoint import SuperPoint
from models.superglue import SuperGlue
import torch

superpoint = SuperPoint({})
torch.jit.save(superpoint, 'SuperPoint.zip')
torch.jit.save(SuperPoint({}), 'SuperPoint.zip')
torch.jit.save(SuperGlue({'weights': 'outdoor'}), 'SuperGlue.zip')
67 changes: 41 additions & 26 deletions models/superglue.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@

from copy import deepcopy
from pathlib import Path
from typing import List, Dict

import torch
from torch import nn

Expand All @@ -60,23 +62,23 @@ def MLP(channels: list, do_bn=True):
return nn.Sequential(*layers)


def normalize_keypoints(kpts, image_shape):
def normalize_keypoints(kpts, image_shape: List[int]):
""" Normalize keypoints locations based on image image_shape"""
_, _, height, width = image_shape
one = kpts.new_tensor(1)
size = torch.stack([one*width, one*height])[None]
size = torch.tensor([[width, height]], dtype=torch.float, device=kpts.device)
center = size / 2
scaling = size.max(1, keepdim=True).values * 0.7
return (kpts - center[:, None, :]) / scaling[:, None, :]


class KeypointEncoder(nn.Module):
class KeypointEncoder(torch.jit.ScriptModule):
""" Joint encoding of visual appearance and location using MLPs"""
def __init__(self, feature_dim, layers):
super().__init__()
self.encoder = MLP([3] + layers + [feature_dim])
nn.init.constant_(self.encoder[-1].bias, 0.0)

@torch.jit.script_method
def forward(self, kpts, scores):
inputs = [kpts.transpose(1, 2), scores.unsqueeze(1)]
return self.encoder(torch.cat(inputs, dim=1))
Expand All @@ -89,16 +91,20 @@ def attention(query, key, value):
return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob


class MultiHeadedAttention(nn.Module):
class MultiHeadedAttention(torch.jit.ScriptModule):
""" Multi-head attention to increase model expressivitiy """
prob: List[torch.Tensor]

def __init__(self, num_heads: int, d_model: int):
super().__init__()
assert d_model % num_heads == 0
self.dim = d_model // num_heads
self.num_heads = num_heads
self.merge = nn.Conv1d(d_model, d_model, kernel_size=1)
self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)])
self.prob = []

@torch.jit.script_method
def forward(self, query, key, value):
batch_dim = query.size(0)
query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1)
Expand All @@ -108,30 +114,32 @@ def forward(self, query, key, value):
return self.merge(x.contiguous().view(batch_dim, self.dim*self.num_heads, -1))


class AttentionalPropagation(nn.Module):
class AttentionalPropagation(torch.jit.ScriptModule):
def __init__(self, feature_dim: int, num_heads: int):
super().__init__()
self.attn = MultiHeadedAttention(num_heads, feature_dim)
self.mlp = MLP([feature_dim*2, feature_dim*2, feature_dim])
nn.init.constant_(self.mlp[-1].bias, 0.0)

@torch.jit.script_method
def forward(self, x, source):
message = self.attn(x, source, source)
return self.mlp(torch.cat([x, message], dim=1))


class AttentionalGNN(nn.Module):
class AttentionalGNN(torch.jit.ScriptModule):
def __init__(self, feature_dim: int, layer_names: list):
super().__init__()
self.layers = nn.ModuleList([
AttentionalPropagation(feature_dim, 4)
for _ in range(len(layer_names))])
self.names = layer_names

@torch.jit.script_method
def forward(self, desc0, desc1):
for layer, name in zip(self.layers, self.names):
for i, layer in enumerate(self.layers):
layer.attn.prob = []
if name == 'cross':
if self.names[i] == 'cross':
src0, src1 = desc1, desc0
else: # if name == 'self':
src0, src1 = desc0, desc1
Expand All @@ -152,8 +160,7 @@ def log_sinkhorn_iterations(Z, log_mu, log_nu, iters: int):
def log_optimal_transport(scores, alpha, iters: int):
""" Perform Differentiable Optimal Transport in Log-space for stability"""
b, m, n = scores.shape
one = scores.new_tensor(1)
ms, ns = (m*one).to(scores), (n*one).to(scores)
ms, ns = torch.tensor(m).to(scores), torch.tensor(n).to(scores)

bins0 = alpha.expand(b, m, 1)
bins1 = alpha.expand(b, 1, n)
Expand All @@ -173,10 +180,10 @@ def log_optimal_transport(scores, alpha, iters: int):


def arange_like(x, dim: int):
return x.new_ones(x.shape[dim]).cumsum(0) - 1 # traceable in 1.1
return torch.ones(x.shape[dim], dtype=x.dtype, device=x.device).cumsum(0) - 1


class SuperGlue(nn.Module):
class SuperGlue(torch.jit.ScriptModule):
"""SuperGlue feature matching middle-end
Given two sets of keypoints and locations, we determine the
Expand Down Expand Up @@ -207,27 +214,35 @@ def __init__(self, config):
super().__init__()
self.config = {**self.default_config, **config}

self.descriptor_dim = self.config['descriptor_dim']
self.weights = self.config['weights']
self.keypoint_encoder = self.config['keypoint_encoder']
self.GNN_layers = self.config['GNN_layers']
self.sinkhorn_iterations = self.config['sinkhorn_iterations']
self.match_threshold = self.config['match_threshold']

self.kenc = KeypointEncoder(
self.config['descriptor_dim'], self.config['keypoint_encoder'])
self.descriptor_dim, self.keypoint_encoder)

self.gnn = AttentionalGNN(
self.config['descriptor_dim'], self.config['GNN_layers'])
self.descriptor_dim, self.GNN_layers)

self.final_proj = nn.Conv1d(
self.config['descriptor_dim'], self.config['descriptor_dim'],
self.descriptor_dim, self.descriptor_dim,
kernel_size=1, bias=True)

bin_score = torch.nn.Parameter(torch.tensor(1.))
self.register_parameter('bin_score', bin_score)

assert self.config['weights'] in ['indoor', 'outdoor']
assert self.weights in ['indoor', 'outdoor']
path = Path(__file__).parent
path = path / 'weights/superglue_{}.pth'.format(self.config['weights'])
path = path / 'weights/superglue_{}.pth'.format(self.weights)
self.load_state_dict(torch.load(path))
print('Loaded SuperGlue model (\"{}\" weights)'.format(
self.config['weights']))
self.weights))

def forward(self, data):
@torch.jit.script_method
def forward(self, data: Dict[str, torch.Tensor]):
"""Run SuperGlue on a pair of keypoints and descriptors"""
desc0, desc1 = data['descriptors0'], data['descriptors1']
kpts0, kpts1 = data['keypoints0'], data['keypoints1']
Expand Down Expand Up @@ -257,25 +272,25 @@ def forward(self, data):

# Compute matching descriptor distance.
scores = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1)
scores = scores / self.config['descriptor_dim']**.5
scores = scores / self.descriptor_dim**.5

# Run the optimal transport.
scores = log_optimal_transport(
scores, self.bin_score,
iters=self.config['sinkhorn_iterations'])
iters=self.sinkhorn_iterations)

# Get the matches with score above "match_threshold".
max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
indices0, indices1 = max0.indices, max1.indices
mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0)
mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1)
zero = scores.new_tensor(0)
zero = torch.tensor(0).to(scores)
mscores0 = torch.where(mutual0, max0.values.exp(), zero)
mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero)
valid0 = mutual0 & (mscores0 > self.config['match_threshold'])
valid0 = mutual0 & (mscores0 > self.match_threshold)
valid1 = mutual1 & valid0.gather(1, indices1)
indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1))
indices0 = torch.where(valid0, indices0, torch.tensor(-1).to(indices0))
indices1 = torch.where(valid1, indices1, torch.tensor(-1).to(indices1))

return {
'matches0': indices0, # use -1 for invalid match
Expand Down

0 comments on commit 666f1e9

Please sign in to comment.