Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Partial checkpoints #861

Merged
merged 4 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 26 additions & 26 deletions config/harness/eval_llama3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,32 @@ eval_harness:
task_spec:
- task: commonsense_qa # 5-way multiple-choice questions based on common-sense, everyday scenarios
num_fewshot: 10
- task: agieval_lsat_ar # 3-shot tests in legal domain
num_fewshot: 3
- task: arc_easy # 10-shot, four-way MCQ questions involving grade 3-9 basic science
num_fewshot: 10
- task: arc_challenge # a (harder) version of arc_easy
num_fewshot: 10
- task: boolq # answer yes/no questions based on a passage
num_fewshot: 10
- task: copa # use causal reasoning to predict the correct outcome of a given scenario
num_fewshot: 0
- task: hellaswag # 4-way multiple choice commonsense reasoning dataset
num_fewshot: 0
task_alias: hellaswag_0shot
- task: hellaswag # 4-way multiple choice commonsense reasoning dataset
num_fewshot: 10
task_alias: hellaswag_10shot
- task: lambada # predict the endings of text passages
num_fewshot: 0
- task: openbookqa # 4-way multiple choice question answering task that requires multi-step reasoning
num_fewshot: 0
- task: piqa # answer questions based on a passage
num_fewshot: 10
- task: wsc273 # Winograd Schema Challenge
num_fewshot: 0
- task: winogrande # Winograd challenge, extended to more domains
num_fewshot: 0
# - task: agieval_lsat_ar # 3-shot tests in legal domain
# num_fewshot: 3
# - task: arc_easy # 10-shot, four-way MCQ questions involving grade 3-9 basic science
# num_fewshot: 10
# - task: arc_challenge # a (harder) version of arc_easy
# num_fewshot: 10
# - task: boolq # answer yes/no questions based on a passage
# num_fewshot: 10
# - task: copa # use causal reasoning to predict the correct outcome of a given scenario
# num_fewshot: 0
# - task: hellaswag # 4-way multiple choice commonsense reasoning dataset
# num_fewshot: 0
# task_alias: hellaswag_0shot
# - task: hellaswag # 4-way multiple choice commonsense reasoning dataset
# num_fewshot: 10
# task_alias: hellaswag_10shot
# - task: lambada # predict the endings of text passages
# num_fewshot: 0
# - task: openbookqa # 4-way multiple choice question answering task that requires multi-step reasoning
# num_fewshot: 0
# - task: piqa # answer questions based on a passage
# num_fewshot: 10
# - task: wsc273 # Winograd Schema Challenge
# num_fewshot: 0
# - task: winogrande # Winograd challenge, extended to more domains
# num_fewshot: 0
# requires generation
## - task: squadv2 # reading comprehension benchmark
# num_fewshot: 10
Expand Down
11 changes: 10 additions & 1 deletion src/levanter/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ def load_checkpoint(
discover_latest=True,
axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None,
mesh: Optional[jax.sharding.Mesh] = None,
allow_partial: bool = False,
) -> M:
"""
Load a checkpoint from a given path. If discover_latest is True, then the latest checkpoint
Expand All @@ -367,6 +368,7 @@ def load_checkpoint(
discover_latest: whether to discover the latest checkpoint in the given path
axis_mapping: the axis mapping to use for loading the checkpoint
mesh: the mesh to use for loading the checkpoint
allow_partial: if True, allow partial loading of the checkpoint. If False, all parameters must be present in the checkpoint.
Returns:
the loaded checkpoint, with the same structure as the exemplar tree

Expand Down Expand Up @@ -397,7 +399,9 @@ def load_checkpoint(

ser, non_ser = equinox.partition(tree, is_jax_array_like)
try:
tree = tree_deserialize_leaves_tensorstore(checkpoint_path, ser, axis_mapping=axis_mapping, mesh=mesh)
tree = tree_deserialize_leaves_tensorstore(
checkpoint_path, ser, axis_mapping=axis_mapping, mesh=mesh, allow_missing=allow_partial
)
tree = equinox.combine(tree, non_ser)
return tree
except: # noqa
Expand Down Expand Up @@ -445,6 +449,7 @@ def load_checkpoint_or_initialize(
donate_args: FilterSpec = True,
donate_kwargs: Optional[FilterSpec] = None,
do_load: Optional[bool] = None,
allow_partial: bool = False,
) -> Callable[Sig, M]:
"""
Load a checkpoint from a given path. If discover_latest is True, then the latest checkpoint
Expand Down Expand Up @@ -476,6 +481,7 @@ def load_checkpoint_or_initialize(
donate_args: a FilterSpec that specifies which arguments to donate to init_fn if we need to initialize
donate_kwargs: a FilterSpec that specifies which kwargs to donate to init_fn if we need to initialize
do_load: if True, always load the checkpoint. If False, always initialize. If None, load if the checkpoint exists, otherwise initialize
allow_partial: if True, allow partial loading of the checkpoint. If False, all parameters must be present in the checkpoint.

Returns:
A function that takes the same arguments as init_fn, but loads the checkpoint if it exists and returns the
Expand All @@ -493,6 +499,8 @@ def load_checkpoint_or_initialize(
)
def init_and_merge(state, *args, **kwargs):
init_state = init_fn(*args, **kwargs)
# remove all ShapeDTypeStructs from the state
state = equinox.filter(state, lambda x: not isinstance(x, jax.ShapeDtypeStruct))
return equinox.combine(state, init_state)

def load_or_init(*args, **kwargs):
Expand All @@ -516,6 +524,7 @@ def load_or_init(*args, **kwargs):
discover_latest=discover_latest,
axis_mapping=axis_mapping,
mesh=mesh,
allow_partial=allow_partial,
)
except FileNotFoundError:
if do_load is True:
Expand Down
51 changes: 42 additions & 9 deletions src/levanter/tensorstore_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from haliax.partitioning import ResourceMapping
from haliax.util import is_named_array

from levanter.utils import jax_utils
from levanter.utils import fsspec_utils, jax_utils


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -119,6 +119,8 @@ def tree_deserialize_leaves_tensorstore(
axis_mapping: Optional[ResourceMapping] = None,
mesh: Optional[Mesh] = None,
manager: Optional[array_ser.GlobalAsyncCheckpointManager] = None,
*,
allow_missing: bool = False,
):
"""
Deserializes a PyTree of Arrays and NamedArrays from a Tensorstore checkpoint, returning a pytree with the same shape
Expand All @@ -132,6 +134,7 @@ def tree_deserialize_leaves_tensorstore(
axis_mapping: optional, the axis mapping for the NamedArrays (if they are not yet arrays)
mesh: optional, the mesh for the NamedArrays (if they are not yet arrays)
manager: optional, the checkpoint manager to use. If not provided, a new one will be created
allow_missing: if True, missing leaves will be allowed and kept as-is

Returns:
A pytree with the same shape as the exemplar pytree, but with the arrays deserialized from the checkpoint
Expand All @@ -151,26 +154,56 @@ def tree_deserialize_leaves_tensorstore(
shardings_leaves, shardings_structure = jtu.tree_flatten(shardings, is_leaf=_is_named_or_none)

assert len(shardings_leaves) == len(paths)

# ok, so, jax really doesn't want any Nones in the leaves here, so we need to temporarily partition the pytree
real_indices = [i for i, x in enumerate(shardings_leaves) if x is not None]
real_leaves = [x for x in shardings_leaves if x is not None]
real_paths = [paths[i] for i in real_indices]
paths_to_load = []
indices_to_load = []
shardings_to_load = []

missing_paths = []
missing_indices = []

for i in real_indices:
path = paths[i]

if not fsspec_utils.exists(path):
missing_paths.append(path)
missing_indices.append(i)
continue

assert len(real_leaves) == len(real_paths), f"{len(real_leaves)} != {len(real_paths)}"
paths_to_load.append(path)
indices_to_load.append(i)
shardings_to_load.append(shardings_leaves[i])

# ok now check for missing paths
if missing_paths:
if not allow_missing:
raise FileNotFoundError(f"Missing paths: {missing_paths}")
else:
to_log = f"Several keys were missing from the checkpoint directory {checkpoint_dir}:"
leaf_paths = jtu.tree_leaves(leaf_key_paths, is_leaf=_is_named_or_none)
for i in missing_indices:
to_log += f"\n - {leaf_paths[i]}"
logger.warning(to_log)

deser_leaves = manager.deserialize_with_paths(shardings=shardings_to_load, paths=paths_to_load)

deser_leaves = manager.deserialize_with_paths(shardings=real_leaves, paths=real_paths)
# now we need to recreate the original structure

out_leaves = [None] * len(shardings_leaves)
for i, x in zip(real_indices, deser_leaves):
out_leaves = jax.tree_leaves(pytree, is_leaf=_is_named_or_none)
assert len(out_leaves) == len(shardings_leaves)
# out_leaves = [None] * len(shardings_leaves)
for i, x in zip(indices_to_load, deser_leaves):
out_leaves[i] = x

deser_arrays = jtu.tree_unflatten(shardings_structure, out_leaves)

# deser_arrays only has arrays, but we need named arrays for at least some.
# deser_arrays only has arrays for the deserialized arrays, but we need named arrays for at least some.
# The original pytree has the structure we want, so we'll use that to rebuild the named arrays
def _rebuild_named_array(like, array):
if is_named_array(array):
return array

if is_named_array(like):
return hax.NamedArray(array, like.axes)
else:
Expand Down
5 changes: 5 additions & 0 deletions src/levanter/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ def initial_state(
mesh=self.device_mesh,
subpath="model",
do_load=True,
allow_partial=self.config.allow_partial_checkpoint,
)()
model_init = jax.tree_util.Partial(lambda m: m, loaded_model)

Expand All @@ -369,6 +370,7 @@ def init_state_and_model(model_init, training_key):
mesh=self.device_mesh,
is_checkpointed=saveable_train_state,
do_load=load_checkpoint,
allow_partial=self.config.allow_partial_checkpoint,
)(model_init, training_key)

return state
Expand Down Expand Up @@ -629,6 +631,9 @@ class TrainerConfig:
load_checkpoint_path: Optional[str] = None
"""can be a parent (to find latest) or a specific checkpoint. if None, will set to checkpointer.base_path."""
initialize_from: Optional[str] = None # Levanter trainer checkpoint to initialize from
allow_partial_checkpoint: bool = False
"""If True, we allow loading a checkpoint that doesn't have all the parameters in the model.
Missing parameters are initialized from the model_init function."""

jax_config: Mapping[str, JsonAtom] = field(
default_factory=lambda: copy.deepcopy(DEFAULT_JAX_CONFIG)
Expand Down
37 changes: 37 additions & 0 deletions tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,3 +329,40 @@ def init_fn(key):
jax.tree_util.tree_leaves(arrays_only(eqx.filter(loaded, is_checkpointed))),
jax.tree_util.tree_leaves(arrays_only(eqx.filter(model1, is_checkpointed))),
)


def test_load_from_checkpoint_allows_partial_checkpoints():
In = Axis("in", 2)
Out = Axis("out", 1)

class MyModule(eqx.Module):
a: hax.NamedArray
b: hax.NamedArray | None

def init_fn(key, use_b):
k_a, k_b = jax.random.split(key)
return MyModule(a=hax.random.normal(k_a, (In, Out)), b=hax.random.normal(k_b, (In, Out)) if use_b else None)

k0 = jax.random.PRNGKey(0)
k1 = jax.random.PRNGKey(1)

model0 = init_fn(k0, False)
model1 = init_fn(k1, True)

is_checkpointed = True

with jax.sharding.Mesh(jax.devices(), ("devices",)), tempfile.TemporaryDirectory() as tmpdir:

save_checkpoint(eqx.filter(model0, is_checkpointed), step=0, checkpoint_path=tmpdir)

loaded = load_checkpoint_or_initialize(
init_fn,
tmpdir,
is_checkpointed=is_checkpointed,
allow_partial=True,
)(k1, True)

assert not any(jax.tree_util.tree_leaves(eqx.filter(loaded, lambda x: isinstance(x, ShapeDtypeStruct))))
assert hax.all(hax.equal(loaded.a, model0.a))
assert loaded.b is not None
assert hax.all(hax.equal(loaded.b, model1.b))
21 changes: 20 additions & 1 deletion tests/test_tensorstore_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,5 +156,24 @@ class MyModule(eqx.Module):
m3 = MyModule(a=hax.zeros(A), b=hax.ones(A))
with TemporaryDirectory() as tmpdir:
tree_serialize_leaves_tensorstore(tmpdir, m2)
with pytest.raises(ValueError):
with pytest.raises(FileNotFoundError):
tree_deserialize_leaves_tensorstore(tmpdir, m3)


def test_tensorstore_ok_with_missing():
mesh = jax.sharding.Mesh(jax.devices(), ("device",))
with mesh:
A = hax.Axis("A", 10)

class MyModule(eqx.Module):
a: Any
b: Any

m = MyModule(a=None, b=hax.zeros(A))
m2 = MyModule(a=hax.full(A, 4), b=hax.ones(A))

with TemporaryDirectory() as tmpdir:
tree_serialize_leaves_tensorstore(tmpdir, m)
m3 = tree_deserialize_leaves_tensorstore(tmpdir, m2, allow_missing=True)
assert hax.all(m3.a == hax.full(A, 4))
assert hax.all(m3.b == hax.zeros(A))
Loading