Skip to content

Commit 26eaa80

Browse files
priyakasimbegcopybara-github
authored andcommitted
Logger fix
PiperOrigin-RevId: 869952160
1 parent d937d46 commit 26eaa80

6 files changed

Lines changed: 87 additions & 112 deletions

File tree

hessian/model_debugger.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ def __init__(self,
288288
use_pmap=True,
289289
save_every=1,
290290
metrics_logger=None,
291+
pytree_metrics_logger=None,
291292
skip_flags=None,
292293
skip_groups=None):
293294
"""Used to inspect a models forward and backward pass.
@@ -305,6 +306,7 @@ def __init__(self,
305306
save_every == 0
306307
metrics_logger: utils.MetricsLogger object. If provided then all
307308
calculations will be saved to disk.
309+
pytree_metrics_logger: utils.PytreeMetricLogger object.
308310
skip_flags: A list of strings of modules to selectively turn off when
309311
doing the skip analysis on the backward pass.
310312
skip_groups: A list of registered functions (defined in partition_tree.py)
@@ -317,6 +319,7 @@ def __init__(self,
317319
' path must be specified when building metrics_logger')
318320
self._save_every = save_every
319321
self._metrics_logger = metrics_logger
322+
self._pytree_metrics_logger = pytree_metrics_logger
320323
self._use_pmap = use_pmap
321324
self.forward_pass = None
322325
self.grad_fn = None
@@ -340,8 +343,11 @@ def __init__(self,
340343
self._stored_metrics = {}
341344

342345
# In the case of preemption we want to restore prior metrics.
343-
if metrics_logger and metrics_logger.latest_pytree_checkpoint_step():
344-
self._stored_metrics = metrics_logger.load_latest_pytree()
346+
if (
347+
pytree_metrics_logger
348+
and pytree_metrics_logger.latest_pytree_checkpoint_step()
349+
):
350+
self._stored_metrics = pytree_metrics_logger.load_latest_pytree()
345351

346352
def _grab_statistics(self,
347353
step,
@@ -375,9 +381,9 @@ def _maybe_save_metrics(self, step):
375381
save_dict = self._stored_metrics.copy()
376382
if self._save_every:
377383
if step % self._save_every == 0:
378-
self._metrics_logger.write_pytree(save_dict, step=step)
384+
self._pytree_metrics_logger.write_pytree(save_dict, step=step)
379385
else:
380-
self._metrics_logger.write_pytree(save_dict, step=step)
386+
self._pytree_metrics_logger.write_pytree(save_dict, step=step)
381387

382388
@property
383389
def stored_metrics(self):

hessian/model_debugger_callback.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def __init__(
9999
self.dataset = dataset
100100
checkpoint_dir = os.path.join(train_dir, 'checkpoints')
101101
pytree_path = os.path.join(checkpoint_dir, 'debugger')
102-
logger = utils.MetricLogger(pytree_path=pytree_path)
102+
logger = utils.PytreeMetricLogger(pytree_path=pytree_path)
103103

104104
get_act_stats_fn = model_debugger.create_forward_pass_stats_fn(
105105
model.apply_on_batch,

hessian/test_model_debugger.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -286,16 +286,20 @@ def test_model_debugger_restore(self):
286286

287287
rep_variables = set_up_cnn()
288288
pytree_path = os.path.join(self.test_dir, 'metrics')
289-
metrics_logger = utils.MetricLogger(
290-
pytree_path=pytree_path, events_dir=self.test_dir)
289+
pytree_metrics_logger = utils.PytreeMetricLogger(
290+
pytree_path=pytree_path)
291+
metrics_logger = utils.MetricLogger(self.test_dir)
291292

292293
# Fake grad_fn for testing.
293294
def grad_fn(params, batch, rng):
294295
del params, batch, rng
295296
return rep_variables['params']
296297

297298
debugger = model_debugger.ModelDebugger(
298-
use_pmap=False, grad_fn=grad_fn, metrics_logger=metrics_logger)
299+
use_pmap=False,
300+
grad_fn=grad_fn,
301+
metrics_logger=metrics_logger,
302+
pytree_metrics_logger=pytree_metrics_logger)
299303

300304
# eval twice to test the concat
301305
extra_metrics = {'train_loss': 1.0}
@@ -319,8 +323,8 @@ def grad_fn(params, batch, rng):
319323
'train_loss',
320324
]
321325

322-
metrics_logger.wait_until_pytree_checkpoint_finished()
323-
loaded_metrics = metrics_logger.load_latest_pytree(None)
326+
pytree_metrics_logger.wait_until_pytree_checkpoint_finished()
327+
loaded_metrics = pytree_metrics_logger.load_latest_pytree(None)
324328

325329
self.assertEqual(set(expected_keys), set(metrics.keys()))
326330
expected_shape = ()
@@ -337,13 +341,15 @@ def grad_fn(params, batch, rng):
337341

338342
# Test restore of prior metrics.
339343
new_debugger = model_debugger.ModelDebugger(
340-
use_pmap=True, metrics_logger=metrics_logger)
344+
use_pmap=True,
345+
metrics_logger=metrics_logger,
346+
pytree_metrics_logger=pytree_metrics_logger)
341347
_ = new_debugger.full_eval(
342348
30,
343349
params=rep_variables['params'],
344350
grad=rep_variables['params'],
345351
extra_scalar_metrics=extra_metrics2)
346-
metrics_logger.wait_until_pytree_checkpoint_finished()
352+
pytree_metrics_logger.wait_until_pytree_checkpoint_finished()
347353
self.assertEqual(
348354
new_debugger.stored_metrics['param_norms_sql2']['Conv_0']
349355
['kernel'].shape, (3,))

init2winit/model_lib/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1018,7 +1018,7 @@ def test_vit_model_apply_sharding_overrides_failure(self):
10181018
str(error),
10191019
'Param shape (1536,) is not compatible with sharding '
10201020
'NamedSharding(mesh=Mesh(\'devices\': 8, axis_types=(Auto,)), '
1021-
'spec=PartitionSpec(None, \'devices\'), memory_kind=device)',
1021+
'spec=P(None, \'devices\'), memory_kind=device)',
10221022
)
10231023

10241024
good_overrides = {

init2winit/test_utils.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
1818
"""
1919

20-
import os
2120
import shutil
2221
import tempfile
2322

@@ -84,24 +83,6 @@ def test_run_in_parallel_on_failing_fn(self):
8483
with self.assertRaisesRegex(ValueError, 'I always fail.'):
8584
utils.run_in_parallel(_fn_that_always_fails, [dict(arg='hi')], 10)
8685

87-
def test_append_pytree(self):
88-
"""Test appending and loading pytrees."""
89-
pytrees = [{'a': i} for i in range(10)]
90-
pytree_path = os.path.join(self.test_dir, 'pytrees')
91-
logger = utils.MetricLogger(pytree_path=pytree_path)
92-
93-
for i, pytree in enumerate(pytrees):
94-
logger.append_pytree(pytree, step=i)
95-
96-
latest = logger.load_latest_pytree(
97-
target=None,
98-
)
99-
saved_pytrees = latest if latest else []
100-
101-
self.assertEqual(
102-
pytrees, [saved_pytrees[i] for i in range(len(saved_pytrees))]
103-
)
104-
10586
def test_array_append(self):
10687
"""Test appending to an array."""
10788
np.testing.assert_allclose(

init2winit/utils.py

Lines changed: 62 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,8 @@ def wrapper(*args, **kwargs):
106106
def set_up_loggers(train_dir, xm_work_unit=None):
107107
"""Creates a logger for eval metrics as well as initialization metrics."""
108108
csv_path = os.path.join(train_dir, 'measurements.csv')
109-
pytree_path = os.path.join(train_dir, 'ttl=180d', 'training_metrics')
110109
metrics_logger = MetricLogger(
111110
csv_path=csv_path,
112-
pytree_path=pytree_path,
113111
xm_work_unit=xm_work_unit,
114112
events_dir=train_dir)
115113

@@ -171,6 +169,68 @@ def add_log_file(logfile):
171169
absl_logger.addHandler(handler)
172170

173171

172+
class PytreeMetricLogger(object):
173+
"""Used to log pytree metrics during training.
174+
175+
This class is used to log pytree metrics during training. It is similar to
176+
MetricLogger, but it is designed to log pytree metrics.
177+
"""
178+
179+
def __init__(
180+
self,
181+
pytree_path,
182+
):
183+
self._pytree_path = pytree_path
184+
185+
orbax_file_options = ocp.checkpoint_manager.FileOptions(
186+
path_permission_mode=0o775,
187+
cns2_storage_options=ocp.options.Cns2StorageOptions(
188+
choose_store_cell=True,
189+
),
190+
)
191+
self._orbax_checkpoint_manager = ocp.CheckpointManager(
192+
self._pytree_path,
193+
options=ocp.CheckpointManagerOptions(
194+
file_options=orbax_file_options,
195+
max_to_keep=1, create=True,
196+
),
197+
)
198+
199+
def write_pytree(self, pytree, step=0):
200+
"""Record a serializable pytree to disk at the given step.
201+
202+
Args:
203+
pytree: Any serializable pytree
204+
step: Integer. The global step.
205+
"""
206+
state = dict(pytree=pytree)
207+
checkpoint.save_checkpoint(
208+
step,
209+
state=state,
210+
orbax_checkpoint_manager=self._orbax_checkpoint_manager,
211+
)
212+
213+
def wait_until_pytree_checkpoint_finished(self):
214+
self._orbax_checkpoint_manager.wait_until_finished()
215+
216+
def latest_pytree_checkpoint_step(self):
217+
return self._orbax_checkpoint_manager.latest_step()
218+
219+
def load_latest_pytree(self, target=None):
220+
"""Load pytree from checkpoint."""
221+
if target:
222+
target = dict(pytree=target)
223+
logging.info('target: %s', target)
224+
loaded_target = checkpoint.load_latest_checkpoint(
225+
target=target,
226+
orbax_checkpoint_manager=self._orbax_checkpoint_manager,
227+
)
228+
if loaded_target:
229+
return loaded_target['pytree']
230+
else:
231+
return target
232+
233+
174234
# TODO(gdahl,gilmer): Use atomic writes to avoid file corruptions due to
175235
# preemptions.
176236
class MetricLogger(object):
@@ -183,7 +243,6 @@ class MetricLogger(object):
183243
def __init__(self,
184244
csv_path='',
185245
json_path='',
186-
pytree_path='',
187246
events_dir=None,
188247
**logger_kwargs):
189248
"""Create a recorder for metrics, as CSV or JSON.
@@ -192,7 +251,6 @@ def __init__(self,
192251
Args:
193252
csv_path: A filepath to a CSV file to append to.
194253
json_path: An optional filepath to a JSON file to append to.
195-
pytree_path: Where to save trees of numeric arrays.
196254
events_dir: Optional. If specified, save tfevents summaries to this
197255
directory.
198256
**logger_kwargs: Optional keyword arguments, whose only valid parameter
@@ -202,7 +260,6 @@ def __init__(self,
202260
self._measurements = {}
203261
self._csv_path = csv_path
204262
self._json_path = json_path
205-
self._pytree_path = pytree_path
206263
if logger_kwargs:
207264
if len(logger_kwargs.keys()) > 1 or 'xm_work_unit' not in logger_kwargs:
208265
raise ValueError(
@@ -216,33 +273,6 @@ def __init__(self,
216273
if events_dir:
217274
self._tb_metric_writer = metric_writers.create_default_writer(events_dir)
218275

219-
orbax_file_options = ocp.checkpoint_manager.FileOptions(
220-
path_permission_mode=0o775,
221-
cns2_storage_options=ocp.options.Cns2StorageOptions(
222-
choose_store_cell=True,
223-
),
224-
)
225-
self._orbax_checkpoint_manager = ocp.CheckpointManager(
226-
self._pytree_path,
227-
options=ocp.CheckpointManagerOptions(
228-
max_to_keep=1, create=True, file_options=orbax_file_options
229-
),
230-
)
231-
232-
def load_latest_pytree(self, target=None):
233-
"""Load pytree from checkpoint."""
234-
if target:
235-
target = dict(pytree=target)
236-
logging.info('target: %s', target)
237-
loaded_target = checkpoint.load_latest_checkpoint(
238-
target=target,
239-
orbax_checkpoint_manager=self._orbax_checkpoint_manager,
240-
)
241-
if loaded_target:
242-
return loaded_target['pytree']
243-
else:
244-
return target
245-
246276
def append_scalar_metrics(self, metrics):
247277
"""Record a dictionary of scalar metrics at a given step.
248278
@@ -289,54 +319,6 @@ def append_scalar_metrics(self, metrics):
289319
# size 512. We could only flush at the end of training to optimize this.
290320
self._tb_metric_writer.flush()
291321

292-
def write_pytree(self, pytree, step=0):
293-
"""Record a serializable pytree to disk at the given step.
294-
295-
Args:
296-
pytree: Any serializable pytree
297-
step: Integer. The global step.
298-
"""
299-
state = dict(pytree=pytree)
300-
checkpoint.save_checkpoint(
301-
step,
302-
state=state,
303-
orbax_checkpoint_manager=self._orbax_checkpoint_manager,
304-
)
305-
306-
def wait_until_pytree_checkpoint_finished(self):
307-
self._orbax_checkpoint_manager.wait_until_finished()
308-
309-
def latest_pytree_checkpoint_step(self):
310-
return self._orbax_checkpoint_manager.latest_step()
311-
312-
def append_pytree(self, pytree, step=0):
313-
"""Append and record a serializable pytree to disk.
314-
315-
The pytree will be saved to disk as a list of pytree objects. Everytime
316-
this function is called, it will load the previous saved state, append the
317-
next pytree to the list, then save the appended list.
318-
319-
Args:
320-
pytree: Any serializable pytree.
321-
step: Integer. The global step.
322-
"""
323-
# Read the latest (and only) checkpoint if it exists, then append the new
324-
# state to it before saving back to disk.
325-
try:
326-
old_state = self.load_latest_pytree(target=None)
327-
except ValueError:
328-
old_state = None
329-
# Because we pass target=None, checkpointing will return the raw state
330-
# dict, where 'pytree' is a dict with keys ['0', '1', ...] instead of a
331-
# list.
332-
if old_state:
333-
state_list = [old_state[i] for i in range(len(old_state))]
334-
else:
335-
state_list = []
336-
state_list.append(pytree)
337-
338-
self.write_pytree(state_list, step=step)
339-
340322
def append_json_object(self, json_obj):
341323
"""Append a json serializable object to the json file."""
342324

0 commit comments

Comments
 (0)