Skip to content

Commit d41c715

Browse files
init2winit Teamcopybara-github
authored andcommitted
internal
PiperOrigin-RevId: 872540150
1 parent 26eaa80 commit d41c715

4 files changed

Lines changed: 50 additions & 8 deletions

File tree

init2winit/main.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,11 @@
153153
'optax_training_algorithm',
154154
'Name of the training algorithm to use.',
155155
)
156+
flags.DEFINE_boolean(
157+
'compile_init_on_cpu',
158+
False,
159+
'Whether or not to compile the init function on CPU or on device.',
160+
)
156161

157162
FLAGS = flags.FLAGS
158163

@@ -208,6 +213,7 @@ def _run(
208213
callback_configs,
209214
external_checkpoint_path,
210215
training_algorithm_name,
216+
compile_init_on_cpu,
211217
):
212218
"""Function that runs a Jax experiment. See flag definitions for args."""
213219
model_cls = models.get_model(model_name)
@@ -246,7 +252,13 @@ def _run(
246252
rng = jax.random.PRNGKey(rng_seed)
247253

248254
# Build the loss_fn, metrics_bundle, and flax_module.
249-
model = model_cls(merged_hps, dataset_meta_data, loss_name, metrics_name)
255+
model = model_cls(
256+
merged_hps,
257+
dataset_meta_data,
258+
loss_name,
259+
metrics_name,
260+
compile_init_on_cpu=compile_init_on_cpu,
261+
)
250262
trial_dir = os.path.join(experiment_dir, str(worker_id))
251263
meta_data_path = os.path.join(trial_dir, 'meta_data.json')
252264
meta_data = {'worker_id': worker_id, 'status': 'incomplete'}
@@ -373,6 +385,7 @@ def main(unused_argv):
373385
callback_configs=callback_configs,
374386
external_checkpoint_path=FLAGS.external_checkpoint_path,
375387
training_algorithm_name=FLAGS.training_algorithm,
388+
compile_init_on_cpu=FLAGS.compile_init_on_cpu,
376389
)
377390

378391

init2winit/main_config_flags.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def _run(
130130
external_checkpoint_path,
131131
training_algorithm_name,
132132
checkpoint_ttl,
133+
compile_init_on_cpu,
133134
):
134135
"""Function that runs a Jax experiment. See flag definitions for args."""
135136
model_cls = models.get_model(model_name)
@@ -168,7 +169,13 @@ def _run(
168169
rng = jax.random.PRNGKey(rng_seed)
169170

170171
# Build the loss_fn, metrics_bundle, and flax_module.
171-
model = model_cls(merged_hps, dataset_meta_data, loss_name, metrics_name)
172+
model = model_cls(
173+
merged_hps,
174+
dataset_meta_data,
175+
loss_name,
176+
metrics_name,
177+
compile_init_on_cpu=compile_init_on_cpu,
178+
)
172179
trial_dir = os.path.join(experiment_dir, str(worker_id))
173180
meta_data_path = os.path.join(trial_dir, 'meta_data.json')
174181
meta_data = {'worker_id': worker_id, 'status': 'incomplete'}
@@ -307,6 +314,7 @@ def main(unused_argv):
307314
external_checkpoint_path=config.external_checkpoint_path,
308315
training_algorithm_name=config.training_algorithm,
309316
checkpoint_ttl=config.ttl,
317+
compile_init_on_cpu=config.compile_init_on_cpu,
310318
)
311319

312320

init2winit/model_lib/base_model.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,13 +142,21 @@ class BaseModel(object):
142142
https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.apply.
143143
"""
144144

145-
def __init__(self, hps, dataset_meta_data, loss_name, metrics_name):
145+
def __init__(
146+
self,
147+
hps,
148+
dataset_meta_data,
149+
loss_name,
150+
metrics_name,
151+
compile_init_on_cpu=False,
152+
):
146153
self.hps = hps
147154
self.dataset_meta_data = dataset_meta_data
148155
self._loss_name = loss_name
149156
self.loss_fn = losses.get_loss_fn(loss_name, hps)
150157
self.output_activation_fn = losses.get_output_activation_fn(loss_name)
151158
self.metrics_bundle = metrics.get_metrics(metrics_name, hps)
159+
self._compile_init_on_cpu = compile_init_on_cpu
152160
self.flax_module = self.build_flax_module()
153161

154162
def initialize(self, initializer, hps, rng, metrics_logger):
@@ -200,10 +208,17 @@ def initialize(self, initializer, hps, rng, metrics_logger):
200208
# construction.
201209
# We initialize model params on host to avoid memory issues.
202210

211+
jit_kwargs = {}
212+
if self._compile_init_on_cpu:
213+
jit_kwargs['backend'] = 'cpu'
214+
logging.info(
215+
'Compiling model init on %s.',
216+
'cpu' if self._compile_init_on_cpu else 'device',
217+
)
203218
start_time = time.time()
204219
model_init_fn = jax.jit(
205-
functools.partial(self.flax_module.init, train=False),
206-
backend='cpu')
220+
functools.partial(self.flax_module.init, train=False), **jit_kwargs
221+
)
207222

208223
init_dict = model_init_fn({'params': params_rng, 'dropout': dropout_rng},
209224
*fake_input_batch)

init2winit/model_lib/nqm.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,15 +177,21 @@ class NQM(base_model.BaseModel):
177177
generates isotropic Gaussian noise.
178178
"""
179179

180-
def __init__(self, hps, dataset_meta_data, loss_name, metrics_name):
180+
def __init__(
181+
self,
182+
hps,
183+
dataset_meta_data,
184+
loss_name,
185+
metrics_name,
186+
compile_init_on_cpu=False,
187+
):
181188
del loss_name
182-
183189
# This is ignored, but is needed to satisfy the initializer API.
184190
self.loss_fn = None
185191
self.metrics_name = metrics_name
186-
187192
self.hps = hps
188193
self.dataset_meta_data = dataset_meta_data
194+
self._compile_init_on_cpu = compile_init_on_cpu
189195
self.flax_module = self.build_flax_module()
190196

191197
def evaluate_batch(self, params, batch_stats, batch):

0 commit comments

Comments
 (0)