Skip to content

Commit caa4485

Browse files
fineguyThe TensorFlow Datasets Authors
authored and
The TensorFlow Datasets Authors
committed
Fix build_test.py
PiperOrigin-RevId: 669282418
1 parent a4368fe commit caa4485

File tree

2 files changed

+35
-21
lines changed

2 files changed

+35
-21
lines changed

tensorflow_datasets/scripts/cli/build.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def _make_download_config(
367367
# Load the download config
368368
manual_dir = args.manual_dir
369369
if args.add_name_to_manual_dir:
370-
manual_dir = os.path.join(manual_dir, dataset_name)
370+
manual_dir = manual_dir / dataset_name
371371

372372
kwargs = {}
373373
if args.max_shard_size_mb:

tensorflow_datasets/scripts/cli/build_test.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@
1616
"""Tests for tensorflow_datasets.scripts.cli.build."""
1717

1818
import contextlib
19+
import dataclasses
20+
import multiprocessing
1921
import os
2022
import pathlib
2123
from typing import Dict, Iterator, List, Optional
2224
from unittest import mock
2325

24-
from absl.testing import parameterized
2526
from etils import epath
2627
import pytest
2728
import tensorflow_datasets as tfds
@@ -89,12 +90,12 @@ def _build(cmd_flags: str, mock_download_and_prepare: bool = True) -> List[str]:
8990
# to patch the function to record the generated_ds manually.
9091
# See:
9192
# https://stackoverflow.com/questions/64792295/how-to-get-self-instance-in-mock-mock-call-args
92-
generated_ds_names = []
93+
queue = multiprocessing.Queue()
9394

9495
def _download_and_prepare(self, *args, **kwargs):
9596
# Remove version from generated name (as only last version can be generated)
9697
full_name = '/'.join(self.info.full_name.split('/')[:-1])
97-
generated_ds_names.append(full_name)
98+
queue.put(full_name)
9899
if mock_download_and_prepare:
99100
return
100101
else:
@@ -105,6 +106,12 @@ def _download_and_prepare(self, *args, **kwargs):
105106
_download_and_prepare,
106107
):
107108
main.main(args)
109+
queue.put(None)
110+
111+
generated_ds_names = []
112+
while full_name := queue.get():
113+
generated_ds_names.append(full_name)
114+
108115
return generated_ds_names
109116

110117

@@ -139,10 +146,10 @@ def test_build_multiple():
139146
]
140147

141148

142-
@parameterized.parameters(range(5))
149+
@pytest.mark.parametrize('num_processes', range(1, 4))
143150
def test_build_parallel(num_processes):
144151
# Order is not guaranteed
145-
assert set(_build(f'trivia_qa --num-proccesses={num_processes}')) == set([
152+
assert set(_build(f'trivia_qa --num-processes={num_processes}')) == set([
146153
'trivia_qa/rc',
147154
'trivia_qa/rc.nocontext',
148155
'trivia_qa/unfiltered',
@@ -288,22 +295,29 @@ def test_download_only():
288295
mock_download.assert_called_with({'file0': 'http://data.org/file1.zip'})
289296

290297

291-
@parameterized.parameters(
292-
('--manual_dir=/a/b', {'manual_dir': '/a/b'}),
293-
('--manual_dir=/a/b --add_name_to_manual_dir', {'manual_dir': '/a/b/x'}),
294-
('--extract_dir=/a/b', {'extract_dir': '/a/b'}),
295-
('--max_examples_per_split=42', {'max_examples_per_split': 42}),
296-
('--register_checksums', {'register_checksums': True}),
297-
('--force_checksums_validation', {'force_checksums_validation': True}),
298-
('--max_shard_size_mb=128', {'max_shard_size': 128 << 20}),
299-
(
300-
'--download_config="{\'max_shard_size\': 1234}"',
301-
{'max_shard_size': 1234},
302-
),
298+
@pytest.mark.parametrize(
299+
'args,download_config_kwargs',
300+
[
301+
('--manual_dir=/a/b', {'manual_dir': epath.Path('/a/b')}),
302+
(
303+
'--manual_dir=/a/b --add_name_to_manual_dir',
304+
{'manual_dir': epath.Path('/a/b/x')},
305+
),
306+
('--extract_dir=/a/b', {'extract_dir': epath.Path('/a/b')}),
307+
('--max_examples_per_split=42', {'max_examples_per_split': 42}),
308+
('--register_checksums', {'register_checksums': True}),
309+
('--force_checksums_validation', {'force_checksums_validation': True}),
310+
('--max_shard_size_mb=128', {'max_shard_size': 128 << 20}),
311+
(
312+
'--download_config={"max_shard_size":1234}',
313+
{'max_shard_size': 1234},
314+
),
315+
],
303316
)
304317
def test_make_download_config(args: str, download_config_kwargs):
305-
args = main._parse_flags(f'tfds build x {download_config_kwargs}'.split())
318+
args = main._parse_flags(f'tfds build x {args}'.split())
306319
actual = build_lib._make_download_config(args, dataset_name='x')
307320
# Ignore the beam runner
308-
actual.replace(beam_runner=None)
309-
assert actual == tfds.download.DownloadConfig(**download_config_kwargs)
321+
actual = actual.replace(beam_runner=None)
322+
expected = tfds.download.DownloadConfig(**download_config_kwargs)
323+
assert dataclasses.asdict(actual) == dataclasses.asdict(expected)

0 commit comments

Comments
 (0)