Skip to content

Commit 3e81f1f

Browse files
committed
Add a safe wrapper for cross entropy.
This fixes an error we were getting when softmax_cross_entropy_with_logits received empty tensors.
1 parent 1de9dcd commit 3e81f1f

File tree

3 files changed

+26
-2
lines changed

3 files changed

+26
-2
lines changed

luminoth/models/fasterrcnn/rcnn.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from luminoth.models.fasterrcnn.rcnn_target import RCNNTarget
66
from luminoth.models.fasterrcnn.roi_pool import ROIPoolingLayer
77
from luminoth.utils.losses import smooth_l1_loss
8+
from luminoth.utils.safe_wrappers import (
9+
safe_softmax_cross_entropy_with_logits
10+
)
811
from luminoth.utils.vars import (
912
get_initializer, layer_summaries, variable_summaries,
1013
get_activation_function
@@ -304,7 +307,7 @@ def loss(self, prediction_dict):
304307

305308
# We get cross entropy loss of each proposal.
306309
cross_entropy_per_proposal = (
307-
tf.nn.softmax_cross_entropy_with_logits(
310+
safe_softmax_cross_entropy_with_logits(
308311
labels=cls_target_one_hot, logits=cls_score_labeled
309312
)
310313
)

luminoth/models/fasterrcnn/rpn.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from .rpn_target import RPNTarget
1111
from .rpn_proposal import RPNProposal
1212
from luminoth.utils.losses import smooth_l1_loss
13+
from luminoth.utils.safe_wrappers import (
14+
safe_softmax_cross_entropy_with_logits
15+
)
1316
from luminoth.utils.vars import (
1417
get_initializer, layer_summaries, variable_summaries,
1518
get_activation_function
@@ -257,7 +260,7 @@ def loss(self, prediction_dict):
257260
cls_target = tf.one_hot(labels, depth=2)
258261

259262
# Equivalent to log loss
260-
ce_per_anchor = tf.nn.softmax_cross_entropy_with_logits(
263+
ce_per_anchor = safe_softmax_cross_entropy_with_logits(
261264
labels=cls_target, logits=cls_score
262265
)
263266
prediction_dict['cross_entropy_per_anchor'] = ce_per_anchor

luminoth/utils/safe_wrappers.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import tensorflow as tf
2+
3+
4+
def safe_softmax_cross_entropy_with_logits(
5+
labels, logits, name='safe_cross_entropy'):
6+
with tf.name_scope(name):
7+
safety_condition = tf.logical_and(
8+
tf.greater(tf.shape(labels)[0], 0, name='labels_notzero'),
9+
tf.greater(tf.shape(logits)[0], 0, name='logits_notzero'),
10+
name='safety_condition'
11+
)
12+
return tf.cond(
13+
safety_condition,
14+
true_fn=lambda: tf.nn.softmax_cross_entropy_with_logits(
15+
labels=labels, logits=logits
16+
),
17+
false_fn=lambda: tf.constant([], dtype=logits.dtype)
18+
)

0 commit comments

Comments
 (0)