diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 5e8ba0915..0a13c26ab 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -14,6 +14,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Roll back earlier change altering metadata format, which was observed to cause breakages. +### Changed + +- Split v0 checkpoint format/layout logic out from `OrbaxLayout` and into +seperate `V0Layout`. + ## [0.11.29] - 2025-11-25 ### Fixed diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout.py index 1fe8aa60a..fc5bbfc6e 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout.py @@ -25,7 +25,6 @@ from orbax.checkpoint.experimental.v1._src.handlers import registration from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout from orbax.checkpoint.experimental.v1._src.loading import v0_compatibility -from orbax.checkpoint.experimental.v1._src.metadata import serialization as metadata_serialization from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types from orbax.checkpoint.experimental.v1._src.path import types as path_types @@ -148,24 +147,23 @@ async def _validate_pytree(self, checkpointable_name: str | None): """Validates a checkpoint path written by `ocp.save_pytree`. Args: - checkpointable_name: The name of the checkpointable to load. A - subdirectory with this name must exist in `directory`. If None then - `directory` is expected to contain the checkpoint directly. Defaults to - `pytree`. + checkpointable_name: The name of the checkpointable to load. For Orbax V1, + a subdirectory with this name must exist in `directory`. Raises: FileNotFoundError: If the path does not exist, or if `pytree` is not found in the directory ValueError: If the PyTree checkpoint is malformed. """ - pytree_dir = ( - self.path - if checkpointable_name is None - else self.path / checkpointable_name - ) - if checkpointable_name is not None and not await async_path.exists( - pytree_dir - ): + if checkpointable_name is None: + raise ValueError( + "A V1 checkpoint was saved and user is attempting to load it," + " treating it as a V0 Orbax checkpoint saved directly from" + f" {self.path}, this is not a characteristic of V1 saved checkpoints" + ) + + pytree_dir = self.path / checkpointable_name + if not await async_path.exists(pytree_dir): subdirs = [ d.name for d in await _subpaths(self.path) @@ -180,17 +178,13 @@ async def _validate_pytree(self, checkpointable_name: str | None): " using" " `ocp.load_checkpointables()`." ) - if not await _has_pytree_metadata_file(pytree_dir): - # TODO(niketkb): Add following details to the error message: - # 1. we should check other available subdirectories and see if any of them - # look like PyTree checkpoints, and instruct the user to consider - # whether they meant to specify any of those. - # 2. we need to check the directory - if it contains PyTree files, suggest - # loading with checkpointable_name=None + + if not (pytree_dir / PYTREE_METADATA_FILE).exists(): raise FileNotFoundError( f"Checkpoint path {self.path} does not contain a PyTree metadata" " file." ) + if not await _has_tensorstore_data_files(pytree_dir): logging.warning( "TensorStore data files not found in checkpoint path %s. This may be" @@ -200,17 +194,13 @@ async def _validate_pytree(self, checkpointable_name: str | None): ) async def _validate(self): - """Validates a checkpoint directory. + """Validates a checkpoint directory to be a V1 Orbax checkpoint. - Must be: + Must fulfill all of the following: - Existing - - A directory. - - Not a temporary path. - - OR - - Has orbax.checkpoint indicator file. - - OR - - Has _CHECKPOINT_METADATA file. - - A subdirectory has _METADATA file (PyTree checkpoint). + - A directory + - Not a temporary path + - Has orbax.checkpoint indicator file Raises: FileNotFoundError: If the path does not exist. @@ -235,33 +225,10 @@ async def _validate(self): if ORBAX_CHECKPOINT_INDICATOR_FILE in [p.name for p in subpaths]: return - # Path points to a single step checkpoint with valid metadata. - if await async_path.exists( - metadata_serialization.checkpoint_metadata_file_path(self.path) - ): - return - - # The path itself points to a PyTree checkpointable. - if await async_path.exists(self.path / PYTREE_METADATA_FILE): - return - # The path points to a directory containing at least one PyTree - # checkpointable. - for subpath in subpaths: - if await async_path.is_dir(subpath) and await async_path.exists( - subpath / PYTREE_METADATA_FILE - ): - return - raise FileNotFoundError( f"Checkpoint path {self.path} could not be identified as a valid Orbax" - " checkpoint. The path must conform to one of the following" - " conditions:\n - Contain the indicator file" - f" {ORBAX_CHECKPOINT_INDICATOR_FILE}. This should be true of all" - " checkpoints saved with the Orbax V1 API. If not present, the" - " checkpoint may have been saved with the V0 API.\n - Contain the" - " _CHECKPOINT_METADATA file.\n - Point directly to a PyTree" - " checkpointable (contain _METADATA file).\n - Contain a subdirectory" - " which is a PyTree checkpointable (contain _METADATA file).\n" + " V1 checkpoint. It is missing the indicator file" + f" '{ORBAX_CHECKPOINT_INDICATOR_FILE}'." ) async def validate(self): diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout_test.py index 66283fb7d..f03808be5 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout_test.py @@ -18,9 +18,6 @@ from etils import epath import numpy as np from orbax.checkpoint import test_utils -from orbax.checkpoint._src.checkpointers import checkpointer -from orbax.checkpoint._src.handlers import composite_checkpoint_handler -from orbax.checkpoint._src.handlers import standard_checkpoint_handler from orbax.checkpoint._src.metadata import value as value_metadata from orbax.checkpoint._src.path import async_path from orbax.checkpoint.experimental.v1._src.handlers import composite_handler @@ -102,7 +99,8 @@ async def test_validate_no_indicator_file(self): / composite_handler.ORBAX_CHECKPOINT_INDICATOR_FILE ) indicator_path.rmtree() # Remove the indicator file - await layout.validate() + with self.assertRaises(InvalidLayoutError): + await layout.validate() async def test_validate_no_metadata_file(self): layout = OrbaxLayout(self.orbax_path / '0') @@ -176,93 +174,6 @@ async def test_metadata(self): self.assertGreater(result_metadata.commit_timestamp_nsecs, 0) -class V0ValidationTest( - unittest.IsolatedAsyncioTestCase, parameterized.TestCase -): - - def setUp(self): - super().setUp() - self.directory = epath.Path(self.create_tempdir().full_path) / 'ckpt' - self.pytree, _ = array_test_utils.create_numpy_pytree() - # Save a checkpoint with a checkpointable name, `state`. - ckptr = checkpointer.Checkpointer( - composite_checkpoint_handler.CompositeCheckpointHandler() - ) - ckptr.save( - self.directory, - composite_checkpoint_handler.CompositeArgs( - state=standard_checkpoint_handler.StandardSaveArgs(self.pytree) - ), - ) - - async def test_nonexistent_path(self): - with self.assertRaises(FileNotFoundError): - await OrbaxLayout(self.directory / 'foo')._validate() - - async def test_not_a_directory(self): - await async_path.write_text(self.directory / 'foo', 'foo') - with self.assertRaises(NotADirectoryError): - await OrbaxLayout(self.directory / 'foo')._validate() - - @parameterized.product(checkpointable_name=['state', None]) - async def test_no_checkpoint_metadata(self, checkpointable_name: str | None): - directory = ( - self.directory / checkpointable_name - if checkpointable_name is not None - else self.directory - ) - await _unlink_checkpoint_metadata(directory) - - await OrbaxLayout(directory)._validate() - if checkpointable_name is None: - await OrbaxLayout(directory)._validate_pytree('state') - else: - await OrbaxLayout(directory)._validate_pytree(None) - - async def test_deleted_pytree(self): - directory = self.directory - (directory / 'state').rmtree() - - await OrbaxLayout(directory)._validate() - with self.assertRaises(FileNotFoundError): - await OrbaxLayout(directory)._validate_pytree('state') - - async def test_missing_checkpointable_matching_name(self): - with self.assertRaises(FileNotFoundError): - await OrbaxLayout(self.directory)._validate_pytree('foo') - - @parameterized.product(checkpointable_name=['state', None]) - async def test_no_pytree_metadata(self, checkpointable_name: str | None): - directory = ( - self.directory / checkpointable_name - if checkpointable_name is not None - else self.directory - ) - await _unlink_pytree_metadata(directory) - - if checkpointable_name is None: - # Passes because we still have the checkpoint metadata. - await OrbaxLayout(directory)._validate() - with self.assertRaises(FileNotFoundError): - await OrbaxLayout(directory)._validate_pytree('state') - else: - with self.assertRaises(FileNotFoundError): - await OrbaxLayout(directory)._validate() - await OrbaxLayout(directory)._validate_pytree(None) - - @parameterized.product(checkpointable_name=['state', None]) - async def test_valid_pytree(self, checkpointable_name: str | None): - directory = ( - self.directory / checkpointable_name - if checkpointable_name is not None - else self.directory - ) - if checkpointable_name is None: - await OrbaxLayout(directory)._validate_pytree('state') - else: - await OrbaxLayout(directory)._validate_pytree(None) - - class V1ValidationTest( unittest.IsolatedAsyncioTestCase, parameterized.TestCase ): @@ -290,7 +201,8 @@ async def test_missing_indicator_file(self, checkpointable_name: str | None): else self.directory ) await _unlink_indicator(directory) - await OrbaxLayout(directory)._validate() + with self.assertRaises(FileNotFoundError): + await OrbaxLayout(directory)._validate() async def test_deleted_pytree(self): directory = self.directory @@ -317,10 +229,12 @@ async def test_no_pytree_metadata(self, checkpointable_name: str | None): with self.assertRaises(FileNotFoundError): await OrbaxLayout(directory)._validate() - with self.assertRaises(FileNotFoundError): - if checkpointable_name is None: + + if checkpointable_name is None: + with self.assertRaises(FileNotFoundError): await OrbaxLayout(directory)._validate_pytree('pytree') - else: + else: + with self.assertRaises(ValueError): await OrbaxLayout(directory)._validate_pytree(None) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry.py index 22b5dd3a3..624078f3b 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry.py @@ -19,6 +19,7 @@ from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout from orbax.checkpoint.experimental.v1._src.layout import orbax_layout from orbax.checkpoint.experimental.v1._src.layout import safetensors_layout +from orbax.checkpoint.experimental.v1._src.layout import v0_layout from orbax.checkpoint.experimental.v1._src.path import types as path_types InvalidLayoutError = checkpoint_layout.InvalidLayoutError @@ -47,7 +48,11 @@ async def get_checkpoint_layout( match layout_enum: case CheckpointLayoutEnum.ORBAX: - layout_class = orbax_layout.OrbaxLayout + # This allows us to restore a v0 checkpoint with its own layout. + if _is_v0_checkpoint(path): + layout_class = v0_layout.V0Layout + else: + layout_class = orbax_layout.OrbaxLayout case CheckpointLayoutEnum.SAFETENSORS: layout_class = safetensors_layout.SafetensorsLayout case _: @@ -104,21 +109,31 @@ async def _try_resolve_pytree_checkpointable( if checkpointable_name is not None: return layout, checkpointable_name # Not a v0 checkpoint; use the default name. - if not _is_v0_checkpoint(layout): + if not _is_v0_checkpoint(layout.path): + if checkpointable_name is None: + raise ValueError( + "Cannot extract pytree from top-level V1 checkpoint directory." + ) return layout, checkpointable_name # If it's a V0 checkpoint, we can try to resolve the checkpointable name from # the path. - if not isinstance(layout, orbax_layout.OrbaxLayout): - raise AssertionError(f"Expected an OrbaxLayout, but got a {type(layout)}.") - # Option 1: It may be a direct path to the PyTree checkpointable. + if not isinstance(layout, (orbax_layout.OrbaxLayout, v0_layout.V0Layout)): + raise AssertionError( + f"Expected an OrbaxLayout or V0Layout, but got a {type(layout)}." + ) + + # If the path itself is a V0 PyTree checkpoint (flat structure), we can + # "zoom out" to the parent directory and treat the current directory as the + # checkpointable. try: original_path = layout.path - new_layout = orbax_layout.OrbaxLayout(original_path.parent) + new_layout = v0_layout.V0Layout(original_path.parent) await new_layout.validate_pytree(original_path.name) return new_layout, original_path.name except checkpoint_layout.InvalidLayoutError: pass - # Option 2: It may be a V0 checkpoint containing a PyTree checkpointable. It + + # It may be a V0 checkpoint containing a PyTree checkpointable. It # is possible for there to be multiple, but this would be unusual, and it is # fine to just return the first one. dir_names = [p.name for p in layout.path.iterdir() if p.is_dir()] @@ -135,8 +150,7 @@ async def _try_resolve_pytree_checkpointable( ) -def _is_v0_checkpoint(layout: CheckpointLayout) -> bool: - return not isinstance(layout, orbax_layout.OrbaxLayout) or ( - isinstance(layout, orbax_layout.OrbaxLayout) - and not layout.has_indicator_file - ) +def _is_v0_checkpoint(path: path_types.PathLike) -> bool: + ctx = context_lib.get_context() + path = ctx.file_options.path_class(path) + return not (path / orbax_layout.ORBAX_CHECKPOINT_INDICATOR_FILE).exists() diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry_test.py index 93c133e4f..ebe4423f7 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry_test.py @@ -60,7 +60,7 @@ async def test_root_directory(self): ): await try_resolve_pytree_checkpointable(layout, None) - @parameterized.product(checkpointable_name=['state', 'params', None]) + @parameterized.product(checkpointable_name=['state', 'params']) async def test_v1(self, checkpointable_name): layout = orbax_layout.OrbaxLayout(self.v1_directory) self.assertTrue(layout.has_indicator_file) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/v0_layout.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/v0_layout.py new file mode 100644 index 000000000..ec9b7b3d0 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/v0_layout.py @@ -0,0 +1,204 @@ +# Copyright 2025 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines `V0Layout`, a class to handle Orbax V0 checkpoint formats.""" + +import asyncio +from typing import Any, Awaitable + +from absl import logging +from orbax.checkpoint._src.metadata import checkpoint as checkpoint_metadata +from orbax.checkpoint._src.path import async_path +from orbax.checkpoint._src.path import temporary_paths +from orbax.checkpoint.experimental.v1._src.context import context as context_lib +from orbax.checkpoint.experimental.v1._src.handlers import composite_handler +from orbax.checkpoint.experimental.v1._src.handlers import registration +from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout +from orbax.checkpoint.experimental.v1._src.loading import v0_compatibility +from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types +from orbax.checkpoint.experimental.v1._src.path import types as path_types + + +InvalidLayoutError = checkpoint_layout.InvalidLayoutError +Path = path_types.Path +CheckpointLayout = checkpoint_layout.CheckpointLayout + +PYTREE_METADATA_FILE = "_METADATA" +ORBAX_CHECKPOINT_INDICATOR_FILE = "orbax.checkpoint" +_OCDBT_MANIFEST_FILE = "ocdbt.manifest" +_ZARRAY_FILE = ".zarray" + + +async def _has_pytree_metadata_file(path: Path) -> bool: + return await async_path.exists(path / PYTREE_METADATA_FILE) + + +async def _has_ocdbt_manifest_file(path: Path) -> bool: + return await async_path.exists(path / _OCDBT_MANIFEST_FILE) + + +async def _has_zarray_files(path: Path) -> bool: + paths = list(await async_path.iterdir(path)) + awaitables = [async_path.exists(p / _ZARRAY_FILE) for p in paths] + return any(await asyncio.gather(*awaitables)) + + +async def _has_tensorstore_data_files(path: Path) -> bool: + return await _has_ocdbt_manifest_file(path) or await _has_zarray_files(path) + + +class V0Layout(CheckpointLayout): + """V0Layout. + + This class handles Orbax V0 checkpoint formats. It inherits + abstract methods from CheckpointLayout. It performs a few core functions: + - Resolves handlers for saving and loading. + - Saves and loads checkpointables to/from individual subdirectories by + delegating to the resolved handlers. + """ + + def __init__(self, path: Path): + self._context = context_lib.get_context() + self._handler_registry = registration.local_registry( + self._context.checkpointables_options.registry, + include_global_registry=False, + ) + self._composite_handler = composite_handler.CompositeHandler( + self._handler_registry + ) + self._path = path + + @property + def path(self) -> Path: + return self._path + + async def metadata(self) -> metadata_types.CheckpointMetadata[dict[str, Any]]: + # Uses the v0 checkpointer to get v0 StepMetadata + checkpointer, _ = v0_compatibility.get_v0_checkpointer_and_args( + self._path, None, context=context_lib.get_context() + ) + step_metadata = checkpointer.metadata(self._path) + + item_metadata = {k: v for k, v in step_metadata.item_metadata.items()} + # Exclude `metrics` if present. + item_metadata.pop("metrics", None) + + return metadata_types.CheckpointMetadata[dict[str, Any]]( + metadata=item_metadata, + init_timestamp_nsecs=step_metadata.init_timestamp_nsecs, + commit_timestamp_nsecs=step_metadata.commit_timestamp_nsecs, + custom_metadata=step_metadata.custom_metadata, + ) + + async def _validate(self) -> None: + """Validates a V0 checkpoint directory.""" + if not await async_path.exists(self.path): + raise FileNotFoundError(f"Checkpoint path {self.path} does not exist.") + if not await async_path.is_dir(self.path): + raise NotADirectoryError( + f"Checkpoint path {self.path} is not a directory." + ) + context = context_lib.get_context() + if await temporary_paths.is_path_temporary( + self.path, + temporary_path_cls=context.file_options.temporary_path_class, + ): + raise ValueError(f"Found incomplete checkpoint at {self.path}.") + + # Path points to a single step checkpoint with valid checkpoint metadata. + if await async_path.exists( + checkpoint_metadata.step_metadata_file_path(self.path) + ): + return + + # The path itself points to a PyTree checkpointable. + if await async_path.exists(self.path / PYTREE_METADATA_FILE): + return + + subpaths = await async_path.iterdir(self.path) + # The path points to a directory containing at least one PyTree + # checkpointable. + for subpath in subpaths: + if await async_path.is_dir(subpath) and await async_path.exists( + subpath / PYTREE_METADATA_FILE + ): + return + + raise ValueError( + f"Checkpoint path {self.path} is not a valid V0 checkpoint." + ) + + async def _validate_pytree(self, checkpointable_name: str | None) -> None: + """Validates the given path as a V0 PyTree checkpoint.""" + pytree_dir = ( + self.path + if checkpointable_name is None + else self.path / checkpointable_name + ) + + if checkpointable_name is not None and not await async_path.exists( + pytree_dir + ): + raise FileNotFoundError( + f"Checkpoint path {self.path} must contain a subdirectory named" + f' "{checkpointable_name}".' + ) + + if not await _has_pytree_metadata_file(pytree_dir): + # TODO(angelmau): Add following details to the error message: + # 1. we should check other available subdirectories and see if any of them + # look like PyTree checkpoints, and instruct the user to consider + # whether they meant to specify any of those. + # 2. we need to check the directory - if it contains PyTree files, suggest + # loading with checkpointable_name=None + raise FileNotFoundError( + f"Checkpoint path {self.path} does not contain a PyTree metadata" + " file." + ) + + if not await _has_tensorstore_data_files(pytree_dir): + logging.warning( + "TensorStore data files not found in checkpoint path %s. This may be" + " a sign of a malformed checkpoint, unless your checkpoint consists" + " entirely of strings or other non-standard PyTree leaves.", + self.path, + ) + + async def validate(self) -> None: + """Validates a V0 checkpoint directory.""" + try: + await self._validate() + except BaseException as e: + raise InvalidLayoutError( + f"Failed to interpret path {self._path} as an Orbax V0 checkpoint." + ) from e + + async def validate_pytree(self, checkpointable_name: str | None) -> None: + """Validates the given path as a V0 PyTree checkpoint.""" + try: + await self._validate_pytree(checkpointable_name) + except BaseException as e: + raise InvalidLayoutError( + f"Failed to interpret path {self._path} as an Orbax V0 PyTree" + " checkpoint." + ) from e + + async def load( + self, + abstract_checkpointables: dict[str, Any] | None = None, + ) -> Awaitable[dict[str, Any]]: + load_awaitable = await self._composite_handler.load( + self._path, abstract_checkpointables + ) + return load_awaitable diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/v0_layout_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/v0_layout_test.py new file mode 100644 index 000000000..11956d933 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/v0_layout_test.py @@ -0,0 +1,227 @@ +# Copyright 2025 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from absl.testing import absltest +from absl.testing import parameterized +from etils import epath +import numpy as np +from orbax.checkpoint import test_utils +from orbax.checkpoint._src.checkpointers import checkpointer +from orbax.checkpoint._src.handlers import composite_checkpoint_handler +from orbax.checkpoint._src.handlers import standard_checkpoint_handler +from orbax.checkpoint._src.path import async_path +from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout +from orbax.checkpoint.experimental.v1._src.layout import v0_layout +from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types +from orbax.checkpoint.experimental.v1._src.testing import array_utils as array_test_utils +import safetensors.numpy + + +InvalidLayoutError = checkpoint_layout.InvalidLayoutError +np_save_file = safetensors.numpy.save_file + + +async def _unlink_checkpoint_metadata(path: epath.Path): + await async_path.unlink(path / '_CHECKPOINT_METADATA', missing_ok=True) + + +async def _unlink_pytree_metadata(path: epath.Path): + await async_path.unlink(path / '_METADATA', missing_ok=True) + for subdir in await async_path.iterdir(path): + if not await async_path.is_dir(subdir): + continue + await async_path.unlink(subdir / '_METADATA', missing_ok=True) + + +class V0LayoutTest(unittest.IsolatedAsyncioTestCase, parameterized.TestCase): + + def setUp(self): + super().setUp() + self.test_dir = self.create_tempdir() + self.orbax_path = epath.Path(self.test_dir.full_path) / 'test_checkpoint' + self.safetensors_path = ( + epath.Path(self.test_dir.full_path) / 'test_checkpoint.safetensors' + ) + + # Create a mock SafeTensors and Orbax V0 checkpoint + self.object_to_save = { + 'a': np.array(3 * [1, 2, 3], dtype=np.int32), + 'b': np.array([0, 1, 0.2], dtype=np.float32), + } + self.custom_metadata = {'framework': 'JAX', 'version': '1.0'} + np_save_file(self.object_to_save, self.safetensors_path) + + # Save V0 checkpoint + ckptr = checkpointer.Checkpointer( + composite_checkpoint_handler.CompositeCheckpointHandler() + ) + ckptr.save( + self.orbax_path / '0', + composite_checkpoint_handler.CompositeArgs( + state=standard_checkpoint_handler.StandardSaveArgs( + self.object_to_save + ) + ), + ) + + async def test_valid_v0_checkpoint(self): + layout = v0_layout.V0Layout(self.orbax_path / '0') + await layout.validate() + + async def test_invalid_v0_checkpoint(self): + layout = v0_layout.V0Layout(self.safetensors_path) + with self.assertRaises(InvalidLayoutError): + await layout.validate() + + async def test_validate_fails_not_directory(self): + layout = v0_layout.V0Layout(self.orbax_path / '1') + with self.assertRaises(InvalidLayoutError): + await layout.validate() + + async def test_validate_no_metadata_file(self): + # V0Layout checks for _CHECKPOINT_METADATA or _METADATA in subdirs. + layout = v0_layout.V0Layout(self.orbax_path / '0') + metadata_path = self.orbax_path / '0' / '_CHECKPOINT_METADATA' + self.assertTrue(metadata_path.exists()) + metadata_path.rmtree() # Remove the metadata file + # Should still pass if subdirs have metadata + await layout.validate() + + async def test_validate_no_metadata_files(self): + layout = v0_layout.V0Layout(self.orbax_path / '0') + metadata_path = self.orbax_path / '0' / '_CHECKPOINT_METADATA' + metadata_path.rmtree() + # Also remove subdir metadata + # The V0 checkpoint structure from CompositeCheckpointHandler is: + # 0/ + # state/ + # _METADATA + # ... + pytree_metadata_path = self.orbax_path / '0' / 'state' / '_METADATA' + pytree_metadata_path.rmtree() + + with self.assertRaises(InvalidLayoutError): + await layout.validate() + + async def test_load_v0_checkpoint(self): + layout = v0_layout.V0Layout(self.orbax_path / '0') + restored_checkpointables_await = await layout.load() + restored_checkpointables = await restored_checkpointables_await + # restored_checkpointables will be {'state': ...} because of CompositeArgs + test_utils.assert_tree_equal( + self, restored_checkpointables['state'], self.object_to_save + ) + + async def test_metadata(self): + """Tests the metadata() method.""" + # V0Layout.metadata() delegates to CompositeHandler.metadata() + # which reads _CHECKPOINT_METADATA or infers from subdirs. + layout = v0_layout.V0Layout(self.orbax_path / '0') + result_metadata = await layout.metadata() + self.assertIsInstance(result_metadata, metadata_types.CheckpointMetadata) + + +class V0InternalValidationTest( + unittest.IsolatedAsyncioTestCase, parameterized.TestCase +): + + def setUp(self): + super().setUp() + self.directory = epath.Path(self.create_tempdir().full_path) / 'ckpt' + self.pytree, _ = array_test_utils.create_numpy_pytree() + # Save a checkpoint with a checkpointable name, `state`. + ckptr = checkpointer.Checkpointer( + composite_checkpoint_handler.CompositeCheckpointHandler() + ) + ckptr.save( + self.directory, + composite_checkpoint_handler.CompositeArgs( + state=standard_checkpoint_handler.StandardSaveArgs(self.pytree) + ), + ) + + async def test_nonexistent_path(self): + with self.assertRaises(FileNotFoundError): + await v0_layout.V0Layout(self.directory / 'foo')._validate() + + async def test_not_a_directory(self): + await async_path.write_text(self.directory / 'foo', 'foo') + with self.assertRaises(NotADirectoryError): + await v0_layout.V0Layout(self.directory / 'foo')._validate() + + @parameterized.product(checkpointable_name=['state', None]) + async def test_no_checkpoint_metadata(self, checkpointable_name: str | None): + directory = ( + self.directory / checkpointable_name + if checkpointable_name is not None + else self.directory + ) + await _unlink_checkpoint_metadata(directory) + + # V0Layout should validate successfully even without _CHECKPOINT_METADATA + # if subdirectories contain PyTree metadata. + await v0_layout.V0Layout(directory)._validate() + if checkpointable_name is None: + await v0_layout.V0Layout(directory)._validate_pytree('state') + else: + await v0_layout.V0Layout(directory)._validate_pytree(None) + + async def test_deleted_pytree(self): + directory = self.directory + (directory / 'state').rmtree() + + await v0_layout.V0Layout(directory)._validate() + with self.assertRaises(FileNotFoundError): + await v0_layout.V0Layout(directory)._validate_pytree('state') + + async def test_missing_checkpointable_matching_name(self): + with self.assertRaises(FileNotFoundError): + await v0_layout.V0Layout(self.directory)._validate_pytree('foo') + + @parameterized.product(checkpointable_name=['state', None]) + async def test_no_pytree_metadata(self, checkpointable_name: str | None): + directory = ( + self.directory / checkpointable_name + if checkpointable_name is not None + else self.directory + ) + await _unlink_pytree_metadata(directory) + + if checkpointable_name is None: + # Passes because we still have the checkpoint metadata. + await v0_layout.V0Layout(directory)._validate() + with self.assertRaises(FileNotFoundError): + await v0_layout.V0Layout(directory)._validate_pytree('state') + else: + with self.assertRaises(ValueError): + await v0_layout.V0Layout(directory)._validate() + with self.assertRaises(FileNotFoundError): + await v0_layout.V0Layout(directory)._validate_pytree(None) + + @parameterized.product(checkpointable_name=['state', None]) + async def test_valid_pytree(self, checkpointable_name: str | None): + directory = ( + self.directory / checkpointable_name + if checkpointable_name is not None + else self.directory + ) + if checkpointable_name is None: + await v0_layout.V0Layout(directory)._validate_pytree('state') + else: + await v0_layout.V0Layout(directory)._validate_pytree(None) + + +if __name__ == '__main__': + absltest.main()