Skip to content

Commit ec722c5

Browse files
Merge pull request #677 from runame/fix-rng
Add missing `rng` for ImageNet-ViT and WMT
2 parents 46e624a + 1d0e8a9 commit ec722c5

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@ class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload):
2424
def initialized(self, key: spec.RandomState,
2525
model: nn.Module) -> spec.ModelInitState:
2626
input_shape = (1, 224, 224, 3)
27-
variables = jax.jit(model.init)({'params': key}, jnp.ones(input_shape))
27+
params_rng, dropout_rng = jax.random.split(key)
28+
variables = jax.jit(
29+
model.init)({'params': params_rng, 'dropout': dropout_rng},
30+
jnp.ones(input_shape))
2831
model_state, params = variables.pop('params')
2932
return params, model_state
3033

algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,10 +234,11 @@ def init_model_fn(
234234
self._train_model = models.Transformer(model_config)
235235
eval_config = replace(model_config, deterministic=True)
236236
self._eval_model = models.Transformer(eval_config)
237-
initial_variables = jax.jit(self._eval_model.init)(
238-
rng,
239-
jnp.ones(input_shape, jnp.float32),
240-
jnp.ones(target_shape, jnp.float32))
237+
params_rng, dropout_rng = jax.random.split(rng)
238+
initial_variables = jax.jit(
239+
self._eval_model.init)({'params': params_rng, 'dropout': dropout_rng},
240+
jnp.ones(input_shape, jnp.float32),
241+
jnp.ones(target_shape, jnp.float32))
241242

242243
initial_params = initial_variables['params']
243244
self._param_shapes = param_utils.jax_param_shapes(initial_params)

0 commit comments

Comments
 (0)