Skip to content

Commit 2f8bb7c

Browse files
BlaziusMaximusOrbax Authors
authored and
Orbax Authors
committed
Add SaveDecisionPolicy and PreservationPolicy to v1, use CheckpointMetadata instead of CheckpointInfo.
PiperOrigin-RevId: 751099231
1 parent 1328f6d commit 2f8bb7c

File tree

10 files changed

+347
-57
lines changed

10 files changed

+347
-57
lines changed

checkpoint/orbax/checkpoint/_src/checkpoint_managers/BUILD

-5
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,13 @@ py_library(
55
srcs = ["save_decision_policy.py"],
66
deps = [
77
"//checkpoint/orbax/checkpoint:options",
8-
"//checkpoint/orbax/checkpoint/_src/metadata:checkpoint_info",
98
"//checkpoint/orbax/checkpoint/_src/multihost",
109
],
1110
)
1211

1312
py_library(
1413
name = "preservation_policy",
1514
srcs = ["preservation_policy.py"],
16-
deps = [
17-
"//checkpoint/orbax/checkpoint:options",
18-
"//checkpoint/orbax/checkpoint/_src/metadata:checkpoint_info",
19-
],
2015
)
2116

2217
py_test(

checkpoint/orbax/checkpoint/_src/checkpoint_managers/preservation_policy.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import datetime
1919
from typing import Any, Callable, Dict, Protocol, Sequence, Set
2020
import numpy as np
21-
from orbax.checkpoint._src.metadata import checkpoint_info
2221

2322

2423
NestedDict = Dict[str, Any]
@@ -36,7 +35,7 @@ class PreservationPolicy(Protocol):
3635

3736
def should_preserve(
3837
self,
39-
checkpoints: Sequence[checkpoint_info.CheckpointInfo],
38+
checkpoints,
4039
*,
4140
context: PreservationContext,
4241
) -> Sequence[bool]:
@@ -52,7 +51,7 @@ class LatestN(PreservationPolicy):
5251

5352
def should_preserve(
5453
self,
55-
checkpoints: Sequence[checkpoint_info.CheckpointInfo],
54+
checkpoints,
5655
*,
5756
context: PreservationContext,
5857
) -> Sequence[bool]:
@@ -69,7 +68,7 @@ class EveryNSeconds(PreservationPolicy):
6968

7069
def should_preserve(
7170
self,
72-
checkpoints: Sequence[checkpoint_info.CheckpointInfo],
71+
checkpoints,
7372
*,
7473
context: PreservationContext,
7574
) -> Sequence[bool]:
@@ -96,7 +95,7 @@ class EveryNSteps(PreservationPolicy):
9695

9796
def should_preserve(
9897
self,
99-
checkpoints: Sequence[checkpoint_info.CheckpointInfo],
98+
checkpoints,
10099
*,
101100
context: PreservationContext,
102101
) -> Sequence[bool]:
@@ -118,7 +117,7 @@ def __post_init__(self, steps_init: Sequence[int]):
118117

119118
def should_preserve(
120119
self,
121-
checkpoints: Sequence[checkpoint_info.CheckpointInfo],
120+
checkpoints,
122121
*,
123122
context: PreservationContext,
124123
) -> Sequence[bool]:
@@ -133,7 +132,7 @@ class AnyPreservationPolicy(PreservationPolicy):
133132

134133
def should_preserve(
135134
self,
136-
checkpoints: Sequence[checkpoint_info.CheckpointInfo],
135+
checkpoints,
137136
*,
138137
context: PreservationContext,
139138
) -> Sequence[bool]:
@@ -154,7 +153,7 @@ class BestN(PreservationPolicy):
154153

155154
def should_preserve(
156155
self,
157-
checkpoints: Sequence[checkpoint_info.CheckpointInfo],
156+
checkpoints,
158157
*,
159158
context: PreservationContext,
160159
) -> Sequence[bool]:

checkpoint/orbax/checkpoint/_src/checkpoint_managers/save_decision_policy.py

+7-36
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import typing
2020
from typing import Container, Protocol, Sequence
2121
from orbax.checkpoint import options as options_lib
22-
from orbax.checkpoint._src.metadata import checkpoint_info
2322
from orbax.checkpoint._src.multihost import multihost
2423

2524

@@ -41,11 +40,7 @@ class SaveDecisionPolicy(Protocol):
4140
"""
4241

4342
def should_save(
44-
self,
45-
step: checkpoint_info.CheckpointInfo,
46-
previous_steps: Sequence[checkpoint_info.CheckpointInfo],
47-
*,
48-
context: DecisionContext
43+
self, step, previous_steps, *, context: DecisionContext
4944
) -> bool:
5045
...
5146

@@ -57,11 +52,7 @@ class FixedIntervalPolicy(SaveDecisionPolicy):
5752
interval: int
5853

5954
def should_save(
60-
self,
61-
step: checkpoint_info.CheckpointInfo,
62-
previous_steps: Sequence[checkpoint_info.CheckpointInfo],
63-
*,
64-
context: DecisionContext
55+
self, step, previous_steps, *, context: DecisionContext
6556
) -> bool:
6657
del previous_steps
6758
del context
@@ -75,11 +66,7 @@ class SpecificStepsPolicy(SaveDecisionPolicy):
7566
steps: Container[int]
7667

7768
def should_save(
78-
self,
79-
step: checkpoint_info.CheckpointInfo,
80-
previous_steps: Sequence[checkpoint_info.CheckpointInfo],
81-
*,
82-
context: DecisionContext
69+
self, step, previous_steps, *, context: DecisionContext
8370
) -> bool:
8471
del previous_steps
8572
del context
@@ -93,11 +80,7 @@ class ContinuousCheckpointingPolicy(SaveDecisionPolicy):
9380
minimum_interval_secs: int | None = None
9481

9582
def should_save(
96-
self,
97-
step: checkpoint_info.CheckpointInfo,
98-
previous_steps: Sequence[checkpoint_info.CheckpointInfo],
99-
*,
100-
context: DecisionContext
83+
self, step, previous_steps, *, context: DecisionContext
10184
) -> bool:
10285
if context.is_saving_in_progress:
10386
return False
@@ -122,11 +105,7 @@ class PreemptionCheckpointingPolicy(SaveDecisionPolicy):
122105
"""Save a checkpoint when a preemption is detected."""
123106

124107
def should_save(
125-
self,
126-
step: checkpoint_info.CheckpointInfo,
127-
previous_steps: Sequence[checkpoint_info.CheckpointInfo],
128-
*,
129-
context: DecisionContext
108+
self, step, previous_steps, *, context: DecisionContext
130109
) -> bool:
131110
del step
132111
del previous_steps
@@ -137,11 +116,7 @@ class InitialSavePolicy(SaveDecisionPolicy):
137116
"""Checkpoint as soon as possible if no checkpoints already exist."""
138117

139118
def should_save(
140-
self,
141-
step: checkpoint_info.CheckpointInfo,
142-
previous_steps: Sequence[checkpoint_info.CheckpointInfo],
143-
*,
144-
context: DecisionContext
119+
self, step, previous_steps, *, context: DecisionContext
145120
) -> bool:
146121
del step
147122
del context
@@ -159,11 +134,7 @@ class AnySavePolicy(SaveDecisionPolicy):
159134
policies: Sequence[SaveDecisionPolicy]
160135

161136
def should_save(
162-
self,
163-
step: checkpoint_info.CheckpointInfo,
164-
previous_steps: Sequence[checkpoint_info.CheckpointInfo],
165-
*,
166-
context: DecisionContext
137+
self, step, previous_steps, *, context: DecisionContext
167138
) -> bool:
168139
return any(
169140
policy.should_save(step, previous_steps=previous_steps, context=context)

checkpoint/orbax/checkpoint/checkpoint_manager.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,8 @@ class _ShouldSaveFnPolicy(save_decision_policy_lib.SaveDecisionPolicy):
169169

170170
def should_save(
171171
self,
172-
step: checkpoint_info.CheckpointInfo,
173-
previous_steps: Sequence[checkpoint_info.CheckpointInfo],
172+
step,
173+
previous_steps,
174174
*,
175175
context: save_decision_policy_lib.DecisionContext,
176176
) -> bool:

checkpoint/orbax/checkpoint/experimental/v1/_src/synchronization/multihost.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616

1717
import threading
1818
import time
19-
from typing import Collection
19+
from typing import Collection, Optional
2020
from absl import logging
2121
import jax
22+
from jax.experimental import multihost_utils
2223
from orbax.checkpoint.experimental.v1._src.synchronization import signaling_client
2324

2425
# Default timeout in seconds.
@@ -139,3 +140,10 @@ def process_index() -> int:
139140
# global_state.process_id. We rely on the latter to work with barriers over a
140141
# subset of processes.
141142
return jax._src.distributed.global_state.process_id # pylint: disable=protected-access
143+
144+
145+
def broadcast_one_to_all(in_tree, is_source: Optional[bool] = None):
146+
"""Broadcast data from a source host to all other hosts."""
147+
if is_source is None:
148+
is_source = process_index() == 0
149+
return multihost_utils.broadcast_one_to_all(in_tree, is_source=is_source)

checkpoint/orbax/checkpoint/experimental/v1/_src/training/BUILD

+22-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,28 @@ package(default_visibility = ["//visibility:public"])
33
py_library(
44
name = "save_decision_policies",
55
srcs = ["save_decision_policies.py"],
6-
deps = ["//checkpoint/orbax/checkpoint/_src/checkpoint_managers:save_decision_policy"],
6+
deps = [
7+
"//checkpoint/orbax/checkpoint/_src/checkpoint_managers:save_decision_policy",
8+
"//orbax/checkpoint/experimental/v1/_src/training/metadata:types",
9+
],
10+
)
11+
12+
py_library(
13+
name = "preservation_policy",
14+
srcs = ["preservation_policy.py"],
15+
deps = [
16+
"//orbax/checkpoint/_src/checkpoint_managers:preservation_policy",
17+
"//orbax/checkpoint/experimental/v1/_src/training/metadata:types",
18+
],
19+
)
20+
21+
py_test(
22+
name = "preservation_policy_test",
23+
srcs = ["preservation_policy_test.py"],
24+
deps = [
25+
":preservation_policy",
26+
"//orbax/checkpoint/experimental/v1/_src/training/metadata:types",
27+
],
728
)
829

930
py_library(

checkpoint/orbax/checkpoint/experimental/v1/_src/training/metadata/types.py

+13
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Metadata for `training.Checkpointer`."""
1616

1717
import dataclasses
18+
import datetime
1819
from orbax.checkpoint._src.metadata import checkpoint_info
1920
from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types
2021
from orbax.checkpoint.experimental.v1._src.tree import types as tree_types
@@ -70,6 +71,18 @@ def step(self) -> int:
7071
def metrics(self) -> tree_types.JsonType | None:
7172
return self._metrics
7273

74+
@metrics.setter
75+
def metrics(self, new_metrics):
76+
if new_metrics is not None and not isinstance(new_metrics, (list, dict)):
77+
raise ValueError("Metrics must be a JSON-serializable object.")
78+
self._metrics = new_metrics
79+
80+
@property
81+
def time(self) -> datetime.datetime:
82+
return datetime.datetime.fromtimestamp(
83+
self.commit_timestamp_nsecs / 1e9, tz=datetime.timezone.utc
84+
)
85+
7386

7487
@dataclasses.dataclass(frozen=True, kw_only=True)
7588
class RootMetadata:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright 2024 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Defines policies for when a checkpoint is preserved."""
16+
17+
from typing import Any, Dict, Protocol, Sequence
18+
from orbax.checkpoint._src.checkpoint_managers import preservation_policy as preservation_policy_lib
19+
from orbax.checkpoint.experimental.v1._src.training.metadata import types
20+
21+
22+
NestedDict = Dict[str, Any]
23+
PyTree = Any
24+
25+
PreservationContext = preservation_policy_lib.PreservationContext
26+
LatestN = preservation_policy_lib.LatestN
27+
EveryNSeconds = preservation_policy_lib.EveryNSeconds
28+
EveryNSteps = preservation_policy_lib.EveryNSteps
29+
CustomSteps = preservation_policy_lib.CustomSteps
30+
AnyPreservationPolicy = preservation_policy_lib.AnyPreservationPolicy
31+
BestN = preservation_policy_lib.BestN
32+
33+
34+
class PreservationPolicy(Protocol):
35+
"""A policy that defines when checkpoints should be preserved."""
36+
37+
def should_preserve(
38+
self,
39+
checkpoints: Sequence[types.CheckpointMetadata],
40+
*,
41+
context: PreservationContext,
42+
) -> Sequence[bool]:
43+
"""Indicates which checkpoints should be preserved.."""
44+
...

0 commit comments

Comments
 (0)