Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions .github/workflows/gpu_ci_h100.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
name: H100-GPU-CI

on:
push:
branches:
- main
paths:
- 'ci/**'
- 'skyrl/backends/skyrl_train/**'
- 'skyrl/train/**'
- 'tests/backends/skyrl_train/gpu/gpu_ci/**'
- 'pyproject.toml'
- '!docs/**'
- '!examples/**'
- '.github/workflows/**'
pull_request_target:
types: [labeled]
workflow_dispatch:


permissions:
checks: write # for status checks to appear
contents: read

jobs:

skyrl_train_tests_h100:
if: >
github.event_name == 'push' ||
github.event_name == 'workflow_dispatch' ||
(
github.event_name == 'pull_request_target' &&
!github.event.pull_request.draft &&
contains(github.event.pull_request.labels.*.name, 'run_h100_gpu_ci') &&
(
github.event.pull_request.author_association == 'MEMBER' ||
github.event.pull_request.author_association == 'OWNER' ||
github.event.pull_request.author_association == 'COLLABORATOR'
)
)
runs-on: ubuntu-latest
defaults:
run:
shell: bash
working-directory: .

steps:
- uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.sha || github.ref }}
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
cache: 'pip'
- name: Install the latest version of uv
uses: astral-sh/setup-uv@v6
with:
activate-environment: true
- name: Install basic dependencies
run: uv pip install anyscale==0.24.79 typer==0.9.0
# Run h100 tests via anyscale staging (compute config llm-team-h100-4x:1)
- name: GPU tests
env:
ANYSCALE_CLI_TOKEN: ${{ secrets.ANYSCALE_CLI_TOKEN_STAGING }}
ANYSCALE_HOST: https://console.anyscale-staging.com
run: |
anyscale job submit -f ci/anyscale_gpu_ci_h100.yaml --timeout 5400
anyscale job wait --name skyrl-train-gpu-ci-h100 --timeout 5400
7 changes: 7 additions & 0 deletions ci/anyscale_gpu_ci_h100.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
name: skyrl-train-gpu-ci-h100
entrypoint: bash ci/gpu_ci_run_h100.sh
image_uri: novaskyai/skyrl-train-ray-2.51.1-py3.12-cu12.8-megatron
ray_version: "2.51.1"
compute_config: llm-team-h100-4x:1
working_dir: .
max_retries: 0
15 changes: 15 additions & 0 deletions ci/gpu_ci_run_h100.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/usr/bin/env bash
set -xeuo pipefail

export CI=true

# Prepare datasets used in tests (Megatron test uses gsm8k env_class).
uv run examples/train/gsm8k/gsm8k_dataset.py --output_dir $HOME/data/gsm8k

# Run FSDP h100 tests. Use --extra fsdp so the fsdp-only tests can collect.
uv run --directory . --isolated --extra dev --extra fsdp pytest -s -vvv -m h100 \
tests/backends/skyrl_train/gpu/gpu_ci/test_policy_local_engines_e2e.py

# Run Megatron h100 tests. Use --extra megatron so the megatron-only tests can collect.
uv run --directory . --isolated --extra dev --extra megatron pytest -s -vvv -m h100 \
tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_megatron_models.py
27 changes: 27 additions & 0 deletions skyrl/backends/skyrl_train/distributed/fsdp_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,33 @@ def _fsdp_init_model(self, model, is_train=True, is_wrapped=False):
}
module = model.model if is_wrapped else model
full_state = module.state_dict()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Calling state_dict() on all ranks materializes the full model state in memory on every process. Since fsdp2_load_full_state_dict uses set_model_state_dict with broadcast_from_rank0=True, only rank 0 needs the actual data. On non-zero ranks, this leads to significant and unnecessary CPU memory usage, which can cause OOMs on nodes with many GPUs when training large models.

Suggested change
full_state = module.state_dict()
full_state = module.state_dict() if dist.get_rank() == 0 else {}


# Move the entire module to meta before apply_fsdp2 so the sharded
# DTensors are allocated directly on GPU at their final sharded size,
# rather than first materializing the full model on each rank.
# Reference: huggingface/accelerate fsdp2_prepare_model uses the
# same pattern.
#
# Non-persistent buffers (e.g. RotaryEmbedding.inv_freq) are not in
# state_dict, so they would be wiped by the meta swap. We snapshot
# rank 0's values before the swap and restore them on rank 0 after,
# so _sync_non_persistent_buffers (inside fsdp2_load_full_state_dict)
# can broadcast real values to the other ranks.
non_persistent_snapshot = {}
if dist.get_rank() == 0:
for sub_name, sub in module.named_modules():
for bname in getattr(sub, "_non_persistent_buffers_set", set()):
buf = sub._buffers.get(bname)
if buf is not None and not buf.is_meta:
non_persistent_snapshot[(sub_name, bname)] = buf.detach().clone()

module.to(torch.device("meta"))

if dist.get_rank() == 0:
for (sub_name, bname), buf in non_persistent_snapshot.items():
sub = module.get_submodule(sub_name) if sub_name else module
sub._buffers[bname] = buf

apply_fsdp2(module, fsdp_kwargs, self.fsdp_config)
fsdp2_load_full_state_dict(module, full_state, cpu_offload)
return module
Expand Down
130 changes: 43 additions & 87 deletions skyrl/backends/skyrl_train/distributed/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,100 +129,56 @@ def _sync_non_persistent_buffers(model: torch.nn.Module, loaded_sd: dict):
module._buffers[key] = src.cpu()


# Fsdp2 load full state dict from `accelerate`
# Reference: https://github.com/huggingface/accelerate/blob/0af621bbecc0e43f5d43766a4945d3d2236bb8a9/src/accelerate/utils/fsdp_utils.py#L455
# NOTE (sumanthrh): The original code from `accelerate` assumes init on meta device - with cpu init only on rank 0, but the code is compatible with cpu init on all ranks.
def fsdp2_load_full_state_dict(model: torch.nn.Module, full_sd: dict, cpu_offload=None):
"""
Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the
parameters from rank 0 to all other ranks. This function modifies the model in-place.
Loads a full state dict (assumed populated on rank 0) into a sharded FSDP2
model by broadcasting from rank 0 and distributing per-parameter shards.

Uses PyTorch's `set_model_state_dict` with `full_state_dict=True` +
`broadcast_from_rank0=True`, which internally does the per-parameter
broadcast and DTensor distribution that we used to do manually. The
utility also releases intermediate staging tensors as it goes, so the
caching allocator can return memory to the device pool after init.

Args:
model (`torch.nn.Module`):
The model to load the state dict into, expected to be on meta device or a VRAM spike can occur
full_sd (`dict`): The full state dict to load, can be only on rank 0
The model to load the state dict into, expected to be on meta
device on all ranks except rank 0 (or on meta on all ranks and
full_sd populated on rank 0).
full_sd (`dict`): The full state dict to load (only rank 0 needs
real data; non-rank-0 ranks may pass an empty dict).
"""
import torch.distributed as dist
from torch.distributed.tensor import distribute_tensor

# Model was previously copied to meta device
meta_sharded_sd = model.state_dict()
sharded_sd = {}

# Rank 0 distributes the full state dict to other ranks
def _infer_parameter_dtype(model, param_name, empty_param):
try:
old_param = model.get_parameter_or_buffer(param_name)
except AttributeError:
# Need this for LORA, as there some params are not *parameters* of sorts
base_param_name, local_param_name = param_name.rsplit(".", 1)
submodule = model.get_submodule(base_param_name)
old_param = getattr(submodule, local_param_name)

is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
casting_dtype = None
is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn

if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
casting_dtype = old_param.dtype

return old_param is not None and old_param.is_contiguous(), casting_dtype

def _cast_and_contiguous(tensor, to_contiguous, dtype):
if dtype is not None:
tensor = tensor.to(dtype=dtype)
if to_contiguous:
tensor = tensor.contiguous()
return tensor

if dist.get_rank() == 0:
for (param_name, full_param), sharded_param in zip(full_sd.items(), meta_sharded_sd.values()):
full_param = full_param.detach().cuda()
mesh = sharded_param.device_mesh
dist.broadcast(full_param, src=0)
sharded_tensor = distribute_tensor(full_param, mesh, sharded_param.placements)
to_contiguous, casting_dtype = _infer_parameter_dtype(
model,
param_name,
full_param,
)
sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype)
sharded_sd[param_name] = sharded_tensor
# We need this else to have a matching `broadcast` for all of the ranks, else we deadlock
else:
for param_name, sharded_param in meta_sharded_sd.items():
full_tensor = torch.empty(sharded_param.size(), device="cuda", dtype=sharded_param.dtype)
mesh = sharded_param.device_mesh
dist.broadcast(full_tensor, src=0)
sharded_tensor = distribute_tensor(full_tensor, mesh, sharded_param.placements)
to_contiguous, casting_dtype = _infer_parameter_dtype(
model,
param_name,
full_tensor,
)
sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype)
sharded_sd[param_name] = sharded_tensor

# we set `assign=True` because our params can be on meta device
model.load_state_dict(sharded_sd, assign=True)

# Broadcast non-persistent buffers (e.g. inv_freq from RotaryEmbedding) that
# are excluded from state_dict. On non-rank-0 meta-init these are still on
# meta device with no data; rank 0 has the correctly computed values.
_sync_non_persistent_buffers(model, sharded_sd)

# If we don't offload FSDP2 Module to CPU and then back to GPU,
# it will occupy a large amount of reserved GPU memory,which can not be released using torch.cuda.empty_cache()
# even if we are using cpu_offload
# TODO (erictang000): this requires an additional offload + backload, see if this can be avoided
# Credit: https://github.com/volcengine/verl/pull/1667
offload_fsdp2_model_to_cpu(model)

torch.cuda.synchronize()
torch.cuda.empty_cache()
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
set_model_state_dict,
)

set_model_state_dict(
model,
full_sd,
options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True),
)

if not cpu_offload:
load_fsdp2_model_to_gpu(model)
# Broadcast non-persistent buffers (e.g. inv_freq from RotaryEmbedding)
# that are excluded from state_dict. On non-rank-0 meta-init these are
# still on the meta device with no data; rank 0 has the correct values.
_sync_non_persistent_buffers(model, model.state_dict())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The second argument to _sync_non_persistent_buffers is not utilized by the function implementation. Calling model.state_dict() on an FSDP2 model, while relatively cheap as it returns sharded DTensors, still incurs unnecessary overhead and creates a large dictionary of tensors for no reason.

Suggested change
_sync_non_persistent_buffers(model, model.state_dict())
_sync_non_persistent_buffers(model, {})


# NOTE: removed the offload_fsdp2_model_to_cpu + load_fsdp2_model_to_gpu
# dance from verl PR #1667. That trick was meant to release reserved-but-
# unallocated cache memory, but for FSDP2 the offload is a no-op
# (`model.to("cpu")` doesn't move FSDPParam-managed storage) and the reload
# then *allocates a second copy*. Diagnostic on Qwen3.5-35B-A3B:
# after set_model_state_dict: 32.71 GiB
# after offload+empty_cache: 32.71 GiB (unchanged)
# after load: 65.42 GiB (doubled)
# set_model_state_dict already leaves us with exactly the shard on GPU;
# no clean-up is needed.
if cpu_offload:
# Caller asked for CPU-resident params; the offload path is still
# broken for FSDP2 but we leave the request explicit so a future fix
# has an obvious hook.
offload_fsdp2_model_to_cpu(model)
return model


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
enables chunked weight updates from training to inference using the
start/update/finish lifecycle:

start_weight_update -> one or more update_weights_chunk -> finish_weight_update
start_weight_update -> one or more update_weights_ipc -> finish_weight_update

This separates the layerwise reload initialization/finalization from individual
chunk transfers, allowing weights to be sent in bounded-memory chunks rather
Expand Down Expand Up @@ -64,7 +64,7 @@ class NewInferenceWorkerWrap:

Provides a three-phase weight update protocol via collective_rpc:
1. start_weight_update: Prepare model for receiving weights
2. update_weights_chunk: Receive and load one chunk of weights
2. update_weights_ipc: Receive and load one chunk of weights
3. finish_weight_update: Finalize the model after all chunks

Attributes accessed from the host GPUWorker (via mixin inheritance):
Expand All @@ -82,7 +82,7 @@ def start_weight_update(self, is_checkpoint_format: bool = True) -> None:
machinery which moves layers to meta device and wraps weight loaders
to defer processing until all weights for each layer are loaded.

Must be called before any update_weights_chunk calls.
Must be called before any update_weights_ipc calls.

Args:
is_checkpoint_format: True if incoming weights are in checkpoint
Expand All @@ -108,7 +108,7 @@ def start_weight_update(self, is_checkpoint_format: bool = True) -> None:
self._skyrl_is_checkpoint_format = is_checkpoint_format
self._skyrl_weight_update_active = True

def update_weights_chunk(self, update_info: dict) -> None:
def update_weights_ipc(self, update_info: dict) -> None:
"""
Receive and load a single chunk of weights.

Expand All @@ -127,7 +127,7 @@ def update_weights_chunk(self, update_info: dict) -> None:
- ipc_handles_pickled: b64(pickle({gpu_uuid: (func, args)}))
"""
if not getattr(self, "_skyrl_weight_update_active", False):
raise RuntimeError("start_weight_update must be called before update_weights_chunk.")
raise RuntimeError("start_weight_update must be called before update_weights_ipc.")

if self.weight_transfer_engine is None:
raise RuntimeError(
Expand Down Expand Up @@ -178,13 +178,53 @@ def update_weights_chunk(self, update_info: dict) -> None:
# before the sender drops its reference on the next barrier).
torch.accelerator.synchronize()

def update_weights_nccl(self, update_info: dict) -> None:
"""
Receive a batched weight update via vLLM's NCCL weight transfer engine.

Alternative to update_weights_ipc for the broadcast (non-IPC) sender:
the trainer initiates an NCCL broadcast via
NCCLWeightTransferEngine.trainer_send_weights, and each inference
worker calls weight_transfer_engine.receive_weights here.

Routed through this skyrl wrap (rather than vLLM's native
/update_weights endpoint) so the load is wrapped with
set_current_vllm_config — process_weights_after_loading on MoE
models can otherwise instantiate kernels (e.g. FlashInfer CUTLASS)
whose __init__ reads get_current_vllm_config().

TODO: remove once the upstream vLLM patch lands (vllm-project/vllm
weight-sync-fix), then route via the native /update_weights endpoint.
https://github.com/vllm-project/vllm/pull/42577
"""
if not getattr(self, "_skyrl_weight_update_active", False):
raise RuntimeError("start_weight_update must be called before update_weights_nccl.")

if self.weight_transfer_engine is None:
raise RuntimeError(
"Weight transfer not configured. Please set weight_transfer_config to enable weight transfer."
)

from vllm.config import set_current_vllm_config

typed_update_info = self.weight_transfer_engine.parse_update_info(update_info)
model = self.model_runner.model

with set_current_vllm_config(self.vllm_config), torch.device(self.device):
self.weight_transfer_engine.receive_weights(
typed_update_info,
load_weights=model.load_weights,
)

torch.accelerator.synchronize()

def finish_weight_update(self) -> None:
"""
Finalize the current weight update.

For checkpoint-format weights, runs layerwise postprocessing
(quantization repacking, attention weight processing, etc.).
Must be called after all update_weights_chunk calls are done.
Must be called after all update_weights_ipc calls are done.
"""
if not getattr(self, "_skyrl_weight_update_active", False):
raise RuntimeError("start_weight_update must be called before finish_weight_update.")
Expand Down
Loading
Loading