Skip to content

Use the new preservation policy in CheckpointManager. #1880

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions checkpoint/orbax/checkpoint/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ py_library(
"//checkpoint/orbax/checkpoint/_src/path:utils",
"//orbax/checkpoint/_src:threading",
"//orbax/checkpoint/_src/checkpoint_managers:policy_checkpoint_info",
"//orbax/checkpoint/_src/checkpoint_managers:preservation_policy",
],
)

Expand Down Expand Up @@ -354,5 +355,6 @@ py_library(
":abstract_checkpoint_manager",
":checkpoint_manager",
"//checkpoint/orbax/checkpoint/_src/checkpoint_managers:save_decision_policy",
"//orbax/checkpoint/_src/checkpoint_managers:preservation_policy",
],
)
165 changes: 62 additions & 103 deletions checkpoint/orbax/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from orbax.checkpoint import utils
from orbax.checkpoint._src import threading as threading_lib
from orbax.checkpoint._src.checkpoint_managers import policy_checkpoint_info
from orbax.checkpoint._src.checkpoint_managers import preservation_policy as preservation_policy_lib
from orbax.checkpoint._src.checkpoint_managers import save_decision_policy as save_decision_policy_lib
from orbax.checkpoint._src.checkpointers import abstract_checkpointer
from orbax.checkpoint._src.checkpointers import async_checkpointer
Expand Down Expand Up @@ -206,6 +207,40 @@ def _get_default_save_decision_policy(
return save_decision_policy_lib.AnySavePolicy(save_interval_policies)


def _get_default_preservation_policy(
options: CheckpointManagerOptions,
) -> preservation_policy_lib.PreservationPolicy:
"""Returns a default preservation policy."""
# Must have set max_to_keep in order to remove any checkpoints.
preservation_policies = []
if options.keep_period is not None:
preservation_policies.append(
preservation_policy_lib.EveryNSteps(options.keep_period)
)
if options.keep_time_interval is not None:
total_seconds = int(options.keep_time_interval.total_seconds())
preservation_policies.append(
preservation_policy_lib.EveryNSeconds(
interval_secs=total_seconds
)
)
if options.best_fn is not None:
preservation_policies.append(
preservation_policy_lib.BestN(
best_fn=options.best_fn,
reverse=(options.best_mode == 'min'),
n=options.max_to_keep,
)
)
else:
preservation_policies.append(
preservation_policy_lib.LatestN(n=options.max_to_keep)
)
return preservation_policy_lib.AnyPreservationPolicy(
preservation_policies
)


# TODO(b/268051457) Clean up when no longer depended upon by internal users.
def is_async_checkpointer(checkpointer: AbstractCheckpointer):
return isinstance(
Expand Down Expand Up @@ -319,6 +354,12 @@ class CheckpointManagerOptions:
is the sole means of determining when a checkpoint should be saved. If not
provided, these other options are used instead. Prefer to use this option
over others.
preservation_policy: An object used to determine which checkpoints to
preserve. If provided, overrides any other options dealing with this
subject, including `max_to_keep`, `keep_time_interval`, `keep_period`, and
`should_keep_fn`, `best_fn`, and is the sole means of determining which
checkpoints to preserve. If not provided, these other options are used
instead. Prefer to use this option over others.
prevent_write_metrics: False by default. If True, metrics will not be written.
"""

Expand Down Expand Up @@ -352,6 +393,9 @@ class CheckpointManagerOptions:
save_decision_policy: Optional[
save_decision_policy_lib.SaveDecisionPolicy
] = None
preservation_policy: Optional[
preservation_policy_lib.PreservationPolicy
] = None
prevent_write_metrics: bool = False

def __post_init__(self):
Expand Down Expand Up @@ -634,7 +678,10 @@ def __init__(
self._options.save_decision_policy
or _get_default_save_decision_policy(self._options)
)

self._preservation_policy = (
self._options.preservation_policy
or _get_default_preservation_policy(self._options)
)
if self._options.best_mode not in ['min', 'max']:
raise ValueError('`best_mode` must be one of: "min", "max"')

Expand Down Expand Up @@ -1705,22 +1752,6 @@ def build_checkpoint_info(step_metadata):
)
return checkpoint_infos

def _get_interval_preserved_checkpoints(
self, checkpoints: checkpoint_info.CheckpointInfos
) -> List[CheckpointInfo]:
"""Gets which checkpoints should be kept based on keep_time_interval."""
if checkpoints.empty():
return []
interval_preserved_checkpoints = [checkpoints[0]]
if self._options.keep_time_interval is not None:
for info in checkpoints[1:]:
if info.time >= (
interval_preserved_checkpoints[-1].time
+ self._options.keep_time_interval
):
interval_preserved_checkpoints.append(info)
return interval_preserved_checkpoints

def _add_checkpoint_info(self, step: int, metrics: Optional[PyTree]):
self._checkpoints.append(
CheckpointInfo(
Expand Down Expand Up @@ -1867,102 +1898,30 @@ def _cleanup_tmp_directories(self):

def _get_old_steps_to_remove(self) -> List[int]:
"""Returns checkpoints that should be deleted."""
# Must have set max_to_keep in order to remove any checkpoints.
if self._options.max_to_keep is None:
return []
# Not enough checkpoints accumulated to consider deletion.
if self._checkpoints.size() <= self._options.max_to_keep:
return []

# This isn't a duration but there isn't a general counter that we can use so
# we abuse a duration metric to count the number of steps examined.
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/write/old_steps_examined_count',
self._checkpoints.size(),
)

if self._track_best:
# Best steps (to keep) are at the end, after sorting.
(
checkpoints_without_metrics,
sorted_checkpoints,
) = self._sort_checkpoints_by_metrics(self._checkpoints)
else:
# checkpoints already sorted by ascending step
checkpoints_without_metrics = []
sorted_checkpoints = [info for info in self._checkpoints]

keep = int(self._options.max_to_keep)
if self._options.keep_checkpoints_without_metrics:
maybe_delete = (
sorted_checkpoints[:-keep] if keep > 0 else sorted_checkpoints
)
active_checkpoints = set(
checkpoints_without_metrics
+ (sorted_checkpoints[-keep:] if keep > 0 else [])
)
else:
all_checkpoints = checkpoints_without_metrics + sorted_checkpoints
maybe_delete = all_checkpoints[:-keep] if keep > 0 else sorted_checkpoints
active_checkpoints = set(all_checkpoints[-keep:] if keep > 0 else [])

interval_preserved_checkpoints = self._get_interval_preserved_checkpoints(
self._checkpoints
preservation_result = self._preservation_policy.should_preserve(
[info for info in self._checkpoints],
context=preservation_policy_lib.PreservationContext(),
)
kept_checkpoints = set()
for info in maybe_delete:
if (
self._options.keep_time_interval is not None
and interval_preserved_checkpoints
):
if info in interval_preserved_checkpoints:
logging.info(
'Preserving %s: (Reason: older falling on keep_time_interval).',
info,
)
kept_checkpoints.add(info)
continue
elif info.time >= (
interval_preserved_checkpoints[-1].time
+ self._options.keep_time_interval
result = []
for i in range(len(self._checkpoints)):
if not preservation_result[i]:
if (
self._options.should_keep_fn is not None
and self._options.should_keep_fn(self._checkpoints[i].step)
):
interval_preserved_checkpoints.append(info)
logging.info(
'Preserving %s: (Reason: latest falling on keep_time_interval).',
info,
'Preserving %s: (Reason: on should_keep_fn callback).',
self._checkpoints[i],
)
kept_checkpoints.add(info)
continue
if (
self._options.should_keep_fn is not None
and self._options.should_keep_fn(info.step)
):
logging.info(
'Preserving %s: (Reason: on should_keep_fn callback).', info
)
kept_checkpoints.add(info)
continue
if (
self._options.keep_period is not None
and info.step % self._options.keep_period == 0
):
logging.info(
'Preserving %s: (Reason: on keep_period=%s).',
info,
self._options.keep_period,
)
kept_checkpoints.add(info)
continue

kept_checkpoints.update(active_checkpoints)

steps_to_remove = []
for info in self._checkpoints:
if info not in kept_checkpoints:
reason = 'worse metric' if self._track_best else 'old checkpoint'
logging.info('Deleting %s: (Reason: %s).', info, reason)
steps_to_remove.append(info.step)
return steps_to_remove
else:
result.append(self._checkpoints[i].step)
return result

def _wait_for_checkpointers(self):
if is_async_checkpointer(self._checkpointer):
Expand Down
11 changes: 11 additions & 0 deletions checkpoint/orbax/checkpoint/checkpoint_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@
AnySavePolicy,
)

from orbax.checkpoint._src.checkpoint_managers import preservation_policy
from orbax.checkpoint._src.checkpoint_managers.preservation_policy import (
PreservationPolicy,
LatestN,
EveryNSeconds,
EveryNSteps,
CustomSteps,
AnyPreservationPolicy,
BestN,
)

from orbax.checkpoint.checkpoint_manager import CheckpointManagerOptions
from orbax.checkpoint.checkpoint_manager import CheckpointManager

Expand Down