[workers] Clean teardown on SIGTERM: drain CUDA + destroy process group#1688
Open
xyuzh wants to merge 2 commits into
Open
[workers] Clean teardown on SIGTERM: drain CUDA + destroy process group#1688xyuzh wants to merge 2 commits into
xyuzh wants to merge 2 commits into
Conversation
…cess group When a k8s pod is evicted (preemption, scale-down, node drain) the container gets SIGTERM with a 25s grace period before SIGKILL. Without a handler, in-flight NCCL collectives leak communicators and the next run may hit stale process group state. Add a SIGTERM handler inside DistributedTorchRayActor.init_worker_process_group() that: - calls torch.cuda.synchronize() to drain any in-flight CUDA work - calls torch.distributed.destroy_process_group() to release NCCL - exits cleanly with sys.exit(0) Both calls are wrapped in try/except so a partial-state worker still tears down the half that's healthy. The whole sequence is well under the 25s grace window. Each call is guarded (`torch.distributed.is_available() and is_initialized()`) so it does nothing when distributed isn't set up yet.
Contributor
There was a problem hiding this comment.
Code Review
This pull request implements a graceful teardown mechanism for worker processes by registering a SIGTERM signal handler that synchronizes CUDA streams and destroys the distributed process group. The review suggests adding a check for torch.cuda.is_available() before calling torch.cuda.synchronize() to avoid unnecessary exceptions in CPU-only environments.
Skip the call entirely on CPU-only environments so we don't generate a noisy warning every time a non-CUDA worker is terminated. Only emit a warning if synchronize() actually fails on a CUDA-capable system.
4 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Why
When a SkyRL worker actor runs on Anyscale (or any k8s cluster) and the pod gets evicted — preemption, scale-down, spot reclaim, node drain — Kubernetes sends
SIGTERMto the container and starts a 25 s grace timer beforeSIGKILL. The worker process is in the middle of NCCL collectives. If it dies without tearing down, the NCCL communicator stays half-open until its 600 s watchdog timeout, and the rest of the actor group blocks on a collective that will never complete.That's exactly what we've been seeing on the staging cluster. A representative trace from a job that got preempted at step 6:
Then the very next attempt to start a job inherited the dangling state:
Same shape on the worker side without the handler — the eviction goes through but the NCCL teardown never happens:
The pod never gets to release its NCCL communicator before SIGKILL hits, so the next cluster spins up and one rank's process group state is already gone — the other 15 wait 601 s and the whole run dies.
What this PR changes
Inside
DistributedTorchRayActor.init_worker_process_group(), right aftertorch.distributed.init_process_group(...), install a SIGTERM handler:Two steps, both guarded:
torch.cuda.synchronize()— drain any in-flight CUDA streams so the GPU isn't in the middle of an op when we shut down.torch.distributed.destroy_process_group()— release the NCCL communicator cleanly, guarded byis_available() and is_initialized()so we don't blow up on CPU-only or pre-init workers.Each step has its own
try/except, so a worker that's only half-healthy still tears down the half that works.Timing
The whole handler finishes in well under a second when no collective is in flight, and within a couple of seconds when one is — easily inside the 25 s k8s grace window. The pod exits cleanly with code 0; k8s releases the GPU; the next job starts on a fresh slate.
Scope
DistributedTorchRayActor.init_worker_process_group()— i.e. all Megatron / FSDP policy, ref, critic, value workers.Test plan
After this PR, the anyscale job terminates successfully with releasing the GPU resources after the job exits
https://console.anyscale-staging.com/cld_82np1njz31y9lwk56mc2xcc23x/prj_69fd51u496ygpnaza62xsmis64/jobs/prodjob_vdybitri4yjxffmyepw9jy2mab?job-tab=overview&job-logs-section-tabs=application_logs