Skip to content

Commit ddf0461

Browse files
authored
Add files via upload
1 parent 920c770 commit ddf0461

File tree

4 files changed

+11
-6
lines changed

4 files changed

+11
-6
lines changed
31 Bytes
Binary file not shown.
Binary file not shown.

bert4keras3/models.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def build_transformer_model(
4949
application='encoder',
5050
return_keras_model=True,
5151
keras_weights_path=None,
52+
initial=True,
5253
**kwargs
5354
):
5455
"""根据配置文件构建模型,可选加载checkpoint权重
@@ -126,7 +127,7 @@ def build_transformer_model(
126127

127128
transformer = MODEL(**configs)
128129
transformer.build(**configs)
129-
if keras.__version__>'3.0' and backlib=='torch':
130+
if keras.__version__>'3.0' and initial:
130131
#keras3不知道为什么attention需要走一次前向才能初始化
131132
inputs=[]
132133
for modelin in transformer.model.inputs:

bert4keras3/transformers.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def get_cache_inputs(self,lengths:list):
444444
def get_custom_position_ids(self):
445445
return self.custom_position_ids
446446
def build_cache_model(self,input_lengths:list,end_token,
447-
search_mode='greedy',k=1,progress_print=False,index_bias=0):
447+
search_mode='greedy',k=1,progress_print=False,index_bias=0,initial = True):
448448
if backlib=='torch':
449449
progress_print=False
450450
inputs=self.get_cache_inputs(input_lengths)
@@ -458,11 +458,15 @@ def build_cache_model(self,input_lengths:list,end_token,
458458
shape=keras.ops.shape(modelin)
459459
shape=[1 if t==None else t for t in shape]
460460
inputs.append(ops.convert_to_tensor(np.ones(shape),modelin.dtype))
461-
if backlib=='torch':
462-
import torch
463-
with torch.no_grad():
461+
if initial:
462+
if backlib=='torch':
463+
import torch
464+
with torch.no_grad():
465+
self.cache_call(inputs=inputs,input_lengths=input_lengths,end_token=end_token,
466+
search_mode=search_mode,k=k,progress_print=progress_print,index_bias=index_bias)
467+
else:
464468
self.cache_call(inputs=inputs,input_lengths=input_lengths,end_token=end_token,
465-
search_mode=search_mode,k=k,progress_print=progress_print,index_bias=index_bias)
469+
search_mode=search_mode,k=k,progress_print=progress_print,index_bias=index_bias)
466470

467471
return model
468472
class LM_Mask(object):

0 commit comments

Comments
 (0)