Skip to content
Open
10 changes: 10 additions & 0 deletions lmdeploy/pytorch/engine/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,16 @@ def swap_out(self, src_to_dst: Dict[int, int]) -> None:
"""
self._swap(self.full_gpu_cache, self.full_cpu_cache, src_to_dst)

def copy_to(self, src_to_dst: Dict[int, int], cache_type: str = 'gpu') -> None:
"""Copy cache.

Args:
src_to_dst (Dict[int, int]): Map between src and dst.
cache_type (str): cache type 'cpu', 'gpu'
"""
target_cache = self.full_gpu_cache if cache_type == 'gpu' else self.full_cpu_cache
self._swap(target_cache, target_cache, src_to_dst)

@classmethod
def get_cache_block_size(cls,
block_size: int,
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,7 @@ def __need_schedule_again(prefill: bool, scheduler_output):
running = scheduler_output.running
swap_in_map = scheduler_output.swap_in_map
swap_out_map = scheduler_output.swap_out_map
copy_map = scheduler_output.copy_map

if len(running) == 0:
return None
Expand All @@ -949,6 +950,7 @@ def __need_schedule_again(prefill: bool, scheduler_output):
inputs=inputs,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map,
copy_map=copy_map,
loop_count=num_loops,
sampling_inputs=sampling_inputs,
stopping_criteria=stopping_criteria,
Expand Down
12 changes: 7 additions & 5 deletions lmdeploy/pytorch/engine/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,18 +201,21 @@ def msg_with_rank(rank: int, msg: str):
return f'rank[{rank}] - {msg}'


def cache_swapping(cache_engine: CacheEngine, swap_in_map: dict, swap_out_map: dict):
def cache_swapping(cache_engine: CacheEngine, swap_in_map: dict, swap_out_map: dict, copy_map: dict):
"""Perform cache swapping."""
issued_cache_op = False
swap_in_map = swap_in_map or dict()
swap_out_map = swap_out_map or dict()
copy_map = copy_map or dict()
if len(swap_in_map) > 0:
cache_engine.swap_in(swap_in_map)
issued_cache_op = True
if len(swap_out_map) > 0:
cache_engine.swap_out(swap_out_map)
issued_cache_op = True

if len(copy_map) > 0:
cache_engine.copy_to(copy_map)
issued_cache_op = True
if issued_cache_op:
cache_engine.events.wait()

Expand Down Expand Up @@ -592,6 +595,7 @@ async def _async_step_background(
loop_count: int,
swap_in_map: Dict = None,
swap_out_map: Dict = None,
copy_map: Dict = None,
sampling_inputs: SamplingInputs = None,
stopping_criteria: StoppingCriteria = None,
return_logits: bool = False,
Expand Down Expand Up @@ -699,7 +703,7 @@ async def __prepare_dp():
# init state cache for first time prefill
# I don't know if this is necessary...
self.state_cache_engine.init_caches(inputs.state_offsets, inputs.history_lengths == 0)
cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map)
cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map, copy_map=copy_map)
for idx in range(loop_count):
# inference
logger.debug(f'<ForwardTask> rank[{rank}]: model forward [{idx}].')
Expand Down Expand Up @@ -995,8 +999,6 @@ async def async_forward(self, inputs: ModelInputs):

Args:
inputs (Dict): The input data comes from _make_inputs.
swap_in_map (SwapMap): Cache maps to swap in.
swap_out_map (SwapMap): Cache maps to swap out.
"""
output = self._forward_impl(inputs)
await asyncio.sleep(0)
Expand Down
78 changes: 72 additions & 6 deletions lmdeploy/pytorch/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,21 +413,67 @@ def __init__(self, multimodals: MultiModalInputs = None):
if multimodals is None:
multimodals = dict()
self.multimodals = multimodals
self._init_mm_ranges()

def _init_mm_ranges(self):
"""Init mm ranges and sort it."""
mm_ranges = []
for _, modal_datas in self.multimodals.items():
for modal_data in modal_datas:
data = (modal_data.start, modal_data.end, modal_data.meta.get('hash_value', None))
mm_ranges.append(data)
mm_ranges.sort(key=lambda x: x[1])
self._mm_ranges = mm_ranges

@property
def mm_ranges(self):
"""mm_ranges."""
return self._mm_ranges

def get_datas(self, start=0, end=-1):
"""Get multimodals from prompts position [start, end)."""
outs = dict()
test_range = range(start, end)
for modal_type, modal_datas in self.multimodals.items():
data = []
for modal_data in modal_datas:
if (modal_data.start not in test_range and modal_data.end not in test_range):
continue
data.append(modal_data)
if modal_data.start < end and modal_data.end > start:
data.append(modal_data)
if len(data) > 0:
outs[modal_type] = data
return outs

def get_step(self, step: int) -> int:
"""Get step that before a whole image."""
real_step = step
for start, end, _ in self._mm_ranges:
if start <= real_step < end:
real_step = start
return real_step

def has_data(self, start: int, end: int) -> bool:
"""Whether has multimodal data in [start, end)"""
return any([s < end and e > start for s, e, _ in self._mm_ranges])

def get_hash_values(self, start: int, end: int):
"""Get multimodals hash values that from [start, end)"""
mm_hash_values = []
multimodal_ends = []

for mm_start, mm_end, hash_value in self._mm_ranges:
# the mm range intersect with the target range
if mm_start < end and mm_end > start:
mm_hash_values.append(hash_value)
# the mm end in the target range
if start < mm_end <= end:
cur_data = (tuple(mm_hash_values), mm_end)
multimodal_ends.append(cur_data)

if len(mm_hash_values) == 0:
mm_hash_values = None
else:
mm_hash_values = tuple(mm_hash_values)
return mm_hash_values, multimodal_ends

def add_inputs(self, input_mms: MultiModalInputs):
"""Add new inputs."""
for modal_type, vals in input_mms.items():
Expand All @@ -436,7 +482,15 @@ def add_inputs(self, input_mms: MultiModalInputs):
else:
self.multimodals[modal_type] = vals

def empty(self):
# update mm_ranges
for modal_data in vals:
data = (modal_data.start, modal_data.end, modal_data.meta.get('hash_value', None))
self._mm_ranges.append(data)

# sort mm_ranges
self._mm_ranges.sort(key=lambda x: x[1])

def empty(self) -> bool:
if len(self.multimodals) == 0:
return True

Expand Down Expand Up @@ -655,7 +709,7 @@ def _update_multimodals(self, multimodals: MultiModalInputs):
if multimodals is None:
self._num_cross = 0
return
multimodals = HistoryMultiModals.update_multimodals(multimodals, self.num_valid_ids)
multimodals = HistoryMultiModals.update_multimodals(multimodals, self._num_history_ids)
self.history_multimodals.add_inputs(multimodals)

# for mllama
Expand All @@ -674,3 +728,15 @@ def update_token_ids(self,
def set_step(self, step: int):
"""Set step."""
raise NotImplementedError('NotImplemented')

def __repr__(self):
return (f'SchedulerSequence(seq_id={self.seq_id}, session_id={self.session_id}, '
f'status={self.status}, arrive_time={self.arrive_time}, '
f'return_logits={self.return_logits}, sampling_param={self.sampling_param}, '
f'num_history_ids={self.num_history_ids}, num_all_tokens={self.num_all_ids}, '
f'num_new_tokens={self.num_new_tokens}, all_token_ids={self.all_ids}, '
f'mm_ranges={self.history_multimodals.mm_ranges}, '
f'num_gpu_blocks={self.num_blocks}, gpu_blocks={self.logical_blocks.get_real_blocks()}, '
f'last_shared_node={getattr(self.logical_blocks, "last_shared_node", None)})')

__str__ = __repr__
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/models/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,13 +863,14 @@ def preprocess_input(self,
offset = input_mm['offset']
num_pad = input_mm['image_tokens']
image_token_id = input_mm['image_token_id']
hash_value = input_mm.get('hash_value', None)
if isinstance(num_pad, torch.Tensor):
num_pad = num_pad.item()

mm_data = MultiModalTensor(data=pixel_values,
start=offset,
end=offset + num_pad,
meta=dict(image_token_id=image_token_id))
meta=dict(image_token_id=image_token_id, hash_value=hash_value))
input_imgs.append(mm_data)

result = PreprocessInputResult(
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/models/cogvlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,13 +898,14 @@ def preprocess_input(self, input_ids: List[int], input_multimodals=None, **kwarg
offset = input_mm['offset']
image_token_id = input_mm['image_token_id']
num_pad = input_mm['image_tokens']
hash_value = input_mm.get('hash_value', None)
if isinstance(num_pad, torch.Tensor):
num_pad = num_pad.item()

mm_data = MultiModalTensor(data=pixel_values,
start=offset,
end=offset + num_pad,
meta=dict(image_token_id=image_token_id))
meta=dict(image_token_id=image_token_id, hash_value=hash_value))
input_imgs.append(mm_data)

result = PreprocessInputResult(
Expand Down
4 changes: 3 additions & 1 deletion lmdeploy/pytorch/models/deepseek_vl2.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def __init__(self,
# ----------- language model ------------
language_config = config.language_config
self.language = DeepseekV2ForCausalLM(config=language_config, ctx_mgr=ctx_mgr, dtype=dtype, device=device)

self.config = language_config
# ----------- input processor ------------
self.input_processor = DeepSeekVLV2InputProcessor(config, dtype)

Expand Down Expand Up @@ -436,6 +436,7 @@ def preprocess_input(self,
offset = input_mm['offset']
image_token_id = input_mm['image_token_id']
num_pad = input_mm['image_tokens']
hash_value = input_mm.get('hash_value', None)
images_spatial_crop = input_mm.get('images_spatial_crop', None)
if isinstance(num_pad, torch.Tensor):
num_pad = num_pad.item()
Expand All @@ -445,6 +446,7 @@ def preprocess_input(self,
end=offset + num_pad,
meta=dict(
image_token_id=image_token_id,
hash_value=hash_value,
images_spatial_crop=images_spatial_crop,
))

Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/models/gemma3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,14 @@ def preprocess_input(self,
offset = input_mm['offset']
image_token_id = input_mm['image_token_id']
num_pad = input_mm['image_tokens']
hash_value = input_mm.get('hash_value', None)
if isinstance(num_pad, torch.Tensor):
num_pad = num_pad.item()

mm_data = MultiModalTensor(data=pixel_values,
start=offset,
end=offset + num_pad,
meta=dict(image_token_id=image_token_id))
meta=dict(image_token_id=image_token_id, hash_value=hash_value))
input_imgs.append(mm_data)

result = PreprocessInputResult(
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,13 +991,14 @@ def preprocess_input(self,
offset = input_mm['offset']
image_token_id = input_mm['image_token_id']
num_pad = input_mm['image_tokens']
hash_value = input_mm.get('hash_value', None)
if isinstance(num_pad, torch.Tensor):
num_pad = num_pad.item()

mm_data = MultiModalTensor(data=pixel_values,
start=offset,
end=offset + num_pad,
meta=dict(image_token_id=image_token_id))
meta=dict(image_token_id=image_token_id, hash_value=hash_value))
input_imgs.append(mm_data)

result = PreprocessInputResult(
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,13 +552,14 @@ def preprocess_input(self,
offset = input_mm['offset']
image_token_id = input_mm['image_token_id']
num_pad = input_mm['image_tokens']
hash_value = input_mm.get('hash_value', None)
if isinstance(num_pad, torch.Tensor):
num_pad = num_pad.item()

mm_data = MultiModalTensor(data=pixel_values,
start=offset,
end=offset + num_pad,
meta=dict(image_token_id=image_token_id))
meta=dict(image_token_id=image_token_id, hash_value=hash_value))
input_imgs.append(mm_data)

result = PreprocessInputResult(
Expand Down
5 changes: 4 additions & 1 deletion lmdeploy/pytorch/models/phi3_v.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,13 +381,16 @@ def preprocess_input(self,
offset = input_mm['offset']
image_token_id = input_mm['image_token_id']
num_pad = input_mm['image_tokens']
hash_value = input_mm.get('hash_value', None)
if isinstance(num_pad, torch.Tensor):
num_pad = num_pad.item()

mm_data = MultiModalTensor(data=pixel_values,
start=offset,
end=offset + num_pad,
meta=dict(image_sizes=image_sizes, image_token_id=image_token_id))
meta=dict(image_sizes=image_sizes,
image_token_id=image_token_id,
hash_value=hash_value))
input_imgs.append(mm_data)

result = PreprocessInputResult(
Expand Down
5 changes: 4 additions & 1 deletion lmdeploy/pytorch/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,13 +716,16 @@ def preprocess_input(self,
start = offset
image_token_id = input_mm['image_token_id']
num_pad = input_mm['image_tokens']
hash_value = input_mm.get('hash_value', None)
if isinstance(num_pad, torch.Tensor):
num_pad = num_pad.item()

mm_data = MultiModalTensor(data=pixel_values,
start=start,
end=start + num_pad,
meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id))
meta=dict(grid_thw=image_grid_thw,
image_token_id=image_token_id,
hash_value=hash_value))
input_imgs.append(mm_data)

result = PreprocessInputResult(
Expand Down
5 changes: 4 additions & 1 deletion lmdeploy/pytorch/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,13 +924,16 @@ def preprocess_input(self,
start = offset
image_token_id = input_mm['image_token_id']
num_pad = input_mm['image_tokens']
hash_value = input_mm.get('hash_value', None)
if isinstance(num_pad, torch.Tensor):
num_pad = num_pad.item()

mm_data = MultiModalTensor(data=pixel_values,
start=start,
end=start + num_pad,
meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id))
meta=dict(grid_thw=image_grid_thw,
image_token_id=image_token_id,
hash_value=hash_value))
input_imgs.append(mm_data)

result = PreprocessInputResult(
Expand Down
Loading
Loading