diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 5e8ba0915..e1284ce79 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -22,6 +22,8 @@ breakages. include an arbitrary `step_prefix` with any character(s) such as underscores. - Fix CheckpointManager initial directory creation to use `file_options.path_permission_mode`. - Fix using jax.eval_shape with StandardRestore +- #v1 Fix missing `_CHECKPOINT_METADATA` upon restoration and reading for + checkpointable handlers. ### Changed diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/composite_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/composite_handler.py index 1043ff8ef..be7ad1948 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/composite_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/composite_handler.py @@ -25,7 +25,9 @@ from orbax.checkpoint._src.metadata import step_metadata_serialization from orbax.checkpoint._src.multihost import multihost from orbax.checkpoint._src.path import async_path +from orbax.checkpoint._src.path import format_utils from orbax.checkpoint.experimental.v1._src.context import context as context_lib +from orbax.checkpoint.experimental.v1._src.handlers import pytree_handler from orbax.checkpoint.experimental.v1._src.handlers import registration from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types import orbax.checkpoint.experimental.v1._src.handlers.global_registration # pylint: disable=unused-import @@ -262,7 +264,7 @@ def get_handlers_for_load( def _get_saved_handler_typestrs( self, directory: path_types.Path - ) -> dict[str, str]: + ) -> dict[str, str | None]: """Reads from the checkpoint metadata to get saved handler typestrs.""" step_metadata_file_path = checkpoint_metadata.step_metadata_file_path( directory @@ -291,28 +293,40 @@ def _get_saved_handler_typestrs( directory, ) - saved_handler_typestrs: dict[str, str] = {} + # The following generically handles the case where the checkpoint is missing + # the checkpoint metadata file. This is a fallback for older v0 checkpoints. + + # We check each subdirectory treating it as a checkpointable. + # The order of presedence for mapping checkpointable names to handlers is: + # 1. A pytree metadata file, indicating a pytree handler + # 2. A checkpointable with a handler registered in the handler registry. + # 3. If neither, we skip it and treat it as garbage since we don't know + # how to handle it. + + saved_handler_typestrs: dict[str, str | None] = {} for checkpointable_path in directory.iterdir(): - serialized_metadata = self._metadata_store.read( - checkpoint_metadata.step_metadata_file_path(checkpointable_path) - ) - if serialized_metadata is None: + if not checkpointable_path.is_dir(): continue - saved_metadata = step_metadata_serialization.deserialize( - serialized_metadata + checkpointable_name = checkpointable_path.name + + pytree_metadata_path = ( + checkpointable_path / format_utils.PYTREE_METADATA_FILE ) - if isinstance(saved_metadata.item_handlers, dict): + + if pytree_metadata_path.exists(): + saved_handler_typestrs[checkpointable_name] = handler_types.typestr( + pytree_handler.PyTreeHandler + ) + elif self._handler_registry.has(checkpointable_name): + # If the handler is registered we can assume it will be found in + # resolve_handler_for_load. + saved_handler_typestrs[checkpointable_name] = None + else: raise ValueError( - f'Path at {directory} contains subdirectories:' - f' {_subdirs(directory)}, which are expected to' - ' match the keys given by the _CHECKPOINT_METADATA file:' - f' {saved_metadata.item_handlers}. If you intended to load a pytree' - ' checkpoint from the given path, then please consider using' - ' `loading.load_pytree(..., checkpointable_name=None)` instead.' - f' {_V0_ERROR_MESSAGE}' + 'Cannot determine handler for checkpointable' + f" '{checkpointable_name}' in directory {directory}. The top-level" + " '_CHECKPOINT_METADATA' is missing, and this item does not have a" + " '_METADATA' file to be loaded as a PyTree, nor is it registered" + " in the handler registry." ) - item_handlers = saved_metadata.item_handlers - if item_handlers is not None: - checkpointable_name = checkpointable_path.name - saved_handler_typestrs[checkpointable_name] = item_handlers return saved_handler_typestrs diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout.py index 1fe8aa60a..fc5bbfc6e 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout.py @@ -25,7 +25,6 @@ from orbax.checkpoint.experimental.v1._src.handlers import registration from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout from orbax.checkpoint.experimental.v1._src.loading import v0_compatibility -from orbax.checkpoint.experimental.v1._src.metadata import serialization as metadata_serialization from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types from orbax.checkpoint.experimental.v1._src.path import types as path_types @@ -148,24 +147,23 @@ async def _validate_pytree(self, checkpointable_name: str | None): """Validates a checkpoint path written by `ocp.save_pytree`. Args: - checkpointable_name: The name of the checkpointable to load. A - subdirectory with this name must exist in `directory`. If None then - `directory` is expected to contain the checkpoint directly. Defaults to - `pytree`. + checkpointable_name: The name of the checkpointable to load. For Orbax V1, + a subdirectory with this name must exist in `directory`. Raises: FileNotFoundError: If the path does not exist, or if `pytree` is not found in the directory ValueError: If the PyTree checkpoint is malformed. """ - pytree_dir = ( - self.path - if checkpointable_name is None - else self.path / checkpointable_name - ) - if checkpointable_name is not None and not await async_path.exists( - pytree_dir - ): + if checkpointable_name is None: + raise ValueError( + "A V1 checkpoint was saved and user is attempting to load it," + " treating it as a V0 Orbax checkpoint saved directly from" + f" {self.path}, this is not a characteristic of V1 saved checkpoints" + ) + + pytree_dir = self.path / checkpointable_name + if not await async_path.exists(pytree_dir): subdirs = [ d.name for d in await _subpaths(self.path) @@ -180,17 +178,13 @@ async def _validate_pytree(self, checkpointable_name: str | None): " using" " `ocp.load_checkpointables()`." ) - if not await _has_pytree_metadata_file(pytree_dir): - # TODO(niketkb): Add following details to the error message: - # 1. we should check other available subdirectories and see if any of them - # look like PyTree checkpoints, and instruct the user to consider - # whether they meant to specify any of those. - # 2. we need to check the directory - if it contains PyTree files, suggest - # loading with checkpointable_name=None + + if not (pytree_dir / PYTREE_METADATA_FILE).exists(): raise FileNotFoundError( f"Checkpoint path {self.path} does not contain a PyTree metadata" " file." ) + if not await _has_tensorstore_data_files(pytree_dir): logging.warning( "TensorStore data files not found in checkpoint path %s. This may be" @@ -200,17 +194,13 @@ async def _validate_pytree(self, checkpointable_name: str | None): ) async def _validate(self): - """Validates a checkpoint directory. + """Validates a checkpoint directory to be a V1 Orbax checkpoint. - Must be: + Must fulfill all of the following: - Existing - - A directory. - - Not a temporary path. - - OR - - Has orbax.checkpoint indicator file. - - OR - - Has _CHECKPOINT_METADATA file. - - A subdirectory has _METADATA file (PyTree checkpoint). + - A directory + - Not a temporary path + - Has orbax.checkpoint indicator file Raises: FileNotFoundError: If the path does not exist. @@ -235,33 +225,10 @@ async def _validate(self): if ORBAX_CHECKPOINT_INDICATOR_FILE in [p.name for p in subpaths]: return - # Path points to a single step checkpoint with valid metadata. - if await async_path.exists( - metadata_serialization.checkpoint_metadata_file_path(self.path) - ): - return - - # The path itself points to a PyTree checkpointable. - if await async_path.exists(self.path / PYTREE_METADATA_FILE): - return - # The path points to a directory containing at least one PyTree - # checkpointable. - for subpath in subpaths: - if await async_path.is_dir(subpath) and await async_path.exists( - subpath / PYTREE_METADATA_FILE - ): - return - raise FileNotFoundError( f"Checkpoint path {self.path} could not be identified as a valid Orbax" - " checkpoint. The path must conform to one of the following" - " conditions:\n - Contain the indicator file" - f" {ORBAX_CHECKPOINT_INDICATOR_FILE}. This should be true of all" - " checkpoints saved with the Orbax V1 API. If not present, the" - " checkpoint may have been saved with the V0 API.\n - Contain the" - " _CHECKPOINT_METADATA file.\n - Point directly to a PyTree" - " checkpointable (contain _METADATA file).\n - Contain a subdirectory" - " which is a PyTree checkpointable (contain _METADATA file).\n" + " V1 checkpoint. It is missing the indicator file" + f" '{ORBAX_CHECKPOINT_INDICATOR_FILE}'." ) async def validate(self): diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout_test.py index 66283fb7d..f03808be5 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout_test.py @@ -18,9 +18,6 @@ from etils import epath import numpy as np from orbax.checkpoint import test_utils -from orbax.checkpoint._src.checkpointers import checkpointer -from orbax.checkpoint._src.handlers import composite_checkpoint_handler -from orbax.checkpoint._src.handlers import standard_checkpoint_handler from orbax.checkpoint._src.metadata import value as value_metadata from orbax.checkpoint._src.path import async_path from orbax.checkpoint.experimental.v1._src.handlers import composite_handler @@ -102,7 +99,8 @@ async def test_validate_no_indicator_file(self): / composite_handler.ORBAX_CHECKPOINT_INDICATOR_FILE ) indicator_path.rmtree() # Remove the indicator file - await layout.validate() + with self.assertRaises(InvalidLayoutError): + await layout.validate() async def test_validate_no_metadata_file(self): layout = OrbaxLayout(self.orbax_path / '0') @@ -176,93 +174,6 @@ async def test_metadata(self): self.assertGreater(result_metadata.commit_timestamp_nsecs, 0) -class V0ValidationTest( - unittest.IsolatedAsyncioTestCase, parameterized.TestCase -): - - def setUp(self): - super().setUp() - self.directory = epath.Path(self.create_tempdir().full_path) / 'ckpt' - self.pytree, _ = array_test_utils.create_numpy_pytree() - # Save a checkpoint with a checkpointable name, `state`. - ckptr = checkpointer.Checkpointer( - composite_checkpoint_handler.CompositeCheckpointHandler() - ) - ckptr.save( - self.directory, - composite_checkpoint_handler.CompositeArgs( - state=standard_checkpoint_handler.StandardSaveArgs(self.pytree) - ), - ) - - async def test_nonexistent_path(self): - with self.assertRaises(FileNotFoundError): - await OrbaxLayout(self.directory / 'foo')._validate() - - async def test_not_a_directory(self): - await async_path.write_text(self.directory / 'foo', 'foo') - with self.assertRaises(NotADirectoryError): - await OrbaxLayout(self.directory / 'foo')._validate() - - @parameterized.product(checkpointable_name=['state', None]) - async def test_no_checkpoint_metadata(self, checkpointable_name: str | None): - directory = ( - self.directory / checkpointable_name - if checkpointable_name is not None - else self.directory - ) - await _unlink_checkpoint_metadata(directory) - - await OrbaxLayout(directory)._validate() - if checkpointable_name is None: - await OrbaxLayout(directory)._validate_pytree('state') - else: - await OrbaxLayout(directory)._validate_pytree(None) - - async def test_deleted_pytree(self): - directory = self.directory - (directory / 'state').rmtree() - - await OrbaxLayout(directory)._validate() - with self.assertRaises(FileNotFoundError): - await OrbaxLayout(directory)._validate_pytree('state') - - async def test_missing_checkpointable_matching_name(self): - with self.assertRaises(FileNotFoundError): - await OrbaxLayout(self.directory)._validate_pytree('foo') - - @parameterized.product(checkpointable_name=['state', None]) - async def test_no_pytree_metadata(self, checkpointable_name: str | None): - directory = ( - self.directory / checkpointable_name - if checkpointable_name is not None - else self.directory - ) - await _unlink_pytree_metadata(directory) - - if checkpointable_name is None: - # Passes because we still have the checkpoint metadata. - await OrbaxLayout(directory)._validate() - with self.assertRaises(FileNotFoundError): - await OrbaxLayout(directory)._validate_pytree('state') - else: - with self.assertRaises(FileNotFoundError): - await OrbaxLayout(directory)._validate() - await OrbaxLayout(directory)._validate_pytree(None) - - @parameterized.product(checkpointable_name=['state', None]) - async def test_valid_pytree(self, checkpointable_name: str | None): - directory = ( - self.directory / checkpointable_name - if checkpointable_name is not None - else self.directory - ) - if checkpointable_name is None: - await OrbaxLayout(directory)._validate_pytree('state') - else: - await OrbaxLayout(directory)._validate_pytree(None) - - class V1ValidationTest( unittest.IsolatedAsyncioTestCase, parameterized.TestCase ): @@ -290,7 +201,8 @@ async def test_missing_indicator_file(self, checkpointable_name: str | None): else self.directory ) await _unlink_indicator(directory) - await OrbaxLayout(directory)._validate() + with self.assertRaises(FileNotFoundError): + await OrbaxLayout(directory)._validate() async def test_deleted_pytree(self): directory = self.directory @@ -317,10 +229,12 @@ async def test_no_pytree_metadata(self, checkpointable_name: str | None): with self.assertRaises(FileNotFoundError): await OrbaxLayout(directory)._validate() - with self.assertRaises(FileNotFoundError): - if checkpointable_name is None: + + if checkpointable_name is None: + with self.assertRaises(FileNotFoundError): await OrbaxLayout(directory)._validate_pytree('pytree') - else: + else: + with self.assertRaises(ValueError): await OrbaxLayout(directory)._validate_pytree(None) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry.py index 22b5dd3a3..624078f3b 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry.py @@ -19,6 +19,7 @@ from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout from orbax.checkpoint.experimental.v1._src.layout import orbax_layout from orbax.checkpoint.experimental.v1._src.layout import safetensors_layout +from orbax.checkpoint.experimental.v1._src.layout import v0_layout from orbax.checkpoint.experimental.v1._src.path import types as path_types InvalidLayoutError = checkpoint_layout.InvalidLayoutError @@ -47,7 +48,11 @@ async def get_checkpoint_layout( match layout_enum: case CheckpointLayoutEnum.ORBAX: - layout_class = orbax_layout.OrbaxLayout + # This allows us to restore a v0 checkpoint with its own layout. + if _is_v0_checkpoint(path): + layout_class = v0_layout.V0Layout + else: + layout_class = orbax_layout.OrbaxLayout case CheckpointLayoutEnum.SAFETENSORS: layout_class = safetensors_layout.SafetensorsLayout case _: @@ -104,21 +109,31 @@ async def _try_resolve_pytree_checkpointable( if checkpointable_name is not None: return layout, checkpointable_name # Not a v0 checkpoint; use the default name. - if not _is_v0_checkpoint(layout): + if not _is_v0_checkpoint(layout.path): + if checkpointable_name is None: + raise ValueError( + "Cannot extract pytree from top-level V1 checkpoint directory." + ) return layout, checkpointable_name # If it's a V0 checkpoint, we can try to resolve the checkpointable name from # the path. - if not isinstance(layout, orbax_layout.OrbaxLayout): - raise AssertionError(f"Expected an OrbaxLayout, but got a {type(layout)}.") - # Option 1: It may be a direct path to the PyTree checkpointable. + if not isinstance(layout, (orbax_layout.OrbaxLayout, v0_layout.V0Layout)): + raise AssertionError( + f"Expected an OrbaxLayout or V0Layout, but got a {type(layout)}." + ) + + # If the path itself is a V0 PyTree checkpoint (flat structure), we can + # "zoom out" to the parent directory and treat the current directory as the + # checkpointable. try: original_path = layout.path - new_layout = orbax_layout.OrbaxLayout(original_path.parent) + new_layout = v0_layout.V0Layout(original_path.parent) await new_layout.validate_pytree(original_path.name) return new_layout, original_path.name except checkpoint_layout.InvalidLayoutError: pass - # Option 2: It may be a V0 checkpoint containing a PyTree checkpointable. It + + # It may be a V0 checkpoint containing a PyTree checkpointable. It # is possible for there to be multiple, but this would be unusual, and it is # fine to just return the first one. dir_names = [p.name for p in layout.path.iterdir() if p.is_dir()] @@ -135,8 +150,7 @@ async def _try_resolve_pytree_checkpointable( ) -def _is_v0_checkpoint(layout: CheckpointLayout) -> bool: - return not isinstance(layout, orbax_layout.OrbaxLayout) or ( - isinstance(layout, orbax_layout.OrbaxLayout) - and not layout.has_indicator_file - ) +def _is_v0_checkpoint(path: path_types.PathLike) -> bool: + ctx = context_lib.get_context() + path = ctx.file_options.path_class(path) + return not (path / orbax_layout.ORBAX_CHECKPOINT_INDICATOR_FILE).exists() diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry_test.py index 93c133e4f..ebe4423f7 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry_test.py @@ -60,7 +60,7 @@ async def test_root_directory(self): ): await try_resolve_pytree_checkpointable(layout, None) - @parameterized.product(checkpointable_name=['state', 'params', None]) + @parameterized.product(checkpointable_name=['state', 'params']) async def test_v1(self, checkpointable_name): layout = orbax_layout.OrbaxLayout(self.v1_directory) self.assertTrue(layout.has_indicator_file) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/v0_layout.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/v0_layout.py new file mode 100644 index 000000000..ec9b7b3d0 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/v0_layout.py @@ -0,0 +1,204 @@ +# Copyright 2025 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines `V0Layout`, a class to handle Orbax V0 checkpoint formats.""" + +import asyncio +from typing import Any, Awaitable + +from absl import logging +from orbax.checkpoint._src.metadata import checkpoint as checkpoint_metadata +from orbax.checkpoint._src.path import async_path +from orbax.checkpoint._src.path import temporary_paths +from orbax.checkpoint.experimental.v1._src.context import context as context_lib +from orbax.checkpoint.experimental.v1._src.handlers import composite_handler +from orbax.checkpoint.experimental.v1._src.handlers import registration +from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout +from orbax.checkpoint.experimental.v1._src.loading import v0_compatibility +from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types +from orbax.checkpoint.experimental.v1._src.path import types as path_types + + +InvalidLayoutError = checkpoint_layout.InvalidLayoutError +Path = path_types.Path +CheckpointLayout = checkpoint_layout.CheckpointLayout + +PYTREE_METADATA_FILE = "_METADATA" +ORBAX_CHECKPOINT_INDICATOR_FILE = "orbax.checkpoint" +_OCDBT_MANIFEST_FILE = "ocdbt.manifest" +_ZARRAY_FILE = ".zarray" + + +async def _has_pytree_metadata_file(path: Path) -> bool: + return await async_path.exists(path / PYTREE_METADATA_FILE) + + +async def _has_ocdbt_manifest_file(path: Path) -> bool: + return await async_path.exists(path / _OCDBT_MANIFEST_FILE) + + +async def _has_zarray_files(path: Path) -> bool: + paths = list(await async_path.iterdir(path)) + awaitables = [async_path.exists(p / _ZARRAY_FILE) for p in paths] + return any(await asyncio.gather(*awaitables)) + + +async def _has_tensorstore_data_files(path: Path) -> bool: + return await _has_ocdbt_manifest_file(path) or await _has_zarray_files(path) + + +class V0Layout(CheckpointLayout): + """V0Layout. + + This class handles Orbax V0 checkpoint formats. It inherits + abstract methods from CheckpointLayout. It performs a few core functions: + - Resolves handlers for saving and loading. + - Saves and loads checkpointables to/from individual subdirectories by + delegating to the resolved handlers. + """ + + def __init__(self, path: Path): + self._context = context_lib.get_context() + self._handler_registry = registration.local_registry( + self._context.checkpointables_options.registry, + include_global_registry=False, + ) + self._composite_handler = composite_handler.CompositeHandler( + self._handler_registry + ) + self._path = path + + @property + def path(self) -> Path: + return self._path + + async def metadata(self) -> metadata_types.CheckpointMetadata[dict[str, Any]]: + # Uses the v0 checkpointer to get v0 StepMetadata + checkpointer, _ = v0_compatibility.get_v0_checkpointer_and_args( + self._path, None, context=context_lib.get_context() + ) + step_metadata = checkpointer.metadata(self._path) + + item_metadata = {k: v for k, v in step_metadata.item_metadata.items()} + # Exclude `metrics` if present. + item_metadata.pop("metrics", None) + + return metadata_types.CheckpointMetadata[dict[str, Any]]( + metadata=item_metadata, + init_timestamp_nsecs=step_metadata.init_timestamp_nsecs, + commit_timestamp_nsecs=step_metadata.commit_timestamp_nsecs, + custom_metadata=step_metadata.custom_metadata, + ) + + async def _validate(self) -> None: + """Validates a V0 checkpoint directory.""" + if not await async_path.exists(self.path): + raise FileNotFoundError(f"Checkpoint path {self.path} does not exist.") + if not await async_path.is_dir(self.path): + raise NotADirectoryError( + f"Checkpoint path {self.path} is not a directory." + ) + context = context_lib.get_context() + if await temporary_paths.is_path_temporary( + self.path, + temporary_path_cls=context.file_options.temporary_path_class, + ): + raise ValueError(f"Found incomplete checkpoint at {self.path}.") + + # Path points to a single step checkpoint with valid checkpoint metadata. + if await async_path.exists( + checkpoint_metadata.step_metadata_file_path(self.path) + ): + return + + # The path itself points to a PyTree checkpointable. + if await async_path.exists(self.path / PYTREE_METADATA_FILE): + return + + subpaths = await async_path.iterdir(self.path) + # The path points to a directory containing at least one PyTree + # checkpointable. + for subpath in subpaths: + if await async_path.is_dir(subpath) and await async_path.exists( + subpath / PYTREE_METADATA_FILE + ): + return + + raise ValueError( + f"Checkpoint path {self.path} is not a valid V0 checkpoint." + ) + + async def _validate_pytree(self, checkpointable_name: str | None) -> None: + """Validates the given path as a V0 PyTree checkpoint.""" + pytree_dir = ( + self.path + if checkpointable_name is None + else self.path / checkpointable_name + ) + + if checkpointable_name is not None and not await async_path.exists( + pytree_dir + ): + raise FileNotFoundError( + f"Checkpoint path {self.path} must contain a subdirectory named" + f' "{checkpointable_name}".' + ) + + if not await _has_pytree_metadata_file(pytree_dir): + # TODO(angelmau): Add following details to the error message: + # 1. we should check other available subdirectories and see if any of them + # look like PyTree checkpoints, and instruct the user to consider + # whether they meant to specify any of those. + # 2. we need to check the directory - if it contains PyTree files, suggest + # loading with checkpointable_name=None + raise FileNotFoundError( + f"Checkpoint path {self.path} does not contain a PyTree metadata" + " file." + ) + + if not await _has_tensorstore_data_files(pytree_dir): + logging.warning( + "TensorStore data files not found in checkpoint path %s. This may be" + " a sign of a malformed checkpoint, unless your checkpoint consists" + " entirely of strings or other non-standard PyTree leaves.", + self.path, + ) + + async def validate(self) -> None: + """Validates a V0 checkpoint directory.""" + try: + await self._validate() + except BaseException as e: + raise InvalidLayoutError( + f"Failed to interpret path {self._path} as an Orbax V0 checkpoint." + ) from e + + async def validate_pytree(self, checkpointable_name: str | None) -> None: + """Validates the given path as a V0 PyTree checkpoint.""" + try: + await self._validate_pytree(checkpointable_name) + except BaseException as e: + raise InvalidLayoutError( + f"Failed to interpret path {self._path} as an Orbax V0 PyTree" + " checkpoint." + ) from e + + async def load( + self, + abstract_checkpointables: dict[str, Any] | None = None, + ) -> Awaitable[dict[str, Any]]: + load_awaitable = await self._composite_handler.load( + self._path, abstract_checkpointables + ) + return load_awaitable diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/v0_layout_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/v0_layout_test.py new file mode 100644 index 000000000..11956d933 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/v0_layout_test.py @@ -0,0 +1,227 @@ +# Copyright 2025 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from absl.testing import absltest +from absl.testing import parameterized +from etils import epath +import numpy as np +from orbax.checkpoint import test_utils +from orbax.checkpoint._src.checkpointers import checkpointer +from orbax.checkpoint._src.handlers import composite_checkpoint_handler +from orbax.checkpoint._src.handlers import standard_checkpoint_handler +from orbax.checkpoint._src.path import async_path +from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout +from orbax.checkpoint.experimental.v1._src.layout import v0_layout +from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types +from orbax.checkpoint.experimental.v1._src.testing import array_utils as array_test_utils +import safetensors.numpy + + +InvalidLayoutError = checkpoint_layout.InvalidLayoutError +np_save_file = safetensors.numpy.save_file + + +async def _unlink_checkpoint_metadata(path: epath.Path): + await async_path.unlink(path / '_CHECKPOINT_METADATA', missing_ok=True) + + +async def _unlink_pytree_metadata(path: epath.Path): + await async_path.unlink(path / '_METADATA', missing_ok=True) + for subdir in await async_path.iterdir(path): + if not await async_path.is_dir(subdir): + continue + await async_path.unlink(subdir / '_METADATA', missing_ok=True) + + +class V0LayoutTest(unittest.IsolatedAsyncioTestCase, parameterized.TestCase): + + def setUp(self): + super().setUp() + self.test_dir = self.create_tempdir() + self.orbax_path = epath.Path(self.test_dir.full_path) / 'test_checkpoint' + self.safetensors_path = ( + epath.Path(self.test_dir.full_path) / 'test_checkpoint.safetensors' + ) + + # Create a mock SafeTensors and Orbax V0 checkpoint + self.object_to_save = { + 'a': np.array(3 * [1, 2, 3], dtype=np.int32), + 'b': np.array([0, 1, 0.2], dtype=np.float32), + } + self.custom_metadata = {'framework': 'JAX', 'version': '1.0'} + np_save_file(self.object_to_save, self.safetensors_path) + + # Save V0 checkpoint + ckptr = checkpointer.Checkpointer( + composite_checkpoint_handler.CompositeCheckpointHandler() + ) + ckptr.save( + self.orbax_path / '0', + composite_checkpoint_handler.CompositeArgs( + state=standard_checkpoint_handler.StandardSaveArgs( + self.object_to_save + ) + ), + ) + + async def test_valid_v0_checkpoint(self): + layout = v0_layout.V0Layout(self.orbax_path / '0') + await layout.validate() + + async def test_invalid_v0_checkpoint(self): + layout = v0_layout.V0Layout(self.safetensors_path) + with self.assertRaises(InvalidLayoutError): + await layout.validate() + + async def test_validate_fails_not_directory(self): + layout = v0_layout.V0Layout(self.orbax_path / '1') + with self.assertRaises(InvalidLayoutError): + await layout.validate() + + async def test_validate_no_metadata_file(self): + # V0Layout checks for _CHECKPOINT_METADATA or _METADATA in subdirs. + layout = v0_layout.V0Layout(self.orbax_path / '0') + metadata_path = self.orbax_path / '0' / '_CHECKPOINT_METADATA' + self.assertTrue(metadata_path.exists()) + metadata_path.rmtree() # Remove the metadata file + # Should still pass if subdirs have metadata + await layout.validate() + + async def test_validate_no_metadata_files(self): + layout = v0_layout.V0Layout(self.orbax_path / '0') + metadata_path = self.orbax_path / '0' / '_CHECKPOINT_METADATA' + metadata_path.rmtree() + # Also remove subdir metadata + # The V0 checkpoint structure from CompositeCheckpointHandler is: + # 0/ + # state/ + # _METADATA + # ... + pytree_metadata_path = self.orbax_path / '0' / 'state' / '_METADATA' + pytree_metadata_path.rmtree() + + with self.assertRaises(InvalidLayoutError): + await layout.validate() + + async def test_load_v0_checkpoint(self): + layout = v0_layout.V0Layout(self.orbax_path / '0') + restored_checkpointables_await = await layout.load() + restored_checkpointables = await restored_checkpointables_await + # restored_checkpointables will be {'state': ...} because of CompositeArgs + test_utils.assert_tree_equal( + self, restored_checkpointables['state'], self.object_to_save + ) + + async def test_metadata(self): + """Tests the metadata() method.""" + # V0Layout.metadata() delegates to CompositeHandler.metadata() + # which reads _CHECKPOINT_METADATA or infers from subdirs. + layout = v0_layout.V0Layout(self.orbax_path / '0') + result_metadata = await layout.metadata() + self.assertIsInstance(result_metadata, metadata_types.CheckpointMetadata) + + +class V0InternalValidationTest( + unittest.IsolatedAsyncioTestCase, parameterized.TestCase +): + + def setUp(self): + super().setUp() + self.directory = epath.Path(self.create_tempdir().full_path) / 'ckpt' + self.pytree, _ = array_test_utils.create_numpy_pytree() + # Save a checkpoint with a checkpointable name, `state`. + ckptr = checkpointer.Checkpointer( + composite_checkpoint_handler.CompositeCheckpointHandler() + ) + ckptr.save( + self.directory, + composite_checkpoint_handler.CompositeArgs( + state=standard_checkpoint_handler.StandardSaveArgs(self.pytree) + ), + ) + + async def test_nonexistent_path(self): + with self.assertRaises(FileNotFoundError): + await v0_layout.V0Layout(self.directory / 'foo')._validate() + + async def test_not_a_directory(self): + await async_path.write_text(self.directory / 'foo', 'foo') + with self.assertRaises(NotADirectoryError): + await v0_layout.V0Layout(self.directory / 'foo')._validate() + + @parameterized.product(checkpointable_name=['state', None]) + async def test_no_checkpoint_metadata(self, checkpointable_name: str | None): + directory = ( + self.directory / checkpointable_name + if checkpointable_name is not None + else self.directory + ) + await _unlink_checkpoint_metadata(directory) + + # V0Layout should validate successfully even without _CHECKPOINT_METADATA + # if subdirectories contain PyTree metadata. + await v0_layout.V0Layout(directory)._validate() + if checkpointable_name is None: + await v0_layout.V0Layout(directory)._validate_pytree('state') + else: + await v0_layout.V0Layout(directory)._validate_pytree(None) + + async def test_deleted_pytree(self): + directory = self.directory + (directory / 'state').rmtree() + + await v0_layout.V0Layout(directory)._validate() + with self.assertRaises(FileNotFoundError): + await v0_layout.V0Layout(directory)._validate_pytree('state') + + async def test_missing_checkpointable_matching_name(self): + with self.assertRaises(FileNotFoundError): + await v0_layout.V0Layout(self.directory)._validate_pytree('foo') + + @parameterized.product(checkpointable_name=['state', None]) + async def test_no_pytree_metadata(self, checkpointable_name: str | None): + directory = ( + self.directory / checkpointable_name + if checkpointable_name is not None + else self.directory + ) + await _unlink_pytree_metadata(directory) + + if checkpointable_name is None: + # Passes because we still have the checkpoint metadata. + await v0_layout.V0Layout(directory)._validate() + with self.assertRaises(FileNotFoundError): + await v0_layout.V0Layout(directory)._validate_pytree('state') + else: + with self.assertRaises(ValueError): + await v0_layout.V0Layout(directory)._validate() + with self.assertRaises(FileNotFoundError): + await v0_layout.V0Layout(directory)._validate_pytree(None) + + @parameterized.product(checkpointable_name=['state', None]) + async def test_valid_pytree(self, checkpointable_name: str | None): + directory = ( + self.directory / checkpointable_name + if checkpointable_name is not None + else self.directory + ) + if checkpointable_name is None: + await v0_layout.V0Layout(directory)._validate_pytree('state') + else: + await v0_layout.V0Layout(directory)._validate_pytree(None) + + +if __name__ == '__main__': + absltest.main() diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/loading.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/loading.py index 5b9f5ac08..1b83c227b 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/loading.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/loading.py @@ -105,6 +105,7 @@ def load_pytree( logging.info('Loading checkpoint from %s.', path) ctx = context_lib.get_context() path = ctx.file_options.path_class(path) + layout, checkpointable_name = asyncio.run( layout_registry.get_checkpoint_layout_pytree( path, ctx.checkpoint_layout, checkpointable_name diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py index 8075b700a..6fb054289 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py @@ -35,6 +35,7 @@ import numpy as np import optax from orbax.checkpoint import test_utils +from orbax.checkpoint._src.checkpointers import standard_checkpointer from orbax.checkpoint._src.multihost import multihost as multihost_v0 from orbax.checkpoint._src.path import atomicity from orbax.checkpoint._src.serialization import serialization @@ -125,6 +126,16 @@ def test_load_default(self, use_async): loaded = self.load_and_wait(self.directory, use_async=use_async) test_utils.assert_tree_equal(self, self.pytree, loaded) + def test_flat_v0_checkpoint(self): + flat_dir = self.directory / 'flat_ckpt' + pytree = {'a': np.array([1, 2, 3])} + with standard_checkpointer.StandardCheckpointer() as ckptr: + ckptr.save(flat_dir, pytree) + + # Verify load_pytree + loaded = ocp.load_pytree(flat_dir, checkpointable_name=None) + test_utils.assert_tree_equal(self, loaded, pytree) + def test_save_pytree_async(self): start_serialize = threading.Event() original_serialize = serialization.async_serialize_from_host @@ -811,6 +822,69 @@ def test_partial_restore_with_placeholder(self, use_async): ): self.load_and_wait(directory, reference_item, use_async=use_async) + def test_missing_checkpoint_metadata_checkpointables(self): + """Checkpointables API save and restore test for missing _CHECKPOINT_METADATA.""" + step_dir = self.directory / 'step_0' + checkpointables = {'a': self.pytree, 'b': self.pytree} + ocp.save_checkpointables(step_dir, checkpointables) + + # Delete the _CHECKPOINT_METADATA file + metadata_file = step_dir / '_CHECKPOINT_METADATA' + self.assertTrue(metadata_file.exists()) + metadata_file.unlink(missing_ok=True) + + loaded = ocp.load_checkpointables(step_dir) + + test_utils.assert_tree_equal(self, self.pytree, loaded['a']) + test_utils.assert_tree_equal(self, self.pytree, loaded['b']) + + def test_missing_checkpoint_metadata_pytree(self): + """Pytree API save and restore test for missing _CHECKPOINT_METADATA.""" + step_dir = self.directory / 'step_2' + ocp.save_pytree(step_dir, self.pytree) + + # Delete the _CHECKPOINT_METADATA file + metadata_file = step_dir / '_CHECKPOINT_METADATA' + self.assertTrue(metadata_file.exists()) + metadata_file.unlink(missing_ok=True) + + loaded = ocp.load_pytree(step_dir) + test_utils.assert_tree_equal(self, self.pytree, loaded) + + @parameterized.parameters(True, False) + def test_missing_metadata_with_registered_handler_succeeds( + self, registered + ): + """Tests fallback success for non-PyTree items with registered handler.""" + step_dir = self.directory / 'step_3' + checkpointables = {'a': self.pytree, 'b': {'some': 'data'}} + # Ensure 'b' is saved with JsonHandler + + options = ocp.options.CheckpointablesOptions.create_with_handlers( + b=ocp.handlers.JsonHandler + ) + with ocp.Context(checkpointables_options=options): + ocp.save_checkpointables(step_dir, checkpointables) + + # Delete the _CHECKPOINT_METADATA file + metadata_file = step_dir / '_CHECKPOINT_METADATA' + self.assertTrue(metadata_file.exists()) + metadata_file.unlink(missing_ok=True) + + if registered: + # Register the handler for 'b' during load as well + with ocp.Context(checkpointables_options=options): + loaded = ocp.load_checkpointables(step_dir) + + test_utils.assert_tree_equal(self, self.pytree, loaded['a']) + self.assertEqual(checkpointables['b'], loaded['b']) + else: + # Try to load without registering 'b' + with self.assertRaisesRegex( + ValueError, "Cannot determine handler for checkpointable 'b'" + ): + ocp.load_checkpointables(step_dir) + def test_checkpointable_with_stateful_checkpointable(self): point = handler_utils.Point(1, 2) checkpointables = {'point': point}