Skip to content

Commit 291e110

Browse files
author
Orbax Authors
committed
#v1 Add checkpointables support for training.Checkpointer.
PiperOrigin-RevId: 752424341
1 parent fa1a881 commit 291e110

21 files changed

+154
-568
lines changed

checkpoint/CHANGELOG.md

-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ would want them to be preferred.
2121

2222
- #v1 Add `JsonHandler`.
2323
- #v1 Add `training.Checkpointer`.
24-
- #v1 Add checkpointables support for `training.Checkpointer`.
2524
- `PartsOf` structure which holds a PyTree whose leaf nodes may be missing.
2625
- #v1 Add compatibility tests for save-by-v0-load-by-v1 and also fix code.
2726

checkpoint/orbax/checkpoint/_src/handlers/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ py_library(
1313
":checkpoint_handler",
1414
":handler_registration",
1515
":handler_type_registry",
16+
":proto_checkpoint_handler",
1617
"//checkpoint/orbax/checkpoint:checkpoint_args",
1718
"//checkpoint/orbax/checkpoint:options",
1819
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",

checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
import asyncio
4949
import concurrent.futures
5050
import dataclasses
51-
from typing import Any, Dict, List, Mapping, MutableSet, Optional, Tuple, Type
51+
from typing import Any, Coroutine, Dict, List, Mapping, MutableSet, Optional, Tuple, Type
5252

5353
from absl import logging
5454
from etils import epath
@@ -63,6 +63,7 @@
6363
from orbax.checkpoint._src.handlers import checkpoint_handler
6464
from orbax.checkpoint._src.handlers import handler_registration
6565
from orbax.checkpoint._src.handlers import handler_type_registry
66+
from orbax.checkpoint._src.handlers import proto_checkpoint_handler
6667
from orbax.checkpoint._src.metadata import checkpoint
6768
from orbax.checkpoint._src.metadata import step_metadata_serialization
6869
from orbax.checkpoint._src.path import atomicity
@@ -78,6 +79,10 @@
7879
CompositeItemMetadata = checkpoint.CompositeItemMetadata
7980
AsyncCheckpointHandler = async_checkpoint_handler.AsyncCheckpointHandler
8081
register_with_handler = checkpoint_args.register_with_handler
82+
ProtoCheckpointHandler = proto_checkpoint_handler.ProtoCheckpointHandler
83+
ProtoSaveArgs = proto_checkpoint_handler.ProtoSaveArgs
84+
ProtoRestoreArgs = proto_checkpoint_handler.ProtoRestoreArgs
85+
AsyncSaveCoroutine = Coroutine[Any, Any, Optional[List[Future]]]
8186
Composite = composite.Composite
8287
HandlerAwaitableSignal = synchronization.HandlerAwaitableSignal
8388

checkpoint/orbax/checkpoint/checkpoint_manager.py

+19-19
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,10 @@
9292
FileOptions = options_lib.FileOptions
9393

9494
DEFAULT_ITEM_NAME = 'default'
95+
DESCRIPTOR_ITEM_NAME = 'descriptor'
9596
METRIC_ITEM_NAME = 'metrics'
9697
METADATA_ITEM_NAME = 'metadata'
97-
RESERVED_ITEM_NAMES = []
98+
RESERVED_ITEM_NAMES = [DESCRIPTOR_ITEM_NAME, METRIC_ITEM_NAME]
9899

99100
_INIT_TIME = datetime.datetime.now(tz=datetime.timezone.utc)
100101

@@ -318,7 +319,6 @@ class CheckpointManagerOptions:
318319
is the sole means of determining when a checkpoint should be saved. If not
319320
provided, these other options are used instead. Prefer to use this option
320321
over others.
321-
prevent_write_metrics: False by default. If True, metrics will not be written.
322322
"""
323323

324324
save_interval_steps: int = 1
@@ -351,7 +351,6 @@ class CheckpointManagerOptions:
351351
save_decision_policy: Optional[
352352
save_decision_policy_lib.SaveDecisionPolicy
353353
] = None
354-
prevent_write_metrics: bool = False
355354

356355
def __post_init__(self):
357356
step_name_format_single_host_load_and_broadcast = (
@@ -911,6 +910,7 @@ def _configure_checkpointer_legacy_init(
911910
f'Invalid type for `checkpointers`. Found {checkpointers}.'
912911
)
913912

913+
# if options.best_fn:
914914
item_handlers[METRIC_ITEM_NAME] = self._metrics_handler
915915
if options.async_options is None:
916916
options.async_options = (
@@ -1366,10 +1366,8 @@ def save(
13661366
'Some provided items have prohibited reserved names:'
13671367
f' {args_dict.keys()}. Reserved names: {RESERVED_ITEM_NAMES}.'
13681368
)
1369-
if (
1370-
metrics is not None and self._track_best
1371-
) and not self._options.prevent_write_metrics:
1372-
args_dict[METRIC_ITEM_NAME] = args_lib.JsonSave(metrics)
1369+
if metrics is not None and self._track_best:
1370+
args_dict['metrics'] = args_lib.JsonSave(metrics)
13731371
args = args_lib.Composite(**args_dict)
13741372

13751373
save_directory = self._get_write_step_directory(step, self.directory)
@@ -1640,18 +1638,20 @@ def item_metadata(
16401638

16411639
# TODO(b/370812224): Deprecate in favor of StepMetadata.metrics
16421640
def metrics(self, step: int) -> Optional[PyTree]:
1643-
try:
1644-
# Use handler directly, since this happens in a background thread and
1645-
# barriers cannot be used. This usage pattern is not
1646-
# recommended in other contexts.
1647-
metrics = self._metrics_handler.restore(
1648-
self._get_read_step_directory(step, self.directory)
1649-
/ METRIC_ITEM_NAME
1650-
)
1651-
return metrics
1652-
except FileNotFoundError as e:
1653-
logging.warning('Missing metrics for step %d', step)
1654-
logging.error(e)
1641+
if self._track_best:
1642+
try:
1643+
# Use handler directly, since this happens in a background thread and
1644+
# barriers cannot be used. This usage pattern is not
1645+
# recommended in other contexts.
1646+
return self._metrics_handler.restore(
1647+
self._get_read_step_directory(step, self.directory)
1648+
/ METRIC_ITEM_NAME
1649+
)
1650+
except FileNotFoundError as e:
1651+
logging.warning('Missing metrics for step %d', step)
1652+
logging.error(e)
1653+
return None
1654+
else:
16551655
return None
16561656

16571657
@property

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/compatibility.py

-6
Original file line numberDiff line numberDiff line change
@@ -114,17 +114,11 @@ def typestr(cls) -> str:
114114
"""A unique identifier for the CheckpointHandler type."""
115115
...
116116

117-
def __repr__(self):
118-
return f'CompatibilityCheckpointHandler({handler_types.typestr(type(self._handler))})'
119-
120117

121118
@dataclasses.dataclass
122119
class Args(checkpoint_args.CheckpointArgs):
123120
checkpointable: Any
124121

125-
def __repr__(self):
126-
return f'CompatibilityArgs({type(self.checkpointable)})'
127-
128122

129123
def get_compatibility_handler(
130124
handler: handler_types.CheckpointableHandler,

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/global_registration.py

-4
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,5 @@
3131

3232
registration.global_registry().add(json_handler.JsonHandler)
3333

34-
registration.global_registry().add(
35-
json_handler.MetricsHandler,
36-
format_utils.METRICS_CHECKPOINTABLE_KEY,
37-
)
3834

3935
registration.global_registry().add(pytree_handler.PyTreeHandler)

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/json_handler.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@
3333
class JsonHandler(CheckpointableHandler[JsonType, None]):
3434
"""An implementation of `CheckpointableHandler` for Json."""
3535

36-
def __init__(self, filename: str = 'metadata'):
37-
self._filename = filename
36+
def __init__(
37+
self,
38+
):
39+
self._filename = 'metadata'
3840

3941
async def _background_save(
4042
self,
@@ -85,9 +87,3 @@ def is_handleable(self, checkpointable: Any) -> bool:
8587

8688
def is_abstract_handleable(self, abstract_checkpointable: Any) -> bool | None:
8789
return None
88-
89-
90-
class MetricsHandler(JsonHandler):
91-
92-
def __init__(self):
93-
super().__init__(filename='metrics')

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -278,15 +278,7 @@ def is_abstract_handleable(self, abstract_checkpointable: Any) -> bool:
278278

279279
@contextlib.contextmanager
280280
def pytree_handler_context():
281-
"""Creates a local context for PyTree handling.
282-
283-
`PYTREE_CHECKPOINTABLE_KEY` is explicitly registered linking to
284-
`PyTreeHandler`. Note that all globally-registered handlers are still included
285-
as backup options. Other options from the parent context are carried through.
286-
287-
Yields:
288-
A new context.
289-
"""
281+
"""Creates a local context where only `PyTreeHandler` is registered."""
290282
# TODO(b/398310070): Verify behavior with nested Contexts.
291283
checkpointables_options = options_lib.CheckpointablesOptions(
292284
registry=registration.local_registry(include_global_registry=True).add(

checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading.py

-4
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,6 @@ def checkpointables_metadata(
118118
)
119119
metadata = checkpointer.metadata(path)
120120
item_metadata = {k: v for k, v in metadata.item_metadata.items()}
121-
# Exclude `metrics` if present. This is relevant only for
122-
# `training.Checkpointer`, and is separately added to the
123-
# `training.CheckpointMetadata` object.
124-
item_metadata.pop('metrics', None)
125121
return CheckpointMetadata[dict[str, Any]](
126122
metadata=item_metadata,
127123
init_timestamp_nsecs=metadata.init_timestamp_nsecs,

checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/types.py

-15
Original file line numberDiff line numberDiff line change
@@ -107,18 +107,3 @@ def from_metadata(
107107
cls, metadata: CheckpointableMetadataT
108108
) -> CheckpointMetadata[CheckpointableMetadataT]:
109109
return cls(metadata=metadata)
110-
111-
def _properties_strings(self) -> dict[str, str]:
112-
return {
113-
'metadata': str(self.metadata),
114-
'init_timestamp_nsecs': str(self.init_timestamp_nsecs),
115-
'commit_timestamp_nsecs': str(self.commit_timestamp_nsecs),
116-
'custom_metadata': str(self.custom_metadata),
117-
}
118-
119-
def __repr__(self):
120-
s = 'CheckpointMetadata('
121-
for k, v in self._properties_strings().items():
122-
s += f' {k} = {v}, '
123-
s += ')'
124-
return s

checkpoint/orbax/checkpoint/experimental/v1/_src/path/format_utils.py

-3
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,7 @@
2828

2929
PYTREE_CHECKPOINTABLE_KEY = 'pytree'
3030

31-
METRICS_CHECKPOINTABLE_KEY = 'metrics'
32-
3331
RESERVED_CHECKPOINTABLE_KEYS = frozenset({
34-
METRICS_CHECKPOINTABLE_KEY,
3532
})
3633

3734

checkpoint/orbax/checkpoint/experimental/v1/_src/saving/saving.py

-4
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,6 @@ def save_checkpointables_async(
239239
def get_v0_checkpointer_and_args(
240240
checkpointables: dict[str, Any],
241241
*,
242-
metrics: tree_types.JsonType | None = None,
243242
context: context_lib.Context,
244243
) -> tuple[
245244
async_checkpointer.AsyncCheckpointer,
@@ -253,9 +252,6 @@ def get_v0_checkpointer_and_args(
253252
raise ValueError(
254253
f'Provided reserved checkpointable keys: {provided_reserved_keys}.'
255254
)
256-
# Global registration ties metrics key to JsonHandler.
257-
if metrics:
258-
checkpointables[format_utils.METRICS_CHECKPOINTABLE_KEY] = metrics
259255

260256

261257
handlers = {

checkpoint/orbax/checkpoint/experimental/v1/_src/testing/BUILD

-11
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ py_library(
66
deps = [
77
":array_utils",
88
":handler_utils",
9-
":tree_utils",
109
"//checkpoint/orbax/checkpoint:test_utils",
1110
"//checkpoint/orbax/checkpoint/_src/path:atomicity",
1211
"//checkpoint/orbax/checkpoint/_src/serialization",
@@ -45,16 +44,6 @@ py_library(
4544
deps = ["//orbax/checkpoint/experimental/v1/_src/path:types"],
4645
)
4746

48-
py_library(
49-
name = "tree_utils",
50-
srcs = ["tree_utils.py"],
51-
deps = [
52-
"//orbax/checkpoint/experimental/v1/_src/path:format_utils",
53-
"//orbax/checkpoint/experimental/v1/_src/path:types",
54-
"//orbax/checkpoint/experimental/v1/_src/synchronization:multihost",
55-
],
56-
)
57-
5847
py_library(
5948
name = "v0v1_compatibility_save_load_test_base",
6049
srcs = ["v0v1_compatibility_save_load_test_base.py"],

checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py

+36-11
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
from orbax.checkpoint.experimental.v1._src.synchronization import multihost
4343
from orbax.checkpoint.experimental.v1._src.testing import array_utils as array_test_utils
4444
from orbax.checkpoint.experimental.v1._src.testing import handler_utils
45-
from orbax.checkpoint.experimental.v1._src.testing import tree_utils as tree_test_utils
4645
from orbax.checkpoint.experimental.v1._src.tree import types as tree_types
4746

4847

@@ -72,6 +71,33 @@ async def _sleep_and_create_paths(*args, **kwargs):
7271
return await _original_create_paths(*args, **kwargs)
7372

7473

74+
def get_d_files_mtimes(path: Path) -> list[int]:
75+
mtimes = []
76+
matching_dirs = list(path.parent.glob(f'{path.name}*'))
77+
if not matching_dirs:
78+
# Temp path not created yet.
79+
return []
80+
assert (
81+
len(matching_dirs) == 1
82+
), f'Expected exactly one matching directory, got {matching_dirs}.'
83+
tmpdir = matching_dirs[0]
84+
matching_pytree_dirs = list(tmpdir.glob(f'{PYTREE_CHECKPOINTABLE_KEY}*'))
85+
if not matching_pytree_dirs:
86+
# Temp path not created yet.
87+
return []
88+
assert len(matching_pytree_dirs) == 1, (
89+
'Expected exactly one matching pytree directory, got'
90+
f' {matching_pytree_dirs}.'
91+
)
92+
pytree_dir = matching_pytree_dirs[0]
93+
for idx in range(multihost.process_count()):
94+
d_path = pytree_dir / f'ocdbt.process_{idx}' / 'd'
95+
if not d_path.exists():
96+
continue
97+
mtimes.extend([f.stat().mtime for f in d_path.iterdir()])
98+
return mtimes
99+
100+
75101
class SaveLoadTestBase:
76102

77103
class Test(parameterized.TestCase):
@@ -121,6 +147,11 @@ def test_load_default(self, use_async):
121147
test_utils.assert_tree_equal(self, self.pytree, loaded)
122148

123149
def test_save_pytree_async(self):
150+
def is_save_complete(directory):
151+
return (
152+
directory / PYTREE_CHECKPOINTABLE_KEY / 'manifest.ocdbt'
153+
).exists()
154+
124155
start_serialize = threading.Event()
125156
original_serialize = serialization.async_serialize_from_host
126157

@@ -136,21 +167,15 @@ def mock_serialize(*args, **kwargs):
136167
)
137168

138169
response = ocp.save_pytree_async(self.directory, self.pytree)
139-
initial_d_files_mtimes = tree_test_utils.get_d_files_mtimes(
140-
self.directory
141-
)
142-
self.assertFalse(
143-
tree_test_utils.is_pytree_checkpoint_complete(self.directory)
144-
)
170+
initial_d_files_mtimes = get_d_files_mtimes(self.directory)
171+
self.assertFalse(is_save_complete(self.directory))
145172
start_serialize.set()
146173

147174
response.result()
148-
final_d_files_mtimes = tree_test_utils.get_d_files_mtimes(self.directory)
175+
final_d_files_mtimes = get_d_files_mtimes(self.directory)
149176
self.assertNotEmpty(final_d_files_mtimes)
150177
self.assertNotEqual(initial_d_files_mtimes, final_d_files_mtimes)
151-
self.assertTrue(
152-
tree_test_utils.is_pytree_checkpoint_complete(self.directory)
153-
)
178+
self.assertTrue(is_save_complete(self.directory))
154179

155180
restored = ocp.load_pytree(
156181
self.directory,

0 commit comments

Comments
 (0)