@@ -444,7 +444,7 @@ def get_cache_inputs(self,lengths:list):
444
444
def get_custom_position_ids (self ):
445
445
return self .custom_position_ids
446
446
def build_cache_model (self ,input_lengths :list ,end_token ,
447
- search_mode = 'greedy' ,k = 1 ,progress_print = False ,index_bias = 0 ):
447
+ search_mode = 'greedy' ,k = 1 ,progress_print = False ,index_bias = 0 , initial = True ):
448
448
if backlib == 'torch' :
449
449
progress_print = False
450
450
inputs = self .get_cache_inputs (input_lengths )
@@ -458,11 +458,15 @@ def build_cache_model(self,input_lengths:list,end_token,
458
458
shape = keras .ops .shape (modelin )
459
459
shape = [1 if t == None else t for t in shape ]
460
460
inputs .append (ops .convert_to_tensor (np .ones (shape ),modelin .dtype ))
461
- if backlib == 'torch' :
462
- import torch
463
- with torch .no_grad ():
461
+ if initial :
462
+ if backlib == 'torch' :
463
+ import torch
464
+ with torch .no_grad ():
465
+ self .cache_call (inputs = inputs ,input_lengths = input_lengths ,end_token = end_token ,
466
+ search_mode = search_mode ,k = k ,progress_print = progress_print ,index_bias = index_bias )
467
+ else :
464
468
self .cache_call (inputs = inputs ,input_lengths = input_lengths ,end_token = end_token ,
465
- search_mode = search_mode ,k = k ,progress_print = progress_print ,index_bias = index_bias )
469
+ search_mode = search_mode ,k = k ,progress_print = progress_print ,index_bias = index_bias )
466
470
467
471
return model
468
472
class LM_Mask (object ):
0 commit comments