Skip to content

Commit ac3843d

Browse files
authored
Add files via upload
1 parent 5bdad0e commit ac3843d

16 files changed

+99
-22
lines changed

bert4keras3/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#! -*- coding: utf-8 -*-
22

3-
__version__ = '1.0.2'
3+
__version__ = '1.1.2'
44

55
from bert4keras3 import backend,layers,models,snippets,tokenizers
66
from bert4keras3.backend import ops
-1 Bytes
Binary file not shown.
339 Bytes
Binary file not shown.
385 Bytes
Binary file not shown.
17.3 KB
Binary file not shown.
924 Bytes
Binary file not shown.
41 KB
Binary file not shown.
1.3 KB
Binary file not shown.
72.6 KB
Binary file not shown.
-1 Bytes
Binary file not shown.
25.3 KB
Binary file not shown.
Binary file not shown.
Binary file not shown.

bert4keras3/backend.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,24 @@
88
import tensorflow as tf
99
from functools import wraps
1010
is_tf_keras = strtobool(os.environ.get('TF_KERAS', '0'))
11+
lora_model = strtobool(os.environ.get('ENABLE_LORA', '0'))
12+
#jax使用flash参考https://github.com/nshepperd/flash_attn_jax/releases这里安装flash
13+
enable_flashatt = strtobool(os.environ.get('FLASH_ATTN', '0'))
1114
os.environ["KERAS_BACKEND"]=os.environ.get("KERAS_BACKEND", 'tensorflow')
1215
backlib=os.environ["KERAS_BACKEND"]
1316
if backlib=='tfkeras':
1417
is_tf_keras = True
18+
if enable_flashatt:
19+
raise('tensorflow not support flash-attention')
1520
elif backlib=='torch':
1621
import torch
22+
if enable_flashatt:
23+
from flash_attn import flash_attn_func
24+
def flash_mha(q,k,v,softmax_scale=None, is_causal=False, window_size=(-1,-1)):
25+
return flash_attn_func(q, k, v, softmax_scale=softmax_scale, causal=is_causal,window_size=window_size)
1726
elif backlib=='jax':
27+
if enable_flashatt:
28+
from flash_attn_jax import flash_mha
1829
import jax
1930
if is_tf_keras:
2031
sys.modules['keras'] = tf.keras
@@ -534,7 +545,10 @@ def actual_grad_fn(*doutputs):
534545
else:
535546
sys.modules['keras.ops']=ops
536547

537-
548+
def slices_index(x,index,axis):
549+
shape = list(ops.shape(x))
550+
shape[axis] = index
551+
return ops.slice(x,ops.zeros_like(shape),shape)
538552
custom_objects = {
539553
'gelu_erf': gelu_erf,
540554
'gelu_tanh': ops.gelu,

bert4keras3/layers.py

+44-9
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33

44
import numpy as np
55

6-
from bert4keras3.backend import keras, ops, is_tf_keras,K,tf
7-
from bert4keras3.backend import align, sequence_masking
6+
from bert4keras3.backend import keras, ops, is_tf_keras,K,tf,enable_flashatt
7+
if enable_flashatt:
8+
from bert4keras3.backend import flash_mha
9+
from bert4keras3.backend import align, sequence_masking,backlib
810
from bert4keras3.backend import recompute_grad,int_shape
911
from bert4keras3.backend import attention_normalize,divide_no_nan
10-
from bert4keras3.backend import sinusoidal_embeddings
12+
from bert4keras3.backend import sinusoidal_embeddings,slices_index
1113
from bert4keras3.backend import apply_rotary_position_embeddings
1214
from keras import initializers, activations
1315
from keras.layers import *
@@ -27,7 +29,10 @@ def get_config(self):
2729
return dict(list(base_config.items()) + list(config.items()))
2830
def call(self,inputs, **kwargs):
2931
index = kwargs.get('index')
30-
return ops.expand_dims(ops.take(inputs,index,self.axis),self.axis)
32+
out = ops.expand_dims(ops.take(inputs,index,self.axis),self.axis)
33+
if backlib=='torch':
34+
return slices_index(out,index+1,1)
35+
return out
3136
def compute_output_shape(self, input_shape):
3237
input_shape = list(input_shape)
3338
input_shape[self.axis]=1
@@ -637,6 +642,16 @@ def pay_attention_to(self, inputs, mask=None, **kwargs):
637642
vw = value_cache
638643
else:
639644
cache = None
645+
646+
if enable_flashatt:
647+
is_causal = False
648+
if a_bias is not None:
649+
is_causal = True
650+
softmax_scale = 1.
651+
if self.attention_scale:
652+
softmax_scale = 1 / self.key_size**0.5
653+
o = flash_mha(qw,kw,vw,softmax_scale=softmax_scale, is_causal=is_causal)
654+
return o,[],[]
640655
# Attention
641656
a = ops.einsum('bjhd,bkhd->bhjk', qw, kw)
642657
# 处理位置编码
@@ -645,6 +660,7 @@ def pay_attention_to(self, inputs, mask=None, **kwargs):
645660
a = a + ops.einsum('bjhd,jkd->bhjk', qw, position_bias)
646661
elif p_bias == 't5_relative':
647662
position_bias = ops.transpose(inputs[n], (2, 0, 1))
663+
#print(a.shape,position_bias.shape)
648664
a = a + ops.expand_dims(position_bias, 0)
649665
# Attention(续)
650666
if self.attention_scale:
@@ -823,21 +839,40 @@ def call(self, inputs, mask=None, a_bias=None, p_bias=None):
823839
if p_bias == 'rotary':
824840
q, k = apply_rotary_position_embeddings(inputs[n], q, k)
825841
# Attention
842+
if enable_flashatt and ops.shape(k)==ops.shape(v):
843+
z = self.pay_flash_attention_to(q,k,v, a_bias)
844+
else:
845+
z = self.pay_attention_to(q,k,v,mask, a_bias)
846+
# 计算输出
847+
if self.self_attention==False and self.factorization:
848+
z = self.vW_dense(z)
849+
o = self.o_dense(u * z)
850+
return o
851+
def pay_flash_attention_to(self, q,k,v, a_bias):
852+
is_causal = False
853+
if a_bias is not None:
854+
is_causal = True
855+
softmax_scale = 1.
856+
if self.attention_scale:
857+
softmax_scale = 1 / self.key_size**0.5
858+
if ops.ndim(q)==3:
859+
k = ops.expand_dims(k,2)
860+
q = ops.expand_dims(q,2)
861+
v = ops.expand_dims(v,2)
862+
o = flash_mha(q,k,v,softmax_scale=softmax_scale, is_causal=is_causal)
863+
return ops.squeeze(o,2)
864+
def pay_attention_to(self, q,k,v,mask, a_bias):
826865
a = ops.einsum('bmd,bnd->bmn', q, k)
827866
if self.attention_scale:
828867
a = a / self.key_size**0.5
829868
A = attention_normalize(a, mask, -1, self.normalization, a_bias)
830869
if self.attention_dropout:
831870
A = self.dropout(A)
832-
# 计算输出
833871
try:
834872
z=ops.einsum('bmn,bnd->bmd', A, v)
835873
except:
836874
pass
837-
if self.self_attention==False and self.factorization:
838-
z = self.vW_dense(z)
839-
o = self.o_dense(u * z)
840-
return o
875+
return z
841876

842877
def compute_mask(self, inputs, mask=None):
843878
if isinstance(mask, list):

bert4keras3/models.py

+39-11
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# 主要模型
33

44
import numpy as np
5-
from bert4keras3.backend import tf,keras,backlib
5+
from bert4keras3.backend import tf,keras,backlib,lora_model
66
from bert4keras3.layers import *
77
from bert4keras3.snippets import insert_arguments
88
from bert4keras3.snippets import delete_arguments
@@ -473,6 +473,16 @@ def cache_call(self,inputs:list,input_lengths:list,end_token,
473473
caches = self.initial_cache(inputs[:1])
474474
key = 0
475475
x = inputs[key]
476+
477+
class start_index(keras.Layer):
478+
def call(self,x):
479+
z = x!=0
480+
if index_bias>0:
481+
t = ops.ones([ops.shape(z)[0],index_bias],dtype=z.dtype)
482+
z = ops.slice_update(z,[0,0],t)
483+
return ops.max(ops.sum(z,-1))-1
484+
485+
476486
length = input_lengths[key]
477487
self.cache_attention_bias=None
478488
self.cache_position_bias=None
@@ -494,26 +504,23 @@ def cache_call(self,inputs:list,input_lengths:list,end_token,
494504
attention_mask=self.cache_attention_bias,
495505
position_bias=self.cache_position_bias)
496506
z,cache = out[:-1],out[-1]
507+
497508
caches[index*j:index*j+j]=cache
498509

499-
class start_index(keras.Layer):
500-
def call(self,x):
501-
z = x!=0
502-
if index_bias>0:
503-
t = ops.ones([ops.shape(z)[0],index_bias],dtype=z.dtype)
504-
z = ops.slice_update(z,[0,0],t)
505-
return ops.max(ops.sum(z,-1))-1
510+
511+
506512
index = self.apply(
507513
inputs=x,
508514
layer=start_index,
509515
name='start_index'
510516
)
517+
511518
def cond(inputs, caches, index , flags):
512519
cond1 = ops.less(index,length-1)
513520
cond2 = ops.logical_not(ops.all(ops.equal(inputs[key][:,index],end_token),-1))
514521
return ops.logical_and(cond1,cond2)
515522

516-
def body(inputs, caches, index , flags):
523+
def body(inputs, caches, index , flags,cache_shape_torch=None):
517524
if progress_print:
518525

519526
print('\r',index,end='')
@@ -529,7 +536,10 @@ def body(inputs, caches, index , flags):
529536
position_bias = self.compute_cache_position_bias(self_cache_update_index = index)
530537

531538
for i in range(self.num_hidden_layers):
539+
532540
layer_caches = caches[i*j:i*j+j]
541+
if backlib=='torch':
542+
layer_caches[0]=ops.concatenate([layer_caches[0],ops.zeros(cache_shape_torch,dtype=layer_caches[0].dtype)],axis=2)
533543
out=self.apply_main_cache_layers(z+[layer_caches], i,self_cache_update_index=index,
534544
cross_cache_update_index=None,
535545
attention_mask=attention_mask,
@@ -546,13 +556,18 @@ def body(inputs, caches, index , flags):
546556
search_in = [o,index,inputs[key],flags]
547557
inputs[key],flags = self.Search(search_in,k=k,mode=search_mode)
548558
return (inputs, caches, index , flags)
559+
num_hidden_layers = self.num_hidden_layers
549560
class WhileLayer(keras.Layer):
550561
def call(self, x):
551562
inputs, caches, index = x[:]
552563
flags = ops.ones([ops.shape(caches[0])[0],1],dtype='bool')
553564
if backlib=='torch':
565+
cache_shape_torch = list(ops.shape(caches[0]))
566+
cache_shape_torch[2] = 1
567+
for i in range(num_hidden_layers):
568+
caches[i*j]=slices_index(caches[i*j],index,2)
554569
while cond(inputs, caches, index , flags):
555-
inputs, caches, index , flags = body(inputs, caches, index , flags)
570+
inputs, caches, index , flags = body(inputs, caches, index , flags,cache_shape_torch)
556571
return (inputs, caches, index)
557572
outs=ops.while_loop(
558573
cond,
@@ -3639,6 +3654,7 @@ def build_transformer_model(
36393654
model='bert',
36403655
application='encoder',
36413656
return_keras_model=True,
3657+
keras_weights_path=None,
36423658
**kwargs
36433659
):
36443660
"""根据配置文件构建模型,可选加载checkpoint权重
@@ -3718,7 +3734,19 @@ def build_transformer_model(
37183734
shape=[1 if t==None else t for t in shape]
37193735
inputs.append(np.zeros(shape,modelin.dtype))
37203736
transformer.model.predict(inputs,verbose=3)
3721-
3737+
if keras_weights_path is not None:
3738+
transformer.model.load_weights(keras_weights_path, skip_mismatch=True)
3739+
if lora_model:
3740+
3741+
def enable_lora(t):
3742+
if isinstance(t,keras.layers.Embedding) or isinstance(t,keras.layers.Dense):
3743+
t.enable_lora(True)
3744+
for layer in transformer.model.layers:
3745+
layer.trainable=False
3746+
enable_lora(layer)
3747+
for kid in dir (layer):
3748+
t = getattr(layer,kid)
3749+
enable_lora(t)
37223750
if checkpoint_path is not None:
37233751
transformer.load_weights_from_checkpoint(checkpoint_path)
37243752

0 commit comments

Comments
 (0)