diff --git a/checkpoint/orbax/checkpoint/BUILD b/checkpoint/orbax/checkpoint/BUILD index c5e3ad25f..8c5a7056b 100644 --- a/checkpoint/orbax/checkpoint/BUILD +++ b/checkpoint/orbax/checkpoint/BUILD @@ -140,6 +140,7 @@ py_library( "//checkpoint/orbax/checkpoint/_src/path:utils", "//orbax/checkpoint/_src:threading", "//orbax/checkpoint/_src/checkpoint_managers:policy_checkpoint_info", + "//orbax/checkpoint/_src/checkpoint_managers:preservation_policy", ], ) @@ -354,5 +355,6 @@ py_library( ":abstract_checkpoint_manager", ":checkpoint_manager", "//checkpoint/orbax/checkpoint/_src/checkpoint_managers:save_decision_policy", + "//orbax/checkpoint/_src/checkpoint_managers:preservation_policy", ], ) diff --git a/checkpoint/orbax/checkpoint/checkpoint_manager.py b/checkpoint/orbax/checkpoint/checkpoint_manager.py index 0119d3a66..2d7c33fbe 100644 --- a/checkpoint/orbax/checkpoint/checkpoint_manager.py +++ b/checkpoint/orbax/checkpoint/checkpoint_manager.py @@ -36,6 +36,7 @@ from orbax.checkpoint import utils from orbax.checkpoint._src import threading as threading_lib from orbax.checkpoint._src.checkpoint_managers import policy_checkpoint_info +from orbax.checkpoint._src.checkpoint_managers import preservation_policy as preservation_policy_lib from orbax.checkpoint._src.checkpoint_managers import save_decision_policy as save_decision_policy_lib from orbax.checkpoint._src.checkpointers import abstract_checkpointer from orbax.checkpoint._src.checkpointers import async_checkpointer @@ -206,6 +207,40 @@ def _get_default_save_decision_policy( return save_decision_policy_lib.AnySavePolicy(save_interval_policies) +def _get_default_preservation_policy( + options: CheckpointManagerOptions, +) -> preservation_policy_lib.PreservationPolicy: + """Returns a default preservation policy.""" + # Must have set max_to_keep in order to remove any checkpoints. + preservation_policies = [] + if options.keep_period is not None: + preservation_policies.append( + preservation_policy_lib.EveryNSteps(options.keep_period) + ) + if options.keep_time_interval is not None: + total_seconds = int(options.keep_time_interval.total_seconds()) + preservation_policies.append( + preservation_policy_lib.EveryNSeconds( + interval_secs=total_seconds + ) + ) + if options.best_fn is not None: + preservation_policies.append( + preservation_policy_lib.BestN( + best_fn=options.best_fn, + reverse=(options.best_mode == 'min'), + n=options.max_to_keep, + ) + ) + else: + preservation_policies.append( + preservation_policy_lib.LatestN(n=options.max_to_keep) + ) + return preservation_policy_lib.AnyPreservationPolicy( + preservation_policies + ) + + # TODO(b/268051457) Clean up when no longer depended upon by internal users. def is_async_checkpointer(checkpointer: AbstractCheckpointer): return isinstance( @@ -319,6 +354,12 @@ class CheckpointManagerOptions: is the sole means of determining when a checkpoint should be saved. If not provided, these other options are used instead. Prefer to use this option over others. + preservation_policy: An object used to determine which checkpoints to + preserve. If provided, overrides any other options dealing with this + subject, including `max_to_keep`, `keep_time_interval`, `keep_period`, and + `should_keep_fn`, `best_fn`, and is the sole means of determining which + checkpoints to preserve. If not provided, these other options are used + instead. Prefer to use this option over others. prevent_write_metrics: False by default. If True, metrics will not be written. """ @@ -352,6 +393,9 @@ class CheckpointManagerOptions: save_decision_policy: Optional[ save_decision_policy_lib.SaveDecisionPolicy ] = None + preservation_policy: Optional[ + preservation_policy_lib.PreservationPolicy + ] = None prevent_write_metrics: bool = False def __post_init__(self): @@ -634,7 +678,10 @@ def __init__( self._options.save_decision_policy or _get_default_save_decision_policy(self._options) ) - + self._preservation_policy = ( + self._options.preservation_policy + or _get_default_preservation_policy(self._options) + ) if self._options.best_mode not in ['min', 'max']: raise ValueError('`best_mode` must be one of: "min", "max"') @@ -1705,22 +1752,6 @@ def build_checkpoint_info(step_metadata): ) return checkpoint_infos - def _get_interval_preserved_checkpoints( - self, checkpoints: checkpoint_info.CheckpointInfos - ) -> List[CheckpointInfo]: - """Gets which checkpoints should be kept based on keep_time_interval.""" - if checkpoints.empty(): - return [] - interval_preserved_checkpoints = [checkpoints[0]] - if self._options.keep_time_interval is not None: - for info in checkpoints[1:]: - if info.time >= ( - interval_preserved_checkpoints[-1].time - + self._options.keep_time_interval - ): - interval_preserved_checkpoints.append(info) - return interval_preserved_checkpoints - def _add_checkpoint_info(self, step: int, metrics: Optional[PyTree]): self._checkpoints.append( CheckpointInfo( @@ -1867,102 +1898,30 @@ def _cleanup_tmp_directories(self): def _get_old_steps_to_remove(self) -> List[int]: """Returns checkpoints that should be deleted.""" - # Must have set max_to_keep in order to remove any checkpoints. - if self._options.max_to_keep is None: - return [] - # Not enough checkpoints accumulated to consider deletion. - if self._checkpoints.size() <= self._options.max_to_keep: - return [] - # This isn't a duration but there isn't a general counter that we can use so # we abuse a duration metric to count the number of steps examined. jax.monitoring.record_event_duration_secs( '/jax/checkpoint/write/old_steps_examined_count', self._checkpoints.size(), ) - - if self._track_best: - # Best steps (to keep) are at the end, after sorting. - ( - checkpoints_without_metrics, - sorted_checkpoints, - ) = self._sort_checkpoints_by_metrics(self._checkpoints) - else: - # checkpoints already sorted by ascending step - checkpoints_without_metrics = [] - sorted_checkpoints = [info for info in self._checkpoints] - - keep = int(self._options.max_to_keep) - if self._options.keep_checkpoints_without_metrics: - maybe_delete = ( - sorted_checkpoints[:-keep] if keep > 0 else sorted_checkpoints - ) - active_checkpoints = set( - checkpoints_without_metrics - + (sorted_checkpoints[-keep:] if keep > 0 else []) - ) - else: - all_checkpoints = checkpoints_without_metrics + sorted_checkpoints - maybe_delete = all_checkpoints[:-keep] if keep > 0 else sorted_checkpoints - active_checkpoints = set(all_checkpoints[-keep:] if keep > 0 else []) - - interval_preserved_checkpoints = self._get_interval_preserved_checkpoints( - self._checkpoints + preservation_result = self._preservation_policy.should_preserve( + [info for info in self._checkpoints], + context=preservation_policy_lib.PreservationContext(), ) - kept_checkpoints = set() - for info in maybe_delete: - if ( - self._options.keep_time_interval is not None - and interval_preserved_checkpoints - ): - if info in interval_preserved_checkpoints: - logging.info( - 'Preserving %s: (Reason: older falling on keep_time_interval).', - info, - ) - kept_checkpoints.add(info) - continue - elif info.time >= ( - interval_preserved_checkpoints[-1].time - + self._options.keep_time_interval + result = [] + for i in range(len(self._checkpoints)): + if not preservation_result[i]: + if ( + self._options.should_keep_fn is not None + and self._options.should_keep_fn(self._checkpoints[i].step) ): - interval_preserved_checkpoints.append(info) logging.info( - 'Preserving %s: (Reason: latest falling on keep_time_interval).', - info, + 'Preserving %s: (Reason: on should_keep_fn callback).', + self._checkpoints[i], ) - kept_checkpoints.add(info) - continue - if ( - self._options.should_keep_fn is not None - and self._options.should_keep_fn(info.step) - ): - logging.info( - 'Preserving %s: (Reason: on should_keep_fn callback).', info - ) - kept_checkpoints.add(info) - continue - if ( - self._options.keep_period is not None - and info.step % self._options.keep_period == 0 - ): - logging.info( - 'Preserving %s: (Reason: on keep_period=%s).', - info, - self._options.keep_period, - ) - kept_checkpoints.add(info) - continue - - kept_checkpoints.update(active_checkpoints) - - steps_to_remove = [] - for info in self._checkpoints: - if info not in kept_checkpoints: - reason = 'worse metric' if self._track_best else 'old checkpoint' - logging.info('Deleting %s: (Reason: %s).', info, reason) - steps_to_remove.append(info.step) - return steps_to_remove + else: + result.append(self._checkpoints[i].step) + return result def _wait_for_checkpointers(self): if is_async_checkpointer(self._checkpointer): diff --git a/checkpoint/orbax/checkpoint/checkpoint_managers.py b/checkpoint/orbax/checkpoint/checkpoint_managers.py index 2c987140e..92b5e96a1 100644 --- a/checkpoint/orbax/checkpoint/checkpoint_managers.py +++ b/checkpoint/orbax/checkpoint/checkpoint_managers.py @@ -26,6 +26,17 @@ AnySavePolicy, ) +from orbax.checkpoint._src.checkpoint_managers import preservation_policy +from orbax.checkpoint._src.checkpoint_managers.preservation_policy import ( + PreservationPolicy, + LatestN, + EveryNSeconds, + EveryNSteps, + CustomSteps, + AnyPreservationPolicy, + BestN, +) + from orbax.checkpoint.checkpoint_manager import CheckpointManagerOptions from orbax.checkpoint.checkpoint_manager import CheckpointManager