Skip to content

Commit 80cfbdb

Browse files
committed
Split methods into helpers
1 parent 2d40658 commit 80cfbdb

File tree

3 files changed

+94
-98
lines changed

3 files changed

+94
-98
lines changed

sdv/datasets/demo.py

Lines changed: 94 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,46 @@ def is_metainfo_yaml(key):
368368
yield dataset_name, key
369369

370370

371+
def _get_info_from_yaml_key(yaml_key):
372+
"""Load and parse YAML metadata from an S3 key."""
373+
raw = _get_data_from_bucket(yaml_key)
374+
return yaml.safe_load(raw) or {}
375+
376+
377+
def _parse_size_mb(size_mb_val, dataset_name):
378+
"""Parse the size (MB) value into a float or NaN with logging on failures."""
379+
try:
380+
return float(size_mb_val) if size_mb_val is not None else np.nan
381+
except (ValueError, TypeError):
382+
LOGGER.info(
383+
f'Invalid dataset-size-mb {size_mb_val} for dataset '
384+
f'{dataset_name}; defaulting to NaN.'
385+
)
386+
return np.nan
387+
388+
389+
def _parse_num_tables(num_tables_val, dataset_name):
390+
"""Parse the num-tables value into an int or NaN with logging on failures."""
391+
if isinstance(num_tables_val, str):
392+
try:
393+
num_tables_val = float(num_tables_val)
394+
except (ValueError, TypeError):
395+
LOGGER.info(
396+
f'Could not cast num_tables_val {num_tables_val} to float for '
397+
f'dataset {dataset_name}; defaulting to NaN.'
398+
)
399+
num_tables_val = np.nan
400+
401+
try:
402+
return int(num_tables_val) if not pd.isna(num_tables_val) else np.nan
403+
except (ValueError, TypeError):
404+
LOGGER.info(
405+
f'Invalid num-tables {num_tables_val} for '
406+
f'dataset {dataset_name} when parsing as int.'
407+
)
408+
return np.nan
409+
410+
371411
def get_available_demos(modality):
372412
"""Get demo datasets available for a ``modality``.
373413
@@ -387,38 +427,10 @@ def get_available_demos(modality):
387427
tables_info = defaultdict(list)
388428
for dataset_name, yaml_key in _iter_metainfo_yaml_entries(contents, modality):
389429
try:
390-
raw = _get_data_from_bucket(yaml_key)
391-
info = yaml.safe_load(raw) or {}
430+
info = _get_info_from_yaml_key(yaml_key)
392431

393-
size_mb_val = info.get('dataset-size-mb')
394-
try:
395-
size_mb = float(size_mb_val) if size_mb_val is not None else np.nan
396-
except (ValueError, TypeError):
397-
LOGGER.info(
398-
f'Invalid dataset-size-mb {size_mb_val} for dataset '
399-
f'{dataset_name}; defaulting to NaN.'
400-
)
401-
size_mb = np.nan
402-
403-
num_tables_val = info.get('num-tables', np.nan)
404-
if isinstance(num_tables_val, str):
405-
try:
406-
num_tables_val = float(num_tables_val)
407-
except (ValueError, TypeError):
408-
LOGGER.info(
409-
f'Could not cast num_tables_val {num_tables_val} to float for '
410-
f'dataset {dataset_name}; defaulting to NaN.'
411-
)
412-
num_tables_val = np.nan
413-
414-
try:
415-
num_tables = int(num_tables_val) if not pd.isna(num_tables_val) else np.nan
416-
except (ValueError, TypeError):
417-
LOGGER.info(
418-
f'Invalid num-tables {num_tables_val} for '
419-
f'dataset {dataset_name} when parsing as int.'
420-
)
421-
num_tables = np.nan
432+
size_mb = _parse_size_mb(info.get('dataset-size-mb'), dataset_name)
433+
num_tables = _parse_num_tables(info.get('num-tables', np.nan), dataset_name)
422434

423435
tables_info['dataset_name'].append(dataset_name)
424436
tables_info['size_MB'].append(size_mb)
@@ -456,6 +468,54 @@ def _find_text_key(contents, dataset_prefix, filename):
456468
return None
457469

458470

471+
def _validate_text_file_content(modality, output_filepath, filename):
472+
"""Validation for the text file content method."""
473+
_validate_modalities(modality)
474+
if output_filepath is not None and not str(output_filepath).endswith('.txt'):
475+
fname = (filename or '').lower()
476+
file_type = 'README' if 'readme' in fname else 'source'
477+
raise ValueError(
478+
f'The {file_type} can only be saved as a txt file. '
479+
"Please provide a filepath ending in '.txt'"
480+
)
481+
482+
483+
def _raise_warnings(filename, output_filepath):
484+
"""Warn about missing text resources for a dataset."""
485+
if (filename or '').upper() == 'README.TXT':
486+
msg = 'No README information is available for this dataset.'
487+
elif (filename or '').upper() == 'SOURCE.TXT':
488+
msg = 'No source information is available for this dataset.'
489+
else:
490+
msg = f'No {filename} information is available for this dataset.'
491+
492+
if output_filepath:
493+
msg = f'{msg} The requested file ({output_filepath}) will not be created.'
494+
495+
warnings.warn(msg, DemoResourceNotFoundWarning)
496+
497+
498+
def _save_document(text, output_filepath, filename, dataset_name):
499+
"""Persist ``text`` to ``output_filepath`` if provided."""
500+
if not output_filepath:
501+
return
502+
503+
if os.path.exists(str(output_filepath)):
504+
raise ValueError(
505+
f"A file named '{output_filepath}' already exists. "
506+
'Please specify a different filepath.'
507+
)
508+
509+
try:
510+
parent = os.path.dirname(str(output_filepath))
511+
if parent:
512+
os.makedirs(parent, exist_ok=True)
513+
with open(output_filepath, 'w', encoding='utf-8') as f:
514+
f.write(text)
515+
except Exception:
516+
LOGGER.info(f'Error saving {filename} for dataset {dataset_name}.')
517+
518+
459519
def _get_text_file_content(modality, dataset_name, filename, output_filepath=None):
460520
"""Fetch text file content under the dataset prefix.
461521
@@ -473,29 +533,13 @@ def _get_text_file_content(modality, dataset_name, filename, output_filepath=Non
473533
str or None:
474534
The decoded text contents if the file exists, otherwise ``None``.
475535
"""
476-
_validate_modalities(modality)
477-
if output_filepath is not None and not str(output_filepath).endswith('.txt'):
478-
fname = (filename or '').lower()
479-
file_type = 'README' if 'readme' in fname else 'source'
480-
raise ValueError(
481-
f'The {file_type} can only be saved as a txt file. '
482-
"Please provide a filepath ending in '.txt'"
483-
)
536+
_validate_text_file_content(modality, output_filepath, filename)
484537

485538
dataset_prefix = f'{modality}/{dataset_name}/'
486539
contents = _list_objects(dataset_prefix)
487-
488540
key = _find_text_key(contents, dataset_prefix, filename)
489541
if not key:
490-
if file_type in ('README', 'SOURCE'):
491-
msg = f'No {file_type} information is available for this dataset.
492-
else:
493-
msg = f'No {filename} information is available for this dataset.'
494-
495-
if output_filepath:
496-
msg = f'{msg} The requested file ({output_filepath}) will not be created.'
497-
498-
warnings.warn(msg, DemoResourceNotFoundWarning)
542+
_raise_warnings(filename, output_filepath)
499543
return None
500544

501545
try:
@@ -505,22 +549,7 @@ def _get_text_file_content(modality, dataset_name, filename, output_filepath=Non
505549
return None
506550

507551
text = raw.decode('utf-8', errors='replace')
508-
if output_filepath:
509-
if os.path.exists(str(output_filepath)):
510-
raise ValueError(
511-
f"A file named '{output_filepath}' already exists. "
512-
'Please specify a different filepath.'
513-
)
514-
try:
515-
parent = os.path.dirname(str(output_filepath))
516-
if parent:
517-
os.makedirs(parent, exist_ok=True)
518-
with open(output_filepath, 'w', encoding='utf-8') as f:
519-
f.write(text)
520-
521-
except Exception:
522-
LOGGER.info(f'Error saving {filename} for dataset {dataset_name}.')
523-
pass
552+
_save_document(text, output_filepath, filename, dataset_name)
524553

525554
return text
526555

tests/unit/multi_table/test_dayz.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -333,21 +333,3 @@ def test__validate_relationships_is_list_of_dicts(self, metadata):
333333

334334
with pytest.raises(SynthesizerProcessingError, match=expected_msg):
335335
DayZSynthesizer.validate_parameters(metadata, {'relationships': ['a', 'b', 'c']})
336-
337-
def test__validate_min_cardinality_allows_zero(self):
338-
"""Test that min_cardinality=0 is allowed and does not raise."""
339-
# Setup
340-
data, metadata = download_demo('multi_table', 'financial_v1')
341-
dayz_parameters = DayZSynthesizer.create_parameters(data, metadata)
342-
dayz_parameters['relationships'] = [
343-
{
344-
'parent_table_name': 'district',
345-
'parent_primary_key': 'district_id',
346-
'child_table_name': 'account',
347-
'child_foreign_key': 'district_id',
348-
'min_cardinality': 0,
349-
}
350-
]
351-
352-
# Run
353-
DayZSynthesizer.validate_parameters(metadata, dayz_parameters)

tests/unit/single_table/test_dayz.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -419,21 +419,6 @@ def test__validate_parameters_errors_with_multi_table_metadata(self):
419419
with pytest.raises(SynthesizerProcessingError, match=expected_error_msg):
420420
_validate_parameters(metadata, dayz_parameters)
421421

422-
def test__validate_parameters_errors_with_relationships(self):
423-
"""Test that single-table validation errors if relationships are provided."""
424-
# Setup
425-
data, metadata = download_demo('multi_table', 'financial_v1')
426-
dayz_parameters = MultiTableDayZSynthesizer.create_parameters(data, metadata)
427-
del dayz_parameters['relationships']
428-
429-
# Run and Assert
430-
expected_error_msg = re.escape(
431-
'Invalid metadata provided for single-table DayZSynthesizer. The metadata contains '
432-
'multiple tables. Please use multi-table DayZSynthesizer instead.'
433-
)
434-
with pytest.raises(SynthesizerProcessingError, match=expected_error_msg):
435-
DayZSynthesizer.validate_parameters(metadata, dayz_parameters)
436-
437422
def test_create_parameters_returns_valid_defaults(self):
438423
"""Test create_parameters returns valid defaults."""
439424
# Setup

0 commit comments

Comments
 (0)