diff --git a/init2winit/main.py b/init2winit/main.py index 69fc74ec..64dcfd2b 100644 --- a/init2winit/main.py +++ b/init2winit/main.py @@ -153,6 +153,11 @@ 'optax_training_algorithm', 'Name of the training algorithm to use.', ) +flags.DEFINE_boolean( + 'compile_init_on_cpu', + False, + 'Whether or not to compile the init function on CPU or on device.', +) FLAGS = flags.FLAGS @@ -208,6 +213,7 @@ def _run( callback_configs, external_checkpoint_path, training_algorithm_name, + compile_init_on_cpu, ): """Function that runs a Jax experiment. See flag definitions for args.""" model_cls = models.get_model(model_name) @@ -246,7 +252,13 @@ def _run( rng = jax.random.PRNGKey(rng_seed) # Build the loss_fn, metrics_bundle, and flax_module. - model = model_cls(merged_hps, dataset_meta_data, loss_name, metrics_name) + model = model_cls( + merged_hps, + dataset_meta_data, + loss_name, + metrics_name, + compile_init_on_cpu=compile_init_on_cpu, + ) trial_dir = os.path.join(experiment_dir, str(worker_id)) meta_data_path = os.path.join(trial_dir, 'meta_data.json') meta_data = {'worker_id': worker_id, 'status': 'incomplete'} @@ -373,6 +385,7 @@ def main(unused_argv): callback_configs=callback_configs, external_checkpoint_path=FLAGS.external_checkpoint_path, training_algorithm_name=FLAGS.training_algorithm, + compile_init_on_cpu=FLAGS.compile_init_on_cpu, ) diff --git a/init2winit/main_config_flags.py b/init2winit/main_config_flags.py index ef991754..f0fe83e4 100644 --- a/init2winit/main_config_flags.py +++ b/init2winit/main_config_flags.py @@ -130,6 +130,7 @@ def _run( external_checkpoint_path, training_algorithm_name, checkpoint_ttl, + compile_init_on_cpu, ): """Function that runs a Jax experiment. See flag definitions for args.""" model_cls = models.get_model(model_name) @@ -168,7 +169,13 @@ def _run( rng = jax.random.PRNGKey(rng_seed) # Build the loss_fn, metrics_bundle, and flax_module. - model = model_cls(merged_hps, dataset_meta_data, loss_name, metrics_name) + model = model_cls( + merged_hps, + dataset_meta_data, + loss_name, + metrics_name, + compile_init_on_cpu=compile_init_on_cpu, + ) trial_dir = os.path.join(experiment_dir, str(worker_id)) meta_data_path = os.path.join(trial_dir, 'meta_data.json') meta_data = {'worker_id': worker_id, 'status': 'incomplete'} @@ -307,6 +314,7 @@ def main(unused_argv): external_checkpoint_path=config.external_checkpoint_path, training_algorithm_name=config.training_algorithm, checkpoint_ttl=config.ttl, + compile_init_on_cpu=config.compile_init_on_cpu, ) diff --git a/init2winit/model_lib/base_model.py b/init2winit/model_lib/base_model.py index 35741660..a19a07b2 100644 --- a/init2winit/model_lib/base_model.py +++ b/init2winit/model_lib/base_model.py @@ -142,13 +142,21 @@ class BaseModel(object): https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.apply. """ - def __init__(self, hps, dataset_meta_data, loss_name, metrics_name): + def __init__( + self, + hps, + dataset_meta_data, + loss_name, + metrics_name, + compile_init_on_cpu=False, + ): self.hps = hps self.dataset_meta_data = dataset_meta_data self._loss_name = loss_name self.loss_fn = losses.get_loss_fn(loss_name, hps) self.output_activation_fn = losses.get_output_activation_fn(loss_name) self.metrics_bundle = metrics.get_metrics(metrics_name, hps) + self._compile_init_on_cpu = compile_init_on_cpu self.flax_module = self.build_flax_module() def initialize(self, initializer, hps, rng, metrics_logger): @@ -200,10 +208,17 @@ def initialize(self, initializer, hps, rng, metrics_logger): # construction. # We initialize model params on host to avoid memory issues. + jit_kwargs = {} + if self._compile_init_on_cpu: + jit_kwargs['backend'] = 'cpu' + logging.info( + 'Compiling model init on %s.', + 'cpu' if self._compile_init_on_cpu else 'device', + ) start_time = time.time() model_init_fn = jax.jit( - functools.partial(self.flax_module.init, train=False), - backend='cpu') + functools.partial(self.flax_module.init, train=False), **jit_kwargs + ) init_dict = model_init_fn({'params': params_rng, 'dropout': dropout_rng}, *fake_input_batch) diff --git a/init2winit/model_lib/nqm.py b/init2winit/model_lib/nqm.py index 9c51ffe4..818bddf7 100644 --- a/init2winit/model_lib/nqm.py +++ b/init2winit/model_lib/nqm.py @@ -177,15 +177,21 @@ class NQM(base_model.BaseModel): generates isotropic Gaussian noise. """ - def __init__(self, hps, dataset_meta_data, loss_name, metrics_name): + def __init__( + self, + hps, + dataset_meta_data, + loss_name, + metrics_name, + compile_init_on_cpu=False, + ): del loss_name - # This is ignored, but is needed to satisfy the initializer API. self.loss_fn = None self.metrics_name = metrics_name - self.hps = hps self.dataset_meta_data = dataset_meta_data + self._compile_init_on_cpu = compile_init_on_cpu self.flax_module = self.build_flax_module() def evaluate_batch(self, params, batch_stats, batch):