35
35
AmortizedPosterior ,
36
36
AmortizedPosteriorLikelihood ,
37
37
)
38
+ from bayesflow .computational_utilities import maximum_mean_discrepancy
38
39
from bayesflow .configuration import *
39
40
from bayesflow .default_settings import DEFAULT_KEYS , OPTIMIZER_DEFAULTS
40
41
from bayesflow .diagnostics import plot_latent_space_2d , plot_sbc_histograms
41
- from bayesflow .exceptions import SimulationError , ArgumentError
42
+ from bayesflow .exceptions import ArgumentError , SimulationError
42
43
from bayesflow .helper_classes import (
43
44
EarlyStopper ,
44
45
LossHistory ,
49
50
)
50
51
from bayesflow .helper_functions import backprop_step , extract_current_lr , format_loss_string , loss_to_string
51
52
from bayesflow .simulation import GenerativeModel , MultiGenerativeModel
52
- from bayesflow .computational_utilities import maximum_mean_discrepancy
53
53
54
54
55
55
class Trainer :
@@ -116,7 +116,7 @@ def __init__(
116
116
max_to_keep = 3 ,
117
117
default_lr = 0.0005 ,
118
118
skip_checks = False ,
119
- memory = True ,
119
+ memory = False ,
120
120
** kwargs ,
121
121
):
122
122
"""Creates a trainer which will use a generative model (or data simulated from it) to optimize
@@ -139,7 +139,7 @@ def __init__(
139
139
The default learning rate to use for default optimizers.
140
140
skip_checks : bool, optional, default: False
141
141
If True, do not perform consistency checks, i.e., simulator runs and passed through nets
142
- memory : bool or bayesflow.SimulationMemory, optional, default: True
142
+ memory : bool or bayesflow.SimulationMemory, optional, default: False
143
143
If ``True``, store a pre-defined amount of simulations for later use (validation, etc.).
144
144
If ``SimulationMemory`` instance provided, stores a reference to the instance.
145
145
Otherwise the corresponding attribute will be set to None.
@@ -1010,12 +1010,9 @@ def train_rounds(
1010
1010
self .optimizer = None
1011
1011
return self .loss_history .get_plottable ()
1012
1012
1013
- def mmd_hypothesis_test (self ,
1014
- observed_data ,
1015
- reference_data = None ,
1016
- num_reference_simulations = 1000 ,
1017
- num_null_samples = 100 ,
1018
- bootstrap = False ):
1013
+ def mmd_hypothesis_test (
1014
+ self , observed_data , reference_data = None , num_reference_simulations = 1000 , num_null_samples = 100 , bootstrap = False
1015
+ ):
1019
1016
"""
1020
1017
1021
1018
Parameters
@@ -1048,12 +1045,12 @@ def mmd_hypothesis_test(self,
1048
1045
1049
1046
reference_data = self .configurator (self .generative_model (num_reference_simulations ))
1050
1047
1051
- if type (reference_data ) == dict and ' summary_conditions' in reference_data .keys ():
1048
+ if type (reference_data ) == dict and " summary_conditions" in reference_data .keys ():
1052
1049
reference_summary = self .amortizer .summary_net (reference_data ["summary_conditions" ])
1053
1050
else :
1054
1051
reference_summary = self .amortizer .summary_net (reference_data )
1055
1052
1056
- if type (observed_data ) == dict and ' summary_conditions' in observed_data .keys ():
1053
+ if type (observed_data ) == dict and " summary_conditions" in observed_data .keys ():
1057
1054
observed_summary = self .amortizer .summary_net (observed_data ["summary_conditions" ])
1058
1055
else :
1059
1056
observed_summary = self .amortizer .summary_net (observed_data )
0 commit comments