@@ -61,29 +61,81 @@ def _is_cpu_clone_active() -> bool:
6161 return getattr (_CPU_CLONE_GUARD , "active" , False )
6262
6363
64+ def _full_zeros_preserving_strides (x : torch .Tensor , device ) -> torch .Tensor :
65+ """Allocate a zero-filled tensor matching ``x``'s size/stride/dtype on ``device``.
66+
67+ Used to re-synthesize KV-cache buffers whose storage was freed (``resize_(0)``)
68+ during the low-memory device move. KV content is all zeros, so this exactly
69+ reproduces the buffer for both the lifted graph value and serialization.
70+ """
71+ needed = 1
72+ for size , stride in zip (x .size (), x .stride ()):
73+ needed += (size - 1 ) * stride
74+ buf = torch .zeros (int (needed ), dtype = x .dtype , device = device )
75+ return torch .as_strided (buf , x .size (), x .stride ())
76+
77+
78+ def _is_emptied (x ) -> bool :
79+ return (
80+ isinstance (x , torch .Tensor )
81+ and x .numel () > 0
82+ and x .untyped_storage ().nbytes () == 0
83+ )
84+
85+
6486@contextlib .contextmanager
6587def _compile_time_cpu_clones (target_device : torch .device ):
6688 """Force AOTI's mutated-buffer clones onto CPU while preserving the
6789 serialized constants' target device."""
68- from torch ._inductor import compile_fx as _cfx
90+ from torch ._inductor import compile_fx as _cfx , graph as _graph
6991 from torch ._inductor .codegen .cpp_wrapper_cpu import CppWrapperCpu as _Cpp
92+ from torch ._inductor .graph import GraphLowering as _GL
7093
7194 orig_clone = _cfx .clone_preserve_strides
7295 orig_codegen_device = _Cpp .codegen_device
96+ orig_get_const = _GL .get_original_value_of_constant
97+ orig_is_same = _graph .is_same_tensor
98+
99+ def _is_same_skip_emptied (data , value ):
100+ # KV buffers freed via resize_(0) all have data_ptr 0, so the stock
101+ # is_same_tensor would treat every same-shape KV constant as a duplicate
102+ # and collapse the 60 layers' caches into one — the runtime needs each
103+ # FQN's own buffer, so the collapsed ones load uninitialized garbage.
104+ # Never dedup an emptied tensor.
105+ if _is_emptied (data ) or _is_emptied (value ):
106+ return False
107+ return orig_is_same (data , value )
73108
74109 def _cpu_clone_preserve_strides (x : torch .Tensor ) -> torch .Tensor :
75- # `clone_preserve_strides` is shared by `_unlift_graph` (clones
76- # lifted buffers — can be safely kept on CPU) and by autotuning code
77- # in `triton_heuristics.py` (clones for benchmark — must stay on
78- # GPU for Triton). Discriminate by caller frame so we only force
79- # CPU clones for the buffer-lifting path.
110+ # `clone_preserve_strides` is shared by `_unlift_graph` (clones lifted
111+ # buffers — can be safely kept on CPU) and by autotuning code in
112+ # `triton_heuristics.py` (clones for benchmark — must stay on GPU for
113+ # Triton). Discriminate by caller frame so we only force CPU clones for
114+ # the buffer-lifting path.
80115 import sys
81116
82117 caller = sys ._getframe (1 ).f_code .co_name
83118 if caller == "_unlift_graph" :
119+ # KV-cache buffers are emptied (storage resize_(0)) by the low-memory
120+ # device move so they never occupy GPU memory during compile. Their
121+ # content is all zeros, so re-synthesize zeros (on CPU, strides
122+ # preserved) instead of cloning the now-empty storage.
123+ if _is_emptied (x ):
124+ return _full_zeros_preserving_strides (x , "cpu" )
84125 return orig_clone (x ).cpu ()
85126 return orig_clone (x )
86127
128+ def _get_const_synthesize_zeros (self , name ):
129+ # AOTI serializes each constant via get_original_value_of_constant ->
130+ # _to_bytes. For KV buffers we freed with resize_(0) this would otherwise
131+ # fall back to the empty-storage constant and write 0 bytes, producing a
132+ # .ptd with an uninitialized cache. Re-synthesize the zeros so the blob
133+ # holds a correctly-zeroed KV cache.
134+ value = orig_get_const (self , name )
135+ if _is_emptied (value ):
136+ return _full_zeros_preserving_strides (value , "cpu" )
137+ return value
138+
87139 def _codegen_device_target_aware (self , device ):
88140 # Translate accidental CPU device strings back to the model target
89141 # device only when a constant we forced to CPU is being serialized.
@@ -99,6 +151,8 @@ def _codegen_device_target_aware(self, device):
99151
100152 _cfx .clone_preserve_strides = _cpu_clone_preserve_strides
101153 _Cpp .codegen_device = _codegen_device_target_aware
154+ _GL .get_original_value_of_constant = _get_const_synthesize_zeros
155+ _graph .is_same_tensor = _is_same_skip_emptied
102156 prev_active = getattr (_CPU_CLONE_GUARD , "active" , False )
103157 _CPU_CLONE_GUARD .active = True
104158 try :
@@ -107,6 +161,89 @@ def _codegen_device_target_aware(self, device):
107161 _CPU_CLONE_GUARD .active = prev_active
108162 _cfx .clone_preserve_strides = orig_clone
109163 _Cpp .codegen_device = orig_codegen_device
164+ _GL .get_original_value_of_constant = orig_get_const
165+ _graph .is_same_tensor = orig_is_same
166+
167+
168+ def _is_kv_buffer (name , v ) -> bool :
169+ return (
170+ isinstance (v , torch .Tensor )
171+ and not isinstance (v , torch .nn .Parameter )
172+ and "kv_cache" in name
173+ )
174+
175+
176+ def _empty_strided_on_device (v , location ):
177+ """A device tensor with v's shape/stride/dtype but zero (freed) storage."""
178+ t = torch .empty_strided (v .shape , v .stride (), dtype = v .dtype , device = location )
179+ t .untyped_storage ().resize_ (0 ) # free bytes, keep device + shape/stride
180+ return t
181+
182+
183+ def _move_graph_nodes_to_device (graph_module , location ):
184+ """Point node device kwargs / aten.to.device targets / meta vals at location."""
185+ import torch .utils ._pytree as pytree
186+
187+ def _to_loc (v ):
188+ return v .to (location ) if isinstance (v , torch .Tensor ) else v
189+
190+ for m in graph_module .modules ():
191+ if not isinstance (m , torch .fx .GraphModule ):
192+ continue
193+ for node in m .graph .nodes :
194+ if "device" in node .kwargs :
195+ node .kwargs = {** node .kwargs , "device" : location }
196+ if node .op == "call_function" and node .target is torch .ops .aten .to .device :
197+ args = list (node .args )
198+ args [1 ] = location
199+ node .args = tuple (args )
200+ node .meta ["val" ] = pytree .tree_map (_to_loc , node .meta .get ("val" ))
201+
202+
203+ def _move_to_device_resize_kv (ep , location ):
204+ """``move_to_device_pass`` variant that frees KV-cache storage on-device.
205+
206+ Mirrors ``torch.export.passes.move_to_device_pass`` exactly, except KV-cache
207+ buffers (FQN contains ``kv_cache``) are placed on ``location`` but with their
208+ storage immediately freed via ``resize_(0)``. This keeps ``device ==
209+ location`` — so the fake-tensor device check on the ``index_copy`` cache
210+ update passes (``self`` and ``values`` both on cuda) — while no real KV bytes
211+ occupy the device during the AOTI compile. KV content is all zeros, so the
212+ emptied tensors are re-synthesized as zeros at the ``_unlift_graph`` clone
213+ (see ``_compile_time_cpu_clones``), which is reused as both the lifted initial
214+ value and the serialized ``.ptd`` constant. The empty/free is interleaved per
215+ tensor so the transient device peak is a single KV buffer, not the whole cache.
216+ Only ``kv_cache`` tensors are emptied (they are the lone large zero-buffers);
217+ every other tensor is moved normally so non-zero content is never lost.
218+ """
219+ import torch .utils ._pytree as pytree
220+
221+ for k , v in ep .state_dict .items ():
222+ if isinstance (v , torch .nn .Parameter ):
223+ ep ._state_dict [k ] = torch .nn .Parameter (v .to (location ), v .requires_grad )
224+ elif _is_kv_buffer (k , v ):
225+ ep ._state_dict [k ] = _empty_strided_on_device (v , location )
226+ else :
227+ ep ._state_dict [k ] = v .to (location )
228+
229+ for k , v in ep .constants .items ():
230+ if isinstance (v , torch .Tensor ):
231+ ep ._constants [k ] = (
232+ _empty_strided_on_device (v , location )
233+ if _is_kv_buffer (k , v )
234+ else v .to (location )
235+ )
236+
237+ if ep .example_inputs is not None :
238+ args , kwargs = ep .example_inputs
239+ ep ._example_inputs = (
240+ pytree .tree_map_only (torch .Tensor , lambda t : t .to (location ), args ),
241+ pytree .tree_map_only (torch .Tensor , lambda t : t .to (location ), kwargs ),
242+ )
243+
244+ _move_graph_nodes_to_device (ep .graph_module , location )
245+ ep .validate ()
246+ return ep
110247
111248
112249@final
@@ -424,6 +561,29 @@ def _is_low_memory_mode(compile_specs: List[CompileSpec]) -> bool:
424561 return spec .value .decode ("utf-8" ).upper () == "ON"
425562 return False
426563
564+ @classmethod
565+ def move_program_to_device (
566+ cls ,
567+ edge_program ,
568+ device : str ,
569+ compile_specs : List [CompileSpec ],
570+ ):
571+ """Move the program to ``device`` for AOTI compile.
572+
573+ On a low-memory export (``low_memory_mode="ON"``) the KV-cache buffers —
574+ which can be 10+ GiB at long context — are placed on-device but with their
575+ storage freed (``resize_(0)``), so they never occupy device memory during
576+ the autotune / cpp_wrapper compile while still satisfying the device-match
577+ check on the cache update. They are re-synthesized as zeros for the lifted
578+ graph and the serialized blob. This activates automatically with low-memory
579+ mode. Other (non-low-memory) exports use the stock pass.
580+ """
581+ from torch .export .passes import move_to_device_pass
582+
583+ if not cls ._is_low_memory_mode (compile_specs ):
584+ return move_to_device_pass (edge_program , device )
585+ return _move_to_device_resize_kv (edge_program , device )
586+
427587 @classmethod
428588 def release_moved_tensors (
429589 cls ,
0 commit comments