Skip to content

Commit 4329ad7

Browse files
angel-coreOrbax Authors
authored andcommitted
Add testing to verify backwards compatibility when missing _CHECKPOINT_METADATA or saving pytree directly to checkpoint.
PiperOrigin-RevId: 840391322
1 parent 23302a2 commit 4329ad7

File tree

9 files changed

+608
-182
lines changed

9 files changed

+608
-182
lines changed

checkpoint/CHANGELOG.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
## [0.11.30] - 2025-11-26
1111

12+
### Added
13+
14+
- Added testing to verify backwards compatibility when a checkpoint is missing
15+
`_CHECKPOINT_METADATA` or saving pytree directly as a checkpoint with no
16+
checkpointable.
17+
1218
### Fixed
1319

1420
- Roll back earlier change altering metadata format, which was observed to cause
1521
breakages.
22+
- Fix `CompositeHandler` extraction logic for handler type strings when missing `_CHECKPOINT_METADATA`
23+
24+
### Changed
25+
26+
- Split v0 checkpoint format/layout logic out from `OrbaxLayout` and into
27+
seperate `V0Layout`.
1628

1729
## [0.11.29] - 2025-11-25
1830

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/composite_handler.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
from orbax.checkpoint._src.metadata import step_metadata_serialization
2626
from orbax.checkpoint._src.multihost import multihost
2727
from orbax.checkpoint._src.path import async_path
28+
from orbax.checkpoint._src.path import format_utils
2829
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
30+
from orbax.checkpoint.experimental.v1._src.handlers import pytree_handler
2931
from orbax.checkpoint.experimental.v1._src.handlers import registration
3032
from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types
3133
import orbax.checkpoint.experimental.v1._src.handlers.global_registration # pylint: disable=unused-import
@@ -262,7 +264,7 @@ def get_handlers_for_load(
262264

263265
def _get_saved_handler_typestrs(
264266
self, directory: path_types.Path
265-
) -> dict[str, str]:
267+
) -> dict[str, str | None]:
266268
"""Reads from the checkpoint metadata to get saved handler typestrs."""
267269
step_metadata_file_path = checkpoint_metadata.step_metadata_file_path(
268270
directory
@@ -291,28 +293,40 @@ def _get_saved_handler_typestrs(
291293
directory,
292294
)
293295

294-
saved_handler_typestrs: dict[str, str] = {}
296+
# The following generically handles the case where the checkpoint is missing
297+
# the checkpoint metadata file. This is a fallback for older v0 checkpoints.
298+
299+
# We check each subdirectory treating it as a checkpointable.
300+
# The order of presedence for mapping checkpointable names to handlers is:
301+
# 1. A pytree metadata file, indicating a pytree handler
302+
# 2. A checkpointable with a handler registered in the handler registry.
303+
# 3. If neither, we skip it and treat it as garbage since we don't know
304+
# how to handle it.
305+
306+
saved_handler_typestrs: dict[str, str | None] = {}
295307
for checkpointable_path in directory.iterdir():
296-
serialized_metadata = self._metadata_store.read(
297-
checkpoint_metadata.step_metadata_file_path(checkpointable_path)
298-
)
299-
if serialized_metadata is None:
308+
if not checkpointable_path.is_dir():
300309
continue
301-
saved_metadata = step_metadata_serialization.deserialize(
302-
serialized_metadata
310+
checkpointable_name = checkpointable_path.name
311+
312+
pytree_metadata_path = (
313+
checkpointable_path / format_utils.PYTREE_METADATA_FILE
303314
)
304-
if isinstance(saved_metadata.item_handlers, dict):
315+
316+
if pytree_metadata_path.exists():
317+
saved_handler_typestrs[checkpointable_name] = handler_types.typestr(
318+
pytree_handler.PyTreeHandler
319+
)
320+
elif self._handler_registry.has(checkpointable_name):
321+
# If the handler is registered we can assume it will be found in
322+
# resolve_handler_for_load.
323+
saved_handler_typestrs[checkpointable_name] = None
324+
else:
305325
raise ValueError(
306-
f'Path at {directory} contains subdirectories:'
307-
f' {_subdirs(directory)}, which are expected to'
308-
' match the keys given by the _CHECKPOINT_METADATA file:'
309-
f' {saved_metadata.item_handlers}. If you intended to load a pytree'
310-
' checkpoint from the given path, then please consider using'
311-
' `loading.load_pytree(..., checkpointable_name=None)` instead.'
312-
f' {_V0_ERROR_MESSAGE}'
326+
'Cannot determine handler for checkpointable'
327+
f" '{checkpointable_name}' in directory {directory}. The top-level"
328+
" '_CHECKPOINT_METADATA' is missing, and this item does not have a"
329+
" '_METADATA' file to be loaded as a PyTree, nor is it registered"
330+
" in the handler registry."
313331
)
314-
item_handlers = saved_metadata.item_handlers
315-
if item_handlers is not None:
316-
checkpointable_name = checkpointable_path.name
317-
saved_handler_typestrs[checkpointable_name] = item_handlers
318332
return saved_handler_typestrs

checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout.py

Lines changed: 21 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from orbax.checkpoint.experimental.v1._src.handlers import registration
2626
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout
2727
from orbax.checkpoint.experimental.v1._src.loading import v0_compatibility
28-
from orbax.checkpoint.experimental.v1._src.metadata import serialization as metadata_serialization
2928
from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types
3029
from orbax.checkpoint.experimental.v1._src.path import types as path_types
3130

@@ -148,24 +147,23 @@ async def _validate_pytree(self, checkpointable_name: str | None):
148147
"""Validates a checkpoint path written by `ocp.save_pytree`.
149148
150149
Args:
151-
checkpointable_name: The name of the checkpointable to load. A
152-
subdirectory with this name must exist in `directory`. If None then
153-
`directory` is expected to contain the checkpoint directly. Defaults to
154-
`pytree`.
150+
checkpointable_name: The name of the checkpointable to load. For Orbax V1,
151+
a subdirectory with this name must exist in `directory`.
155152
156153
Raises:
157154
FileNotFoundError: If the path does not exist, or if `pytree` is not found
158155
in the directory
159156
ValueError: If the PyTree checkpoint is malformed.
160157
"""
161-
pytree_dir = (
162-
self.path
163-
if checkpointable_name is None
164-
else self.path / checkpointable_name
165-
)
166-
if checkpointable_name is not None and not await async_path.exists(
167-
pytree_dir
168-
):
158+
if checkpointable_name is None:
159+
raise ValueError(
160+
"A V1 checkpoint was saved and user is attempting to load it,"
161+
" treating it as a V0 Orbax checkpoint saved directly from"
162+
f" {self.path}, this is not a characteristic of V1 saved checkpoints"
163+
)
164+
165+
pytree_dir = self.path / checkpointable_name
166+
if not await async_path.exists(pytree_dir):
169167
subdirs = [
170168
d.name
171169
for d in await _subpaths(self.path)
@@ -180,17 +178,13 @@ async def _validate_pytree(self, checkpointable_name: str | None):
180178
" using"
181179
" `ocp.load_checkpointables()`."
182180
)
183-
if not await _has_pytree_metadata_file(pytree_dir):
184-
# TODO(niketkb): Add following details to the error message:
185-
# 1. we should check other available subdirectories and see if any of them
186-
# look like PyTree checkpoints, and instruct the user to consider
187-
# whether they meant to specify any of those.
188-
# 2. we need to check the directory - if it contains PyTree files, suggest
189-
# loading with checkpointable_name=None
181+
182+
if not (pytree_dir / PYTREE_METADATA_FILE).exists():
190183
raise FileNotFoundError(
191184
f"Checkpoint path {self.path} does not contain a PyTree metadata"
192185
" file."
193186
)
187+
194188
if not await _has_tensorstore_data_files(pytree_dir):
195189
logging.warning(
196190
"TensorStore data files not found in checkpoint path %s. This may be"
@@ -200,17 +194,13 @@ async def _validate_pytree(self, checkpointable_name: str | None):
200194
)
201195

202196
async def _validate(self):
203-
"""Validates a checkpoint directory.
197+
"""Validates a checkpoint directory to be a V1 Orbax checkpoint.
204198
205-
Must be:
199+
Must fulfill all of the following:
206200
- Existing
207-
- A directory.
208-
- Not a temporary path.
209-
- OR
210-
- Has orbax.checkpoint indicator file.
211-
- OR
212-
- Has _CHECKPOINT_METADATA file.
213-
- A subdirectory has _METADATA file (PyTree checkpoint).
201+
- A directory
202+
- Not a temporary path
203+
- Has orbax.checkpoint indicator file
214204
215205
Raises:
216206
FileNotFoundError: If the path does not exist.
@@ -235,33 +225,10 @@ async def _validate(self):
235225
if ORBAX_CHECKPOINT_INDICATOR_FILE in [p.name for p in subpaths]:
236226
return
237227

238-
# Path points to a single step checkpoint with valid metadata.
239-
if await async_path.exists(
240-
metadata_serialization.checkpoint_metadata_file_path(self.path)
241-
):
242-
return
243-
244-
# The path itself points to a PyTree checkpointable.
245-
if await async_path.exists(self.path / PYTREE_METADATA_FILE):
246-
return
247-
# The path points to a directory containing at least one PyTree
248-
# checkpointable.
249-
for subpath in subpaths:
250-
if await async_path.is_dir(subpath) and await async_path.exists(
251-
subpath / PYTREE_METADATA_FILE
252-
):
253-
return
254-
255228
raise FileNotFoundError(
256229
f"Checkpoint path {self.path} could not be identified as a valid Orbax"
257-
" checkpoint. The path must conform to one of the following"
258-
" conditions:\n - Contain the indicator file"
259-
f" {ORBAX_CHECKPOINT_INDICATOR_FILE}. This should be true of all"
260-
" checkpoints saved with the Orbax V1 API. If not present, the"
261-
" checkpoint may have been saved with the V0 API.\n - Contain the"
262-
" _CHECKPOINT_METADATA file.\n - Point directly to a PyTree"
263-
" checkpointable (contain _METADATA file).\n - Contain a subdirectory"
264-
" which is a PyTree checkpointable (contain _METADATA file).\n"
230+
" V1 checkpoint. It is missing the indicator file"
231+
f" '{ORBAX_CHECKPOINT_INDICATOR_FILE}'."
265232
)
266233

267234
async def validate(self):

checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout_test.py

Lines changed: 9 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@
1818
from etils import epath
1919
import numpy as np
2020
from orbax.checkpoint import test_utils
21-
from orbax.checkpoint._src.checkpointers import checkpointer
22-
from orbax.checkpoint._src.handlers import composite_checkpoint_handler
23-
from orbax.checkpoint._src.handlers import standard_checkpoint_handler
2421
from orbax.checkpoint._src.metadata import value as value_metadata
2522
from orbax.checkpoint._src.path import async_path
2623
from orbax.checkpoint.experimental.v1._src.handlers import composite_handler
@@ -102,7 +99,8 @@ async def test_validate_no_indicator_file(self):
10299
/ composite_handler.ORBAX_CHECKPOINT_INDICATOR_FILE
103100
)
104101
indicator_path.rmtree() # Remove the indicator file
105-
await layout.validate()
102+
with self.assertRaises(InvalidLayoutError):
103+
await layout.validate()
106104

107105
async def test_validate_no_metadata_file(self):
108106
layout = OrbaxLayout(self.orbax_path / '0')
@@ -176,93 +174,6 @@ async def test_metadata(self):
176174
self.assertGreater(result_metadata.commit_timestamp_nsecs, 0)
177175

178176

179-
class V0ValidationTest(
180-
unittest.IsolatedAsyncioTestCase, parameterized.TestCase
181-
):
182-
183-
def setUp(self):
184-
super().setUp()
185-
self.directory = epath.Path(self.create_tempdir().full_path) / 'ckpt'
186-
self.pytree, _ = array_test_utils.create_numpy_pytree()
187-
# Save a checkpoint with a checkpointable name, `state`.
188-
ckptr = checkpointer.Checkpointer(
189-
composite_checkpoint_handler.CompositeCheckpointHandler()
190-
)
191-
ckptr.save(
192-
self.directory,
193-
composite_checkpoint_handler.CompositeArgs(
194-
state=standard_checkpoint_handler.StandardSaveArgs(self.pytree)
195-
),
196-
)
197-
198-
async def test_nonexistent_path(self):
199-
with self.assertRaises(FileNotFoundError):
200-
await OrbaxLayout(self.directory / 'foo')._validate()
201-
202-
async def test_not_a_directory(self):
203-
await async_path.write_text(self.directory / 'foo', 'foo')
204-
with self.assertRaises(NotADirectoryError):
205-
await OrbaxLayout(self.directory / 'foo')._validate()
206-
207-
@parameterized.product(checkpointable_name=['state', None])
208-
async def test_no_checkpoint_metadata(self, checkpointable_name: str | None):
209-
directory = (
210-
self.directory / checkpointable_name
211-
if checkpointable_name is not None
212-
else self.directory
213-
)
214-
await _unlink_checkpoint_metadata(directory)
215-
216-
await OrbaxLayout(directory)._validate()
217-
if checkpointable_name is None:
218-
await OrbaxLayout(directory)._validate_pytree('state')
219-
else:
220-
await OrbaxLayout(directory)._validate_pytree(None)
221-
222-
async def test_deleted_pytree(self):
223-
directory = self.directory
224-
(directory / 'state').rmtree()
225-
226-
await OrbaxLayout(directory)._validate()
227-
with self.assertRaises(FileNotFoundError):
228-
await OrbaxLayout(directory)._validate_pytree('state')
229-
230-
async def test_missing_checkpointable_matching_name(self):
231-
with self.assertRaises(FileNotFoundError):
232-
await OrbaxLayout(self.directory)._validate_pytree('foo')
233-
234-
@parameterized.product(checkpointable_name=['state', None])
235-
async def test_no_pytree_metadata(self, checkpointable_name: str | None):
236-
directory = (
237-
self.directory / checkpointable_name
238-
if checkpointable_name is not None
239-
else self.directory
240-
)
241-
await _unlink_pytree_metadata(directory)
242-
243-
if checkpointable_name is None:
244-
# Passes because we still have the checkpoint metadata.
245-
await OrbaxLayout(directory)._validate()
246-
with self.assertRaises(FileNotFoundError):
247-
await OrbaxLayout(directory)._validate_pytree('state')
248-
else:
249-
with self.assertRaises(FileNotFoundError):
250-
await OrbaxLayout(directory)._validate()
251-
await OrbaxLayout(directory)._validate_pytree(None)
252-
253-
@parameterized.product(checkpointable_name=['state', None])
254-
async def test_valid_pytree(self, checkpointable_name: str | None):
255-
directory = (
256-
self.directory / checkpointable_name
257-
if checkpointable_name is not None
258-
else self.directory
259-
)
260-
if checkpointable_name is None:
261-
await OrbaxLayout(directory)._validate_pytree('state')
262-
else:
263-
await OrbaxLayout(directory)._validate_pytree(None)
264-
265-
266177
class V1ValidationTest(
267178
unittest.IsolatedAsyncioTestCase, parameterized.TestCase
268179
):
@@ -290,7 +201,8 @@ async def test_missing_indicator_file(self, checkpointable_name: str | None):
290201
else self.directory
291202
)
292203
await _unlink_indicator(directory)
293-
await OrbaxLayout(directory)._validate()
204+
with self.assertRaises(FileNotFoundError):
205+
await OrbaxLayout(directory)._validate()
294206

295207
async def test_deleted_pytree(self):
296208
directory = self.directory
@@ -317,10 +229,12 @@ async def test_no_pytree_metadata(self, checkpointable_name: str | None):
317229

318230
with self.assertRaises(FileNotFoundError):
319231
await OrbaxLayout(directory)._validate()
320-
with self.assertRaises(FileNotFoundError):
321-
if checkpointable_name is None:
232+
233+
if checkpointable_name is None:
234+
with self.assertRaises(FileNotFoundError):
322235
await OrbaxLayout(directory)._validate_pytree('pytree')
323-
else:
236+
else:
237+
with self.assertRaises(ValueError):
324238
await OrbaxLayout(directory)._validate_pytree(None)
325239

326240

0 commit comments

Comments
 (0)