Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion init2winit/trainer_lib/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def _eval(self, start_step, start_time, save=True):

if self._eval_use_ema:
eval_params = self.training_algorithm.get_ema_eval_params(
self._optimizer_state
self._optimizer_state, self._params
)
else:
eval_params = self._params
Expand Down
4 changes: 3 additions & 1 deletion init2winit/trainer_lib/training_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,11 +313,12 @@ def init_optimizer_state(
return optax_optimizer_state

# TODO(b/436634470): Consolidate this with the prepare_for_eval API
def get_ema_eval_params(self, optimizer_state):
def get_ema_eval_params(self, optimizer_state, params):
"""Extracts the exponential moving average (EMA) parameters from the optimizer state.

Args:
optimizer_state: The current state of the optimizer.
params: The current model parameters.

Returns:
The EMA parameters.
Expand All @@ -326,6 +327,7 @@ def get_ema_eval_params(self, optimizer_state):
ValueError: If the EMA parameters cannot be extracted from the optimizer
state.
"""
del params # Unused
if isinstance(optimizer_state, optax.InjectStatefulHyperparamsState):
eval_params = optimizer_state.inner_state[0][0].ema
elif isinstance(
Expand Down