@@ -106,10 +106,8 @@ def wrapper(*args, **kwargs):
106106def set_up_loggers (train_dir , xm_work_unit = None ):
107107 """Creates a logger for eval metrics as well as initialization metrics."""
108108 csv_path = os .path .join (train_dir , 'measurements.csv' )
109- pytree_path = os .path .join (train_dir , 'ttl=180d' , 'training_metrics' )
110109 metrics_logger = MetricLogger (
111110 csv_path = csv_path ,
112- pytree_path = pytree_path ,
113111 xm_work_unit = xm_work_unit ,
114112 events_dir = train_dir )
115113
@@ -171,6 +169,68 @@ def add_log_file(logfile):
171169 absl_logger .addHandler (handler )
172170
173171
172+ class PytreeMetricLogger (object ):
173+ """Used to log pytree metrics during training.
174+
175+ This class is used to log pytree metrics during training. It is similar to
176+ MetricLogger, but it is designed to log pytree metrics.
177+ """
178+
179+ def __init__ (
180+ self ,
181+ pytree_path ,
182+ ):
183+ self ._pytree_path = pytree_path
184+
185+ orbax_file_options = ocp .checkpoint_manager .FileOptions (
186+ path_permission_mode = 0o775 ,
187+ cns2_storage_options = ocp .options .Cns2StorageOptions (
188+ choose_store_cell = True ,
189+ ),
190+ )
191+ self ._orbax_checkpoint_manager = ocp .CheckpointManager (
192+ self ._pytree_path ,
193+ options = ocp .CheckpointManagerOptions (
194+ file_options = orbax_file_options ,
195+ max_to_keep = 1 , create = True ,
196+ ),
197+ )
198+
199+ def write_pytree (self , pytree , step = 0 ):
200+ """Record a serializable pytree to disk at the given step.
201+
202+ Args:
203+ pytree: Any serializable pytree
204+ step: Integer. The global step.
205+ """
206+ state = dict (pytree = pytree )
207+ checkpoint .save_checkpoint (
208+ step ,
209+ state = state ,
210+ orbax_checkpoint_manager = self ._orbax_checkpoint_manager ,
211+ )
212+
213+ def wait_until_pytree_checkpoint_finished (self ):
214+ self ._orbax_checkpoint_manager .wait_until_finished ()
215+
216+ def latest_pytree_checkpoint_step (self ):
217+ return self ._orbax_checkpoint_manager .latest_step ()
218+
219+ def load_latest_pytree (self , target = None ):
220+ """Load pytree from checkpoint."""
221+ if target :
222+ target = dict (pytree = target )
223+ logging .info ('target: %s' , target )
224+ loaded_target = checkpoint .load_latest_checkpoint (
225+ target = target ,
226+ orbax_checkpoint_manager = self ._orbax_checkpoint_manager ,
227+ )
228+ if loaded_target :
229+ return loaded_target ['pytree' ]
230+ else :
231+ return target
232+
233+
174234# TODO(gdahl,gilmer): Use atomic writes to avoid file corruptions due to
175235# preemptions.
176236class MetricLogger (object ):
@@ -183,7 +243,6 @@ class MetricLogger(object):
183243 def __init__ (self ,
184244 csv_path = '' ,
185245 json_path = '' ,
186- pytree_path = '' ,
187246 events_dir = None ,
188247 ** logger_kwargs ):
189248 """Create a recorder for metrics, as CSV or JSON.
@@ -192,7 +251,6 @@ def __init__(self,
192251 Args:
193252 csv_path: A filepath to a CSV file to append to.
194253 json_path: An optional filepath to a JSON file to append to.
195- pytree_path: Where to save trees of numeric arrays.
196254 events_dir: Optional. If specified, save tfevents summaries to this
197255 directory.
198256 **logger_kwargs: Optional keyword arguments, whose only valid parameter
@@ -202,7 +260,6 @@ def __init__(self,
202260 self ._measurements = {}
203261 self ._csv_path = csv_path
204262 self ._json_path = json_path
205- self ._pytree_path = pytree_path
206263 if logger_kwargs :
207264 if len (logger_kwargs .keys ()) > 1 or 'xm_work_unit' not in logger_kwargs :
208265 raise ValueError (
@@ -216,33 +273,6 @@ def __init__(self,
216273 if events_dir :
217274 self ._tb_metric_writer = metric_writers .create_default_writer (events_dir )
218275
219- orbax_file_options = ocp .checkpoint_manager .FileOptions (
220- path_permission_mode = 0o775 ,
221- cns2_storage_options = ocp .options .Cns2StorageOptions (
222- choose_store_cell = True ,
223- ),
224- )
225- self ._orbax_checkpoint_manager = ocp .CheckpointManager (
226- self ._pytree_path ,
227- options = ocp .CheckpointManagerOptions (
228- max_to_keep = 1 , create = True , file_options = orbax_file_options
229- ),
230- )
231-
232- def load_latest_pytree (self , target = None ):
233- """Load pytree from checkpoint."""
234- if target :
235- target = dict (pytree = target )
236- logging .info ('target: %s' , target )
237- loaded_target = checkpoint .load_latest_checkpoint (
238- target = target ,
239- orbax_checkpoint_manager = self ._orbax_checkpoint_manager ,
240- )
241- if loaded_target :
242- return loaded_target ['pytree' ]
243- else :
244- return target
245-
246276 def append_scalar_metrics (self , metrics ):
247277 """Record a dictionary of scalar metrics at a given step.
248278
@@ -289,54 +319,6 @@ def append_scalar_metrics(self, metrics):
289319 # size 512. We could only flush at the end of training to optimize this.
290320 self ._tb_metric_writer .flush ()
291321
292- def write_pytree (self , pytree , step = 0 ):
293- """Record a serializable pytree to disk at the given step.
294-
295- Args:
296- pytree: Any serializable pytree
297- step: Integer. The global step.
298- """
299- state = dict (pytree = pytree )
300- checkpoint .save_checkpoint (
301- step ,
302- state = state ,
303- orbax_checkpoint_manager = self ._orbax_checkpoint_manager ,
304- )
305-
306- def wait_until_pytree_checkpoint_finished (self ):
307- self ._orbax_checkpoint_manager .wait_until_finished ()
308-
309- def latest_pytree_checkpoint_step (self ):
310- return self ._orbax_checkpoint_manager .latest_step ()
311-
312- def append_pytree (self , pytree , step = 0 ):
313- """Append and record a serializable pytree to disk.
314-
315- The pytree will be saved to disk as a list of pytree objects. Everytime
316- this function is called, it will load the previous saved state, append the
317- next pytree to the list, then save the appended list.
318-
319- Args:
320- pytree: Any serializable pytree.
321- step: Integer. The global step.
322- """
323- # Read the latest (and only) checkpoint if it exists, then append the new
324- # state to it before saving back to disk.
325- try :
326- old_state = self .load_latest_pytree (target = None )
327- except ValueError :
328- old_state = None
329- # Because we pass target=None, checkpointing will return the raw state
330- # dict, where 'pytree' is a dict with keys ['0', '1', ...] instead of a
331- # list.
332- if old_state :
333- state_list = [old_state [i ] for i in range (len (old_state ))]
334- else :
335- state_list = []
336- state_list .append (pytree )
337-
338- self .write_pytree (state_list , step = step )
339-
340322 def append_json_object (self , json_obj ):
341323 """Append a json serializable object to the json file."""
342324
0 commit comments