Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions tests/pytorch/distributed/test_comm_gemm_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#
# See LICENSE for license information.
import os
import subprocess
from pathlib import Path

import pytest
Expand All @@ -15,6 +14,8 @@
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.pytorch.utils import get_device_compute_capability

from utils import run_proctree_with_timeout as run_subprocess


if torch.cuda.device_count() < 2:
pytest.skip("Comm+GEMM overlap requires at least 2 GPUs.")
Expand Down Expand Up @@ -88,7 +89,8 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, aggregate, quantization
if aggregate:
test_cmd.append("--aggregate")

result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False)
result = run_subprocess(test_cmd, 120 if IS_HIP_EXTENSION else None,
env=os.environ, capture_output=True, check=False)
if (
result.returncode != 0
or "NUMERICAL CHECK FAILED" in result.stderr.decode()
Expand Down Expand Up @@ -143,7 +145,8 @@ def _run_layer_with_overlap(
# not show up in more recent GPUs.
os.environ["NVTE_FLASH_ATTN"] = "0"

result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False)
result = run_subprocess(test_cmd, 120 if IS_HIP_EXTENSION else None,
env=os.environ, capture_output=True, check=False)

os.unsetenv("PYTORCH_JIT")
os.unsetenv("NVTE_TORCH_COMPILE")
Expand Down
9 changes: 7 additions & 2 deletions tests/pytorch/distributed/test_torch_fsdp2.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import os
import pytest
import subprocess
from pathlib import Path
import transformer_engine.pytorch as te

import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION

from utils import run_proctree_with_timeout as run_subprocess


fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
Expand All @@ -32,7 +36,8 @@ def _run_test(fp_init, sharding_dims, recipe, layer_type):
test_cmd += ["--recipe", recipe]
test_cmd += ["--layer-type", layer_type]

result = subprocess.run(test_cmd, env=os.environ, check=True)
result = run_subprocess(test_cmd, 120 if IS_HIP_EXTENSION else None, env=os.environ,
check=True)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of using run_proctree_with_timeout, could we just use coreutils' timeout command?

This could perhaps even become a pytest fixture:

# in conftest.py:
_DEFAULT_TIMEOUT = 120
_DEFAULT_KILL_AFTER = 30


def pytest_configure(config):
    config.addinivalue_line(
        "markers",
        "subprocess_timeout(seconds, kill_after=30): "
        "wrap subprocess.run calls with coreutils timeout to detect hangs",
    )


@pytest.fixture(autouse=True)
def subprocess_timeout(request, monkeypatch):
    marker = request.node.get_closest_marker("subprocess_timeout")
    if marker is None:
        return

    seconds = str(marker.args[0]) if marker.args else str(_DEFAULT_TIMEOUT)
    kill_after = str(marker.kwargs.get("kill_after", _DEFAULT_KILL_AFTER))

    original_run = subprocess.run

    def _run_with_timeout(cmd, *args, **kwargs):
        cmd = ["timeout", f"--kill-after={kill_after}", seconds] + list(cmd)
        result = original_run(cmd, *args, **kwargs)
        if result.returncode == 124:
            pytest.fail(f"Subprocess timed out after {seconds}s (hang detected)")
        return result

    monkeypatch.setattr(subprocess, "run", _run_with_timeout)

# for the test:
@pytest.mark.subprocess_timeout(20, kill_after=2)
def test_xyz():

This would have a few advantages:

  • No sys.path modification, shorter code
  • No code changes inside test functions
  • signal/kill whole process group natively

@ipanfilo ipanfilo May 29, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is brilliant idea. It can be done even more simple by modifying command line right in tests. I've modified the PR. The first CI run is at https://github.com/ROCm/TransformerEngine/actions/runs/26611894136/job/78425262492

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good.

There are no timeout-related messages in the distributed-test logs at https://github.com/ROCm/TransformerEngine/actions/runs/26611894136/job/78425262492 as far as I can tell, and the distributed tests passed, so the mechanism was not exercised in that run. Not sure if you want to test this further within this PR, but either way, good to go from my side.



@pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs")
Expand Down
9 changes: 6 additions & 3 deletions tests/pytorch/distributed/test_torch_fsdp2_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import os
from typing import List
import pytest
import subprocess
from pathlib import Path
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.quantization import FP8GlobalStateManager
import torch
from run_fsdp2_fp8_model import SimpleNet
from utils import run_proctree_with_timeout as run_subprocess

fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
Expand Down Expand Up @@ -49,8 +49,11 @@ def _run_test(fp_init, recipe):
test_cmd += ["--fp8-init"]
test_cmd += ["--recipe", recipe]

subprocess.run(test_cmd + ['--use-fsdp2','--gradients-save-file', 'all_iters_fsdp2.pt'], env=os.environ, check=True)
subprocess.run(test_cmd + ['--gradients-save-file', 'all_iters_dp.pt'], env=os.environ, check=True)
timeout = 120
run_subprocess(test_cmd + ['--use-fsdp2','--gradients-save-file', 'all_iters_fsdp2.pt'],
timeout, env=os.environ, check=True)
run_subprocess(test_cmd + ['--gradients-save-file', 'all_iters_dp.pt'], timeout,
env=os.environ, check=True)

# Load outputs
output_fsdp = torch.load("all_iters_fsdp2.pt", map_location="cpu")
Expand Down
60 changes: 60 additions & 0 deletions tests/pytorch/distributed/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
#
# See LICENSE for license information.

import os, signal, subprocess


def run_proctree_with_timeout(cmd, timeout, **kwargs):
"""Run a command in a subprocess and check for errors."""

if timeout is None:
return subprocess.run(cmd, **kwargs)

if "timeout" in kwargs:
raise ValueError("Timeout should be passed as a separate argument, not in kwargs")

capture_output = kwargs.pop("capture_output", False)
if capture_output:
kwargs["stdout"] = subprocess.PIPE
kwargs["stderr"] = subprocess.PIPE
else:
stdout, stderr = None, None

check = kwargs.pop("check", False)

kwargs["start_new_session"] = True # To use killpg as termination fallback
p = subprocess.Popen(cmd, **kwargs)
try:
if capture_output:
stdout, stderr = p.communicate(timeout=timeout)
else:
p.wait(timeout=timeout)
except subprocess.TimeoutExpired:
p.terminate()
try:
# Give the process time to terminate gracefully
if capture_output:
stdout, stderr = p.communicate(timeout=timeout)
else:
p.wait(timeout=timeout)
except subprocess.TimeoutExpired:
os.killpg(p.pid, signal.SIGKILL)
if capture_output:
stdout, stderr = p.communicate()

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description states the goal is to "kill all child processes on timeout", but p.terminate() only sends SIGTERM to p.pid (the session leader — i.e. torchrun), not to the whole process group. The actual hung workers are children of torchrun, and torchrun generally won't forward SIGTERM to them promptly. So the "graceful" stage almost never reclaims the workers; only the killpg(SIGKILL) fallback does, and that's gated behind a second full timeout wait. With timeout=120, the worst-case wall-clock before the hang is actually cleared is ~240s.

Two suggestions:

  1. Make the graceful stage match the stated intent by signaling the whole group: os.killpg(p.pid, signal.SIGTERM) instead of p.terminate().
  2. Use a short, bounded grace window (e.g. 10–30s) for the second wait/communicate rather than reusing timeout, so SIGKILL fires promptly when graceful shutdown is ignored.

Secondary concern: swallowing TimeoutExpired and returning a CompletedProcess with a negative returncode makes hang-killed runs indistinguishable from ordinary failures in the test logs. Since the whole point of this wrapper is to surface hangs, consider either re-raising TimeoutExpired (matching subprocess.run(timeout=...) semantics) or at minimum logging a clear "timed out, killed" line so CI failures are diagnosable.


# Handle check=True
if check and p.returncode != 0:
raise subprocess.CalledProcessError(
cmd,
kwargs.get("args", None),
output=stdout,
stderr=stderr
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CalledProcessError constructor signature is CalledProcessError(returncode, cmd, output=None, stderr=None), but this call passes the wrong values in the wrong order: the cmd list goes where returncode (an int) is expected, and kwargs.get("args", None) is always None here because args was never inserted into kwargscmd is a positional parameter of this function, not a kwarg. The exception will still raise, but exc.returncode will be a list and exc.cmd will be None, breaking any caller that introspects the exception (and obscuring the failure in pytest tracebacks). Suggested fix:

Suggested change
raise subprocess.CalledProcessError(
cmd,
kwargs.get("args", None),
output=stdout,
stderr=stderr
)
# Handle check=True
if check and p.returncode != 0:
raise subprocess.CalledProcessError(
p.returncode,
cmd,
output=stdout,
stderr=stderr,
)


return subprocess.CompletedProcess(
cmd,
p.returncode,
stdout,
stderr
)
Loading