1
+ import gc
1
2
import time
2
3
import warnings
3
4
from collections import defaultdict
@@ -894,6 +895,10 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
894
895
seq_lens = torch .ones (max_batch_size , dtype = torch .int32 ).cuda ()
895
896
block_tables = torch .from_numpy (self .graph_block_tables ).cuda ()
896
897
898
+ # Prepare buffer for outputs. These will be reused for all batch sizes.
899
+ # It will be filled after the first graph capture.
900
+ hidden_states : Optional [torch .Tensor ] = None
901
+
897
902
graph_batch_size = _get_graph_batch_size (
898
903
self .scheduler_config .max_num_seqs )
899
904
batch_size_capture_list = [
@@ -930,9 +935,11 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
930
935
self .set_active_loras (set (), lora_mapping )
931
936
932
937
graph_runner = CUDAGraphRunner (self .model )
933
- graph_runner .capture (
938
+ hidden_states = graph_runner .capture (
934
939
input_tokens [:batch_size ],
935
940
input_positions [:batch_size ],
941
+ hidden_states [:batch_size ]
942
+ if hidden_states is not None else None ,
936
943
kv_caches ,
937
944
attn_metadata ,
938
945
memory_pool = self .graph_memory_pool ,
@@ -969,12 +976,13 @@ def capture(
969
976
self ,
970
977
input_ids : torch .Tensor ,
971
978
positions : torch .Tensor ,
979
+ hidden_states : Optional [torch .Tensor ],
972
980
kv_caches : List [torch .Tensor ],
973
981
attn_metadata : AttentionMetadata ,
974
982
memory_pool : Optional [Tuple [int , int ]],
975
983
stream : torch .cuda .Stream ,
976
984
** kwargs ,
977
- ) -> None :
985
+ ) -> torch . Tensor :
978
986
assert self ._graph is None
979
987
# Run the model a few times without capturing the graph.
980
988
# This is to make sure that the captured graph does not include the
@@ -993,13 +1001,21 @@ def capture(
993
1001
# Capture the graph.
994
1002
self ._graph = torch .cuda .CUDAGraph ()
995
1003
with torch .cuda .graph (self ._graph , pool = memory_pool , stream = stream ):
996
- hidden_states = self .model (
1004
+ output_hidden_states = self .model (
997
1005
input_ids ,
998
1006
positions ,
999
1007
kv_caches ,
1000
1008
attn_metadata ,
1001
1009
** kwargs ,
1002
1010
)
1011
+ if hidden_states is not None :
1012
+ hidden_states .copy_ (output_hidden_states )
1013
+ else :
1014
+ hidden_states = output_hidden_states
1015
+ del output_hidden_states
1016
+ # make sure `output_hidden_states` is deleted
1017
+ # in the graph's memory pool
1018
+ gc .collect ()
1003
1019
torch .cuda .synchronize ()
1004
1020
1005
1021
# Save the input and output buffers.
@@ -1012,7 +1028,7 @@ def capture(
1012
1028
"block_tables" : attn_metadata .decode_metadata .block_tables ,
1013
1029
}
1014
1030
self .output_buffers = {"hidden_states" : hidden_states }
1015
- return
1031
+ return hidden_states
1016
1032
1017
1033
def forward (
1018
1034
self ,
0 commit comments