diff --git a/README.md b/README.md index f1bfc5d..b037c01 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,13 @@ pip install flash-attn --no-build-isolation 3. Play with the source in `train.py` ``` -python train.py +torchrun --nproc_per_node=1 train.py +``` + +with multiple gpu + +``` +torchrun --nproc_per_node=8 train.py ``` ### Inspiration diff --git a/ckpt_utils.py b/ckpt_utils.py new file mode 100644 index 0000000..3fa8189 --- /dev/null +++ b/ckpt_utils.py @@ -0,0 +1,34 @@ +from torch.distributed.checkpoint.state_dict import ( + set_optimizer_state_dict, + set_model_state_dict, + get_model_state_dict, + get_optimizer_state_dict, +) +import torch.distributed.checkpoint as dcp +import torch.distributed as dist + + +def save_checkpoint(model, optimizer, path): + """Save model and optimizer state using distributed checkpoint""" + model_state = get_model_state_dict(model=model) + optimizer_state = get_optimizer_state_dict(model=model, optimizers=optimizer) + + state_dict = {"model": model_state, "optimizer": optimizer_state} + + dcp.save(state_dict=state_dict, storage_writer=dcp.FileSystemWriter(path)) + + +def load_checkpoint(model, optimizer, path): + """Load model and optimizer state using distributed checkpoint""" + + dcp_state_dict = { + "model": get_model_state_dict(model=model), + "optimizer": get_optimizer_state_dict(model=model, optimizers=optimizer), + } + + dcp.load(dcp_state_dict, storage_reader=dcp.FileSystemReader(path)) + + set_model_state_dict(model=model, model_state_dict=dcp_state_dict["model"]) + set_optimizer_state_dict( + model=model, optimizers=optimizer, optim_state_dict=dcp_state_dict["optimizer"] + ) diff --git a/train.py b/train.py index c4a6ad9..6239130 100644 --- a/train.py +++ b/train.py @@ -1,5 +1,6 @@ from collections.abc import Callable import json +import os from pathlib import Path import random import re @@ -10,12 +11,16 @@ import torch.nn.functional as F from torch.nn.utils import clip_grad_norm_ from torch.utils.data import DataLoader +import torch.distributed as dist +from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy + from transformers import ( AutoTokenizer, PreTrainedTokenizer, LlamaForCausalLM, GenerationConfig, ) +from ckpt_utils import save_checkpoint from loss import approx_kl_divergence, GRPOLoss from replay_buffer import ReplayBuffer, Experience, join_experience_batch @@ -156,7 +161,7 @@ def sequences_log_probs( ) -> torch.Tensor: position_ids = attention_mask.long().cumsum(dim=-1) - 1 position_ids.masked_fill_(mask=(attention_mask == 0), value=1) - output = model.forward( + output = model( input_ids=sequence_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -191,10 +196,64 @@ def read_prompts( return rows +def setup_dist(): + """Initialize process group and create device mesh""" + dist.init_process_group("nccl") + + # Get local world size and rank + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + + +def load_model_fsdp( + model_name_or_path: str, + reduce_fp32: bool = False, + reshard_after_forward: bool = True, + trust_remote_code: bool = False, +) -> tuple[LlamaForCausalLM, PreTrainedTokenizer]: + """Load model and apply composable FSDP""" + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + tokenizer.pad_token = tokenizer.eos_token + + model = LlamaForCausalLM.from_pretrained( + model_name_or_path, + trust_remote_code=trust_remote_code, + attn_implementation="flash_attention_2", + # torch_dtype=torch.bfloat16, + ).to(dist.get_rank()) + + # Define mixed precision policy + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32 if reduce_fp32 else None, + ) + + # Apply FSDP to transformer layers + for layer_id, transformer_block in enumerate(model.model.layers): + # Reshard all layers except the last one if enabled + should_reshard = ( + reshard_after_forward and layer_id < len(model.model.layers) - 1 + ) + + fully_shard( + transformer_block, + mp_policy=mp_policy, + reshard_after_forward=should_reshard, + ) + + # Apply FSDP to the whole model + fully_shard( + model, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward, + ) + + return model, tokenizer + + def main(): seed = 42 wandb_project = None # "tiny_grpo" - device_index = 0 model_name = "meta-llama/Llama-3.2-1B-Instruct" checkpoint_path = Path("./output") checkpoint_interval = 20 @@ -203,6 +262,10 @@ def main(): kl_weight = 0.01 clip_eps = 0.2 + # FSDP specific configs + reduce_fp32 = False + reshard_after_forward = False + group_size = 12 rollouts_per_step = 32 epochs_per_step = 1 @@ -213,15 +276,28 @@ def main(): top_p = 1.0 temperature = 1.0 - device = torch.device("cuda", device_index) - cpu_device = torch.device("cpu") + # Initialize distributed setup + setup_dist() init_rng(seed) - reference_model, _ = load_model(model_name, device_map=device) - model, tokenizer = load_model(model_name, device_map=device) + # Load models with composable FSDP + reference_model, _ = load_model_fsdp( + model_name, + reduce_fp32=reduce_fp32, + reshard_after_forward=reshard_after_forward, + ) + model, tokenizer = load_model_fsdp( + model_name, + reduce_fp32=reduce_fp32, + reshard_after_forward=reshard_after_forward, + ) + optimizer = optim.Adam(model.parameters(), lr=lr) reference_model.eval() + model.train() + + # Enable gradient checkpointing model.gradient_checkpointing_enable( gradient_checkpointing_kwargs={"use_reentrant": False} ) @@ -235,22 +311,27 @@ def main(): and x["num_digits"] <= 3, max_rows=64 * 1024, ) - print(f"found {len(prompts)} matching prompts") + + if dist.get_rank() == 0: + print(f"found {len(prompts)} matching prompts") + prompt_loader = DataLoader( prompts, batch_size=rollouts_per_step, shuffle=True, drop_last=True, - pin_memory=False, + pin_memory=True, ) replay_buffer = ReplayBuffer() objective = GRPOLoss(clip_eps=clip_eps, kl_weight=kl_weight) - if wandb_project is None: - wandb.init(mode="disabled") - else: - wandb.init(project=wandb_project) + # Initialize wandb only on rank 0 + if dist.get_rank() == 0: + if wandb_project is None: + wandb.init(mode="disabled") + else: + wandb.init(project=wandb_project) for k, prompt_batch in enumerate(prompt_loader): rollout_returns = [] @@ -273,9 +354,12 @@ def main(): top_p=top_p, ) - print( - f"rollout q='{q}', a='{a}', returns={returns.sum().item():.2f}, replay_buffer_size={len(replay_buffer)}, sequence_ids={sequence_ids.shape}" - ) + if dist.get_rank() == 0: + print( + f"rollout q='{q}', a='{a}', returns={returns.sum().item():.2f}, " + f"replay_buffer_size={len(replay_buffer)}, sequence_ids={sequence_ids.shape}" + ) + rollout_returns.append(returns.cpu()) advantages = group_advantages(returns) @@ -307,12 +391,12 @@ def main(): action_mask=action_mask, kl=kl, ) - replay_buffer.append(experience.to(cpu_device)) + replay_buffer.append(experience.to("cpu")) - torch.cuda.empty_cache() - episode_return_sum = torch.stack(rollout_returns).sum() - print(f"returns of step {k}: {episode_return_sum:.4f}") - wandb.log({"returns": episode_return_sum}) + if dist.get_rank() == 0: + episode_return_sum = torch.stack(rollout_returns).sum() + print(f"returns of step {k}: {episode_return_sum:.4f}") + wandb.log({"returns": episode_return_sum}) experience_sampler = DataLoader( replay_buffer, @@ -325,10 +409,8 @@ def main(): for step_epoch in range(epochs_per_step): model.train() - for exp in experience_sampler: - exp: Experience - - exp = exp.to(device) + for i, exp in enumerate(experience_sampler): + exp = exp.to(dist.get_rank()) optimizer.zero_grad() @@ -336,29 +418,45 @@ def main(): model, sequence_ids=exp.sequences, attention_mask=exp.attention_mask ) - loss, kl = objective(log_probs=log_probs, experience=exp) + loss, kl = objective.forward(log_probs=log_probs, experience=exp) if not loss.isfinite(): - print(f"Loss not finite, skipping backward, loss={loss}") - print(f"experience.advantages={experience.advantages}") + if dist.get_rank() == 0: + print(f"Loss not finite, skipping backward, loss={loss}") + print(f"Loss not finite, skipping backward, loss={loss}") + print(f"experience.advantages={experience.advantages}") + print(f"Loss not finite, skipping backward, loss={loss}") + print(f"experience.advantages={experience.advantages}") continue loss.backward() - grad_norm = clip_grad_norm_(model.parameters(), max_norm=max_norm) - print(f"{step_epoch}: kl={kl: .4f}, grad_norm={grad_norm: .4f}") - wandb.log({"kl": kl, "grad_norm": grad_norm}) + + # gather the grad_norm from all the gpus + grad_norm = clip_grad_norm_( + model.parameters(), max_norm=max_norm + ).full_tensor() + + if dist.get_rank() == 0: + print(f"{step_epoch}: kl={kl: .4f}, grad_norm={grad_norm: .4f}") + wandb.log({"kl": kl, "grad_norm": grad_norm}) optimizer.step() + # Save checkpoint only on rank 0 if ( checkpoint_path is not None and checkpoint_interval is not None and (k + 1) % checkpoint_interval == 0 ): - model.save_pretrained(checkpoint_path / f"step_{k}") + print(f"saving checkpoint to {checkpoint_path / f'step_{k}.pt'}") + save_checkpoint(model, optimizer, checkpoint_path / f"step_{k}.pt") + + # Final save on rank 0 + if checkpoint_path is not None and dist.get_rank() == 0: + save_checkpoint(model, optimizer, checkpoint_path / f"step_{k}.pt") - if checkpoint_path is not None: - model.save_pretrained(checkpoint_path / f"step_{k}") + # Cleanup + dist.destroy_process_group() if __name__ == "__main__":