Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions skyrl/backends/skyrl_train/workers/worker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
import logging
import os
import signal
import socket
import sys
from collections import defaultdict
from ctypes import CDLL, POINTER, Structure, c_char_p, c_int, c_ulong, c_void_p
from datetime import timedelta
Expand Down Expand Up @@ -121,6 +123,28 @@ def init_worker_process_group(self):
backend="cpu:gloo,cuda:nccl", timeout=timedelta(seconds=SKYRL_WORKER_NCCL_TIMEOUT_IN_S)
)

# Clean teardown on k8s SIGTERM: drain CUDA streams + release NCCL
# communicators before the 25s grace period elapses.
rank = self._rank

def _sigterm_cleanup(signum, frame):
logger.warning(f"SIGTERM received in worker rank={rank}, cleaning up...")

try:
torch.cuda.synchronize()
except Exception as e:
logger.warning(f"cuda.synchronize() failed: {e}")
Comment thread
xyuzh marked this conversation as resolved.
Outdated

try:
if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
except Exception as e:
logger.warning(f"destroy_process_group() failed: {e}")

sys.exit(0)

signal.signal(signal.SIGTERM, _sigterm_cleanup)

# setup device mesh
# TODO: Support TP / PP for additional backends
# NOTE (sumanthrh): Device mesh and mesh rank are rank specific attributes. For the current way the strategy is defined,
Expand Down
Loading