Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions .github/workflows/gpu_ci_h100.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
name: H100-GPU-CI

on:
push:
branches:
- main
paths:
- 'ci/**'
- 'skyrl/backends/skyrl_train/**'
- 'skyrl/train/**'
- 'tests/backends/skyrl_train/gpu/gpu_ci/**'
- 'pyproject.toml'
- '!docs/**'
- '!examples/**'
- '.github/workflows/**'
pull_request_target:
types: [labeled]
workflow_dispatch:


permissions:
checks: write # for status checks to appear
contents: read

jobs:

skyrl_train_tests_h100:
if: >
github.event_name == 'push' ||
github.event_name == 'workflow_dispatch' ||
(
github.event_name == 'pull_request_target' &&
!github.event.pull_request.draft &&
contains(github.event.pull_request.labels.*.name, 'run_h100_gpu_ci') &&
(
github.event.pull_request.author_association == 'MEMBER' ||
github.event.pull_request.author_association == 'OWNER' ||
github.event.pull_request.author_association == 'COLLABORATOR'
)
)
runs-on: ubuntu-latest
defaults:
run:
shell: bash
working-directory: .

steps:
- uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.sha || github.ref }}
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
cache: 'pip'
- name: Install the latest version of uv
uses: astral-sh/setup-uv@v6
with:
activate-environment: true
- name: Install basic dependencies
run: uv pip install anyscale==0.24.79 typer==0.9.0
# Run h100 tests via anyscale staging (compute config llm-team-h100-4x:1)
- name: GPU tests
env:
ANYSCALE_CLI_TOKEN: ${{ secrets.ANYSCALE_CLI_TOKEN_STAGING }}
ANYSCALE_HOST: https://console.anyscale-staging.com
run: |
anyscale job submit -f ci/anyscale_gpu_ci_h100.yaml --timeout 5400
anyscale job wait --name skyrl-train-gpu-ci-h100 --timeout 5400
7 changes: 7 additions & 0 deletions ci/anyscale_gpu_ci_h100.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
name: skyrl-train-gpu-ci-h100
entrypoint: bash ci/gpu_ci_run_h100.sh
image_uri: novaskyai/skyrl-train-ray-2.51.1-py3.12-cu12.8-megatron
ray_version: "2.51.1"
compute_config: llm-team-h100-4x:1
working_dir: .
max_retries: 0
15 changes: 15 additions & 0 deletions ci/gpu_ci_run_h100.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/usr/bin/env bash
set -xeuo pipefail

export CI=true

# Prepare datasets used in tests (Megatron test uses gsm8k env_class).
uv run examples/train/gsm8k/gsm8k_dataset.py --output_dir $HOME/data/gsm8k

# Run FSDP h100 tests. Use --extra fsdp so the fsdp-only tests can collect.
uv run --directory . --isolated --extra dev --extra fsdp pytest -s -vvv -m h100 \
tests/backends/skyrl_train/gpu/gpu_ci/test_policy_local_engines_e2e.py

# Run Megatron h100 tests. Use --extra megatron so the megatron-only tests can collect.
uv run --directory . --isolated --extra dev --extra megatron pytest -s -vvv -m h100 \
tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_megatron_models.py
27 changes: 27 additions & 0 deletions skyrl/backends/skyrl_train/distributed/fsdp_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,33 @@ def _fsdp_init_model(self, model, is_train=True, is_wrapped=False):
}
module = model.model if is_wrapped else model
full_state = module.state_dict()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Calling state_dict() on all ranks materializes the full model state in memory on every process. Since fsdp2_load_full_state_dict uses set_model_state_dict with broadcast_from_rank0=True, only rank 0 needs the actual data. On non-zero ranks, this leads to significant and unnecessary CPU memory usage, which can cause OOMs on nodes with many GPUs when training large models.

Suggested change
full_state = module.state_dict()
full_state = module.state_dict() if dist.get_rank() == 0 else {}


# Move the entire module to meta before apply_fsdp2 so the sharded
# DTensors are allocated directly on GPU at their final sharded size,
# rather than first materializing the full model on each rank.
# Reference: huggingface/accelerate fsdp2_prepare_model uses the
# same pattern.
#
# Non-persistent buffers (e.g. RotaryEmbedding.inv_freq) are not in
# state_dict, so they would be wiped by the meta swap. We snapshot
# rank 0's values before the swap and restore them on rank 0 after,
# so _sync_non_persistent_buffers (inside fsdp2_load_full_state_dict)
# can broadcast real values to the other ranks.
non_persistent_snapshot = {}
if dist.get_rank() == 0:
for sub_name, sub in module.named_modules():
for bname in getattr(sub, "_non_persistent_buffers_set", set()):
buf = sub._buffers.get(bname)
if buf is not None and not buf.is_meta:
non_persistent_snapshot[(sub_name, bname)] = buf.detach().clone()

module.to(torch.device("meta"))

if dist.get_rank() == 0:
for (sub_name, bname), buf in non_persistent_snapshot.items():
sub = module.get_submodule(sub_name) if sub_name else module
sub._buffers[bname] = buf

apply_fsdp2(module, fsdp_kwargs, self.fsdp_config)
fsdp2_load_full_state_dict(module, full_state, cpu_offload)
return module
Expand Down
130 changes: 43 additions & 87 deletions skyrl/backends/skyrl_train/distributed/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,100 +129,56 @@ def _sync_non_persistent_buffers(model: torch.nn.Module, loaded_sd: dict):
module._buffers[key] = src.cpu()


# Fsdp2 load full state dict from `accelerate`
# Reference: https://github.com/huggingface/accelerate/blob/0af621bbecc0e43f5d43766a4945d3d2236bb8a9/src/accelerate/utils/fsdp_utils.py#L455
# NOTE (sumanthrh): The original code from `accelerate` assumes init on meta device - with cpu init only on rank 0, but the code is compatible with cpu init on all ranks.
def fsdp2_load_full_state_dict(model: torch.nn.Module, full_sd: dict, cpu_offload=None):
"""
Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the
parameters from rank 0 to all other ranks. This function modifies the model in-place.
Loads a full state dict (assumed populated on rank 0) into a sharded FSDP2
model by broadcasting from rank 0 and distributing per-parameter shards.

Uses PyTorch's `set_model_state_dict` with `full_state_dict=True` +
`broadcast_from_rank0=True`, which internally does the per-parameter
broadcast and DTensor distribution that we used to do manually. The
utility also releases intermediate staging tensors as it goes, so the
caching allocator can return memory to the device pool after init.

Args:
model (`torch.nn.Module`):
The model to load the state dict into, expected to be on meta device or a VRAM spike can occur
full_sd (`dict`): The full state dict to load, can be only on rank 0
The model to load the state dict into, expected to be on meta
device on all ranks except rank 0 (or on meta on all ranks and
full_sd populated on rank 0).
full_sd (`dict`): The full state dict to load (only rank 0 needs
real data; non-rank-0 ranks may pass an empty dict).
"""
import torch.distributed as dist
from torch.distributed.tensor import distribute_tensor

# Model was previously copied to meta device
meta_sharded_sd = model.state_dict()
sharded_sd = {}

# Rank 0 distributes the full state dict to other ranks
def _infer_parameter_dtype(model, param_name, empty_param):
try:
old_param = model.get_parameter_or_buffer(param_name)
except AttributeError:
# Need this for LORA, as there some params are not *parameters* of sorts
base_param_name, local_param_name = param_name.rsplit(".", 1)
submodule = model.get_submodule(base_param_name)
old_param = getattr(submodule, local_param_name)

is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
casting_dtype = None
is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn

if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
casting_dtype = old_param.dtype

return old_param is not None and old_param.is_contiguous(), casting_dtype

def _cast_and_contiguous(tensor, to_contiguous, dtype):
if dtype is not None:
tensor = tensor.to(dtype=dtype)
if to_contiguous:
tensor = tensor.contiguous()
return tensor

if dist.get_rank() == 0:
for (param_name, full_param), sharded_param in zip(full_sd.items(), meta_sharded_sd.values()):
full_param = full_param.detach().cuda()
mesh = sharded_param.device_mesh
dist.broadcast(full_param, src=0)
sharded_tensor = distribute_tensor(full_param, mesh, sharded_param.placements)
to_contiguous, casting_dtype = _infer_parameter_dtype(
model,
param_name,
full_param,
)
sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype)
sharded_sd[param_name] = sharded_tensor
# We need this else to have a matching `broadcast` for all of the ranks, else we deadlock
else:
for param_name, sharded_param in meta_sharded_sd.items():
full_tensor = torch.empty(sharded_param.size(), device="cuda", dtype=sharded_param.dtype)
mesh = sharded_param.device_mesh
dist.broadcast(full_tensor, src=0)
sharded_tensor = distribute_tensor(full_tensor, mesh, sharded_param.placements)
to_contiguous, casting_dtype = _infer_parameter_dtype(
model,
param_name,
full_tensor,
)
sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype)
sharded_sd[param_name] = sharded_tensor

# we set `assign=True` because our params can be on meta device
model.load_state_dict(sharded_sd, assign=True)

# Broadcast non-persistent buffers (e.g. inv_freq from RotaryEmbedding) that
# are excluded from state_dict. On non-rank-0 meta-init these are still on
# meta device with no data; rank 0 has the correctly computed values.
_sync_non_persistent_buffers(model, sharded_sd)

# If we don't offload FSDP2 Module to CPU and then back to GPU,
# it will occupy a large amount of reserved GPU memory,which can not be released using torch.cuda.empty_cache()
# even if we are using cpu_offload
# TODO (erictang000): this requires an additional offload + backload, see if this can be avoided
# Credit: https://github.com/volcengine/verl/pull/1667
offload_fsdp2_model_to_cpu(model)

torch.cuda.synchronize()
torch.cuda.empty_cache()
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
set_model_state_dict,
)

set_model_state_dict(
model,
full_sd,
options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True),
)

if not cpu_offload:
load_fsdp2_model_to_gpu(model)
# Broadcast non-persistent buffers (e.g. inv_freq from RotaryEmbedding)
# that are excluded from state_dict. On non-rank-0 meta-init these are
# still on the meta device with no data; rank 0 has the correct values.
_sync_non_persistent_buffers(model, model.state_dict())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The second argument to _sync_non_persistent_buffers is not utilized by the function implementation. Calling model.state_dict() on an FSDP2 model, while relatively cheap as it returns sharded DTensors, still incurs unnecessary overhead and creates a large dictionary of tensors for no reason.

Suggested change
_sync_non_persistent_buffers(model, model.state_dict())
_sync_non_persistent_buffers(model, {})


# NOTE: removed the offload_fsdp2_model_to_cpu + load_fsdp2_model_to_gpu
# dance from verl PR #1667. That trick was meant to release reserved-but-
# unallocated cache memory, but for FSDP2 the offload is a no-op
# (`model.to("cpu")` doesn't move FSDPParam-managed storage) and the reload
# then *allocates a second copy*. Diagnostic on Qwen3.5-35B-A3B:
# after set_model_state_dict: 32.71 GiB
# after offload+empty_cache: 32.71 GiB (unchanged)
# after load: 65.42 GiB (doubled)
# set_model_state_dict already leaves us with exactly the shard on GPU;
# no clean-up is needed.
if cpu_offload:
# Caller asked for CPU-resident params; the offload path is still
# broken for FSDP2 but we leave the request explicit so a future fix
# has an obvious hook.
offload_fsdp2_model_to_cpu(model)
return model


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,22 +180,32 @@ def setup_distributed(self, timeout=timedelta(minutes=30)) -> None:
def offload_to_cpu(self, model, optimizer, offload_optimizer=True, offload_model=True):
"""
Offload model weights and optimizer to CPU memory.

The grad buffer belongs to the DDP-wrapped model, not the optimizer,
so it is offloaded whenever ``offload_optimizer`` is requested even if
``optimizer is None`` (e.g. ``policy.skip_optimizer_init=True`` flows).
"""
if offload_model:
offload_megatron_model_to_cpu(model)
if optimizer and offload_optimizer:
if offload_optimizer:
offload_megatron_grads_to_cpu(model)
offload_megatron_optimizer(optimizer)
if optimizer is not None:
offload_megatron_optimizer(optimizer)
torch.cuda.synchronize()
torch.cuda.empty_cache()

def backload_to_gpu(self, model, optimizer, backload_optimizer=True, backload_model=True):
"""Reload model weights back to GPU."""
"""Reload model weights back to GPU.

See :meth:`offload_to_cpu` for why the grad-buffer half is decoupled
from optimizer existence.
"""
if backload_model:
load_megatron_model_to_gpu(model)
if optimizer and backload_optimizer:
if backload_optimizer:
load_megatron_grads_to_gpu(model)
load_megatron_optimizer(optimizer)
if optimizer is not None:
load_megatron_optimizer(optimizer)
torch.cuda.synchronize()

def backward(self, loss: torch.Tensor, model, optimizer: optim.Optimizer, **kwargs) -> None:
Expand Down
Loading
Loading