Skip to content

Commit 45e50ee

Browse files
csferngtensorflow-copybara
authored andcommitted
Fix metric computation with Metric objects in AdversarialRegularization.
Cause of the problem: When a Keras model calls `self.add_metric()` with `aggregation='mean'`, a new `Metric` object is created to track the state of the metric. But if the metric is calculated by an existing `Metric` object, the state update only happens on the existing `Metric` object, but not the newly created one. Thus the shown metric (from the new `Metric` object) is never updated. Details of the fix: The fix sets `aggregation` to `None` when a metric is calculated by a `Metric` object. In this way the Keras model will report the numbers from that `Metric` object, instead of creating a new one. But another issue is triggered: The same `Metric` object is now compiled into both the base model and the adversarial-regularized model, which messes up the computational graph when running with Tensorflow 1.x (graph mode). To resolve the computational graph issue, each `Metric` object is cloned before passing to the base model. The `Metric` objects are also cloned if they are to be applied on multiple outputs. Fixes #8 PiperOrigin-RevId: 269422315
1 parent 8f02c44 commit 45e50ee

File tree

2 files changed

+107
-15
lines changed

2 files changed

+107
-15
lines changed

neural_structured_learning/keras/adversarial_regularization.py

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,32 @@ def _prepare_loss_weights(loss_weights, output_names):
262262
'got {}'.format(str(loss_weights)))
263263

264264

265+
def _clone_metrics(metrics):
266+
"""Creates a copy of the maybe-nested metric specification.
267+
268+
Args:
269+
metrics: A collection of metric specifications. Supports the same set of
270+
formats as the `metrics` argument in `tf.keras.Model.compile`.
271+
272+
Returns:
273+
The same format as the `metrics` argument, with all `tf.keras.metric.Metric`
274+
objects replaced by their copies.
275+
"""
276+
277+
def clone(metric):
278+
# A `Metric` object is stateful and can only be used in 1 model on 1 output.
279+
# Cloning the object allows the same metric to be applied in both base and
280+
# adversarial-regularized models, and also on multiple outputs in one model.
281+
# The cloning logic is the same as the `clone_metric` function in
282+
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/metrics.py
283+
if not isinstance(metric, keras.metrics.Metric):
284+
return metric
285+
with tf.init_scope():
286+
return metric.__class__.from_config(metric.get_config())
287+
288+
return tf.nest.map_structure(clone, metrics)
289+
290+
265291
def _prepare_metric_fns(metrics, output_names, loss_wrappers):
266292
"""Converts `metrics` into a list of per-output list of metrics.
267293
@@ -290,16 +316,16 @@ def _prepare_metric_fns(metrics, output_names, loss_wrappers):
290316
to_list = lambda x: x if isinstance(x, list) else [x]
291317

292318
if isinstance(metrics, collections.Mapping):
293-
# If `metrics` is a dictionary mapping output name to a list of metric fns,
294-
# coverts it to a list of lists using the order in `output_names`.
319+
# Converts `metrics` from a dictionary to a list of lists using the order
320+
# specified in `output_names`.
295321
metrics = [to_list(metrics.get(name, [])) for name in output_names]
296322

297323
if not any(isinstance(m, list) for m in metrics):
298-
# If `metrics` is a list of metric fns, replicates them to be a list of
299-
# lists so that all metric fns can be applied to each output.
300-
metrics = [metrics for _ in output_names]
324+
# Replicates `metrics` to be a list of lists if it is a plain list of
325+
# metrics, so that all metrics can be applied to each output.
326+
metrics = [metrics] + [_clone_metrics(metrics) for _ in output_names[1:]]
301327

302-
# Here `metrics` is a list of lists, each sub-list corresponds to metric fns
328+
# Here `metrics` is a list of lists, and each sub-list corresponds to metrics
303329
# to be applied on an output.
304330
if len(metrics) != len(output_names):
305331
raise ValueError('The number of sub-lists in `metrics` should be the '
@@ -326,6 +352,7 @@ def _compute_loss_and_metrics(losses,
326352
outputs. Must have the same length as `labels` and `outputs`.
327353
metrics: List of list of (metric fn, metric name) pairs, for additional
328354
metrics to report for each output. Must have the same length as `outputs`.
355+
If set to `None`, no additional metrics will be reported.
329356
labels: List of `Tensor` objects of ground truth targets. Must have the same
330357
length as `losses` and `outputs`.
331358
outputs: List of `Tensor` objects of predicted targets. Must have the same
@@ -334,17 +361,26 @@ def _compute_loss_and_metrics(losses,
334361
335362
Returns:
336363
total_loss: Weighted sum of losses on all outputs.
337-
metrics: List of (value, name) pairs for metric reporting.
364+
metrics: List of (value, aggregation, name) tuples for metric reporting.
338365
"""
339366
outputs = tf.nest.flatten(outputs)
340367
total_loss, output_metrics = [], []
368+
if metrics is None:
369+
metrics = [[]] * len(losses)
341370
for (label, output, loss, per_output_metrics) in zip(labels, outputs, losses,
342371
metrics):
343372
loss_value = loss(label, output, sample_weights)
344373
total_loss.append(loss.weight * loss_value)
345-
output_metrics.append((loss_value, loss.name))
374+
output_metrics.append((loss_value, 'mean', loss.name))
346375
for metric_fn, metric_name in per_output_metrics:
347-
output_metrics.append((metric_fn(label, output), metric_name))
376+
value = metric_fn(label, output)
377+
# Metric objects always return an aggregated result, and shouldn't be
378+
# aggregated again.
379+
if isinstance(metric_fn, keras.metrics.Metric):
380+
aggregation = None
381+
else:
382+
aggregation = 'mean'
383+
output_metrics.append((value, aggregation, metric_name))
348384
return tf.add_n(total_loss), output_metrics
349385

350386

@@ -451,7 +487,7 @@ def compile(self,
451487
self.base_model.compile(
452488
optimizer,
453489
loss=self._compile_arg_loss,
454-
metrics=self._compile_arg_metrics,
490+
metrics=_clone_metrics(self._compile_arg_metrics),
455491
loss_weights=self._compile_arg_loss_weights,
456492
**kwargs)
457493

@@ -517,6 +553,9 @@ def _build_labeled_metrics(self, output_names, labeled_losses):
517553
per_output_metrics = []
518554
for metric_fn in metric_fns:
519555
metric_name = self._make_metric_name(metric_fn, label_key)
556+
if isinstance(metric_fn, keras.metrics.Metric):
557+
# Updates the name of the Metric object to make sure it is unique.
558+
metric_fn._name = metric_name # pylint: disable=protected-access
520559
per_output_metrics.append((metric_fn, metric_name))
521560
self._labeled_metrics.append(per_output_metrics)
522561

@@ -526,9 +565,10 @@ def _get_or_create_base_output_names(self, outputs):
526565
['output_%d' % i for i in range(1, num_output + 1)])
527566

528567
def _compute_total_loss(self, labels, outputs, sample_weights=None):
529-
loss, _ = _compute_loss_and_metrics(self._labeled_losses,
530-
self._labeled_metrics, labels, outputs,
531-
sample_weights)
568+
# `None` is passed instead of the actual metrics in order to skip computing
569+
# metric values and updating metric states.
570+
loss, _ = _compute_loss_and_metrics(self._labeled_losses, None, labels,
571+
outputs, sample_weights)
532572
return loss
533573

534574
def _split_inputs(self, inputs):
@@ -575,8 +615,8 @@ def call(self, inputs, **kwargs):
575615
outputs, labeled_loss, metrics, tape = self._forward_pass(
576616
inputs, labels, sample_weights, kwargs)
577617
self.add_loss(labeled_loss)
578-
for value, name in metrics:
579-
self.add_metric(value, aggregation='mean', name=name)
618+
for value, aggregation, name in metrics:
619+
self.add_metric(value, aggregation=aggregation, name=name)
580620

581621
# Adversarial loss.
582622
base_model_fn = lambda inputs: self.base_model(inputs, **kwargs)

neural_structured_learning/keras/adversarial_regularization_test.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,58 @@ def test_train_with_duplicated_metrics(self):
378378
self.assertEqual(history.history['mean_squared_error'],
379379
history.history['mean_squared_error_2'])
380380

381+
def test_train_with_metric_object(self):
382+
w, x0, y0, lr, adv_config, _ = self._set_up_linear_regression()
383+
384+
inputs = {'feature': tf.constant(x0), 'label': tf.constant(y0)}
385+
model = build_linear_keras_functional_model(input_shape=(2,), weights=w)
386+
adv_model = adversarial_regularization.AdversarialRegularization(
387+
model, label_keys=['label'], adv_config=adv_config)
388+
adv_model.compile(
389+
optimizer=keras.optimizers.SGD(lr),
390+
loss='MSE',
391+
metrics=[tf.keras.metrics.MeanAbsoluteError()])
392+
history = adv_model.fit(x=inputs, batch_size=1, steps_per_epoch=1)
393+
394+
actual_metric = history.history['mean_absolute_error'][0]
395+
expected_metric = np.abs(y0 - np.dot(x0, w)).mean()
396+
self.assertAllClose(expected_metric, actual_metric)
397+
398+
def test_train_with_2_outputs(self):
399+
w, x0, y0, lr, adv_config, _ = self._set_up_linear_regression()
400+
inputs = {
401+
'feature': tf.constant(x0),
402+
'label1': tf.constant(y0),
403+
'label2': tf.constant(-y0)
404+
}
405+
406+
input_layer = keras.Input(shape=(2,), name='feature')
407+
layer1 = keras.layers.Dense(
408+
w.shape[-1],
409+
use_bias=False,
410+
kernel_initializer=keras.initializers.Constant(w))
411+
layer2 = keras.layers.Dense(
412+
w.shape[-1],
413+
use_bias=False,
414+
kernel_initializer=keras.initializers.Constant(-w))
415+
model = keras.Model(
416+
inputs={'feature': input_layer},
417+
outputs=[layer1(input_layer), layer2(input_layer)])
418+
419+
adv_model = adversarial_regularization.AdversarialRegularization(
420+
model, label_keys=['label1', 'label2'], adv_config=adv_config)
421+
adv_model.compile(
422+
optimizer=keras.optimizers.SGD(lr),
423+
loss='MSE',
424+
metrics=[tf.keras.metrics.MeanAbsoluteError()])
425+
history = adv_model.fit(x=inputs, batch_size=1, steps_per_epoch=1)
426+
427+
expected_metric = np.abs(y0 - np.dot(x0, w)).mean()
428+
self.assertAllClose(expected_metric,
429+
history.history['mean_absolute_error_label1'][0])
430+
self.assertAllClose(expected_metric,
431+
history.history['mean_absolute_error_label2'][0])
432+
381433
def test_evaluate_binary_classification_metrics(self):
382434
# multi-label binary classification model
383435
w = np.array([[4.0, 1.0, -5.0], [-3.0, 1.0, 2.0]])

0 commit comments

Comments
 (0)