11"""Utils method for data augmentation metrics."""
22
3- import pandas as pd
4-
5- from sdmetrics ._utils_metadata import _process_data_with_metadata , _validate_single_table_metadata
6-
7-
8- def _validate_tables (real_training_data , synthetic_data , real_validation_data ):
9- """Validate the tables of the Data Augmentation metrics."""
10- tables = [real_training_data , synthetic_data , real_validation_data ]
11- if any (not isinstance (table , pd .DataFrame ) for table in tables ):
12- raise ValueError (
13- '`real_training_data`, `synthetic_data` and `real_validation_data` must be '
14- 'pandas DataFrames.'
15- )
16-
17-
18- def _validate_prediction_column_name (prediction_column_name ):
19- """Validate the prediction column name of the Data Augmentation metrics."""
20- if not isinstance (prediction_column_name , str ):
21- raise TypeError ('`prediction_column_name` must be a string.' )
22-
23-
24- def _validate_classifier (classifier ):
25- """Validate the classifier of the Data Augmentation metrics."""
26- if classifier is not None and not isinstance (classifier , str ):
27- raise TypeError ('`classifier` must be a string or None.' )
28-
29- if classifier != 'XGBoost' :
30- raise ValueError ('Currently only `XGBoost` is supported as classifier.' )
3+ from sdmetrics ._utils_metadata import _validate_single_table_metadata
4+ from sdmetrics .single_table .utils import (
5+ _validate_classifier ,
6+ _validate_data_and_metadata ,
7+ _validate_prediction_column_name ,
8+ _validate_tables ,
9+ )
3110
3211
3312def _validate_fixed_recall_value (fixed_recall_value ):
@@ -53,51 +32,6 @@ def _validate_parameters(
5332 _validate_fixed_recall_value (fixed_recall_value )
5433
5534
56- def _validate_data_and_metadata (
57- real_training_data ,
58- synthetic_data ,
59- real_validation_data ,
60- metadata ,
61- prediction_column_name ,
62- minority_class_label ,
63- ):
64- """Validate the data and metadata of the Data Augmentation metrics."""
65- if prediction_column_name not in metadata ['columns' ]:
66- raise ValueError (
67- f'The column `{ prediction_column_name } ` is not described in the metadata.'
68- ' Please update your metadata.'
69- )
70-
71- if metadata ['columns' ][prediction_column_name ]['sdtype' ] not in ('categorical' , 'boolean' ):
72- raise ValueError (
73- f'The column `{ prediction_column_name } ` must be either categorical or boolean.'
74- ' Please update your metadata.'
75- )
76-
77- if minority_class_label not in real_training_data [prediction_column_name ].unique ():
78- raise ValueError (
79- f'The value `{ minority_class_label } ` is not present in the column '
80- f'`{ prediction_column_name } ` for the real training data.'
81- )
82-
83- if minority_class_label not in real_validation_data [prediction_column_name ].unique ():
84- raise ValueError (
85- f"The metric can't be computed because the value `{ minority_class_label } ` "
86- f'is not present in the column `{ prediction_column_name } ` for the real validation data.'
87- ' The `precision` and `recall` are undefined for this case.'
88- )
89-
90- synthetic_labels = set (synthetic_data [prediction_column_name ].unique ())
91- real_labels = set (real_training_data [prediction_column_name ].unique ())
92- if not synthetic_labels .issubset (real_labels ):
93- to_print = "', '" .join (sorted (synthetic_labels - real_labels ))
94- raise ValueError (
95- f'The ``{ prediction_column_name } `` column must have the same values in the real '
96- 'and synthetic data. The following values are present in the synthetic data and'
97- f" not the real data: '{ to_print } '"
98- )
99-
100-
10135def _validate_inputs (
10236 real_training_data ,
10337 synthetic_data ,
@@ -127,13 +61,12 @@ def _validate_inputs(
12761 minority_class_label ,
12862 )
12963
130-
131- def _process_data_with_metadata_ml_efficacy_metrics (
132- real_training_data , synthetic_data , real_validation_data , metadata
133- ):
134- """Process the data for ML efficacy metrics according to the metadata."""
135- real_training_data = _process_data_with_metadata (real_training_data , metadata , True )
136- synthetic_data = _process_data_with_metadata (synthetic_data , metadata , True )
137- real_validation_data = _process_data_with_metadata (real_validation_data , metadata , True )
138-
139- return real_training_data , synthetic_data , real_validation_data
64+ synthetic_labels = set (synthetic_data [prediction_column_name ].unique ())
65+ real_labels = set (real_training_data [prediction_column_name ].unique ())
66+ if not synthetic_labels .issubset (real_labels ):
67+ to_print = "', '" .join (sorted (synthetic_labels - real_labels ))
68+ raise ValueError (
69+ f'The `{ prediction_column_name } ` column must have the same values in the real '
70+ 'and synthetic data. The following values are present in the synthetic data and'
71+ f" not the real data: '{ to_print } '"
72+ )
0 commit comments