Skip to content

Commit 26de450

Browse files
authored
Merge pull request #92 from elseml/Development
Fix train_from_presimulation for model comparison
2 parents 1d7cb0f + d32ef29 commit 26de450

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

bayesflow/trainers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
import logging
2222
import os
2323
from pickle import load as pickle_load
24-
import tensorflow as tf
2524

2625
import numpy as np
26+
import tensorflow as tf
2727
from tqdm.autonotebook import tqdm
2828

2929
from bayesflow.amortizers import (
@@ -737,7 +737,10 @@ def train_from_presimulation(
737737
input_dict = self.configurator(epoch_data[index])
738738

739739
# Like the number of iterations, the batch size is inferred from presimulated dictionary or list
740-
batch_size = epoch_data[index][DEFAULT_KEYS["sim_data"]].shape[0]
740+
if isinstance(self.amortizer, AmortizedModelComparison):
741+
batch_size = input_dict[DEFAULT_KEYS["summary_conditions"]].shape[0]
742+
else:
743+
batch_size = epoch_data[index][DEFAULT_KEYS["sim_data"]].shape[0]
741744
loss = self._train_step(batch_size, _backprop_step, input_dict, **kwargs)
742745

743746
# Store returned loss

0 commit comments

Comments
 (0)