Skip to content

Commit 966afee

Browse files
committed
fix a minor typo bug in baseline
1 parent e6c8a89 commit 966afee

File tree

7 files changed

+58
-45
lines changed

7 files changed

+58
-45
lines changed

configs/sgg_res101_step.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,6 @@ MODEL:
2828
GRCNN_FEATURE_UPDATE_STEP: 2
2929
SOLVER:
3030
BASE_LR: 5e-3
31-
MAX_ITER: 40000
32-
STEPS: (20000,30000)
31+
MAX_ITER: 15000
32+
STEPS: (8000,12000)
3333
CHECKPOINT_PERIOD: 1000

lib/config/defaults.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@
3737
""""======================================="""
3838
_C.MODEL = CN()
3939
_C.MODEL.ALGORITHM = "sg_baseline"
40+
_C.MODEL.USE_RELPN = False
4041
_C.MODEL.USE_FREQ_PRIOR = False
41-
4242
_C.MODEL.RPN_ONLY = False
4343
_C.MODEL.MASK_ON = False
4444
_C.MODEL.RETINANET_ON = False

lib/scene_parser/rcnn/modeling/relation_heads/baseline/baseline.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def __init__(self, cfg, in_channels):
1414
super(Baseline, self).__init__()
1515
self.cfg = cfg
1616
self.pred_feature_extractor = make_roi_relation_feature_extractor(cfg, in_channels)
17-
self.predictor = make_roi_relation_predictor(cfg, self.feature_extractor.out_channels)
17+
self.predictor = make_roi_relation_predictor(cfg, self.pred_feature_extractor.out_channels)
1818

1919
def forward(self, features, proposals, proposal_pairs):
2020
obj_class_logits = None # no need to predict object class again

lib/scene_parser/rcnn/modeling/relation_heads/grcnn/agcn/agcn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,11 @@ class _GraphConvolutionLayer_Collect(nn.Module):
5555
def __init__(self, dim_obj, dim_rel):
5656
super(_GraphConvolutionLayer_Collect, self).__init__()
5757
self.collect_units = nn.ModuleList()
58-
# self.collect_units.append(_Collection_Unit(dim_obj, dim_obj)) # obj from obj
5958
self.collect_units.append(_Collection_Unit(dim_rel, dim_obj)) # obj (subject) from rel
6059
self.collect_units.append(_Collection_Unit(dim_rel, dim_obj)) # obj (object) from rel
6160
self.collect_units.append(_Collection_Unit(dim_obj, dim_rel)) # rel from obj (subject)
6261
self.collect_units.append(_Collection_Unit(dim_obj, dim_rel)) # rel from obj (object)
62+
self.collect_units.append(_Collection_Unit(dim_obj, dim_obj)) # obj from obj
6363

6464
def forward(self, target, source, attention, unit_id):
6565
collection = self.collect_units[unit_id](target, source, attention)

lib/scene_parser/rcnn/modeling/relation_heads/grcnn/grcnn.py

+19-11
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Scene Graph Generation by Iterative Message Passing
1+
# Graph R-CNN for scene graph generation
22
# Reimnplemetned by Jianwei Yang ([email protected])
33

44
import numpy as np
@@ -17,7 +17,7 @@ class GRCNN(nn.Module):
1717
def __init__(self, cfg, in_channels):
1818
super(GRCNN, self).__init__()
1919
self.cfg = cfg
20-
self.dim = 512
20+
self.dim = 1024
2121
self.update_step = cfg.MODEL.ROI_RELATION_HEAD.GRCNN_FEATURE_UPDATE_STEP
2222
num_classes_obj = cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES
2323
num_classes_pred = cfg.MODEL.ROI_RELATION_HEAD.NUM_CLASSES
@@ -42,30 +42,35 @@ def __init__(self, cfg, in_channels):
4242
self.gcn_collect_score = _GraphConvolutionLayer_Collect(num_classes_obj, num_classes_pred)
4343
self.gcn_update_score = _GraphConvolutionLayer_Update(num_classes_obj, num_classes_pred)
4444

45-
self.obj_predictor = make_roi_relation_box_predictor(cfg, 512)
46-
self.pred_predictor = make_roi_relation_predictor(cfg, 512)
45+
self.obj_predictor = make_roi_relation_box_predictor(cfg, self.dim)
46+
self.pred_predictor = make_roi_relation_predictor(cfg, self.dim)
4747

4848
def _get_map_idxs(self, proposals, proposal_pairs):
4949
rel_inds = []
5050
offset = 0
51+
obj_num = sum([len(proposal) for proposal in proposals])
52+
obj_obj_map = torch.FloatTensor(obj_num, obj_num).fill_(0)
5153
for proposal, proposal_pair in zip(proposals, proposal_pairs):
5254
rel_ind_i = proposal_pair.get_field("idx_pairs").detach()
55+
obj_obj_map_i = (1 - torch.eye(len(proposal))).float()
56+
obj_obj_map[offset:offset + len(proposal), offset:offset + len(proposal)] = obj_obj_map_i
5357
rel_ind_i += offset
5458
offset += len(proposal)
5559
rel_inds.append(rel_ind_i)
5660

5761
rel_inds = torch.cat(rel_inds, 0)
5862

59-
subj_pred_map = rel_inds.new(sum([len(proposal) for proposal in proposals]), rel_inds.shape[0]).fill_(0).float().detach()
60-
obj_pred_map = rel_inds.new(sum([len(proposal) for proposal in proposals]), rel_inds.shape[0]).fill_(0).float().detach()
63+
subj_pred_map = rel_inds.new(obj_num, rel_inds.shape[0]).fill_(0).float().detach()
64+
obj_pred_map = rel_inds.new(obj_num, rel_inds.shape[0]).fill_(0).float().detach()
6165

6266
subj_pred_map.scatter_(0, (rel_inds[:, 0].contiguous().view(1, -1)), 1)
6367
obj_pred_map.scatter_(0, (rel_inds[:, 1].contiguous().view(1, -1)), 1)
68+
obj_obj_map = obj_obj_map.type_as(obj_pred_map)
6469

65-
return rel_inds, subj_pred_map, obj_pred_map
70+
return rel_inds, obj_obj_map, subj_pred_map, obj_pred_map
6671

6772
def forward(self, features, proposals, proposal_pairs):
68-
rel_inds, subj_pred_map, obj_pred_map = self._get_map_idxs(proposals, proposal_pairs)
73+
rel_inds, obj_obj_map, subj_pred_map, obj_pred_map = self._get_map_idxs(proposals, proposal_pairs)
6974
x_obj = torch.cat([proposal.get_field("features").detach() for proposal in proposals], 0)
7075
obj_class_logits = torch.cat([proposal.get_field("logits").detach() for proposal in proposals], 0)
7176
# x_obj = self.avgpool(self.obj_feature_extractor(features, proposals))
@@ -78,9 +83,12 @@ def forward(self, features, proposals, proposal_pairs):
7883
pred_feats = [x_pred]
7984

8085
for t in range(self.update_step):
86+
# message from other objects
87+
source_obj = self.gcn_collect_feat(obj_feats[t], obj_feats[t], obj_obj_map, 4)
88+
8189
source_rel_sub = self.gcn_collect_feat(obj_feats[t], pred_feats[t], subj_pred_map, 0)
8290
source_rel_obj = self.gcn_collect_feat(obj_feats[t], pred_feats[t], obj_pred_map, 1)
83-
source2obj_all = (source_rel_sub + source_rel_obj) / 2
91+
source2obj_all = (source_obj + source_rel_sub + source_rel_obj) / 3
8492
obj_feats.append(self.gcn_update_feat(obj_feats[t], source2obj_all, 0))
8593

8694
'''update predicate logits'''
@@ -100,12 +108,12 @@ def forward(self, features, proposals, proposal_pairs):
100108
for t in range(self.update_step):
101109
'''update object logits'''
102110
# message from other objects
103-
# source_obj = self.gcn_collect(obj_class_logits, obj_class_logits, map_obj_obj, cfg.COLLECT_OBJ_FROM_OBJ)
111+
source_obj = self.gcn_collect_score(obj_scores[t], obj_scores[t], obj_obj_map, 4)
104112

105113
#essage from predicate
106114
source_rel_sub = self.gcn_collect_score(obj_scores[t], pred_scores[t], subj_pred_map, 0)
107115
source_rel_obj = self.gcn_collect_score(obj_scores[t], pred_scores[t], obj_pred_map, 1)
108-
source2obj_all = (source_rel_sub + source_rel_obj) / 2
116+
source2obj_all = (source_obj + source_rel_sub + source_rel_obj) / 3
109117
obj_scores.append(self.gcn_update_score(obj_scores[t], source2obj_all, 0))
110118

111119
'''update predicate logits'''

lib/scene_parser/rcnn/modeling/relation_heads/relation_heads.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,10 @@
44
import numpy as np
55
import torch
66
from torch import nn
7-
8-
# from .roi_relation_feature_extractors import make_roi_relation_feature_extractor
9-
# from .roi_relation_predictors import make_roi_relation_predictor
7+
from lib.scene_parser.rcnn.structures.bounding_box_pair import BoxPairList
108
from .inference import make_roi_relation_post_processor
119
from .loss import make_roi_relation_loss_evaluator
12-
from lib.scene_parser.rcnn.structures.bounding_box_pair import BoxPairList
10+
from .relpn.relpn import make_relation_proposal_network
1311

1412
from .baseline.baseline import build_baseline_model
1513
from .imp.imp import build_imp_model
@@ -40,6 +38,9 @@ def __init__(self, cfg, in_channels):
4038
self.post_processor = make_roi_relation_post_processor(cfg)
4139
self.loss_evaluator = make_roi_relation_loss_evaluator(cfg)
4240

41+
if self.cfg.MODEL.USE_RELPN:
42+
self.relpn = make_relation_proposal_network(cfg)
43+
4344
self.freq_dist = None
4445
if self.cfg.MODEL.USE_FREQ_PRIOR:
4546
self.freq_dist = torch.from_numpy(np.load("freq_prior.npy"))
@@ -85,11 +86,16 @@ def forward(self, features, proposals, targets=None):
8586
if self.training:
8687
# Faster R-CNN subsamples during training the proposals with a fixed
8788
# positive / negative ratio
88-
with torch.no_grad():
89-
proposal_pairs = self.loss_evaluator.subsample(proposals, targets)
89+
if self.cfg.MODEL.USE_RELPN:
90+
proposal_pairs = self.relpn(proposals, targets)
91+
else:
92+
with torch.no_grad():
93+
proposal_pairs = self.loss_evaluator.subsample(proposals, targets)
9094
else:
91-
# proposals = [proposal[:32] for proposal in proposals]
92-
proposal_pairs = self._get_proposal_pairs(proposals)
95+
if self.cfg.MODEL.USE_RELPN:
96+
proposal_pairs = self.relpn(proposals, targets)
97+
else:
98+
proposal_pairs = self._get_proposal_pairs(proposals)
9399

94100
if self.cfg.MODEL.USE_FREQ_PRIOR:
95101
"""

lib/scene_parser/rcnn/modeling/relation_heads/reldn/reldn.py

+20-21
Original file line numberDiff line numberDiff line change
@@ -40,20 +40,20 @@ def __init__(self, cfg, in_channels, eps=1e-10):
4040
nn.Linear(self.dim, self.dim),
4141
)
4242

43-
# self.rel_embedding = nn.Sequential(
44-
# nn.Linear(3 * self.dim, self.dim),
45-
# nn.ReLU(True),
46-
# nn.Linear(self.dim, self.dim),
47-
# nn.ReLU(True)
48-
# )
43+
self.rel_embedding = nn.Sequential(
44+
nn.Linear(3 * self.dim, self.dim),
45+
nn.ReLU(True),
46+
nn.Linear(self.dim, self.dim),
47+
nn.ReLU(True)
48+
)
4949

50-
# self.rel_spatial_feat = build_spatial_feature(cfg, self.dim)
50+
self.rel_spatial_feat = build_spatial_feature(cfg, self.dim)
5151

52-
# self.rel_subj_predictor = make_roi_relation_predictor(cfg, 512)
53-
# self.rel_obj_predictor = make_roi_relation_predictor(cfg, 512)
52+
self.rel_subj_predictor = make_roi_relation_predictor(cfg, 512)
53+
self.rel_obj_predictor = make_roi_relation_predictor(cfg, 512)
5454
self.rel_pred_predictor = make_roi_relation_predictor(cfg, 512)
5555

56-
# self.rel_spt_predictor = nn.Linear(64, num_classes)
56+
self.rel_spt_predictor = nn.Linear(64, num_classes)
5757

5858

5959
self.freq_dist = torch.from_numpy(np.load("freq_prior.npy"))
@@ -91,25 +91,24 @@ def forward(self, features, proposals, proposal_pairs):
9191
# x_obj = self.avgpool(self.obj_feature_extractor(features, proposals))
9292
x_pred = self.avgpool(self.pred_feature_extractor(features, proposal_pairs))
9393
x_obj = x_obj.view(x_obj.size(0), -1); x_pred = x_pred.view(x_pred.size(0), -1)
94-
x_obj = self.obj_embedding(x_obj);
95-
x_pred = self.pred_embedding(x_pred)
94+
x_obj = self.obj_embedding(x_obj); x_pred = self.pred_embedding(x_pred)
9695

9796
sub_vert = x_obj[rel_inds[:, 0]] #
9897
obj_vert = x_obj[rel_inds[:, 1]]
9998

10099
'''compute visual scores'''
101-
# rel_subj_class_logits = self.rel_subj_predictor(sub_vert.unsqueeze(2).unsqueeze(3))
102-
# rel_obj_class_logits = self.rel_obj_predictor(obj_vert.unsqueeze(2).unsqueeze(3))
100+
rel_subj_class_logits = self.rel_subj_predictor(sub_vert.unsqueeze(2).unsqueeze(3))
101+
rel_obj_class_logits = self.rel_obj_predictor(obj_vert.unsqueeze(2).unsqueeze(3))
103102

104-
x_rel = x_pred + sub_vert + obj_vert # torch.cat([sub_vert, obj_vert, x_pred], 1)
105-
# x_rel = self.rel_embedding(x_rel)
103+
x_rel = torch.cat([sub_vert, obj_vert, x_pred], 1)
104+
x_rel = self.rel_embedding(x_rel)
106105
rel_pred_class_logits = self.rel_pred_predictor(x_rel.unsqueeze(2).unsqueeze(3))
107-
# rel_vis_class_logits = rel_pred_class_logits + rel_subj_class_logits + rel_obj_class_logits
108-
rel_vis_class_logits = rel_pred_class_logits # + rel_subj_class_logits + rel_obj_class_logits
106+
rel_vis_class_logits = rel_pred_class_logits + rel_subj_class_logits + rel_obj_class_logits
107+
# rel_vis_class_logits = rel_pred_class_logits # + rel_subj_class_logits + rel_obj_class_logits
109108

110109
'''compute spatial scores'''
111-
# edge_spt_feats = self.rel_spatial_feat(proposal_pairs)
112-
# rel_spt_class_logits = self.rel_spt_predictor(edge_spt_feats)
110+
edge_spt_feats = self.rel_spatial_feat(proposal_pairs)
111+
rel_spt_class_logits = self.rel_spt_predictor(edge_spt_feats)
113112

114113
'''compute semantic scores'''
115114
rel_sem_class_logits = []
@@ -129,7 +128,7 @@ def forward(self, features, proposals, proposal_pairs):
129128
rel_sem_class_logits.append(class_logits_per_image)
130129
rel_sem_class_logits = torch.cat(rel_sem_class_logits, 0)
131130

132-
rel_class_logits = rel_vis_class_logits + rel_sem_class_logits # + rel_spt_class_logits #
131+
rel_class_logits = rel_vis_class_logits + rel_sem_class_logits + rel_spt_class_logits #
133132
return (x_obj, x_pred), obj_class_logits, rel_class_logits
134133

135134
def build_reldn_model(cfg, in_channels):

0 commit comments

Comments
 (0)