We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 12ebb82 + e23a889 commit ab9e3fbCopy full SHA for ab9e3fb
algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py
@@ -33,8 +33,10 @@ def init_model_fn(
33
use_tanh=self.use_tanh,
34
use_layer_norm=self.use_layer_norm,
35
dropout_rate=dropout_rate)
36
-
37
- variables = jax.jit(self._model.init)({'params': rng}, fake_batch)
+ params_rng, dropout_rng = jax.random.split(rng)
+ variables = jax.jit(
38
+ self._model.init)({'params': params_rng, 'dropout': dropout_rng},
39
+ fake_batch)
40
params = variables['params']
41
self._param_shapes = param_utils.jax_param_shapes(params)
42
self._param_types = param_utils.jax_param_types(self._param_shapes)
0 commit comments