Skip to content

Commit 18214da

Browse files
Merge pull request #75 from elseml/Development
Set memory=False as default to prevent conflict with checkpoint
2 parents 749c0a6 + fc7c034 commit 18214da

File tree

1 file changed

+9
-12
lines changed

1 file changed

+9
-12
lines changed

bayesflow/trainers.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,11 @@
3535
AmortizedPosterior,
3636
AmortizedPosteriorLikelihood,
3737
)
38+
from bayesflow.computational_utilities import maximum_mean_discrepancy
3839
from bayesflow.configuration import *
3940
from bayesflow.default_settings import DEFAULT_KEYS, OPTIMIZER_DEFAULTS
4041
from bayesflow.diagnostics import plot_latent_space_2d, plot_sbc_histograms
41-
from bayesflow.exceptions import SimulationError, ArgumentError
42+
from bayesflow.exceptions import ArgumentError, SimulationError
4243
from bayesflow.helper_classes import (
4344
EarlyStopper,
4445
LossHistory,
@@ -49,7 +50,6 @@
4950
)
5051
from bayesflow.helper_functions import backprop_step, extract_current_lr, format_loss_string, loss_to_string
5152
from bayesflow.simulation import GenerativeModel, MultiGenerativeModel
52-
from bayesflow.computational_utilities import maximum_mean_discrepancy
5353

5454

5555
class Trainer:
@@ -116,7 +116,7 @@ def __init__(
116116
max_to_keep=3,
117117
default_lr=0.0005,
118118
skip_checks=False,
119-
memory=True,
119+
memory=False,
120120
**kwargs,
121121
):
122122
"""Creates a trainer which will use a generative model (or data simulated from it) to optimize
@@ -139,7 +139,7 @@ def __init__(
139139
The default learning rate to use for default optimizers.
140140
skip_checks : bool, optional, default: False
141141
If True, do not perform consistency checks, i.e., simulator runs and passed through nets
142-
memory : bool or bayesflow.SimulationMemory, optional, default: True
142+
memory : bool or bayesflow.SimulationMemory, optional, default: False
143143
If ``True``, store a pre-defined amount of simulations for later use (validation, etc.).
144144
If ``SimulationMemory`` instance provided, stores a reference to the instance.
145145
Otherwise the corresponding attribute will be set to None.
@@ -1010,12 +1010,9 @@ def train_rounds(
10101010
self.optimizer = None
10111011
return self.loss_history.get_plottable()
10121012

1013-
def mmd_hypothesis_test(self,
1014-
observed_data,
1015-
reference_data=None,
1016-
num_reference_simulations=1000,
1017-
num_null_samples=100,
1018-
bootstrap=False):
1013+
def mmd_hypothesis_test(
1014+
self, observed_data, reference_data=None, num_reference_simulations=1000, num_null_samples=100, bootstrap=False
1015+
):
10191016
"""
10201017
10211018
Parameters
@@ -1048,12 +1045,12 @@ def mmd_hypothesis_test(self,
10481045

10491046
reference_data = self.configurator(self.generative_model(num_reference_simulations))
10501047

1051-
if type(reference_data) == dict and 'summary_conditions' in reference_data.keys():
1048+
if type(reference_data) == dict and "summary_conditions" in reference_data.keys():
10521049
reference_summary = self.amortizer.summary_net(reference_data["summary_conditions"])
10531050
else:
10541051
reference_summary = self.amortizer.summary_net(reference_data)
10551052

1056-
if type(observed_data) == dict and 'summary_conditions' in observed_data.keys():
1053+
if type(observed_data) == dict and "summary_conditions" in observed_data.keys():
10571054
observed_summary = self.amortizer.summary_net(observed_data["summary_conditions"])
10581055
else:
10591056
observed_summary = self.amortizer.summary_net(observed_data)

0 commit comments

Comments
 (0)