Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@ pip install flash-attn --no-build-isolation
3. Play with the source in `train.py`

```
python train.py
torchrun --nproc_per_node=1 train.py
```

with multiple gpu

```
torchrun --nproc_per_node=8 train.py
```

### Inspiration
Expand Down
34 changes: 34 additions & 0 deletions ckpt_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from torch.distributed.checkpoint.state_dict import (
set_optimizer_state_dict,
set_model_state_dict,
get_model_state_dict,
get_optimizer_state_dict,
)
import torch.distributed.checkpoint as dcp
import torch.distributed as dist


def save_checkpoint(model, optimizer, path):
"""Save model and optimizer state using distributed checkpoint"""
model_state = get_model_state_dict(model=model)
optimizer_state = get_optimizer_state_dict(model=model, optimizers=optimizer)

state_dict = {"model": model_state, "optimizer": optimizer_state}

dcp.save(state_dict=state_dict, storage_writer=dcp.FileSystemWriter(path))


def load_checkpoint(model, optimizer, path):
"""Load model and optimizer state using distributed checkpoint"""

dcp_state_dict = {
"model": get_model_state_dict(model=model),
"optimizer": get_optimizer_state_dict(model=model, optimizers=optimizer),
}

dcp.load(dcp_state_dict, storage_reader=dcp.FileSystemReader(path))

set_model_state_dict(model=model, model_state_dict=dcp_state_dict["model"])
set_optimizer_state_dict(
model=model, optimizers=optimizer, optim_state_dict=dcp_state_dict["optimizer"]
)
164 changes: 131 additions & 33 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Callable
import json
import os
from pathlib import Path
import random
import re
Expand All @@ -10,12 +11,16 @@
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
import torch.distributed as dist
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy

from transformers import (
AutoTokenizer,
PreTrainedTokenizer,
LlamaForCausalLM,
GenerationConfig,
)
from ckpt_utils import save_checkpoint
from loss import approx_kl_divergence, GRPOLoss
from replay_buffer import ReplayBuffer, Experience, join_experience_batch

Expand Down Expand Up @@ -156,7 +161,7 @@ def sequences_log_probs(
) -> torch.Tensor:
position_ids = attention_mask.long().cumsum(dim=-1) - 1
position_ids.masked_fill_(mask=(attention_mask == 0), value=1)
output = model.forward(
output = model(
input_ids=sequence_ids,
attention_mask=attention_mask,
position_ids=position_ids,
Expand Down Expand Up @@ -191,10 +196,64 @@ def read_prompts(
return rows


def setup_dist():
"""Initialize process group and create device mesh"""
dist.init_process_group("nccl")

# Get local world size and rank
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)


def load_model_fsdp(
model_name_or_path: str,
reduce_fp32: bool = False,
reshard_after_forward: bool = True,
trust_remote_code: bool = False,
) -> tuple[LlamaForCausalLM, PreTrainedTokenizer]:
"""Load model and apply composable FSDP"""
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token

model = LlamaForCausalLM.from_pretrained(
model_name_or_path,
trust_remote_code=trust_remote_code,
attn_implementation="flash_attention_2",
# torch_dtype=torch.bfloat16,
).to(dist.get_rank())

# Define mixed precision policy
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32 if reduce_fp32 else None,
)

# Apply FSDP to transformer layers
for layer_id, transformer_block in enumerate(model.model.layers):
# Reshard all layers except the last one if enabled
should_reshard = (
reshard_after_forward and layer_id < len(model.model.layers) - 1
)

fully_shard(
transformer_block,
mp_policy=mp_policy,
reshard_after_forward=should_reshard,
)

# Apply FSDP to the whole model
fully_shard(
model,
mp_policy=mp_policy,
reshard_after_forward=reshard_after_forward,
)

return model, tokenizer


def main():
seed = 42
wandb_project = None # "tiny_grpo"
device_index = 0
model_name = "meta-llama/Llama-3.2-1B-Instruct"
checkpoint_path = Path("./output")
checkpoint_interval = 20
Expand All @@ -203,6 +262,10 @@ def main():
kl_weight = 0.01
clip_eps = 0.2

# FSDP specific configs
reduce_fp32 = False
reshard_after_forward = False

group_size = 12
rollouts_per_step = 32
epochs_per_step = 1
Expand All @@ -213,15 +276,28 @@ def main():
top_p = 1.0
temperature = 1.0

device = torch.device("cuda", device_index)
cpu_device = torch.device("cpu")
# Initialize distributed setup
setup_dist()
init_rng(seed)

reference_model, _ = load_model(model_name, device_map=device)
model, tokenizer = load_model(model_name, device_map=device)
# Load models with composable FSDP
reference_model, _ = load_model_fsdp(
model_name,
reduce_fp32=reduce_fp32,
reshard_after_forward=reshard_after_forward,
)
model, tokenizer = load_model_fsdp(
model_name,
reduce_fp32=reduce_fp32,
reshard_after_forward=reshard_after_forward,
)

optimizer = optim.Adam(model.parameters(), lr=lr)

reference_model.eval()
model.train()

# Enable gradient checkpointing
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
Expand All @@ -235,22 +311,27 @@ def main():
and x["num_digits"] <= 3,
max_rows=64 * 1024,
)
print(f"found {len(prompts)} matching prompts")

if dist.get_rank() == 0:
print(f"found {len(prompts)} matching prompts")

prompt_loader = DataLoader(
prompts,
batch_size=rollouts_per_step,
shuffle=True,
drop_last=True,
pin_memory=False,
pin_memory=True,
)

replay_buffer = ReplayBuffer()
objective = GRPOLoss(clip_eps=clip_eps, kl_weight=kl_weight)

if wandb_project is None:
wandb.init(mode="disabled")
else:
wandb.init(project=wandb_project)
# Initialize wandb only on rank 0
if dist.get_rank() == 0:
if wandb_project is None:
wandb.init(mode="disabled")
else:
wandb.init(project=wandb_project)

for k, prompt_batch in enumerate(prompt_loader):
rollout_returns = []
Expand All @@ -273,9 +354,12 @@ def main():
top_p=top_p,
)

print(
f"rollout q='{q}', a='{a}', returns={returns.sum().item():.2f}, replay_buffer_size={len(replay_buffer)}, sequence_ids={sequence_ids.shape}"
)
if dist.get_rank() == 0:
print(
f"rollout q='{q}', a='{a}', returns={returns.sum().item():.2f}, "
f"replay_buffer_size={len(replay_buffer)}, sequence_ids={sequence_ids.shape}"
)

rollout_returns.append(returns.cpu())

advantages = group_advantages(returns)
Expand Down Expand Up @@ -307,12 +391,12 @@ def main():
action_mask=action_mask,
kl=kl,
)
replay_buffer.append(experience.to(cpu_device))
replay_buffer.append(experience.to("cpu"))

torch.cuda.empty_cache()
episode_return_sum = torch.stack(rollout_returns).sum()
print(f"returns of step {k}: {episode_return_sum:.4f}")
wandb.log({"returns": episode_return_sum})
if dist.get_rank() == 0:
episode_return_sum = torch.stack(rollout_returns).sum()
print(f"returns of step {k}: {episode_return_sum:.4f}")
wandb.log({"returns": episode_return_sum})

experience_sampler = DataLoader(
replay_buffer,
Expand All @@ -325,40 +409,54 @@ def main():
for step_epoch in range(epochs_per_step):
model.train()

for exp in experience_sampler:
exp: Experience

exp = exp.to(device)
for i, exp in enumerate(experience_sampler):
exp = exp.to(dist.get_rank())

optimizer.zero_grad()

log_probs = sequences_log_probs(
model, sequence_ids=exp.sequences, attention_mask=exp.attention_mask
)

loss, kl = objective(log_probs=log_probs, experience=exp)
loss, kl = objective.forward(log_probs=log_probs, experience=exp)

if not loss.isfinite():
print(f"Loss not finite, skipping backward, loss={loss}")
print(f"experience.advantages={experience.advantages}")
if dist.get_rank() == 0:
print(f"Loss not finite, skipping backward, loss={loss}")
print(f"Loss not finite, skipping backward, loss={loss}")
print(f"experience.advantages={experience.advantages}")
print(f"Loss not finite, skipping backward, loss={loss}")
print(f"experience.advantages={experience.advantages}")
continue

loss.backward()
grad_norm = clip_grad_norm_(model.parameters(), max_norm=max_norm)
print(f"{step_epoch}: kl={kl: .4f}, grad_norm={grad_norm: .4f}")
wandb.log({"kl": kl, "grad_norm": grad_norm})

# gather the grad_norm from all the gpus
grad_norm = clip_grad_norm_(
model.parameters(), max_norm=max_norm
).full_tensor()

if dist.get_rank() == 0:
print(f"{step_epoch}: kl={kl: .4f}, grad_norm={grad_norm: .4f}")
wandb.log({"kl": kl, "grad_norm": grad_norm})

optimizer.step()

# Save checkpoint only on rank 0
if (
checkpoint_path is not None
and checkpoint_interval is not None
and (k + 1) % checkpoint_interval == 0
):
model.save_pretrained(checkpoint_path / f"step_{k}")
print(f"saving checkpoint to {checkpoint_path / f'step_{k}.pt'}")
save_checkpoint(model, optimizer, checkpoint_path / f"step_{k}.pt")

# Final save on rank 0
if checkpoint_path is not None and dist.get_rank() == 0:
save_checkpoint(model, optimizer, checkpoint_path / f"step_{k}.pt")

if checkpoint_path is not None:
model.save_pretrained(checkpoint_path / f"step_{k}")
# Cleanup
dist.destroy_process_group()


if __name__ == "__main__":
Expand Down