File tree Expand file tree Collapse file tree 2 files changed +5
-6
lines changed
algorithmic_efficiency/workloads
imagenet_vit/imagenet_jax Expand file tree Collapse file tree 2 files changed +5
-6
lines changed Original file line number Diff line number Diff line change 2121# Make sure we inherit from the ViT base workload first.
2222class ImagenetVitWorkload (BaseImagenetVitWorkload , ImagenetResNetWorkload ):
2323
24- def initialized (self ,
25- key : spec .RandomState ,
24+ def initialized (self , key : spec .RandomState ,
2625 model : nn .Module ) -> spec .ModelInitState :
2726 input_shape = (1 , 224 , 224 , 3 )
2827 params_rng , dropout_rng = jax .random .split (key )
Original file line number Diff line number Diff line change @@ -235,10 +235,10 @@ def init_model_fn(
235235 eval_config = replace (model_config , deterministic = True )
236236 self ._eval_model = models .Transformer (eval_config )
237237 params_rng , dropout_rng = jax .random .split (rng )
238- initial_variables = jax .jit (self . _eval_model . init )(
239- {'params' : params_rng , 'dropout' : dropout_rng },
240- jnp .ones (input_shape , jnp .float32 ),
241- jnp .ones (target_shape , jnp .float32 ))
238+ initial_variables = jax .jit (
239+ self . _eval_model . init )( {'params' : params_rng , 'dropout' : dropout_rng },
240+ jnp .ones (input_shape , jnp .float32 ),
241+ jnp .ones (target_shape , jnp .float32 ))
242242
243243 initial_params = initial_variables ['params' ]
244244 self ._param_shapes = param_utils .jax_param_shapes (initial_params )
You can’t perform that action at this time.
0 commit comments