Skip to content

Commit ab9e3fb

Browse files
Merge pull request #672 from runame/fix-rng
Use separate `rng` for dropout in FastMRI `model_init`
2 parents 12ebb82 + e23a889 commit ab9e3fb

File tree

1 file changed

+4
-2
lines changed
  • algorithmic_efficiency/workloads/fastmri/fastmri_jax

1 file changed

+4
-2
lines changed

algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@ def init_model_fn(
3333
use_tanh=self.use_tanh,
3434
use_layer_norm=self.use_layer_norm,
3535
dropout_rate=dropout_rate)
36-
37-
variables = jax.jit(self._model.init)({'params': rng}, fake_batch)
36+
params_rng, dropout_rng = jax.random.split(rng)
37+
variables = jax.jit(
38+
self._model.init)({'params': params_rng, 'dropout': dropout_rng},
39+
fake_batch)
3840
params = variables['params']
3941
self._param_shapes = param_utils.jax_param_shapes(params)
4042
self._param_types = param_utils.jax_param_types(self._param_shapes)

0 commit comments

Comments
 (0)