Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ __pycache__/
*.pyc
input.txt
env/
venv/
venv/
.venv/
26 changes: 26 additions & 0 deletions config/train_gpt2_1.3B.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# config for training GPT-2 (1.3B)

# these make the total batch size be ~1M
# 8 batch size * 64 block size * 2048 grad_accum_steps = 1,048,576
# Using strong scaling, so make grad_acc_steps multiple of maximum gpu count
batch_size = 8
block_size = 64
gradient_accumulation_steps = 1 * 2048

# USE FOR TESTING = ~0.25M batch size
# batch_size = 8
# block_size = 16
# gradient_accumulation_steps = 1 * 2048

# model - 1.3B from OPT paper table 1 (https://arxiv.org/pdf/2205.01068)
# # model params ~= 12 * n_layer * n_emd**2 + n_embd * vocab_size
n_layer = 24
n_head = 32
n_embd = 2048
learning_rate=2e-4

# max iters
max_iters = 10

use_pccl=True
bucket_cap_mb=32
69 changes: 69 additions & 0 deletions patch_ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import sys
import torch
import torch.distributed as dist
from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import allreduce_hook
from mpi4py import MPI

# PCCL imports
sys.path.append("/ccs/home/keshprad/sc2025-pccl-reproducer-v2/")
from pccl.process_groups import ProcessGroups
from pccl.build_kernels import build
from pccl.all_reduce import all_reduce_2D

pg = None
def get_heir_pg():
global pg
assert pg is not None, "did you call the patch_ddp function?"
return pg

def is_global_pg(group):
return (group is None) or (group == dist.group.WORLD) or (dist.get_world_size(group) == dist.get_world_size())

def pccl_all_reduce_hook(group: dist.ProcessGroup, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]:
if is_global_pg(group):
# define input tensor - get the flattened gradient tensor from the bucket
input_tensor = bucket.buffer()
# define output tensor
output_tensor = torch.empty(input_tensor.size(0),
device=input_tensor.device,
dtype=input_tensor.dtype)
# pccl is blocking for now
all_reduce_2D(output_tensor,
input_tensor,
group=get_heir_pg(),
async_op=False,
use_rh_and_rd=True,
use_pccl_cpp_backend=True)

# Create _dummy_ future and set the result
future = torch.futures.Future()
future.set_result(output_tensor)
return future
else:
return allreduce_hook(group, bucket)

def patch_ddp(ddp_model):
assert dist.is_initialized()
assert dist.get_world_size()

# auto-detect intra-node process group size
intra_node_pg_size = torch.cuda.device_count()

# build pccl
# if dist.get_rank() == 0:
# build()
# MPI.COMM_WORLD.Barrier()
# else:
# MPI.COMM_WORLD.Barrier()
# build()
# build PCCL on every node in node-local NVMe
build()
MPI.COMM_WORLD.Barrier()

# setup process groups sub-communicators
global pg
pg = ProcessGroups(intra_node_pg_size,
dist.get_world_size() // intra_node_pg_size)

# register communication hook with the DDP process group as state
ddp_model.register_comm_hook(ddp_model.process_group, pccl_all_reduce_hook)
68 changes: 68 additions & 0 deletions scripts/create_python_env_frontier.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#!/bin/bash
#
RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m' # No Color

SCRATCH=/lustre/orion/csc547/scratch/keshprad
WRKSPC=${SCRATCH}/nanoGPT
# everything will be installed in $WRKSPC

ENV_NAME=".venv"
# this is the name of your python venv, change if needed

cd $WRKSPC
echo -e "${RED}Creating Python Environment in $WRKSPC:${GREEN}"

# Load modules
module reset
# load modules
rocm_version=6.4.1
module load PrgEnv-cray
module load rocm/${rocm_version}
module load cray-mpich/8.1.32
module load cpe/25.03
module load craype-accel-amd-gfx90a
module load cray-python/3.11.7
module load ninja
module list

#Step 1 - activate your venv
python -m venv $WRKSPC/$ENV_NAME
source $WRKSPC/$ENV_NAME/bin/activate
# upgrade pip
pip install -U pip

echo -e "${RED}Installing Dependencies:${GREEN}"
#Step 2 - install torch
pip install torch==2.8.0 --index-url https://download.pytorch.org/whl/rocm6.4
pip install --upgrade numpy

#Step 3 build mpi4py
MPICC="cc -shared" pip install --no-cache-dir --no-binary=mpi4py mpi4py

#Step 4 - install other packages
pip install numpy transformers datasets tiktoken wandb tqdm


#Step 5 - AWS-OFI-RCCL plugin
# skip for now as I already have it installed
# echo "Installing RCCL Plugin"
# cd ${SCRATCH}
# git clone --recursive https://github.com/ROCmSoftwarePlatform/aws-ofi-rccl
# cd aws-ofi-rccl
# libfabric_path=/opt/cray/libfabric/1.22.0
# ./autogen.sh
# export LD_LIBRARY_PATH=/opt/rocm-$rocm_version/lib:$LD_LIBRARY_PATH
# PLUG_PREFIX=$PWD
# CC=hipcc CFLAGS=-I/opt/rocm-$rocm_version/include ./configure \
# --with-libfabric=$libfabric_path --with-rccl=/opt/rocm-$rocm_version --enable-trace \
# --prefix=$PLUG_PREFIX --with-hip=/opt/rocm-$rocm_version --with-mpi=$MPICH_DIR
# make
# make install
# cd ..

echo -e "${RED}Your Python Environment is ready. To activate it run the following commands in the SAME order:${NC}"
echo -e "${GREEN}source $WRKSPC/$ENV_NAME/bin/activate${NC}"
echo ""
echo -e "${NC}"
1 change: 1 addition & 0 deletions scripts/get_rank.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/bin/bash
# select_gpu_device wrapper script
export RANK=${SLURM_PROCID}
export LOCAL_RANK=${SLURM_LOCALID}
exec $*
94 changes: 94 additions & 0 deletions scripts/run_frontier.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#!/bin/bash
#SBATCH -q normal
#SBATCH -J nanogpt
#SBATCH --gpu-bind none
#SBATCH -t 00:10:00
#SBATCH -A csc547
#SBATCH --exclusive
#SBATCH -C nvme
# Run like: sbatch run_frontier16.sh

echo "start run: $(date)"

export SCRATCH="/lustre/orion/csc547/scratch/keshprad"
export WRKSPC="${SCRATCH}/nanoGPT"
export HF_HOME="${SCRATCH}/.cache/hf"
export TRITON_CACHE_DIR="${SCRATCH}/.cache/triton"
cd $WRKSPC

# load modules
rocm_version="6.4.1"
module load PrgEnv-cray
module load rocm/${rocm_version}
module load cray-mpich/8.1.32
module load cpe/25.03
module load craype-accel-amd-gfx90a
module load cray-python/3.11.7
module load ninja
module list
export CXX=CC
export CC=cc

# activate env
source ${WRKSPC}/.venv/bin/activate

NNODES=$SLURM_JOB_NUM_NODES
GPUS=$(( NNODES * 8 ))
## master addr and port
# setting variables for torch.distributed
export MASTER_ADDR=$(hostname)
export MASTER_PORT=29500
export WORLD_SIZE=$GPUS
export OMP_NUM_THREADS=7

## some RCCL env variables
export FI_CXI_ATS=0
export HSA_FORCE_FINE_GRAIN_PCIE=1
export NCCL_CROSS_NIC=1
export NCCL_SOCKET_IFNAME=hsn0
export CUDA_VISIBLE_DEVICES=7,6,5,4,3,2,1,0
export CUDA_DEVICE_MAX_CONNECTIONS=1
# AWS-OFI-RCCL
export LD_LIBRARY_PATH="${SCRATCH}/aws-ofi-rccl/lib:$LD_LIBRARY_PATH"

# mpich gpu support
export MPICH_GPU_SUPPORT_ENABLED=1
export MPICH_OFI_VERBOSE=1
export MPICH_OFI_NIC_POLICY="USER"
export MPICH_OFI_NIC_MAPPING="0:0-1; 1:2-3; 2:4-5; 3:6-7"

# fi variables
export FI_CXI_RDZV_THRESHOLD=0
export FI_CXI_RDZV_GET_MIN=0
export FI_CXI_RDZV_EAGER_SIZE=0
export MPICH_OFI_CXI_COUNTER_VERBOSE=1
# collecting counter data
export MPICH_OFI_CXI_COUNTER_REPORT=5
export HSA_ENABLE_SDMA=0

# other
export MPICH_GPU_SUPPORT_ENABLED=1
export GPU_MAX_HW_QUEUES=1
export OFI_NCCL_USE_IPV6_TCP=1

MASK_0="0x00fe000000000000" # Cores 49-55
MASK_1="0xfe00000000000000" # Cores 57-64
MASK_2="0x0000000000fe0000" # Cores 17-23
MASK_3="0x00000000fe000000" # Cores 25-31
MASK_4="0x00000000000000fe" # Cores 1-7
MASK_5="0x000000000000fe00" # Cores 9-15
MASK_6="0x000000fe00000000" # Cores 33-39
MASK_7="0x0000fe0000000000" # Cores 41-47
CPU_MASK="--cpu-bind=mask_cpu:${MASK_0},${MASK_1},${MASK_2},${MASK_3},${MASK_4},${MASK_5},${MASK_6},${MASK_7}"


SCRIPT="scripts/get_rank.sh python -u train.py config/train_gpt2_1.3B.py"
# log start date
echo "start nanoGPT: $(date)"
run_cmd="srun -N $NNODES -n $GPUS --ntasks-per-node=8 -c 7 ${CPU_MASK} --mem-bind=map_mem:3,3,1,1,0,0,2,2 $SCRIPT"
echo $run_cmd
eval $run_cmd
# log end date
echo "end nanoGPT: $(date)"

echo "end run: $(date)"
49 changes: 35 additions & 14 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from axonn.intra_layer import optimize_communication

from model import GPTConfig, GPT
from patch_ddp import patch_ddp
import csv

# -----------------------------------------------------------------------------
# default config values designed to train a gpt2 (124M) on OpenWebText
Expand Down Expand Up @@ -75,15 +77,8 @@
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
compile = True # use PyTorch 2.0 to compile the model to be faster

# model parallelism args
G_intra_r=1
G_intra_c=1
G_intra_d=1

# gradient checkpointing
gradient_checkpointing=False

use_pccl = False
bucket_cap_mb=None # uses default value from PyTorch DDP
# -----------------------------------------------------------------------------
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open('configurator.py').read()) # overrides from command line or config file
Expand Down Expand Up @@ -214,6 +209,7 @@ def get_batch(split):
model.crop_block_size(block_size)
model_args['block_size'] = block_size # so that the checkpoint will have the right value
model.to(device)
n_params = model.get_num_params()

# initialize a GradScaler. If enabled=False scaler is a no-op
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
Expand All @@ -233,8 +229,12 @@ def get_batch(split):
model = torch.compile(model) # requires PyTorch 2.0

# wrap model into DDP container
#if ddp:
# model = DDP(model, device_ids=[local_rank], process_group=ax.comm_handle.coll_nccl_comm)
if ddp:
model = DDP(model, device_ids=[ddp_local_rank], bucket_cap_mb=bucket_cap_mb)

# use pccl - register comm hook for pccl allreduce
if use_pccl:
patch_ddp(model)

# helps estimate an arbitrarily accurate loss over either split using many batches
@torch.no_grad()
Expand All @@ -256,7 +256,7 @@ def estimate_loss():
def get_lr(it):
# 1) linear warmup for warmup_iters steps
if it < warmup_iters:
return learning_rate * it / warmup_iters
return learning_rate * (it + 1) / (warmup_iters + 1)
# 2) if it > lr_decay_iters, return min learning rate
if it > lr_decay_iters:
return min_lr
Expand All @@ -278,6 +278,22 @@ def get_lr(it):
raw_model = model # row model and model are the same in axonn
running_mfu = -1.0

# Setup benchmark data collection
data_folder = f"./benchmark/pccl-ddp"
os.makedirs(data_folder, exist_ok=True)

# Create initial csv file
gpu_count = ddp_world_size if ddp else 1
slurm_job_id = os.environ.get('SLURM_JOB_ID', 'unknown')
all_reduce_library = "pccl" if use_pccl else "xccl"
csv_filename = os.path.join(data_folder, f"gpt2-{n_params/1e9:.2f}B_{all_reduce_library}_gpus_{gpu_count}_slurm_{slurm_job_id}.csv")
if master_process:
with open(csv_filename, 'w') as f:
writer = csv.writer(f)
# Write the header
header = ["gpu_count", "slurm_job_id", "model_size", "global_batch_size", "iter", "loss", "time (s)", "memory (GB)", "max_mem (GB)"]
writer.writerow(header)
f.flush()

while True:

Expand Down Expand Up @@ -351,8 +367,13 @@ def get_lr(it):
running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
memory = torch.cuda.memory_allocated() / 1024 / 1024 / 1024
peak = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024
if torch.distributed.get_rank() == 0:
print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%, mem = {memory:.2f} GB | max mem = {peak} GB")
print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%, mem = {memory:.2f} GB, max mem = {peak:.2f} GB")

# master_process logs to CSV file
with open(csv_filename, 'a') as f:
writer = csv.writer(f)
writer.writerow([gpu_count, slurm_job_id, f"{n_params/1e9:.2f}B", tokens_per_iter, iter_num, lossf, dt, memory, peak])
f.flush()
iter_num += 1
local_iter_num += 1

Expand Down