Skip to content

Removing _track_best from manager #1870

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ def should_preserve(

@dataclasses.dataclass(kw_only=True)
class BestN(PreservationPolicy):
"""A policy that preserves the best checkpoints based on a best_fn."""
"""A policy that preserves the best checkpoints based on a get_metric_fn."""

best_fn: Callable[[PyTree], float]
get_metric_fn: Callable[[PyTree], float]
reverse: bool
n: int | None = None

Expand All @@ -174,7 +174,7 @@ def should_preserve(
]
indexed_checkpoints_with_metrics = sorted(
indexed_checkpoints_with_metrics,
key=lambda item: self.best_fn(item[1].metrics),
key=lambda item: self.get_metric_fn(item[1].metrics),
reverse=self.reverse,
)
preserve_indices = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def test_custom_steps_policy(self, steps, expected_preserved_steps):
)
def test_best_n_policy(self, n, loss, expected_preserved_steps):
policy = preservation_policy_lib.BestN(
best_fn=lambda metrics: metrics['loss'],
get_metric_fn=lambda metrics: metrics['loss'],
reverse=True,
n=n,
)
Expand All @@ -199,7 +199,7 @@ def test_joint_preservation_policy(self):
), # 0, 3, 6, 9
preservation_policy_lib.CustomSteps(steps=[0, 3]), # 0, 3
preservation_policy_lib.BestN(
best_fn=lambda metrics: metrics['loss'],
get_metric_fn=lambda metrics: metrics['loss'],
reverse=True,
n=2,
), # 1, 2, 3, 4, 5, 7, 9, 11
Expand Down
42 changes: 16 additions & 26 deletions checkpoint/orbax/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def _get_default_preservation_policy(
if options.best_fn is not None:
preservation_policies.append(
preservation_policy_lib.BestN(
best_fn=options.best_fn,
get_metric_fn=options.best_fn,
reverse=(options.best_mode == 'min'),
n=options.max_to_keep,
)
Expand Down Expand Up @@ -835,7 +835,7 @@ def __init__(
)

self._checkpoints = checkpoint_info.CheckpointInfos(
self._load_checkpoint_infos()
self._load_checkpoint_infos
)

self._metadata_dir = self.directory / METADATA_ITEM_NAME
Expand Down Expand Up @@ -1110,7 +1110,7 @@ def all_steps(self, read: bool = False) -> Sequence[int]:
logging.warning(
'`read` option is deprecated. Use `reload` to read from disk.'
)
self._checkpoints.set(self._load_checkpoint_infos())
self._checkpoints.set(self._load_checkpoint_infos)
return [ckpt.step for ckpt in self._checkpoints]

def latest_step(self) -> Optional[int]:
Expand All @@ -1132,8 +1132,6 @@ def best_step(self) -> Optional[int]:
Returns:
A step (int) or None if no steps are present.
"""
if not self._track_best:
return self.latest_step()
if self._checkpoints.empty():
return None
_, sorted_checkpoints = self._sort_checkpoints_by_metrics(self._checkpoints)
Expand All @@ -1147,7 +1145,7 @@ def reload(self):
Resets internal cache of checkpoint steps, in case the directory managed
by this object has been updated externally.
"""
self._checkpoints.set(self._load_checkpoint_infos())
self._checkpoints.set(self._load_checkpoint_infos)

def reached_preemption(self, step: int) -> bool:
"""Returns True if a preemption sync point has been reached."""
Expand Down Expand Up @@ -1397,7 +1395,7 @@ def save(
items = {DEFAULT_ITEM_NAME: items}
save_kwargs = {DEFAULT_ITEM_NAME: save_kwargs}

if self._track_best and metrics is None:
if self._options.best_fn is not None and metrics is None:
logging.warning('Requested `tracked_metric`; did not provide metrics.')

if args is None:
Expand Down Expand Up @@ -1429,7 +1427,7 @@ def save(
'Some provided items have prohibited reserved names:'
f' {args_dict.keys()}. Reserved names: {RESERVED_ITEM_NAMES}.'
)
if metrics is not None and self._track_best:
if metrics is not None:
args_dict['metrics'] = args_lib.JsonSave(metrics)
args = args_lib.Composite(**args_dict)

Expand Down Expand Up @@ -1701,20 +1699,16 @@ def item_metadata(

# TODO(b/370812224): Deprecate in favor of StepMetadata.metrics
def metrics(self, step: int) -> Optional[PyTree]:
if self._track_best:
try:
# Use handler directly, since this happens in a background thread and
# barriers cannot be used. This usage pattern is not
# recommended in other contexts.
return self._metrics_handler.restore(
self._get_read_step_directory(step, self.directory)
/ METRIC_ITEM_NAME
)
except FileNotFoundError as e:
logging.warning('Missing metrics for step %d', step)
logging.error(e)
return None
else:
try:
# Use handler directly, since this happens in a background thread and
# barriers cannot be used. This usage pattern is not
# recommended in other contexts.
return self._metrics_handler.restore(
self._get_read_step_directory(step, self.directory) / METRIC_ITEM_NAME
)
except FileNotFoundError as e:
logging.warning('Missing metrics for step %d', step)
logging.error(e)
return None

@property
Expand All @@ -1725,10 +1719,6 @@ def _metrics_handler(self) -> CheckpointHandler:
)

@property
def _track_best(self):
"""Returns true if we should track the best checkpoints by given metric."""
return self._options.best_fn is not None

def _load_checkpoint_infos(self) -> List[CheckpointInfo]:
"""Loads a list of CheckpointInfo for existing checkpoints.

Expand Down