@@ -371,47 +371,66 @@ def __init__(self, multimodals: MultiModalInputs):
371371 if multimodals is None :
372372 multimodals = dict ()
373373 self .multimodals = multimodals
374+ self ._init_mm_ranges ()
375+
376+ def _init_mm_ranges (self ):
377+ """init mm ranges and sort it."""
378+ mm_ranges = []
379+ for _ , modal_datas in self .multimodals .items ():
380+ for modal_data in modal_datas :
381+ data = (modal_data .start , modal_data .end , modal_data .meta .get ('hash_value' , None ))
382+ mm_ranges .append (data )
383+ mm_ranges .sort (key = lambda x : x [1 ])
384+ self ._mm_ranges = mm_ranges
385+
386+ @property
387+ def mm_ranges (self ):
388+ """mm_ranges"""
389+ return self ._mm_ranges
374390
375391 def get_datas (self , start = 0 , end = - 1 ):
376392 """get multimodals from prompts position [start, end)."""
377393 outs = dict ()
378- test_range = range (start , end )
379394 for modal_type , modal_datas in self .multimodals .items ():
380395 data = []
381396 for modal_data in modal_datas :
382- if (modal_data .start not in test_range and modal_data .end not in test_range ):
383- continue
384- data .append (modal_data )
397+ if modal_data .start < end and modal_data .end > start :
398+ data .append (modal_data )
385399 if len (data ) > 0 :
386400 outs [modal_type ] = data
387401 return outs
388402
389- def get_step (self , step : int ):
403+ def get_step (self , step : int ) -> int :
390404 """get step that before a whole image."""
391405 real_step = step
392- for modal_type , modal_datas in self .multimodals .items ():
393- for modal_data in modal_datas :
394- if modal_data .start > real_step :
395- continue
396- elif modal_data .end <= real_step :
397- continue
398- else :
399- real_step = modal_data .start
406+ for start , end , _ in self ._mm_ranges :
407+ if start <= real_step < end :
408+ real_step = start
400409 return real_step
401410
411+ def has_data (self , start : int , end : int ) -> bool :
412+ """whether has multimodal data in [start, end)"""
413+ return any ([s < end and e > start for s , e , _ in self ._mm_ranges ])
414+
402415 def get_hash_values (self , start : int , end : int ):
403416 """get multimodals hash values that from [start, end)"""
404- hash_values = []
405- for modal_type , modal_datas in self .multimodals .items ():
406- for modal_data in modal_datas :
407- if modal_data .start < end and modal_data .end > start :
408- if modal_data .meta .get ('hash_value' , None ):
409- hash_values .append (modal_data .meta ['hash_value' ])
410- if hash_values :
411- hash_values = tuple (hash_values )
417+ mm_hash_values = []
418+ multimodal_ends = []
419+
420+ for mm_start , mm_end , hash_value in self ._mm_ranges :
421+ # the mm range intersect with the target range
422+ if mm_start < end and mm_end > start :
423+ mm_hash_values .append (hash_value )
424+ # the mm end in the target range
425+ if start < mm_end <= end :
426+ cur_data = (tuple (mm_hash_values ), mm_end )
427+ multimodal_ends .append (cur_data )
428+
429+ if len (mm_hash_values ) == 0 :
430+ mm_hash_values = None
412431 else :
413- hash_values = None
414- return hash_values
432+ mm_hash_values = tuple ( mm_hash_values )
433+ return mm_hash_values , multimodal_ends
415434
416435 def add_inputs (self , input_mms : MultiModalInputs ):
417436 """add new inputs."""
@@ -421,9 +440,17 @@ def add_inputs(self, input_mms: MultiModalInputs):
421440 else :
422441 self .multimodals [modal_type ] = vals
423442
424- def empty (self ):
443+ # update mm_ranges
444+ for modal_data in vals :
445+ data = (modal_data .start , modal_data .end , modal_data .meta .get ('hash_value' , None ))
446+ self ._mm_ranges .append (data )
447+
448+ # sort mm_ranges
449+ self ._mm_ranges .sort (key = lambda x : x [1 ])
450+
451+ def empty (self ) -> bool :
425452 if len (self .multimodals ) == 0 :
426- return 0
453+ return True
427454
428455 return all (len (vals ) == 0 for vals in self .multimodals )
429456
@@ -609,7 +636,7 @@ def update_token_ids(self,
609636
610637 # update multimodals
611638 if multimodals is not None :
612- multimodals = HistoryMultiModals .update_multimodals (multimodals , self .num_all_ids )
639+ multimodals = HistoryMultiModals .update_multimodals (multimodals , self ._num_history_ids )
613640 self .history_multimodals .add_inputs (multimodals )
614641
615642 # cross
@@ -641,6 +668,7 @@ def set_step(self, step: int):
641668 new_step = self .history_multimodals .get_step (step )
642669 assert 0 <= new_step <= step
643670 step = new_step
671+
644672 self ._num_history_ids = step
645673 self ._num_token_ids = num_all_ids - step
646674 self .num_ignored_history = min (step , self .num_ignored_history )
@@ -651,3 +679,16 @@ def set_step(self, step: int):
651679 if self .history_multimodals is not None :
652680 self ._num_history_cross = self .history_multimodals .get_encoder_len (0 , self .num_history_ids )
653681 self ._num_cross = self .history_multimodals .get_encoder_len (self ._num_history_ids , num_all_ids )
682+
683+ def __repr__ (self ):
684+ return (
685+ f'SchedulerSequence(seq_id={ self .seq_id } , session_id={ self .session_id } , '
686+ f'status={ self .status } , arrive_time={ self .arrive_time } , '
687+ f'return_logits={ self .return_logits } , sampling_param={ self .sampling_param } , '
688+ f'num_history_tokens={ self .history_len } , num_all_tokens={ self .num_all_ids } , '
689+ f'num_new_tokens={ self .num_new_tokens } , all_token_ids={ self .all_ids } , '
690+ f'num_gpu_blocks={ self .num_blocks } , gpu_blocks={ self .logical_blocks .get_real_blocks ()} , '
691+ f'last_shared_node={ getattr (self , "last_shared_node" , None )} )'
692+ )
693+
694+ __str__ = __repr__
0 commit comments