|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | + |
| 4 | +class RelPN(nn.Module): |
| 5 | + def __init__(self, cfg, in_dim): |
| 6 | + super(RelPN, self).__init__() |
| 7 | + |
| 8 | + def match_targets_to_proposals(self, proposal, target): |
| 9 | + match_quality_matrix = boxlist_iou(target, proposal) |
| 10 | + temp = [] |
| 11 | + target_box_pairs = [] |
| 12 | + # import pdb; pdb.set_trace() |
| 13 | + for i in range(match_quality_matrix.shape[0]): |
| 14 | + for j in range(match_quality_matrix.shape[0]): |
| 15 | + match_i = match_quality_matrix[i].view(-1, 1) |
| 16 | + match_j = match_quality_matrix[j].view(1, -1) |
| 17 | + match_ij = ((match_i + match_j) / 2) |
| 18 | + # rmeove duplicate index |
| 19 | + non_duplicate_idx = (torch.eye(match_ij.shape[0]).view(-1) == 0).nonzero().view(-1).to(match_ij.device) |
| 20 | + match_ij = match_ij.view(-1) # [::match_quality_matrix.shape[1]] = 0 |
| 21 | + match_ij = match_ij[non_duplicate_idx] |
| 22 | + temp.append(match_ij) |
| 23 | + boxi = target.bbox[i]; boxj = target.bbox[j] |
| 24 | + box_pair = torch.cat((boxi, boxj), 0) |
| 25 | + target_box_pairs.append(box_pair) |
| 26 | + |
| 27 | + # import pdb; pdb.set_trace() |
| 28 | + |
| 29 | + match_pair_quality_matrix = torch.stack(temp, 0).view(len(temp), -1) |
| 30 | + target_box_pairs = torch.stack(target_box_pairs, 0) |
| 31 | + target_pair = BoxPairList(target_box_pairs, target.size, target.mode) |
| 32 | + target_pair.add_field("labels", target.get_field("pred_labels").view(-1)) |
| 33 | + |
| 34 | + box_subj = proposal.bbox |
| 35 | + box_obj = proposal.bbox |
| 36 | + box_subj = box_subj.unsqueeze(1).repeat(1, box_subj.shape[0], 1) |
| 37 | + box_obj = box_obj.unsqueeze(0).repeat(box_obj.shape[0], 1, 1) |
| 38 | + proposal_box_pairs = torch.cat((box_subj.view(-1, 4), box_obj.view(-1, 4)), 1) |
| 39 | + |
| 40 | + idx_subj = torch.arange(box_subj.shape[0]).view(-1, 1, 1).repeat(1, box_obj.shape[0], 1).to(proposal.bbox.device) |
| 41 | + idx_obj = torch.arange(box_obj.shape[0]).view(1, -1, 1).repeat(box_subj.shape[0], 1, 1).to(proposal.bbox.device) |
| 42 | + proposal_idx_pairs = torch.cat((idx_subj.view(-1, 1), idx_obj.view(-1, 1)), 1) |
| 43 | + |
| 44 | + non_duplicate_idx = (proposal_idx_pairs[:, 0] != proposal_idx_pairs[:, 1]).nonzero() |
| 45 | + proposal_box_pairs = proposal_box_pairs[non_duplicate_idx.view(-1)] |
| 46 | + proposal_idx_pairs = proposal_idx_pairs[non_duplicate_idx.view(-1)] |
| 47 | + proposal_pairs = BoxPairList(proposal_box_pairs, proposal.size, proposal.mode) |
| 48 | + proposal_pairs.add_field("idx_pairs", proposal_idx_pairs) |
| 49 | + |
| 50 | + # matched_idxs = self.proposal_matcher(match_quality_matrix) |
| 51 | + matched_idxs = self.proposal_pair_matcher(match_pair_quality_matrix) |
| 52 | + |
| 53 | + # Fast RCNN only need "labels" field for selecting the targets |
| 54 | + # target = target.copy_with_fields("pred_labels") |
| 55 | + # get the targets corresponding GT for each proposal |
| 56 | + # NB: need to clamp the indices because we can have a single |
| 57 | + # GT in the image, and matched_idxs can be -2, which goes |
| 58 | + # out of bounds |
| 59 | + |
| 60 | + if self.use_matched_pairs_only and (matched_idxs >= 0).sum() > self.minimal_matched_pairs: |
| 61 | + # filter all matched_idxs < 0 |
| 62 | + proposal_pairs = proposal_pairs[matched_idxs >= 0] |
| 63 | + matched_idxs = matched_idxs[matched_idxs >= 0] |
| 64 | + |
| 65 | + matched_targets = target_pair[matched_idxs.clamp(min=0)] |
| 66 | + matched_targets.add_field("matched_idxs", matched_idxs) |
| 67 | + return matched_targets, proposal_pairs |
| 68 | + |
| 69 | + def prepare_targets(self, proposals, targets): |
| 70 | + labels = [] |
| 71 | + proposal_pairs = [] |
| 72 | + for proposals_per_image, targets_per_image in zip(proposals, targets): |
| 73 | + matched_targets, proposal_pairs_per_image = self.match_targets_to_proposals( |
| 74 | + proposals_per_image, targets_per_image |
| 75 | + ) |
| 76 | + |
| 77 | + matched_idxs = matched_targets.get_field("matched_idxs") |
| 78 | + |
| 79 | + labels_per_image = matched_targets.get_field("labels") |
| 80 | + labels_per_image = labels_per_image.to(dtype=torch.int64) |
| 81 | + |
| 82 | + # Label background (below the low threshold) |
| 83 | + bg_inds = matched_idxs == Matcher.BELOW_LOW_THRESHOLD |
| 84 | + labels_per_image[bg_inds] = 0 |
| 85 | + |
| 86 | + # Label ignore proposals (between low and high thresholds) |
| 87 | + ignore_inds = matched_idxs == Matcher.BETWEEN_THRESHOLDS |
| 88 | + labels_per_image[ignore_inds] = -1 # -1 is ignored by sampler |
| 89 | + |
| 90 | + # compute regression targets |
| 91 | + # regression_targets_per_image = self.box_coder.encode( |
| 92 | + # matched_targets.bbox, proposals_per_image.bbox |
| 93 | + # ) |
| 94 | + |
| 95 | + labels.append(labels_per_image) |
| 96 | + proposal_pairs.append(proposal_pairs_per_image) |
| 97 | + |
| 98 | + # regression_targets.append(regression_targets_per_image) |
| 99 | + |
| 100 | + return labels, proposal_pairs |
| 101 | + |
| 102 | + def forward(self, proposals, targets): |
| 103 | + """ |
| 104 | + This method performs the positive/negative sampling, and return |
| 105 | + the sampled proposals. |
| 106 | + Note: this function keeps a state. |
| 107 | +
|
| 108 | + Arguments: |
| 109 | + proposals (list[BoxList]) |
| 110 | + targets (list[BoxList]) |
| 111 | + """ |
| 112 | + |
| 113 | + labels, proposal_pairs = self.prepare_targets(proposals, targets) |
| 114 | + sampled_pos_inds, sampled_neg_inds = self.fg_bg_pair_sampler(labels) |
| 115 | + |
| 116 | + proposal_pairs = list(proposal_pairs) |
| 117 | + # add corresponding label and regression_targets information to the bounding boxes |
| 118 | + for labels_per_image, proposal_pairs_per_image in zip( |
| 119 | + labels, proposal_pairs |
| 120 | + ): |
| 121 | + proposal_pairs_per_image.add_field("labels", labels_per_image) |
| 122 | + # proposals_per_image.add_field( |
| 123 | + # "regression_targets", regression_targets_per_image |
| 124 | + # ) |
| 125 | + |
| 126 | + # distributed sampled proposals, that were obtained on all feature maps |
| 127 | + # concatenated via the fg_bg_sampler, into individual feature map levels |
| 128 | + for img_idx, (pos_inds_img, neg_inds_img) in enumerate( |
| 129 | + zip(sampled_pos_inds, sampled_neg_inds) |
| 130 | + ): |
| 131 | + img_sampled_inds = torch.nonzero(pos_inds_img | neg_inds_img).squeeze(1) |
| 132 | + proposal_pairs_per_image = proposal_pairs[img_idx][img_sampled_inds] |
| 133 | + proposal_pairs[img_idx] = proposal_pairs_per_image |
| 134 | + |
| 135 | + self._proposal_pairs = proposal_pairs |
| 136 | + return proposal_pairs |
| 137 | + |
| 138 | +def make_relation_proposal_network(cfg): |
| 139 | + return RelPN(cfg) |
0 commit comments