diff --git a/init2winit/trainer_lib/base_trainer.py b/init2winit/trainer_lib/base_trainer.py index 6ec7bcba..3a09fab2 100644 --- a/init2winit/trainer_lib/base_trainer.py +++ b/init2winit/trainer_lib/base_trainer.py @@ -618,7 +618,7 @@ def train(self): batch = next(train_iter) batch = self.finalize_batch_fn(batch) if self._global_step == 0: - batch_size_pytree = trainer_utils.get_batch_size(batch) + batch_size_pytree = self.get_batch_size(batch) if any( bsz != self._hps.batch_size for bsz in jax.tree.leaves(batch_size_pytree) @@ -673,10 +673,10 @@ def train(self): @abc.abstractmethod def update(self, batch, rng, metrics_update_fn, metrics_state, training_cost): """Single step of the training loop. - + Note this method is responsible for updating the private _global_step attribute of the Trainer. - + Args: batch: the per-device batch of data to process. rng: the RNG used for calling the model. `step` and `local_device_index` @@ -723,3 +723,6 @@ def finalize_batch_fn(self, batch): def fetch_learning_rate(self, optimizer_state): return trainer_utils.fetch_learning_rate(optimizer_state) + + def get_batch_size(self, batch): + return trainer_utils.get_batch_size(batch) diff --git a/init2winit/trainer_lib/trainer_utils.py b/init2winit/trainer_lib/trainer_utils.py index 4fba6352..f414d827 100644 --- a/init2winit/trainer_lib/trainer_utils.py +++ b/init2winit/trainer_lib/trainer_utils.py @@ -192,6 +192,7 @@ def evaluate( metrics = metrics.merge(computed_metrics) metrics = jax.device_get(process_allgather(metrics, tiled=True)) + metrics = jax.tree_util.tree_map(lambda x: x[0] if x.ndim > 0 else x, metrics) # For data splits with no data (e.g. Imagenet no test set) no values # will appear for that split. if metrics is not None: diff --git a/init2winit/trainer_lib/training_algorithm.py b/init2winit/trainer_lib/training_algorithm.py index b8d19d65..e0a8ca7f 100644 --- a/init2winit/trainer_lib/training_algorithm.py +++ b/init2winit/trainer_lib/training_algorithm.py @@ -155,7 +155,7 @@ def update_params( @abc.abstractmethod def init_optimizer_state( self, - workload=None, + model=None, params=None, model_state=None, hyperparameters=None, @@ -164,7 +164,7 @@ def init_optimizer_state( """Initializes the optimizer state. Args: - workload: The workload being trained. + model: The model being trained. params: The initial model parameters. model_state: The initial state of the model. hyperparameters: The hyperparameters for the training. @@ -273,7 +273,7 @@ def update_params( def init_optimizer_state( self, - workload=None, + model=None, params=None, model_state=None, hyperparameters=None, @@ -282,7 +282,7 @@ def init_optimizer_state( """Initializes the optimizer state. Args: - workload: The workload being trained. + model: The model being trained. params: The initial model parameters. model_state: The initial state of the model. hyperparameters: The hyperparameters for the training. @@ -291,7 +291,7 @@ def init_optimizer_state( Returns: Optimizer state: Pytree of optimizer state. """ - del workload, model_state, hyperparameters, rng # Unused + del model, model_state, hyperparameters, rng # Unused stretch_factor = 1 if self.hps.get('total_accumulated_batch_size') is not None: stretch_factor = (