Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions hessian/model_debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@
"""

import functools
import os

import flax
import flax.linen as nn
from init2winit.checkpoint import load_pytree
from init2winit.model_lib import partition_tree
from init2winit.utils import array_append
from init2winit.utils import tree_norm_sql2
Expand Down Expand Up @@ -342,11 +340,8 @@ def __init__(self,
self._stored_metrics = {}

# In the case of preemption we want to restore prior metrics.
if metrics_logger:
metrics_file = os.path.join(metrics_logger._pytree_path,
'training_metrics')
if exists(metrics_file):
self._stored_metrics = load_pytree(metrics_file)
if metrics_logger and metrics_logger.latest_pytree_checkpoint_step():
self._stored_metrics = metrics_logger.load_latest_pytree()

def _grab_statistics(self,
step,
Expand Down Expand Up @@ -380,9 +375,9 @@ def _maybe_save_metrics(self, step):
save_dict = self._stored_metrics.copy()
if self._save_every:
if step % self._save_every == 0:
self._metrics_logger.write_pytree(save_dict)
self._metrics_logger.write_pytree(save_dict, step=step)
else:
self._metrics_logger.write_pytree(save_dict)
self._metrics_logger.write_pytree(save_dict, step=step)

@property
def stored_metrics(self):
Expand Down
20 changes: 9 additions & 11 deletions hessian/test_model_debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from absl.testing import absltest
import flax
from flax import linen as nn
from init2winit import checkpoint
from init2winit import utils
from init2winit.hessian import model_debugger
from init2winit.hessian.model_debugger import skip_bwd
Expand Down Expand Up @@ -282,11 +281,10 @@ def fake_loss_fn(ps, x, rng, module_flags=None):
self.assertEqual(og['B_0']['Dense_0']['kernel'], 16.0)
self.assertEqual(og['C_0']['Dense_0']['kernel'], 36.0)

def test_model_debugger_pmap(self):
def test_model_debugger_restore(self):
"""Test training for two epochs on MNIST with a small model."""

rep_variables = set_up_cnn()

pytree_path = os.path.join(self.test_dir, 'metrics')
metrics_logger = utils.MetricLogger(
pytree_path=pytree_path, events_dir=self.test_dir)
Expand All @@ -297,18 +295,18 @@ def grad_fn(params, batch, rng):
return rep_variables['params']

debugger = model_debugger.ModelDebugger(
use_pmap=True, grad_fn=grad_fn, metrics_logger=metrics_logger)
use_pmap=False, grad_fn=grad_fn, metrics_logger=metrics_logger)

# eval twice to test the concat
extra_metrics = {'train_loss': 1.0}
extra_metrics2 = {'train_loss': 1.0}
metrics = debugger.full_eval(
_ = debugger.full_eval(
10,
params=rep_variables['params'],
grad=rep_variables['params'],
extra_scalar_metrics=extra_metrics)
metrics = debugger.full_eval(
10,
20,
params=rep_variables['params'],
grad=None, # use internal gradient comp
extra_scalar_metrics=extra_metrics2)
Expand All @@ -321,9 +319,8 @@ def grad_fn(params, batch, rng):
'train_loss',
]

metrics_file = os.path.join(self.test_dir, 'metrics/training_metrics')

loaded_metrics = checkpoint.load_checkpoint(metrics_file)['pytree']
metrics_logger.wait_until_pytree_checkpoint_finished()
loaded_metrics = metrics_logger.load_latest_pytree(None)

self.assertEqual(set(expected_keys), set(metrics.keys()))
expected_shape = ()
Expand All @@ -341,11 +338,12 @@ def grad_fn(params, batch, rng):
# Test restore of prior metrics.
new_debugger = model_debugger.ModelDebugger(
use_pmap=True, metrics_logger=metrics_logger)
metrics = new_debugger.full_eval(
10,
_ = new_debugger.full_eval(
30,
params=rep_variables['params'],
grad=rep_variables['params'],
extra_scalar_metrics=extra_metrics2)
metrics_logger.wait_until_pytree_checkpoint_finished()
self.assertEqual(
new_debugger.stored_metrics['param_norms_sql2']['Conv_0']
['kernel'].shape, (3,))
Expand Down
158 changes: 63 additions & 95 deletions init2winit/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,23 @@
This is useful for training neural networks with stax, where model parameters
are nested numpy arrays.
"""
import os
import sys

from absl import flags
from absl import logging
from flax.training import checkpoints as flax_checkpoints
import jax
# pylint: disable=g-importing-member
from jax.experimental.multihost_utils import process_allgather
import orbax.checkpoint as ocp

FLAGS = flags.FLAGS


def load_pytree(pytree_file, orbax_checkpointer=None):
"""Loads the checkpointed pytree."""
latest = load_latest_checkpoint(pytree_file,
target=None,
orbax_checkpointer=orbax_checkpointer)
def load_pytree(pytree_file, orbax_checkpoint_manager=None):
"""Loads a checkpointed pytree."""
if not orbax_checkpoint_manager:
orbax_checkpoint_manager = ocp.CheckpointManager(pytree_file)
latest = load_latest_checkpoint(
target=None, orbax_checkpoint_manager=orbax_checkpoint_manager
)
if latest:
# Because we pass target=None, flax checkpointing will return the raw
# state dict, where 'state' will be a dict with keys ['0', '1', ...]
Expand All @@ -49,26 +48,22 @@ def maybe_restore_checkpoint(
unreplicated_params,
unreplicated_batch_stats,
unreplicated_training_metrics_state,
train_dir,
external_checkpoint_path=None,
orbax_checkpointer=None):
orbax_checkpoint_manager=None,
orbax_checkpoint_manager_external=None):
"""Optionally restores from a checkpoint.

The checkpoint logic is as follows: if there is a checkpoint in `train_dir`,
restore it. Else, if `external_checkpoint_path` is set, restore the
checkpoint found there. Else, don't restore any checkpoint, and just
return the passed-in optimizer_state, params, batch_stats, and
The checkpoint logic is as follows: if `orbax_checkpoint_manager` contains
a latest checkpoint, restore it. Otherwise, don't restore any checkpoint,
and just return the passed-in optimizer_state, params, batch_stats, and
metrics_grabber.

Args:
unreplicated_optimizer_state: unreplicated optimizer state
unreplicated_params: unreplicated params
unreplicated_batch_stats: unreplicated batch stats
unreplicated_training_metrics_state: unreplicated metrics state
train_dir: (str) The training directory where we will look for a checkpoint.
external_checkpoint_path: (str) If this argument is set, then we will load
the external checkpoint stored there.
orbax_checkpointer: orbax.Checkpointer
orbax_checkpoint_manager: orbax.CheckpointManager
orbax_checkpoint_manager_external: orbax.CheckpointManager

Returns:
unreplicated_optimizer_state
Expand All @@ -89,12 +84,14 @@ def maybe_restore_checkpoint(
training_metrics_grabber=unreplicated_training_metrics_state,
global_step=uninitialized_global_step,
preemption_count=0,
sum_train_cost=0.0)
logging.info('Loading latest checkpoint from train_dir: %s', train_dir)
latest_ckpt = load_latest_checkpoint(train_dir,
target=unreplicated_checkpoint_state,
orbax_checkpointer=orbax_checkpointer)
logging.info('Loading checkpoint from train_dir %s complete.', train_dir)
sum_train_cost=0.0,
)
logging.info('Loading latest checkpoint')
latest_ckpt = load_latest_checkpoint(
target=unreplicated_checkpoint_state,
orbax_checkpoint_manager=orbax_checkpoint_manager,
)
logging.info('Loading checkpoint from complete.')
# Load_latest_checkpoint() will return unreplicated_checkpoint_state if
# train_dir does not exist or if it exists and contains no checkpoints.
# Note that we could likely change the below line to:
Expand All @@ -105,23 +102,14 @@ def maybe_restore_checkpoint(
if found_checkpoint:
ckpt_to_return = latest_ckpt
is_restored = True # We do want trainer to increment preemption_count.
logging.info('Restoring checkpoint from ckpt_%d',
latest_ckpt['global_step'])
# Else, if external_checkpoint_path is non-null, restore from that checkpoint.
elif external_checkpoint_path is not None:
# TODO(jeremycohen) This code will crash if we try to load an external
# checkpoint which was trained with a different num_train_steps. The issue
# is that some of the fields in the training metrics state are arrays of
# shape [num_train_steps]. In the future we may want to handle these
# arrays explicitly, in order to avoid this crash.
logging.info(
'Restoring checkpoint from external_checkpoint_path %s',
external_checkpoint_path,
'Restoring checkpoint from ckpt_%d', latest_ckpt['global_step']
)
ckpt_to_return = load_checkpoint(
external_checkpoint_path,
elif not found_checkpoint and orbax_checkpoint_manager_external:
logging.info('Restoring checkpoint from external checkpoint.')
ckpt_to_return = load_latest_checkpoint(
target=unreplicated_checkpoint_state,
orbax_checkpointer=orbax_checkpointer,
orbax_checkpoint_manager=orbax_checkpoint_manager_external,
)
is_restored = False # We don't want trainer to increment preemption_count.

Expand Down Expand Up @@ -158,17 +146,15 @@ def maybe_restore_checkpoint(
is_restored) # is_restored


def save_unreplicated_checkpoint(
train_dir,
def unreplicate_and_save_checkpoint(
optimizer_state,
params,
batch_stats,
training_metrics_state,
global_step,
preemption_count,
sum_train_cost,
orbax_checkpointer,
max_to_keep=1):
orbax_checkpoint_manager):
"""Saves pytree, step, preemption_count, and sum_train_cost to train_dir."""
logging.info('Saving checkpoint to ckpt_%d', global_step)
# jax.device_get doesn't work if jax.Array lives on multiple hosts.
Expand All @@ -191,91 +177,73 @@ def save_unreplicated_checkpoint(
params=unreplicated_params,
batch_stats=unreplicated_batch_stats,
training_metrics_grabber=unreplicated_training_metrics_state)
save_checkpoint(train_dir,
global_step,
save_checkpoint(global_step,
state,
max_to_keep=max_to_keep,
orbax_checkpointer=orbax_checkpointer)
orbax_checkpoint_manager=orbax_checkpoint_manager)
logging.info('Done saving checkpoint.')


def save_checkpoint(train_dir,
step,
def save_checkpoint(step,
state,
prefix='ckpt_',
max_to_keep=None,
orbax_checkpointer=None):
"""Saves checkpoint to train_dir/{prefix}{step}.
orbax_checkpoint_manager):
"""Saves checkpoint to train_dir.

A list of checkpoints will be stored in train_dir. The user
is responsible for using unique checkpoint names when calling save_checkpoint
repeatedly. If the same train_dir and checkpoint name are used more than once,
the latest file will become corrupt. This may become an issue if max_to_keep
is not None.
A list of checkpoints will be stored in train_dir/step.
If the step folder already exists, the checkpoint will not be saved and a
warning will be logged.

Args:
train_dir: (str) Directory to create the checkpoint directory in.
step: (int) Step of the checkpoint.
state: (dict) The state to save.
prefix: (str) Prefix of the checkpoint name.
max_to_keep: (int) Checkpoints older than the max_to_keep'th will be
deleted. Defaults to never deleting.
orbax_checkpointer: orbax.Checkpointer
state: (pytree)The state to save.
orbax_checkpoint_manager: orbax.CheckpointManager

Returns:
The path of the checkpoint directory.
"""
if max_to_keep is None:
max_to_keep = sys.maxsize
flax_checkpoints.save_checkpoint_multiprocess(
train_dir,
target=state,
step=step,
prefix=prefix,
keep=max_to_keep,
overwrite=True,
orbax_checkpointer=orbax_checkpointer,
)
save_dir = os.path.join(train_dir, prefix + str(step))
return save_dir
saved = orbax_checkpoint_manager.save(step, args=ocp.args.StandardSave(state))
if not saved:
logging.warning(
'Checkpoint at step %d was not saved! Perhaps it already exists?', step
)
return orbax_checkpoint_manager.directory


def load_checkpoint(
checkpoint_path, target=None, prefix='ckpt_', orbax_checkpointer=None
checkpoint_path=None,
target=None,
step=0,
orbax_checkpoint_manager=None,
):
"""Loads the specified checkpoint."""
restored = flax_checkpoints.restore_checkpoint(
checkpoint_path,
target=target,
prefix=prefix,
orbax_checkpointer=orbax_checkpointer,
# for backwards compatibility
if checkpoint_path and not orbax_checkpoint_manager:
orbax_checkpoint_manager = ocp.CheckpointManager(checkpoint_path)
restored = orbax_checkpoint_manager.restore(
step,
args=ocp.args.StandardRestore(target),
)
return restored


def load_latest_checkpoint(
train_dir, target=None, prefix='ckpt_', orbax_checkpointer=None
):
def load_latest_checkpoint(target=None, orbax_checkpoint_manager=None):
"""Loads the most recent checkpoint listed in train_dir.

Args:
train_dir: the directory to read checkpoints from.
target: used for Flax checkpointing, a pytree whose structure will be used
target: used for checkpointing, a pytree whose structure will be used
to structure the restored checkpoint data.
prefix: the prefix of the names of checkpoint files.
orbax_checkpointer: orbax.Checkpointer
orbax_checkpoint_manager: An orbax.CheckpointManager instance.
Returns:
The state restored from the checkpoint. If using Flax checkpointing and
target=None, this will return a unstructured dictionary containing the
checkpoint data, as returned by to_state_dict in serialization.py:
https://github.com/google/flax/blob/master/flax/serialization.py#L67. If
the directory doesn't exist, it will return the original target.
"""
restore_step = orbax_checkpoint_manager.latest_step()
try:
restored = flax_checkpoints.restore_checkpoint(
train_dir, target=target, prefix=prefix,
orbax_checkpointer=orbax_checkpointer
restored = orbax_checkpoint_manager.restore(
restore_step, args=ocp.args.StandardRestore(target)
)
return restored
except ValueError:
except FileNotFoundError:
return target
Loading