Skip to content

Commit 0428a9c

Browse files
tomvdwThe TensorFlow Datasets Authors
authored andcommitted
Encode examples before writing them in the ShardBasedBuilder
PiperOrigin-RevId: 807804529
1 parent a5d9682 commit 0428a9c

File tree

3 files changed

+13
-3
lines changed

3 files changed

+13
-3
lines changed

tensorflow_datasets/core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from tensorflow_datasets.core.dataset_builder import BuilderConfig
2424
from tensorflow_datasets.core.dataset_builder import DatasetBuilder
2525
from tensorflow_datasets.core.dataset_builder import GeneratorBasedBuilder
26+
from tensorflow_datasets.core.dataset_builder import ShardBasedBuilder
2627
from tensorflow_datasets.core.dataset_info import BeamMetadataDict
2728
from tensorflow_datasets.core.dataset_info import DatasetIdentity
2829
from tensorflow_datasets.core.dataset_info import DatasetInfo

tensorflow_datasets/core/split_builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def submit_shard_based_generation(
184184
serializer = example_serializer.ExampleSerializer(serialized_info)
185185

186186
shard_writer = writer_lib.ShardWriter(
187+
features=self._features,
187188
serializer=serializer,
188189
example_writer=self._example_writer,
189190
)

tensorflow_datasets/core/writer.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from etils import epath
3737
from tensorflow_datasets.core import example_parser
3838
from tensorflow_datasets.core import example_serializer
39+
from tensorflow_datasets.core import features as features_lib
3940
from tensorflow_datasets.core import file_adapters
4041
from tensorflow_datasets.core import hashing
4142
from tensorflow_datasets.core import naming
@@ -264,27 +265,34 @@ class ShardWriter:
264265

265266
def __init__(
266267
self,
268+
features: features_lib.FeatureConnector,
267269
serializer: example_serializer.Serializer,
268270
example_writer: ExampleWriter,
269271
):
270272
"""Initializes Writer.
271273
272274
Args:
275+
features: the features of the dataset.
273276
serializer: class that can serialize examples.
274277
example_writer: class that writes examples to disk or elsewhere.
275278
"""
279+
self._features = features
276280
self._serializer = serializer
277281
self._example_writer = example_writer
278282

283+
def _serialize_example(self, example: Example) -> Any:
284+
"""Encodes and serializes an example."""
285+
return self._serializer.serialize_example(
286+
self._features.encode_example(example)
287+
)
288+
279289
def write(
280290
self,
281291
examples: Iterable[KeyExample],
282292
path: epath.Path,
283293
) -> int:
284294
"""Returns the number of examples written to the given path."""
285-
serialized_examples = [
286-
(k, self._serializer.serialize_example(v)) for k, v in examples
287-
]
295+
serialized_examples = [(k, self._serialize_example(v)) for k, v in examples]
288296
self._example_writer.write(path=path, examples=serialized_examples)
289297

290298
return len(serialized_examples)

0 commit comments

Comments
 (0)