|
36 | 36 | from etils import epath
|
37 | 37 | from tensorflow_datasets.core import example_parser
|
38 | 38 | from tensorflow_datasets.core import example_serializer
|
| 39 | + from tensorflow_datasets.core import features as features_lib |
39 | 40 | from tensorflow_datasets.core import file_adapters
|
40 | 41 | from tensorflow_datasets.core import hashing
|
41 | 42 | from tensorflow_datasets.core import naming
|
@@ -264,27 +265,34 @@ class ShardWriter:
|
264 | 265 |
|
265 | 266 | def __init__(
|
266 | 267 | self,
|
| 268 | + features: features_lib.FeatureConnector, |
267 | 269 | serializer: example_serializer.Serializer,
|
268 | 270 | example_writer: ExampleWriter,
|
269 | 271 | ):
|
270 | 272 | """Initializes Writer.
|
271 | 273 |
|
272 | 274 | Args:
|
| 275 | + features: the features of the dataset. |
273 | 276 | serializer: class that can serialize examples.
|
274 | 277 | example_writer: class that writes examples to disk or elsewhere.
|
275 | 278 | """
|
| 279 | + self._features = features |
276 | 280 | self._serializer = serializer
|
277 | 281 | self._example_writer = example_writer
|
278 | 282 |
|
| 283 | + def _serialize_example(self, example: Example) -> Any: |
| 284 | + """Encodes and serializes an example.""" |
| 285 | + return self._serializer.serialize_example( |
| 286 | + self._features.encode_example(example) |
| 287 | + ) |
| 288 | + |
279 | 289 | def write(
|
280 | 290 | self,
|
281 | 291 | examples: Iterable[KeyExample],
|
282 | 292 | path: epath.Path,
|
283 | 293 | ) -> int:
|
284 | 294 | """Returns the number of examples written to the given path."""
|
285 |
| - serialized_examples = [ |
286 |
| - (k, self._serializer.serialize_example(v)) for k, v in examples |
287 |
| - ] |
| 295 | + serialized_examples = [(k, self._serialize_example(v)) for k, v in examples] |
288 | 296 | self._example_writer.write(path=path, examples=serialized_examples)
|
289 | 297 |
|
290 | 298 | return len(serialized_examples)
|
|
0 commit comments