diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index 8f16d4c10..aaf878949 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -1,11 +1,12 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +import gc import logging import math import os import time -from contextlib import nullcontext +from contextlib import contextmanager, nullcontext from typing import Any, Optional, Union import numpy as np @@ -2283,7 +2284,18 @@ def capture_cudagraph(self): # TBO graphs don't capture compute_logits, so disable logits_in_graph. self.logits_in_graph = self.world_size == 1 and not is_tbo - with graph_capture() as gc: + @contextmanager + def pause_gc(): + # No GC during capture: a finalizer's hipModuleUnload aborts it (HIP 900). + gc.collect() + gc.disable() + try: + yield + finally: + gc.enable() + gc.collect() + + with pause_gc(), graph_capture() as capture_ctx: capture_range = ( tqdm.tqdm(self.graph_bs) if self.rank == 0 else self.graph_bs ) @@ -2362,7 +2374,7 @@ def capture_cudagraph(self): input_ids[:num_tokens], positions[:num_tokens], self.graph_pool, - gc.stream, + capture_ctx.stream, output_buffer=outputs[:num_tokens], ) graph_aux = None @@ -2374,7 +2386,9 @@ def capture_cudagraph(self): if self.use_mrope else positions[:num_tokens] ) - with torch.cuda.graph(graph, self.graph_pool, stream=gc.stream): + with torch.cuda.graph( + graph, self.graph_pool, stream=capture_ctx.stream + ): model_output = self.model( input_ids[:num_tokens], model_positions,