1
- # Scene Graph Generation by Iterative Message Passing
1
+ # Graph R-CNN for scene graph generation
2
2
# Reimnplemetned by Jianwei Yang ([email protected] )
3
3
4
4
import numpy as np
@@ -17,7 +17,7 @@ class GRCNN(nn.Module):
17
17
def __init__ (self , cfg , in_channels ):
18
18
super (GRCNN , self ).__init__ ()
19
19
self .cfg = cfg
20
- self .dim = 512
20
+ self .dim = 1024
21
21
self .update_step = cfg .MODEL .ROI_RELATION_HEAD .GRCNN_FEATURE_UPDATE_STEP
22
22
num_classes_obj = cfg .MODEL .ROI_BOX_HEAD .NUM_CLASSES
23
23
num_classes_pred = cfg .MODEL .ROI_RELATION_HEAD .NUM_CLASSES
@@ -42,30 +42,35 @@ def __init__(self, cfg, in_channels):
42
42
self .gcn_collect_score = _GraphConvolutionLayer_Collect (num_classes_obj , num_classes_pred )
43
43
self .gcn_update_score = _GraphConvolutionLayer_Update (num_classes_obj , num_classes_pred )
44
44
45
- self .obj_predictor = make_roi_relation_box_predictor (cfg , 512 )
46
- self .pred_predictor = make_roi_relation_predictor (cfg , 512 )
45
+ self .obj_predictor = make_roi_relation_box_predictor (cfg , self . dim )
46
+ self .pred_predictor = make_roi_relation_predictor (cfg , self . dim )
47
47
48
48
def _get_map_idxs (self , proposals , proposal_pairs ):
49
49
rel_inds = []
50
50
offset = 0
51
+ obj_num = sum ([len (proposal ) for proposal in proposals ])
52
+ obj_obj_map = torch .FloatTensor (obj_num , obj_num ).fill_ (0 )
51
53
for proposal , proposal_pair in zip (proposals , proposal_pairs ):
52
54
rel_ind_i = proposal_pair .get_field ("idx_pairs" ).detach ()
55
+ obj_obj_map_i = (1 - torch .eye (len (proposal ))).float ()
56
+ obj_obj_map [offset :offset + len (proposal ), offset :offset + len (proposal )] = obj_obj_map_i
53
57
rel_ind_i += offset
54
58
offset += len (proposal )
55
59
rel_inds .append (rel_ind_i )
56
60
57
61
rel_inds = torch .cat (rel_inds , 0 )
58
62
59
- subj_pred_map = rel_inds .new (sum ([ len ( proposal ) for proposal in proposals ]) , rel_inds .shape [0 ]).fill_ (0 ).float ().detach ()
60
- obj_pred_map = rel_inds .new (sum ([ len ( proposal ) for proposal in proposals ]) , rel_inds .shape [0 ]).fill_ (0 ).float ().detach ()
63
+ subj_pred_map = rel_inds .new (obj_num , rel_inds .shape [0 ]).fill_ (0 ).float ().detach ()
64
+ obj_pred_map = rel_inds .new (obj_num , rel_inds .shape [0 ]).fill_ (0 ).float ().detach ()
61
65
62
66
subj_pred_map .scatter_ (0 , (rel_inds [:, 0 ].contiguous ().view (1 , - 1 )), 1 )
63
67
obj_pred_map .scatter_ (0 , (rel_inds [:, 1 ].contiguous ().view (1 , - 1 )), 1 )
68
+ obj_obj_map = obj_obj_map .type_as (obj_pred_map )
64
69
65
- return rel_inds , subj_pred_map , obj_pred_map
70
+ return rel_inds , obj_obj_map , subj_pred_map , obj_pred_map
66
71
67
72
def forward (self , features , proposals , proposal_pairs ):
68
- rel_inds , subj_pred_map , obj_pred_map = self ._get_map_idxs (proposals , proposal_pairs )
73
+ rel_inds , obj_obj_map , subj_pred_map , obj_pred_map = self ._get_map_idxs (proposals , proposal_pairs )
69
74
x_obj = torch .cat ([proposal .get_field ("features" ).detach () for proposal in proposals ], 0 )
70
75
obj_class_logits = torch .cat ([proposal .get_field ("logits" ).detach () for proposal in proposals ], 0 )
71
76
# x_obj = self.avgpool(self.obj_feature_extractor(features, proposals))
@@ -78,9 +83,12 @@ def forward(self, features, proposals, proposal_pairs):
78
83
pred_feats = [x_pred ]
79
84
80
85
for t in range (self .update_step ):
86
+ # message from other objects
87
+ source_obj = self .gcn_collect_feat (obj_feats [t ], obj_feats [t ], obj_obj_map , 4 )
88
+
81
89
source_rel_sub = self .gcn_collect_feat (obj_feats [t ], pred_feats [t ], subj_pred_map , 0 )
82
90
source_rel_obj = self .gcn_collect_feat (obj_feats [t ], pred_feats [t ], obj_pred_map , 1 )
83
- source2obj_all = (source_rel_sub + source_rel_obj ) / 2
91
+ source2obj_all = (source_obj + source_rel_sub + source_rel_obj ) / 3
84
92
obj_feats .append (self .gcn_update_feat (obj_feats [t ], source2obj_all , 0 ))
85
93
86
94
'''update predicate logits'''
@@ -100,12 +108,12 @@ def forward(self, features, proposals, proposal_pairs):
100
108
for t in range (self .update_step ):
101
109
'''update object logits'''
102
110
# message from other objects
103
- # source_obj = self.gcn_collect(obj_class_logits, obj_class_logits, map_obj_obj, cfg.COLLECT_OBJ_FROM_OBJ )
111
+ source_obj = self .gcn_collect_score ( obj_scores [ t ], obj_scores [ t ], obj_obj_map , 4 )
104
112
105
113
#essage from predicate
106
114
source_rel_sub = self .gcn_collect_score (obj_scores [t ], pred_scores [t ], subj_pred_map , 0 )
107
115
source_rel_obj = self .gcn_collect_score (obj_scores [t ], pred_scores [t ], obj_pred_map , 1 )
108
- source2obj_all = (source_rel_sub + source_rel_obj ) / 2
116
+ source2obj_all = (source_obj + source_rel_sub + source_rel_obj ) / 3
109
117
obj_scores .append (self .gcn_update_score (obj_scores [t ], source2obj_all , 0 ))
110
118
111
119
'''update predicate logits'''
0 commit comments