Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion init2winit/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@
'optax_training_algorithm',
'Name of the training algorithm to use.',
)
flags.DEFINE_boolean(
'compile_init_on_cpu',
False,
'Whether or not to compile the init function on CPU or on device.',
)

FLAGS = flags.FLAGS

Expand Down Expand Up @@ -208,6 +213,7 @@ def _run(
callback_configs,
external_checkpoint_path,
training_algorithm_name,
compile_init_on_cpu,
):
"""Function that runs a Jax experiment. See flag definitions for args."""
model_cls = models.get_model(model_name)
Expand Down Expand Up @@ -246,7 +252,13 @@ def _run(
rng = jax.random.PRNGKey(rng_seed)

# Build the loss_fn, metrics_bundle, and flax_module.
model = model_cls(merged_hps, dataset_meta_data, loss_name, metrics_name)
model = model_cls(
merged_hps,
dataset_meta_data,
loss_name,
metrics_name,
compile_init_on_cpu=compile_init_on_cpu,
)
trial_dir = os.path.join(experiment_dir, str(worker_id))
meta_data_path = os.path.join(trial_dir, 'meta_data.json')
meta_data = {'worker_id': worker_id, 'status': 'incomplete'}
Expand Down Expand Up @@ -373,6 +385,7 @@ def main(unused_argv):
callback_configs=callback_configs,
external_checkpoint_path=FLAGS.external_checkpoint_path,
training_algorithm_name=FLAGS.training_algorithm,
compile_init_on_cpu=FLAGS.compile_init_on_cpu,
)


Expand Down
10 changes: 9 additions & 1 deletion init2winit/main_config_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def _run(
external_checkpoint_path,
training_algorithm_name,
checkpoint_ttl,
compile_init_on_cpu,
):
"""Function that runs a Jax experiment. See flag definitions for args."""
model_cls = models.get_model(model_name)
Expand Down Expand Up @@ -168,7 +169,13 @@ def _run(
rng = jax.random.PRNGKey(rng_seed)

# Build the loss_fn, metrics_bundle, and flax_module.
model = model_cls(merged_hps, dataset_meta_data, loss_name, metrics_name)
model = model_cls(
merged_hps,
dataset_meta_data,
loss_name,
metrics_name,
compile_init_on_cpu=compile_init_on_cpu,
)
trial_dir = os.path.join(experiment_dir, str(worker_id))
meta_data_path = os.path.join(trial_dir, 'meta_data.json')
meta_data = {'worker_id': worker_id, 'status': 'incomplete'}
Expand Down Expand Up @@ -307,6 +314,7 @@ def main(unused_argv):
external_checkpoint_path=config.external_checkpoint_path,
training_algorithm_name=config.training_algorithm,
checkpoint_ttl=config.ttl,
compile_init_on_cpu=config.compile_init_on_cpu,
)


Expand Down
21 changes: 18 additions & 3 deletions init2winit/model_lib/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,21 @@ class BaseModel(object):
https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.apply.
"""

def __init__(self, hps, dataset_meta_data, loss_name, metrics_name):
def __init__(
self,
hps,
dataset_meta_data,
loss_name,
metrics_name,
compile_init_on_cpu=False,
):
self.hps = hps
self.dataset_meta_data = dataset_meta_data
self._loss_name = loss_name
self.loss_fn = losses.get_loss_fn(loss_name, hps)
self.output_activation_fn = losses.get_output_activation_fn(loss_name)
self.metrics_bundle = metrics.get_metrics(metrics_name, hps)
self._compile_init_on_cpu = compile_init_on_cpu
self.flax_module = self.build_flax_module()

def initialize(self, initializer, hps, rng, metrics_logger):
Expand Down Expand Up @@ -200,10 +208,17 @@ def initialize(self, initializer, hps, rng, metrics_logger):
# construction.
# We initialize model params on host to avoid memory issues.

jit_kwargs = {}
if self._compile_init_on_cpu:
jit_kwargs['backend'] = 'cpu'
logging.info(
'Compiling model init on %s.',
'cpu' if self._compile_init_on_cpu else 'device',
)
start_time = time.time()
model_init_fn = jax.jit(
functools.partial(self.flax_module.init, train=False),
backend='cpu')
functools.partial(self.flax_module.init, train=False), **jit_kwargs
)

init_dict = model_init_fn({'params': params_rng, 'dropout': dropout_rng},
*fake_input_batch)
Expand Down
12 changes: 9 additions & 3 deletions init2winit/model_lib/nqm.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,21 @@ class NQM(base_model.BaseModel):
generates isotropic Gaussian noise.
"""

def __init__(self, hps, dataset_meta_data, loss_name, metrics_name):
def __init__(
self,
hps,
dataset_meta_data,
loss_name,
metrics_name,
compile_init_on_cpu=False,
):
del loss_name

# This is ignored, but is needed to satisfy the initializer API.
self.loss_fn = None
self.metrics_name = metrics_name

self.hps = hps
self.dataset_meta_data = dataset_meta_data
self._compile_init_on_cpu = compile_init_on_cpu
self.flax_module = self.build_flax_module()

def evaluate_batch(self, params, batch_stats, batch):
Expand Down