Skip to content

Commit 9de8f92

Browse files
committed
Feedback
1 parent 7a70825 commit 9de8f92

File tree

4 files changed

+384
-153
lines changed

4 files changed

+384
-153
lines changed

sdmetrics/single_table/equalized_odds.py

Lines changed: 69 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,12 @@ def _preprocess_data(
8585
).astype(int)
8686

8787
# Convert sensitive column to binary
88-
data[sensitive_column_name] = (
89-
data[sensitive_column_name] == sensitive_column_value
90-
).astype(int)
88+
if pd.isna(sensitive_column_value):
89+
data[sensitive_column_name] = data[sensitive_column_name].isna().astype(int)
90+
else:
91+
data[sensitive_column_name] = (
92+
data[sensitive_column_name] == sensitive_column_value
93+
).astype(int)
9194

9295
# Handle categorical columns for XGBoost
9396
for column, column_meta in metadata['columns'].items():
@@ -162,32 +165,32 @@ def _compute_equalized_odds_score(cls, prediction_counts):
162165
true_group = prediction_counts['True']
163166
false_group = prediction_counts['False']
164167

165-
# Compute TPR for each group
166-
tpr_true = true_group['true_positive'] / max(
167-
1, true_group['true_positive'] + true_group['false_negative']
168-
)
169-
tpr_false = false_group['true_positive'] / max(
170-
1, false_group['true_positive'] + false_group['false_negative']
171-
)
172-
173-
# Compute FPR for each group
174-
fpr_true = true_group['false_positive'] / max(
175-
1, true_group['false_positive'] + true_group['true_negative']
176-
)
177-
fpr_false = false_group['false_positive'] / max(
178-
1, false_group['false_positive'] + false_group['true_negative']
179-
)
168+
# Compute TPR and FPR for each group using a loop
169+
tpr = {}
170+
fpr = {}
171+
for group_name, group in [('True', true_group), ('False', false_group)]:
172+
tpr[group_name] = group['true_positive'] / max(
173+
1, group['true_positive'] + group['false_negative']
174+
)
175+
fpr[group_name] = group['false_positive'] / max(
176+
1, group['false_positive'] + group['true_negative']
177+
)
180178

181179
# Compute fairness scores
182-
tpr_fairness = 1 - abs(tpr_true - tpr_false)
183-
fpr_fairness = 1 - abs(fpr_true - fpr_false)
180+
tpr_fairness = 1 - abs(tpr['True'] - tpr['False'])
181+
fpr_fairness = 1 - abs(fpr['True'] - fpr['False'])
184182

185183
# Final equalized odds score is minimum of the two fairness scores
186184
return min(tpr_fairness, fpr_fairness)
187185

188186
@classmethod
189187
def _evaluate_dataset(
190-
cls, train_data, validation_data, prediction_column_name, sensitive_column_name
188+
cls,
189+
train_data,
190+
validation_data,
191+
prediction_column_name,
192+
sensitive_column_name,
193+
sensitive_column_value,
191194
):
192195
"""Evaluate equalized odds for a single dataset."""
193196
# Train classifier
@@ -202,12 +205,21 @@ def _evaluate_dataset(
202205
# Compute prediction counts
203206
prediction_counts = cls._compute_prediction_counts(predictions, actuals, sensitive_values)
204207

208+
# Format the keys to include sensitive column value as in the spec
209+
formatted_counts = {}
210+
for key, counts in prediction_counts.items():
211+
if key == 'True':
212+
formatted_key = f'{sensitive_column_value}=True'
213+
else:
214+
formatted_key = f'{sensitive_column_value}=False'
215+
formatted_counts[formatted_key] = counts
216+
205217
# Compute equalized odds score
206218
equalized_odds_score = cls._compute_equalized_odds_score(prediction_counts)
207219

208220
return {
209221
'equalized_odds': equalized_odds_score,
210-
'prediction_counts_validation': prediction_counts,
222+
'prediction_counts_validation': formatted_counts,
211223
}
212224

213225
@classmethod
@@ -341,74 +353,49 @@ def compute_breakdown(
341353
)
342354
)
343355

344-
real_training_processed = cls._preprocess_data(
345-
real_training_data,
346-
prediction_column_name,
347-
positive_class_label,
348-
sensitive_column_name,
349-
sensitive_column_value,
350-
metadata,
351-
)
352-
353-
synthetic_processed = cls._preprocess_data(
354-
synthetic_data,
355-
prediction_column_name,
356-
positive_class_label,
357-
sensitive_column_name,
358-
sensitive_column_value,
359-
metadata,
360-
)
361-
362-
real_validation_processed = cls._preprocess_data(
363-
real_validation_data,
364-
prediction_column_name,
365-
positive_class_label,
366-
sensitive_column_name,
367-
sensitive_column_value,
368-
metadata,
369-
)
370-
371-
# Validate data sufficiency for training sets
372-
cls._validate_data_sufficiency(
373-
real_training_processed,
374-
prediction_column_name,
375-
sensitive_column_name,
376-
1,
377-
1, # Using 1 since we converted to binary
378-
)
379-
380-
cls._validate_data_sufficiency(
381-
synthetic_processed,
382-
prediction_column_name,
383-
sensitive_column_name,
384-
1,
385-
1, # Using 1 since we converted to binary
386-
)
356+
processed_data = []
357+
for data in [real_training_data, synthetic_data, real_validation_data]:
358+
processed_data.append(
359+
cls._preprocess_data(
360+
data,
361+
prediction_column_name,
362+
positive_class_label,
363+
sensitive_column_name,
364+
sensitive_column_value,
365+
metadata,
366+
)
367+
)
387368

388-
# Evaluate both datasets
389-
real_results = cls._evaluate_dataset(
390-
real_training_processed,
391-
real_validation_processed,
392-
prediction_column_name,
393-
sensitive_column_name,
394-
)
369+
real_training_processed, synthetic_processed, real_validation_processed = processed_data
370+
results = []
371+
for data in [real_training_processed, synthetic_processed]:
372+
cls._validate_data_sufficiency(
373+
data,
374+
prediction_column_name,
375+
sensitive_column_name,
376+
1,
377+
1, # Using 1 since we converted to binary
378+
)
395379

396-
synthetic_results = cls._evaluate_dataset(
397-
synthetic_processed,
398-
real_validation_processed,
399-
prediction_column_name,
400-
sensitive_column_name,
401-
)
380+
results.append(
381+
cls._evaluate_dataset(
382+
data,
383+
real_validation_processed,
384+
prediction_column_name,
385+
sensitive_column_name,
386+
sensitive_column_value,
387+
)
388+
)
402389

403390
# Compute final improvement score
404-
real_score = real_results['equalized_odds']
405-
synthetic_score = synthetic_results['equalized_odds']
391+
real_score = results[0]['equalized_odds']
392+
synthetic_score = results[1]['equalized_odds']
406393
improvement_score = (synthetic_score - real_score) / 2 + 0.5
407394

408395
return {
409396
'score': improvement_score,
410-
'real_training_data': real_score,
411-
'synthetic_data': synthetic_score,
397+
'real_training_data': results[0],
398+
'synthetic_data': results[1],
412399
}
413400

414401
@classmethod

sdmetrics/single_table/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,11 @@ def _validate_column_values_exist(dataframes_dict, column_value_pairs):
6464
"""
6565
for df_name, df in dataframes_dict.items():
6666
for column_name, value in column_value_pairs:
67-
if value not in df[column_name].to_numpy():
67+
column_values = df[column_name]
68+
value_exists = (pd.isna(value) and column_values.isna().any()) or (
69+
value in column_values.to_numpy()
70+
)
71+
if not value_exists:
6872
raise ValueError(f"Value '{value}' not found in {df_name}['{column_name}']")
6973

7074

0 commit comments

Comments
 (0)