|
22 | 22 | import tensorflow as tf
|
23 | 23 | from scipy import stats
|
24 | 24 | from sklearn.calibration import calibration_curve
|
| 25 | +from sklearn.model_selection import KFold, cross_val_score |
| 26 | +from sklearn.neural_network import MLPClassifier |
25 | 27 |
|
26 | 28 | from bayesflow.default_settings import MMD_BANDWIDTH_LIST
|
27 | 29 | from bayesflow.exceptions import ShapeError
|
@@ -517,3 +519,87 @@ def aggregated_rmse(x_true, x_pred):
|
517 | 519 | return aggregated_error(
|
518 | 520 | x_true=x_true, x_pred=x_pred, inner_error_fun=root_mean_squared_error, outer_aggregation_fun=np.mean
|
519 | 521 | )
|
| 522 | + |
| 523 | + |
| 524 | +def c2st( |
| 525 | + source_samples, |
| 526 | + target_samples, |
| 527 | + n_folds=5, |
| 528 | + scoring="accuracy", |
| 529 | + normalize=True, |
| 530 | + seed=123, |
| 531 | + hidden_units_per_dim=16, |
| 532 | + aggregate_output=True, |
| 533 | +): |
| 534 | + """C2ST metric [1] using an sklearn neural network classifier (i.e., MLP). |
| 535 | + Code adapted from https://github.com/sbi-benchmark/sbibm/blob/main/sbibm/metrics/c2st.py |
| 536 | +
|
| 537 | + [1] Lopez-Paz, D., & Oquab, M. (2016). Revisiting classifier two-sample tests. arXiv:1610.06545. |
| 538 | +
|
| 539 | + Parameters |
| 540 | + ---------- |
| 541 | + source_samples : np.ndarray or tf.Tensor |
| 542 | + Source samples (e.g., approximate posterior samples) |
| 543 | + target_samples : np.ndarray or tf.Tensor |
| 544 | + Target samples (e.g., samples from a reference posterior) |
| 545 | + n_folds : int, optional, default: 5 |
| 546 | + Number of folds in k-fold cross-validation for the classifier evaluation |
| 547 | + scoring : str, optional, default: "accuracy" |
| 548 | + Evaluation score of the sklearn MLP classifier |
| 549 | + normalize : bool, optional, default: True |
| 550 | + Whether the data shall be z-standardized relative to source_samples |
| 551 | + seed : int, optional, default: 123 |
| 552 | + RNG seed for the MLP and k-fold CV |
| 553 | + hidden_units_per_dim : int, optional, default: 16 |
| 554 | + Number of hidden units in the MLP, relative to the input dimensions. |
| 555 | + Example: source samples are 5D, hidden_units_per_dim=16 -> 80 hidden units per layer |
| 556 | + aggregate_output : bool, optional, default: True |
| 557 | + Whether to return a single value aggregated over all cross-validation runs |
| 558 | + or all values from all runs. If left at default, the empirical mean will be returned |
| 559 | +
|
| 560 | + Returns |
| 561 | + ------- |
| 562 | + c2st_score : float |
| 563 | + The resulting C2ST score |
| 564 | +
|
| 565 | + """ |
| 566 | + |
| 567 | + x = np.array(source_samples) |
| 568 | + y = np.array(target_samples) |
| 569 | + |
| 570 | + num_dims = x.shape[1] |
| 571 | + if not num_dims == y.shape[1]: |
| 572 | + raise ShapeError( |
| 573 | + f"source_samples and target_samples can have different number of observations (1st dim)" |
| 574 | + f"but must have the same dimensionality (2nd dim)" |
| 575 | + f"found: source_samples {source_samples.shape[1]}, target_samples {target_samples.shape[1]}" |
| 576 | + ) |
| 577 | + |
| 578 | + if normalize: |
| 579 | + x_mean = np.mean(x, axis=0) |
| 580 | + x_std = np.std(x, axis=0) |
| 581 | + x = (x - x_mean) / x_std |
| 582 | + y = (y - x_mean) / x_std |
| 583 | + |
| 584 | + clf = MLPClassifier( |
| 585 | + activation="relu", |
| 586 | + hidden_layer_sizes=(hidden_units_per_dim * num_dims, hidden_units_per_dim * num_dims), |
| 587 | + max_iter=10000, |
| 588 | + solver="adam", |
| 589 | + random_state=seed, |
| 590 | + ) |
| 591 | + |
| 592 | + data = np.concatenate((x, y)) |
| 593 | + target = np.concatenate( |
| 594 | + ( |
| 595 | + np.zeros((x.shape[0],)), |
| 596 | + np.ones((y.shape[0],)), |
| 597 | + ) |
| 598 | + ) |
| 599 | + |
| 600 | + shuffle = KFold(n_splits=n_folds, shuffle=True, random_state=seed) |
| 601 | + scores = cross_val_score(clf, data, target, cv=shuffle, scoring=scoring) |
| 602 | + |
| 603 | + if aggregate_output: |
| 604 | + c2st_score = np.asarray(np.mean(scores)).astype(np.float32) |
| 605 | + return c2st_score |
0 commit comments