Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions init2winit/trainer_lib/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ def train(self):
batch = next(train_iter)
batch = self.finalize_batch_fn(batch)
if self._global_step == 0:
batch_size_pytree = trainer_utils.get_batch_size(batch)
batch_size_pytree = self.get_batch_size(batch)
if any(
bsz != self._hps.batch_size
for bsz in jax.tree.leaves(batch_size_pytree)
Expand Down Expand Up @@ -673,10 +673,10 @@ def train(self):
@abc.abstractmethod
def update(self, batch, rng, metrics_update_fn, metrics_state, training_cost):
"""Single step of the training loop.

Note this method is responsible for updating the private _global_step
attribute of the Trainer.

Args:
batch: the per-device batch of data to process.
rng: the RNG used for calling the model. `step` and `local_device_index`
Expand Down Expand Up @@ -723,3 +723,6 @@ def finalize_batch_fn(self, batch):

def fetch_learning_rate(self, optimizer_state):
return trainer_utils.fetch_learning_rate(optimizer_state)

def get_batch_size(self, batch):
return trainer_utils.get_batch_size(batch)
1 change: 1 addition & 0 deletions init2winit/trainer_lib/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def evaluate(
metrics = metrics.merge(computed_metrics)

metrics = jax.device_get(process_allgather(metrics, tiled=True))
metrics = jax.tree_util.tree_map(lambda x: x[0] if x.ndim > 0 else x, metrics)
# For data splits with no data (e.g. Imagenet no test set) no values
# will appear for that split.
if metrics is not None:
Expand Down
10 changes: 5 additions & 5 deletions init2winit/trainer_lib/training_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def update_params(
@abc.abstractmethod
def init_optimizer_state(
self,
workload=None,
model=None,
params=None,
model_state=None,
hyperparameters=None,
Expand All @@ -164,7 +164,7 @@ def init_optimizer_state(
"""Initializes the optimizer state.

Args:
workload: The workload being trained.
model: The model being trained.
params: The initial model parameters.
model_state: The initial state of the model.
hyperparameters: The hyperparameters for the training.
Expand Down Expand Up @@ -273,7 +273,7 @@ def update_params(

def init_optimizer_state(
self,
workload=None,
model=None,
params=None,
model_state=None,
hyperparameters=None,
Expand All @@ -282,7 +282,7 @@ def init_optimizer_state(
"""Initializes the optimizer state.

Args:
workload: The workload being trained.
model: The model being trained.
params: The initial model parameters.
model_state: The initial state of the model.
hyperparameters: The hyperparameters for the training.
Expand All @@ -291,7 +291,7 @@ def init_optimizer_state(
Returns:
Optimizer state: Pytree of optimizer state.
"""
del workload, model_state, hyperparameters, rng # Unused
del model, model_state, hyperparameters, rng # Unused
stretch_factor = 1
if self.hps.get('total_accumulated_batch_size') is not None:
stretch_factor = (
Expand Down