File tree Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Original file line number Diff line number Diff line change 21
21
import logging
22
22
import os
23
23
from pickle import load as pickle_load
24
- import tensorflow as tf
25
24
26
25
import numpy as np
26
+ import tensorflow as tf
27
27
from tqdm .autonotebook import tqdm
28
28
29
29
from bayesflow .amortizers import (
@@ -737,7 +737,10 @@ def train_from_presimulation(
737
737
input_dict = self .configurator (epoch_data [index ])
738
738
739
739
# Like the number of iterations, the batch size is inferred from presimulated dictionary or list
740
- batch_size = epoch_data [index ][DEFAULT_KEYS ["sim_data" ]].shape [0 ]
740
+ if isinstance (self .amortizer , AmortizedModelComparison ):
741
+ batch_size = input_dict [DEFAULT_KEYS ["summary_conditions" ]].shape [0 ]
742
+ else :
743
+ batch_size = epoch_data [index ][DEFAULT_KEYS ["sim_data" ]].shape [0 ]
741
744
loss = self ._train_step (batch_size , _backprop_step , input_dict , ** kwargs )
742
745
743
746
# Store returned loss
You can’t perform that action at this time.
0 commit comments