Skip to content

Commit 1d0e8a9

Browse files
committed
Fix yapf
1 parent e31f2ca commit 1d0e8a9

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121
# Make sure we inherit from the ViT base workload first.
2222
class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload):
2323

24-
def initialized(self,
25-
key: spec.RandomState,
24+
def initialized(self, key: spec.RandomState,
2625
model: nn.Module) -> spec.ModelInitState:
2726
input_shape = (1, 224, 224, 3)
2827
params_rng, dropout_rng = jax.random.split(key)

algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -235,10 +235,10 @@ def init_model_fn(
235235
eval_config = replace(model_config, deterministic=True)
236236
self._eval_model = models.Transformer(eval_config)
237237
params_rng, dropout_rng = jax.random.split(rng)
238-
initial_variables = jax.jit(self._eval_model.init)(
239-
{'params': params_rng, 'dropout': dropout_rng},
240-
jnp.ones(input_shape, jnp.float32),
241-
jnp.ones(target_shape, jnp.float32))
238+
initial_variables = jax.jit(
239+
self._eval_model.init)({'params': params_rng, 'dropout': dropout_rng},
240+
jnp.ones(input_shape, jnp.float32),
241+
jnp.ones(target_shape, jnp.float32))
242242

243243
initial_params = initial_variables['params']
244244
self._param_shapes = param_utils.jax_param_shapes(initial_params)

0 commit comments

Comments
 (0)