From 28fcbfc357db99bfe9ec8577ce312f66ff94dc50 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 15 Oct 2025 15:23:16 -0700 Subject: [PATCH] Allow Input to be optional to take None inputs, similar to what keras3 has. PiperOrigin-RevId: 819935785 --- ...internal__.legacy.layers.-input-spec.pbtxt | 2 +- ...tensorflow.keras.layers.-input-layer.pbtxt | 2 +- .../tensorflow.keras.layers.-input-spec.pbtxt | 2 +- .../golden/v1/tensorflow.keras.layers.pbtxt | 2 +- tf_keras/api/golden/v1/tensorflow.keras.pbtxt | 2 +- ...tensorflow.keras.layers.-input-layer.pbtxt | 2 +- .../tensorflow.keras.layers.-input-spec.pbtxt | 2 +- .../golden/v2/tensorflow.keras.layers.pbtxt | 2 +- tf_keras/api/golden/v2/tensorflow.keras.pbtxt | 2 +- tf_keras/engine/data_adapter.py | 47 +++++-- tf_keras/engine/data_adapter_test.py | 122 ++++++++++++++++++ tf_keras/engine/functional.py | 51 +++++--- tf_keras/engine/functional_test.py | 27 ++++ tf_keras/engine/input_layer.py | 9 ++ tf_keras/engine/input_spec.py | 12 +- tf_keras/engine/training_utils_v1.py | 2 +- tf_keras/engine/training_v1.py | 12 +- 17 files changed, 257 insertions(+), 43 deletions(-) diff --git a/tf_keras/api/golden/v1/tensorflow.keras.__internal__.legacy.layers.-input-spec.pbtxt b/tf_keras/api/golden/v1/tensorflow.keras.__internal__.legacy.layers.-input-spec.pbtxt index 6421e356a..3c03f1079 100644 --- a/tf_keras/api/golden/v1/tensorflow.keras.__internal__.legacy.layers.-input-spec.pbtxt +++ b/tf_keras/api/golden/v1/tensorflow.keras.__internal__.legacy.layers.-input-spec.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\', \'allow_last_axis_squeeze\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\'], " + argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\', \'allow_last_axis_squeeze\', \'name\', \'optional\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], " } member_method { name: "from_config" diff --git a/tf_keras/api/golden/v1/tensorflow.keras.layers.-input-layer.pbtxt b/tf_keras/api/golden/v1/tensorflow.keras.layers.-input-layer.pbtxt index e5b3e97d3..543bd7c52 100644 --- a/tf_keras/api/golden/v1/tensorflow.keras.layers.-input-layer.pbtxt +++ b/tf_keras/api/golden/v1/tensorflow.keras.layers.-input-layer.pbtxt @@ -129,7 +129,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'input_shape\', \'batch_size\', \'dtype\', \'input_tensor\', \'sparse\', \'name\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'input_shape\', \'batch_size\', \'dtype\', \'input_tensor\', \'sparse\', \'name\', \'ragged\', \'type_spec\', \'optional\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\'], " } member_method { name: "add_loss" diff --git a/tf_keras/api/golden/v1/tensorflow.keras.layers.-input-spec.pbtxt b/tf_keras/api/golden/v1/tensorflow.keras.layers.-input-spec.pbtxt index 5aef9ca71..a20ca1d4f 100644 --- a/tf_keras/api/golden/v1/tensorflow.keras.layers.-input-spec.pbtxt +++ b/tf_keras/api/golden/v1/tensorflow.keras.layers.-input-spec.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\', \'allow_last_axis_squeeze\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\'], " + argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\', \'allow_last_axis_squeeze\', \'name\', \'optional\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], " } member_method { name: "from_config" diff --git a/tf_keras/api/golden/v1/tensorflow.keras.layers.pbtxt b/tf_keras/api/golden/v1/tensorflow.keras.layers.pbtxt index 6ae37c06b..20ef13fff 100644 --- a/tf_keras/api/golden/v1/tensorflow.keras.layers.pbtxt +++ b/tf_keras/api/golden/v1/tensorflow.keras.layers.pbtxt @@ -482,7 +482,7 @@ tf_module { } member_method { name: "Input" - argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\', \'optional\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\'], " } member_method { name: "add" diff --git a/tf_keras/api/golden/v1/tensorflow.keras.pbtxt b/tf_keras/api/golden/v1/tensorflow.keras.pbtxt index a5592a0f0..f68f76053 100644 --- a/tf_keras/api/golden/v1/tensorflow.keras.pbtxt +++ b/tf_keras/api/golden/v1/tensorflow.keras.pbtxt @@ -90,6 +90,6 @@ tf_module { } member_method { name: "Input" - argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\', \'optional\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\'], " } } diff --git a/tf_keras/api/golden/v2/tensorflow.keras.layers.-input-layer.pbtxt b/tf_keras/api/golden/v2/tensorflow.keras.layers.-input-layer.pbtxt index e5b3e97d3..543bd7c52 100644 --- a/tf_keras/api/golden/v2/tensorflow.keras.layers.-input-layer.pbtxt +++ b/tf_keras/api/golden/v2/tensorflow.keras.layers.-input-layer.pbtxt @@ -129,7 +129,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'input_shape\', \'batch_size\', \'dtype\', \'input_tensor\', \'sparse\', \'name\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'input_shape\', \'batch_size\', \'dtype\', \'input_tensor\', \'sparse\', \'name\', \'ragged\', \'type_spec\', \'optional\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\'], " } member_method { name: "add_loss" diff --git a/tf_keras/api/golden/v2/tensorflow.keras.layers.-input-spec.pbtxt b/tf_keras/api/golden/v2/tensorflow.keras.layers.-input-spec.pbtxt index 5aef9ca71..a20ca1d4f 100644 --- a/tf_keras/api/golden/v2/tensorflow.keras.layers.-input-spec.pbtxt +++ b/tf_keras/api/golden/v2/tensorflow.keras.layers.-input-spec.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\', \'allow_last_axis_squeeze\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\'], " + argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\', \'allow_last_axis_squeeze\', \'name\', \'optional\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], " } member_method { name: "from_config" diff --git a/tf_keras/api/golden/v2/tensorflow.keras.layers.pbtxt b/tf_keras/api/golden/v2/tensorflow.keras.layers.pbtxt index a2b218a4c..1fcb35be4 100644 --- a/tf_keras/api/golden/v2/tensorflow.keras.layers.pbtxt +++ b/tf_keras/api/golden/v2/tensorflow.keras.layers.pbtxt @@ -538,7 +538,7 @@ tf_module { } member_method { name: "Input" - argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\', \'optional\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\'], " } member_method { name: "add" diff --git a/tf_keras/api/golden/v2/tensorflow.keras.pbtxt b/tf_keras/api/golden/v2/tensorflow.keras.pbtxt index c080bc275..a460f246d 100644 --- a/tf_keras/api/golden/v2/tensorflow.keras.pbtxt +++ b/tf_keras/api/golden/v2/tensorflow.keras.pbtxt @@ -95,6 +95,6 @@ tf_module { } member_method { name: "Input" - argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\', \'optional\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\'], " } } diff --git a/tf_keras/engine/data_adapter.py b/tf_keras/engine/data_adapter.py index 179f097ef..01f47180b 100644 --- a/tf_keras/engine/data_adapter.py +++ b/tf_keras/engine/data_adapter.py @@ -231,7 +231,9 @@ def _is_tensor(v): return True return False - return all(_is_tensor(v) for v in flat_inputs) + return all(_is_tensor(v) for v in flat_inputs if v is not None) and any( + _is_tensor(v) for v in flat_inputs + ) def __init__( self, @@ -259,7 +261,7 @@ def __init__( inputs = pack_x_y_sample_weight(x, y, sample_weights) num_samples = set( - int(i.shape[0]) for i in tf.nest.flatten(inputs) + int(i.shape[0]) for i in tf.nest.flatten(inputs) if i is not None ).pop() _check_data_cardinality(inputs) @@ -386,7 +388,7 @@ def slice_inputs(self, indices_dataset, inputs): def grab_batch(i, data): return tf.nest.map_structure( - lambda d: tf.gather(d, i, axis=0), data + lambda d: tf.gather(d, i, axis=0) if d is not None else d, data ) dataset = dataset.map(grab_batch, num_parallel_calls=tf.data.AUTOTUNE) @@ -459,7 +461,9 @@ def _is_array_like(v): if not TensorLikeDataAdapter.can_handle( x, y ) and not CompositeTensorDataAdapter.can_handle(x, y): - return all(_is_array_like(v) for v in flat_inputs) + return all( + _is_array_like(v) for v in flat_inputs if v is not None + ) and any(v is not None for v in flat_inputs) else: return False @@ -496,7 +500,7 @@ def dynamic_shape_like(t): shape[0] = None return tuple(shape) - flat_dtypes = [inp.dtype for inp in flat_inputs] + flat_dtypes = [inp.dtype for inp in flat_inputs if inp is not None] contiguous = True if self._shuffle and self._shuffle != "batch": contiguous = False @@ -509,15 +513,26 @@ def grab_batch(indices): # to a Tensor may force it into memory.. def py_method(ind): def slice_array(data): + if data is None: + return None return training_utils.slice_arrays( data, ind.numpy(), contiguous=contiguous ) - return [slice_array(inp) for inp in flat_inputs] + return [ + slice_array(inp) for inp in flat_inputs if inp is not None + ] - flat_out = tf.py_function(py_method, [indices], flat_dtypes) - for v, original_inp in zip(flat_out, flat_inputs): - v.set_shape(dynamic_shape_like(original_inp)) + results = tf.py_function(py_method, [indices], flat_dtypes) + results_it = iter(results) + flat_out = [] + for original_inp in flat_inputs: + if original_inp is None: + flat_out.append(None) + else: + v = next(results_it) + v.set_shape(dynamic_shape_like(original_inp)) + flat_out.append(v) return tf.nest.pack_sequence_as(inputs, flat_out) dataset = indices_dataset.map( @@ -608,8 +623,10 @@ def _is_tensor_or_composite(v): return True return _is_composite(v) - return any(_is_composite(v) for v in flat_inputs) and all( - _is_tensor_or_composite(v) for v in flat_inputs + return any( + _is_composite(v) for v in flat_inputs if v is not None + ) and all( + _is_tensor_or_composite(v) for v in flat_inputs if v is not None ) def __init__( @@ -1944,14 +1961,18 @@ def single_batch_iterator( def _check_data_cardinality(data): - num_samples = set(int(i.shape[0]) for i in tf.nest.flatten(data)) + num_samples = set( + int(i.shape[0]) for i in tf.nest.flatten(data) if i is not None + ) if len(num_samples) > 1: msg = "Data cardinality is ambiguous:\n" for label, single_data in zip(["x", "y", "sample_weight"], data): msg += " {} sizes: {}\n".format( label, ", ".join( - str(i.shape[0]) for i in tf.nest.flatten(single_data) + str(i.shape[0]) + for i in tf.nest.flatten(single_data) + if i is not None ), ) msg += "Make sure all arrays contain the same number of samples." diff --git a/tf_keras/engine/data_adapter_test.py b/tf_keras/engine/data_adapter_test.py index 9cc8ba071..ca350b8e3 100644 --- a/tf_keras/engine/data_adapter_test.py +++ b/tf_keras/engine/data_adapter_test.py @@ -25,6 +25,7 @@ from tf_keras.testing_infra import test_combinations from tf_keras.testing_infra import test_utils from tf_keras.utils import data_utils +from tf_keras.utils import dataset_creator # isort: off from tensorflow.python.eager import context @@ -427,6 +428,26 @@ def _get_epoch(ds_iter): # Check that each elements appears, and only once. self.assertAllClose(x, np.sort(second_epoch_data)) + def test_tensor_like_with_none_input(self): + x = [np.ones((10, 1), dtype=np.float32), None] + y = np.zeros((10, 1), dtype=np.float32) + self.assertTrue(data_adapter.TensorLikeDataAdapter.can_handle(x, y)) + adapter = data_adapter.TensorLikeDataAdapter( + x, y, batch_size=2, shuffle=False + ) + dataset = adapter.get_dataset() + self.assertEqual(adapter.get_size(), 5) + self.assertFalse(adapter.has_partial_batch()) + self.assertIsNone(adapter.partial_batch_size()) + for i, batch in enumerate(dataset): + x_batch, y_batch, _ = data_adapter.unpack_x_y_sample_weight(batch) + self.assertIsInstance(x_batch, tuple) + self.assertEqual(x_batch[0].shape, (2, 1)) + self.assertIsNone(x_batch[1]) + self.assertEqual(y_batch.shape, (2, 1)) + if i >= 4: + break + @test_combinations.run_all_keras_modes(always_skip_v1=True) def test_batch_shuffle_correctness(self): num_samples = 100 @@ -787,6 +808,28 @@ def _get_epoch(ds_iter): # Check that each elements appears, and only once. self.assertAllClose(x, np.sort(second_epoch_data)) + def test_generic_array_like_with_none_input(self): + x = [DummyArrayLike(np.ones((10, 1), dtype=np.float32)), None] + y = DummyArrayLike(np.zeros((10, 1), dtype=np.float32)) + self.assertTrue( + data_adapter.GenericArrayLikeDataAdapter.can_handle(x, y) + ) + adapter = data_adapter.GenericArrayLikeDataAdapter( + x, y, batch_size=2, shuffle=False + ) + dataset = adapter.get_dataset() + self.assertEqual(adapter.get_size(), 5) + self.assertFalse(adapter.has_partial_batch()) + self.assertIsNone(adapter.partial_batch_size()) + for i, batch in enumerate(dataset): + x_batch, y_batch, _ = data_adapter.unpack_x_y_sample_weight(batch) + self.assertIsInstance(x_batch, tuple) + self.assertEqual(x_batch[0].shape, (2, 1)) + self.assertIsNone(x_batch[1]) + self.assertEqual(y_batch.shape, (2, 1)) + if i >= 4: + break + @test_combinations.run_all_keras_modes(always_skip_v1=True) def test_batch_shuffle_correctness(self): num_samples = 100 @@ -885,6 +928,85 @@ def test_partial_batch( ) +class CompositeTensorDataAdapterTest(DataAdapterTestBase): + def setUp(self): + super().setUp() + self.adapter_cls = data_adapter.CompositeTensorDataAdapter + + def test_composite_tensor_with_none_input(self): + x = [ + tf.SparseTensor( + indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4] + ), + None, + ] + y = np.zeros((3, 1), dtype=np.float32) + self.assertTrue( + data_adapter.CompositeTensorDataAdapter.can_handle(x, y) + ) + adapter = data_adapter.CompositeTensorDataAdapter( + x, y, batch_size=2, shuffle=False + ) + dataset = adapter.get_dataset() + self.assertEqual(adapter.get_size(), 2) # 3 samples, batch_size=2 -> 2 + self.assertTrue(adapter.has_partial_batch()) + self.assertEqual(adapter.partial_batch_size(), 1) + + data = list(dataset) + self.assertEqual(len(data), 2) + + x_batch, y_batch, _ = data_adapter.unpack_x_y_sample_weight(data[0]) + self.assertIsInstance(x_batch, tuple) + self.assertEqual(x_batch[0].dense_shape.numpy().tolist(), [2, 4]) + self.assertIsNone(x_batch[1]) + self.assertEqual(y_batch.shape, (2, 1)) + + x_batch, y_batch, _ = data_adapter.unpack_x_y_sample_weight(data[1]) + self.assertIsInstance(x_batch, tuple) + self.assertEqual(x_batch[0].dense_shape.numpy().tolist(), [1, 4]) + self.assertIsNone(x_batch[1]) + self.assertEqual(y_batch.shape, (1, 1)) + + +class DatasetCreatorAdapterTest(DataAdapterTestBase): + def setUp(self): + super().setUp() + self.adapter_cls = data_adapter.DatasetCreatorAdapter + + def test_with_none_input(self): + def dataset_fn(input_context=None): + del input_context + x_0 = np.ones((10, 1), dtype=np.float32) + y = np.zeros((10, 1), dtype=np.float32) + ds = tf.data.Dataset.from_tensor_slices((x_0, y)) + + def map_fn(x0, y): + return tf.data.Dataset.from_tensors(((x0, None), y)) + + ds = ds.flat_map(map_fn) + return ds.batch(2) + + dc = dataset_creator.DatasetCreator(dataset_fn) + self.assertTrue(data_adapter.DatasetCreatorAdapter.can_handle(dc)) + adapter = data_adapter.DatasetCreatorAdapter( + dc, + y=None, + steps=5, + distribution_strategy=tf.distribute.get_strategy(), + ) + dataset = adapter.get_dataset() + self.assertIsNone(adapter.get_size()) + + for i, batch in enumerate(dataset): + x_batch, y_batch, _ = data_adapter.unpack_x_y_sample_weight(batch) + self.assertIsInstance(x_batch, tuple) + self.assertEqual(x_batch[0].shape, (2, 1)) + self.assertIsNone(x_batch[1]) + self.assertEqual(y_batch.shape, (2, 1)) + if i >= 4: + break + + class DatasetAdapterTest(DataAdapterTestBase): def setUp(self): super().setUp() diff --git a/tf_keras/engine/functional.py b/tf_keras/engine/functional.py index 53fcb5392..01cf3d880 100644 --- a/tf_keras/engine/functional.py +++ b/tf_keras/engine/functional.py @@ -351,25 +351,45 @@ def input_spec(self): if isinstance(self._nested_inputs, dict): # Case where `_nested_inputs` is a plain dict of Inputs. names = sorted(self._nested_inputs.keys()) - return [ - input_spec.InputSpec( - shape=shape_with_no_batch_size(self._nested_inputs[name]), - allow_last_axis_squeeze=True, - name=name, + specs = [] + for name in names: + layer = self._nested_inputs[name]._keras_history.layer + optional = ( + layer.optional + if isinstance(layer, input_layer_module.InputLayer) + else False ) - for name in names - ] + specs.append( + input_spec.InputSpec( + shape=shape_with_no_batch_size( + self._nested_inputs[name] + ), + allow_last_axis_squeeze=True, + name=name, + optional=optional, + ) + ) + return specs else: # Single input, or list / tuple of inputs. # The data may be passed as a dict keyed by input name. - return [ - input_spec.InputSpec( - shape=shape_with_no_batch_size(x), - allow_last_axis_squeeze=True, - name=x._keras_history.layer.name, + specs = [] + for x in self.inputs: + layer = x._keras_history.layer + optional = ( + layer.optional + if isinstance(layer, input_layer_module.InputLayer) + else False ) - for x in self.inputs - ] + specs.append( + input_spec.InputSpec( + shape=shape_with_no_batch_size(x), + allow_last_axis_squeeze=True, + name=x._keras_history.layer.name, + optional=optional, + ) + ) + return specs @input_spec.setter def input_spec(self, value): @@ -644,7 +664,8 @@ def _run_internal_graph(self, inputs, training=None, mask=None): else: masks = self._flatten_to_reference_inputs(mask) for input_t, mask in zip(inputs, masks): - input_t._keras_mask = mask + if input_t is not None: + input_t._keras_mask = mask # Dictionary mapping reference tensors to computed tensors. tensor_dict = {} diff --git a/tf_keras/engine/functional_test.py b/tf_keras/engine/functional_test.py index 76c65fc98..7db5b12a1 100644 --- a/tf_keras/engine/functional_test.py +++ b/tf_keras/engine/functional_test.py @@ -2010,6 +2010,33 @@ def test_dict_inputs_tensors(self): self.assertEqual(model.inputs[0]._keras_history.layer.name, "b") self.assertEqual(model.inputs[1]._keras_history.layer.name, "a") + @test_combinations.generate(test_combinations.keras_mode_combinations()) + def test_model_with_optional_input(self): + class CustomAdd(layers.Layer): + def call(self, input_a, input_b=None): + if input_b is None: + return input_a + return input_a + input_b + + input_a = input_layer_lib.Input(shape=(2,)) + input_b = input_layer_lib.Input(shape=(2,), optional=True) + added = CustomAdd()(input_a, input_b) + outputs = layers.Dense(2, activation="relu")(added) + model = training_lib.Model(inputs=[input_a, input_b], outputs=outputs) + + x1 = np.ones((100, 2)) + x2 = None + y = np.ones((100, 2)) + + model.compile( + optimizer="sgd", + loss="mse", + run_eagerly=test_utils.should_run_eagerly(), + ) + model.fit([x1, x2], y, batch_size=2, epochs=1) + model.evaluate([x1, x2], y) + model.predict([x1, x2]) + class GraphUtilsTest(tf.test.TestCase): def testGetReachableFromInputs(self): diff --git a/tf_keras/engine/input_layer.py b/tf_keras/engine/input_layer.py index 831e3a227..1ade25f5b 100644 --- a/tf_keras/engine/input_layer.py +++ b/tf_keras/engine/input_layer.py @@ -98,6 +98,8 @@ class InputLayer(base_layer.Layer): `tf.TypeSpec` represents the entire batch. When provided, all other args except name must be `None`. name: Optional name of the layer (string). + optional: Boolean, whether the input is optional or not. + An optional input can accept `None` values. """ @traceback_utils.filter_traceback @@ -111,6 +113,7 @@ def __init__( name=None, ragged=None, type_spec=None, + optional=False, **kwargs, ): self._init_input_shape = input_shape @@ -180,6 +183,7 @@ def __init__( self.ragged = True if ragged else False self.batch_size = batch_size self.supports_masking = True + self.optional = optional if isinstance(input_shape, tf.TensorShape): input_shape = tuple(input_shape.as_list()) @@ -284,6 +288,7 @@ def get_config(self): "sparse": self.sparse, "ragged": self.ragged, "name": self.name, + "optional": self.optional, } return config @@ -303,6 +308,7 @@ def Input( tensor=None, ragged=None, type_spec=None, + optional=False, **kwargs, ): """`Input()` is used to instantiate a TF-Keras tensor. @@ -341,6 +347,8 @@ def Input( [this guide](https://www.tensorflow.org/guide/ragged_tensor). type_spec: A `tf.TypeSpec` object to create the input placeholder from. When provided, all other args except name must be None. + optional: Boolean, whether the input is optional or not. + An optional input can accept `None` values. **kwargs: deprecated arguments support. Supports `batch_shape` and `batch_input_shape`. @@ -415,6 +423,7 @@ def Input( "ragged": ragged, "input_tensor": tensor, "type_spec": type_spec, + "optional": optional, } batch_input_shape = kwargs.pop( diff --git a/tf_keras/engine/input_spec.py b/tf_keras/engine/input_spec.py index ccab9165d..169ddd245 100644 --- a/tf_keras/engine/input_spec.py +++ b/tf_keras/engine/input_spec.py @@ -56,6 +56,8 @@ class InputSpec: as long as the last axis of the spec is 1. name: Expected key corresponding to this input when passing data as a dictionary. + optional: Boolean, whether the input is optional or not. + An optional input can accept `None` values. Example: @@ -82,6 +84,7 @@ def __init__( axes=None, allow_last_axis_squeeze=False, name=None, + optional=False, ): self.dtype = tf.as_dtype(dtype).name if dtype is not None else None shape = tf.TensorShape(shape) @@ -99,6 +102,7 @@ def __init__( self.min_ndim = min_ndim self.name = name self.allow_last_axis_squeeze = allow_last_axis_squeeze + self.optional = optional try: axes = axes or {} self.axes = {int(k): axes[k] for k in axes} @@ -204,7 +208,11 @@ def assert_input_compatibility(input_spec, inputs, layer_name): inputs = list_inputs inputs = tf.nest.flatten(inputs) - for x in inputs: + for _, (x, spec) in enumerate(zip(inputs, input_spec)): + if spec is None: + continue + if x is None and spec.optional: + continue # Having a shape/dtype is the only commonality of the various # tensor-like objects that may be passed. The most common kind of # invalid type we are guarding for is a Layer instance (Functional API), @@ -224,6 +232,8 @@ def assert_input_compatibility(input_spec, inputs, layer_name): for input_index, (x, spec) in enumerate(zip(inputs, input_spec)): if spec is None: continue + if x is None and spec.optional: + continue shape = tf.TensorShape(x.shape) if shape.rank is None: diff --git a/tf_keras/engine/training_utils_v1.py b/tf_keras/engine/training_utils_v1.py index 0c4773c43..4017019cc 100644 --- a/tf_keras/engine/training_utils_v1.py +++ b/tf_keras/engine/training_utils_v1.py @@ -694,7 +694,7 @@ def standardize_input_data( # Check shapes compatibility. if shapes: for i in range(len(names)): - if shapes[i] is not None: + if shapes[i] is not None and data[i] is not None: if tf.is_tensor(data[i]): tensorshape = data[i].shape if not tensorshape: diff --git a/tf_keras/engine/training_v1.py b/tf_keras/engine/training_v1.py index 087da4ab9..6d6f77e46 100644 --- a/tf_keras/engine/training_v1.py +++ b/tf_keras/engine/training_v1.py @@ -46,6 +46,9 @@ from tf_keras.utils.mode_keys import ModeKeys # isort: off +from tensorflow.python.framework.none_tensor import ( + NoneTensorSpec, +) from tensorflow.python.platform import tf_logging as logging try: @@ -2220,9 +2223,9 @@ def _handle_metrics( target, output, output_mask, - weights=sample_weights[i] - if sample_weights - else None, + weights=( + sample_weights[i] if sample_weights else None + ), ) ) return metric_results @@ -2727,7 +2730,8 @@ def _standardize_tensors( tf_utils.convert_variables_to_tensors(self.inputs) ) for a, b in zip(flat_inputs, flat_expected_inputs): - tf.nest.assert_same_structure(a, b, expand_composites=True) + if type(a) is not NoneTensorSpec: + tf.nest.assert_same_structure(a, b, expand_composites=True) if y is not None: # Prepare self._sample_weight_modes. List with the same length as