@@ -262,6 +262,32 @@ def _prepare_loss_weights(loss_weights, output_names):
262
262
'got {}' .format (str (loss_weights )))
263
263
264
264
265
+ def _clone_metrics (metrics ):
266
+ """Creates a copy of the maybe-nested metric specification.
267
+
268
+ Args:
269
+ metrics: A collection of metric specifications. Supports the same set of
270
+ formats as the `metrics` argument in `tf.keras.Model.compile`.
271
+
272
+ Returns:
273
+ The same format as the `metrics` argument, with all `tf.keras.metric.Metric`
274
+ objects replaced by their copies.
275
+ """
276
+
277
+ def clone (metric ):
278
+ # A `Metric` object is stateful and can only be used in 1 model on 1 output.
279
+ # Cloning the object allows the same metric to be applied in both base and
280
+ # adversarial-regularized models, and also on multiple outputs in one model.
281
+ # The cloning logic is the same as the `clone_metric` function in
282
+ # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/metrics.py
283
+ if not isinstance (metric , keras .metrics .Metric ):
284
+ return metric
285
+ with tf .init_scope ():
286
+ return metric .__class__ .from_config (metric .get_config ())
287
+
288
+ return tf .nest .map_structure (clone , metrics )
289
+
290
+
265
291
def _prepare_metric_fns (metrics , output_names , loss_wrappers ):
266
292
"""Converts `metrics` into a list of per-output list of metrics.
267
293
@@ -290,16 +316,16 @@ def _prepare_metric_fns(metrics, output_names, loss_wrappers):
290
316
to_list = lambda x : x if isinstance (x , list ) else [x ]
291
317
292
318
if isinstance (metrics , collections .Mapping ):
293
- # If `metrics` is a dictionary mapping output name to a list of metric fns,
294
- # coverts it to a list of lists using the order in `output_names`.
319
+ # Converts `metrics` from a dictionary to a list of lists using the order
320
+ # specified in `output_names`.
295
321
metrics = [to_list (metrics .get (name , [])) for name in output_names ]
296
322
297
323
if not any (isinstance (m , list ) for m in metrics ):
298
- # If `metrics` is a list of metric fns, replicates them to be a list of
299
- # lists so that all metric fns can be applied to each output.
300
- metrics = [metrics for _ in output_names ]
324
+ # Replicates `metrics` to be a list of lists if it is a plain list of
325
+ # metrics, so that all metrics can be applied to each output.
326
+ metrics = [metrics ] + [ _clone_metrics ( metrics ) for _ in output_names [ 1 :] ]
301
327
302
- # Here `metrics` is a list of lists, each sub-list corresponds to metric fns
328
+ # Here `metrics` is a list of lists, and each sub-list corresponds to metrics
303
329
# to be applied on an output.
304
330
if len (metrics ) != len (output_names ):
305
331
raise ValueError ('The number of sub-lists in `metrics` should be the '
@@ -326,6 +352,7 @@ def _compute_loss_and_metrics(losses,
326
352
outputs. Must have the same length as `labels` and `outputs`.
327
353
metrics: List of list of (metric fn, metric name) pairs, for additional
328
354
metrics to report for each output. Must have the same length as `outputs`.
355
+ If set to `None`, no additional metrics will be reported.
329
356
labels: List of `Tensor` objects of ground truth targets. Must have the same
330
357
length as `losses` and `outputs`.
331
358
outputs: List of `Tensor` objects of predicted targets. Must have the same
@@ -334,17 +361,26 @@ def _compute_loss_and_metrics(losses,
334
361
335
362
Returns:
336
363
total_loss: Weighted sum of losses on all outputs.
337
- metrics: List of (value, name) pairs for metric reporting.
364
+ metrics: List of (value, aggregation, name) tuples for metric reporting.
338
365
"""
339
366
outputs = tf .nest .flatten (outputs )
340
367
total_loss , output_metrics = [], []
368
+ if metrics is None :
369
+ metrics = [[]] * len (losses )
341
370
for (label , output , loss , per_output_metrics ) in zip (labels , outputs , losses ,
342
371
metrics ):
343
372
loss_value = loss (label , output , sample_weights )
344
373
total_loss .append (loss .weight * loss_value )
345
- output_metrics .append ((loss_value , loss .name ))
374
+ output_metrics .append ((loss_value , 'mean' , loss .name ))
346
375
for metric_fn , metric_name in per_output_metrics :
347
- output_metrics .append ((metric_fn (label , output ), metric_name ))
376
+ value = metric_fn (label , output )
377
+ # Metric objects always return an aggregated result, and shouldn't be
378
+ # aggregated again.
379
+ if isinstance (metric_fn , keras .metrics .Metric ):
380
+ aggregation = None
381
+ else :
382
+ aggregation = 'mean'
383
+ output_metrics .append ((value , aggregation , metric_name ))
348
384
return tf .add_n (total_loss ), output_metrics
349
385
350
386
@@ -451,7 +487,7 @@ def compile(self,
451
487
self .base_model .compile (
452
488
optimizer ,
453
489
loss = self ._compile_arg_loss ,
454
- metrics = self ._compile_arg_metrics ,
490
+ metrics = _clone_metrics ( self ._compile_arg_metrics ) ,
455
491
loss_weights = self ._compile_arg_loss_weights ,
456
492
** kwargs )
457
493
@@ -517,6 +553,9 @@ def _build_labeled_metrics(self, output_names, labeled_losses):
517
553
per_output_metrics = []
518
554
for metric_fn in metric_fns :
519
555
metric_name = self ._make_metric_name (metric_fn , label_key )
556
+ if isinstance (metric_fn , keras .metrics .Metric ):
557
+ # Updates the name of the Metric object to make sure it is unique.
558
+ metric_fn ._name = metric_name # pylint: disable=protected-access
520
559
per_output_metrics .append ((metric_fn , metric_name ))
521
560
self ._labeled_metrics .append (per_output_metrics )
522
561
@@ -526,9 +565,10 @@ def _get_or_create_base_output_names(self, outputs):
526
565
['output_%d' % i for i in range (1 , num_output + 1 )])
527
566
528
567
def _compute_total_loss (self , labels , outputs , sample_weights = None ):
529
- loss , _ = _compute_loss_and_metrics (self ._labeled_losses ,
530
- self ._labeled_metrics , labels , outputs ,
531
- sample_weights )
568
+ # `None` is passed instead of the actual metrics in order to skip computing
569
+ # metric values and updating metric states.
570
+ loss , _ = _compute_loss_and_metrics (self ._labeled_losses , None , labels ,
571
+ outputs , sample_weights )
532
572
return loss
533
573
534
574
def _split_inputs (self , inputs ):
@@ -575,8 +615,8 @@ def call(self, inputs, **kwargs):
575
615
outputs , labeled_loss , metrics , tape = self ._forward_pass (
576
616
inputs , labels , sample_weights , kwargs )
577
617
self .add_loss (labeled_loss )
578
- for value , name in metrics :
579
- self .add_metric (value , aggregation = 'mean' , name = name )
618
+ for value , aggregation , name in metrics :
619
+ self .add_metric (value , aggregation = aggregation , name = name )
580
620
581
621
# Adversarial loss.
582
622
base_model_fn = lambda inputs : self .base_model (inputs , ** kwargs )
0 commit comments