Skip to content

Commit 069399b

Browse files
authored
fix(cudagraph): pause GC during capture to prevent Triton finalizer abort (#1322) (#1339)
During CUDAGraph capture, MiniMax-M3's autotuned _topk_index_partial_kernel discards candidate CompiledKernels. A gen-0 GC firing inside the stream-capture region runs CompiledKernel.__del__ -> hipModuleUnload, which HIP forbids while a stream is capturing (HIP 900), corrupting the capture and aborting the custom_all_reduce IPC handshake (SIGABRT). gc.freeze() did not help because the discarded kernels are created mid-loop. Disable GC for the whole capture window and restore via try/finally.
1 parent b9cff14 commit 069399b

1 file changed

Lines changed: 18 additions & 4 deletions

File tree

atom/model_engine/model_runner.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
# SPDX-License-Identifier: MIT
22
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
33

4+
import gc
45
import logging
56
import math
67
import os
78
import time
8-
from contextlib import nullcontext
9+
from contextlib import contextmanager, nullcontext
910
from typing import Any, Optional, Union
1011

1112
import numpy as np
@@ -2283,7 +2284,18 @@ def capture_cudagraph(self):
22832284
# TBO graphs don't capture compute_logits, so disable logits_in_graph.
22842285
self.logits_in_graph = self.world_size == 1 and not is_tbo
22852286

2286-
with graph_capture() as gc:
2287+
@contextmanager
2288+
def pause_gc():
2289+
# No GC during capture: a finalizer's hipModuleUnload aborts it (HIP 900).
2290+
gc.collect()
2291+
gc.disable()
2292+
try:
2293+
yield
2294+
finally:
2295+
gc.enable()
2296+
gc.collect()
2297+
2298+
with pause_gc(), graph_capture() as capture_ctx:
22872299
capture_range = (
22882300
tqdm.tqdm(self.graph_bs) if self.rank == 0 else self.graph_bs
22892301
)
@@ -2362,7 +2374,7 @@ def capture_cudagraph(self):
23622374
input_ids[:num_tokens],
23632375
positions[:num_tokens],
23642376
self.graph_pool,
2365-
gc.stream,
2377+
capture_ctx.stream,
23662378
output_buffer=outputs[:num_tokens],
23672379
)
23682380
graph_aux = None
@@ -2374,7 +2386,9 @@ def capture_cudagraph(self):
23742386
if self.use_mrope
23752387
else positions[:num_tokens]
23762388
)
2377-
with torch.cuda.graph(graph, self.graph_pool, stream=gc.stream):
2389+
with torch.cuda.graph(
2390+
graph, self.graph_pool, stream=capture_ctx.stream
2391+
):
23782392
model_output = self.model(
23792393
input_ids[:num_tokens],
23802394
model_positions,

0 commit comments

Comments
 (0)