@@ -706,6 +706,33 @@ def version_7(cls, ctx, node, **kwargs):
706706 _make_softmax_cross_entropy_with_logits (ctx , labels , logits , node )
707707
708708
709+ def _make_sparse_softmax_cross_entropy_with_logits (ctx , label , logit , tf_ori_node ):
710+ logit = logit .output [0 ]
711+ label = label .output [0 ]
712+ label_dtype = ctx .get_dtype (label )
713+ logit_dtype = ctx .get_dtype (logit )
714+ utils .make_sure (label_dtype == logit_dtype , "the following logic only works on same dtype of label and logit" )
715+
716+ # when label is onehot, logic "tf.multiply(-1, tf.reduce_sum(tf.multiply(label, log_softmax), axis=1))" is equal to
717+ # "-log(q_i)" where i is the selected index specified by label, q_i = logic_i/sum, the detail process is as follows:
718+ # logit_exp=exp(logit) >> sum = tf.reduce_sum(logit_exp, axis = -1), masked_sum = reduce_sum(mul(logit_exp, mul))
719+ # >> -log(masked_sum/sum)
720+ logit_exp = ctx .make_node (op_type = "Exp" , inputs = [logit ]).output [0 ]
721+ logit_exp_sum = ctx .make_node (op_type = "ReduceSum" , inputs = [logit_exp ], attr = {"axes" : [- 1 ], "keepdims" : 0 }).output [0 ]
722+ masked = ctx .make_node (op_type = "Mul" , inputs = [label , logit_exp ]).output [0 ]
723+ masked_sum = ctx .make_node (op_type = "ReduceSum" , inputs = [masked ], attr = {"axes" : [- 1 ], "keepdims" : 0 }).output [0 ]
724+ probability = ctx .make_node (op_type = "Div" , inputs = [masked_sum , logit_exp_sum ]).output [0 ]
725+ log_prob = ctx .make_node (op_type = "Log" , inputs = [probability ]).output [0 ]
726+ const_negative_one = ctx .make_const (name = utils .make_name ("const_negative_one" ),
727+ np_val = np .array (- 1 ).astype (utils .ONNX_TO_NUMPY_DTYPE [logit_dtype ])).output [0 ]
728+
729+ shapes = tf_ori_node .output_shapes
730+ dtypes = tf_ori_node .output_dtypes
731+ ctx .remove_node (tf_ori_node .name )
732+ res = ctx .make_node (op_type = "Mul" , inputs = [log_prob , const_negative_one ],
733+ outputs = [tf_ori_node .output [0 ]], shapes = [shapes [0 ]], dtypes = [dtypes [0 ]])
734+
735+
709736@tf_op ("SparseSoftmaxCrossEntropyWithLogits" )
710737class SparseSoftmaxCrossEntropyWithLogits :
711738 @classmethod
@@ -778,4 +805,4 @@ def version_9(cls, ctx, node, **kwargs):
778805 if logit_dtype != TensorProto .INT64 :
779806 label_node = ctx .make_node ("Cast" , label_node .output , attr = {"to" : logit_dtype }, dtypes = [logit_dtype ])
780807
781- _make_softmax_cross_entropy_with_logits (ctx , label_node , logit_node , node )
808+ _make_sparse_softmax_cross_entropy_with_logits (ctx , label_node , logit_node , node )
0 commit comments