-
Notifications
You must be signed in to change notification settings - Fork 331
[ci] H100 CI #1679
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[ci] H100 CI #1679
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,69 @@ | ||
| name: H100-GPU-CI | ||
|
|
||
| on: | ||
| push: | ||
| branches: | ||
| - main | ||
| paths: | ||
| - 'ci/**' | ||
| - 'skyrl/backends/skyrl_train/**' | ||
| - 'skyrl/train/**' | ||
| - 'tests/backends/skyrl_train/gpu/gpu_ci/**' | ||
| - 'pyproject.toml' | ||
| - '!docs/**' | ||
| - '!examples/**' | ||
| - '.github/workflows/**' | ||
| pull_request_target: | ||
| types: [labeled] | ||
| workflow_dispatch: | ||
|
|
||
|
|
||
| permissions: | ||
| checks: write # for status checks to appear | ||
| contents: read | ||
|
|
||
| jobs: | ||
|
|
||
| skyrl_train_tests_h100: | ||
| if: > | ||
| github.event_name == 'push' || | ||
| github.event_name == 'workflow_dispatch' || | ||
| ( | ||
| github.event_name == 'pull_request_target' && | ||
| !github.event.pull_request.draft && | ||
| contains(github.event.pull_request.labels.*.name, 'run_h100_gpu_ci') && | ||
| ( | ||
| github.event.pull_request.author_association == 'MEMBER' || | ||
| github.event.pull_request.author_association == 'OWNER' || | ||
| github.event.pull_request.author_association == 'COLLABORATOR' | ||
| ) | ||
| ) | ||
| runs-on: ubuntu-latest | ||
| defaults: | ||
| run: | ||
| shell: bash | ||
| working-directory: . | ||
|
|
||
| steps: | ||
| - uses: actions/checkout@v4 | ||
| with: | ||
| ref: ${{ github.event.pull_request.head.sha || github.ref }} | ||
| - name: Set up Python | ||
| uses: actions/setup-python@v5 | ||
| with: | ||
| python-version: '3.12' | ||
| cache: 'pip' | ||
| - name: Install the latest version of uv | ||
| uses: astral-sh/setup-uv@v6 | ||
| with: | ||
| activate-environment: true | ||
| - name: Install basic dependencies | ||
| run: uv pip install anyscale==0.24.79 typer==0.9.0 | ||
| # Run h100 tests via anyscale staging (compute config llm-team-h100-4x:1) | ||
| - name: GPU tests | ||
| env: | ||
| ANYSCALE_CLI_TOKEN: ${{ secrets.ANYSCALE_CLI_TOKEN_STAGING }} | ||
| ANYSCALE_HOST: https://console.anyscale-staging.com | ||
| run: | | ||
| anyscale job submit -f ci/anyscale_gpu_ci_h100.yaml --timeout 5400 | ||
| anyscale job wait --name skyrl-train-gpu-ci-h100 --timeout 5400 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| name: skyrl-train-gpu-ci-h100 | ||
| entrypoint: bash ci/gpu_ci_run_h100.sh | ||
| image_uri: novaskyai/skyrl-train-ray-2.51.1-py3.12-cu12.8-megatron | ||
| ray_version: "2.51.1" | ||
| compute_config: llm-team-h100-4x:1 | ||
| working_dir: . | ||
| max_retries: 0 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| #!/usr/bin/env bash | ||
| set -xeuo pipefail | ||
|
|
||
| export CI=true | ||
|
|
||
| # Prepare datasets used in tests (Megatron test uses gsm8k env_class). | ||
| uv run examples/train/gsm8k/gsm8k_dataset.py --output_dir $HOME/data/gsm8k | ||
|
|
||
| # Run FSDP h100 tests. Use --extra fsdp so the fsdp-only tests can collect. | ||
| uv run --directory . --isolated --extra dev --extra fsdp pytest -s -vvv -m h100 \ | ||
| tests/backends/skyrl_train/gpu/gpu_ci/test_policy_local_engines_e2e.py | ||
|
|
||
| # Run Megatron h100 tests. Use --extra megatron so the megatron-only tests can collect. | ||
| uv run --directory . --isolated --extra dev --extra megatron pytest -s -vvv -m h100 \ | ||
| tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_megatron_models.py |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -129,100 +129,56 @@ def _sync_non_persistent_buffers(model: torch.nn.Module, loaded_sd: dict): | |||||
| module._buffers[key] = src.cpu() | ||||||
|
|
||||||
|
|
||||||
| # Fsdp2 load full state dict from `accelerate` | ||||||
| # Reference: https://github.com/huggingface/accelerate/blob/0af621bbecc0e43f5d43766a4945d3d2236bb8a9/src/accelerate/utils/fsdp_utils.py#L455 | ||||||
| # NOTE (sumanthrh): The original code from `accelerate` assumes init on meta device - with cpu init only on rank 0, but the code is compatible with cpu init on all ranks. | ||||||
| def fsdp2_load_full_state_dict(model: torch.nn.Module, full_sd: dict, cpu_offload=None): | ||||||
| """ | ||||||
| Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the | ||||||
| parameters from rank 0 to all other ranks. This function modifies the model in-place. | ||||||
| Loads a full state dict (assumed populated on rank 0) into a sharded FSDP2 | ||||||
| model by broadcasting from rank 0 and distributing per-parameter shards. | ||||||
|
|
||||||
| Uses PyTorch's `set_model_state_dict` with `full_state_dict=True` + | ||||||
| `broadcast_from_rank0=True`, which internally does the per-parameter | ||||||
| broadcast and DTensor distribution that we used to do manually. The | ||||||
| utility also releases intermediate staging tensors as it goes, so the | ||||||
| caching allocator can return memory to the device pool after init. | ||||||
|
|
||||||
| Args: | ||||||
| model (`torch.nn.Module`): | ||||||
| The model to load the state dict into, expected to be on meta device or a VRAM spike can occur | ||||||
| full_sd (`dict`): The full state dict to load, can be only on rank 0 | ||||||
| The model to load the state dict into, expected to be on meta | ||||||
| device on all ranks except rank 0 (or on meta on all ranks and | ||||||
| full_sd populated on rank 0). | ||||||
| full_sd (`dict`): The full state dict to load (only rank 0 needs | ||||||
| real data; non-rank-0 ranks may pass an empty dict). | ||||||
| """ | ||||||
| import torch.distributed as dist | ||||||
| from torch.distributed.tensor import distribute_tensor | ||||||
|
|
||||||
| # Model was previously copied to meta device | ||||||
| meta_sharded_sd = model.state_dict() | ||||||
| sharded_sd = {} | ||||||
|
|
||||||
| # Rank 0 distributes the full state dict to other ranks | ||||||
| def _infer_parameter_dtype(model, param_name, empty_param): | ||||||
| try: | ||||||
| old_param = model.get_parameter_or_buffer(param_name) | ||||||
| except AttributeError: | ||||||
| # Need this for LORA, as there some params are not *parameters* of sorts | ||||||
| base_param_name, local_param_name = param_name.rsplit(".", 1) | ||||||
| submodule = model.get_submodule(base_param_name) | ||||||
| old_param = getattr(submodule, local_param_name) | ||||||
|
|
||||||
| is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") | ||||||
| casting_dtype = None | ||||||
| is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn | ||||||
|
|
||||||
| if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn: | ||||||
| casting_dtype = old_param.dtype | ||||||
|
|
||||||
| return old_param is not None and old_param.is_contiguous(), casting_dtype | ||||||
|
|
||||||
| def _cast_and_contiguous(tensor, to_contiguous, dtype): | ||||||
| if dtype is not None: | ||||||
| tensor = tensor.to(dtype=dtype) | ||||||
| if to_contiguous: | ||||||
| tensor = tensor.contiguous() | ||||||
| return tensor | ||||||
|
|
||||||
| if dist.get_rank() == 0: | ||||||
| for (param_name, full_param), sharded_param in zip(full_sd.items(), meta_sharded_sd.values()): | ||||||
| full_param = full_param.detach().cuda() | ||||||
| mesh = sharded_param.device_mesh | ||||||
| dist.broadcast(full_param, src=0) | ||||||
| sharded_tensor = distribute_tensor(full_param, mesh, sharded_param.placements) | ||||||
| to_contiguous, casting_dtype = _infer_parameter_dtype( | ||||||
| model, | ||||||
| param_name, | ||||||
| full_param, | ||||||
| ) | ||||||
| sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype) | ||||||
| sharded_sd[param_name] = sharded_tensor | ||||||
| # We need this else to have a matching `broadcast` for all of the ranks, else we deadlock | ||||||
| else: | ||||||
| for param_name, sharded_param in meta_sharded_sd.items(): | ||||||
| full_tensor = torch.empty(sharded_param.size(), device="cuda", dtype=sharded_param.dtype) | ||||||
| mesh = sharded_param.device_mesh | ||||||
| dist.broadcast(full_tensor, src=0) | ||||||
| sharded_tensor = distribute_tensor(full_tensor, mesh, sharded_param.placements) | ||||||
| to_contiguous, casting_dtype = _infer_parameter_dtype( | ||||||
| model, | ||||||
| param_name, | ||||||
| full_tensor, | ||||||
| ) | ||||||
| sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype) | ||||||
| sharded_sd[param_name] = sharded_tensor | ||||||
|
|
||||||
| # we set `assign=True` because our params can be on meta device | ||||||
| model.load_state_dict(sharded_sd, assign=True) | ||||||
|
|
||||||
| # Broadcast non-persistent buffers (e.g. inv_freq from RotaryEmbedding) that | ||||||
| # are excluded from state_dict. On non-rank-0 meta-init these are still on | ||||||
| # meta device with no data; rank 0 has the correctly computed values. | ||||||
| _sync_non_persistent_buffers(model, sharded_sd) | ||||||
|
|
||||||
| # If we don't offload FSDP2 Module to CPU and then back to GPU, | ||||||
| # it will occupy a large amount of reserved GPU memory,which can not be released using torch.cuda.empty_cache() | ||||||
| # even if we are using cpu_offload | ||||||
| # TODO (erictang000): this requires an additional offload + backload, see if this can be avoided | ||||||
| # Credit: https://github.com/volcengine/verl/pull/1667 | ||||||
| offload_fsdp2_model_to_cpu(model) | ||||||
|
|
||||||
| torch.cuda.synchronize() | ||||||
| torch.cuda.empty_cache() | ||||||
| from torch.distributed.checkpoint.state_dict import ( | ||||||
| StateDictOptions, | ||||||
| set_model_state_dict, | ||||||
| ) | ||||||
|
|
||||||
| set_model_state_dict( | ||||||
| model, | ||||||
| full_sd, | ||||||
| options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True), | ||||||
| ) | ||||||
|
|
||||||
| if not cpu_offload: | ||||||
| load_fsdp2_model_to_gpu(model) | ||||||
| # Broadcast non-persistent buffers (e.g. inv_freq from RotaryEmbedding) | ||||||
| # that are excluded from state_dict. On non-rank-0 meta-init these are | ||||||
| # still on the meta device with no data; rank 0 has the correct values. | ||||||
| _sync_non_persistent_buffers(model, model.state_dict()) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The second argument to
Suggested change
|
||||||
|
|
||||||
| # NOTE: removed the offload_fsdp2_model_to_cpu + load_fsdp2_model_to_gpu | ||||||
| # dance from verl PR #1667. That trick was meant to release reserved-but- | ||||||
| # unallocated cache memory, but for FSDP2 the offload is a no-op | ||||||
| # (`model.to("cpu")` doesn't move FSDPParam-managed storage) and the reload | ||||||
| # then *allocates a second copy*. Diagnostic on Qwen3.5-35B-A3B: | ||||||
| # after set_model_state_dict: 32.71 GiB | ||||||
| # after offload+empty_cache: 32.71 GiB (unchanged) | ||||||
| # after load: 65.42 GiB (doubled) | ||||||
| # set_model_state_dict already leaves us with exactly the shard on GPU; | ||||||
| # no clean-up is needed. | ||||||
| if cpu_offload: | ||||||
| # Caller asked for CPU-resident params; the offload path is still | ||||||
| # broken for FSDP2 but we leave the request explicit so a future fix | ||||||
| # has an obvious hook. | ||||||
| offload_fsdp2_model_to_cpu(model) | ||||||
| return model | ||||||
|
|
||||||
|
|
||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Calling
state_dict()on all ranks materializes the full model state in memory on every process. Sincefsdp2_load_full_state_dictusesset_model_state_dictwithbroadcast_from_rank0=True, only rank 0 needs the actual data. On non-zero ranks, this leads to significant and unnecessary CPU memory usage, which can cause OOMs on nodes with many GPUs when training large models.