Skip to content

Commit 2d40658

Browse files
authored
Merge branch 'main' into feature-branch-download-demo
2 parents d4aa5da + 7a720b2 commit 2d40658

File tree

11 files changed

+353
-42
lines changed

11 files changed

+353
-42
lines changed

sdv/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def check_sdv_versions_and_warn(synthesizer):
273273
"""
274274
current_community_version = getattr(version, 'community', None)
275275
current_enterprise_version = getattr(version, 'enterprise', None)
276-
if synthesizer._fitted:
276+
if getattr(synthesizer, '_fitted', False):
277277
fitted_community_version = getattr(synthesizer, '_fitted_sdv_version', None)
278278
fitted_enterprise_version = getattr(synthesizer, '_fitted_sdv_enterprise_version', None)
279279
community_mismatch = current_community_version != fitted_community_version

sdv/multi_table/_dayz_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ def detect_relationship_parameters(data, metadata):
3434
'child_table_name': rel_tuple[1],
3535
'parent_primary_key': rel_tuple[2],
3636
'child_foreign_key': rel_tuple[3],
37-
'min_cardinality': cardinality_table['cardinality'].min(),
38-
'max_cardinality': cardinality_table['cardinality'].max(),
37+
'min_cardinality': int(cardinality_table['cardinality'].min()),
38+
'max_cardinality': int(cardinality_table['cardinality'].max()),
3939
})
4040

4141
return relationship_parameters

sdv/multi_table/dayz.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Multi-Table DayZ parameter detection and creation."""
22

3+
from sdv.cag._utils import _is_list_of_type
34
from sdv.errors import SynthesizerInputError, SynthesizerProcessingError
45
from sdv.multi_table._dayz_utils import create_parameters_multi_table
56
from sdv.single_table.dayz import _validate_parameter_structure, _validate_tables_parameter
@@ -48,8 +49,10 @@ def _validate_cardinality_bounds(relationship):
4849

4950

5051
def _validate_relationship_structure(dayz_parameters):
51-
if not isinstance(dayz_parameters.get('relationships', []), list):
52-
raise SynthesizerProcessingError("The 'relationships' parameter value must be a list.")
52+
if not _is_list_of_type(dayz_parameters.get('relationships', []), dict):
53+
raise SynthesizerProcessingError(
54+
"The 'relationships' parameter value must be a list of dictionaries."
55+
)
5356

5457
for relationship in dayz_parameters.get('relationships', []):
5558
unknown_relationship_parameters = relationship.keys() - set(RELATIONSHIP_PARAMETER_KEYS)
@@ -160,18 +163,18 @@ def __init__(self, metadata, locales=['en_US']):
160163
)
161164

162165
@classmethod
163-
def create_parameters(cls, data, metadata, output_filename=None):
166+
def create_parameters(cls, data, metadata, filepath=None):
164167
"""Create parameters for the DayZSynthesizer.
165168
166169
Args:
167170
data (dict[str, pd.DataFrame]): The input data.
168171
metadata (Metadata): The metadata object.
169-
output_filename (str, optional): The output filename for the parameters.
172+
filepath (str, optional): The output filename for the parameters.
170173
171174
Returns:
172175
dict: The created parameters.
173176
"""
174-
return create_parameters_multi_table(data, metadata, output_filename)
177+
return create_parameters_multi_table(data, metadata, filepath)
175178

176179
@staticmethod
177180
def validate_parameters(metadata, parameters):

sdv/single_table/_dayz_utils.py

Lines changed: 84 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,79 @@ def detect_table_parameters(data):
1818
return {'num_rows': len(data)}
1919

2020

21+
def _compute_missing_values_proportion(series):
22+
"""Compute missing value proportion with a safe fallback for empty series."""
23+
if len(series) == 0:
24+
return 0.0
25+
26+
value = float(series.isna().mean())
27+
return 0.0 if pd.isna(value) else value
28+
29+
30+
def _detect_numerical_column_parameters(series):
31+
"""Detect numerical-specific parameters with fallbacks when undetectable.
32+
33+
Returns only keys that can be reliably detected (no None values).
34+
"""
35+
params = {}
36+
non_null = series.dropna()
37+
if non_null.empty:
38+
return params
39+
40+
try:
41+
num_decimal_digits = learn_rounding_digits(series)
42+
if isinstance(num_decimal_digits, int) and num_decimal_digits >= 0:
43+
params['num_decimal_digits'] = num_decimal_digits
44+
except Exception:
45+
pass
46+
47+
min_value = non_null.min()
48+
max_value = non_null.max()
49+
if not pd.isna(min_value):
50+
params['min_value'] = min_value.item() if hasattr(min_value, 'item') else float(min_value)
51+
if not pd.isna(max_value):
52+
params['max_value'] = max_value.item() if hasattr(max_value, 'item') else float(max_value)
53+
54+
return params
55+
56+
57+
def _detect_datetime_column_parameters(series, column_metadata):
58+
"""Detect datetime-specific parameters with fallbacks when undetectable.
59+
60+
Returns only keys that can be reliably detected (no None values).
61+
"""
62+
params = {}
63+
datetime_format = column_metadata.get('datetime_format', None)
64+
if datetime_format:
65+
datetime_column = pd.to_datetime(series, format=datetime_format, errors='coerce')
66+
else:
67+
datetime_column = pd.to_datetime(series, errors='coerce')
68+
69+
non_na = datetime_column[~pd.isna(datetime_column)]
70+
if non_na.empty:
71+
return params
72+
73+
start_dt = non_na.min()
74+
end_dt = non_na.max()
75+
if datetime_format:
76+
params['start_timestamp'] = start_dt.strftime(datetime_format)
77+
params['end_timestamp'] = end_dt.strftime(datetime_format)
78+
else:
79+
params['start_timestamp'] = start_dt.strftime('%Y-%m-%d %H:%M:%S')
80+
params['end_timestamp'] = end_dt.strftime('%Y-%m-%d %H:%M:%S')
81+
82+
return params
83+
84+
85+
def _detect_categorical_column_parameters(series):
86+
"""Detect categorical/boolean parameters."""
87+
categorical_values = series.dropna().unique()
88+
if len(categorical_values) == 0:
89+
return {}
90+
91+
return {'category_values': categorical_values.tolist()}
92+
93+
2194
def detect_column_parameters(data, metadata, table_name):
2295
"""Detect all column-level Dayz parameters.
2396
@@ -37,46 +110,28 @@ def detect_column_parameters(data, metadata, table_name):
37110
table_metadata = metadata.tables[table_name]
38111
column_parameters = {}
39112
for column_name, column_metadata in table_metadata.columns.items():
40-
column_parameters[column_name] = {}
41113
sdtype = column_metadata['sdtype']
114+
params = {}
42115
if sdtype == 'numerical':
43-
column_parameters[column_name] = {
44-
'num_decimal_digits': learn_rounding_digits(data[column_name]),
45-
'min_value': data[column_name].min(),
46-
'max_value': data[column_name].max(),
47-
}
116+
params.update(_detect_numerical_column_parameters(data[column_name]))
48117
elif sdtype == 'datetime':
49-
datetime_format = column_metadata.get('datetime_format', None)
50-
if datetime_format:
51-
datetime_column = pd.to_datetime(
52-
data[column_name], format=datetime_format, errors='coerce'
53-
)
54-
start_timestamp = datetime_column.min().strftime(datetime_format)
55-
end_timestamp = datetime_column.max().strftime(datetime_format)
56-
57-
else:
58-
datetime_column = pd.to_datetime(data[column_name], errors='coerce')
59-
start_timestamp = str(datetime_column.min())
60-
end_timestamp = str(datetime_column.max())
61-
62-
column_parameters[column_name] = {
63-
'start_timestamp': start_timestamp,
64-
'end_timestamp': end_timestamp,
65-
}
118+
params.update(_detect_datetime_column_parameters(data[column_name], column_metadata))
66119
elif sdtype == 'categorical':
67-
column_parameters[column_name] = {
68-
'category_values': data[column_name].dropna().unique().tolist()
69-
}
120+
params.update(_detect_categorical_column_parameters(data[column_name]))
70121

71-
column_parameters[column_name]['missing_values_proportion'] = float(
72-
data[column_name].isna().mean()
73-
)
122+
params['missing_values_proportion'] = _compute_missing_values_proportion(data[column_name])
123+
column_parameters[column_name] = params
74124

75125
return {'columns': column_parameters}
76126

77127

78128
def create_parameters(data, metadata, output_filename):
79129
"""Detect and create a parameter dict for the DayZ model."""
130+
if len(data) == 0:
131+
raise ValueError('Data is empty')
132+
if len(metadata.tables) == 0:
133+
raise ValueError('Metadata is empty')
134+
80135
metadata.validate()
81136
datas = data if isinstance(data, dict) else {metadata._get_single_table_name(): data}
82137
metadata.validate_data(datas)

sdv/single_table/dayz.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,13 @@ def _validate_parameters(metadata, parameters):
227227
"""
228228
metadata.validate()
229229
_validate_parameter_structure(parameters)
230+
231+
if len(metadata.tables) > 1:
232+
raise SynthesizerProcessingError(
233+
'Invalid metadata provided for single-table DayZSynthesizer. The metadata contains '
234+
'multiple tables. Please use multi-table DayZSynthesizer instead.'
235+
)
236+
230237
if 'relationships' in parameters:
231238
msg = (
232239
"Invalid DayZ parameter 'relationships' for single-table DayZSynthesizer. "
@@ -248,18 +255,18 @@ def __init__(self, metadata, locales=['en_US']):
248255
)
249256

250257
@classmethod
251-
def create_parameters(cls, data, metadata, output_filename=None):
258+
def create_parameters(cls, data, metadata, filepath=None):
252259
"""Create parameters for the DayZ synthesizer.
253260
254261
Args:
255262
data (pd.DataFrame): The input data.
256263
metadata (Metadata): The metadata object.
257-
output_filename (str, optional): The output filename for the parameters.
264+
filepath (str, optional): The output filename for the parameters.
258265
259266
Returns:
260267
dict: The created parameters.
261268
"""
262-
return create_parameters(data, metadata, output_filename)
269+
return create_parameters(data, metadata, filepath)
263270

264271
@staticmethod
265272
def validate_parameters(metadata, parameters):

tests/integration/multi_table/test_dayz.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
"""Integration tests for DayZ parameter detection."""
22

3+
import pandas as pd
4+
import pytest
5+
36
from sdv.datasets.demo import download_demo
47
from sdv.metadata import Metadata
58
from sdv.multi_table import DayZSynthesizer
@@ -173,3 +176,23 @@ def test_validate_parameters(self):
173176

174177
# Run and Assert
175178
DayZSynthesizer.validate_parameters(metadata, dayz_parameters)
179+
180+
def test_create_parameters_empty_data(self):
181+
"""Test creating parameters with empty data."""
182+
# Setup
183+
data = {}
184+
metadata = Metadata()
185+
186+
# Run and Assert
187+
with pytest.raises(ValueError, match='Data is empty'):
188+
DayZSynthesizer.create_parameters(data, metadata)
189+
190+
def test_create_parameters_empty_metadata(self):
191+
"""Test creating parameters with empty metadata."""
192+
# Setup
193+
data = {'table': pd.DataFrame({'col1': [1, 2, 3]})}
194+
metadata = Metadata()
195+
196+
# Run and Assert
197+
with pytest.raises(ValueError, match='Metadata is empty'):
198+
DayZSynthesizer.create_parameters(data, metadata)

tests/integration/single_table/test_dayz.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
"""Integration tests for DayZ parameter detection."""
22

3+
import pandas as pd
4+
import pytest
5+
36
from sdv.datasets.demo import download_demo
47
from sdv.metadata import Metadata
58
from sdv.single_table import DayZSynthesizer
@@ -100,3 +103,23 @@ def test_validate_parameters(self):
100103

101104
# Run and Assert
102105
DayZSynthesizer.validate_parameters(metadata, dayz_parameters)
106+
107+
def test_create_parameters_empty_data(self):
108+
"""Test creating parameters with empty data."""
109+
# Setup
110+
data = pd.DataFrame()
111+
metadata = Metadata()
112+
113+
# Run and Assert
114+
with pytest.raises(ValueError, match='Data is empty'):
115+
DayZSynthesizer.create_parameters(data, metadata)
116+
117+
def test_create_parameters_empty_metadata(self):
118+
"""Test creating parameters with empty metadata."""
119+
# Setup
120+
data = pd.DataFrame({'col1': [1, 2, 3]})
121+
metadata = Metadata()
122+
123+
# Run and Assert
124+
with pytest.raises(ValueError, match='Metadata is empty'):
125+
DayZSynthesizer.create_parameters(data, metadata)

tests/unit/multi_table/test_dayz.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pandas as pd
55
import pytest
66

7+
from sdv.datasets.demo import download_demo
78
from sdv.errors import SynthesizerInputError, SynthesizerProcessingError
89
from sdv.metadata import Metadata
910
from sdv.multi_table.dayz import (
@@ -74,7 +75,7 @@ def test__validate_relationship_structure():
7475

7576
# Run and Assert
7677
expected_bad_relationships_value_msg = re.escape(
77-
"The 'relationships' parameter value must be a list."
78+
"The 'relationships' parameter value must be a list of dictionaries."
7879
)
7980
with pytest.raises(SynthesizerProcessingError, match=expected_bad_relationships_value_msg):
8081
_validate_relationship_structure(bad_relationships_value)
@@ -320,3 +321,33 @@ def test_validate_parameters(self, mock__validate_parameters, metadata):
320321

321322
# Assert
322323
mock__validate_parameters.assert_called_once_with(metadata, dayz_parameters)
324+
325+
def test__validate_relationships_is_list_of_dicts(self, metadata):
326+
"""Test that 'relationships' must be a list of dicts."""
327+
# Run and Assert
328+
expected_msg = re.escape(
329+
"The 'relationships' parameter value must be a list of dictionaries."
330+
)
331+
with pytest.raises(SynthesizerProcessingError, match=expected_msg):
332+
DayZSynthesizer.validate_parameters(metadata, {'relationships': {'a', 'b', 'c'}})
333+
334+
with pytest.raises(SynthesizerProcessingError, match=expected_msg):
335+
DayZSynthesizer.validate_parameters(metadata, {'relationships': ['a', 'b', 'c']})
336+
337+
def test__validate_min_cardinality_allows_zero(self):
338+
"""Test that min_cardinality=0 is allowed and does not raise."""
339+
# Setup
340+
data, metadata = download_demo('multi_table', 'financial_v1')
341+
dayz_parameters = DayZSynthesizer.create_parameters(data, metadata)
342+
dayz_parameters['relationships'] = [
343+
{
344+
'parent_table_name': 'district',
345+
'parent_primary_key': 'district_id',
346+
'child_table_name': 'account',
347+
'child_foreign_key': 'district_id',
348+
'min_cardinality': 0,
349+
}
350+
]
351+
352+
# Run
353+
DayZSynthesizer.validate_parameters(metadata, dayz_parameters)

tests/unit/single_table/test__dayz_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,10 @@ def test_create_parameters(mock_detect_table, mock_detect_column, tmp_path):
9696
}
9797
}
9898

99-
data = pd.DataFrame()
99+
data = pd.DataFrame({'col1': [1, 2, 3], 'col2': [4, 5, 6]})
100100
metadata = Mock()
101101
metadata._get_single_table_name.return_value = 'table_name'
102+
metadata.tables = {'table_name': Mock()}
102103

103104
# Run
104105
result = create_parameters(data, metadata, output_filename=output_filename)

0 commit comments

Comments
 (0)