19
19
import typing
20
20
from typing import Container , Protocol , Sequence
21
21
from orbax .checkpoint import options as options_lib
22
- from orbax .checkpoint ._src .metadata import checkpoint_info
23
22
from orbax .checkpoint ._src .multihost import multihost
24
23
25
24
@@ -41,11 +40,7 @@ class SaveDecisionPolicy(Protocol):
41
40
"""
42
41
43
42
def should_save (
44
- self ,
45
- step : checkpoint_info .CheckpointInfo ,
46
- previous_steps : Sequence [checkpoint_info .CheckpointInfo ],
47
- * ,
48
- context : DecisionContext
43
+ self , step , previous_steps , * , context : DecisionContext
49
44
) -> bool :
50
45
...
51
46
@@ -57,11 +52,7 @@ class FixedIntervalPolicy(SaveDecisionPolicy):
57
52
interval : int
58
53
59
54
def should_save (
60
- self ,
61
- step : checkpoint_info .CheckpointInfo ,
62
- previous_steps : Sequence [checkpoint_info .CheckpointInfo ],
63
- * ,
64
- context : DecisionContext
55
+ self , step , previous_steps , * , context : DecisionContext
65
56
) -> bool :
66
57
del previous_steps
67
58
del context
@@ -75,11 +66,7 @@ class SpecificStepsPolicy(SaveDecisionPolicy):
75
66
steps : Container [int ]
76
67
77
68
def should_save (
78
- self ,
79
- step : checkpoint_info .CheckpointInfo ,
80
- previous_steps : Sequence [checkpoint_info .CheckpointInfo ],
81
- * ,
82
- context : DecisionContext
69
+ self , step , previous_steps , * , context : DecisionContext
83
70
) -> bool :
84
71
del previous_steps
85
72
del context
@@ -93,11 +80,7 @@ class ContinuousCheckpointingPolicy(SaveDecisionPolicy):
93
80
minimum_interval_secs : int | None = None
94
81
95
82
def should_save (
96
- self ,
97
- step : checkpoint_info .CheckpointInfo ,
98
- previous_steps : Sequence [checkpoint_info .CheckpointInfo ],
99
- * ,
100
- context : DecisionContext
83
+ self , step , previous_steps , * , context : DecisionContext
101
84
) -> bool :
102
85
if context .is_saving_in_progress :
103
86
return False
@@ -122,11 +105,7 @@ class PreemptionCheckpointingPolicy(SaveDecisionPolicy):
122
105
"""Save a checkpoint when a preemption is detected."""
123
106
124
107
def should_save (
125
- self ,
126
- step : checkpoint_info .CheckpointInfo ,
127
- previous_steps : Sequence [checkpoint_info .CheckpointInfo ],
128
- * ,
129
- context : DecisionContext
108
+ self , step , previous_steps , * , context : DecisionContext
130
109
) -> bool :
131
110
del step
132
111
del previous_steps
@@ -137,11 +116,7 @@ class InitialSavePolicy(SaveDecisionPolicy):
137
116
"""Checkpoint as soon as possible if no checkpoints already exist."""
138
117
139
118
def should_save (
140
- self ,
141
- step : checkpoint_info .CheckpointInfo ,
142
- previous_steps : Sequence [checkpoint_info .CheckpointInfo ],
143
- * ,
144
- context : DecisionContext
119
+ self , step , previous_steps , * , context : DecisionContext
145
120
) -> bool :
146
121
del step
147
122
del context
@@ -159,11 +134,7 @@ class AnySavePolicy(SaveDecisionPolicy):
159
134
policies : Sequence [SaveDecisionPolicy ]
160
135
161
136
def should_save (
162
- self ,
163
- step : checkpoint_info .CheckpointInfo ,
164
- previous_steps : Sequence [checkpoint_info .CheckpointInfo ],
165
- * ,
166
- context : DecisionContext
137
+ self , step , previous_steps , * , context : DecisionContext
167
138
) -> bool :
168
139
return any (
169
140
policy .should_save (step , previous_steps = previous_steps , context = context )
0 commit comments