11# Copyright (c) OpenMMLab. All rights reserved.
22import heapq
33from typing import Dict , List , Set , Tuple
4+ from collections import defaultdict
45
56import numpy as np
67
@@ -185,6 +186,8 @@ def _allocate_text(self, seq: SchedulerSequence):
185186 logical_blocks .last_shared_node = node
186187 if node .parent is not None and len (node .children ) == 0 :
187188 # ignore root
189+ if node is None or node .block == - 1 :
190+ breakpoint ()
188191 self .leaves .add (node )
189192 if len (blocks ) > 0 :
190193 self .allocator .add_ref_count (np .array (blocks ), 1 )
@@ -219,6 +222,12 @@ def _match_multimodals(self, seq: SchedulerSequence) -> Dict[int, int]:
219222 else :
220223 last_max_num_matched = curr .num_matched
221224
225+ # if there is a full block or the rest blocks contain vision tokens
226+ if (curr .num_matched + block_size ) > seq .num_all_ids :
227+ mm_hash_values , _ = seq .history_multimodals .get_hash_values (last_max_num_matched , seq .num_all_ids )
228+ if not mm_hash_values :
229+ return copy_map
230+
222231 num_matched = curr .num_matched
223232 logger .debug (f'Matching seq-{ seq .seq_id } { num_matched } /{ seq .num_all_ids } ' )
224233
@@ -285,6 +294,7 @@ def __match_success(node: Node):
285294
286295 add_ref_blocks = matched_blocks
287296 if len (copy_map ):
297+ self .allocator .update_access_time (np .array (list (copy_map .keys ())))
288298 add_ref_blocks = [b for b in add_ref_blocks if b not in copy_map .values ()]
289299 add_ref_blocks = np .array (add_ref_blocks )
290300 matched_blocks = np .array (matched_blocks )
@@ -326,6 +336,12 @@ def _allocate_multimodals(self, seq: SchedulerSequence) -> Dict[int, int]:
326336 else :
327337 last_max_num_matched = node .num_matched
328338
339+ # if there is a full block or the rest blocks contain vision tokens
340+ if (node .num_matched + block_size ) > seq .num_all_ids :
341+ mm_hash_values , _ = seq .history_multimodals .get_hash_values (last_max_num_matched , seq .num_all_ids )
342+ if not mm_hash_values :
343+ return copy_map
344+
329345 num_matched = node .num_matched
330346 num_all_ids = seq .num_all_ids
331347
@@ -443,10 +459,19 @@ def __add_full_node(node, mm_hash_values):
443459 for cur_node in (unfull_nodes + [last_node ]):
444460 if cur_node .parent is not None and len (cur_node .children ) == 0 :
445461 # ignore root
462+ if cur_node is None or cur_node .block == - 1 :
463+ breakpoint ()
446464 self .leaves .add (cur_node )
447465
448466 if len (blocks ) > 0 :
449- self .allocator .add_ref_count (np .array (blocks ), 1 )
467+ update_time_blocks = blocks
468+ if copy_map :
469+ update_time_blocks += list (copy_map .keys ())
470+ update_time_blocks += list (copy_map .values ())
471+ self .allocator .update_access_time (np .array (update_time_blocks ))
472+ blocks = np .array (blocks )
473+ self .allocator .add_ref_count (blocks , 1 )
474+
450475 if len (free_blocks ) > 0 :
451476 self .allocator .free (np .array (free_blocks ))
452477
@@ -455,7 +480,7 @@ def __add_full_node(node, mm_hash_values):
455480 @logging_timer ('BlockTrie_Evict' , logger )
456481 def evict (self , max_num_blocks : int ):
457482 """evict."""
458- if not self .enable :
483+ if not self .enable or len ( self . leaves ) == 0 :
459484 return 0
460485 logger .debug (f'Need to evict max_num_blocks={ max_num_blocks } ' )
461486
@@ -465,15 +490,45 @@ def __remove_leaf(leaves, evicted_blocks):
465490 parent = leaf .parent
466491 leaf .parent = None
467492 self .leaves .remove (leaf )
468- logger .debug (f'Evict block={ leaf .block } node. mm_hashes={ leaf .mm_hashes } ' )
493+ logger .debug (f'Evict block={ leaf .block } mm_hashes={ leaf .mm_hashes } num_matched= { leaf . num_matched } ' )
469494 return parent , leaf
470495
471496 def __add_leaf (leaves , parent ):
497+ if parent is None or parent .block == - 1 :
498+ breakpoint ()
472499 self .leaves .add (parent )
473500 if self .allocator .get_ref_count (parent .block ) == 1 :
474501 access_time = self .allocator .get_access_time (parent .block )
502+ logger .debug (f'Evict heappush block={ parent .block } mm_hashes={ parent .mm_hashes } num_matched={ parent .num_matched } ' )
475503 heapq .heappush (leaves , (access_time , parent ))
476504
505+ def __filter_leaf (leaves , ref_cnt ):
506+ # when the same block is referenced by multiple nodes
507+ # we need to remove the full block first
508+ groups = defaultdict (list )
509+ for idx in range (len (leaves )):
510+ groups [leaves [idx ].block ].append (idx )
511+
512+ indices = []
513+ deduce_ref_blocks = []
514+ for gp in groups .values ():
515+ num = len (gp )
516+ # only deal with a block is refed by a unfull node and a full node case
517+ if num == 2 and num == ref_cnt [gp [0 ]]:
518+ full , unfull = gp
519+ if leaves [unfull ].is_full :
520+ full , unfull = unfull , full
521+ # remove full node
522+ leaves [full ].parent = None
523+ logger .debug (f'Evict remove duplicate full block={ leaves [full ].block } mm_hashes={ leaves [full ].mm_hashes } num_matched={ leaves [full ].num_matched } ' )
524+ self .leaves .remove (leaves [full ])
525+ deduce_ref_blocks .append (leaves [full ].block )
526+ indices .append (unfull )
527+
528+ if len (deduce_ref_blocks ) > 0 :
529+ self .allocator .add_ref_count (np .array (deduce_ref_blocks ), - 1 )
530+ return indices
531+
477532 evicted_blocks = []
478533 leaves = list (self .leaves )
479534
@@ -482,8 +537,10 @@ def __add_leaf(leaves, parent):
482537 ref_cnt = self .allocator .get_ref_count (leave_blocks )
483538 indices = (ref_cnt == 1 ).nonzero ()[0 ]
484539 if len (indices ) == 0 :
485- return 0
486-
540+ indices = __filter_leaf (leaves , ref_cnt )
541+ if len (indices ) == 0 :
542+ return 0
543+
487544 # make heap
488545 leaves = list (leaves [i ] for i in indices )
489546 access_times = self .allocator .get_access_time (leave_blocks )
@@ -492,6 +549,8 @@ def __add_leaf(leaves, parent):
492549 heapq .heapify (leaves )
493550
494551 while len (leaves ) > 0 and len (evicted_blocks ) < max_num_blocks :
552+ if any ([l [1 ] is None or l [1 ].parent is None for l in leaves ]):
553+ breakpoint ()
495554 parent , removed_leaf = __remove_leaf (leaves , evicted_blocks )
496555 if parent .parent is None :
497556 # ignore root
@@ -502,10 +561,13 @@ def __add_leaf(leaves, parent):
502561 parent .children ) == 0 and self .allocator .get_ref_count (parent .block ) == 1 :
503562 tmp_parent = parent .parent
504563 evicted_blocks .append (parent .block )
505- logger .debug (f'Evict block={ parent .block } node. mm_hashes={ parent .mm_hashes } ' )
564+ logger .debug (f'Evict block={ parent .block } mm_hashes={ parent .mm_hashes } num_matched= { parent . num_matched } ' )
506565 parent .parent = None
507566 parent = tmp_parent
508-
567+
568+ if parent .parent is None :
569+ # ignore root
570+ continue
509571 if len (parent .children ) == 0 :
510572 __add_leaf (leaves , parent )
511573
0 commit comments