@@ -28,7 +28,7 @@ def msg_with_rank(rank: int, msg: str):
2828 return f'rank[{ rank } ] - { msg } '
2929
3030
31- def cache_swapping (cache_engine : CacheEngine , swap_in_map : dict , swap_out_map : dict ):
31+ def cache_swapping (cache_engine : CacheEngine , swap_in_map : dict , swap_out_map : dict , copy_map : dict ):
3232 """perform cache swapping."""
3333 issued_cache_op = False
3434 if len (swap_in_map ) > 0 :
@@ -37,7 +37,9 @@ def cache_swapping(cache_engine: CacheEngine, swap_in_map: dict, swap_out_map: d
3737 if len (swap_out_map ) > 0 :
3838 cache_engine .swap_out (swap_out_map )
3939 issued_cache_op = True
40-
40+ if len (copy_map ) > 0 :
41+ cache_engine .copy_to (copy_map )
42+ issued_cache_op = True
4143 if issued_cache_op :
4244 cache_engine .events .wait ()
4345
@@ -135,7 +137,7 @@ def all_context(self):
135137 def _forward_impl (self , inputs : ModelInputs , swap_in_map : SwapMap , swap_out_map : SwapMap ):
136138 raise NotImplementedError ('NotImplemented.' )
137139
138- async def async_forward (self , inputs : ModelInputs , swap_in_map : SwapMap , swap_out_map : SwapMap ):
140+ async def async_forward (self , inputs : ModelInputs , swap_in_map : SwapMap , swap_out_map : SwapMap , copy_map : SwapMap ):
139141 """model forward.
140142
141143 Args:
@@ -200,6 +202,7 @@ async def _async_model_forward(
200202 inputs : ModelInputs ,
201203 swap_in_map : Dict ,
202204 swap_out_map : Dict ,
205+ copy_map : Dict ,
203206 return_logits : bool ,
204207 sync_long_context : bool ,
205208 ):
@@ -241,12 +244,13 @@ def get_output(self):
241244
242245 async def __forward (inputs ):
243246 """forward."""
244- nonlocal swap_done , swap_in_map , swap_out_map
247+ nonlocal swap_done , swap_in_map , swap_out_map , copy_map
248+ breakpoint ()
245249 if swap_done :
246- return await self .async_forward (inputs , swap_in_map = dict (), swap_out_map = dict ())
250+ return await self .async_forward (inputs , swap_in_map = dict (), swap_out_map = dict (), copy_map = dict () )
247251 else :
248252 swap_done = True
249- return await self .async_forward (inputs , swap_in_map = swap_in_map , swap_out_map = swap_out_map )
253+ return await self .async_forward (inputs , swap_in_map = swap_in_map , swap_out_map = swap_out_map , copy_map = copy_map )
250254
251255 async def __long_context_single_forward (new_inputs , max_seqlen : int ):
252256 """one large sequence."""
@@ -334,6 +338,7 @@ async def _async_step_background(
334338 inputs : ModelInputs ,
335339 swap_in_map : Dict ,
336340 swap_out_map : Dict ,
341+ copy_map : Dict ,
337342 loop_count : int ,
338343 all_ids : torch .Tensor = None ,
339344 guided_input_ids : torch .Tensor = None ,
@@ -420,6 +425,7 @@ async def __await_distworker(worker, timeout: float = 0.001):
420425 inputs ,
421426 swap_in_map = swap_in_map ,
422427 swap_out_map = swap_out_map ,
428+ copy_map = copy_map ,
423429 return_logits = return_logits ,
424430 sync_long_context = sync_long_context ,
425431 )
@@ -467,6 +473,7 @@ async def __await_distworker(worker, timeout: float = 0.001):
467473 if is_decoding and idx < loop_count - 1 :
468474 swap_in_map = dict ()
469475 swap_out_map = dict ()
476+ copy_map = dict ()
470477 inputs .model_metas = model_metas
471478 __update_inputs (next_token_ids )
472479
@@ -637,8 +644,8 @@ def build_cache_engine(self):
637644
638645 self .cache_engine = CacheEngine (self .cache_config , self .model_config , world_size = tp )
639646
640- def _forward_impl (self , inputs : ModelInputs , swap_in_map : SwapMap , swap_out_map : SwapMap ):
641- cache_swapping (self .cache_engine , swap_in_map = swap_in_map , swap_out_map = swap_out_map )
647+ def _forward_impl (self , inputs : ModelInputs , swap_in_map : SwapMap , swap_out_map : SwapMap , copy_map : SwapMap ):
648+ cache_swapping (self .cache_engine , swap_in_map = swap_in_map , swap_out_map = swap_out_map , copy_map = copy_map )
642649 output = model_forward (
643650 self .patched_model ,
644651 inputs ,
@@ -647,15 +654,15 @@ def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map:
647654 )
648655 return output
649656
650- async def async_forward (self , inputs : ModelInputs , swap_in_map : SwapMap , swap_out_map : SwapMap ):
657+ async def async_forward (self , inputs : ModelInputs , swap_in_map : SwapMap , swap_out_map : SwapMap , copy_map : SwapMap ):
651658 """model forward.
652659
653660 Args:
654661 inputs (Dict): The input data comes from _make_inputs.
655662 swap_in_map (SwapMap): Cache maps to swap in.
656663 swap_out_map (SwapMap): Cache maps to swap out.
657664 """
658- output = self ._forward_impl (inputs , swap_in_map = swap_in_map , swap_out_map = swap_out_map )
665+ output = self ._forward_impl (inputs , swap_in_map = swap_in_map , swap_out_map = swap_out_map , copy_map = copy_map )
659666 await asyncio .sleep (0 )
660667 return output
661668
0 commit comments