diff --git a/init2winit/trainer_lib/base_trainer.py b/init2winit/trainer_lib/base_trainer.py index b1da41d3..248c54c1 100644 --- a/init2winit/trainer_lib/base_trainer.py +++ b/init2winit/trainer_lib/base_trainer.py @@ -447,7 +447,7 @@ def _eval(self, start_step, start_time, save=True): if self._eval_use_ema: eval_params = self.training_algorithm.get_ema_eval_params( - self._optimizer_state + self._optimizer_state, self._params ) else: eval_params = self._params diff --git a/init2winit/trainer_lib/training_algorithm.py b/init2winit/trainer_lib/training_algorithm.py index a7e473d9..f0f57a22 100644 --- a/init2winit/trainer_lib/training_algorithm.py +++ b/init2winit/trainer_lib/training_algorithm.py @@ -313,11 +313,12 @@ def init_optimizer_state( return optax_optimizer_state # TODO(b/436634470): Consolidate this with the prepare_for_eval API - def get_ema_eval_params(self, optimizer_state): + def get_ema_eval_params(self, optimizer_state, params): """Extracts the exponential moving average (EMA) parameters from the optimizer state. Args: optimizer_state: The current state of the optimizer. + params: The current model parameters. Returns: The EMA parameters. @@ -326,6 +327,7 @@ def get_ema_eval_params(self, optimizer_state): ValueError: If the EMA parameters cannot be extracted from the optimizer state. """ + del params # Unused if isinstance(optimizer_state, optax.InjectStatefulHyperparamsState): eval_params = optimizer_state.inner_state[0][0].ema elif isinstance(