Skip to content

Commit 0f8d89e

Browse files
Lukasz Kaisercopybara-github
Lukasz Kaiser
authored andcommitted
Make data.Serial return a generator function instead of a generator so it can compose, add default arg for easy calling.
PiperOrigin-RevId: 323283358
1 parent a2497cb commit 0f8d89e

File tree

3 files changed

+19
-8
lines changed

3 files changed

+19
-8
lines changed

trax/data/inputs.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,12 @@
8282

8383

8484
def Serial(*fns): # pylint: disable=invalid-name
85-
"""Creates an input pipeline by running all functions one after another."""
86-
generator = None
87-
for f in fastmath.tree_flatten(fns):
88-
generator = f(generator)
89-
return generator
85+
"""Combines generator functions into one that runs them in turn."""
86+
def composed_fns(generator=None):
87+
for f in fastmath.tree_flatten(fns):
88+
generator = f(generator)
89+
return generator
90+
return composed_fns
9091

9192

9293
def Log(n_steps_per_example=1, only_shapes=True): # pylint: disable=invalid-name

trax/data/inputs_test.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,16 @@ def test_batch_data(self):
7676
def test_serial(self):
7777
dataset = lambda _: ((i, i+1) for i in range(10))
7878
batches = data.Serial(dataset, data.Shuffle(3), data.Batch(10))
79-
batch = next(batches)
79+
batch = next(batches())
80+
self.assertLen(batch, 2)
81+
self.assertEqual(batch[0].shape, (10,))
82+
83+
def test_serial_composes(self):
84+
"""Check that data.Serial works inside another data.Serial."""
85+
dataset = lambda _: ((i, i+1) for i in range(10))
86+
serial1 = data.Serial(dataset, data.Shuffle(3))
87+
batches = data.Serial(serial1, data.Batch(10))
88+
batch = next(batches())
8089
self.assertLen(batch, 2)
8190
self.assertEqual(batch[0].shape, (10,))
8291

@@ -88,7 +97,7 @@ def test_serial_with_python(self):
8897
lambda g: filter(lambda x: x[0] % 2 == 1, g),
8998
data.Batch(2)
9099
)
91-
batch = next(batches)
100+
batch = next(batches())
92101
self.assertLen(batch, 2)
93102
(xs, ys) = batch
94103
# First tuple after filtering is (1, 3) = (1, 2+1).

trax/data/tf_inputs.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,8 @@ def select_from(example):
242242
dataset = dataset.map(select_from)
243243
dataset = dataset.repeat()
244244

245-
def gen(unused_arg):
245+
def gen(generator=None):
246+
del generator
246247
for example in fastmath.dataset_as_numpy(dataset):
247248
yield example
248249
return gen

0 commit comments

Comments
 (0)