diff --git a/init2winit/checkpoint.py b/init2winit/checkpoint.py index e4d0f1ec..eb94a1c2 100644 --- a/init2winit/checkpoint.py +++ b/init2winit/checkpoint.py @@ -90,9 +90,11 @@ def maybe_restore_checkpoint( global_step=uninitialized_global_step, preemption_count=0, sum_train_cost=0.0) + logging.info('Loading latest checkpoint from train_dir: %s', train_dir) latest_ckpt = load_latest_checkpoint(train_dir, target=unreplicated_checkpoint_state, orbax_checkpointer=orbax_checkpointer) + logging.info('Loading checkpoint from train_dir %s complete.', train_dir) # Load_latest_checkpoint() will return unreplicated_checkpoint_state if # train_dir does not exist or if it exists and contains no checkpoints. # Note that we could likely change the below line to: diff --git a/init2winit/model_lib/base_model.py b/init2winit/model_lib/base_model.py index eecb2bcf..be089c62 100644 --- a/init2winit/model_lib/base_model.py +++ b/init2winit/model_lib/base_model.py @@ -17,7 +17,9 @@ import copy import functools +import time +from absl import logging from flax import linen as nn from init2winit import utils from init2winit.model_lib import losses @@ -197,17 +199,28 @@ def initialize(self, initializer, hps, rng, metrics_logger): # mess up Pythonic boolean statements like `not train` inside the model # construction. # We initialize model params on host to avoid memory issues. + + start_time = time.time() model_init_fn = jax.jit( functools.partial(self.flax_module.init, train=False), backend='cpu') + init_dict = model_init_fn({'params': params_rng, 'dropout': dropout_rng}, *fake_input_batch) + + logging.info( + 'Flax module init call took %f seconds.', + time.time() - start_time, + ) + start_time = time.time() # Trainable model parameters. params = init_dict['params'] batch_stats = init_dict.get('batch_stats', {}) if hps.get('layer_rescale_factors'): params = model_utils.rescale_layers(params, hps.layer_rescale_factors) + logging.info('Layers rescaled in %f seconds.', time.time() - start_time) + start_time = time.time() # We don't pass batch_stats to the initializer, the initializer will just # run batch_norm in train mode and does not need to maintain the # batch_stats. @@ -216,12 +229,25 @@ def initialize(self, initializer, hps, rng, metrics_logger): # hyper_param? # TODO(gilmer): instead of passing in weighted_xent, pass in the model and # get the loss from that. - params = initializer(self.loss_fn, self.flax_module, params, hps, - hps.input_shape, hps.output_shape, init_rng, - metrics_logger) + params = initializer( + self.loss_fn, + self.flax_module, + params, + hps, + hps.input_shape, + hps.output_shape, + init_rng, + metrics_logger, + ) + logging.info('Initializer run in %f seconds.', time.time() - start_time) + start_time = time.time() self._param_shapes = model_utils.param_shapes(params) self._param_types = model_utils.param_types(self._param_shapes) + logging.info( + 'Param shapes and types computed in %f seconds.', + time.time() - start_time, + ) return params, batch_stats diff --git a/init2winit/trainer_lib/base_trainer.py b/init2winit/trainer_lib/base_trainer.py index ed9ddd67..0c219cf8 100644 --- a/init2winit/trainer_lib/base_trainer.py +++ b/init2winit/trainer_lib/base_trainer.py @@ -217,8 +217,8 @@ def wait_until_orbax_checkpointer_finished(self): def log_model_info(self, unreplicated_params): if jax.process_index() == 0: utils.log_pytree_shape_and_statistics(unreplicated_params) - logging.info('train_size: %d,', self._hps.train_size) utils.tabulate_model(self._model, self._hps) + logging.info('train_size: %d,', self._hps.train_size) def maybe_restore_from_checkpoint(self, unreplicated_optimizer_state, @@ -486,19 +486,25 @@ def setup_and_maybe_restore(self, init_rng, data_rng, callback_rng): data_rng: (jax.random.PRNGKey) Rng seed used in data shuffling. callback_rng: (jax.random.PRNGKey) Rng seed used in callback functions. """ + start_time = time.time() self.training_algorithm = self._training_algorithm_class( self._hps, self._model, self._num_train_steps ) - + logging.info( + 'Training algorithm set up in %f seconds', time.time() - start_time + ) + start_time = time.time() unreplicated_params, unreplicated_batch_stats = self._model.initialize( self._initializer, self._hps, init_rng, self._init_logger, ) + logging.info('Model initialized in %f seconds', time.time() - start_time) self.log_model_info(unreplicated_params) + start_time = time.time() unreplicated_optimizer_state = self.training_algorithm.init_optimizer_state( self._model, unreplicated_params, @@ -506,12 +512,16 @@ def setup_and_maybe_restore(self, init_rng, data_rng, callback_rng): self._hps, init_rng, ) + logging.info( + 'Optimizer state initialized in %f seconds', time.time() - start_time + ) unreplicated_metrics_state = None # TODO(kasimbeg): move this to initialization. self._metrics_update_fn = None self._metrics_summary_fn = None + start_time = time.time() if self._training_metrics_config is not None: (metrics_init_fn, self._metrics_update_fn, self._metrics_summary_fn) = make_training_metrics( @@ -519,7 +529,10 @@ def setup_and_maybe_restore(self, init_rng, data_rng, callback_rng): unreplicated_metrics_state = metrics_init_fn( unreplicated_params, unreplicated_batch_stats ) - + logging.info( + 'Metrics initialized in %f seconds', time.time() - start_time + ) + start_time = time.time() ( unreplicated_optimizer_state, unreplicated_params, @@ -532,6 +545,10 @@ def setup_and_maybe_restore(self, init_rng, data_rng, callback_rng): unreplicated_metrics_state, ) + logging.info( + 'Checkpoint restored in %f seconds', time.time() - start_time + ) + start_time = time.time() ( self._params, self._params_sharding, @@ -547,9 +564,13 @@ def setup_and_maybe_restore(self, init_rng, data_rng, callback_rng): unreplicated_batch_stats, unreplicated_metrics_state, ) + logging.info( + 'Training state sharded in %f seconds', time.time() - start_time + ) self._dataset = self.setup_data_loader(data_rng, self._global_step) self._eval_callbacks = self._setup_eval_callbacks(callback_rng) + logging.info('Training state setup complete') def train(self): """All training logic. @@ -578,9 +599,11 @@ def train(self): self.setup_and_maybe_restore(init_rng, data_rng, callback_rng) - if jax.process_index() == 0: - trainer_utils.log_message( - 'Starting training!', self._logging_pool, self._xm_work_unit) + trainer_utils.log_message( + 'Setup and maybe restore completed!', + self._logging_pool, + self._xm_work_unit, + ) train_iter = itertools.islice( self._dataset.train_iterator_fn(),