Skip to content

Commit a160af3

Browse files
committed
add debug ref
1 parent f50c956 commit a160af3

File tree

1 file changed

+69
-7
lines changed

1 file changed

+69
-7
lines changed

lmdeploy/pytorch/paging/block_trie.py

Lines changed: 69 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import heapq
33
from typing import Dict, List, Set, Tuple
4+
from collections import defaultdict
45

56
import numpy as np
67

@@ -185,6 +186,8 @@ def _allocate_text(self, seq: SchedulerSequence):
185186
logical_blocks.last_shared_node = node
186187
if node.parent is not None and len(node.children) == 0:
187188
# ignore root
189+
if node is None or node.block == -1:
190+
breakpoint()
188191
self.leaves.add(node)
189192
if len(blocks) > 0:
190193
self.allocator.add_ref_count(np.array(blocks), 1)
@@ -219,6 +222,12 @@ def _match_multimodals(self, seq: SchedulerSequence) -> Dict[int, int]:
219222
else:
220223
last_max_num_matched = curr.num_matched
221224

225+
# if there is a full block or the rest blocks contain vision tokens
226+
if (curr.num_matched + block_size) > seq.num_all_ids:
227+
mm_hash_values, _ = seq.history_multimodals.get_hash_values(last_max_num_matched, seq.num_all_ids)
228+
if not mm_hash_values:
229+
return copy_map
230+
222231
num_matched = curr.num_matched
223232
logger.debug(f'Matching seq-{seq.seq_id} {num_matched}/{seq.num_all_ids}')
224233

@@ -285,6 +294,7 @@ def __match_success(node: Node):
285294

286295
add_ref_blocks = matched_blocks
287296
if len(copy_map):
297+
self.allocator.update_access_time(np.array(list(copy_map.keys())))
288298
add_ref_blocks = [b for b in add_ref_blocks if b not in copy_map.values()]
289299
add_ref_blocks = np.array(add_ref_blocks)
290300
matched_blocks = np.array(matched_blocks)
@@ -326,6 +336,12 @@ def _allocate_multimodals(self, seq: SchedulerSequence) -> Dict[int, int]:
326336
else:
327337
last_max_num_matched = node.num_matched
328338

339+
# if there is a full block or the rest blocks contain vision tokens
340+
if (node.num_matched + block_size) > seq.num_all_ids:
341+
mm_hash_values, _ = seq.history_multimodals.get_hash_values(last_max_num_matched, seq.num_all_ids)
342+
if not mm_hash_values:
343+
return copy_map
344+
329345
num_matched = node.num_matched
330346
num_all_ids = seq.num_all_ids
331347

@@ -443,10 +459,19 @@ def __add_full_node(node, mm_hash_values):
443459
for cur_node in (unfull_nodes + [last_node]):
444460
if cur_node.parent is not None and len(cur_node.children) == 0:
445461
# ignore root
462+
if cur_node is None or cur_node.block == -1:
463+
breakpoint()
446464
self.leaves.add(cur_node)
447465

448466
if len(blocks) > 0:
449-
self.allocator.add_ref_count(np.array(blocks), 1)
467+
update_time_blocks = blocks
468+
if copy_map:
469+
update_time_blocks += list(copy_map.keys())
470+
update_time_blocks += list(copy_map.values())
471+
self.allocator.update_access_time(np.array(update_time_blocks))
472+
blocks = np.array(blocks)
473+
self.allocator.add_ref_count(blocks, 1)
474+
450475
if len(free_blocks) > 0:
451476
self.allocator.free(np.array(free_blocks))
452477

@@ -455,7 +480,7 @@ def __add_full_node(node, mm_hash_values):
455480
@logging_timer('BlockTrie_Evict', logger)
456481
def evict(self, max_num_blocks: int):
457482
"""evict."""
458-
if not self.enable:
483+
if not self.enable or len(self.leaves) == 0:
459484
return 0
460485
logger.debug(f'Need to evict max_num_blocks={max_num_blocks}')
461486

@@ -465,15 +490,45 @@ def __remove_leaf(leaves, evicted_blocks):
465490
parent = leaf.parent
466491
leaf.parent = None
467492
self.leaves.remove(leaf)
468-
logger.debug(f'Evict block={leaf.block} node.mm_hashes={leaf.mm_hashes}')
493+
logger.debug(f'Evict block={leaf.block} mm_hashes={leaf.mm_hashes} num_matched={leaf.num_matched}')
469494
return parent, leaf
470495

471496
def __add_leaf(leaves, parent):
497+
if parent is None or parent.block == -1:
498+
breakpoint()
472499
self.leaves.add(parent)
473500
if self.allocator.get_ref_count(parent.block) == 1:
474501
access_time = self.allocator.get_access_time(parent.block)
502+
logger.debug(f'Evict heappush block={parent.block} mm_hashes={parent.mm_hashes} num_matched={parent.num_matched}')
475503
heapq.heappush(leaves, (access_time, parent))
476504

505+
def __filter_leaf(leaves, ref_cnt):
506+
# when the same block is referenced by multiple nodes
507+
# we need to remove the full block first
508+
groups = defaultdict(list)
509+
for idx in range(len(leaves)):
510+
groups[leaves[idx].block].append(idx)
511+
512+
indices = []
513+
deduce_ref_blocks = []
514+
for gp in groups.values():
515+
num = len(gp)
516+
# only deal with a block is refed by a unfull node and a full node case
517+
if num == 2 and num == ref_cnt[gp[0]]:
518+
full, unfull = gp
519+
if leaves[unfull].is_full:
520+
full, unfull = unfull, full
521+
# remove full node
522+
leaves[full].parent = None
523+
logger.debug(f'Evict remove duplicate full block={leaves[full].block} mm_hashes={leaves[full].mm_hashes} num_matched={leaves[full].num_matched}')
524+
self.leaves.remove(leaves[full])
525+
deduce_ref_blocks.append(leaves[full].block)
526+
indices.append(unfull)
527+
528+
if len(deduce_ref_blocks) > 0:
529+
self.allocator.add_ref_count(np.array(deduce_ref_blocks), -1)
530+
return indices
531+
477532
evicted_blocks = []
478533
leaves = list(self.leaves)
479534

@@ -482,8 +537,10 @@ def __add_leaf(leaves, parent):
482537
ref_cnt = self.allocator.get_ref_count(leave_blocks)
483538
indices = (ref_cnt == 1).nonzero()[0]
484539
if len(indices) == 0:
485-
return 0
486-
540+
indices = __filter_leaf(leaves, ref_cnt)
541+
if len(indices) == 0:
542+
return 0
543+
487544
# make heap
488545
leaves = list(leaves[i] for i in indices)
489546
access_times = self.allocator.get_access_time(leave_blocks)
@@ -492,6 +549,8 @@ def __add_leaf(leaves, parent):
492549
heapq.heapify(leaves)
493550

494551
while len(leaves) > 0 and len(evicted_blocks) < max_num_blocks:
552+
if any([l[1] is None or l[1].parent is None for l in leaves]):
553+
breakpoint()
495554
parent, removed_leaf = __remove_leaf(leaves, evicted_blocks)
496555
if parent.parent is None:
497556
# ignore root
@@ -502,10 +561,13 @@ def __add_leaf(leaves, parent):
502561
parent.children) == 0 and self.allocator.get_ref_count(parent.block) == 1:
503562
tmp_parent = parent.parent
504563
evicted_blocks.append(parent.block)
505-
logger.debug(f'Evict block={parent.block} node.mm_hashes={parent.mm_hashes}')
564+
logger.debug(f'Evict block={parent.block} mm_hashes={parent.mm_hashes} num_matched={parent.num_matched}')
506565
parent.parent = None
507566
parent = tmp_parent
508-
567+
568+
if parent.parent is None:
569+
# ignore root
570+
continue
509571
if len(parent.children) == 0:
510572
__add_leaf(leaves, parent)
511573

0 commit comments

Comments
 (0)