Skip to content

Commit 0709a3c

Browse files
committed
Docfix and raise ConfigurationError if summary_loss_fun is provided without a summary_net
1 parent b3c562c commit 0709a3c

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

bayesflow/amortizers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,10 @@ def _determine_latent_dist(self, latent_dist):
427427
def _determine_summary_loss(self, loss_fun):
428428
"""Determines which summary loss to use if default `None` argument provided, otherwise return identity."""
429429

430+
# Throw, if summary loss without a summary network provided
431+
if loss_fun is not None and self.summary_net is None:
432+
raise ConfigurationError('You need to provide a summary_net if you want to use a summary_loss_fun.')
433+
430434
# If callable, return provided loss
431435
if loss_fun is None or callable(loss_fun):
432436
return loss_fun
@@ -566,12 +570,12 @@ def sample(self, input_dict, n_samples, to_numpy=True, **kwargs):
566570
return lik_samples
567571

568572
def sample_loop(self, input_list, n_samples, to_numpy=True, **kwargs):
569-
"""Generates random draws from the surrogate network given a list of dicts with conditonal variables.
573+
"""Generates random draws from the surrogate network given a list of dicts with conditional variables.
570574
Useful when GPU memory is limited or data sets have a different (non-Tensor) structure.
571575
572576
Parameters
573577
----------
574-
input_list : list of dictionaries, each dictionary having the following mandatory keys, if ``DEFAULT_KEYS`` unchanged:
578+
input_list : list of dictionaries, each dictionary having the following mandatory keys (default):
575579
``conditions`` - the conditioning variables that the directly passed to the surrogate network
576580
n_samples : int
577581
The number of posterior draws (samples) to obtain from the approximate posterior

0 commit comments

Comments
 (0)