|
92 | 92 | FileOptions = options_lib.FileOptions
|
93 | 93 |
|
94 | 94 | DEFAULT_ITEM_NAME = 'default'
|
| 95 | +DESCRIPTOR_ITEM_NAME = 'descriptor' |
95 | 96 | METRIC_ITEM_NAME = 'metrics'
|
96 | 97 | METADATA_ITEM_NAME = 'metadata'
|
97 |
| -RESERVED_ITEM_NAMES = [] |
| 98 | +RESERVED_ITEM_NAMES = [DESCRIPTOR_ITEM_NAME, METRIC_ITEM_NAME] |
98 | 99 |
|
99 | 100 | _INIT_TIME = datetime.datetime.now(tz=datetime.timezone.utc)
|
100 | 101 |
|
@@ -318,7 +319,6 @@ class CheckpointManagerOptions:
|
318 | 319 | is the sole means of determining when a checkpoint should be saved. If not
|
319 | 320 | provided, these other options are used instead. Prefer to use this option
|
320 | 321 | over others.
|
321 |
| - prevent_write_metrics: False by default. If True, metrics will not be written. |
322 | 322 | """
|
323 | 323 |
|
324 | 324 | save_interval_steps: int = 1
|
@@ -351,7 +351,6 @@ class CheckpointManagerOptions:
|
351 | 351 | save_decision_policy: Optional[
|
352 | 352 | save_decision_policy_lib.SaveDecisionPolicy
|
353 | 353 | ] = None
|
354 |
| - prevent_write_metrics: bool = False |
355 | 354 |
|
356 | 355 | def __post_init__(self):
|
357 | 356 | step_name_format_single_host_load_and_broadcast = (
|
@@ -911,6 +910,7 @@ def _configure_checkpointer_legacy_init(
|
911 | 910 | f'Invalid type for `checkpointers`. Found {checkpointers}.'
|
912 | 911 | )
|
913 | 912 |
|
| 913 | + # if options.best_fn: |
914 | 914 | item_handlers[METRIC_ITEM_NAME] = self._metrics_handler
|
915 | 915 | if options.async_options is None:
|
916 | 916 | options.async_options = (
|
@@ -1366,10 +1366,8 @@ def save(
|
1366 | 1366 | 'Some provided items have prohibited reserved names:'
|
1367 | 1367 | f' {args_dict.keys()}. Reserved names: {RESERVED_ITEM_NAMES}.'
|
1368 | 1368 | )
|
1369 |
| - if ( |
1370 |
| - metrics is not None and self._track_best |
1371 |
| - ) and not self._options.prevent_write_metrics: |
1372 |
| - args_dict[METRIC_ITEM_NAME] = args_lib.JsonSave(metrics) |
| 1369 | + if metrics is not None and self._track_best: |
| 1370 | + args_dict['metrics'] = args_lib.JsonSave(metrics) |
1373 | 1371 | args = args_lib.Composite(**args_dict)
|
1374 | 1372 |
|
1375 | 1373 | save_directory = self._get_write_step_directory(step, self.directory)
|
@@ -1640,18 +1638,20 @@ def item_metadata(
|
1640 | 1638 |
|
1641 | 1639 | # TODO(b/370812224): Deprecate in favor of StepMetadata.metrics
|
1642 | 1640 | def metrics(self, step: int) -> Optional[PyTree]:
|
1643 |
| - try: |
1644 |
| - # Use handler directly, since this happens in a background thread and |
1645 |
| - # barriers cannot be used. This usage pattern is not |
1646 |
| - # recommended in other contexts. |
1647 |
| - metrics = self._metrics_handler.restore( |
1648 |
| - self._get_read_step_directory(step, self.directory) |
1649 |
| - / METRIC_ITEM_NAME |
1650 |
| - ) |
1651 |
| - return metrics |
1652 |
| - except FileNotFoundError as e: |
1653 |
| - logging.warning('Missing metrics for step %d', step) |
1654 |
| - logging.error(e) |
| 1641 | + if self._track_best: |
| 1642 | + try: |
| 1643 | + # Use handler directly, since this happens in a background thread and |
| 1644 | + # barriers cannot be used. This usage pattern is not |
| 1645 | + # recommended in other contexts. |
| 1646 | + return self._metrics_handler.restore( |
| 1647 | + self._get_read_step_directory(step, self.directory) |
| 1648 | + / METRIC_ITEM_NAME |
| 1649 | + ) |
| 1650 | + except FileNotFoundError as e: |
| 1651 | + logging.warning('Missing metrics for step %d', step) |
| 1652 | + logging.error(e) |
| 1653 | + return None |
| 1654 | + else: |
1655 | 1655 | return None
|
1656 | 1656 |
|
1657 | 1657 | @property
|
|
0 commit comments