@@ -62,7 +62,7 @@ def plot_recovery(
62
62
https://betanalpha.github.io/assets/case_studies/principled_bayesian_workflow.html
63
63
64
64
Important: Posterior aggregates play no special role in Bayesian inference and should only
65
- be used heuristically. For instanec , in the case of multi-modal posteriors, common point
65
+ be used heuristically. For instance , in the case of multi-modal posteriors, common point
66
66
estimates, such as mean, (geometric) median, or maximum a posteriori (MAP) mean nothing.
67
67
68
68
Parameters
@@ -71,7 +71,7 @@ def plot_recovery(
71
71
The posterior draws obtained from n_data_sets
72
72
prior_samples : np.ndarray of shape (n_data_sets, n_params)
73
73
The prior draws (true parameters) obtained for generating the n_data_sets
74
- point_agg : callable, optional, default: np.median
74
+ point_agg : callable, optional, default: `` np.median``
75
75
The function to apply to the posterior draws to get a point estimate for each marginal.
76
76
The default computes the marginal median for each marginal posterior as a robust
77
77
point estimate.
@@ -89,13 +89,13 @@ def plot_recovery(
89
89
metric_fontsize : int, optional, default: 16
90
90
The font size of the goodness-of-fit metric (if provided)
91
91
tick_fontsize : int, optional, default: 12
92
- The font size of the axis ticklabels
92
+ The font size of the axis tick labels
93
93
add_corr : bool, optional, default: True
94
94
A flag for adding correlation between true and estimates to the plot
95
95
add_r2 : bool, optional, default: True
96
96
A flag for adding R^2 between true and estimates to the plot
97
97
color : str, optional, default: '#8f2727'
98
- The color for the true vs. estimated scatter points and errobars
98
+ The color for the true vs. estimated scatter points and error bars
99
99
100
100
Returns
101
101
-------
@@ -144,7 +144,7 @@ def plot_recovery(
144
144
if i >= n_params :
145
145
break
146
146
147
- # Add scatter and errorbars
147
+ # Add scatter and error bars
148
148
if uncertainty_agg is not None :
149
149
_ = ax .errorbar (prior_samples [:, i ], est [:, i ], yerr = u [:, i ], fmt = "o" , alpha = 0.5 , color = color )
150
150
else :
@@ -242,7 +242,7 @@ def plot_z_score_contraction(
242
242
243
243
post_contraction = 1 - (posterior_variance / prior_variance)
244
244
245
- In other words, the posterior is a proxy for the reduction in ucnertainty gained by
245
+ In other words, the posterior is a proxy for the reduction in uncertainty gained by
246
246
replacing the prior with the posterior. The ideal posterior contraction tends to 1.
247
247
Contraction near zero indicates that the posterior variance is almost identical to
248
248
the prior variance for the particular marginal parameter distribution.
@@ -253,7 +253,7 @@ def plot_z_score_contraction(
253
253
Toward a principled Bayesian workflow in cognitive science.
254
254
Psychological methods, 26(1), 103.
255
255
256
- Also available at https://arxiv.org/abs/1904.12765
256
+ Paper also available at https://arxiv.org/abs/1904.12765
257
257
258
258
Parameters
259
259
----------
@@ -272,7 +272,7 @@ def plot_z_score_contraction(
272
272
tick_fontsize : int, optional, default: 12
273
273
The font size of the axis ticklabels
274
274
color : str, optional, default: '#8f2727'
275
- The color for the true vs. estimated scatter points and errobars
275
+ The color for the true vs. estimated scatter points and error bars
276
276
277
277
Returns
278
278
-------
@@ -887,21 +887,21 @@ def plot_losses(
887
887
lw = lw_val ,
888
888
label = "Validation" ,
889
889
)
890
- # Schmuck
890
+ # Schmuck
891
891
ax .set_xlabel ("Training step #" , fontsize = label_fontsize )
892
892
ax .set_ylabel ("Loss value" , fontsize = label_fontsize )
893
893
sns .despine (ax = ax )
894
894
ax .grid (alpha = grid_alpha )
895
895
ax .set_title (train_losses .columns [i ], fontsize = title_fontsize )
896
896
# Only add legend if there is a validation curve
897
- if val_losses is not None :
897
+ if val_losses is not None or moving_average :
898
898
ax .legend (fontsize = legend_fontsize )
899
899
f .tight_layout ()
900
900
return f
901
901
902
902
903
903
def plot_prior2d (prior , param_names = None , n_samples = 2000 , height = 2.5 , color = "#8f2727" , ** kwargs ):
904
- """Creates pairplots for a given joint prior.
904
+ """Creates pair-plots for a given joint prior.
905
905
906
906
Parameters
907
907
----------
@@ -913,7 +913,7 @@ def plot_prior2d(prior, param_names=None, n_samples=2000, height=2.5, color="#8f
913
913
The number of random draws from the joint prior
914
914
height : float, optional, default: 2.5
915
915
The height of the pair plot
916
- color : str, optional, defailt : '#8f2727'
916
+ color : str, optional, default : '#8f2727'
917
917
The color of the plot
918
918
**kwargs : dict, optional
919
919
Additional keyword arguments passed to the sns.PairGrid constructor
@@ -943,14 +943,16 @@ def plot_prior2d(prior, param_names=None, n_samples=2000, height=2.5, color="#8f
943
943
# Generate plots
944
944
g = sns .PairGrid (data_to_plot , height = height , ** kwargs )
945
945
g .map_diag (sns .histplot , fill = True , color = color , alpha = 0.9 , kde = True )
946
- # Kernel density estimation (KDE) may not always be possible (e.g. with parameters whose correlation is close to 1 or -1).
946
+
947
+ # Kernel density estimation (KDE) may not always be possible
948
+ # (e.g. with parameters whose correlation is close to 1 or -1).
947
949
# In this scenario, a scatter-plot is generated instead.
948
950
try :
949
951
g .map_lower (sns .kdeplot , fill = True , color = color , alpha = 0.9 )
950
952
except Exception as e :
951
- logging .warn ("KDE failed due to the following exception:\n " + repr (e ) + "\n Substituting scatter plot." )
952
- g .map_lower (plt . scatter , alpha = 0.6 , s = 40 , edgecolor = "k" , color = color )
953
- g .map_upper (plt . scatter , alpha = 0.6 , s = 40 , edgecolor = "k" , color = color )
953
+ logging .warning ("KDE failed due to the following exception:\n " + repr (e ) + "\n Substituting scatter plot." )
954
+ g .map_lower (sns . scatterplot , alpha = 0.6 , s = 40 , edgecolor = "k" , color = color )
955
+ g .map_upper (sns . scatterplot , alpha = 0.6 , s = 40 , edgecolor = "k" , color = color )
954
956
955
957
# Add grids
956
958
for i in range (dim ):
@@ -961,8 +963,8 @@ def plot_prior2d(prior, param_names=None, n_samples=2000, height=2.5, color="#8f
961
963
962
964
963
965
def plot_latent_space_2d (z_samples , height = 2.5 , color = "#8f2727" , ** kwargs ):
964
- """Creates pairplots for the latent space learned by the inference network. Enables
965
- visual inspection of the the latent space and whether its structrue corresponds to the
966
+ """Creates pair plots for the latent space learned by the inference network. Enables
967
+ visual inspection of the latent space and whether its structure corresponds to the
966
968
one enforced by the optimization criterion.
967
969
968
970
Parameters
@@ -971,7 +973,7 @@ def plot_latent_space_2d(z_samples, height=2.5, color="#8f2727", **kwargs):
971
973
The latent samples computed through a forward pass of the inference network.
972
974
height : float, optional, default: 2.5
973
975
The height of the pair plot.
974
- color : str, optional, defailt : '#8f2727'
976
+ color : str, optional, default : '#8f2727'
975
977
The color of the plot
976
978
**kwargs : dict, optional
977
979
Additional keyword arguments passed to the sns.PairGrid constructor
@@ -996,7 +998,7 @@ def plot_latent_space_2d(z_samples, height=2.5, color="#8f2727", **kwargs):
996
998
g = sns .PairGrid (data_to_plot , height = height , ** kwargs )
997
999
g .map_diag (sns .histplot , fill = True , color = color , alpha = 0.9 , kde = True )
998
1000
g .map_lower (sns .kdeplot , fill = True , color = color , alpha = 0.9 )
999
- g .map_upper (plt . scatter , alpha = 0.6 , s = 40 , edgecolor = "k" , color = color )
1001
+ g .map_upper (sns . scatterplot , alpha = 0.6 , s = 40 , edgecolor = "k" , color = color )
1000
1002
1001
1003
# Add grids
1002
1004
for i in range (z_dim ):
@@ -1060,6 +1062,8 @@ def plot_calibration_curves(
1060
1062
# Determine n_subplots dynamically
1061
1063
n_row = int (np .ceil (num_models / 6 ))
1062
1064
n_col = int (np .ceil (num_models / n_row ))
1065
+
1066
+ # Compute calibration
1063
1067
cal_errs , probs_true , probs_pred = expected_calibration_error (true_models , pred_models , num_bins )
1064
1068
1065
1069
# Initialize figure
@@ -1094,8 +1098,6 @@ def plot_calibration_curves(
1094
1098
ax [j ].spines ["top" ].set_visible (False )
1095
1099
ax [j ].set_xlim ([0 - epsilon , 1 + epsilon ])
1096
1100
ax [j ].set_ylim ([0 - epsilon , 1 + epsilon ])
1097
- ax [j ].set_xlabel ("Predicted probability" , fontsize = label_fontsize )
1098
- ax [j ].set_ylabel ("True probability" , fontsize = label_fontsize )
1099
1101
ax [j ].set_xticks ([0.0 , 0.2 , 0.4 , 0.6 , 0.8 , 1.0 ])
1100
1102
ax [j ].set_yticks ([0.0 , 0.2 , 0.4 , 0.6 , 0.8 , 1.0 ])
1101
1103
ax [j ].grid (alpha = 0.5 )
@@ -1111,6 +1113,18 @@ def plot_calibration_curves(
1111
1113
size = legend_fontsize ,
1112
1114
)
1113
1115
1116
+ # Only add x-labels to the bottom row
1117
+ bottom_row = axarr if n_row == 1 else axarr [0 ] if n_col == 1 else axarr [n_row - 1 , :]
1118
+ for _ax in bottom_row :
1119
+ _ax .set_xlabel ("Predicted probability" , fontsize = label_fontsize )
1120
+
1121
+ # Only add y-labels to left-most row
1122
+ if n_row == 1 : # if there is only one row, the ax array is 1D
1123
+ ax [0 ].set_ylabel ("True probability" , fontsize = label_fontsize )
1124
+ else : # if there is more than one row, the ax array is 2D
1125
+ for _ax in axarr [:, 0 ]:
1126
+ _ax .set_ylabel ("True probability" , fontsize = label_fontsize )
1127
+
1114
1128
fig .tight_layout ()
1115
1129
return fig
1116
1130
@@ -1223,32 +1237,31 @@ def plot_mmd_hypothesis_test(
1223
1237
1224
1238
Parameters
1225
1239
----------
1226
- mmd_null: np.ndarray
1227
- samples from the MMD sampling distribution under the null hypothesis "the model is well-specified"
1228
- mmd_observed: float
1229
- observed MMD value
1230
- alpha_level: float
1231
- rejection probability (type I error)
1232
- null_color: color
1233
- color for the H0 sampling distribution
1234
- observed_color: color
1235
- color for the observed MMD
1236
- alpha_color: color
1237
- color for the rejection area
1240
+ mmd_null : np.ndarray
1241
+ The samples from the MMD sampling distribution under the null hypothesis "the model is well-specified"
1242
+ mmd_observed : float
1243
+ The observed MMD value
1244
+ alpha_level : float
1245
+ The rejection probability (type I error)
1246
+ null_color : str or tuple
1247
+ The color of the H0 sampling distribution
1248
+ observed_color : str or tuple
1249
+ The color of the observed MMD
1250
+ alpha_color : str or tuple
1251
+ The color of the rejection area
1238
1252
truncate_vlines_at_kde: bool
1239
1253
true: cut off the vlines at the kde
1240
1254
false: continue kde lines across the plot
1241
- xmin: float
1242
- lower x axis limit
1243
- xmax: float
1244
- upper x axis limit
1245
- bw_factor: float, default: 1.5
1255
+ xmin : float
1256
+ The lower x- axis limit
1257
+ xmax : float
1258
+ The upper x- axis limit
1259
+ bw_factor : float, optional , default: 1.5
1246
1260
bandwidth (aka. smoothing parameter) of the kernel density estimate
1247
1261
1248
1262
Returns
1249
1263
-------
1250
1264
f : plt.Figure - the figure instance for optional saving
1251
-
1252
1265
"""
1253
1266
1254
1267
def draw_vline_to_kde (x , kde_object , color , label = None , ** kwargs ):
0 commit comments