@@ -142,13 +142,21 @@ class BaseModel(object):
142142 https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.apply.
143143 """
144144
145- def __init__ (self , hps , dataset_meta_data , loss_name , metrics_name ):
145+ def __init__ (
146+ self ,
147+ hps ,
148+ dataset_meta_data ,
149+ loss_name ,
150+ metrics_name ,
151+ compile_init_on_cpu = False ,
152+ ):
146153 self .hps = hps
147154 self .dataset_meta_data = dataset_meta_data
148155 self ._loss_name = loss_name
149156 self .loss_fn = losses .get_loss_fn (loss_name , hps )
150157 self .output_activation_fn = losses .get_output_activation_fn (loss_name )
151158 self .metrics_bundle = metrics .get_metrics (metrics_name , hps )
159+ self ._compile_init_on_cpu = compile_init_on_cpu
152160 self .flax_module = self .build_flax_module ()
153161
154162 def initialize (self , initializer , hps , rng , metrics_logger ):
@@ -200,10 +208,17 @@ def initialize(self, initializer, hps, rng, metrics_logger):
200208 # construction.
201209 # We initialize model params on host to avoid memory issues.
202210
211+ jit_kwargs = {}
212+ if self ._compile_init_on_cpu :
213+ jit_kwargs ['backend' ] = 'cpu'
214+ logging .info (
215+ 'Compiling model init on %s.' ,
216+ 'cpu' if self ._compile_init_on_cpu else 'device' ,
217+ )
203218 start_time = time .time ()
204219 model_init_fn = jax .jit (
205- functools .partial (self .flax_module .init , train = False ),
206- backend = 'cpu' )
220+ functools .partial (self .flax_module .init , train = False ), ** jit_kwargs
221+ )
207222
208223 init_dict = model_init_fn ({'params' : params_rng , 'dropout' : dropout_rng },
209224 * fake_input_batch )
0 commit comments