File tree 3 files changed +19
-8
lines changed
3 files changed +19
-8
lines changed Original file line number Diff line number Diff line change 82
82
83
83
84
84
def Serial (* fns ): # pylint: disable=invalid-name
85
- """Creates an input pipeline by running all functions one after another."""
86
- generator = None
87
- for f in fastmath .tree_flatten (fns ):
88
- generator = f (generator )
89
- return generator
85
+ """Combines generator functions into one that runs them in turn."""
86
+ def composed_fns (generator = None ):
87
+ for f in fastmath .tree_flatten (fns ):
88
+ generator = f (generator )
89
+ return generator
90
+ return composed_fns
90
91
91
92
92
93
def Log (n_steps_per_example = 1 , only_shapes = True ): # pylint: disable=invalid-name
Original file line number Diff line number Diff line change @@ -76,7 +76,16 @@ def test_batch_data(self):
76
76
def test_serial (self ):
77
77
dataset = lambda _ : ((i , i + 1 ) for i in range (10 ))
78
78
batches = data .Serial (dataset , data .Shuffle (3 ), data .Batch (10 ))
79
- batch = next (batches )
79
+ batch = next (batches ())
80
+ self .assertLen (batch , 2 )
81
+ self .assertEqual (batch [0 ].shape , (10 ,))
82
+
83
+ def test_serial_composes (self ):
84
+ """Check that data.Serial works inside another data.Serial."""
85
+ dataset = lambda _ : ((i , i + 1 ) for i in range (10 ))
86
+ serial1 = data .Serial (dataset , data .Shuffle (3 ))
87
+ batches = data .Serial (serial1 , data .Batch (10 ))
88
+ batch = next (batches ())
80
89
self .assertLen (batch , 2 )
81
90
self .assertEqual (batch [0 ].shape , (10 ,))
82
91
@@ -88,7 +97,7 @@ def test_serial_with_python(self):
88
97
lambda g : filter (lambda x : x [0 ] % 2 == 1 , g ),
89
98
data .Batch (2 )
90
99
)
91
- batch = next (batches )
100
+ batch = next (batches () )
92
101
self .assertLen (batch , 2 )
93
102
(xs , ys ) = batch
94
103
# First tuple after filtering is (1, 3) = (1, 2+1).
Original file line number Diff line number Diff line change @@ -242,7 +242,8 @@ def select_from(example):
242
242
dataset = dataset .map (select_from )
243
243
dataset = dataset .repeat ()
244
244
245
- def gen (unused_arg ):
245
+ def gen (generator = None ):
246
+ del generator
246
247
for example in fastmath .dataset_as_numpy (dataset ):
247
248
yield example
248
249
return gen
You can’t perform that action at this time.
0 commit comments