Skip to content

Commit 0a32020

Browse files
committed
add dummy relpn
1 parent 5841929 commit 0a32020

File tree

2 files changed

+139
-0
lines changed

2 files changed

+139
-0
lines changed

lib/scene_parser/rcnn/modeling/relation_heads/relpn/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
class RelPN(nn.Module):
5+
def __init__(self, cfg, in_dim):
6+
super(RelPN, self).__init__()
7+
8+
def match_targets_to_proposals(self, proposal, target):
9+
match_quality_matrix = boxlist_iou(target, proposal)
10+
temp = []
11+
target_box_pairs = []
12+
# import pdb; pdb.set_trace()
13+
for i in range(match_quality_matrix.shape[0]):
14+
for j in range(match_quality_matrix.shape[0]):
15+
match_i = match_quality_matrix[i].view(-1, 1)
16+
match_j = match_quality_matrix[j].view(1, -1)
17+
match_ij = ((match_i + match_j) / 2)
18+
# rmeove duplicate index
19+
non_duplicate_idx = (torch.eye(match_ij.shape[0]).view(-1) == 0).nonzero().view(-1).to(match_ij.device)
20+
match_ij = match_ij.view(-1) # [::match_quality_matrix.shape[1]] = 0
21+
match_ij = match_ij[non_duplicate_idx]
22+
temp.append(match_ij)
23+
boxi = target.bbox[i]; boxj = target.bbox[j]
24+
box_pair = torch.cat((boxi, boxj), 0)
25+
target_box_pairs.append(box_pair)
26+
27+
# import pdb; pdb.set_trace()
28+
29+
match_pair_quality_matrix = torch.stack(temp, 0).view(len(temp), -1)
30+
target_box_pairs = torch.stack(target_box_pairs, 0)
31+
target_pair = BoxPairList(target_box_pairs, target.size, target.mode)
32+
target_pair.add_field("labels", target.get_field("pred_labels").view(-1))
33+
34+
box_subj = proposal.bbox
35+
box_obj = proposal.bbox
36+
box_subj = box_subj.unsqueeze(1).repeat(1, box_subj.shape[0], 1)
37+
box_obj = box_obj.unsqueeze(0).repeat(box_obj.shape[0], 1, 1)
38+
proposal_box_pairs = torch.cat((box_subj.view(-1, 4), box_obj.view(-1, 4)), 1)
39+
40+
idx_subj = torch.arange(box_subj.shape[0]).view(-1, 1, 1).repeat(1, box_obj.shape[0], 1).to(proposal.bbox.device)
41+
idx_obj = torch.arange(box_obj.shape[0]).view(1, -1, 1).repeat(box_subj.shape[0], 1, 1).to(proposal.bbox.device)
42+
proposal_idx_pairs = torch.cat((idx_subj.view(-1, 1), idx_obj.view(-1, 1)), 1)
43+
44+
non_duplicate_idx = (proposal_idx_pairs[:, 0] != proposal_idx_pairs[:, 1]).nonzero()
45+
proposal_box_pairs = proposal_box_pairs[non_duplicate_idx.view(-1)]
46+
proposal_idx_pairs = proposal_idx_pairs[non_duplicate_idx.view(-1)]
47+
proposal_pairs = BoxPairList(proposal_box_pairs, proposal.size, proposal.mode)
48+
proposal_pairs.add_field("idx_pairs", proposal_idx_pairs)
49+
50+
# matched_idxs = self.proposal_matcher(match_quality_matrix)
51+
matched_idxs = self.proposal_pair_matcher(match_pair_quality_matrix)
52+
53+
# Fast RCNN only need "labels" field for selecting the targets
54+
# target = target.copy_with_fields("pred_labels")
55+
# get the targets corresponding GT for each proposal
56+
# NB: need to clamp the indices because we can have a single
57+
# GT in the image, and matched_idxs can be -2, which goes
58+
# out of bounds
59+
60+
if self.use_matched_pairs_only and (matched_idxs >= 0).sum() > self.minimal_matched_pairs:
61+
# filter all matched_idxs < 0
62+
proposal_pairs = proposal_pairs[matched_idxs >= 0]
63+
matched_idxs = matched_idxs[matched_idxs >= 0]
64+
65+
matched_targets = target_pair[matched_idxs.clamp(min=0)]
66+
matched_targets.add_field("matched_idxs", matched_idxs)
67+
return matched_targets, proposal_pairs
68+
69+
def prepare_targets(self, proposals, targets):
70+
labels = []
71+
proposal_pairs = []
72+
for proposals_per_image, targets_per_image in zip(proposals, targets):
73+
matched_targets, proposal_pairs_per_image = self.match_targets_to_proposals(
74+
proposals_per_image, targets_per_image
75+
)
76+
77+
matched_idxs = matched_targets.get_field("matched_idxs")
78+
79+
labels_per_image = matched_targets.get_field("labels")
80+
labels_per_image = labels_per_image.to(dtype=torch.int64)
81+
82+
# Label background (below the low threshold)
83+
bg_inds = matched_idxs == Matcher.BELOW_LOW_THRESHOLD
84+
labels_per_image[bg_inds] = 0
85+
86+
# Label ignore proposals (between low and high thresholds)
87+
ignore_inds = matched_idxs == Matcher.BETWEEN_THRESHOLDS
88+
labels_per_image[ignore_inds] = -1 # -1 is ignored by sampler
89+
90+
# compute regression targets
91+
# regression_targets_per_image = self.box_coder.encode(
92+
# matched_targets.bbox, proposals_per_image.bbox
93+
# )
94+
95+
labels.append(labels_per_image)
96+
proposal_pairs.append(proposal_pairs_per_image)
97+
98+
# regression_targets.append(regression_targets_per_image)
99+
100+
return labels, proposal_pairs
101+
102+
def forward(self, proposals, targets):
103+
"""
104+
This method performs the positive/negative sampling, and return
105+
the sampled proposals.
106+
Note: this function keeps a state.
107+
108+
Arguments:
109+
proposals (list[BoxList])
110+
targets (list[BoxList])
111+
"""
112+
113+
labels, proposal_pairs = self.prepare_targets(proposals, targets)
114+
sampled_pos_inds, sampled_neg_inds = self.fg_bg_pair_sampler(labels)
115+
116+
proposal_pairs = list(proposal_pairs)
117+
# add corresponding label and regression_targets information to the bounding boxes
118+
for labels_per_image, proposal_pairs_per_image in zip(
119+
labels, proposal_pairs
120+
):
121+
proposal_pairs_per_image.add_field("labels", labels_per_image)
122+
# proposals_per_image.add_field(
123+
# "regression_targets", regression_targets_per_image
124+
# )
125+
126+
# distributed sampled proposals, that were obtained on all feature maps
127+
# concatenated via the fg_bg_sampler, into individual feature map levels
128+
for img_idx, (pos_inds_img, neg_inds_img) in enumerate(
129+
zip(sampled_pos_inds, sampled_neg_inds)
130+
):
131+
img_sampled_inds = torch.nonzero(pos_inds_img | neg_inds_img).squeeze(1)
132+
proposal_pairs_per_image = proposal_pairs[img_idx][img_sampled_inds]
133+
proposal_pairs[img_idx] = proposal_pairs_per_image
134+
135+
self._proposal_pairs = proposal_pairs
136+
return proposal_pairs
137+
138+
def make_relation_proposal_network(cfg):
139+
return RelPN(cfg)

0 commit comments

Comments
 (0)