Skip to content

Commit 3a1d0b0

Browse files
Allow Input to be optional to take None inputs, similar to what keras3 has.
PiperOrigin-RevId: 819935785
1 parent 2890094 commit 3a1d0b0

15 files changed

+254
-44
lines changed

tf_keras/api/golden/v1/tensorflow.keras.__internal__.legacy.layers.-input-spec.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ tf_class {
44
is_instance: "<type \'object\'>"
55
member_method {
66
name: "__init__"
7-
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\', \'allow_last_axis_squeeze\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\'], "
7+
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\', \'allow_last_axis_squeeze\', \'name\', \'optional\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
88
}
99
member_method {
1010
name: "from_config"

tf_keras/api/golden/v1/tensorflow.keras.layers.-input-layer.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ tf_class {
129129
}
130130
member_method {
131131
name: "__init__"
132-
argspec: "args=[\'self\', \'input_shape\', \'batch_size\', \'dtype\', \'input_tensor\', \'sparse\', \'name\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
132+
argspec: "args=[\'self\', \'input_shape\', \'batch_size\', \'dtype\', \'input_tensor\', \'sparse\', \'name\', \'ragged\', \'type_spec\', \'optional\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\'], "
133133
}
134134
member_method {
135135
name: "add_loss"

tf_keras/api/golden/v1/tensorflow.keras.layers.-input-spec.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ tf_class {
44
is_instance: "<type \'object\'>"
55
member_method {
66
name: "__init__"
7-
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\', \'allow_last_axis_squeeze\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\'], "
7+
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\', \'allow_last_axis_squeeze\', \'name\', \'optional\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
88
}
99
member_method {
1010
name: "from_config"

tf_keras/api/golden/v1/tensorflow.keras.layers.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,7 @@ tf_module {
482482
}
483483
member_method {
484484
name: "Input"
485-
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
485+
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\', \'optional\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\'], "
486486
}
487487
member_method {
488488
name: "add"

tf_keras/api/golden/v1/tensorflow.keras.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,6 @@ tf_module {
9090
}
9191
member_method {
9292
name: "Input"
93-
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
93+
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\', \'optional\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\'], "
9494
}
9595
}

tf_keras/api/golden/v2/tensorflow.keras.layers.-input-layer.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ tf_class {
129129
}
130130
member_method {
131131
name: "__init__"
132-
argspec: "args=[\'self\', \'input_shape\', \'batch_size\', \'dtype\', \'input_tensor\', \'sparse\', \'name\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
132+
argspec: "args=[\'self\', \'input_shape\', \'batch_size\', \'dtype\', \'input_tensor\', \'sparse\', \'name\', \'ragged\', \'type_spec\', \'optional\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\'], "
133133
}
134134
member_method {
135135
name: "add_loss"

tf_keras/api/golden/v2/tensorflow.keras.layers.-input-spec.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ tf_class {
44
is_instance: "<type \'object\'>"
55
member_method {
66
name: "__init__"
7-
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\', \'allow_last_axis_squeeze\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\'], "
7+
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\', \'allow_last_axis_squeeze\', \'name\', \'optional\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
88
}
99
member_method {
1010
name: "from_config"

tf_keras/api/golden/v2/tensorflow.keras.layers.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ tf_module {
538538
}
539539
member_method {
540540
name: "Input"
541-
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
541+
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\', \'optional\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\'], "
542542
}
543543
member_method {
544544
name: "add"

tf_keras/api/golden/v2/tensorflow.keras.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,6 @@ tf_module {
9595
}
9696
member_method {
9797
name: "Input"
98-
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
98+
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\', \'optional\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\'], "
9999
}
100100
}

tf_keras/engine/data_adapter.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def _is_tensor(v):
231231
return True
232232
return False
233233

234-
return all(_is_tensor(v) for v in flat_inputs)
234+
return all(_is_tensor(v) for v in flat_inputs if v is not None)
235235

236236
def __init__(
237237
self,
@@ -259,7 +259,7 @@ def __init__(
259259
inputs = pack_x_y_sample_weight(x, y, sample_weights)
260260

261261
num_samples = set(
262-
int(i.shape[0]) for i in tf.nest.flatten(inputs)
262+
int(i.shape[0]) for i in tf.nest.flatten(inputs) if i is not None
263263
).pop()
264264
_check_data_cardinality(inputs)
265265

@@ -386,7 +386,7 @@ def slice_inputs(self, indices_dataset, inputs):
386386

387387
def grab_batch(i, data):
388388
return tf.nest.map_structure(
389-
lambda d: tf.gather(d, i, axis=0), data
389+
lambda d: tf.gather(d, i, axis=0) if d is not None else d, data
390390
)
391391

392392
dataset = dataset.map(grab_batch, num_parallel_calls=tf.data.AUTOTUNE)
@@ -459,7 +459,7 @@ def _is_array_like(v):
459459
if not TensorLikeDataAdapter.can_handle(
460460
x, y
461461
) and not CompositeTensorDataAdapter.can_handle(x, y):
462-
return all(_is_array_like(v) for v in flat_inputs)
462+
return all(_is_array_like(v) for v in flat_inputs if v is not None)
463463
else:
464464
return False
465465

@@ -496,7 +496,7 @@ def dynamic_shape_like(t):
496496
shape[0] = None
497497
return tuple(shape)
498498

499-
flat_dtypes = [inp.dtype for inp in flat_inputs]
499+
flat_dtypes = [inp.dtype for inp in flat_inputs if inp is not None]
500500
contiguous = True
501501
if self._shuffle and self._shuffle != "batch":
502502
contiguous = False
@@ -509,15 +509,26 @@ def grab_batch(indices):
509509
# to a Tensor may force it into memory..
510510
def py_method(ind):
511511
def slice_array(data):
512+
if data is None:
513+
return None
512514
return training_utils.slice_arrays(
513515
data, ind.numpy(), contiguous=contiguous
514516
)
515517

516-
return [slice_array(inp) for inp in flat_inputs]
518+
return [
519+
slice_array(inp) for inp in flat_inputs if inp is not None
520+
]
517521

518-
flat_out = tf.py_function(py_method, [indices], flat_dtypes)
519-
for v, original_inp in zip(flat_out, flat_inputs):
520-
v.set_shape(dynamic_shape_like(original_inp))
522+
results = tf.py_function(py_method, [indices], flat_dtypes)
523+
results_it = iter(results)
524+
flat_out = []
525+
for original_inp in flat_inputs:
526+
if original_inp is None:
527+
flat_out.append(None)
528+
else:
529+
v = next(results_it)
530+
v.set_shape(dynamic_shape_like(original_inp))
531+
flat_out.append(v)
521532
return tf.nest.pack_sequence_as(inputs, flat_out)
522533

523534
dataset = indices_dataset.map(
@@ -608,8 +619,10 @@ def _is_tensor_or_composite(v):
608619
return True
609620
return _is_composite(v)
610621

611-
return any(_is_composite(v) for v in flat_inputs) and all(
612-
_is_tensor_or_composite(v) for v in flat_inputs
622+
return any(
623+
_is_composite(v) for v in flat_inputs if v is not None
624+
) and all(
625+
_is_tensor_or_composite(v) for v in flat_inputs if v is not None
613626
)
614627

615628
def __init__(
@@ -1944,14 +1957,18 @@ def single_batch_iterator(
19441957

19451958

19461959
def _check_data_cardinality(data):
1947-
num_samples = set(int(i.shape[0]) for i in tf.nest.flatten(data))
1960+
num_samples = set(
1961+
int(i.shape[0]) for i in tf.nest.flatten(data) if i is not None
1962+
)
19481963
if len(num_samples) > 1:
19491964
msg = "Data cardinality is ambiguous:\n"
19501965
for label, single_data in zip(["x", "y", "sample_weight"], data):
19511966
msg += " {} sizes: {}\n".format(
19521967
label,
19531968
", ".join(
1954-
str(i.shape[0]) for i in tf.nest.flatten(single_data)
1969+
str(i.shape[0])
1970+
for i in tf.nest.flatten(single_data)
1971+
if i is not None
19551972
),
19561973
)
19571974
msg += "Make sure all arrays contain the same number of samples."

0 commit comments

Comments
 (0)