Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions init2winit/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,11 @@ def maybe_restore_checkpoint(
global_step=uninitialized_global_step,
preemption_count=0,
sum_train_cost=0.0)
logging.info('Loading latest checkpoint from train_dir: %s', train_dir)
latest_ckpt = load_latest_checkpoint(train_dir,
target=unreplicated_checkpoint_state,
orbax_checkpointer=orbax_checkpointer)
logging.info('Loading checkpoint from train_dir %s complete.', train_dir)
# Load_latest_checkpoint() will return unreplicated_checkpoint_state if
# train_dir does not exist or if it exists and contains no checkpoints.
# Note that we could likely change the below line to:
Expand Down
32 changes: 29 additions & 3 deletions init2winit/model_lib/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

import copy
import functools
import time

from absl import logging
from flax import linen as nn
from init2winit import utils
from init2winit.model_lib import losses
Expand Down Expand Up @@ -197,17 +199,28 @@ def initialize(self, initializer, hps, rng, metrics_logger):
# mess up Pythonic boolean statements like `not train` inside the model
# construction.
# We initialize model params on host to avoid memory issues.

start_time = time.time()
model_init_fn = jax.jit(
functools.partial(self.flax_module.init, train=False),
backend='cpu')

init_dict = model_init_fn({'params': params_rng, 'dropout': dropout_rng},
*fake_input_batch)

logging.info(
'Flax module init call took %f seconds.',
time.time() - start_time,
)
start_time = time.time()
# Trainable model parameters.
params = init_dict['params']
batch_stats = init_dict.get('batch_stats', {})

if hps.get('layer_rescale_factors'):
params = model_utils.rescale_layers(params, hps.layer_rescale_factors)
logging.info('Layers rescaled in %f seconds.', time.time() - start_time)
start_time = time.time()
# We don't pass batch_stats to the initializer, the initializer will just
# run batch_norm in train mode and does not need to maintain the
# batch_stats.
Expand All @@ -216,12 +229,25 @@ def initialize(self, initializer, hps, rng, metrics_logger):
# hyper_param?
# TODO(gilmer): instead of passing in weighted_xent, pass in the model and
# get the loss from that.
params = initializer(self.loss_fn, self.flax_module, params, hps,
hps.input_shape, hps.output_shape, init_rng,
metrics_logger)
params = initializer(
self.loss_fn,
self.flax_module,
params,
hps,
hps.input_shape,
hps.output_shape,
init_rng,
metrics_logger,
)
logging.info('Initializer run in %f seconds.', time.time() - start_time)
start_time = time.time()

self._param_shapes = model_utils.param_shapes(params)
self._param_types = model_utils.param_types(self._param_shapes)
logging.info(
'Param shapes and types computed in %f seconds.',
time.time() - start_time,
)

return params, batch_stats

Expand Down
35 changes: 29 additions & 6 deletions init2winit/trainer_lib/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ def wait_until_orbax_checkpointer_finished(self):
def log_model_info(self, unreplicated_params):
if jax.process_index() == 0:
utils.log_pytree_shape_and_statistics(unreplicated_params)
logging.info('train_size: %d,', self._hps.train_size)
utils.tabulate_model(self._model, self._hps)
logging.info('train_size: %d,', self._hps.train_size)

def maybe_restore_from_checkpoint(self,
unreplicated_optimizer_state,
Expand Down Expand Up @@ -486,40 +486,53 @@ def setup_and_maybe_restore(self, init_rng, data_rng, callback_rng):
data_rng: (jax.random.PRNGKey) Rng seed used in data shuffling.
callback_rng: (jax.random.PRNGKey) Rng seed used in callback functions.
"""
start_time = time.time()
self.training_algorithm = self._training_algorithm_class(
self._hps, self._model, self._num_train_steps
)

logging.info(
'Training algorithm set up in %f seconds', time.time() - start_time
)
start_time = time.time()
unreplicated_params, unreplicated_batch_stats = self._model.initialize(
self._initializer,
self._hps,
init_rng,
self._init_logger,
)
logging.info('Model initialized in %f seconds', time.time() - start_time)

self.log_model_info(unreplicated_params)

start_time = time.time()
unreplicated_optimizer_state = self.training_algorithm.init_optimizer_state(
self._model,
unreplicated_params,
unreplicated_batch_stats,
self._hps,
init_rng,
)
logging.info(
'Optimizer state initialized in %f seconds', time.time() - start_time
)

unreplicated_metrics_state = None
# TODO(kasimbeg): move this to initialization.
self._metrics_update_fn = None
self._metrics_summary_fn = None

start_time = time.time()
if self._training_metrics_config is not None:
(metrics_init_fn, self._metrics_update_fn,
self._metrics_summary_fn) = make_training_metrics(
self._num_train_steps, self._hps, **self._training_metrics_config)
unreplicated_metrics_state = metrics_init_fn(
unreplicated_params, unreplicated_batch_stats
)

logging.info(
'Metrics initialized in %f seconds', time.time() - start_time
)
start_time = time.time()
(
unreplicated_optimizer_state,
unreplicated_params,
Expand All @@ -532,6 +545,10 @@ def setup_and_maybe_restore(self, init_rng, data_rng, callback_rng):
unreplicated_metrics_state,
)

logging.info(
'Checkpoint restored in %f seconds', time.time() - start_time
)
start_time = time.time()
(
self._params,
self._params_sharding,
Expand All @@ -547,9 +564,13 @@ def setup_and_maybe_restore(self, init_rng, data_rng, callback_rng):
unreplicated_batch_stats,
unreplicated_metrics_state,
)
logging.info(
'Training state sharded in %f seconds', time.time() - start_time
)

self._dataset = self.setup_data_loader(data_rng, self._global_step)
self._eval_callbacks = self._setup_eval_callbacks(callback_rng)
logging.info('Training state setup complete')

def train(self):
"""All training logic.
Expand Down Expand Up @@ -578,9 +599,11 @@ def train(self):

self.setup_and_maybe_restore(init_rng, data_rng, callback_rng)

if jax.process_index() == 0:
trainer_utils.log_message(
'Starting training!', self._logging_pool, self._xm_work_unit)
trainer_utils.log_message(
'Setup and maybe restore completed!',
self._logging_pool,
self._xm_work_unit,
)

train_iter = itertools.islice(
self._dataset.train_iterator_fn(),
Expand Down