File tree Expand file tree Collapse file tree 3 files changed +26
-2
lines changed Expand file tree Collapse file tree 3 files changed +26
-2
lines changed Original file line number Diff line number Diff line change 5
5
from luminoth .models .fasterrcnn .rcnn_target import RCNNTarget
6
6
from luminoth .models .fasterrcnn .roi_pool import ROIPoolingLayer
7
7
from luminoth .utils .losses import smooth_l1_loss
8
+ from luminoth .utils .safe_wrappers import (
9
+ safe_softmax_cross_entropy_with_logits
10
+ )
8
11
from luminoth .utils .vars import (
9
12
get_initializer , layer_summaries , variable_summaries ,
10
13
get_activation_function
@@ -304,7 +307,7 @@ def loss(self, prediction_dict):
304
307
305
308
# We get cross entropy loss of each proposal.
306
309
cross_entropy_per_proposal = (
307
- tf . nn . softmax_cross_entropy_with_logits (
310
+ safe_softmax_cross_entropy_with_logits (
308
311
labels = cls_target_one_hot , logits = cls_score_labeled
309
312
)
310
313
)
Original file line number Diff line number Diff line change 10
10
from .rpn_target import RPNTarget
11
11
from .rpn_proposal import RPNProposal
12
12
from luminoth .utils .losses import smooth_l1_loss
13
+ from luminoth .utils .safe_wrappers import (
14
+ safe_softmax_cross_entropy_with_logits
15
+ )
13
16
from luminoth .utils .vars import (
14
17
get_initializer , layer_summaries , variable_summaries ,
15
18
get_activation_function
@@ -257,7 +260,7 @@ def loss(self, prediction_dict):
257
260
cls_target = tf .one_hot (labels , depth = 2 )
258
261
259
262
# Equivalent to log loss
260
- ce_per_anchor = tf . nn . softmax_cross_entropy_with_logits (
263
+ ce_per_anchor = safe_softmax_cross_entropy_with_logits (
261
264
labels = cls_target , logits = cls_score
262
265
)
263
266
prediction_dict ['cross_entropy_per_anchor' ] = ce_per_anchor
Original file line number Diff line number Diff line change
1
+ import tensorflow as tf
2
+
3
+
4
+ def safe_softmax_cross_entropy_with_logits (
5
+ labels , logits , name = 'safe_cross_entropy' ):
6
+ with tf .name_scope (name ):
7
+ safety_condition = tf .logical_and (
8
+ tf .greater (tf .shape (labels )[0 ], 0 , name = 'labels_notzero' ),
9
+ tf .greater (tf .shape (logits )[0 ], 0 , name = 'logits_notzero' ),
10
+ name = 'safety_condition'
11
+ )
12
+ return tf .cond (
13
+ safety_condition ,
14
+ true_fn = lambda : tf .nn .softmax_cross_entropy_with_logits (
15
+ labels = labels , logits = logits
16
+ ),
17
+ false_fn = lambda : tf .constant ([], dtype = logits .dtype )
18
+ )
You can’t perform that action at this time.
0 commit comments