diff --git a/lmdeploy/pytorch/engine/cache_engine.py b/lmdeploy/pytorch/engine/cache_engine.py index 96c0cb63c3..27ba8fbe21 100644 --- a/lmdeploy/pytorch/engine/cache_engine.py +++ b/lmdeploy/pytorch/engine/cache_engine.py @@ -277,6 +277,16 @@ def swap_out(self, src_to_dst: Dict[int, int]) -> None: """ self._swap(self.full_gpu_cache, self.full_cpu_cache, src_to_dst) + def copy_to(self, src_to_dst: Dict[int, int], cache_type: str = 'gpu') -> None: + """Copy cache. + + Args: + src_to_dst (Dict[int, int]): Map between src and dst. + cache_type (str): cache type 'cpu', 'gpu' + """ + target_cache = self.full_gpu_cache if cache_type == 'gpu' else self.full_cpu_cache + self._swap(target_cache, target_cache, src_to_dst) + @classmethod def get_cache_block_size(cls, block_size: int, diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index f583ed7224..b8e4428d72 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -931,6 +931,7 @@ def __need_schedule_again(prefill: bool, scheduler_output): running = scheduler_output.running swap_in_map = scheduler_output.swap_in_map swap_out_map = scheduler_output.swap_out_map + copy_map = scheduler_output.copy_map if len(running) == 0: return None @@ -949,6 +950,7 @@ def __need_schedule_again(prefill: bool, scheduler_output): inputs=inputs, swap_in_map=swap_in_map, swap_out_map=swap_out_map, + copy_map=copy_map, loop_count=num_loops, sampling_inputs=sampling_inputs, stopping_criteria=stopping_criteria, diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index b184fbf64e..b6d7c20176 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -201,18 +201,21 @@ def msg_with_rank(rank: int, msg: str): return f'rank[{rank}] - {msg}' -def cache_swapping(cache_engine: CacheEngine, swap_in_map: dict, swap_out_map: dict): +def cache_swapping(cache_engine: CacheEngine, swap_in_map: dict, swap_out_map: dict, copy_map: dict): """Perform cache swapping.""" issued_cache_op = False swap_in_map = swap_in_map or dict() swap_out_map = swap_out_map or dict() + copy_map = copy_map or dict() if len(swap_in_map) > 0: cache_engine.swap_in(swap_in_map) issued_cache_op = True if len(swap_out_map) > 0: cache_engine.swap_out(swap_out_map) issued_cache_op = True - + if len(copy_map) > 0: + cache_engine.copy_to(copy_map) + issued_cache_op = True if issued_cache_op: cache_engine.events.wait() @@ -592,6 +595,7 @@ async def _async_step_background( loop_count: int, swap_in_map: Dict = None, swap_out_map: Dict = None, + copy_map: Dict = None, sampling_inputs: SamplingInputs = None, stopping_criteria: StoppingCriteria = None, return_logits: bool = False, @@ -699,7 +703,7 @@ async def __prepare_dp(): # init state cache for first time prefill # I don't know if this is necessary... self.state_cache_engine.init_caches(inputs.state_offsets, inputs.history_lengths == 0) - cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map) + cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map, copy_map=copy_map) for idx in range(loop_count): # inference logger.debug(f' rank[{rank}]: model forward [{idx}].') @@ -995,8 +999,6 @@ async def async_forward(self, inputs: ModelInputs): Args: inputs (Dict): The input data comes from _make_inputs. - swap_in_map (SwapMap): Cache maps to swap in. - swap_out_map (SwapMap): Cache maps to swap out. """ output = self._forward_impl(inputs) await asyncio.sleep(0) diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index ec5e6098b8..c132e312d6 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -413,21 +413,67 @@ def __init__(self, multimodals: MultiModalInputs = None): if multimodals is None: multimodals = dict() self.multimodals = multimodals + self._init_mm_ranges() + + def _init_mm_ranges(self): + """Init mm ranges and sort it.""" + mm_ranges = [] + for _, modal_datas in self.multimodals.items(): + for modal_data in modal_datas: + data = (modal_data.start, modal_data.end, modal_data.meta.get('hash_value', None)) + mm_ranges.append(data) + mm_ranges.sort(key=lambda x: x[1]) + self._mm_ranges = mm_ranges + + @property + def mm_ranges(self): + """mm_ranges.""" + return self._mm_ranges def get_datas(self, start=0, end=-1): """Get multimodals from prompts position [start, end).""" outs = dict() - test_range = range(start, end) for modal_type, modal_datas in self.multimodals.items(): data = [] for modal_data in modal_datas: - if (modal_data.start not in test_range and modal_data.end not in test_range): - continue - data.append(modal_data) + if modal_data.start < end and modal_data.end > start: + data.append(modal_data) if len(data) > 0: outs[modal_type] = data return outs + def get_step(self, step: int) -> int: + """Get step that before a whole image.""" + real_step = step + for start, end, _ in self._mm_ranges: + if start <= real_step < end: + real_step = start + return real_step + + def has_data(self, start: int, end: int) -> bool: + """Whether has multimodal data in [start, end)""" + return any([s < end and e > start for s, e, _ in self._mm_ranges]) + + def get_hash_values(self, start: int, end: int): + """Get multimodals hash values that from [start, end)""" + mm_hash_values = [] + multimodal_ends = [] + + for mm_start, mm_end, hash_value in self._mm_ranges: + # the mm range intersect with the target range + if mm_start < end and mm_end > start: + mm_hash_values.append(hash_value) + # the mm end in the target range + if start < mm_end <= end: + cur_data = (tuple(mm_hash_values), mm_end) + multimodal_ends.append(cur_data) + + if len(mm_hash_values) == 0: + mm_hash_values = None + else: + mm_hash_values = tuple(mm_hash_values) + return mm_hash_values, multimodal_ends + def add_inputs(self, input_mms: MultiModalInputs): """Add new inputs.""" for modal_type, vals in input_mms.items(): @@ -436,7 +482,15 @@ def add_inputs(self, input_mms: MultiModalInputs): else: self.multimodals[modal_type] = vals - def empty(self): + # update mm_ranges + for modal_data in vals: + data = (modal_data.start, modal_data.end, modal_data.meta.get('hash_value', None)) + self._mm_ranges.append(data) + + # sort mm_ranges + self._mm_ranges.sort(key=lambda x: x[1]) + + def empty(self) -> bool: if len(self.multimodals) == 0: return True @@ -655,7 +709,7 @@ def _update_multimodals(self, multimodals: MultiModalInputs): if multimodals is None: self._num_cross = 0 return - multimodals = HistoryMultiModals.update_multimodals(multimodals, self.num_valid_ids) + multimodals = HistoryMultiModals.update_multimodals(multimodals, self._num_history_ids) self.history_multimodals.add_inputs(multimodals) # for mllama @@ -674,3 +728,15 @@ def update_token_ids(self, def set_step(self, step: int): """Set step.""" raise NotImplementedError('NotImplemented') + + def __repr__(self): + return (f'SchedulerSequence(seq_id={self.seq_id}, session_id={self.session_id}, ' + f'status={self.status}, arrive_time={self.arrive_time}, ' + f'return_logits={self.return_logits}, sampling_param={self.sampling_param}, ' + f'num_history_ids={self.num_history_ids}, num_all_tokens={self.num_all_ids}, ' + f'num_new_tokens={self.num_new_tokens}, all_token_ids={self.all_ids}, ' + f'mm_ranges={self.history_multimodals.mm_ranges}, ' + f'num_gpu_blocks={self.num_blocks}, gpu_blocks={self.logical_blocks.get_real_blocks()}, ' + f'last_shared_node={getattr(self.logical_blocks, "last_shared_node", None)})') + + __str__ = __repr__ diff --git a/lmdeploy/pytorch/models/chatglm2.py b/lmdeploy/pytorch/models/chatglm2.py index 56e3169bb7..e537c6adc3 100644 --- a/lmdeploy/pytorch/models/chatglm2.py +++ b/lmdeploy/pytorch/models/chatglm2.py @@ -863,13 +863,14 @@ def preprocess_input(self, offset = input_mm['offset'] num_pad = input_mm['image_tokens'] image_token_id = input_mm['image_token_id'] + hash_value = input_mm.get('hash_value', None) if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() mm_data = MultiModalTensor(data=pixel_values, start=offset, end=offset + num_pad, - meta=dict(image_token_id=image_token_id)) + meta=dict(image_token_id=image_token_id, hash_value=hash_value)) input_imgs.append(mm_data) result = PreprocessInputResult( diff --git a/lmdeploy/pytorch/models/cogvlm.py b/lmdeploy/pytorch/models/cogvlm.py index ad8adc9739..94a1ec11ff 100644 --- a/lmdeploy/pytorch/models/cogvlm.py +++ b/lmdeploy/pytorch/models/cogvlm.py @@ -898,13 +898,14 @@ def preprocess_input(self, input_ids: List[int], input_multimodals=None, **kwarg offset = input_mm['offset'] image_token_id = input_mm['image_token_id'] num_pad = input_mm['image_tokens'] + hash_value = input_mm.get('hash_value', None) if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() mm_data = MultiModalTensor(data=pixel_values, start=offset, end=offset + num_pad, - meta=dict(image_token_id=image_token_id)) + meta=dict(image_token_id=image_token_id, hash_value=hash_value)) input_imgs.append(mm_data) result = PreprocessInputResult( diff --git a/lmdeploy/pytorch/models/deepseek_vl2.py b/lmdeploy/pytorch/models/deepseek_vl2.py index 290b9a4fc0..77221ed872 100644 --- a/lmdeploy/pytorch/models/deepseek_vl2.py +++ b/lmdeploy/pytorch/models/deepseek_vl2.py @@ -146,7 +146,7 @@ def __init__(self, # ----------- language model ------------ language_config = config.language_config self.language = DeepseekV2ForCausalLM(config=language_config, ctx_mgr=ctx_mgr, dtype=dtype, device=device) - + self.config = language_config # ----------- input processor ------------ self.input_processor = DeepSeekVLV2InputProcessor(config, dtype) @@ -436,6 +436,7 @@ def preprocess_input(self, offset = input_mm['offset'] image_token_id = input_mm['image_token_id'] num_pad = input_mm['image_tokens'] + hash_value = input_mm.get('hash_value', None) images_spatial_crop = input_mm.get('images_spatial_crop', None) if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() @@ -445,6 +446,7 @@ def preprocess_input(self, end=offset + num_pad, meta=dict( image_token_id=image_token_id, + hash_value=hash_value, images_spatial_crop=images_spatial_crop, )) diff --git a/lmdeploy/pytorch/models/gemma3_vl.py b/lmdeploy/pytorch/models/gemma3_vl.py index 8f4ea8e972..000e10913c 100644 --- a/lmdeploy/pytorch/models/gemma3_vl.py +++ b/lmdeploy/pytorch/models/gemma3_vl.py @@ -105,13 +105,14 @@ def preprocess_input(self, offset = input_mm['offset'] image_token_id = input_mm['image_token_id'] num_pad = input_mm['image_tokens'] + hash_value = input_mm.get('hash_value', None) if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() mm_data = MultiModalTensor(data=pixel_values, start=offset, end=offset + num_pad, - meta=dict(image_token_id=image_token_id)) + meta=dict(image_token_id=image_token_id, hash_value=hash_value)) input_imgs.append(mm_data) result = PreprocessInputResult( diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py index 2dbd9f9f3e..5230944c72 100644 --- a/lmdeploy/pytorch/models/internvl.py +++ b/lmdeploy/pytorch/models/internvl.py @@ -991,13 +991,14 @@ def preprocess_input(self, offset = input_mm['offset'] image_token_id = input_mm['image_token_id'] num_pad = input_mm['image_tokens'] + hash_value = input_mm.get('hash_value', None) if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() mm_data = MultiModalTensor(data=pixel_values, start=offset, end=offset + num_pad, - meta=dict(image_token_id=image_token_id)) + meta=dict(image_token_id=image_token_id, hash_value=hash_value)) input_imgs.append(mm_data) result = PreprocessInputResult( diff --git a/lmdeploy/pytorch/models/llava.py b/lmdeploy/pytorch/models/llava.py index e87242df4c..f33f2acffc 100644 --- a/lmdeploy/pytorch/models/llava.py +++ b/lmdeploy/pytorch/models/llava.py @@ -552,13 +552,14 @@ def preprocess_input(self, offset = input_mm['offset'] image_token_id = input_mm['image_token_id'] num_pad = input_mm['image_tokens'] + hash_value = input_mm.get('hash_value', None) if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() mm_data = MultiModalTensor(data=pixel_values, start=offset, end=offset + num_pad, - meta=dict(image_token_id=image_token_id)) + meta=dict(image_token_id=image_token_id, hash_value=hash_value)) input_imgs.append(mm_data) result = PreprocessInputResult( diff --git a/lmdeploy/pytorch/models/phi3_v.py b/lmdeploy/pytorch/models/phi3_v.py index a70d0bfecc..bd20cb681a 100644 --- a/lmdeploy/pytorch/models/phi3_v.py +++ b/lmdeploy/pytorch/models/phi3_v.py @@ -381,13 +381,16 @@ def preprocess_input(self, offset = input_mm['offset'] image_token_id = input_mm['image_token_id'] num_pad = input_mm['image_tokens'] + hash_value = input_mm.get('hash_value', None) if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() mm_data = MultiModalTensor(data=pixel_values, start=offset, end=offset + num_pad, - meta=dict(image_sizes=image_sizes, image_token_id=image_token_id)) + meta=dict(image_sizes=image_sizes, + image_token_id=image_token_id, + hash_value=hash_value)) input_imgs.append(mm_data) result = PreprocessInputResult( diff --git a/lmdeploy/pytorch/models/qwen2_5_vl.py b/lmdeploy/pytorch/models/qwen2_5_vl.py index 11d2948b7a..37a9ff296e 100644 --- a/lmdeploy/pytorch/models/qwen2_5_vl.py +++ b/lmdeploy/pytorch/models/qwen2_5_vl.py @@ -716,13 +716,16 @@ def preprocess_input(self, start = offset image_token_id = input_mm['image_token_id'] num_pad = input_mm['image_tokens'] + hash_value = input_mm.get('hash_value', None) if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() mm_data = MultiModalTensor(data=pixel_values, start=start, end=start + num_pad, - meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id)) + meta=dict(grid_thw=image_grid_thw, + image_token_id=image_token_id, + hash_value=hash_value)) input_imgs.append(mm_data) result = PreprocessInputResult( diff --git a/lmdeploy/pytorch/models/qwen2_vl.py b/lmdeploy/pytorch/models/qwen2_vl.py index 77e025c638..11eb673502 100644 --- a/lmdeploy/pytorch/models/qwen2_vl.py +++ b/lmdeploy/pytorch/models/qwen2_vl.py @@ -924,13 +924,16 @@ def preprocess_input(self, start = offset image_token_id = input_mm['image_token_id'] num_pad = input_mm['image_tokens'] + hash_value = input_mm.get('hash_value', None) if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() mm_data = MultiModalTensor(data=pixel_values, start=start, end=start + num_pad, - meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id)) + meta=dict(grid_thw=image_grid_thw, + image_token_id=image_token_id, + hash_value=hash_value)) input_imgs.append(mm_data) result = PreprocessInputResult( diff --git a/lmdeploy/pytorch/paging/block_trie.py b/lmdeploy/pytorch/paging/block_trie.py index d00690af24..f0b1757042 100644 --- a/lmdeploy/pytorch/paging/block_trie.py +++ b/lmdeploy/pytorch/paging/block_trie.py @@ -1,23 +1,39 @@ # Copyright (c) OpenMMLab. All rights reserved. import heapq -from typing import Dict, Set +from dataclasses import dataclass +from typing import Dict, Optional, Set, Tuple import numpy as np from lmdeploy.pytorch.messages import SchedulerSequence +from lmdeploy.utils import get_logger, logging_timer from ..config import CacheConfig from .block_manager import BaseBlockManager +logger = get_logger('lmdeploy') + +def hash_block_tokens(tokens: np.ndarray, mm_hashes: Tuple[str] = None): + """Hash func.""" + if mm_hashes is None: + mm_hashes = 'random' + hash_data = (mm_hashes, tuple(tokens)) + hash_key = hash(hash_data) + return hash_key + + +@dataclass class Node: """Node of block trie.""" - - def __init__(self, hash_key: int, block: int, tokens: np.ndarray, num_matched: int = 0): - self.hash_key = hash_key - self.block = block - self.tokens = tokens - self.num_matched = num_matched + hash_key: int + block: int + tokens: np.ndarray + num_matched: int = 0 + mm_hashes: Optional[Tuple[str]] = None + + def __post_init__(self): + """Post init.""" self.children: Dict[int, 'Node'] = dict() self._parent: 'Node' = None @@ -34,12 +50,39 @@ def parent(self, val: 'Node'): val.children[self.hash_key] = self self._parent = val + def match_child(self, tokens: np.ndarray, mm_hashes=None): + """Match child.""" + hash_key = hash_block_tokens(tokens, mm_hashes=mm_hashes) + hash_collision = None + matched_child = None + if hash_key in self.children: + child = self.children[hash_key] + if child.mm_hashes == mm_hashes and np.array_equal(tokens, child.tokens): + matched_child = child + hash_collision = False + else: + hash_collision = True + logger.error(f'Hash collision found for tokens={tokens}, ' + f'mm_hashes={mm_hashes} with node={child}') + return matched_child, hash_collision, hash_key + + def __hash__(self): + return hash((self.block, self.num_matched, self.hash_key)) + def __lt__(self, other): return True def __le__(self, other): return True + def __repr__(self): + return (f'Node(hash_key={self.hash_key}, block={self.block}, ' + f'num_matched={self.num_matched}, mm_hashes={self.mm_hashes}, ' + f'num_children={len(self.children)}, is_root={self.parent is None}, ' + f'tokens={self.tokens})') + + __str__ = __repr__ + class BlockTrie: """Block trie for prefix caching.""" @@ -54,6 +97,7 @@ def __init__(self, cache_config: CacheConfig, block_manager: BaseBlockManager): # caches with different adapter should not be shared. self._roots: Dict[str, Node] = dict() self.leaves: Set[Node] = set() + self.hit_rates = [] def get_root(self, adapter_name: str): """Get root by adapter name.""" @@ -61,6 +105,7 @@ def get_root(self, adapter_name: str): self._roots[adapter_name] = Node(-1, -1, None) return self._roots[adapter_name] + @logging_timer('BlockTrie_Match', logger) def match(self, seq: SchedulerSequence): """Match sequence and cache.""" if not self.enable: @@ -74,6 +119,7 @@ def match(self, seq: SchedulerSequence): if curr is None: curr = self.get_root(seq.adapter_name) num_matched = curr.num_matched + init_num_matched = num_matched def __match_success(node: Node): nonlocal curr, num_matched @@ -81,18 +127,73 @@ def __match_success(node: Node): curr = node num_matched += block_size - while num_matched + block_size < seq.num_valid_ids: - curr_tokens = seq.history_cache[num_matched:num_matched + block_size] - - key = hash(('random', tuple(curr_tokens))) - if key not in curr.children: - break - - child = curr.children[key] - if not np.array_equal(curr_tokens, child.tokens): - break - - __match_success(child) + def __match_pure_text(): + nonlocal curr, num_matched + while num_matched + block_size <= seq.num_all_ids: + curr_tokens = seq.history_cache[num_matched:num_matched + block_size] + child, _, _ = curr.match_child(curr_tokens, mm_hashes=None) + if child is None: + break + __match_success(child) + + def __match_multimodals(): + nonlocal curr, num_matched, mm_ranges, matched_blocks + while num_matched + block_size <= seq.num_all_ids: + if len(mm_ranges) > 0 and num_matched <= mm_ranges[0][0] < (num_matched + block_size): + # find last block without img_data intersect + last_end_num_matched = -1 + for mm_idx, (_, end, _) in enumerate(mm_ranges): + end_num_matched = (((end - 1) // block_size) + 1) * block_size + if end_num_matched > seq.num_all_ids: + # last block that include end token is not full, just stop quickly + break + intersect_ranges = [data for data in mm_ranges[mm_idx + 1:] if data[0] < end_num_matched] + if len(intersect_ranges) == 0: + last_end_num_matched = end_num_matched + break + if last_end_num_matched == -1: + break + + mutimodal_matched_blocks = [] + all_match = True + multi_curr = curr + for multi_num_matched in range(num_matched, last_end_num_matched, block_size): + num_matched_end = multi_num_matched + block_size + curr_tokens = seq.history_cache[multi_num_matched:num_matched_end] + mm_hashes = tuple([data[2] for data in mm_ranges if data[0] < num_matched_end]) + child, _, _ = multi_curr.match_child(curr_tokens, mm_hashes=mm_hashes) + if child is not None: + mutimodal_matched_blocks.append(child.block) + mm_ranges = [data for data in mm_ranges if data[1] > num_matched_end] + multi_curr = child + else: + all_match = False + break + if all_match: + matched_blocks += mutimodal_matched_blocks + num_matched = last_end_num_matched + curr = multi_curr + else: + break + else: + curr_tokens = seq.history_cache[num_matched:num_matched + block_size] + child, _, _ = curr.match_child(curr_tokens, mm_hashes=None) + if child is None: + break + __match_success(child) + + mm_ranges = None + + if seq.history_multimodals is not None and len(seq.history_multimodals.mm_ranges) > 0: + mm_ranges = list(seq.history_multimodals.mm_ranges) + mm_ranges = [data for data in mm_ranges if num_matched < data[1]] + if len(mm_ranges) == 0: + mm_ranges = None + + if mm_ranges is None: + __match_pure_text() + else: + __match_multimodals() if len(matched_blocks) > 0: matched_blocks = np.array(matched_blocks) @@ -100,9 +201,13 @@ def __match_success(node: Node): self.allocator.add_ref_count(matched_blocks, 1) seq.logical_blocks.append(matched_blocks) seq.set_step(num_matched) - + hit_rate = 100 * len(matched_blocks) * block_size / float(seq.num_all_ids - init_num_matched) + self.hit_rates.append(hit_rate) seq.logical_blocks.last_shared_node = curr + logger.info(f'Block Trie current hit rate={hit_rate}%, ' + f'mean hit rate={np.mean(self.hit_rates)}%, matching seq={seq}') + @logging_timer('BlockTrie_Allocate', logger) def allocate(self, seq: SchedulerSequence): """allocate.""" if not self.enable: @@ -114,7 +219,7 @@ def allocate(self, seq: SchedulerSequence): if node is None: node = self.get_root(seq.adapter_name) logical_blocks.last_shared_node = node - + logger.info(f'Allocate seq={seq}') num_matched = node.num_matched num_valid_ids = seq.num_valid_ids @@ -124,29 +229,135 @@ def allocate(self, seq: SchedulerSequence): if len(node.children) == 0 and node.parent is not None: self.leaves.remove(node) - block_id = num_matched // block_size blocks = [] free_blocks = [] - while num_matched + block_size <= num_valid_ids: - curr_tokens = seq.history_cache[num_matched:num_matched + block_size] - block = logical_blocks[block_id] + def __allocate_text(): + nonlocal node, num_matched, blocks, free_blocks + + block_id = num_matched // block_size + while num_matched + block_size <= seq.num_all_ids: + curr_tokens = seq.history_cache[num_matched:num_matched + block_size] + + block = logical_blocks[block_id] + parent = node - hash_key = hash(('random', tuple(curr_tokens))) - parent = node - if hash_key in parent.children: - child = parent.children[hash_key] - if not np.array_equal(curr_tokens, child.tokens): + mm_hashes = None + matched_child, hash_collision, hash_key = node.match_child(curr_tokens, mm_hashes=mm_hashes) + if hash_collision: break - node = child - free_blocks.append(block) - logical_blocks[block_id] = node.block - else: - node = Node(hash_key=hash_key, block=block, tokens=curr_tokens, num_matched=num_matched + block_size) - node.parent = parent - blocks.append(node.block) - num_matched += block_size - block_id += 1 + + if matched_child is not None: + node = matched_child + free_blocks.append(block) + logical_blocks[block_id] = node.block + else: + node = Node(hash_key=hash_key, + block=block, + tokens=curr_tokens, + num_matched=num_matched + block_size) + node.parent = parent + blocks.append(node.block) + num_matched += block_size + block_id += 1 + + def __allocate_multimodals(): + nonlocal node, num_matched, blocks, free_blocks, mm_ranges + + block_id = num_matched // block_size + + while num_matched + block_size <= seq.num_all_ids: + if len(mm_ranges) > 0 and (mm_ranges[0][0] // block_size) == block_id: + # find last block without img_data intersect + last_end_num_matched = -1 + for mm_idx, (_, end, _) in enumerate(mm_ranges): + end_num_matched = (((end - 1) // block_size) + 1) * block_size + if end_num_matched > seq.num_all_ids: + # last block that include end token is not full, just stop quickly + break + intersect_ranges = [data for data in mm_ranges[mm_idx + 1:] if data[0] < end_num_matched] + if len(intersect_ranges) == 0: + last_end_num_matched = end_num_matched + break + if last_end_num_matched == -1: + break + + multi_blocks = [] + multi_free_blocks = [] + multi_node = node + all_allocate = True + multi_block_id = block_id + for multi_num_matched in range(num_matched, last_end_num_matched, block_size): + num_matched_end = multi_num_matched + block_size + curr_tokens = seq.history_cache[multi_num_matched:num_matched_end] + mm_hashes = tuple([data[2] for data in mm_ranges if data[0] < num_matched_end]) + matched_child, hash_collision, hash_key = multi_node.match_child(curr_tokens, + mm_hashes=mm_hashes) + if hash_collision: + all_allocate = False + break + block = logical_blocks[multi_block_id] + parent = multi_node + if matched_child is not None: + multi_node = matched_child + multi_free_blocks.append(block) + logical_blocks[multi_block_id] = matched_child.block + else: + multi_node = Node(hash_key=hash_key, + block=block, + tokens=curr_tokens, + num_matched=num_matched_end, + mm_hashes=mm_hashes) + multi_node.parent = parent + multi_blocks.append(multi_node.block) + multi_block_id += 1 + mm_ranges = [data for data in mm_ranges if data[1] > num_matched_end] + + if all_allocate: + blocks += multi_blocks + free_blocks += multi_free_blocks + num_matched = last_end_num_matched + node = multi_node + block_id = multi_block_id + else: + break + else: + curr_tokens = seq.history_cache[num_matched:num_matched + block_size] + + block = logical_blocks[block_id] + parent = node + + mm_hashes = None + matched_child, hash_collision, hash_key = node.match_child(curr_tokens, mm_hashes=mm_hashes) + if hash_collision: + break + + if matched_child is not None: + node = matched_child + free_blocks.append(block) + logical_blocks[block_id] = node.block + else: + node = Node(hash_key=hash_key, + block=block, + tokens=curr_tokens, + num_matched=num_matched + block_size) + node.parent = parent + blocks.append(node.block) + num_matched += block_size + block_id += 1 + + mm_ranges = None + + if seq.history_multimodals is not None and len(seq.history_multimodals.mm_ranges) > 0: + mm_ranges = list(seq.history_multimodals.mm_ranges) + mm_ranges = [data for data in mm_ranges if num_matched < data[1]] + if len(mm_ranges) == 0: + mm_ranges = None + + if mm_ranges is None: + __allocate_text() + else: + __allocate_multimodals() logical_blocks.last_shared_node = node if node.parent is not None and len(node.children) == 0: @@ -157,6 +368,7 @@ def allocate(self, seq: SchedulerSequence): if len(free_blocks) > 0: self.allocator.free(np.array(free_blocks)) + @logging_timer('BlockTrie_Evict', logger) def evict(self, max_num_blocks: int): """evict.""" if not self.enable: @@ -168,7 +380,7 @@ def __remove_leaf(leaves, evicted_blocks): parent = leaf.parent leaf.parent = None self.leaves.remove(leaf) - return parent + return parent, leaf def __add_leaf(leaves, parent): self.leaves.add(parent) @@ -197,10 +409,25 @@ def __add_leaf(leaves, parent): heapq.heapify(leaves) while len(leaves) > 0 and len(evicted_blocks) < max_num_blocks: - parent = __remove_leaf(leaves, evicted_blocks) + parent, removed_leaf = __remove_leaf(leaves, evicted_blocks) if parent.parent is None: # ignore root continue + + # remove nodes of with same mm_hashes + if removed_leaf.mm_hashes: + while removed_leaf.mm_hashes == parent.mm_hashes and len(parent.children) == 0: + tmp_parent = parent.parent + evicted_blocks.append(parent.block) + parent.parent = None + logger.info(f'Evict multimodal node={parent}') + parent = tmp_parent + logger.info(f'Next multimodal node={parent}') + + if parent.parent is None: + # ignore root + continue + if len(parent.children) == 0: __add_leaf(leaves, parent) diff --git a/lmdeploy/pytorch/strategies/ar/sequence.py b/lmdeploy/pytorch/strategies/ar/sequence.py index 91a3335f18..84f6181473 100644 --- a/lmdeploy/pytorch/strategies/ar/sequence.py +++ b/lmdeploy/pytorch/strategies/ar/sequence.py @@ -56,9 +56,8 @@ def set_step(self, step: int): """Set step.""" num_all_ids = self.num_all_ids # update step for vlm - if len(self.history_embeddings) > 0: - new_step, self._num_history_images, self._num_images = \ - self.history_embeddings.get_step(step) + if self.history_multimodals is not None: + new_step = self.history_multimodals.get_step(step) assert 0 <= new_step <= step step = new_step self._num_history_ids = step diff --git a/lmdeploy/serve/vl_async_engine.py b/lmdeploy/serve/vl_async_engine.py index a784e67e74..1729922a9c 100644 --- a/lmdeploy/serve/vl_async_engine.py +++ b/lmdeploy/serve/vl_async_engine.py @@ -28,10 +28,13 @@ def __init__(self, **kwargs) -> None: if backend == 'pytorch': try_import_deeplink(backend_config.device_type) - if backend_config and backend_config.enable_prefix_caching: + if backend_config.enable_prefix_caching and backend == 'turbomind': backend_config.enable_prefix_caching = False - logger.warning('Prefix caching is disabled since LMDeploy hasn\'t support in on VL models yet') + logger.warning('VLM does not support prefix caching for turbomind engine.') self.vl_encoder = ImageEncoder(model_path, backend, vision_config, backend_config=backend_config) + if backend_config.enable_prefix_caching and not self.vl_encoder.model.support_prefix_caching: + logger.warning(f'Prefix caching is not supported for {model_path}') + super().__init__(model_path, backend=backend, backend_config=backend_config, **kwargs) if self.model_name == 'base': raise RuntimeError( diff --git a/lmdeploy/vl/model/base.py b/lmdeploy/vl/model/base.py index f06a175195..a8b2f38dc1 100644 --- a/lmdeploy/vl/model/base.py +++ b/lmdeploy/vl/model/base.py @@ -15,13 +15,15 @@ class VisonModel(ABC): """Visual model which extract image feature.""" _arch: Union[str, List[str]] = None + support_prefix_caching: bool = True def __init__(self, model_path: str, with_llm: bool = False, max_memory: Dict[int, int] = None, hf_config: AutoConfig = None, - backend: str = ''): + backend: str = '', + enable_prefix_caching: bool = False): """init.""" self.model_path = model_path self.with_llm = with_llm @@ -31,6 +33,7 @@ def __init__(self, _, hf_config = get_model_arch(model_path) self.hf_config = hf_config self.image_token_id = self.get_pad_token_id(model_path, hf_config) or 0 + self.enable_prefix_caching = enable_prefix_caching def get_pad_token_id(self, model_path, hf_config): """Get pad_token_id from hf_config or tokenizer.""" diff --git a/lmdeploy/vl/model/builder.py b/lmdeploy/vl/model/builder.py index 995f63c9f4..28c24c5b98 100644 --- a/lmdeploy/vl/model/builder.py +++ b/lmdeploy/vl/model/builder.py @@ -59,7 +59,15 @@ def load_vl_model(model_path: str, max_memory = {i: torch.cuda.mem_get_info(i)[0] for i in range(tp)} if backend == 'turbomind' else None _, hf_config = get_model_arch(model_path) - kwargs = dict(model_path=model_path, with_llm=with_llm, max_memory=max_memory, hf_config=hf_config, backend=backend) + + enable_prefix_caching = getattr(backend_config, 'enable_prefix_caching', False) + + kwargs = dict(model_path=model_path, + with_llm=with_llm, + max_memory=max_memory, + hf_config=hf_config, + backend=backend, + enable_prefix_caching=enable_prefix_caching) for name, module in VISION_MODELS.module_dict.items(): try: diff --git a/lmdeploy/vl/model/cogvlm.py b/lmdeploy/vl/model/cogvlm.py index a0f3e8b073..ebf5a812bf 100644 --- a/lmdeploy/vl/model/cogvlm.py +++ b/lmdeploy/vl/model/cogvlm.py @@ -12,6 +12,7 @@ class CogVLMVisionModel(VisonModel): """CogVLM vision model.""" _arch = 'CogVLMForCausalLM' + support_prefix_caching: bool = False def build_preprocessor(self): from torchvision import transforms @@ -43,7 +44,7 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: """Refer to the spec of `super().preprocess`""" images = self.collect_images(messages) outputs = [] - for image, _ in images: + for image, params in images: image = image.convert('RGB') pixel_values = self.image_transform(image) outputs.append( diff --git a/lmdeploy/vl/model/deepseek.py b/lmdeploy/vl/model/deepseek.py index 999479e186..16e351f22a 100644 --- a/lmdeploy/vl/model/deepseek.py +++ b/lmdeploy/vl/model/deepseek.py @@ -8,6 +8,7 @@ from lmdeploy.utils import get_logger from lmdeploy.vl.model.base import VISION_MODELS, VisonModel from lmdeploy.vl.model.utils import disable_logging +from lmdeploy.vl.utils import hash_multimodal_data logger = get_logger('lmdeploy') @@ -90,9 +91,12 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: """Refers to the spec of `super.preprocess()""" images = self.collect_images(messages) outputs = [] - for image, _ in images: + for image, params in images: image = image.convert('RGB') pixel_values = self.image_processor([image], return_tensors='pt').pixel_values + hash_value = None + if self.enable_prefix_caching: + hash_value = hash_multimodal_data(model_id=self.model_path, image=image, params=params) outputs.append( dict( pixel_values=pixel_values, @@ -100,6 +104,7 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: # refer to https://github.com/deepseek-ai/DeepSeek-VL/blob/main/deepseek_vl/models/processing_vlm.py # noqa # which is hardcoded 576 image_tokens=576, + hash_value=hash_value, image_token_id=self.image_token_id)) messages.append(dict(role='preprocess', content=outputs)) return messages diff --git a/lmdeploy/vl/model/deepseek_vl2.py b/lmdeploy/vl/model/deepseek_vl2.py index 109424d3cd..c3bd7e4ce5 100644 --- a/lmdeploy/vl/model/deepseek_vl2.py +++ b/lmdeploy/vl/model/deepseek_vl2.py @@ -8,6 +8,7 @@ from lmdeploy.utils import get_logger from lmdeploy.vl.model.base import VISION_MODELS, VisonModel +from lmdeploy.vl.utils import hash_multimodal_data logger = get_logger('lmdeploy') @@ -74,11 +75,15 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: # convert to upstream api formats images = [img_parameter[0] for img_parameter in images] formatted_messages = [] + hash_value = '' for message in messages: text_content = DeepSeek2VisionModel.proc_single_message(message) image_content = [x['image'] for x in message['content'] if x['type'] == 'image'] + if self.enable_prefix_caching: + hash_value += hash_multimodal_data(model_id=self.model_path, image=image_content) formatted_messages.append(dict(role=message['role'], content=text_content, images=image_content)) - + if hash_value == '': + hash_value = None # NOTE: DeepseekVLV2Processor inputs # conversations (List[Dict]): conversations with a list of messages; # images (List[ImageType]): the list of images; @@ -92,13 +97,12 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: messages.append( dict(role='preprocess', content=[ - dict( - pixel_values=prepare.images, - image_tokens=prepare.num_image_tokens[0], - image_token_id=self.image_processor.image_token_id, - image_size=self.image_processor.image_size, - images_spatial_crop=prepare.images_spatial_crop, - ) + dict(pixel_values=prepare.images, + image_tokens=prepare.num_image_tokens[0], + image_token_id=self.image_processor.image_token_id, + image_size=self.image_processor.image_size, + images_spatial_crop=prepare.images_spatial_crop, + hash_value=hash_value) ])) return messages diff --git a/lmdeploy/vl/model/gemma3_vl.py b/lmdeploy/vl/model/gemma3_vl.py index d488217436..772e41bd01 100644 --- a/lmdeploy/vl/model/gemma3_vl.py +++ b/lmdeploy/vl/model/gemma3_vl.py @@ -39,6 +39,7 @@ class Gemma3VisionModel(VisonModel): """Gemma3 vision model.""" _arch = 'Gemma3ForConditionalGeneration' + support_prefix_caching: bool = False def __init__(self, model_path: str, diff --git a/lmdeploy/vl/model/glm4_v.py b/lmdeploy/vl/model/glm4_v.py index 6032ed0532..472b8675c4 100644 --- a/lmdeploy/vl/model/glm4_v.py +++ b/lmdeploy/vl/model/glm4_v.py @@ -14,6 +14,7 @@ class GLM4VisionModel(VisonModel): """Glm-4v-9b vision model.""" _arch = ['ChatGLMModel', 'ChatGLMForConditionalGeneration'] + support_prefix_caching: bool = False @classmethod def match(cls, config: AutoConfig): @@ -58,12 +59,12 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: # model decide what to do images = [x.convert('RGB') for x in images] pixel_values = [self.image_transform(x) for x in images] - outputs.extend([ - dict(pixel_values=_2, - image_size=_1.size, - image_tokens=self.n_token_per_image, - image_token_id=self.image_token_id) for _1, _2 in zip(images, pixel_values) - ]) + for image, pixel_value in zip(images, pixel_values): + data = dict(pixel_values=pixel_value, + image_size=image.size, + image_tokens=self.n_token_per_image, + image_token_id=self.image_token_id) + outputs.append(data) messages.append(dict(role='preprocess', content=outputs)) return messages diff --git a/lmdeploy/vl/model/internvl.py b/lmdeploy/vl/model/internvl.py index a2b8d7f9b7..eaed1e24d0 100644 --- a/lmdeploy/vl/model/internvl.py +++ b/lmdeploy/vl/model/internvl.py @@ -7,6 +7,7 @@ from lmdeploy.utils import get_logger from lmdeploy.vl.model.base import VISION_MODELS, VisonModel from lmdeploy.vl.model.utils import disable_logging +from lmdeploy.vl.utils import hash_multimodal_data logger = get_logger('lmdeploy') @@ -74,8 +75,14 @@ def __init__(self, with_llm: bool = False, max_memory: Dict[int, int] = None, hf_config: AutoConfig = None, - backend: str = ''): - super().__init__(model_path, with_llm, max_memory, hf_config, backend) + backend: str = '', + enable_prefix_caching: bool = False): + super().__init__(model_path, + with_llm, + max_memory, + hf_config, + backend, + enable_prefix_caching=enable_prefix_caching) IMG_CONTEXT_TOKEN = '' tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False) self.image_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) @@ -197,11 +204,15 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: for image, params in images: image = image.convert('RGB') pixel_values = self.processor(image, params) + hash_value = None + if self.enable_prefix_caching: + hash_value = hash_multimodal_data(model_id=self.model_path, image=image, params=params) image_tokens = (pixel_values.shape[0] * self.image_tokens_per_patch) outputs.append( dict(pixel_values=pixel_values, image_tokens=image_tokens, image_token_id=self.image_token_id, + hash_value=hash_value, image_size=image.size)) messages.append(dict(role='preprocess', content=outputs)) return messages diff --git a/lmdeploy/vl/model/llava_hf.py b/lmdeploy/vl/model/llava_hf.py index 3eb348b390..fcf71019ad 100644 --- a/lmdeploy/vl/model/llava_hf.py +++ b/lmdeploy/vl/model/llava_hf.py @@ -8,6 +8,7 @@ from lmdeploy.utils import get_logger from lmdeploy.vl.model.base import VISION_MODELS, VisonModel from lmdeploy.vl.model.utils import disable_logging +from lmdeploy.vl.utils import hash_multimodal_data logger = get_logger('lmdeploy') @@ -61,11 +62,15 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: outputs = [] for image, params in images: image = image.convert('RGB') + hash_value = None + if self.enable_prefix_caching: + hash_value = hash_multimodal_data(model_id=self.model_path, image=image, params=params) pixel_values = self.processor(image, return_tensors='pt', input_data_format='channels_last').pixel_values outputs.append( dict(pixel_values=pixel_values, image_size=image.size, image_tokens=self.n_token_per_image, + hash_value=hash_value, image_token_id=self.image_token_id)) messages.append(dict(role='preprocess', content=outputs)) return messages diff --git a/lmdeploy/vl/model/llava_next.py b/lmdeploy/vl/model/llava_next.py index b705f237b8..9f6d758ccc 100644 --- a/lmdeploy/vl/model/llava_next.py +++ b/lmdeploy/vl/model/llava_next.py @@ -8,6 +8,7 @@ from lmdeploy.utils import get_logger from lmdeploy.vl.model.llava_hf import VISION_MODELS, LlavaHfVisionModel from lmdeploy.vl.model.utils import disable_logging +from lmdeploy.vl.utils import hash_multimodal_data logger = get_logger('lmdeploy') @@ -70,6 +71,9 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: outputs = [] for image, params in images: image = image.convert('RGB') + hash_value = None + if self.enable_prefix_caching: + hash_value = hash_multimodal_data(model_id=self.model_path, image=image, params=params) result = self.processor(image, return_tensors='pt', input_data_format='channels_last') # ! infer image_num_patches from image_sizes image_num_patches = [ @@ -93,6 +97,7 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: dict(image_size=image.size, image_patches=image_num_patches, image_tokens=image_tokens, + hash_value=hash_value, image_token_id=self.image_token_id)) outputs.append(result) messages.append(dict(role='preprocess', content=outputs)) diff --git a/lmdeploy/vl/model/minicpmv.py b/lmdeploy/vl/model/minicpmv.py index 3c075c8ee4..13ee2e5110 100644 --- a/lmdeploy/vl/model/minicpmv.py +++ b/lmdeploy/vl/model/minicpmv.py @@ -10,6 +10,7 @@ from lmdeploy.utils import get_logger from lmdeploy.vl.model.base import VISION_MODELS, VisonModel from lmdeploy.vl.model.utils import disable_logging +from lmdeploy.vl.utils import hash_multimodal_data logger = get_logger('lmdeploy') @@ -140,7 +141,11 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: if item['type'] == 'image': image = item['image'].convert('RGB') params = {k: v for k, v in item.items() if k not in {'type', 'image'}} + hash_value = None + if self.enable_prefix_caching: + hash_value = hash_multimodal_data(model_id=self.model_path, image=image, params=params) result = self._preprocess_func(image, params) + result['hash_value'] = hash_value outputs.append(result) messages[i].update(dict(preprocess=outputs)) return messages diff --git a/lmdeploy/vl/model/mllama.py b/lmdeploy/vl/model/mllama.py index 545badb977..1c733e25e2 100644 --- a/lmdeploy/vl/model/mllama.py +++ b/lmdeploy/vl/model/mllama.py @@ -18,6 +18,7 @@ class MllamaVLModel(VisonModel): """llama3.2 model.""" _arch = 'MllamaForConditionalGeneration' + support_prefix_caching: bool = False def build_preprocessor(self): from transformers import AutoProcessor diff --git a/lmdeploy/vl/model/phi3_vision.py b/lmdeploy/vl/model/phi3_vision.py index 683220c29c..446b406c65 100644 --- a/lmdeploy/vl/model/phi3_vision.py +++ b/lmdeploy/vl/model/phi3_vision.py @@ -5,6 +5,7 @@ from transformers import AutoProcessor from lmdeploy.vl.model.llava_hf import VISION_MODELS, LlavaHfVisionModel +from lmdeploy.vl.utils import hash_multimodal_data @VISION_MODELS.register_module() @@ -34,9 +35,16 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: images = self.collect_images(messages) outputs = [] for image, params in images: + hash_value = None + if self.enable_prefix_caching: + hash_value = hash_multimodal_data(model_id=self.model_path, image=image, params=params) result = self.processor.image_processor([image], return_tensors='pt') image_tokens = result['num_img_tokens'] - result.update(dict(image_size=image.size, image_tokens=image_tokens, image_token_id=self.image_token_id)) + result.update( + dict(image_size=image.size, + image_tokens=image_tokens, + image_token_id=self.image_token_id, + hash_value=hash_value)) outputs.append(result) messages.append(dict(role='preprocess', content=outputs)) return messages diff --git a/lmdeploy/vl/model/qwen.py b/lmdeploy/vl/model/qwen.py index 44f62619d6..5bd54989fb 100644 --- a/lmdeploy/vl/model/qwen.py +++ b/lmdeploy/vl/model/qwen.py @@ -8,6 +8,7 @@ from lmdeploy.utils import get_logger from lmdeploy.vl.model.base import VISION_MODELS, VisonModel from lmdeploy.vl.model.utils import disable_logging +from lmdeploy.vl.utils import hash_multimodal_data logger = get_logger('lmdeploy') @@ -75,11 +76,15 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: outputs = [] for image, params in images: image = image.convert('RGB') + hash_value = None + if self.enable_prefix_caching: + hash_value = hash_multimodal_data(model_id=self.model_path, image=image, params=params) pixel_values = self.image_transform(image) outputs.append( dict(pixel_values=pixel_values, image_size=image.size, image_tokens=256, + hash_value=hash_value, image_token_id=self.image_token_id)) messages.append(dict(role='preprocess', content=outputs)) return messages diff --git a/lmdeploy/vl/model/qwen2.py b/lmdeploy/vl/model/qwen2.py index 43096be28b..cf7a484e8f 100644 --- a/lmdeploy/vl/model/qwen2.py +++ b/lmdeploy/vl/model/qwen2.py @@ -26,6 +26,7 @@ class Qwen2VLModel(VisonModel): """Qwen2VL model.""" _arch = ['Qwen2VLForConditionalGeneration', 'Qwen2_5_VLForConditionalGeneration'] + support_prefix_caching: bool = False def build_preprocessor(self): check_qwen_vl_deps_install() @@ -44,7 +45,6 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: outputs = [] for image, params in images: image = image.convert('RGB') - item = dict(type='image', image=image) item.update({key: params[key] for key in params.keys() if key in optional_keys}) image_inputs, _ = process_vision_info([dict(content=[item])]) diff --git a/lmdeploy/vl/utils.py b/lmdeploy/vl/utils.py index d933a2611a..4037d571ea 100644 --- a/lmdeploy/vl/utils.py +++ b/lmdeploy/vl/utils.py @@ -4,7 +4,10 @@ from io import BytesIO from typing import Union +import numpy as np import requests +import torch +from blake3 import blake3 from PIL import Image, ImageFile from lmdeploy.utils import get_logger @@ -81,3 +84,50 @@ def load_image(image_url: Union[str, Image.Image]) -> Image.Image: img = Image.new('RGB', (32, 32)) return img + + +# from https://github.com/vllm-project/vllm/blob/f0ef37233ea0ba5251edaea7362984110411e7eb/vllm/multimodal/hasher.py # noqa: E501 + + +def hash_multimodal_data(multimodal_type: str = 'image', **multimodal_datas: object) -> str: + """Hash image related data.""" + + multimodal_datas['multimodal_type'] = multimodal_type + + def _convert_to_bytes(key: str, value: object): + """Recursively convert object to bytes.""" + if isinstance(value, (list, tuple)): + for idx, obj in enumerate(value): + yield from _convert_to_bytes(f'{key}.{idx}', obj) + elif isinstance(value, dict): + for k, v in value.items(): + yield from _convert_to_bytes(f'{key}.{k}', v) + else: + key_bytes = key.encode('utf-8') + if isinstance(value, str): + value_bytes = value.encode('utf-8') + elif isinstance(value, bytes): + value_bytes = value + elif isinstance(value, Image.Image): + value_bytes = value.tobytes() + else: + if isinstance(value, torch.Tensor): + value = value.cpu().numpy() + elif isinstance(value, (int, float)): + value = np.array(value) + + if isinstance(value, np.ndarray): + value_bytes = value.tobytes() + else: + import pickle + value_bytes = pickle.dumps(value) + yield key_bytes, value_bytes + + hasher = blake3() + for k, v in multimodal_datas.items(): + for k_bytes, v_bytes in _convert_to_bytes(k, v): + hasher.update(k_bytes) + hasher.update(v_bytes) + + hash_value = hasher.hexdigest() + return hash_value diff --git a/requirements/runtime_ascend.txt b/requirements/runtime_ascend.txt index c9b887ccd4..7478a00142 100644 --- a/requirements/runtime_ascend.txt +++ b/requirements/runtime_ascend.txt @@ -1,4 +1,5 @@ accelerate>=0.29.3 +blake3 dlinfer-ascend>=0.1.3 einops fastapi diff --git a/requirements/runtime_camb.txt b/requirements/runtime_camb.txt index 5b37b003c0..11a19d4b49 100644 --- a/requirements/runtime_camb.txt +++ b/requirements/runtime_camb.txt @@ -1,4 +1,5 @@ accelerate==1.2.0 +blake3 einops fastapi fire diff --git a/requirements/runtime_cuda.txt b/requirements/runtime_cuda.txt index 73d6b1dc81..f4ae6c308f 100644 --- a/requirements/runtime_cuda.txt +++ b/requirements/runtime_cuda.txt @@ -1,5 +1,6 @@ accelerate>=0.29.3 aiohttp +blake3 einops fastapi fire diff --git a/requirements/runtime_maca.txt b/requirements/runtime_maca.txt index 70202d5ce5..e754727e3f 100644 --- a/requirements/runtime_maca.txt +++ b/requirements/runtime_maca.txt @@ -1,4 +1,5 @@ accelerate==0.32.1 +blake3 einops fastapi fire diff --git a/requirements/runtime_rocm.txt b/requirements/runtime_rocm.txt index 47d6f66fcd..da8169c50b 100644 --- a/requirements/runtime_rocm.txt +++ b/requirements/runtime_rocm.txt @@ -1,4 +1,5 @@ accelerate>=0.29.3 +blake3 einops fastapi fire diff --git a/tests/pytorch/paging/test_block_trie.py b/tests/pytorch/paging/test_block_trie.py index 7d20c96dab..a4274740a8 100644 --- a/tests/pytorch/paging/test_block_trie.py +++ b/tests/pytorch/paging/test_block_trie.py @@ -3,11 +3,12 @@ from lmdeploy.pytorch.config import CacheConfig from lmdeploy.pytorch.messages import SchedulerSession, SequenceManager, SequenceMeta +from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor from lmdeploy.pytorch.paging.block_manager import build_block_manager from lmdeploy.pytorch.paging.block_trie import BlockTrie -class TestBlockTire: +class TestBlockTrie: @pytest.fixture def block_size(self): @@ -143,3 +144,364 @@ def test_evict(self, block_trie, seq_manager, num_gpu_blocks): new_leaf = next(iter(block_trie.leaves)) assert leaf != new_leaf assert block_mgr.get_num_free_gpu_blocks() == 5 + + def test_allocate_multimodals(self, block_trie, block_mgr, block_size, seq_manager): + allocator = block_trie.allocator + sess = SchedulerSession(0, seq_manager) + half_block_size = block_size // 2 + # test case 1 single block + token_ids = [1] * block_size + [2] * half_block_size + multimodals = dict(image=[ + MultiModalTensor(data=None, start=0, end=block_size, meta=dict(hash_value='image_0')), + ]) + seq = sess.add_sequence(token_ids, multimodals=multimodals) + + block_mgr.allocate(seq) + block_trie.allocate(seq) + logical_blocks = seq.logical_blocks + assert len(logical_blocks) == 2 + ref_cnt = allocator.get_ref_count(logical_blocks.get_real_blocks()) + assert np.array_equal(ref_cnt, [2, 1]) + node = getattr(seq.logical_blocks, 'last_shared_node', None) + assert node is not None + assert node.mm_hashes == tuple(['image_0']) + assert node.num_matched == block_size + assert np.array_equal(node.tokens, [1] * block_size) + assert node in block_trie.leaves + assert len(block_trie.leaves) == 1 + assert node.parent not in block_trie.leaves + assert block_mgr.get_num_free_gpu_blocks() == (block_mgr.num_gpu_blocks - 2) + block_mgr.free(seq) + block_trie.evict(2) + assert block_mgr.get_num_free_gpu_blocks() == block_mgr.num_gpu_blocks + assert len(block_trie.leaves) == 0 + + # test case 2 multi blocks, but last block not full + + token_ids = [1] * (block_size + half_block_size) + [2] * 2 * block_size + multimodals = dict(image=[ + MultiModalTensor(data=None, + start=block_size + half_block_size, + end=3 * block_size + half_block_size, + meta=dict(hash_value='image_0')), + ]) + seq = sess.add_sequence(token_ids, multimodals=multimodals) + block_mgr.allocate(seq) + block_trie.allocate(seq) + logical_blocks = seq.logical_blocks + assert len(logical_blocks) == 4 + ref_cnt = allocator.get_ref_count(logical_blocks.get_real_blocks()) + assert np.array_equal(ref_cnt, [2, 1, 1, 1]) + node = getattr(seq.logical_blocks, 'last_shared_node', None) + assert node is not None + assert node.mm_hashes is None + assert node.num_matched == block_size + assert np.array_equal(node.tokens, [1] * block_size) + assert node in block_trie.leaves + assert len(block_trie.leaves) == 1 + assert node.parent not in block_trie.leaves + assert block_mgr.get_num_free_gpu_blocks() == (block_mgr.num_gpu_blocks - 4) + block_mgr.free(seq) + block_trie.evict(1) + assert block_mgr.get_num_free_gpu_blocks() == block_mgr.num_gpu_blocks + assert len(block_trie.leaves) == 0 + + # append text token to make last block full + seq.update_token_ids([3] * block_size) + block_mgr.allocate(seq) + block_trie.allocate(seq) + logical_blocks = seq.logical_blocks + assert len(logical_blocks) == 5 + ref_cnt = allocator.get_ref_count(logical_blocks.get_real_blocks()) + assert np.array_equal(ref_cnt, [2, 2, 2, 2, 1]) + node = getattr(seq.logical_blocks, 'last_shared_node', None) + assert node is not None + assert node.num_matched == block_size * 4 + assert node.mm_hashes == tuple(['image_0']) + expect_tokens = [2] * half_block_size + [3] * half_block_size + assert np.array_equal(node.tokens, expect_tokens) + assert np.array_equal(node.parent.tokens, [2] * block_size) + assert node in block_trie.leaves + assert len(block_trie.leaves) == 1 + assert block_mgr.get_num_free_gpu_blocks() == (block_mgr.num_gpu_blocks - 5) + block_mgr.free(seq) + block_trie.evict(5) + assert block_mgr.get_num_free_gpu_blocks() == block_mgr.num_gpu_blocks + assert len(block_trie.leaves) == 0 + + # test 3 multi images + quarter_block_size = block_size // 4 + token_ids = [1] * quarter_block_size + [2] * half_block_size + [3] * quarter_block_size + [ + 4 + ] * quarter_block_size + multimodals = dict(image=[ + MultiModalTensor(data=None, start=0, end=quarter_block_size, meta=dict(hash_value='image_0')), + MultiModalTensor( + data=None, start=block_size - quarter_block_size, end=block_size, meta=dict(hash_value='image_1')), + ]) + seq = sess.add_sequence(token_ids, multimodals=multimodals) + + block_mgr.allocate(seq) + block_trie.allocate(seq) + logical_blocks = seq.logical_blocks + assert len(logical_blocks) == 2 + ref_cnt = allocator.get_ref_count(logical_blocks.get_real_blocks()) + assert np.array_equal(ref_cnt, [2, 1]) + node = getattr(seq.logical_blocks, 'last_shared_node', None) + assert node is not None + assert node.mm_hashes == tuple(['image_0', 'image_1']) + assert node.num_matched == block_size + expect_tokens = [1] * quarter_block_size + [2] * half_block_size + [3] * quarter_block_size + assert np.array_equal(node.tokens, expect_tokens) + assert node in block_trie.leaves + assert len(block_trie.leaves) == 1 + assert node.parent not in block_trie.leaves + + # append image token, but last vision block is not full + token_ids = [4] * block_size + [5] * half_block_size + [6] * block_size + multimodals = dict(image=[ + MultiModalTensor(data=None, start=0, end=block_size, meta=dict(hash_value='image_2')), + MultiModalTensor(data=None, + start=2 * block_size - half_block_size, + end=3 * block_size - half_block_size, + meta=dict(hash_value='image_3')), + ]) + seq.update_token_ids(token_ids, multimodals=multimodals) + block_mgr.allocate(seq) + block_trie.allocate(seq) + logical_blocks = seq.logical_blocks + assert len(logical_blocks) == 4 + ref_cnt = allocator.get_ref_count(logical_blocks.get_real_blocks()) + assert np.array_equal(ref_cnt, [2, 2, 2, 1]) + node = getattr(seq.logical_blocks, 'last_shared_node', None) + assert node is not None + assert node.mm_hashes == tuple(['image_3']) + assert node.num_matched == block_size * 3 + expect_tokens = [4] * quarter_block_size + [5] * half_block_size + [6] * quarter_block_size + assert np.array_equal(node.tokens, expect_tokens) + + # append text to make last vision block full + seq.update_token_ids([7] * block_size) + block_mgr.allocate(seq) + block_trie.allocate(seq) + logical_blocks = seq.logical_blocks + assert len(logical_blocks) == 5 + ref_cnt = allocator.get_ref_count(logical_blocks.get_real_blocks()) + assert np.array_equal(ref_cnt, [2, 2, 2, 2, 1]) + node = getattr(seq.logical_blocks, 'last_shared_node', None) + assert node is not None + assert node.mm_hashes is None + assert node.num_matched == 4 * block_size + expect_tokens = [6] * (block_size - quarter_block_size) + [7] * quarter_block_size + assert np.array_equal(node.tokens, expect_tokens) + parent = node.parent + assert parent is not None + assert parent.mm_hashes == tuple(['image_3']) + assert parent.num_matched == 3 * block_size + expect_tokens = [4] * quarter_block_size + [5] * half_block_size + [6] * quarter_block_size + assert np.array_equal(parent.tokens, expect_tokens) + + @pytest.mark.test + def test_match_multimodals(self, block_trie, block_mgr, block_size, seq_manager): + allocator = block_trie.allocator + sess = SchedulerSession(0, seq_manager) + half_block_size = block_size // 2 + quarter_block_size = block_size // 4 + + # initialize cache with single image + + token_ids = [1] * half_block_size # text + token_ids += [2] * block_size # img0 + multimodals = dict(image=[ + MultiModalTensor(data=None, start=half_block_size, end=3 * + half_block_size, meta=dict(hash_value='image_0')), + ]) + seq0 = sess.add_sequence(token_ids, multimodals=multimodals) + + block_mgr.allocate(seq0) + block_trie.allocate(seq0) + + # test with last vision block unfull + token_ids = [1] * half_block_size # text + token_ids += [2] * block_size # img0 + multimodals = dict(image=[ + MultiModalTensor(data=None, start=half_block_size, end=3 * + half_block_size, meta=dict(hash_value='image_0')), + ]) + seq_prob = sess.add_sequence(token_ids, multimodals=multimodals) + + block_trie.match(seq_prob) + last_node = getattr(seq_prob.logical_blocks, 'last_shared_node', None) + assert last_node.parent is None + assert last_node.num_matched == 0 + assert last_node.mm_hashes is None + + seq0.update_token_ids(token_ids=[3] * block_size) + block_mgr.allocate(seq0) + block_trie.allocate(seq0) + + # prob seq last vision block is unfull + block_trie.match(seq_prob) + last_node = getattr(seq_prob.logical_blocks, 'last_shared_node', None) + assert last_node.parent is None + assert last_node.num_matched == 0 + assert last_node.mm_hashes is None + + # prob seq last vision block is full + seq_prob.update_token_ids([3] * block_size + [0] * block_size) + block_trie.match(seq_prob) + last_node = getattr(seq_prob.logical_blocks, 'last_shared_node', None) + assert last_node.parent is not None + assert last_node.num_matched == block_size * 2 + assert last_node.mm_hashes == tuple(['image_0']) + assert np.array_equal(last_node.tokens, [2] * half_block_size + [3] * half_block_size) + parent = last_node.parent + assert parent is not None + assert parent.num_matched == block_size + assert parent.mm_hashes == tuple(['image_0']) + assert np.array_equal(parent.tokens, [1] * half_block_size + [2] * half_block_size) + ref_cnt = allocator.get_ref_count(seq_prob.logical_blocks.get_real_blocks()) + assert np.array_equal(ref_cnt, [3, 3]) + block_mgr.free(seq_prob) + assert len(seq_prob.logical_blocks) == 0 + ref_cnt = allocator.get_ref_count(seq0.logical_blocks.get_real_blocks()) + assert np.array_equal(ref_cnt, [2, 2, 1]) + + # test with different image + token_ids = [1] * half_block_size # text + token_ids += [2] * block_size # img1 + token_ids += [3] * block_size + multimodals = dict(image=[ + MultiModalTensor(data=None, start=half_block_size, end=3 * + half_block_size, meta=dict(hash_value='image_1')), + ]) + seq_prob = sess.add_sequence(token_ids, multimodals=multimodals) + + block_trie.match(seq_prob) + last_node = getattr(seq_prob.logical_blocks, 'last_shared_node', None) + assert last_node.parent is None + assert last_node.num_matched == 0 + assert last_node.mm_hashes is None + + # test with multi image + block_mgr.free(seq0) + block_trie.evict(3) + assert block_mgr.get_num_free_gpu_blocks() == block_mgr.num_gpu_blocks + + # test multi images + token_ids = [1] * half_block_size # text + token_ids += [100] * 3 * half_block_size # img 0 + token_ids += [2] * block_size # text + token_ids += [200] * quarter_block_size # img 1 + token_ids += [3] * half_block_size # text + token_ids += [300] * half_block_size # img 2 + token_ids += [4] * half_block_size # text + + multimodals = dict(image=[ + MultiModalTensor(data=None, start=half_block_size, end=2 * block_size, meta=dict(hash_value='image_0')), + MultiModalTensor(data=None, + start=3 * block_size, + end=3 * block_size + quarter_block_size, + meta=dict(hash_value='image_1')), + MultiModalTensor(data=None, + start=4 * block_size - quarter_block_size, + end=4 * block_size + quarter_block_size, + meta=dict(hash_value='image_2')), + ]) + seq0 = sess.add_sequence(token_ids, multimodals=multimodals) + block_mgr.allocate(seq0) + block_trie.allocate(seq0) + + # test one same image + token_ids = [1] * half_block_size # text + token_ids += [100] * 3 * half_block_size # img 0 + token_ids += [2] * half_block_size # haft text, not match + + multimodals = dict(image=[ + MultiModalTensor(data=None, start=half_block_size, end=2 * block_size, meta=dict(hash_value='image_0')), + ]) + seq_prob = sess.add_sequence(token_ids, multimodals=multimodals) + + block_trie.match(seq_prob) + last_node = getattr(seq_prob.logical_blocks, 'last_shared_node', None) + assert last_node.parent is not None + assert last_node.num_matched == 2 * block_size + assert last_node.mm_hashes == tuple(['image_0']) + assert np.array_equal(last_node.tokens, [100] * block_size) + ref_cnt = allocator.get_ref_count(seq_prob.logical_blocks.get_real_blocks()) + assert np.array_equal(ref_cnt, [3, 3]) + block_mgr.free(seq_prob) + + # test with two same image + token_ids = [1] * half_block_size # text + token_ids += [100] * 3 * half_block_size # img 0 + token_ids += [2] * block_size # text + token_ids += [200] * quarter_block_size # img 1 + token_ids += [6] * block_size # diff text + + multimodals = dict(image=[ + MultiModalTensor(data=None, start=half_block_size, end=2 * block_size, meta=dict(hash_value='image_0')), + MultiModalTensor(data=None, + start=3 * block_size, + end=3 * block_size + quarter_block_size, + meta=dict(hash_value='image_1')), + ]) + seq_prob = sess.add_sequence(token_ids, multimodals=multimodals) + + block_trie.match(seq_prob) + last_node = getattr(seq_prob.logical_blocks, 'last_shared_node', None) + assert last_node.parent is not None + assert last_node.num_matched == 3 * block_size + assert last_node.mm_hashes is None + assert np.array_equal(last_node.tokens, [2] * block_size) + ref_cnt = allocator.get_ref_count(seq_prob.logical_blocks.get_real_blocks()) + assert np.array_equal(ref_cnt, [3, 3, 3]) + block_mgr.free(seq_prob) + + # test with two same image + token_ids = [1] * half_block_size # text + token_ids += [100] * 3 * half_block_size # img 0 + token_ids += [2] * block_size # text + token_ids += [200] * quarter_block_size # img 1 + token_ids += [3] * half_block_size # text + token_ids += [300] * half_block_size # img 2 + token_ids += [4] * half_block_size # text + token_ids += [5] * block_size # text + + multimodals = dict(image=[ + MultiModalTensor(data=None, start=half_block_size, end=2 * block_size, meta=dict(hash_value='image_0')), + MultiModalTensor(data=None, + start=3 * block_size, + end=3 * block_size + quarter_block_size, + meta=dict(hash_value='image_1')), + MultiModalTensor(data=None, + start=4 * block_size - quarter_block_size, + end=4 * block_size + quarter_block_size, + meta=dict(hash_value='image_2')), + ]) + seq_prob = sess.add_sequence(token_ids, multimodals=multimodals) + + block_trie.match(seq_prob) + last_node = getattr(seq_prob.logical_blocks, 'last_shared_node', None) + assert last_node.parent is not None + assert last_node.num_matched == 3 * block_size + assert last_node.mm_hashes is None + assert np.array_equal(last_node.tokens, [2] * block_size) + ref_cnt = allocator.get_ref_count(seq_prob.logical_blocks.get_real_blocks()) + assert np.array_equal(ref_cnt, [3, 3, 3]) + block_mgr.free(seq_prob) + + # test with all images match + seq0.update_token_ids([5] * 2 * block_size) + block_mgr.allocate(seq0) + block_trie.allocate(seq0) + + block_trie.match(seq_prob) + last_node = getattr(seq_prob.logical_blocks, 'last_shared_node', None) + assert last_node.parent is not None + assert last_node.num_matched == 5 * block_size + assert last_node.mm_hashes == tuple(['image_2']) + assert np.array_equal(last_node.tokens, + [300] * quarter_block_size + [4] * half_block_size + [5] * quarter_block_size) + ref_cnt = allocator.get_ref_count(seq_prob.logical_blocks.get_real_blocks()) + assert np.array_equal(ref_cnt, [3, 3, 3, 3, 3])