Skip to content

[workers] Clean teardown on SIGTERM: drain CUDA + destroy process group#1688

Open
xyuzh wants to merge 2 commits into
NovaSky-AI:mainfrom
xyuzh:xinyu/worker-sigterm-cleanup
Open

[workers] Clean teardown on SIGTERM: drain CUDA + destroy process group#1688
xyuzh wants to merge 2 commits into
NovaSky-AI:mainfrom
xyuzh:xinyu/worker-sigterm-cleanup

Conversation

@xyuzh
Copy link
Copy Markdown

@xyuzh xyuzh commented May 18, 2026

Why

When a SkyRL worker actor runs on Anyscale (or any k8s cluster) and the pod gets evicted — preemption, scale-down, spot reclaim, node drain — Kubernetes sends SIGTERM to the container and starts a 25 s grace timer before SIGKILL. The worker process is in the middle of NCCL collectives. If it dies without tearing down, the NCCL communicator stays half-open until its 600 s watchdog timeout, and the rest of the actor group blocks on a collective that will never complete.

That's exactly what we've been seeing on the staging cluster. A representative trace from a job that got preempted at step 6:

(autoscaler +44m27s) Instance k-1a5e56fa0001 (node IP: 10.1.140.199) will be terminated soon (reason: kubernetes-pod-termination).
(autoscaler +44m27s) [autoscaler] Cluster resized to {184 CPU, 8 GPU}.
(autoscaler +44m27s) Instance k-6e0b1ffb0000 (node IP: 10.1.39.138) will be terminated soon (reason: kubernetes-pod-termination).
(autoscaler +44m27s) Instance k-1a5e56fa0000 (node IP: 10.1.159.37) will be terminated soon (reason: kubernetes-pod-termination).
Traceback (most recent call last):
  ...
  File ".../skyrl/train/utils/utils.py", line 801, in get_ray_pg_ready_with_timeout
    raise RuntimeError(
RuntimeError: Failed to create placement group with 1 bundles (requiring 2.0 GPUs, 1.0 CPUs total) in 180 seconds.

Then the very next attempt to start a job inherited the dangling state:

ray.exceptions.LocalRayletDiedError: The task's local raylet died. Check raylet.out for more information.

Same shape on the worker side without the handler — the eviction goes through but the NCCL teardown never happens:

torch.distributed.DistStoreError: Timed out after 601 seconds waiting for clients. 1/16 clients joined.

The pod never gets to release its NCCL communicator before SIGKILL hits, so the next cluster spins up and one rank's process group state is already gone — the other 15 wait 601 s and the whole run dies.

What this PR changes

Inside DistributedTorchRayActor.init_worker_process_group(), right after torch.distributed.init_process_group(...), install a SIGTERM handler:

def _sigterm_cleanup(signum, frame):
    logger.warning(f"SIGTERM received in worker rank={rank}, cleaning up...")
    try:
        torch.cuda.synchronize()
    except Exception as e:
        logger.warning(f"cuda.synchronize() failed: {e}")
    try:
        if torch.distributed.is_available() and torch.distributed.is_initialized():
            torch.distributed.destroy_process_group()
    except Exception as e:
        logger.warning(f"destroy_process_group() failed: {e}")
    sys.exit(0)

signal.signal(signal.SIGTERM, _sigterm_cleanup)

Two steps, both guarded:

  1. torch.cuda.synchronize() — drain any in-flight CUDA streams so the GPU isn't in the middle of an op when we shut down.
  2. torch.distributed.destroy_process_group() — release the NCCL communicator cleanly, guarded by is_available() and is_initialized() so we don't blow up on CPU-only or pre-init workers.

Each step has its own try/except, so a worker that's only half-healthy still tears down the half that works.

Timing

The whole handler finishes in well under a second when no collective is in flight, and within a couple of seconds when one is — easily inside the 25 s k8s grace window. The pod exits cleanly with code 0; k8s releases the GPU; the next job starts on a fresh slate.

Scope

  • Applies to every actor that goes through DistributedTorchRayActor.init_worker_process_group() — i.e. all Megatron / FSDP policy, ref, critic, value workers.
  • No behavior change for jobs that don't get evicted: the handler just sits installed.

Test plan

After this PR, the anyscale job terminates successfully with releasing the GPU resources after the job exits
https://console.anyscale-staging.com/cld_82np1njz31y9lwk56mc2xcc23x/prj_69fd51u496ygpnaza62xsmis64/jobs/prodjob_vdybitri4yjxffmyepw9jy2mab?job-tab=overview&job-logs-section-tabs=application_logs

…cess group

When a k8s pod is evicted (preemption, scale-down, node drain) the container
gets SIGTERM with a 25s grace period before SIGKILL. Without a handler, in-flight
NCCL collectives leak communicators and the next run may hit stale process group
state.

Add a SIGTERM handler inside DistributedTorchRayActor.init_worker_process_group()
that:
  - calls torch.cuda.synchronize() to drain any in-flight CUDA work
  - calls torch.distributed.destroy_process_group() to release NCCL
  - exits cleanly with sys.exit(0)

Both calls are wrapped in try/except so a partial-state worker still tears down
the half that's healthy. The whole sequence is well under the 25s grace window.

Each call is guarded (`torch.distributed.is_available() and is_initialized()`)
so it does nothing when distributed isn't set up yet.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements a graceful teardown mechanism for worker processes by registering a SIGTERM signal handler that synchronizes CUDA streams and destroys the distributed process group. The review suggests adding a check for torch.cuda.is_available() before calling torch.cuda.synchronize() to avoid unnecessary exceptions in CPU-only environments.

Comment thread skyrl/backends/skyrl_train/workers/worker.py Outdated
Skip the call entirely on CPU-only environments so we don't generate a noisy
warning every time a non-CUDA worker is terminated. Only emit a warning if
synchronize() actually fails on a CUDA-capable system.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant