diff --git a/init2winit/hyperparameters.py b/init2winit/hyperparameters.py index a366391d..3a52b0b4 100644 --- a/init2winit/hyperparameters.py +++ b/init2winit/hyperparameters.py @@ -150,6 +150,8 @@ def build_hparams(model_name, if merged_dict.get('use_shallue_label_smoothing', False): num_classes = merged_dict['output_shape'][-1] merged_dict['label_smoothing'] *= num_classes / float(num_classes - 1) + if 'compile_init_on_cpu' not in merged_dict: + merged_dict['compile_init_on_cpu'] = False merged = config_dict.ConfigDict(merged_dict) merged.lock() diff --git a/init2winit/model_lib/base_model.py b/init2winit/model_lib/base_model.py index 35741660..63017f34 100644 --- a/init2winit/model_lib/base_model.py +++ b/init2winit/model_lib/base_model.py @@ -200,10 +200,18 @@ def initialize(self, initializer, hps, rng, metrics_logger): # construction. # We initialize model params on host to avoid memory issues. + compile_init_on_cpu = hps.get('compile_init_on_cpu', False) + jit_kwargs = {} + if compile_init_on_cpu: + jit_kwargs['backend'] = 'cpu' + logging.info( + 'Compiling model init on %s.', + 'cpu' if compile_init_on_cpu else 'device', + ) start_time = time.time() model_init_fn = jax.jit( - functools.partial(self.flax_module.init, train=False), - backend='cpu') + functools.partial(self.flax_module.init, train=False), **jit_kwargs + ) init_dict = model_init_fn({'params': params_rng, 'dropout': dropout_rng}, *fake_input_batch)