16
16
"""Tests for tensorflow_datasets.scripts.cli.build."""
17
17
18
18
import contextlib
19
+ import dataclasses
20
+ import multiprocessing
19
21
import os
20
22
import pathlib
21
23
from typing import Dict , Iterator , List , Optional
22
24
from unittest import mock
23
25
24
- from absl .testing import parameterized
25
26
from etils import epath
26
27
import pytest
27
28
import tensorflow_datasets as tfds
@@ -89,12 +90,12 @@ def _build(cmd_flags: str, mock_download_and_prepare: bool = True) -> List[str]:
89
90
# to patch the function to record the generated_ds manually.
90
91
# See:
91
92
# https://stackoverflow.com/questions/64792295/how-to-get-self-instance-in-mock-mock-call-args
92
- generated_ds_names = []
93
+ queue = multiprocessing . Queue ()
93
94
94
95
def _download_and_prepare (self , * args , ** kwargs ):
95
96
# Remove version from generated name (as only last version can be generated)
96
97
full_name = '/' .join (self .info .full_name .split ('/' )[:- 1 ])
97
- generated_ds_names . append (full_name )
98
+ queue . put (full_name )
98
99
if mock_download_and_prepare :
99
100
return
100
101
else :
@@ -105,6 +106,12 @@ def _download_and_prepare(self, *args, **kwargs):
105
106
_download_and_prepare ,
106
107
):
107
108
main .main (args )
109
+ queue .put (None )
110
+
111
+ generated_ds_names = []
112
+ while full_name := queue .get ():
113
+ generated_ds_names .append (full_name )
114
+
108
115
return generated_ds_names
109
116
110
117
@@ -139,10 +146,10 @@ def test_build_multiple():
139
146
]
140
147
141
148
142
- @parameterized . parameters ( range (5 ))
149
+ @pytest . mark . parametrize ( 'num_processes' , range (1 , 4 ))
143
150
def test_build_parallel (num_processes ):
144
151
# Order is not guaranteed
145
- assert set (_build (f'trivia_qa --num-proccesses ={ num_processes } ' )) == set ([
152
+ assert set (_build (f'trivia_qa --num-processes ={ num_processes } ' )) == set ([
146
153
'trivia_qa/rc' ,
147
154
'trivia_qa/rc.nocontext' ,
148
155
'trivia_qa/unfiltered' ,
@@ -288,22 +295,29 @@ def test_download_only():
288
295
mock_download .assert_called_with ({'file0' : 'http://data.org/file1.zip' })
289
296
290
297
291
- @parameterized .parameters (
292
- ('--manual_dir=/a/b' , {'manual_dir' : '/a/b' }),
293
- ('--manual_dir=/a/b --add_name_to_manual_dir' , {'manual_dir' : '/a/b/x' }),
294
- ('--extract_dir=/a/b' , {'extract_dir' : '/a/b' }),
295
- ('--max_examples_per_split=42' , {'max_examples_per_split' : 42 }),
296
- ('--register_checksums' , {'register_checksums' : True }),
297
- ('--force_checksums_validation' , {'force_checksums_validation' : True }),
298
- ('--max_shard_size_mb=128' , {'max_shard_size' : 128 << 20 }),
299
- (
300
- '--download_config="{\' max_shard_size\' : 1234}"' ,
301
- {'max_shard_size' : 1234 },
302
- ),
298
+ @pytest .mark .parametrize (
299
+ 'args,download_config_kwargs' ,
300
+ [
301
+ ('--manual_dir=/a/b' , {'manual_dir' : epath .Path ('/a/b' )}),
302
+ (
303
+ '--manual_dir=/a/b --add_name_to_manual_dir' ,
304
+ {'manual_dir' : epath .Path ('/a/b/x' )},
305
+ ),
306
+ ('--extract_dir=/a/b' , {'extract_dir' : epath .Path ('/a/b' )}),
307
+ ('--max_examples_per_split=42' , {'max_examples_per_split' : 42 }),
308
+ ('--register_checksums' , {'register_checksums' : True }),
309
+ ('--force_checksums_validation' , {'force_checksums_validation' : True }),
310
+ ('--max_shard_size_mb=128' , {'max_shard_size' : 128 << 20 }),
311
+ (
312
+ '--download_config={"max_shard_size":1234}' ,
313
+ {'max_shard_size' : 1234 },
314
+ ),
315
+ ],
303
316
)
304
317
def test_make_download_config (args : str , download_config_kwargs ):
305
- args = main ._parse_flags (f'tfds build x { download_config_kwargs } ' .split ())
318
+ args = main ._parse_flags (f'tfds build x { args } ' .split ())
306
319
actual = build_lib ._make_download_config (args , dataset_name = 'x' )
307
320
# Ignore the beam runner
308
- actual .replace (beam_runner = None )
309
- assert actual == tfds .download .DownloadConfig (** download_config_kwargs )
321
+ actual = actual .replace (beam_runner = None )
322
+ expected = tfds .download .DownloadConfig (** download_config_kwargs )
323
+ assert dataclasses .asdict (actual ) == dataclasses .asdict (expected )
0 commit comments