Skip to content

Commit c507fca

Browse files
Merge pull request #99 from stefanradev93/Development
Development
2 parents f46374d + 0709a3c commit c507fca

8 files changed

+395
-685
lines changed

bayesflow/amortizers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,10 @@ def _determine_latent_dist(self, latent_dist):
427427
def _determine_summary_loss(self, loss_fun):
428428
"""Determines which summary loss to use if default `None` argument provided, otherwise return identity."""
429429

430+
# Throw, if summary loss without a summary network provided
431+
if loss_fun is not None and self.summary_net is None:
432+
raise ConfigurationError('You need to provide a summary_net if you want to use a summary_loss_fun.')
433+
430434
# If callable, return provided loss
431435
if loss_fun is None or callable(loss_fun):
432436
return loss_fun
@@ -566,12 +570,12 @@ def sample(self, input_dict, n_samples, to_numpy=True, **kwargs):
566570
return lik_samples
567571

568572
def sample_loop(self, input_list, n_samples, to_numpy=True, **kwargs):
569-
"""Generates random draws from the surrogate network given a list of dicts with conditonal variables.
573+
"""Generates random draws from the surrogate network given a list of dicts with conditional variables.
570574
Useful when GPU memory is limited or data sets have a different (non-Tensor) structure.
571575
572576
Parameters
573577
----------
574-
input_list : list of dictionaries, each dictionary having the following mandatory keys, if ``DEFAULT_KEYS`` unchanged:
578+
input_list : list of dictionaries, each dictionary having the following mandatory keys (default):
575579
``conditions`` - the conditioning variables that the directly passed to the surrogate network
576580
n_samples : int
577581
The number of posterior draws (samples) to obtain from the approximate posterior

bayesflow/diagnostics.py

Lines changed: 53 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def plot_recovery(
6262
https://betanalpha.github.io/assets/case_studies/principled_bayesian_workflow.html
6363
6464
Important: Posterior aggregates play no special role in Bayesian inference and should only
65-
be used heuristically. For instanec, in the case of multi-modal posteriors, common point
65+
be used heuristically. For instance, in the case of multi-modal posteriors, common point
6666
estimates, such as mean, (geometric) median, or maximum a posteriori (MAP) mean nothing.
6767
6868
Parameters
@@ -71,7 +71,7 @@ def plot_recovery(
7171
The posterior draws obtained from n_data_sets
7272
prior_samples : np.ndarray of shape (n_data_sets, n_params)
7373
The prior draws (true parameters) obtained for generating the n_data_sets
74-
point_agg : callable, optional, default: np.median
74+
point_agg : callable, optional, default: ``np.median``
7575
The function to apply to the posterior draws to get a point estimate for each marginal.
7676
The default computes the marginal median for each marginal posterior as a robust
7777
point estimate.
@@ -89,13 +89,13 @@ def plot_recovery(
8989
metric_fontsize : int, optional, default: 16
9090
The font size of the goodness-of-fit metric (if provided)
9191
tick_fontsize : int, optional, default: 12
92-
The font size of the axis ticklabels
92+
The font size of the axis tick labels
9393
add_corr : bool, optional, default: True
9494
A flag for adding correlation between true and estimates to the plot
9595
add_r2 : bool, optional, default: True
9696
A flag for adding R^2 between true and estimates to the plot
9797
color : str, optional, default: '#8f2727'
98-
The color for the true vs. estimated scatter points and errobars
98+
The color for the true vs. estimated scatter points and error bars
9999
100100
Returns
101101
-------
@@ -144,7 +144,7 @@ def plot_recovery(
144144
if i >= n_params:
145145
break
146146

147-
# Add scatter and errorbars
147+
# Add scatter and error bars
148148
if uncertainty_agg is not None:
149149
_ = ax.errorbar(prior_samples[:, i], est[:, i], yerr=u[:, i], fmt="o", alpha=0.5, color=color)
150150
else:
@@ -242,7 +242,7 @@ def plot_z_score_contraction(
242242
243243
post_contraction = 1 - (posterior_variance / prior_variance)
244244
245-
In other words, the posterior is a proxy for the reduction in ucnertainty gained by
245+
In other words, the posterior is a proxy for the reduction in uncertainty gained by
246246
replacing the prior with the posterior. The ideal posterior contraction tends to 1.
247247
Contraction near zero indicates that the posterior variance is almost identical to
248248
the prior variance for the particular marginal parameter distribution.
@@ -253,7 +253,7 @@ def plot_z_score_contraction(
253253
Toward a principled Bayesian workflow in cognitive science.
254254
Psychological methods, 26(1), 103.
255255
256-
Also available at https://arxiv.org/abs/1904.12765
256+
Paper also available at https://arxiv.org/abs/1904.12765
257257
258258
Parameters
259259
----------
@@ -272,7 +272,7 @@ def plot_z_score_contraction(
272272
tick_fontsize : int, optional, default: 12
273273
The font size of the axis ticklabels
274274
color : str, optional, default: '#8f2727'
275-
The color for the true vs. estimated scatter points and errobars
275+
The color for the true vs. estimated scatter points and error bars
276276
277277
Returns
278278
-------
@@ -887,21 +887,21 @@ def plot_losses(
887887
lw=lw_val,
888888
label="Validation",
889889
)
890-
# Schmuck
890+
# Schmuck
891891
ax.set_xlabel("Training step #", fontsize=label_fontsize)
892892
ax.set_ylabel("Loss value", fontsize=label_fontsize)
893893
sns.despine(ax=ax)
894894
ax.grid(alpha=grid_alpha)
895895
ax.set_title(train_losses.columns[i], fontsize=title_fontsize)
896896
# Only add legend if there is a validation curve
897-
if val_losses is not None:
897+
if val_losses is not None or moving_average:
898898
ax.legend(fontsize=legend_fontsize)
899899
f.tight_layout()
900900
return f
901901

902902

903903
def plot_prior2d(prior, param_names=None, n_samples=2000, height=2.5, color="#8f2727", **kwargs):
904-
"""Creates pairplots for a given joint prior.
904+
"""Creates pair-plots for a given joint prior.
905905
906906
Parameters
907907
----------
@@ -913,7 +913,7 @@ def plot_prior2d(prior, param_names=None, n_samples=2000, height=2.5, color="#8f
913913
The number of random draws from the joint prior
914914
height : float, optional, default: 2.5
915915
The height of the pair plot
916-
color : str, optional, defailt : '#8f2727'
916+
color : str, optional, default : '#8f2727'
917917
The color of the plot
918918
**kwargs : dict, optional
919919
Additional keyword arguments passed to the sns.PairGrid constructor
@@ -943,14 +943,16 @@ def plot_prior2d(prior, param_names=None, n_samples=2000, height=2.5, color="#8f
943943
# Generate plots
944944
g = sns.PairGrid(data_to_plot, height=height, **kwargs)
945945
g.map_diag(sns.histplot, fill=True, color=color, alpha=0.9, kde=True)
946-
# Kernel density estimation (KDE) may not always be possible (e.g. with parameters whose correlation is close to 1 or -1).
946+
947+
# Kernel density estimation (KDE) may not always be possible
948+
# (e.g. with parameters whose correlation is close to 1 or -1).
947949
# In this scenario, a scatter-plot is generated instead.
948950
try:
949951
g.map_lower(sns.kdeplot, fill=True, color=color, alpha=0.9)
950952
except Exception as e:
951-
logging.warn("KDE failed due to the following exception:\n" + repr(e) + "\nSubstituting scatter plot.")
952-
g.map_lower(plt.scatter, alpha=0.6, s=40, edgecolor="k", color=color)
953-
g.map_upper(plt.scatter, alpha=0.6, s=40, edgecolor="k", color=color)
953+
logging.warning("KDE failed due to the following exception:\n" + repr(e) + "\nSubstituting scatter plot.")
954+
g.map_lower(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color)
955+
g.map_upper(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color)
954956

955957
# Add grids
956958
for i in range(dim):
@@ -961,8 +963,8 @@ def plot_prior2d(prior, param_names=None, n_samples=2000, height=2.5, color="#8f
961963

962964

963965
def plot_latent_space_2d(z_samples, height=2.5, color="#8f2727", **kwargs):
964-
"""Creates pairplots for the latent space learned by the inference network. Enables
965-
visual inspection of the the latent space and whether its structrue corresponds to the
966+
"""Creates pair plots for the latent space learned by the inference network. Enables
967+
visual inspection of the latent space and whether its structure corresponds to the
966968
one enforced by the optimization criterion.
967969
968970
Parameters
@@ -971,7 +973,7 @@ def plot_latent_space_2d(z_samples, height=2.5, color="#8f2727", **kwargs):
971973
The latent samples computed through a forward pass of the inference network.
972974
height : float, optional, default: 2.5
973975
The height of the pair plot.
974-
color : str, optional, defailt : '#8f2727'
976+
color : str, optional, default : '#8f2727'
975977
The color of the plot
976978
**kwargs : dict, optional
977979
Additional keyword arguments passed to the sns.PairGrid constructor
@@ -996,7 +998,7 @@ def plot_latent_space_2d(z_samples, height=2.5, color="#8f2727", **kwargs):
996998
g = sns.PairGrid(data_to_plot, height=height, **kwargs)
997999
g.map_diag(sns.histplot, fill=True, color=color, alpha=0.9, kde=True)
9981000
g.map_lower(sns.kdeplot, fill=True, color=color, alpha=0.9)
999-
g.map_upper(plt.scatter, alpha=0.6, s=40, edgecolor="k", color=color)
1001+
g.map_upper(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color)
10001002

10011003
# Add grids
10021004
for i in range(z_dim):
@@ -1060,6 +1062,8 @@ def plot_calibration_curves(
10601062
# Determine n_subplots dynamically
10611063
n_row = int(np.ceil(num_models / 6))
10621064
n_col = int(np.ceil(num_models / n_row))
1065+
1066+
# Compute calibration
10631067
cal_errs, probs_true, probs_pred = expected_calibration_error(true_models, pred_models, num_bins)
10641068

10651069
# Initialize figure
@@ -1094,8 +1098,6 @@ def plot_calibration_curves(
10941098
ax[j].spines["top"].set_visible(False)
10951099
ax[j].set_xlim([0 - epsilon, 1 + epsilon])
10961100
ax[j].set_ylim([0 - epsilon, 1 + epsilon])
1097-
ax[j].set_xlabel("Predicted probability", fontsize=label_fontsize)
1098-
ax[j].set_ylabel("True probability", fontsize=label_fontsize)
10991101
ax[j].set_xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
11001102
ax[j].set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
11011103
ax[j].grid(alpha=0.5)
@@ -1111,6 +1113,18 @@ def plot_calibration_curves(
11111113
size=legend_fontsize,
11121114
)
11131115

1116+
# Only add x-labels to the bottom row
1117+
bottom_row = axarr if n_row == 1 else axarr[0] if n_col == 1 else axarr[n_row - 1, :]
1118+
for _ax in bottom_row:
1119+
_ax.set_xlabel("Predicted probability", fontsize=label_fontsize)
1120+
1121+
# Only add y-labels to left-most row
1122+
if n_row == 1: # if there is only one row, the ax array is 1D
1123+
ax[0].set_ylabel("True probability", fontsize=label_fontsize)
1124+
else: # if there is more than one row, the ax array is 2D
1125+
for _ax in axarr[:, 0]:
1126+
_ax.set_ylabel("True probability", fontsize=label_fontsize)
1127+
11141128
fig.tight_layout()
11151129
return fig
11161130

@@ -1223,32 +1237,31 @@ def plot_mmd_hypothesis_test(
12231237
12241238
Parameters
12251239
----------
1226-
mmd_null: np.ndarray
1227-
samples from the MMD sampling distribution under the null hypothesis "the model is well-specified"
1228-
mmd_observed: float
1229-
observed MMD value
1230-
alpha_level: float
1231-
rejection probability (type I error)
1232-
null_color: color
1233-
color for the H0 sampling distribution
1234-
observed_color: color
1235-
color for the observed MMD
1236-
alpha_color: color
1237-
color for the rejection area
1240+
mmd_null : np.ndarray
1241+
The samples from the MMD sampling distribution under the null hypothesis "the model is well-specified"
1242+
mmd_observed : float
1243+
The observed MMD value
1244+
alpha_level : float
1245+
The rejection probability (type I error)
1246+
null_color : str or tuple
1247+
The color of the H0 sampling distribution
1248+
observed_color : str or tuple
1249+
The color of the observed MMD
1250+
alpha_color : str or tuple
1251+
The color of the rejection area
12381252
truncate_vlines_at_kde: bool
12391253
true: cut off the vlines at the kde
12401254
false: continue kde lines across the plot
1241-
xmin: float
1242-
lower x axis limit
1243-
xmax: float
1244-
upper x axis limit
1245-
bw_factor: float, default: 1.5
1255+
xmin : float
1256+
The lower x-axis limit
1257+
xmax : float
1258+
The upper x-axis limit
1259+
bw_factor : float, optional, default: 1.5
12461260
bandwidth (aka. smoothing parameter) of the kernel density estimate
12471261
12481262
Returns
12491263
-------
12501264
f : plt.Figure - the figure instance for optional saving
1251-
12521265
"""
12531266

12541267
def draw_vline_to_kde(x, kde_object, color, label=None, **kwargs):

bayesflow/losses.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
# SOFTWARE.
2020

2121
import tensorflow as tf
22-
import tensorflow_probability as tfp
2322

2423
from bayesflow.computational_utilities import maximum_mean_discrepancy
2524

@@ -62,7 +61,7 @@ def kl_latent_space_student(v, z, log_det_J):
6261
z : tf.Tensor of shape (batch_size, ...)
6362
The (latent transformed) target variables
6463
log_det_J : tf.Tensor of shape (batch_size, ...)
65-
The logartihm of the Jacobian determinant of the transformation.
64+
The logarithm of the Jacobian determinant of the transformation.
6665
6766
Returns
6867
-------
@@ -131,7 +130,7 @@ def mmd_summary_space(summary_outputs, z_dist=tf.random.normal, kernel="gaussian
131130
The kernel function to use for MMD computation.
132131
"""
133132

134-
z_samples = z_dist(summary_outputs.shape)
133+
z_samples = z_dist(tf.shape(summary_outputs))
135134
mmd_loss = maximum_mean_discrepancy(summary_outputs, z_samples, kernel)
136135
return mmd_loss
137136

0 commit comments

Comments
 (0)