Skip to content

Commit d7f40fc

Browse files
committed
support match unfull blocks for multimudals
1 parent 0f8dd72 commit d7f40fc

File tree

6 files changed

+322
-47
lines changed

6 files changed

+322
-47
lines changed

lmdeploy/pytorch/engine/cache_engine.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,16 @@ def swap_out(self, src_to_dst: Dict[int, int]) -> None:
246246
"""
247247
self._swap(self.full_gpu_cache, self.full_cpu_cache, src_to_dst)
248248

249+
def copy_to(self, src_to_dst: Dict[int, int], cache_type: str = 'gpu') -> None:
250+
"""Copy cache.
251+
252+
Args:
253+
src_to_dst (Dict[int, int]): Map between src and dst.
254+
cache_type (str): cache type 'cpu', 'gpu'
255+
"""
256+
target_cache = self.full_gpu_cache if cache_type == 'gpu' else self.full_cpu_cache
257+
self._swap(target_cache, target_cache, src_to_dst)
258+
249259
@classmethod
250260
def get_cache_block_size(cls,
251261
block_size: int,

lmdeploy/pytorch/engine/engine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,7 @@ def __need_logits(seqs: SeqList):
628628
running = scheduler_output.running
629629
swap_in_map = scheduler_output.swap_in_map
630630
swap_out_map = scheduler_output.swap_out_map
631+
copy_map = scheduler_output.copy_map
631632
assert len(running) > 0
632633

633634
# create inputs
@@ -645,6 +646,7 @@ def __need_logits(seqs: SeqList):
645646
inputs=inputs,
646647
swap_in_map=swap_in_map,
647648
swap_out_map=swap_out_map,
649+
copy_map=copy_map,
648650
all_ids=all_ids,
649651
guided_input_ids=guided_input_ids,
650652
sampling_inputs=sampling_inputs,

lmdeploy/pytorch/engine/model_agent.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def msg_with_rank(rank: int, msg: str):
2828
return f'rank[{rank}] - {msg}'
2929

3030

31-
def cache_swapping(cache_engine: CacheEngine, swap_in_map: dict, swap_out_map: dict):
31+
def cache_swapping(cache_engine: CacheEngine, swap_in_map: dict, swap_out_map: dict, copy_map: dict):
3232
"""perform cache swapping."""
3333
issued_cache_op = False
3434
if len(swap_in_map) > 0:
@@ -37,7 +37,9 @@ def cache_swapping(cache_engine: CacheEngine, swap_in_map: dict, swap_out_map: d
3737
if len(swap_out_map) > 0:
3838
cache_engine.swap_out(swap_out_map)
3939
issued_cache_op = True
40-
40+
if len(copy_map) > 0:
41+
cache_engine.copy_to(copy_map)
42+
issued_cache_op = True
4143
if issued_cache_op:
4244
cache_engine.events.wait()
4345

@@ -63,7 +65,6 @@ def model_forward(
6365
kv_quant_policy=cache_engine.cache_config.quant_policy,
6466
)
6567
with ctx_mgr.context(context):
66-
model_metas = None
6768
model_metas = model.update_model_metas(
6869
past_key_values=cache_engine.gpu_cache,
6970
context=context,
@@ -123,7 +124,7 @@ def all_context(self):
123124
with device_mgr.context(self.device_ctx), dist_mgr.context(self.dist_ctx):
124125
yield
125126

126-
async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap):
127+
async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap, copy_map: SwapMap):
127128
"""model forward.
128129
129130
Args:
@@ -172,7 +173,7 @@ def get_free_mem(self):
172173
gpu_mem_physical_free, _ = get_gpu_memory()
173174
return gpu_mem_physical_free
174175

175-
async def _async_model_forward(self, inputs: ModelInputs, swap_in_map: Dict, swap_out_map: Dict,
176+
async def _async_model_forward(self, inputs: ModelInputs, swap_in_map: Dict, swap_out_map: Dict, copy_map: Dict,
176177
return_logits: bool):
177178
"""model forward."""
178179
max_prefill_token_num = self.cache_config.max_prefill_token_num
@@ -212,12 +213,15 @@ def get_output(self):
212213

213214
async def __forward(inputs):
214215
"""forward."""
215-
nonlocal swap_done, swap_in_map, swap_out_map
216+
nonlocal swap_done, swap_in_map, swap_out_map, copy_map
216217
if swap_done:
217-
return await self.async_forward(inputs, swap_in_map=dict(), swap_out_map=dict())
218+
return await self.async_forward(inputs, swap_in_map=dict(), swap_out_map=dict(), copy_map=dict())
218219
else:
219220
swap_done = True
220-
return await self.async_forward(inputs, swap_in_map=swap_in_map, swap_out_map=swap_out_map)
221+
return await self.async_forward(inputs,
222+
swap_in_map=swap_in_map,
223+
swap_out_map=swap_out_map,
224+
copy_map=copy_map)
221225

222226
async def __long_context_single_forward(inputs):
223227
"""one large sequence."""
@@ -278,7 +282,7 @@ def __get_last_logits():
278282

279283
return next_token_ids
280284

281-
async def _async_step_background(self, inputs: ModelInputs, swap_in_map: Dict, swap_out_map: Dict,
285+
async def _async_step_background(self, inputs: ModelInputs, swap_in_map: Dict, swap_out_map: Dict, copy_map: Dict,
282286
all_ids: torch.Tensor, guided_input_ids: torch.Tensor,
283287
sampling_inputs: SamplingInputs, num_appendable_ids: torch.LongTensor,
284288
num_ignore_eos: torch.LongTensor, loop_count: int, return_logits: bool,
@@ -322,6 +326,7 @@ def __update_inputs(next_token_ids):
322326
output = await self._async_model_forward(inputs,
323327
swap_in_map=swap_in_map,
324328
swap_out_map=swap_out_map,
329+
copy_map=copy_map,
325330
return_logits=return_logits)
326331
logits = output['logits']
327332
logits = logits[0] # [bs, seq, prob] -> [seq, prob]
@@ -359,6 +364,7 @@ def __update_inputs(next_token_ids):
359364
if is_decoding and idx < loop_count - 1:
360365
swap_in_map = dict()
361366
swap_out_map = dict()
367+
copy_map = dict()
362368
inputs.model_metas = model_metas
363369
__update_inputs(next_token_ids)
364370

@@ -516,8 +522,8 @@ def build_cache_engine(self):
516522
with self.all_context():
517523
self.cache_engine = CacheEngine(self.cache_config, self.model_config, self.tp_rank, world_size=self.tp)
518524

519-
def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap):
520-
cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map)
525+
def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap, copy_map: SwapMap):
526+
cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map, copy_map=copy_map)
521527
output = model_forward(
522528
self.patched_model,
523529
inputs,
@@ -527,15 +533,15 @@ def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map:
527533
)
528534
return output
529535

530-
async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap):
536+
async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap, copy_map: SwapMap):
531537
"""model forward.
532538
533539
Args:
534540
inputs (Dict): The input data comes from _make_inputs.
535541
swap_in_map (SwapMap): Cache maps to swap in.
536542
swap_out_map (SwapMap): Cache maps to swap out.
537543
"""
538-
output = self._forward_impl(inputs, swap_in_map=swap_in_map, swap_out_map=swap_out_map)
544+
output = self._forward_impl(inputs, swap_in_map=swap_in_map, swap_out_map=swap_out_map, copy_map=copy_map)
539545
await asyncio.sleep(0)
540546
return output
541547

lmdeploy/pytorch/messages.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,17 @@ def __init__(self, multimodals: MultiModalInputs):
371371
if multimodals is None:
372372
multimodals = dict()
373373
self.multimodals = multimodals
374+
self._init_mm_ranges()
375+
376+
def _init_mm_ranges(self):
377+
"""init mm ranges and sort it."""
378+
mm_ranges = []
379+
for _, modal_datas in self.multimodals.items():
380+
for modal_data in modal_datas:
381+
data = (modal_data.start, modal_data.end, modal_data.meta.get('hash_value', None))
382+
mm_ranges.append(data)
383+
mm_ranges.sort(key=lambda x: x[1])
384+
self._mm_ranges = mm_ranges
374385

375386
def get_datas(self, start=0, end=-1):
376387
"""get multimodals from prompts position [start, end)."""
@@ -389,29 +400,24 @@ def get_datas(self, start=0, end=-1):
389400
def get_step(self, step: int):
390401
"""get step that before a whole image."""
391402
real_step = step
392-
for modal_type, modal_datas in self.multimodals.items():
393-
for modal_data in modal_datas:
394-
if modal_data.start > real_step:
395-
continue
396-
elif modal_data.end <= real_step:
397-
continue
398-
else:
399-
real_step = modal_data.start
403+
for start, end, _ in self._mm_ranges:
404+
if start <= real_step < end:
405+
real_step = start
400406
return real_step
401407

402408
def get_hash_values(self, start: int, end: int):
403409
"""get multimodals hash values that from [start, end)"""
404-
hash_values = []
405-
for modal_type, modal_datas in self.multimodals.items():
406-
for modal_data in modal_datas:
407-
if modal_data.start < end and modal_data.end > start:
408-
if modal_data.meta.get('hash_value', None):
409-
hash_values.append(modal_data.meta['hash_value'])
410-
if hash_values:
411-
hash_values = tuple(hash_values)
412-
else:
413-
hash_values = None
414-
return hash_values
410+
mm_hash_values = []
411+
multimodal_ends = []
412+
for mm_start, mm_end, hash_value in self._mm_ranges:
413+
# the mm range intersect with the target range
414+
if mm_start < end and mm_end > start:
415+
mm_hash_values.append(hash_value)
416+
# the mm end in the target range
417+
if start < mm_end <= end:
418+
cur_data = (tuple(mm_hash_values), mm_end)
419+
multimodal_ends.append(cur_data)
420+
return tuple(mm_hash_values), multimodal_ends
415421

416422
def add_inputs(self, input_mms: MultiModalInputs):
417423
"""add new inputs."""
@@ -421,6 +427,14 @@ def add_inputs(self, input_mms: MultiModalInputs):
421427
else:
422428
self.multimodals[modal_type] = vals
423429

430+
# update mm_ranges
431+
for modal_data in vals:
432+
data = (modal_data.start, modal_data.end, modal_data.meta.get('hash_value', None))
433+
self._mm_ranges.append(data)
434+
435+
# sort mm_ranges
436+
self._mm_ranges.sort(key=lambda x: x[1])
437+
424438
def empty(self):
425439
if len(self.multimodals) == 0:
426440
return 0

0 commit comments

Comments
 (0)