Skip to content

Commit 0373e18

Browse files
authored
[Core][CUDA Graph] add output buffer for cudagraph (vllm-project#5074)
[Core][CUDA Graph] add output buffer for cudagraph to reduce memory footprint (vllm-project#5074)
1 parent c09dade commit 0373e18

File tree

1 file changed

+20
-4
lines changed

1 file changed

+20
-4
lines changed

Diff for: vllm/worker/model_runner.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import gc
12
import time
23
import warnings
34
from collections import defaultdict
@@ -894,6 +895,10 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
894895
seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
895896
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
896897

898+
# Prepare buffer for outputs. These will be reused for all batch sizes.
899+
# It will be filled after the first graph capture.
900+
hidden_states: Optional[torch.Tensor] = None
901+
897902
graph_batch_size = _get_graph_batch_size(
898903
self.scheduler_config.max_num_seqs)
899904
batch_size_capture_list = [
@@ -930,9 +935,11 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
930935
self.set_active_loras(set(), lora_mapping)
931936

932937
graph_runner = CUDAGraphRunner(self.model)
933-
graph_runner.capture(
938+
hidden_states = graph_runner.capture(
934939
input_tokens[:batch_size],
935940
input_positions[:batch_size],
941+
hidden_states[:batch_size]
942+
if hidden_states is not None else None,
936943
kv_caches,
937944
attn_metadata,
938945
memory_pool=self.graph_memory_pool,
@@ -969,12 +976,13 @@ def capture(
969976
self,
970977
input_ids: torch.Tensor,
971978
positions: torch.Tensor,
979+
hidden_states: Optional[torch.Tensor],
972980
kv_caches: List[torch.Tensor],
973981
attn_metadata: AttentionMetadata,
974982
memory_pool: Optional[Tuple[int, int]],
975983
stream: torch.cuda.Stream,
976984
**kwargs,
977-
) -> None:
985+
) -> torch.Tensor:
978986
assert self._graph is None
979987
# Run the model a few times without capturing the graph.
980988
# This is to make sure that the captured graph does not include the
@@ -993,13 +1001,21 @@ def capture(
9931001
# Capture the graph.
9941002
self._graph = torch.cuda.CUDAGraph()
9951003
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
996-
hidden_states = self.model(
1004+
output_hidden_states = self.model(
9971005
input_ids,
9981006
positions,
9991007
kv_caches,
10001008
attn_metadata,
10011009
**kwargs,
10021010
)
1011+
if hidden_states is not None:
1012+
hidden_states.copy_(output_hidden_states)
1013+
else:
1014+
hidden_states = output_hidden_states
1015+
del output_hidden_states
1016+
# make sure `output_hidden_states` is deleted
1017+
# in the graph's memory pool
1018+
gc.collect()
10031019
torch.cuda.synchronize()
10041020

10051021
# Save the input and output buffers.
@@ -1012,7 +1028,7 @@ def capture(
10121028
"block_tables": attn_metadata.decode_metadata.block_tables,
10131029
}
10141030
self.output_buffers = {"hidden_states": hidden_states}
1015-
return
1031+
return hidden_states
10161032

10171033
def forward(
10181034
self,

0 commit comments

Comments
 (0)