Skip to content

Commit 60cb270

Browse files
Merge pull request #91 from elseml/Development
Fixes #90 (axis labeling)
2 parents d0c9f1e + 298d4b6 commit 60cb270

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

bayesflow/computational_utilities.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,8 @@ def expected_calibration_error(m_true, m_pred, num_bins=10):
279279
# Extract number of models and prepare containers
280280
n_models = m_true.shape[1]
281281
cal_errs = []
282-
probs = []
282+
probs_true = []
283+
probs_pred = []
283284

284285
# Loop for each model and compute calibration errs per bin
285286
for k in range(n_models):
@@ -295,8 +296,9 @@ def expected_calibration_error(m_true, m_pred, num_bins=10):
295296
cal_err = np.sum(np.abs(prob_true - prob_pred) * (bin_total[nonzero] / len(y_true)))
296297

297298
cal_errs.append(cal_err)
298-
probs.append((prob_true, prob_pred))
299-
return cal_errs, probs
299+
probs_true.append(prob_true)
300+
probs_pred.append(prob_pred)
301+
return cal_errs, probs_true, probs_pred
300302

301303

302304
def maximum_mean_discrepancy(source_samples, target_samples, kernel="gaussian", mmd_weight=1.0, minimum=0.0):

bayesflow/diagnostics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,7 +1057,7 @@ def plot_calibration_curves(
10571057
# Determine n_subplots dynamically
10581058
n_row = int(np.ceil(num_models / 6))
10591059
n_col = int(np.ceil(num_models / n_row))
1060-
cal_errs, cal_probs = expected_calibration_error(true_models, pred_models, num_bins)
1060+
cal_errs, probs_true, probs_pred = expected_calibration_error(true_models, pred_models, num_bins)
10611061

10621062
# Initialize figure
10631063
if fig_size is None:
@@ -1073,7 +1073,7 @@ def plot_calibration_curves(
10731073
ax = axarr
10741074
for j in range(num_models):
10751075
# Plot calibration curve
1076-
ax[j].plot(cal_probs[j][0], cal_probs[j][1], color=color)
1076+
ax[j].plot(probs_pred[j], probs_true[j], color=color)
10771077

10781078
# Plot AB line
10791079
ax[j].plot(ax[j].get_xlim(), ax[j].get_xlim(), "--", color="darkgrey")

0 commit comments

Comments
 (0)