From cdc71bdc5b9d988eaddb07f12e23aa74982f6778 Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Tue, 1 Apr 2025 12:44:44 +0000 Subject: [PATCH 1/2] Perform all scaling on loss and token count calculation, remove separate grad scaling Signed-off-by: Nathan Azrak --- recipes/full_finetune_distributed.py | 267 ++++++++++++++------------- torchtune/training/_distributed.py | 6 + 2 files changed, 146 insertions(+), 127 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index aa8b596ef6..c5c27a278d 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -767,147 +767,160 @@ def train(self) -> None: for curr_epoch in range(self.epochs_run, self.total_epochs): pbar = tqdm(total=self._steps_per_epoch, disable=not self._is_rank_zero) self._dataloader.sampler.set_epoch(curr_epoch) - for idx, batch in enumerate(self._dataloader): + + dataloader_iter = iter(self._dataloader) + + for update_step in range(self._steps_per_epoch): # Start tracking CUDA memory for active steps for just the first epoch if ( self._is_rank_zero and curr_epoch == 0 and self.profiler_profile_memory - and idx == self.profiler_wait_steps + self.profiler_warmup_steps + # TODO: confirm we want these steps according to *update steps*, not *forward steps* + and update_step + == self.profiler_wait_steps + self.profiler_warmup_steps and self._device.type == "cuda" ): torch.cuda.memory._record_memory_history() - utils.batch_to_device(batch, self._device) - - # Calculate the number of unmasked tokens in the current batch - # and increment the total number of tokens seen in the step - current_num_tokens = ( - batch["labels"] != self._loss_fn.ignore_index - ).sum() - num_tokens += current_num_tokens - - # Shape [b, s], needed for the loss not the model - labels = batch.pop("labels") - - with self.activations_handling_ctx: - logits = self._model(**batch) - # Shift labels to compute loss - # equivalent to doing labels[..., 1:] and logits[..., :-1, :] - # But this way we dont need to slice the logits. We just add an ignore index to labels. - labels = torch.hstack( - (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) - ) - if not isinstance(logits, list): - labels = labels.reshape(-1) - logits = logits.reshape(-1, logits.size(-1)) - - # Compute loss - # Loss is normalized by default so we multiply by the number of tokens - # This way we can normalize by the total number of tokens if we're accumulating gradients - current_loss = self._loss_fn(logits, labels) * current_num_tokens - - # free logits otherwise it peaks backward memory - del logits - - running_loss += current_loss - - # For optimizer in backward, we need to normalize before calling backward - # This case and gradient accumulation are mutually exclusive - if self._optimizer_in_bwd: - torch.distributed.all_reduce(num_tokens) - torch.distributed.all_reduce(running_loss) - current_loss = current_loss * (self.dp_degree / num_tokens) - - current_loss.backward() - # Optimizer step (if not fused in backward call) - if (idx + 1) % self._gradient_accumulation_steps == 0: - if not self._optimizer_in_bwd: - # Get total number of tokens across all ranks to normalize gradients - torch.distributed.all_reduce(num_tokens) - # This will ensure that the logged loss matches what we're optimizing - torch.distributed.all_reduce(running_loss) - # Manually scale the gradients from unnormalized loss by total # of tokens - training.scale_grads(self._model, self.dp_degree / num_tokens) - if self._clip_grad_norm is not None: - grad_norm = torch.nn.utils.clip_grad_norm_( - self._model.parameters(), - max_norm=float(self._clip_grad_norm), - ) - # If sharded, collect the DTensor here - if isinstance(grad_norm, DTensor): - grad_norm = grad_norm.full_tensor() - self._optimizer.step() - self._optimizer.zero_grad(set_to_none=True) - - # Update the number of steps when the weights are updated - self.global_step += 1 - - # Step the learning rate scheduler - if self._lr_scheduler is not None: - self._lr_scheduler.step() - - loss_to_log = running_loss.item() / num_tokens - pbar.update(1) - pbar.set_description( - f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + + # pre-fetch examples and move to device + batches = [] + for _ in range(self._gradient_accumulation_steps): + try: + batch = next(dataloader_iter) + except StopIteration: + break + + utils.batch_to_device(batch, self._device) + + # Calculate the number of unmasked tokens in the current batch + # and increment the total number of tokens seen in the step + current_num_tokens = ( + batch["labels"] != self._loss_fn.ignore_index + ).sum() + num_tokens += current_num_tokens + batches.append((batch, current_num_tokens.item())) + + if len(batches) == 0: + # dataloader is empty + break + + torch.distributed.all_reduce(num_tokens) + + # TODO: confirm this adjustment is correct, i believe workers in same DP group will overcount + # num_tokens + num_tokens = num_tokens / self.parallel_dims.non_data_parallel_size + + for batch, token_count in batches: + # Shape [b, s], needed for the loss not the model + labels = batch.pop("labels") + + with self.activations_handling_ctx: + logits = self._model(**batch) + # Shift labels to compute loss + # equivalent to doing labels[..., 1:] and logits[..., :-1, :] + # But this way we dont need to slice the logits. We just add an ignore index to labels. + labels = torch.hstack( + (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) + ) + if not isinstance(logits, list): + labels = labels.reshape(-1) + logits = logits.reshape(-1, logits.size(-1)) + + # Compute loss + # Loss is normalized by default so we weigh by the size of the tokens in this batch divided + # by total tokens in update step + # TODO: confirm division by non data parallel size is correct + # Appears required due to output layer outputting Replicate + current_loss = self._loss_fn(logits, labels) * ( + token_count + / num_tokens + / self.parallel_dims.non_data_parallel_size ) - # Log per-step metrics - if ( - self.global_step % self._log_every_n_steps == 0 - and self._is_rank_zero - ): - time_per_step = time.perf_counter() - t0 - log_dict = { - "loss": loss_to_log, - "lr": get_lr( - ( - self._optimizer - if not self._optimizer_in_bwd - else self._optim_ckpt_wrapper - ), - ), - "tokens_per_second_per_gpu": num_tokens - / (time_per_step * self.world_size), - } - if self._log_peak_memory_stats: - log_dict.update( - training.get_memory_stats(device=self._device) - ) - if self._clip_grad_norm is not None: - log_dict.update({"grad_norm": grad_norm}) - self._metric_logger.log_dict( - log_dict, - step=self.global_step, + # free logits otherwise it peaks backward memory + del logits + + current_loss.backward() + + running_loss += current_loss + + if not self._optimizer_in_bwd: + if self._clip_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_( + self._model.parameters(), + max_norm=float(self._clip_grad_norm), ) + # If sharded, collect the DTensor here + if isinstance(grad_norm, DTensor): + grad_norm = grad_norm.full_tensor() + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + + # Update the number of steps when the weights are updated + self.global_step += 1 + + # Step the learning rate scheduler + if self._lr_scheduler is not None: + self._lr_scheduler.step() + + torch.distributed.all_reduce(running_loss) + loss_to_log = running_loss.item() + pbar.update(1) + pbar.set_description( + f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + ) + + # Log per-step metrics + if ( + self.global_step % self._log_every_n_steps == 0 + and self._is_rank_zero + ): + time_per_step = time.perf_counter() - t0 + log_dict = { + "loss": loss_to_log, + "lr": get_lr( + ( + self._optimizer + if not self._optimizer_in_bwd + else self._optim_ckpt_wrapper + ), + ), + "tokens_per_second_per_gpu": num_tokens + / (time_per_step * self.world_size), + "num_tokens": num_tokens.item(), + } + if self._log_peak_memory_stats: + log_dict.update(training.get_memory_stats(device=self._device)) + if self._clip_grad_norm is not None: + log_dict.update({"grad_norm": grad_norm}) + self._metric_logger.log_dict( + log_dict, + step=self.global_step, + ) - # Reset running stats for the next step - running_loss = 0 - num_tokens = 0 - t0 = time.perf_counter() - - # Stop tracking CUDA memory now that active steps are complete - if ( - self._is_rank_zero - and curr_epoch == 0 - and self.profiler_profile_memory - and idx - == self.profiler_wait_steps - + self.profiler_warmup_steps - + self.profiler_active_steps - and self._device.type == "cuda" - ): - torch.cuda.memory._record_memory_history(enabled=None) - - # Step profiler - # Note that this is called within gradient accumulation block, hence - # will include multiple forward / backward passes if gradient accumulation > 1 - self._profiler.step() + # Reset running stats for the next step + running_loss = 0 + num_tokens = 0 + t0 = time.perf_counter() + # Stop tracking CUDA memory now that active steps are complete if ( - (idx + 1) // self._gradient_accumulation_steps - ) == self.max_steps_per_epoch: - break + self._is_rank_zero + and curr_epoch == 0 + and self.profiler_profile_memory + and update_step + == self.profiler_wait_steps + + self.profiler_warmup_steps + + self.profiler_active_steps + and self._device.type == "cuda" + ): + torch.cuda.memory._record_memory_history(enabled=None) + + # Step profiler + # Note that this is called within gradient accumulation block, hence + # will include multiple forward / backward passes if gradient accumulation > 1 + self._profiler.step() self.epochs_run += 1 self._checkpoint_client.save_checkpoint( diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index 2346c58010..5fd601cf64 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -8,6 +8,7 @@ import logging import os from dataclasses import dataclass +from functools import cached_property from itertools import chain from typing import Any, Callable, cast, Dict, List, Optional, Tuple @@ -122,6 +123,11 @@ def dp_shard_enabled(self): def tp_enabled(self): return self.tp > 1 + @cached_property + def non_data_parallel_size(self): + # update this as new parallelism strategies are added + return self.tp + def _get_sharding_strategy(strategy: str) -> ShardingStrategy: """Helper function to convert sharding strategy strings to ShardingStrategy enum.""" From 204f36e147ff7545c1737dbaac255ac445f2e638 Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Wed, 2 Apr 2025 00:49:30 +0000 Subject: [PATCH 2/2] Count tokens per batch on CPU, only move current batch to GPU Signed-off-by: Nathan Azrak --- recipes/full_finetune_distributed.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index c5c27a278d..416c0573ed 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -791,13 +791,13 @@ def train(self) -> None: except StopIteration: break - utils.batch_to_device(batch, self._device) - # Calculate the number of unmasked tokens in the current batch # and increment the total number of tokens seen in the step current_num_tokens = ( - batch["labels"] != self._loss_fn.ignore_index - ).sum() + (batch["labels"] != self._loss_fn.ignore_index) + .sum() + .to(self._device) + ) num_tokens += current_num_tokens batches.append((batch, current_num_tokens.item())) @@ -812,6 +812,7 @@ def train(self) -> None: num_tokens = num_tokens / self.parallel_dims.non_data_parallel_size for batch, token_count in batches: + utils.batch_to_device(batch, self._device) # Shape [b, s], needed for the loss not the model labels = batch.pop("labels")