@@ -28,7 +28,7 @@ def msg_with_rank(rank: int, msg: str):
2828 return f'rank[{ rank } ] - { msg } '
2929
3030
31- def cache_swapping (cache_engine : CacheEngine , swap_in_map : dict , swap_out_map : dict ):
31+ def cache_swapping (cache_engine : CacheEngine , swap_in_map : dict , swap_out_map : dict , copy_map : dict ):
3232 """perform cache swapping."""
3333 issued_cache_op = False
3434 if len (swap_in_map ) > 0 :
@@ -37,7 +37,9 @@ def cache_swapping(cache_engine: CacheEngine, swap_in_map: dict, swap_out_map: d
3737 if len (swap_out_map ) > 0 :
3838 cache_engine .swap_out (swap_out_map )
3939 issued_cache_op = True
40-
40+ if len (copy_map ) > 0 :
41+ cache_engine .copy_to (copy_map )
42+ issued_cache_op = True
4143 if issued_cache_op :
4244 cache_engine .events .wait ()
4345
@@ -135,7 +137,7 @@ def all_context(self):
135137 def _forward_impl (self , inputs : ModelInputs , swap_in_map : SwapMap , swap_out_map : SwapMap ):
136138 raise NotImplementedError ('NotImplemented.' )
137139
138- async def async_forward (self , inputs : ModelInputs , swap_in_map : SwapMap , swap_out_map : SwapMap ):
140+ async def async_forward (self , inputs : ModelInputs , swap_in_map : SwapMap , swap_out_map : SwapMap , copy_map : SwapMap ):
139141 """model forward.
140142
141143 Args:
@@ -200,6 +202,7 @@ async def _async_model_forward(
200202 inputs : ModelInputs ,
201203 swap_in_map : Dict ,
202204 swap_out_map : Dict ,
205+ copy_map : Dict ,
203206 return_logits : bool ,
204207 sync_long_context : bool ,
205208 ):
@@ -241,12 +244,12 @@ def get_output(self):
241244
242245 async def __forward (inputs ):
243246 """forward."""
244- nonlocal swap_done , swap_in_map , swap_out_map
247+ nonlocal swap_done , swap_in_map , swap_out_map , copy_map
245248 if swap_done :
246- return await self .async_forward (inputs , swap_in_map = dict (), swap_out_map = dict ())
249+ return await self .async_forward (inputs , swap_in_map = dict (), swap_out_map = dict (), copy_map = dict () )
247250 else :
248251 swap_done = True
249- return await self .async_forward (inputs , swap_in_map = swap_in_map , swap_out_map = swap_out_map )
252+ return await self .async_forward (inputs , swap_in_map = swap_in_map , swap_out_map = swap_out_map , copy_map = copy_map )
250253
251254 async def __long_context_single_forward (new_inputs , max_seqlen : int ):
252255 """one large sequence."""
@@ -334,6 +337,7 @@ async def _async_step_background(
334337 inputs : ModelInputs ,
335338 swap_in_map : Dict ,
336339 swap_out_map : Dict ,
340+ copy_map : Dict ,
337341 loop_count : int ,
338342 all_ids : torch .Tensor = None ,
339343 guided_input_ids : torch .Tensor = None ,
@@ -420,6 +424,7 @@ async def __await_distworker(worker, timeout: float = 0.001):
420424 inputs ,
421425 swap_in_map = swap_in_map ,
422426 swap_out_map = swap_out_map ,
427+ copy_map = copy_map ,
423428 return_logits = return_logits ,
424429 sync_long_context = sync_long_context ,
425430 )
@@ -467,6 +472,7 @@ async def __await_distworker(worker, timeout: float = 0.001):
467472 if is_decoding and idx < loop_count - 1 :
468473 swap_in_map = dict ()
469474 swap_out_map = dict ()
475+ copy_map = dict ()
470476 inputs .model_metas = model_metas
471477 __update_inputs (next_token_ids )
472478
@@ -637,8 +643,8 @@ def build_cache_engine(self):
637643
638644 self .cache_engine = CacheEngine (self .cache_config , self .model_config , world_size = tp )
639645
640- def _forward_impl (self , inputs : ModelInputs , swap_in_map : SwapMap , swap_out_map : SwapMap ):
641- cache_swapping (self .cache_engine , swap_in_map = swap_in_map , swap_out_map = swap_out_map )
646+ def _forward_impl (self , inputs : ModelInputs , swap_in_map : SwapMap , swap_out_map : SwapMap , copy_map : SwapMap ):
647+ cache_swapping (self .cache_engine , swap_in_map = swap_in_map , swap_out_map = swap_out_map , copy_map = copy_map )
642648 output = model_forward (
643649 self .patched_model ,
644650 inputs ,
@@ -647,15 +653,15 @@ def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map:
647653 )
648654 return output
649655
650- async def async_forward (self , inputs : ModelInputs , swap_in_map : SwapMap , swap_out_map : SwapMap ):
656+ async def async_forward (self , inputs : ModelInputs , swap_in_map : SwapMap , swap_out_map : SwapMap , copy_map : SwapMap ):
651657 """model forward.
652658
653659 Args:
654660 inputs (Dict): The input data comes from _make_inputs.
655661 swap_in_map (SwapMap): Cache maps to swap in.
656662 swap_out_map (SwapMap): Cache maps to swap out.
657663 """
658- output = self ._forward_impl (inputs , swap_in_map = swap_in_map , swap_out_map = swap_out_map )
664+ output = self ._forward_impl (inputs , swap_in_map = swap_in_map , swap_out_map = swap_out_map , copy_map = copy_map )
659665 await asyncio .sleep (0 )
660666 return output
661667
0 commit comments