Skip to content

RFC: Update-based training loop refactor, scale loss and token count calculation rather than gradients #2543

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 141 additions & 127 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
@@ -767,147 +767,161 @@ def train(self) -> None:
for curr_epoch in range(self.epochs_run, self.total_epochs):
pbar = tqdm(total=self._steps_per_epoch, disable=not self._is_rank_zero)
self._dataloader.sampler.set_epoch(curr_epoch)
for idx, batch in enumerate(self._dataloader):

dataloader_iter = iter(self._dataloader)

for update_step in range(self._steps_per_epoch):
# Start tracking CUDA memory for active steps for just the first epoch
if (
self._is_rank_zero
and curr_epoch == 0
and self.profiler_profile_memory
and idx == self.profiler_wait_steps + self.profiler_warmup_steps
# TODO: confirm we want these steps according to *update steps*, not *forward steps*
and update_step
== self.profiler_wait_steps + self.profiler_warmup_steps
and self._device.type == "cuda"
):
torch.cuda.memory._record_memory_history()
utils.batch_to_device(batch, self._device)

# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step
current_num_tokens = (
batch["labels"] != self._loss_fn.ignore_index
).sum()
num_tokens += current_num_tokens

# Shape [b, s], needed for the loss not the model
labels = batch.pop("labels")

with self.activations_handling_ctx:
logits = self._model(**batch)
# Shift labels to compute loss
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
# But this way we dont need to slice the logits. We just add an ignore index to labels.
labels = torch.hstack(
(labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
)
if not isinstance(logits, list):
labels = labels.reshape(-1)
logits = logits.reshape(-1, logits.size(-1))

# Compute loss
# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
current_loss = self._loss_fn(logits, labels) * current_num_tokens

# free logits otherwise it peaks backward memory
del logits

running_loss += current_loss

# For optimizer in backward, we need to normalize before calling backward
# This case and gradient accumulation are mutually exclusive
if self._optimizer_in_bwd:
torch.distributed.all_reduce(num_tokens)
torch.distributed.all_reduce(running_loss)
current_loss = current_loss * (self.dp_degree / num_tokens)

current_loss.backward()
# Optimizer step (if not fused in backward call)
if (idx + 1) % self._gradient_accumulation_steps == 0:
if not self._optimizer_in_bwd:
# Get total number of tokens across all ranks to normalize gradients
torch.distributed.all_reduce(num_tokens)
# This will ensure that the logged loss matches what we're optimizing
torch.distributed.all_reduce(running_loss)
# Manually scale the gradients from unnormalized loss by total # of tokens
training.scale_grads(self._model, self.dp_degree / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
max_norm=float(self._clip_grad_norm),
)
# If sharded, collect the DTensor here
if isinstance(grad_norm, DTensor):
grad_norm = grad_norm.full_tensor()
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)

# Update the number of steps when the weights are updated
self.global_step += 1

# Step the learning rate scheduler
if self._lr_scheduler is not None:
self._lr_scheduler.step()

loss_to_log = running_loss.item() / num_tokens
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"

# pre-fetch examples and move to device
batches = []
for _ in range(self._gradient_accumulation_steps):
try:
batch = next(dataloader_iter)
except StopIteration:
break

# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step
current_num_tokens = (
(batch["labels"] != self._loss_fn.ignore_index)
.sum()
.to(self._device)
)
num_tokens += current_num_tokens
batches.append((batch, current_num_tokens.item()))

# Log per-step metrics
if (
self.global_step % self._log_every_n_steps == 0
and self._is_rank_zero
):
time_per_step = time.perf_counter() - t0
log_dict = {
"loss": loss_to_log,
"lr": get_lr(
(
self._optimizer
if not self._optimizer_in_bwd
else self._optim_ckpt_wrapper
),
),
"tokens_per_second_per_gpu": num_tokens
/ (time_per_step * self.world_size),
}
if self._log_peak_memory_stats:
log_dict.update(
training.get_memory_stats(device=self._device)
)
if self._clip_grad_norm is not None:
log_dict.update({"grad_norm": grad_norm})
self._metric_logger.log_dict(
log_dict,
step=self.global_step,
if len(batches) == 0:
# dataloader is empty
break

torch.distributed.all_reduce(num_tokens)

# TODO: confirm this adjustment is correct, i believe workers in same DP group will overcount
# num_tokens
num_tokens = num_tokens / self.parallel_dims.non_data_parallel_size

for batch, token_count in batches:
utils.batch_to_device(batch, self._device)
# Shape [b, s], needed for the loss not the model
labels = batch.pop("labels")

with self.activations_handling_ctx:
logits = self._model(**batch)
# Shift labels to compute loss
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
# But this way we dont need to slice the logits. We just add an ignore index to labels.
labels = torch.hstack(
(labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
)
if not isinstance(logits, list):
labels = labels.reshape(-1)
logits = logits.reshape(-1, logits.size(-1))

# Compute loss
# Loss is normalized by default so we weigh by the size of the tokens in this batch divided
# by total tokens in update step
# TODO: confirm division by non data parallel size is correct
# Appears required due to output layer outputting Replicate
current_loss = self._loss_fn(logits, labels) * (
token_count
/ num_tokens
/ self.parallel_dims.non_data_parallel_size
)

# free logits otherwise it peaks backward memory
del logits

current_loss.backward()

running_loss += current_loss

if not self._optimizer_in_bwd:
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
max_norm=float(self._clip_grad_norm),
)
# If sharded, collect the DTensor here
if isinstance(grad_norm, DTensor):
grad_norm = grad_norm.full_tensor()
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)

# Update the number of steps when the weights are updated
self.global_step += 1

# Step the learning rate scheduler
if self._lr_scheduler is not None:
self._lr_scheduler.step()

torch.distributed.all_reduce(running_loss)
loss_to_log = running_loss.item()
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
)

# Log per-step metrics
if (
self.global_step % self._log_every_n_steps == 0
and self._is_rank_zero
):
time_per_step = time.perf_counter() - t0
log_dict = {
"loss": loss_to_log,
"lr": get_lr(
(
self._optimizer
if not self._optimizer_in_bwd
else self._optim_ckpt_wrapper
),
),
"tokens_per_second_per_gpu": num_tokens
/ (time_per_step * self.world_size),
"num_tokens": num_tokens.item(),
}
if self._log_peak_memory_stats:
log_dict.update(training.get_memory_stats(device=self._device))
if self._clip_grad_norm is not None:
log_dict.update({"grad_norm": grad_norm})
self._metric_logger.log_dict(
log_dict,
step=self.global_step,
)

# Reset running stats for the next step
running_loss = 0
num_tokens = 0
t0 = time.perf_counter()

# Stop tracking CUDA memory now that active steps are complete
if (
self._is_rank_zero
and curr_epoch == 0
and self.profiler_profile_memory
and idx
== self.profiler_wait_steps
+ self.profiler_warmup_steps
+ self.profiler_active_steps
and self._device.type == "cuda"
):
torch.cuda.memory._record_memory_history(enabled=None)

# Step profiler
# Note that this is called within gradient accumulation block, hence
# will include multiple forward / backward passes if gradient accumulation > 1
self._profiler.step()
# Reset running stats for the next step
running_loss = 0
num_tokens = 0
t0 = time.perf_counter()

# Stop tracking CUDA memory now that active steps are complete
if (
(idx + 1) // self._gradient_accumulation_steps
) == self.max_steps_per_epoch:
break
self._is_rank_zero
and curr_epoch == 0
and self.profiler_profile_memory
and update_step
== self.profiler_wait_steps
+ self.profiler_warmup_steps
+ self.profiler_active_steps
and self._device.type == "cuda"
):
torch.cuda.memory._record_memory_history(enabled=None)

# Step profiler
# Note that this is called within gradient accumulation block, hence
# will include multiple forward / backward passes if gradient accumulation > 1
self._profiler.step()

self.epochs_run += 1
self._checkpoint_client.save_checkpoint(
6 changes: 6 additions & 0 deletions torchtune/training/_distributed.py
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@
import logging
import os
from dataclasses import dataclass
from functools import cached_property
from itertools import chain
from typing import Any, Callable, cast, Dict, List, Optional, Tuple

@@ -122,6 +123,11 @@ def dp_shard_enabled(self):
def tp_enabled(self):
return self.tp > 1

@cached_property
def non_data_parallel_size(self):
# update this as new parallelism strategies are added
return self.tp


def _get_sharding_strategy(strategy: str) -> ShardingStrategy:
"""Helper function to convert sharding strategy strings to ShardingStrategy enum."""