diff --git a/demo_superglue.py b/demo_superglue.py index 3cdbb88b..fd760761 100755 --- a/demo_superglue.py +++ b/demo_superglue.py @@ -153,7 +153,7 @@ assert ret, 'Error when reading the first frame (try different --input?)' frame_tensor = frame2tensor(frame, device) - last_data = matching.superpoint({'image': frame_tensor}) + last_data = matching.superpoint(frame_tensor) last_data = {k+'0': last_data[k] for k in keys} last_data['image0'] = frame_tensor last_frame = frame diff --git a/jit.py b/jit.py new file mode 100644 index 00000000..2f28c850 --- /dev/null +++ b/jit.py @@ -0,0 +1,5 @@ +from models.superpoint import SuperPoint +import torch + +superpoint = SuperPoint({}) +torch.jit.save(superpoint, 'SuperPoint.zip') diff --git a/models/matching.py b/models/matching.py index 5d174208..13c91ea6 100644 --- a/models/matching.py +++ b/models/matching.py @@ -63,10 +63,10 @@ def forward(self, data): # Extract SuperPoint (keypoints, scores, descriptors) if not provided if 'keypoints0' not in data: - pred0 = self.superpoint({'image': data['image0']}) + pred0 = self.superpoint(data['image0']) pred = {**pred, **{k+'0': v for k, v in pred0.items()}} if 'keypoints1' not in data: - pred1 = self.superpoint({'image': data['image1']}) + pred1 = self.superpoint(data['image1']) pred = {**pred, **{k+'1': v for k, v in pred1.items()}} # Batch all features diff --git a/models/superpoint.py b/models/superpoint.py index 8e411925..79b0dde5 100644 --- a/models/superpoint.py +++ b/models/superpoint.py @@ -44,20 +44,21 @@ import torch from torch import nn + +def max_pool(x, nms_radius: int): + return torch.nn.functional.max_pool2d( + x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius) + def simple_nms(scores, nms_radius: int): """ Fast Non-maximum suppression to remove nearby points """ assert(nms_radius >= 0) - def max_pool(x): - return torch.nn.functional.max_pool2d( - x, kernel_size=nms_radius*2+1, stride=1, padding=nms_radius) - zeros = torch.zeros_like(scores) - max_mask = scores == max_pool(scores) + max_mask = scores == max_pool(scores, nms_radius) for _ in range(2): - supp_mask = max_pool(max_mask.float()) > 0 + supp_mask = max_pool(max_mask.float(), nms_radius) > 0 supp_scores = torch.where(supp_mask, zeros, scores) - new_max_mask = supp_scores == max_pool(supp_scores) + new_max_mask = supp_scores == max_pool(supp_scores, nms_radius) max_mask = max_mask | (new_max_mask & (~supp_mask)) return torch.where(max_mask, scores, zeros) @@ -81,18 +82,16 @@ def sample_descriptors(keypoints, descriptors, s: int = 8): """ Interpolate descriptors at keypoint locations """ b, c, h, w = descriptors.shape keypoints = keypoints - s / 2 + 0.5 - keypoints /= torch.tensor([(w*s - s/2 - 0.5), (h*s - s/2 - 0.5)], - ).to(keypoints)[None] + keypoints /= torch.tensor([(w*s - s/2 - 0.5), (h*s - s/2 - 0.5)]).to(keypoints).unsqueeze(0) keypoints = keypoints*2 - 1 # normalize to (-1, 1) - args = {'align_corners': True} if int(torch.__version__[2]) > 2 else {} descriptors = torch.nn.functional.grid_sample( - descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args) + descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', align_corners=True) descriptors = torch.nn.functional.normalize( - descriptors.reshape(b, c, -1), p=2, dim=1) + descriptors.reshape(b, c, -1), p=2., dim=1) return descriptors -class SuperPoint(nn.Module): +class SuperPoint(torch.jit.ScriptModule): """SuperPoint Convolutional Detector and Descriptor SuperPoint: Self-Supervised Interest Point Detection and @@ -112,6 +111,12 @@ def __init__(self, config): super().__init__() self.config = {**self.default_config, **config} + self.descriptor_dim = self.config['descriptor_dim'] + self.nms_radius = self.config['nms_radius'] + self.keypoint_threshold = self.config['keypoint_threshold'] + self.max_keypoints = self.config['max_keypoints'] + self.remove_borders = self.config['remove_borders'] + self.relu = nn.ReLU(inplace=True) self.pool = nn.MaxPool2d(kernel_size=2, stride=2) c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256 @@ -130,22 +135,23 @@ def __init__(self, config): self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) self.convDb = nn.Conv2d( - c5, self.config['descriptor_dim'], + c5, self.descriptor_dim, kernel_size=1, stride=1, padding=0) path = Path(__file__).parent / 'weights/superpoint_v1.pth' self.load_state_dict(torch.load(str(path))) - mk = self.config['max_keypoints'] + mk = self.max_keypoints if mk == 0 or mk < -1: - raise ValueError('\"max_keypoints\" must be positive or \"-1\"') + raise ValueError('"max_keypoints" must be positive or "-1"') print('Loaded SuperPoint model') - def forward(self, data): + @torch.jit.script_method + def forward(self, image): """ Compute keypoints, scores, descriptors for image """ # Shared Encoder - x = self.relu(self.conv1a(data['image'])) + x = self.relu(self.conv1a(image)) x = self.relu(self.conv1b(x)) x = self.pool(x) x = self.relu(self.conv2a(x)) @@ -164,39 +170,39 @@ def forward(self, data): b, _, h, w = scores.shape scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h*8, w*8) - scores = simple_nms(scores, self.config['nms_radius']) - - # Extract keypoints - keypoints = [ - torch.nonzero(s > self.config['keypoint_threshold']) - for s in scores] - scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)] - - # Discard keypoints near the image borders - keypoints, scores = list(zip(*[ - remove_borders(k, s, self.config['remove_borders'], h*8, w*8) - for k, s in zip(keypoints, scores)])) - - # Keep the k keypoints with highest score - if self.config['max_keypoints'] >= 0: - keypoints, scores = list(zip(*[ - top_k_keypoints(k, s, self.config['max_keypoints']) - for k, s in zip(keypoints, scores)])) - - # Convert (h, w) to (x, y) - keypoints = [torch.flip(k, [1]).float() for k in keypoints] + scores = simple_nms(scores, self.nms_radius) # Compute the dense descriptors cDa = self.relu(self.convDa(x)) descriptors = self.convDb(cDa) - descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1) - - # Extract descriptors - descriptors = [sample_descriptors(k[None], d[None], 8)[0] - for k, d in zip(keypoints, descriptors)] + descriptors = torch.nn.functional.normalize(descriptors, p=2., dim=1) + + keypoints = [] + scores_out = [] + descriptors_out = [] + for i in range(b): + # Extract keypoints + s = scores[i] + k = torch.nonzero(s > self.keypoint_threshold) + s = s[s > self.keypoint_threshold] + + # Discard keypoints near the image borders + k, s = remove_borders(k, s, self.remove_borders, h*8, w*8) + + # Keep the k keypoints with highest score + if self.max_keypoints >= 0: + k, s = top_k_keypoints(k, s, self.max_keypoints) + + # Convert (h, w) to (x, y) + k = torch.flip(k, [1]).float() + + # Extract descriptors + descriptors_out.append(sample_descriptors(k.unsqueeze(0), descriptors[i].unsqueeze(0), 8)[0]) + keypoints.append(k) + scores_out.append(s) return { 'keypoints': keypoints, - 'scores': scores, - 'descriptors': descriptors, + 'scores': scores_out, + 'descriptors': descriptors_out, }