@@ -231,7 +231,7 @@ def _is_tensor(v):
231231 return True
232232 return False
233233
234- return all (_is_tensor (v ) for v in flat_inputs )
234+ return all (_is_tensor (v ) for v in flat_inputs if v is not None )
235235
236236 def __init__ (
237237 self ,
@@ -259,7 +259,7 @@ def __init__(
259259 inputs = pack_x_y_sample_weight (x , y , sample_weights )
260260
261261 num_samples = set (
262- int (i .shape [0 ]) for i in tf .nest .flatten (inputs )
262+ int (i .shape [0 ]) for i in tf .nest .flatten (inputs ) if i is not None
263263 ).pop ()
264264 _check_data_cardinality (inputs )
265265
@@ -386,7 +386,7 @@ def slice_inputs(self, indices_dataset, inputs):
386386
387387 def grab_batch (i , data ):
388388 return tf .nest .map_structure (
389- lambda d : tf .gather (d , i , axis = 0 ), data
389+ lambda d : tf .gather (d , i , axis = 0 ) if d is not None else d , data
390390 )
391391
392392 dataset = dataset .map (grab_batch , num_parallel_calls = tf .data .AUTOTUNE )
@@ -459,7 +459,7 @@ def _is_array_like(v):
459459 if not TensorLikeDataAdapter .can_handle (
460460 x , y
461461 ) and not CompositeTensorDataAdapter .can_handle (x , y ):
462- return all (_is_array_like (v ) for v in flat_inputs )
462+ return all (_is_array_like (v ) for v in flat_inputs if v is not None )
463463 else :
464464 return False
465465
@@ -496,7 +496,7 @@ def dynamic_shape_like(t):
496496 shape [0 ] = None
497497 return tuple (shape )
498498
499- flat_dtypes = [inp .dtype for inp in flat_inputs ]
499+ flat_dtypes = [inp .dtype for inp in flat_inputs if inp is not None ]
500500 contiguous = True
501501 if self ._shuffle and self ._shuffle != "batch" :
502502 contiguous = False
@@ -509,15 +509,26 @@ def grab_batch(indices):
509509 # to a Tensor may force it into memory..
510510 def py_method (ind ):
511511 def slice_array (data ):
512+ if data is None :
513+ return None
512514 return training_utils .slice_arrays (
513515 data , ind .numpy (), contiguous = contiguous
514516 )
515517
516- return [slice_array (inp ) for inp in flat_inputs ]
518+ return [
519+ slice_array (inp ) for inp in flat_inputs if inp is not None
520+ ]
517521
518- flat_out = tf .py_function (py_method , [indices ], flat_dtypes )
519- for v , original_inp in zip (flat_out , flat_inputs ):
520- v .set_shape (dynamic_shape_like (original_inp ))
522+ results = tf .py_function (py_method , [indices ], flat_dtypes )
523+ results_it = iter (results )
524+ flat_out = []
525+ for original_inp in flat_inputs :
526+ if original_inp is None :
527+ flat_out .append (None )
528+ else :
529+ v = next (results_it )
530+ v .set_shape (dynamic_shape_like (original_inp ))
531+ flat_out .append (v )
521532 return tf .nest .pack_sequence_as (inputs , flat_out )
522533
523534 dataset = indices_dataset .map (
@@ -608,8 +619,10 @@ def _is_tensor_or_composite(v):
608619 return True
609620 return _is_composite (v )
610621
611- return any (_is_composite (v ) for v in flat_inputs ) and all (
612- _is_tensor_or_composite (v ) for v in flat_inputs
622+ return any (
623+ _is_composite (v ) for v in flat_inputs if v is not None
624+ ) and all (
625+ _is_tensor_or_composite (v ) for v in flat_inputs if v is not None
613626 )
614627
615628 def __init__ (
@@ -1944,14 +1957,18 @@ def single_batch_iterator(
19441957
19451958
19461959def _check_data_cardinality (data ):
1947- num_samples = set (int (i .shape [0 ]) for i in tf .nest .flatten (data ))
1960+ num_samples = set (
1961+ int (i .shape [0 ]) for i in tf .nest .flatten (data ) if i is not None
1962+ )
19481963 if len (num_samples ) > 1 :
19491964 msg = "Data cardinality is ambiguous:\n "
19501965 for label , single_data in zip (["x" , "y" , "sample_weight" ], data ):
19511966 msg += " {} sizes: {}\n " .format (
19521967 label ,
19531968 ", " .join (
1954- str (i .shape [0 ]) for i in tf .nest .flatten (single_data )
1969+ str (i .shape [0 ])
1970+ for i in tf .nest .flatten (single_data )
1971+ if i is not None
19551972 ),
19561973 )
19571974 msg += "Make sure all arrays contain the same number of samples."
0 commit comments