Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Roll back earlier change altering metadata format, which was observed to cause
breakages.

### Changed

- Split v0 checkpoint format/layout logic out from `OrbaxLayout` and into
seperate `V0Layout`.

## [0.11.29] - 2025-11-25

### Fixed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from orbax.checkpoint.experimental.v1._src.handlers import registration
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout
from orbax.checkpoint.experimental.v1._src.loading import v0_compatibility
from orbax.checkpoint.experimental.v1._src.metadata import serialization as metadata_serialization
from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types
from orbax.checkpoint.experimental.v1._src.path import types as path_types

Expand Down Expand Up @@ -148,24 +147,23 @@ async def _validate_pytree(self, checkpointable_name: str | None):
"""Validates a checkpoint path written by `ocp.save_pytree`.

Args:
checkpointable_name: The name of the checkpointable to load. A
subdirectory with this name must exist in `directory`. If None then
`directory` is expected to contain the checkpoint directly. Defaults to
`pytree`.
checkpointable_name: The name of the checkpointable to load. For Orbax V1,
a subdirectory with this name must exist in `directory`.

Raises:
FileNotFoundError: If the path does not exist, or if `pytree` is not found
in the directory
ValueError: If the PyTree checkpoint is malformed.
"""
pytree_dir = (
self.path
if checkpointable_name is None
else self.path / checkpointable_name
)
if checkpointable_name is not None and not await async_path.exists(
pytree_dir
):
if checkpointable_name is None:
raise ValueError(
"A V1 checkpoint was saved and user is attempting to load it,"
" treating it as a V0 Orbax checkpoint saved directly from"
f" {self.path}, this is not a characteristic of V1 saved checkpoints"
)

pytree_dir = self.path / checkpointable_name
if not await async_path.exists(pytree_dir):
subdirs = [
d.name
for d in await _subpaths(self.path)
Expand All @@ -180,17 +178,13 @@ async def _validate_pytree(self, checkpointable_name: str | None):
" using"
" `ocp.load_checkpointables()`."
)
if not await _has_pytree_metadata_file(pytree_dir):
# TODO(niketkb): Add following details to the error message:
# 1. we should check other available subdirectories and see if any of them
# look like PyTree checkpoints, and instruct the user to consider
# whether they meant to specify any of those.
# 2. we need to check the directory - if it contains PyTree files, suggest
# loading with checkpointable_name=None

if not (pytree_dir / PYTREE_METADATA_FILE).exists():
raise FileNotFoundError(
f"Checkpoint path {self.path} does not contain a PyTree metadata"
" file."
)

if not await _has_tensorstore_data_files(pytree_dir):
logging.warning(
"TensorStore data files not found in checkpoint path %s. This may be"
Expand All @@ -200,17 +194,13 @@ async def _validate_pytree(self, checkpointable_name: str | None):
)

async def _validate(self):
"""Validates a checkpoint directory.
"""Validates a checkpoint directory to be a V1 Orbax checkpoint.

Must be:
Must fulfill all of the following:
- Existing
- A directory.
- Not a temporary path.
- OR
- Has orbax.checkpoint indicator file.
- OR
- Has _CHECKPOINT_METADATA file.
- A subdirectory has _METADATA file (PyTree checkpoint).
- A directory
- Not a temporary path
- Has orbax.checkpoint indicator file

Raises:
FileNotFoundError: If the path does not exist.
Expand All @@ -235,33 +225,10 @@ async def _validate(self):
if ORBAX_CHECKPOINT_INDICATOR_FILE in [p.name for p in subpaths]:
return

# Path points to a single step checkpoint with valid metadata.
if await async_path.exists(
metadata_serialization.checkpoint_metadata_file_path(self.path)
):
return

# The path itself points to a PyTree checkpointable.
if await async_path.exists(self.path / PYTREE_METADATA_FILE):
return
# The path points to a directory containing at least one PyTree
# checkpointable.
for subpath in subpaths:
if await async_path.is_dir(subpath) and await async_path.exists(
subpath / PYTREE_METADATA_FILE
):
return

raise FileNotFoundError(
f"Checkpoint path {self.path} could not be identified as a valid Orbax"
" checkpoint. The path must conform to one of the following"
" conditions:\n - Contain the indicator file"
f" {ORBAX_CHECKPOINT_INDICATOR_FILE}. This should be true of all"
" checkpoints saved with the Orbax V1 API. If not present, the"
" checkpoint may have been saved with the V0 API.\n - Contain the"
" _CHECKPOINT_METADATA file.\n - Point directly to a PyTree"
" checkpointable (contain _METADATA file).\n - Contain a subdirectory"
" which is a PyTree checkpointable (contain _METADATA file).\n"
" V1 checkpoint. It is missing the indicator file"
f" '{ORBAX_CHECKPOINT_INDICATOR_FILE}'."
)

async def validate(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
from etils import epath
import numpy as np
from orbax.checkpoint import test_utils
from orbax.checkpoint._src.checkpointers import checkpointer
from orbax.checkpoint._src.handlers import composite_checkpoint_handler
from orbax.checkpoint._src.handlers import standard_checkpoint_handler
from orbax.checkpoint._src.metadata import value as value_metadata
from orbax.checkpoint._src.path import async_path
from orbax.checkpoint.experimental.v1._src.handlers import composite_handler
Expand Down Expand Up @@ -102,7 +99,8 @@ async def test_validate_no_indicator_file(self):
/ composite_handler.ORBAX_CHECKPOINT_INDICATOR_FILE
)
indicator_path.rmtree() # Remove the indicator file
await layout.validate()
with self.assertRaises(InvalidLayoutError):
await layout.validate()

async def test_validate_no_metadata_file(self):
layout = OrbaxLayout(self.orbax_path / '0')
Expand Down Expand Up @@ -176,93 +174,6 @@ async def test_metadata(self):
self.assertGreater(result_metadata.commit_timestamp_nsecs, 0)


class V0ValidationTest(
unittest.IsolatedAsyncioTestCase, parameterized.TestCase
):

def setUp(self):
super().setUp()
self.directory = epath.Path(self.create_tempdir().full_path) / 'ckpt'
self.pytree, _ = array_test_utils.create_numpy_pytree()
# Save a checkpoint with a checkpointable name, `state`.
ckptr = checkpointer.Checkpointer(
composite_checkpoint_handler.CompositeCheckpointHandler()
)
ckptr.save(
self.directory,
composite_checkpoint_handler.CompositeArgs(
state=standard_checkpoint_handler.StandardSaveArgs(self.pytree)
),
)

async def test_nonexistent_path(self):
with self.assertRaises(FileNotFoundError):
await OrbaxLayout(self.directory / 'foo')._validate()

async def test_not_a_directory(self):
await async_path.write_text(self.directory / 'foo', 'foo')
with self.assertRaises(NotADirectoryError):
await OrbaxLayout(self.directory / 'foo')._validate()

@parameterized.product(checkpointable_name=['state', None])
async def test_no_checkpoint_metadata(self, checkpointable_name: str | None):
directory = (
self.directory / checkpointable_name
if checkpointable_name is not None
else self.directory
)
await _unlink_checkpoint_metadata(directory)

await OrbaxLayout(directory)._validate()
if checkpointable_name is None:
await OrbaxLayout(directory)._validate_pytree('state')
else:
await OrbaxLayout(directory)._validate_pytree(None)

async def test_deleted_pytree(self):
directory = self.directory
(directory / 'state').rmtree()

await OrbaxLayout(directory)._validate()
with self.assertRaises(FileNotFoundError):
await OrbaxLayout(directory)._validate_pytree('state')

async def test_missing_checkpointable_matching_name(self):
with self.assertRaises(FileNotFoundError):
await OrbaxLayout(self.directory)._validate_pytree('foo')

@parameterized.product(checkpointable_name=['state', None])
async def test_no_pytree_metadata(self, checkpointable_name: str | None):
directory = (
self.directory / checkpointable_name
if checkpointable_name is not None
else self.directory
)
await _unlink_pytree_metadata(directory)

if checkpointable_name is None:
# Passes because we still have the checkpoint metadata.
await OrbaxLayout(directory)._validate()
with self.assertRaises(FileNotFoundError):
await OrbaxLayout(directory)._validate_pytree('state')
else:
with self.assertRaises(FileNotFoundError):
await OrbaxLayout(directory)._validate()
await OrbaxLayout(directory)._validate_pytree(None)

@parameterized.product(checkpointable_name=['state', None])
async def test_valid_pytree(self, checkpointable_name: str | None):
directory = (
self.directory / checkpointable_name
if checkpointable_name is not None
else self.directory
)
if checkpointable_name is None:
await OrbaxLayout(directory)._validate_pytree('state')
else:
await OrbaxLayout(directory)._validate_pytree(None)


class V1ValidationTest(
unittest.IsolatedAsyncioTestCase, parameterized.TestCase
):
Expand Down Expand Up @@ -290,7 +201,8 @@ async def test_missing_indicator_file(self, checkpointable_name: str | None):
else self.directory
)
await _unlink_indicator(directory)
await OrbaxLayout(directory)._validate()
with self.assertRaises(FileNotFoundError):
await OrbaxLayout(directory)._validate()

async def test_deleted_pytree(self):
directory = self.directory
Expand All @@ -317,10 +229,12 @@ async def test_no_pytree_metadata(self, checkpointable_name: str | None):

with self.assertRaises(FileNotFoundError):
await OrbaxLayout(directory)._validate()
with self.assertRaises(FileNotFoundError):
if checkpointable_name is None:

if checkpointable_name is None:
with self.assertRaises(FileNotFoundError):
await OrbaxLayout(directory)._validate_pytree('pytree')
else:
else:
with self.assertRaises(ValueError):
await OrbaxLayout(directory)._validate_pytree(None)


Expand Down
38 changes: 26 additions & 12 deletions checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout
from orbax.checkpoint.experimental.v1._src.layout import orbax_layout
from orbax.checkpoint.experimental.v1._src.layout import safetensors_layout
from orbax.checkpoint.experimental.v1._src.layout import v0_layout
from orbax.checkpoint.experimental.v1._src.path import types as path_types

InvalidLayoutError = checkpoint_layout.InvalidLayoutError
Expand Down Expand Up @@ -47,7 +48,11 @@ async def get_checkpoint_layout(

match layout_enum:
case CheckpointLayoutEnum.ORBAX:
layout_class = orbax_layout.OrbaxLayout
# This allows us to restore a v0 checkpoint with its own layout.
if _is_v0_checkpoint(path):
layout_class = v0_layout.V0Layout
else:
layout_class = orbax_layout.OrbaxLayout
case CheckpointLayoutEnum.SAFETENSORS:
layout_class = safetensors_layout.SafetensorsLayout
case _:
Expand Down Expand Up @@ -104,21 +109,31 @@ async def _try_resolve_pytree_checkpointable(
if checkpointable_name is not None:
return layout, checkpointable_name
# Not a v0 checkpoint; use the default name.
if not _is_v0_checkpoint(layout):
if not _is_v0_checkpoint(layout.path):
if checkpointable_name is None:
raise ValueError(
"Cannot extract pytree from top-level V1 checkpoint directory."
)
return layout, checkpointable_name
# If it's a V0 checkpoint, we can try to resolve the checkpointable name from
# the path.
if not isinstance(layout, orbax_layout.OrbaxLayout):
raise AssertionError(f"Expected an OrbaxLayout, but got a {type(layout)}.")
# Option 1: It may be a direct path to the PyTree checkpointable.
if not isinstance(layout, (orbax_layout.OrbaxLayout, v0_layout.V0Layout)):
raise AssertionError(
f"Expected an OrbaxLayout or V0Layout, but got a {type(layout)}."
)

# If the path itself is a V0 PyTree checkpoint (flat structure), we can
# "zoom out" to the parent directory and treat the current directory as the
# checkpointable.
try:
original_path = layout.path
new_layout = orbax_layout.OrbaxLayout(original_path.parent)
new_layout = v0_layout.V0Layout(original_path.parent)
await new_layout.validate_pytree(original_path.name)
return new_layout, original_path.name
except checkpoint_layout.InvalidLayoutError:
pass
# Option 2: It may be a V0 checkpoint containing a PyTree checkpointable. It

# It may be a V0 checkpoint containing a PyTree checkpointable. It
# is possible for there to be multiple, but this would be unusual, and it is
# fine to just return the first one.
dir_names = [p.name for p in layout.path.iterdir() if p.is_dir()]
Expand All @@ -135,8 +150,7 @@ async def _try_resolve_pytree_checkpointable(
)


def _is_v0_checkpoint(layout: CheckpointLayout) -> bool:
return not isinstance(layout, orbax_layout.OrbaxLayout) or (
isinstance(layout, orbax_layout.OrbaxLayout)
and not layout.has_indicator_file
)
def _is_v0_checkpoint(path: path_types.PathLike) -> bool:
ctx = context_lib.get_context()
path = ctx.file_options.path_class(path)
return not (path / orbax_layout.ORBAX_CHECKPOINT_INDICATOR_FILE).exists()
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async def test_root_directory(self):
):
await try_resolve_pytree_checkpointable(layout, None)

@parameterized.product(checkpointable_name=['state', 'params', None])
@parameterized.product(checkpointable_name=['state', 'params'])
async def test_v1(self, checkpointable_name):
layout = orbax_layout.OrbaxLayout(self.v1_directory)
self.assertTrue(layout.has_indicator_file)
Expand Down
Loading
Loading