Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions atom/model_engine/model_runner.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

import gc
import logging
import math
import os
import time
from contextlib import nullcontext
from contextlib import contextmanager, nullcontext
from typing import Any, Optional, Union

import numpy as np
Expand Down Expand Up @@ -2283,7 +2284,18 @@ def capture_cudagraph(self):
# TBO graphs don't capture compute_logits, so disable logits_in_graph.
self.logits_in_graph = self.world_size == 1 and not is_tbo

with graph_capture() as gc:
@contextmanager
def pause_gc():
# No GC during capture: a finalizer's hipModuleUnload aborts it (HIP 900).
gc.collect()
gc.disable()
try:
yield
finally:
gc.enable()
gc.collect()
Comment on lines +2288 to +2296

with pause_gc(), graph_capture() as capture_ctx:
capture_range = (
tqdm.tqdm(self.graph_bs) if self.rank == 0 else self.graph_bs
)
Expand Down Expand Up @@ -2362,7 +2374,7 @@ def capture_cudagraph(self):
input_ids[:num_tokens],
positions[:num_tokens],
self.graph_pool,
gc.stream,
capture_ctx.stream,
output_buffer=outputs[:num_tokens],
)
graph_aux = None
Expand All @@ -2374,7 +2386,9 @@ def capture_cudagraph(self):
if self.use_mrope
else positions[:num_tokens]
)
with torch.cuda.graph(graph, self.graph_pool, stream=gc.stream):
with torch.cuda.graph(
graph, self.graph_pool, stream=capture_ctx.stream
):
model_output = self.model(
input_ids[:num_tokens],
model_positions,
Expand Down
Loading