diff --git a/hessian/model_debugger.py b/hessian/model_debugger.py index 943f0b07..9825653d 100644 --- a/hessian/model_debugger.py +++ b/hessian/model_debugger.py @@ -18,11 +18,9 @@ """ import functools -import os import flax import flax.linen as nn -from init2winit.checkpoint import load_pytree from init2winit.model_lib import partition_tree from init2winit.utils import array_append from init2winit.utils import tree_norm_sql2 @@ -342,11 +340,8 @@ def __init__(self, self._stored_metrics = {} # In the case of preemption we want to restore prior metrics. - if metrics_logger: - metrics_file = os.path.join(metrics_logger._pytree_path, - 'training_metrics') - if exists(metrics_file): - self._stored_metrics = load_pytree(metrics_file) + if metrics_logger and metrics_logger.latest_pytree_checkpoint_step(): + self._stored_metrics = metrics_logger.load_latest_pytree() def _grab_statistics(self, step, @@ -380,9 +375,9 @@ def _maybe_save_metrics(self, step): save_dict = self._stored_metrics.copy() if self._save_every: if step % self._save_every == 0: - self._metrics_logger.write_pytree(save_dict) + self._metrics_logger.write_pytree(save_dict, step=step) else: - self._metrics_logger.write_pytree(save_dict) + self._metrics_logger.write_pytree(save_dict, step=step) @property def stored_metrics(self): diff --git a/hessian/test_model_debugger.py b/hessian/test_model_debugger.py index 3c0b98fe..47f00204 100644 --- a/hessian/test_model_debugger.py +++ b/hessian/test_model_debugger.py @@ -24,7 +24,6 @@ from absl.testing import absltest import flax from flax import linen as nn -from init2winit import checkpoint from init2winit import utils from init2winit.hessian import model_debugger from init2winit.hessian.model_debugger import skip_bwd @@ -282,11 +281,10 @@ def fake_loss_fn(ps, x, rng, module_flags=None): self.assertEqual(og['B_0']['Dense_0']['kernel'], 16.0) self.assertEqual(og['C_0']['Dense_0']['kernel'], 36.0) - def test_model_debugger_pmap(self): + def test_model_debugger_restore(self): """Test training for two epochs on MNIST with a small model.""" rep_variables = set_up_cnn() - pytree_path = os.path.join(self.test_dir, 'metrics') metrics_logger = utils.MetricLogger( pytree_path=pytree_path, events_dir=self.test_dir) @@ -297,18 +295,18 @@ def grad_fn(params, batch, rng): return rep_variables['params'] debugger = model_debugger.ModelDebugger( - use_pmap=True, grad_fn=grad_fn, metrics_logger=metrics_logger) + use_pmap=False, grad_fn=grad_fn, metrics_logger=metrics_logger) # eval twice to test the concat extra_metrics = {'train_loss': 1.0} extra_metrics2 = {'train_loss': 1.0} - metrics = debugger.full_eval( + _ = debugger.full_eval( 10, params=rep_variables['params'], grad=rep_variables['params'], extra_scalar_metrics=extra_metrics) metrics = debugger.full_eval( - 10, + 20, params=rep_variables['params'], grad=None, # use internal gradient comp extra_scalar_metrics=extra_metrics2) @@ -321,9 +319,8 @@ def grad_fn(params, batch, rng): 'train_loss', ] - metrics_file = os.path.join(self.test_dir, 'metrics/training_metrics') - - loaded_metrics = checkpoint.load_checkpoint(metrics_file)['pytree'] + metrics_logger.wait_until_pytree_checkpoint_finished() + loaded_metrics = metrics_logger.load_latest_pytree(None) self.assertEqual(set(expected_keys), set(metrics.keys())) expected_shape = () @@ -341,11 +338,12 @@ def grad_fn(params, batch, rng): # Test restore of prior metrics. new_debugger = model_debugger.ModelDebugger( use_pmap=True, metrics_logger=metrics_logger) - metrics = new_debugger.full_eval( - 10, + _ = new_debugger.full_eval( + 30, params=rep_variables['params'], grad=rep_variables['params'], extra_scalar_metrics=extra_metrics2) + metrics_logger.wait_until_pytree_checkpoint_finished() self.assertEqual( new_debugger.stored_metrics['param_norms_sql2']['Conv_0'] ['kernel'].shape, (3,)) diff --git a/init2winit/checkpoint.py b/init2winit/checkpoint.py index eb94a1c2..d1b02303 100644 --- a/init2winit/checkpoint.py +++ b/init2winit/checkpoint.py @@ -18,24 +18,23 @@ This is useful for training neural networks with stax, where model parameters are nested numpy arrays. """ -import os -import sys - from absl import flags from absl import logging -from flax.training import checkpoints as flax_checkpoints import jax # pylint: disable=g-importing-member from jax.experimental.multihost_utils import process_allgather +import orbax.checkpoint as ocp FLAGS = flags.FLAGS -def load_pytree(pytree_file, orbax_checkpointer=None): - """Loads the checkpointed pytree.""" - latest = load_latest_checkpoint(pytree_file, - target=None, - orbax_checkpointer=orbax_checkpointer) +def load_pytree(pytree_file, orbax_checkpoint_manager=None): + """Loads a checkpointed pytree.""" + if not orbax_checkpoint_manager: + orbax_checkpoint_manager = ocp.CheckpointManager(pytree_file) + latest = load_latest_checkpoint( + target=None, orbax_checkpoint_manager=orbax_checkpoint_manager + ) if latest: # Because we pass target=None, flax checkpointing will return the raw # state dict, where 'state' will be a dict with keys ['0', '1', ...] @@ -49,15 +48,13 @@ def maybe_restore_checkpoint( unreplicated_params, unreplicated_batch_stats, unreplicated_training_metrics_state, - train_dir, - external_checkpoint_path=None, - orbax_checkpointer=None): + orbax_checkpoint_manager=None, + orbax_checkpoint_manager_external=None): """Optionally restores from a checkpoint. - The checkpoint logic is as follows: if there is a checkpoint in `train_dir`, - restore it. Else, if `external_checkpoint_path` is set, restore the - checkpoint found there. Else, don't restore any checkpoint, and just - return the passed-in optimizer_state, params, batch_stats, and + The checkpoint logic is as follows: if `orbax_checkpoint_manager` contains + a latest checkpoint, restore it. Otherwise, don't restore any checkpoint, + and just return the passed-in optimizer_state, params, batch_stats, and metrics_grabber. Args: @@ -65,10 +62,8 @@ def maybe_restore_checkpoint( unreplicated_params: unreplicated params unreplicated_batch_stats: unreplicated batch stats unreplicated_training_metrics_state: unreplicated metrics state - train_dir: (str) The training directory where we will look for a checkpoint. - external_checkpoint_path: (str) If this argument is set, then we will load - the external checkpoint stored there. - orbax_checkpointer: orbax.Checkpointer + orbax_checkpoint_manager: orbax.CheckpointManager + orbax_checkpoint_manager_external: orbax.CheckpointManager Returns: unreplicated_optimizer_state @@ -89,12 +84,14 @@ def maybe_restore_checkpoint( training_metrics_grabber=unreplicated_training_metrics_state, global_step=uninitialized_global_step, preemption_count=0, - sum_train_cost=0.0) - logging.info('Loading latest checkpoint from train_dir: %s', train_dir) - latest_ckpt = load_latest_checkpoint(train_dir, - target=unreplicated_checkpoint_state, - orbax_checkpointer=orbax_checkpointer) - logging.info('Loading checkpoint from train_dir %s complete.', train_dir) + sum_train_cost=0.0, + ) + logging.info('Loading latest checkpoint') + latest_ckpt = load_latest_checkpoint( + target=unreplicated_checkpoint_state, + orbax_checkpoint_manager=orbax_checkpoint_manager, + ) + logging.info('Loading checkpoint from complete.') # Load_latest_checkpoint() will return unreplicated_checkpoint_state if # train_dir does not exist or if it exists and contains no checkpoints. # Note that we could likely change the below line to: @@ -105,23 +102,14 @@ def maybe_restore_checkpoint( if found_checkpoint: ckpt_to_return = latest_ckpt is_restored = True # We do want trainer to increment preemption_count. - logging.info('Restoring checkpoint from ckpt_%d', - latest_ckpt['global_step']) - # Else, if external_checkpoint_path is non-null, restore from that checkpoint. - elif external_checkpoint_path is not None: - # TODO(jeremycohen) This code will crash if we try to load an external - # checkpoint which was trained with a different num_train_steps. The issue - # is that some of the fields in the training metrics state are arrays of - # shape [num_train_steps]. In the future we may want to handle these - # arrays explicitly, in order to avoid this crash. logging.info( - 'Restoring checkpoint from external_checkpoint_path %s', - external_checkpoint_path, + 'Restoring checkpoint from ckpt_%d', latest_ckpt['global_step'] ) - ckpt_to_return = load_checkpoint( - external_checkpoint_path, + elif not found_checkpoint and orbax_checkpoint_manager_external: + logging.info('Restoring checkpoint from external checkpoint.') + ckpt_to_return = load_latest_checkpoint( target=unreplicated_checkpoint_state, - orbax_checkpointer=orbax_checkpointer, + orbax_checkpoint_manager=orbax_checkpoint_manager_external, ) is_restored = False # We don't want trainer to increment preemption_count. @@ -158,8 +146,7 @@ def maybe_restore_checkpoint( is_restored) # is_restored -def save_unreplicated_checkpoint( - train_dir, +def unreplicate_and_save_checkpoint( optimizer_state, params, batch_stats, @@ -167,8 +154,7 @@ def save_unreplicated_checkpoint( global_step, preemption_count, sum_train_cost, - orbax_checkpointer, - max_to_keep=1): + orbax_checkpoint_manager): """Saves pytree, step, preemption_count, and sum_train_cost to train_dir.""" logging.info('Saving checkpoint to ckpt_%d', global_step) # jax.device_get doesn't work if jax.Array lives on multiple hosts. @@ -191,79 +177,61 @@ def save_unreplicated_checkpoint( params=unreplicated_params, batch_stats=unreplicated_batch_stats, training_metrics_grabber=unreplicated_training_metrics_state) - save_checkpoint(train_dir, - global_step, + save_checkpoint(global_step, state, - max_to_keep=max_to_keep, - orbax_checkpointer=orbax_checkpointer) + orbax_checkpoint_manager=orbax_checkpoint_manager) logging.info('Done saving checkpoint.') -def save_checkpoint(train_dir, - step, +def save_checkpoint(step, state, - prefix='ckpt_', - max_to_keep=None, - orbax_checkpointer=None): - """Saves checkpoint to train_dir/{prefix}{step}. + orbax_checkpoint_manager): + """Saves checkpoint to train_dir. - A list of checkpoints will be stored in train_dir. The user - is responsible for using unique checkpoint names when calling save_checkpoint - repeatedly. If the same train_dir and checkpoint name are used more than once, - the latest file will become corrupt. This may become an issue if max_to_keep - is not None. + A list of checkpoints will be stored in train_dir/step. + If the step folder already exists, the checkpoint will not be saved and a + warning will be logged. Args: - train_dir: (str) Directory to create the checkpoint directory in. step: (int) Step of the checkpoint. - state: (dict) The state to save. - prefix: (str) Prefix of the checkpoint name. - max_to_keep: (int) Checkpoints older than the max_to_keep'th will be - deleted. Defaults to never deleting. - orbax_checkpointer: orbax.Checkpointer + state: (pytree)The state to save. + orbax_checkpoint_manager: orbax.CheckpointManager Returns: The path of the checkpoint directory. """ - if max_to_keep is None: - max_to_keep = sys.maxsize - flax_checkpoints.save_checkpoint_multiprocess( - train_dir, - target=state, - step=step, - prefix=prefix, - keep=max_to_keep, - overwrite=True, - orbax_checkpointer=orbax_checkpointer, - ) - save_dir = os.path.join(train_dir, prefix + str(step)) - return save_dir + saved = orbax_checkpoint_manager.save(step, args=ocp.args.StandardSave(state)) + if not saved: + logging.warning( + 'Checkpoint at step %d was not saved! Perhaps it already exists?', step + ) + return orbax_checkpoint_manager.directory def load_checkpoint( - checkpoint_path, target=None, prefix='ckpt_', orbax_checkpointer=None + checkpoint_path=None, + target=None, + step=0, + orbax_checkpoint_manager=None, ): """Loads the specified checkpoint.""" - restored = flax_checkpoints.restore_checkpoint( - checkpoint_path, - target=target, - prefix=prefix, - orbax_checkpointer=orbax_checkpointer, + # for backwards compatibility + if checkpoint_path and not orbax_checkpoint_manager: + orbax_checkpoint_manager = ocp.CheckpointManager(checkpoint_path) + restored = orbax_checkpoint_manager.restore( + step, + args=ocp.args.StandardRestore(target), ) return restored -def load_latest_checkpoint( - train_dir, target=None, prefix='ckpt_', orbax_checkpointer=None -): +def load_latest_checkpoint(target=None, orbax_checkpoint_manager=None): """Loads the most recent checkpoint listed in train_dir. Args: - train_dir: the directory to read checkpoints from. - target: used for Flax checkpointing, a pytree whose structure will be used + target: used for checkpointing, a pytree whose structure will be used to structure the restored checkpoint data. - prefix: the prefix of the names of checkpoint files. - orbax_checkpointer: orbax.Checkpointer + orbax_checkpoint_manager: An orbax.CheckpointManager instance. Returns: The state restored from the checkpoint. If using Flax checkpointing and target=None, this will return a unstructured dictionary containing the @@ -271,11 +239,11 @@ def load_latest_checkpoint( https://github.com/google/flax/blob/master/flax/serialization.py#L67. If the directory doesn't exist, it will return the original target. """ + restore_step = orbax_checkpoint_manager.latest_step() try: - restored = flax_checkpoints.restore_checkpoint( - train_dir, target=target, prefix=prefix, - orbax_checkpointer=orbax_checkpointer + restored = orbax_checkpoint_manager.restore( + restore_step, args=ocp.args.StandardRestore(target) ) return restored - except ValueError: + except FileNotFoundError: return target diff --git a/init2winit/gradient_statistics_callback.py b/init2winit/gradient_statistics_callback.py index 76fbe78e..af123cec 100644 --- a/init2winit/gradient_statistics_callback.py +++ b/init2winit/gradient_statistics_callback.py @@ -26,6 +26,7 @@ from init2winit.dataset_lib import data_utils import jax import jax.numpy as jnp +import orbax.checkpoint as ocp class GradientStatisticsCallback(base_callback.BaseCallBack): @@ -66,6 +67,12 @@ def __init__( ] self.num_updates = 0 + self.orbax_checkpoint_manager = ocp.CheckpointManager( + self.save_path, + options=ocp.CheckpointManagerOptions( + max_to_keep=1, create=True + ), + ) def update(params, batch, batch_stats, dropout_rng): def opt_cost(params): @@ -146,10 +153,8 @@ def run_eval(self, params, batch_stats, optimizer_state, global_step): ) checkpoint.save_checkpoint( - self.save_path, step=global_step, state=state, - prefix='measurement_', - max_to_keep=None) + orbax_checkpoint_manager=self.orbax_checkpoint_manager) return {} diff --git a/init2winit/test_checkpoint.py b/init2winit/test_checkpoint.py index 7d2fff7f..cc2d74a4 100644 --- a/init2winit/test_checkpoint.py +++ b/init2winit/test_checkpoint.py @@ -30,7 +30,7 @@ import jax.numpy as jnp import jax.tree_util import numpy as np -import orbax.checkpoint as orbax_checkpoint +import orbax.checkpoint as ocp from tensorflow.io import gfile @@ -59,8 +59,6 @@ def setUp(self): model_init_fn = jax.jit( functools.partial(model.flax_module.init, train=False)) init_dict = model_init_fn({'params': params_rng}, xs) - self.orbax_checkpointer = orbax_checkpoint.AsyncCheckpointer( - orbax_checkpoint.PyTreeCheckpointHandler(), timeout_secs=60) self.params = init_dict['params'] def tearDown(self): @@ -71,45 +69,53 @@ def tearDown(self): # We could supply the params pytree as a fake gradient and do an update. def test_save_load_roundtrip(self): """Test that saving and loading produces the original state.""" - baz = ['a', 'b', 'ccc'] - state = dict(params=self.params, global_step=5, completed_epochs=4, baz=baz) - checkpoint.save_checkpoint(self.test_dir, 0, state, - orbax_checkpointer=self.orbax_checkpointer) + orbax_checkpoint_manager = ocp.CheckpointManager( + self.test_dir, + options=ocp.CheckpointManagerOptions( + max_to_keep=1, create=True), + ) + state = dict(params=self.params, global_step=5, completed_epochs=4) + checkpoint.save_checkpoint( + 0, + state, + orbax_checkpoint_manager=orbax_checkpoint_manager, + ) + orbax_checkpoint_manager.wait_until_finished() latest = checkpoint.load_latest_checkpoint( - self.test_dir, target=state, orbax_checkpointer=self.orbax_checkpointer + target=state, orbax_checkpoint_manager=orbax_checkpoint_manager ) - self.assertEqual(latest['baz'], baz) assert pytree_equal(latest['params'], self.params) self.assertEqual(latest['global_step'], 5) self.assertEqual(latest['completed_epochs'], 4) def test_delete_old_checkpoints(self): """Test that old checkpoints are deleted.""" + orbax_checkpoint_manager = ocp.CheckpointManager( + self.test_dir, + options=ocp.CheckpointManagerOptions( + max_to_keep=1, create=True + ), + ) state1 = dict(params=self.params, global_step=5, completed_epochs=4,) checkpoint.save_checkpoint( - self.test_dir, 0, state1, - orbax_checkpointer=self.orbax_checkpointer, - max_to_keep=1) + orbax_checkpoint_manager=orbax_checkpoint_manager) state2 = dict(params=self.params, global_step=10, completed_epochs=8) checkpoint.save_checkpoint( - self.test_dir, 1, state2, - orbax_checkpointer=self.orbax_checkpointer, - max_to_keep=1) - self.orbax_checkpointer.wait_until_finished() + orbax_checkpoint_manager=orbax_checkpoint_manager) + orbax_checkpoint_manager.wait_until_finished() dir_contents = gfile.glob(os.path.join(self.test_dir, '*')) - # Due to Flax Orbax migration using Orbax AsyncCheckpointer will result - # in 'max_to_keep + 1' files. - self.assertLen(dir_contents, 1 + 1) + + self.assertLen(dir_contents, 1) def test_all_variables_restored(self): """Test that all variables are properly restored. @@ -134,8 +140,14 @@ def test_all_variables_restored(self): initial_batch_stats = {'mean': 0} initial_training_metrics = {'ema': 0} + orbax_checkpoint_manager = ocp.CheckpointManager( + fresh_train_dir, + options=ocp.CheckpointManagerOptions( + max_to_keep=1, create=True + ), + ) + checkpoint.save_checkpoint( - train_dir=fresh_train_dir, step=global_step, state=dict(global_step=global_step, preemption_count=preemption_count, @@ -144,8 +156,7 @@ def test_all_variables_restored(self): params=saved_params, batch_stats=saved_batch_stats, training_metrics_grabber=saved_training_metrics), - orbax_checkpointer=self.orbax_checkpointer, - max_to_keep=1) + orbax_checkpoint_manager=orbax_checkpoint_manager,) ( ret_state, @@ -161,8 +172,7 @@ def test_all_variables_restored(self): initial_params, initial_batch_stats, initial_training_metrics, - fresh_train_dir, - orbax_checkpointer=self.orbax_checkpointer, + orbax_checkpoint_manager=orbax_checkpoint_manager, ) assert pytree_equal( @@ -189,74 +199,87 @@ def test_all_variables_restored(self): def test_maybe_restore_from_checkpoint_logic(self): """Test that the right checkpoint is returned. - 1. If no external_checkpoint_path was passed, and if there is no - latest checkpoint in the train_dir, then the function should return - the passed-in params, batch_stats, etc. - 2. If an external checkpoint was provided but no latest checkpoint - exists in the train_dir, then the function should return the external - checkpoint. - 3. If a latest checkpoint exists in the train dir, then the function - should return that checkpoint. - + 1. If there is no latest checkpoint in the train_dir, then the function + should returnthe passed-in params, batch_stats, etc. + 2. If there is a latest checkpoint in the train_dir, then the function + should return the latest checkpoint. In the interest of conciseness, this test only checks the params, not the batch_stats, optimizer_state, or training_metics. The below test test_all_variables_restored() covers the other three. """ # mock parameters. initial_params = {'foo': 1.0} - latest_params = {'foo': 2.0} - external_params = {'foo': 3.0} + latest_params = {'foo': 3.0} - fresh_train_dir = tempfile.mkdtemp() - external_dir = tempfile.mkdtemp() + checkpoint_dir = tempfile.mkdtemp() + + orbax_checkpoint_manager = ocp.CheckpointManager( + checkpoint_dir, + options=ocp.CheckpointManagerOptions( + max_to_keep=1, create=True + ), + ) # two helper functions - def save_checkpoint(train_dir, global_step, preemption_count, - sum_train_cost, params): + def save_checkpoint( + orbax_checkpoint_manager, + global_step, + preemption_count, + sum_train_cost, + params, + ): """Helper function to save a checkpoint.""" checkpoint.save_checkpoint( - train_dir=train_dir, step=global_step, - state=dict(global_step=global_step, - preemption_count=preemption_count, - sum_train_cost=sum_train_cost, - optimizer_state={}, - params=params, - batch_stats={}, - training_metrics_grabber={}), - orbax_checkpointer=self.orbax_checkpointer, - max_to_keep=1) - - def maybe_restore_checkpoint(params, train_dir, external_checkpoint_path): + state=dict( + global_step=global_step, + preemption_count=preemption_count, + sum_train_cost=sum_train_cost, + optimizer_state={}, + params=params, + batch_stats={}, + training_metrics_grabber={}, + ), + orbax_checkpoint_manager=orbax_checkpoint_manager, + ) + + def maybe_restore_checkpoint(orbax_checkpoint_manager, params): """Helper function to replicate_and_maybe_restore a checkpoint.""" - (_, ret_params, _, _, - ret_global_step, ret_sum_train_cost, ret_preemption_count, - ret_is_restored) = checkpoint.maybe_restore_checkpoint( - {}, params, {}, {}, train_dir, external_checkpoint_path, - orbax_checkpointer=self.orbax_checkpointer) + ( + _, + ret_params, + _, + _, + ret_global_step, + ret_sum_train_cost, + ret_preemption_count, + ret_is_restored, + ) = checkpoint.maybe_restore_checkpoint( + {}, params, {}, {}, orbax_checkpoint_manager=orbax_checkpoint_manager + ) ret_params_unrep = ret_params - return (ret_params_unrep, ret_global_step, ret_sum_train_cost, - ret_preemption_count, ret_is_restored) - - # Save external checkpoint. - save_checkpoint(train_dir=external_dir, - global_step=5, - preemption_count=4, - sum_train_cost=7.0, - params=external_params) - external_checkpoint_path = os.path.join(external_dir, 'ckpt_' + str(5)) + return ( + ret_params_unrep, + ret_global_step, + ret_sum_train_cost, + ret_preemption_count, + ret_is_restored, + ) - # If no latest checkpoint exists, and no external checkpoint was provided, - # the function should return the passed-in params. + # If no latest checkpoint exists, the function should return the passed-in + # params. - (ret_params, ret_global_step, ret_sum_train_cost, ret_preemption_count, - ret_is_restored) = maybe_restore_checkpoint(initial_params, - fresh_train_dir, - None) + ( + ret_params, + ret_global_step, + ret_sum_train_cost, + ret_preemption_count, + ret_is_restored, + ) = maybe_restore_checkpoint(orbax_checkpoint_manager, initial_params) self.assertEqual(ret_preemption_count, 0) self.assertEqual(ret_global_step, 0) @@ -267,40 +290,32 @@ def maybe_restore_checkpoint(params, train_dir, external_checkpoint_path): # If no latest checkpoint exists, and an external checkpoint was provided, # the function should return the external checkpoint. - (ret_params, ret_global_step, ret_sum_train_cost, ret_preemption_count, - ret_is_restored) = maybe_restore_checkpoint(initial_params, - fresh_train_dir, - external_checkpoint_path) + # Save external checkpoint. + save_checkpoint( + orbax_checkpoint_manager, + global_step=5, + preemption_count=4, + sum_train_cost=7.0, + params=latest_params, + ) + + orbax_checkpoint_manager.wait_until_finished() + + ( + ret_params, + ret_global_step, + ret_sum_train_cost, + ret_preemption_count, + ret_is_restored, + ) = maybe_restore_checkpoint(orbax_checkpoint_manager, latest_params) self.assertEqual(ret_preemption_count, 4) self.assertEqual(ret_global_step, 5) self.assertEqual(ret_sum_train_cost, 7.0) - self.assertFalse(ret_is_restored) - assert pytree_equal(ret_params, external_params) - - # Save latest checkpoint. - save_checkpoint(train_dir=fresh_train_dir, - global_step=10, - preemption_count=2, - sum_train_cost=2.2, - params=latest_params) - - # If a latest checkpoint exists, then even if an external checkpoint was - # provided, the function should return the latest checkpoint. - - (ret_params, ret_global_step, ret_sum_train_cost, ret_preemption_count, - ret_is_restored) = maybe_restore_checkpoint(initial_params, - fresh_train_dir, - external_checkpoint_path) - - self.assertEqual(ret_preemption_count, 2) - self.assertEqual(ret_global_step, 10) - self.assertEqual(ret_sum_train_cost, 2.2) self.assertTrue(ret_is_restored) assert pytree_equal(ret_params, latest_params) - shutil.rmtree(fresh_train_dir) - shutil.rmtree(external_dir) + shutil.rmtree(checkpoint_dir) if __name__ == '__main__': diff --git a/init2winit/test_utils.py b/init2winit/test_utils.py index cee4d033..1b6076a6 100644 --- a/init2winit/test_utils.py +++ b/init2winit/test_utils.py @@ -23,7 +23,6 @@ from absl.testing import absltest from absl.testing import parameterized -from init2winit import checkpoint from init2winit import utils import jax.numpy as jnp import numpy as np @@ -75,31 +74,35 @@ def tearDown(self): expected=list(range(1, 10)), ), ) - def testRunInParallel(self, input_list_dict, num_workers, expected): + def test_run_in_parallel(self, input_list_dict, num_workers, expected): """Test running successful fns in parallel, originally from mlbileschi.""" actual = utils.run_in_parallel(_identity, input_list_dict, num_workers) self.assertEqual(actual, expected) - def testRunInParallelOnFailingFn(self): + def test_run_in_parallel_on_failing_fn(self): """Test running failing fns in parallel, originally from mlbileschi.""" with self.assertRaisesRegex(ValueError, 'I always fail.'): utils.run_in_parallel(_fn_that_always_fails, [dict(arg='hi')], 10) - def testAppendPytree(self): + def test_append_pytree(self): """Test appending and loading pytrees.""" pytrees = [{'a': i} for i in range(10)] - pytree_path = os.path.join(self.test_dir, 'pytree.ckpt') + pytree_path = os.path.join(self.test_dir, 'pytrees') logger = utils.MetricLogger(pytree_path=pytree_path) - for pytree in pytrees: - logger.append_pytree(pytree) + for i, pytree in enumerate(pytrees): + logger.append_pytree(pytree, step=i) + + latest = logger.load_latest_pytree( + target=None, + ) + saved_pytrees = latest if latest else [] - latest = checkpoint.load_latest_checkpoint(pytree_path, prefix='') - saved_pytrees = latest['pytree'] if latest else [] self.assertEqual( - pytrees, [saved_pytrees[str(i)] for i in range(len(saved_pytrees))]) + pytrees, [saved_pytrees[i] for i in range(len(saved_pytrees))] + ) - def testArrayAppend(self): + def test_array_append(self): """Test appending to an array.""" np.testing.assert_allclose( utils.array_append(jnp.array([1, 2, 3]), 4), jnp.array([1, 2, 3, 4])) @@ -107,13 +110,13 @@ def testArrayAppend(self): utils.array_append(jnp.array([[1, 2], [3, 4]]), jnp.array([5, 6])), jnp.array([[1, 2], [3, 4], [5, 6]])) - def testTreeNormSqL2(self): + def test_tree_norm_sq_l2(self): """Test computing the squared L2 norm of a pytree.""" pytree = {'foo': jnp.ones(10), 'baz': jnp.ones(20)} self.assertEqual(utils.tree_norm_sql2(pytree), {'foo': 10.0, 'baz': 20.0}) self.assertEqual(utils.total_tree_norm_sql2(pytree), 30.0) - def testTreeSum(self): + def test_tree_sum(self): """Test computing the sum of a pytree.""" pytree = {'foo': 2*jnp.ones(10), 'baz': jnp.ones(20)} self.assertEqual(utils.total_tree_sum(pytree), 40) diff --git a/init2winit/trainer_lib/base_trainer.py b/init2winit/trainer_lib/base_trainer.py index 18ac768d..c439464b 100644 --- a/init2winit/trainer_lib/base_trainer.py +++ b/init2winit/trainer_lib/base_trainer.py @@ -174,10 +174,11 @@ def __init__( choose_store_cell=True, ), ) - self._orbax_checkpointer = ocp.AsyncCheckpointer( - ocp.PyTreeCheckpointHandler(use_ocdbt=False), - timeout_secs=600, - file_options=orbax_file_options, + self._orbax_checkpoint_manager = ocp.CheckpointManager( + self._checkpoint_dir, + options=ocp.CheckpointManagerOptions( + max_to_keep=1, create=True, file_options=orbax_file_options + ), ) self._early_stopping_target_name = early_stopping_target_name self._early_stopping_target_value = early_stopping_target_value @@ -195,8 +196,15 @@ def __init__( self._callback_configs = [callback_configs] else: self._callback_configs = callback_configs - self._external_checkpoint_path = external_checkpoint_path - + if external_checkpoint_path is not None: + self._orbax_checkpoint_manager_external = ocp.CheckpointManager( + external_checkpoint_path, + options=ocp.CheckpointManagerOptions( + create=True, file_options=orbax_file_options + ), + ) + else: + self._orbax_checkpoint_manager_external = None # For logging / processing off the main thread self._logging_pool = multiprocessing.pool.ThreadPool() @@ -208,10 +216,16 @@ def __init__( assert eval_batch_size % (jax.device_count()) == 0 # Only used if checkpoints_steps is non-empty. Standard checkpoints are - # saved in train_dir. + # saved in _checkpoint_dir. self._extra_checkpoint_dir = os.path.join( self._checkpoint_dir, 'checkpoints' ) + self._orbax_checkpoint_manager_extra = ocp.CheckpointManager( + self._extra_checkpoint_dir, + options=ocp.CheckpointManagerOptions( + create=True, file_options=orbax_file_options + ), + ) # During eval, we can donate the 'batch' buffer. We don't donate the # 'params' and 'batch_stats' buffers as we don't re-assign those values in @@ -228,7 +242,7 @@ def __init__( self._training_algorithm_class) def wait_until_orbax_checkpointer_finished(self): - self._orbax_checkpointer.wait_until_finished() + self._orbax_checkpoint_manager.wait_until_finished() def log_model_info(self, unreplicated_params): if jax.process_index() == 0: @@ -267,9 +281,8 @@ def maybe_restore_from_checkpoint(self, unreplicated_params, unreplicated_batch_stats, unreplicated_metrics_state, - train_dir=self._checkpoint_dir, - external_checkpoint_path=self._external_checkpoint_path, - orbax_checkpointer=self._orbax_checkpointer, + orbax_checkpoint_manager=self._orbax_checkpoint_manager, + orbax_checkpoint_manager_external=self._orbax_checkpoint_manager_external, ) if self._is_restored: @@ -335,13 +348,14 @@ def setup_data_loader(self, data_rng, global_step): return dataset - def _save(self, checkpoint_dir, max_to_keep=1): + def _save(self, checkpoint_manager=None): if utils.use_mock_tpu_backend(): logging.info('Skip saving checkpoint when running with mock backend.') return + if checkpoint_manager is None: + checkpoint_manager = self._orbax_checkpoint_manager - checkpoint.save_unreplicated_checkpoint( - checkpoint_dir, + checkpoint.unreplicate_and_save_checkpoint( self._optimizer_state, self._params, self._batch_stats, @@ -349,8 +363,7 @@ def _save(self, checkpoint_dir, max_to_keep=1): self._global_step, self._preemption_count, self._sum_train_cost, - self._orbax_checkpointer, - max_to_keep=max_to_keep, + checkpoint_manager, ) def _get_step_frequency(self, cur_step, start_step, start_time): @@ -451,7 +464,7 @@ def _eval(self, start_step, start_time, save=True): ) self._run_eval_callbacks(report) if save: - self._save(self._checkpoint_dir) + self._save() steps_since_last_eval = self._global_step - self._prev_eval_step steps_per_sec_no_eval = steps_since_last_eval / time_since_last_eval run_time = time.time() - self._time_at_prev_eval_end @@ -646,7 +659,7 @@ def train(self): self._prev_eval_step = self._global_step if self._global_step in self._checkpoint_steps: - self._save(self._extra_checkpoint_dir, max_to_keep=None) + self._save(checkpoint_manager=self._orbax_checkpoint_manager_extra) for _ in range(start_step, self._num_train_steps): with jax.profiler.StepTraceAnnotation( @@ -682,13 +695,15 @@ def train(self): self._sum_train_cost, ) if self._global_step in self._checkpoint_steps: - self._save(self._extra_checkpoint_dir, max_to_keep=None) + logging.info('Saving checkpoint at step %d', self._global_step) + self._save(checkpoint_manager=self._orbax_checkpoint_manager_extra) # TODO(gdahl, gilmer): consider moving this test up. # NB: Since this test is after we increment self._global_step, having 0 # in eval_steps does nothing. if trainer_utils.should_eval( - self._global_step, self._eval_frequency, self._eval_steps): + self._global_step, self._eval_frequency, self._eval_steps + ): try: report = self._eval(start_step, start_time) except utils.TrainingDivergedError as e: @@ -709,6 +724,8 @@ def train(self): yield report # To make sure the last checkpoint was correctly saved. self.wait_until_orbax_checkpointer_finished() + self._orbax_checkpoint_manager.close() + self._orbax_checkpoint_manager_extra.close() @abc.abstractmethod def update(self, batch, rng, metrics_update_fn, metrics_state, training_cost): diff --git a/init2winit/trainer_lib/test_trainer.py b/init2winit/trainer_lib/test_trainer.py index d88bc1f8..87e324e8 100644 --- a/init2winit/trainer_lib/test_trainer.py +++ b/init2winit/trainer_lib/test_trainer.py @@ -741,8 +741,8 @@ def as_dataset(self, *args, **kwargs): checkpoint_dir = os.path.join(self.test_dir, 'ttl=180d', 'checkpoints') saved_steps = [] for f in tf.io.gfile.listdir(checkpoint_dir): - if f[:5] == 'ckpt_': - saved_steps.append(int(f[5:])) + if f.isdigit(): + saved_steps.append(int(f)) self.assertEqual(set(saved_steps), set(checkpoint_steps)) diff --git a/init2winit/utils.py b/init2winit/utils.py index f0ba297a..e307e12d 100644 --- a/init2winit/utils.py +++ b/init2winit/utils.py @@ -27,14 +27,15 @@ from clu import metric_writers import flax import flax.linen as nn -from flax.training import checkpoints as flax_checkpoints from init2winit import checkpoint import jax import jax.numpy as jnp import numpy as np +import orbax.checkpoint as ocp import pandas as pd from tensorflow.io import gfile + exists = gfile.exists @@ -215,6 +216,33 @@ def __init__(self, if events_dir: self._tb_metric_writer = metric_writers.create_default_writer(events_dir) + orbax_file_options = ocp.checkpoint_manager.FileOptions( + path_permission_mode=0o775, + cns2_storage_options=ocp.options.Cns2StorageOptions( + choose_store_cell=True, + ), + ) + self._orbax_checkpoint_manager = ocp.CheckpointManager( + self._pytree_path, + options=ocp.CheckpointManagerOptions( + max_to_keep=1, create=True, file_options=orbax_file_options + ), + ) + + def load_latest_pytree(self, target=None): + """Load pytree from checkpoint.""" + if target: + target = dict(pytree=target) + logging.info('target: %s', target) + loaded_target = checkpoint.load_latest_checkpoint( + target=target, + orbax_checkpoint_manager=self._orbax_checkpoint_manager, + ) + if loaded_target: + return loaded_target['pytree'] + else: + return target + def append_scalar_metrics(self, metrics): """Record a dictionary of scalar metrics at a given step. @@ -261,23 +289,27 @@ def append_scalar_metrics(self, metrics): # size 512. We could only flush at the end of training to optimize this. self._tb_metric_writer.flush() - def write_pytree(self, pytree, prefix='training_metrics'): - """Record a serializable pytree to disk, overwriting any previous state. + def write_pytree(self, pytree, step=0): + """Record a serializable pytree to disk at the given step. Args: pytree: Any serializable pytree - prefix: The prefix for the checkpoint. Save path is - self._pytree_path/prefix + step: Integer. The global step. """ state = dict(pytree=pytree) checkpoint.save_checkpoint( - self._pytree_path, - step='', + step, state=state, - prefix=prefix, - max_to_keep=None) + orbax_checkpoint_manager=self._orbax_checkpoint_manager, + ) + + def wait_until_pytree_checkpoint_finished(self): + self._orbax_checkpoint_manager.wait_until_finished() + + def latest_pytree_checkpoint_step(self): + return self._orbax_checkpoint_manager.latest_step() - def append_pytree(self, pytree, prefix='training_metrics'): + def append_pytree(self, pytree, step=0): """Append and record a serializable pytree to disk. The pytree will be saved to disk as a list of pytree objects. Everytime @@ -286,26 +318,24 @@ def append_pytree(self, pytree, prefix='training_metrics'): Args: pytree: Any serializable pytree. - prefix: The prefix for the checkpoint. + step: Integer. The global step. """ # Read the latest (and only) checkpoint if it exists, then append the new # state to it before saving back to disk. try: - old_state = flax_checkpoints.restore_checkpoint( - self._pytree_path, target=None, prefix=prefix) + old_state = self.load_latest_pytree(target=None) except ValueError: old_state = None # Because we pass target=None, checkpointing will return the raw state # dict, where 'pytree' is a dict with keys ['0', '1', ...] instead of a # list. if old_state: - state_list = old_state['pytree'] - state_list = [state_list[str(i)] for i in range(len(state_list))] + state_list = [old_state[i] for i in range(len(old_state))] else: state_list = [] state_list.append(pytree) - self.write_pytree(state_list) + self.write_pytree(state_list, step=step) def append_json_object(self, json_obj): """Append a json serializable object to the json file."""