@@ -537,3 +537,122 @@ def update(
537537 ctx_v_out = torch .where (invalid_mask .unsqueeze (- 1 ), torch .tensor (0.0 , dtype = torch .float32 ), v_out )
538538 v_out = torch .where ((is_sliding_layer & (position_ids .max () >= (layer_ctx_len - 1 ))), v_out , ctx_v_out )
539539 return k_out , v_out
540+
541+
542+ # This is a hack for now, until we get to merging this code with HybridCache class,
543+ # We don't really need to inherit transformers classes as their cache classes are made to work with pytorch and
544+ # ours are made to work with AIC
545+ class QEffHybridCacheForGPTOSS :
546+ def __init__ (self , config , batch_size , max_cache_len , sliding_window_len ):
547+ self .max_cache_len = max_cache_len
548+ self .batch_size = batch_size
549+ self .sliding_window_len = sliding_window_len
550+ self .key_cache : List [torch .Tensor ] = []
551+ self .value_cache : List [torch .Tensor ] = []
552+
553+ @classmethod
554+ def from_legacy_cache (
555+ cls , config , past_key_values : Optional [Tuple [Tuple [torch .FloatTensor ]]] = None
556+ ) -> "HybridCache" :
557+ """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
558+ backward compatibility."""
559+ cache = cls (
560+ config ,
561+ batch_size = past_key_values [0 ][0 ].shape [0 ],
562+ max_cache_len = past_key_values [1 ][0 ].shape [2 ],
563+ sliding_window_len = past_key_values [0 ][0 ].shape [2 ],
564+ )
565+ if past_key_values is not None :
566+ for layer_idx in range (len (past_key_values )):
567+ key_states , value_states = past_key_values [layer_idx ]
568+ cache .update (key_states , value_states , layer_idx )
569+ return cache
570+
571+ def __len__ (self ):
572+ """
573+ Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
574+ to the number of layers in the model.
575+ """
576+ return len (self .key_cache )
577+
578+ def get_seq_length (self , layer_idx : Optional [int ] = 0 ) -> int :
579+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
580+ # TODO: deprecate this function in favor of `cache_position`
581+ is_empty_layer = (
582+ len (self .key_cache ) == 0 # no cache in any layer
583+ or len (self .key_cache ) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
584+ or len (self .key_cache [layer_idx ]) == 0 # the layer has no cache
585+ )
586+ layer_seq_length = self .key_cache [layer_idx ].shape [- 2 ] if not is_empty_layer else 0
587+ return layer_seq_length
588+
589+ def to_legacy_cache (self ) -> Tuple [Tuple [torch .Tensor ], Tuple [torch .Tensor ]]:
590+ """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for
591+ backward compatibility."""
592+ legacy_cache = ()
593+ for layer_idx in range (len (self )):
594+ legacy_cache += ((self .key_cache [layer_idx ], self .value_cache [layer_idx ]),)
595+ return legacy_cache
596+
597+ def update (
598+ self ,
599+ key_states : torch .Tensor ,
600+ value_states : torch .Tensor ,
601+ layer_idx : int ,
602+ cache_kwargs : Optional [Dict [str , Any ]] = None ,
603+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
604+ if len (self .key_cache ) <= layer_idx :
605+ self .key_cache .append (key_states )
606+ self .value_cache .append (value_states )
607+ k_out , v_out = key_states , value_states
608+ else :
609+ position_ids = cache_kwargs .get ("position_ids" )
610+ is_sliding_layer = cache_kwargs .get ("is_sliding" )
611+ sliding_window = cache_kwargs .get ("sliding_window" )
612+ batch_index = cache_kwargs .get ("batch_index" , None ) # Check and fetch batch index value from the kwargs
613+
614+ if is_sliding_layer :
615+ kv_position_ids = torch .where (position_ids == - 1 , position_ids , position_ids % sliding_window )
616+ else :
617+ kv_position_ids = position_ids
618+
619+ if batch_index is not None :
620+ if torch .onnx .is_in_onnx_export ():
621+ invalid_scatter_index = torch .iinfo (torch .int32 ).max
622+ scatter_position_ids = torch .where (kv_position_ids < 0 , invalid_scatter_index , kv_position_ids )
623+ else :
624+ scatter_position_ids = kv_position_ids
625+ self .key_cache [layer_idx ] = CtxScatterFuncCB .apply (
626+ self .key_cache [layer_idx ], batch_index , scatter_position_ids , key_states
627+ )
628+ self .value_cache [layer_idx ] = CtxScatterFuncCB .apply (
629+ self .value_cache [layer_idx ], batch_index , scatter_position_ids , value_states
630+ )
631+ else :
632+ self .key_cache [layer_idx ] = CtxScatterFunc .apply (self .key_cache [layer_idx ], kv_position_ids , key_states )
633+ self .value_cache [layer_idx ] = CtxScatterFunc .apply (
634+ self .value_cache [layer_idx ], kv_position_ids , value_states
635+ )
636+
637+ k_out , v_out = self .key_cache [layer_idx ], self .value_cache [layer_idx ]
638+
639+ # Original Gather
640+ ctx_len = self .key_cache [layer_idx ].shape [2 ]
641+ ctx_indices = torch .arange (ctx_len )[None , None , ...]
642+ gather_limit = position_ids .max (1 , keepdim = True ).values .unsqueeze (1 )
643+ invalid_mask = ctx_indices > gather_limit
644+ if torch .onnx .is_in_onnx_export ():
645+ invalid_idx_value = torch .iinfo (torch .int32 ).max
646+ else :
647+ invalid_idx_value = 0
648+ ctx_indices = torch .where (invalid_mask , invalid_idx_value , ctx_indices )
649+
650+ if batch_index is not None :
651+ k_out = CtxGatherFuncCB .apply (k_out , batch_index , ctx_indices )
652+ v_out = CtxGatherFuncCB .apply (v_out , batch_index , ctx_indices )
653+ else :
654+ k_out = CtxGatherFunc .apply (k_out , ctx_indices )
655+ v_out = CtxGatherFunc .apply (v_out , ctx_indices )
656+
657+ v_out = torch .where (invalid_mask .unsqueeze (- 1 ), torch .tensor (0.0 , dtype = torch .float32 ), v_out )
658+ return k_out , v_out
0 commit comments