@@ -28,7 +28,7 @@ def msg_with_rank(rank: int, msg: str):
2828 return f'rank[{ rank } ] - { msg } '
2929
3030
31- def cache_swapping (cache_engine : CacheEngine , swap_in_map : dict , swap_out_map : dict ):
31+ def cache_swapping (cache_engine : CacheEngine , swap_in_map : dict , swap_out_map : dict , copy_map : dict ):
3232 """perform cache swapping."""
3333 issued_cache_op = False
3434 if len (swap_in_map ) > 0 :
@@ -37,7 +37,9 @@ def cache_swapping(cache_engine: CacheEngine, swap_in_map: dict, swap_out_map: d
3737 if len (swap_out_map ) > 0 :
3838 cache_engine .swap_out (swap_out_map )
3939 issued_cache_op = True
40-
40+ if len (copy_map ) > 0 :
41+ cache_engine .copy_to (copy_map )
42+ issued_cache_op = True
4143 if issued_cache_op :
4244 cache_engine .events .wait ()
4345
@@ -63,7 +65,6 @@ def model_forward(
6365 kv_quant_policy = cache_engine .cache_config .quant_policy ,
6466 )
6567 with ctx_mgr .context (context ):
66- model_metas = None
6768 model_metas = model .update_model_metas (
6869 past_key_values = cache_engine .gpu_cache ,
6970 context = context ,
@@ -123,7 +124,7 @@ def all_context(self):
123124 with device_mgr .context (self .device_ctx ), dist_mgr .context (self .dist_ctx ):
124125 yield
125126
126- async def async_forward (self , inputs : ModelInputs , swap_in_map : SwapMap , swap_out_map : SwapMap ):
127+ async def async_forward (self , inputs : ModelInputs , swap_in_map : SwapMap , swap_out_map : SwapMap , copy_map : SwapMap ):
127128 """model forward.
128129
129130 Args:
@@ -172,7 +173,7 @@ def get_free_mem(self):
172173 gpu_mem_physical_free , _ = get_gpu_memory ()
173174 return gpu_mem_physical_free
174175
175- async def _async_model_forward (self , inputs : ModelInputs , swap_in_map : Dict , swap_out_map : Dict ,
176+ async def _async_model_forward (self , inputs : ModelInputs , swap_in_map : Dict , swap_out_map : Dict , copy_map : Dict ,
176177 return_logits : bool ):
177178 """model forward."""
178179 max_prefill_token_num = self .cache_config .max_prefill_token_num
@@ -212,12 +213,15 @@ def get_output(self):
212213
213214 async def __forward (inputs ):
214215 """forward."""
215- nonlocal swap_done , swap_in_map , swap_out_map
216+ nonlocal swap_done , swap_in_map , swap_out_map , copy_map
216217 if swap_done :
217- return await self .async_forward (inputs , swap_in_map = dict (), swap_out_map = dict ())
218+ return await self .async_forward (inputs , swap_in_map = dict (), swap_out_map = dict (), copy_map = dict () )
218219 else :
219220 swap_done = True
220- return await self .async_forward (inputs , swap_in_map = swap_in_map , swap_out_map = swap_out_map )
221+ return await self .async_forward (inputs ,
222+ swap_in_map = swap_in_map ,
223+ swap_out_map = swap_out_map ,
224+ copy_map = copy_map )
221225
222226 async def __long_context_single_forward (inputs ):
223227 """one large sequence."""
@@ -278,7 +282,7 @@ def __get_last_logits():
278282
279283 return next_token_ids
280284
281- async def _async_step_background (self , inputs : ModelInputs , swap_in_map : Dict , swap_out_map : Dict ,
285+ async def _async_step_background (self , inputs : ModelInputs , swap_in_map : Dict , swap_out_map : Dict , copy_map : Dict ,
282286 all_ids : torch .Tensor , guided_input_ids : torch .Tensor ,
283287 sampling_inputs : SamplingInputs , num_appendable_ids : torch .LongTensor ,
284288 num_ignore_eos : torch .LongTensor , loop_count : int , return_logits : bool ,
@@ -322,6 +326,7 @@ def __update_inputs(next_token_ids):
322326 output = await self ._async_model_forward (inputs ,
323327 swap_in_map = swap_in_map ,
324328 swap_out_map = swap_out_map ,
329+ copy_map = copy_map ,
325330 return_logits = return_logits )
326331 logits = output ['logits' ]
327332 logits = logits [0 ] # [bs, seq, prob] -> [seq, prob]
@@ -359,6 +364,7 @@ def __update_inputs(next_token_ids):
359364 if is_decoding and idx < loop_count - 1 :
360365 swap_in_map = dict ()
361366 swap_out_map = dict ()
367+ copy_map = dict ()
362368 inputs .model_metas = model_metas
363369 __update_inputs (next_token_ids )
364370
@@ -516,8 +522,8 @@ def build_cache_engine(self):
516522 with self .all_context ():
517523 self .cache_engine = CacheEngine (self .cache_config , self .model_config , self .tp_rank , world_size = self .tp )
518524
519- def _forward_impl (self , inputs : ModelInputs , swap_in_map : SwapMap , swap_out_map : SwapMap ):
520- cache_swapping (self .cache_engine , swap_in_map = swap_in_map , swap_out_map = swap_out_map )
525+ def _forward_impl (self , inputs : ModelInputs , swap_in_map : SwapMap , swap_out_map : SwapMap , copy_map : SwapMap ):
526+ cache_swapping (self .cache_engine , swap_in_map = swap_in_map , swap_out_map = swap_out_map , copy_map = copy_map )
521527 output = model_forward (
522528 self .patched_model ,
523529 inputs ,
@@ -527,15 +533,15 @@ def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map:
527533 )
528534 return output
529535
530- async def async_forward (self , inputs : ModelInputs , swap_in_map : SwapMap , swap_out_map : SwapMap ):
536+ async def async_forward (self , inputs : ModelInputs , swap_in_map : SwapMap , swap_out_map : SwapMap , copy_map : SwapMap ):
531537 """model forward.
532538
533539 Args:
534540 inputs (Dict): The input data comes from _make_inputs.
535541 swap_in_map (SwapMap): Cache maps to swap in.
536542 swap_out_map (SwapMap): Cache maps to swap out.
537543 """
538- output = self ._forward_impl (inputs , swap_in_map = swap_in_map , swap_out_map = swap_out_map )
544+ output = self ._forward_impl (inputs , swap_in_map = swap_in_map , swap_out_map = swap_out_map , copy_map = copy_map )
539545 await asyncio .sleep (0 )
540546 return output
541547
0 commit comments