Skip to content

Commit f0fc8c2

Browse files
authored
Add files via upload
1 parent bc1d1f8 commit f0fc8c2

12 files changed

+176
-74
lines changed
1.49 KB
Binary file not shown.
1.67 KB
Binary file not shown.
0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.
-9 Bytes
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

bert4keras3/backend.py

+108-74
Original file line numberDiff line numberDiff line change
@@ -9,40 +9,72 @@
99
is_tf_keras = strtobool(os.environ.get('TF_KERAS', '0'))
1010
os.environ["KERAS_BACKEND"]=os.environ.get("KERAS_BACKEND", 'tensorflow')
1111
backlib=os.environ["KERAS_BACKEND"]
12-
if backlib=='torch':
12+
if backlib=='tfkeras':
13+
is_tf_keras = True
14+
elif backlib=='torch':
1315
import torch
1416
elif backlib=='jax':
1517
import jax
1618
if is_tf_keras:
1719
sys.modules['keras'] = tf.keras
20+
1821
import keras
1922
import keras.backend as K
23+
do_recompute = strtobool(os.environ.get('RECOMPUTE', '0'))
24+
use_keras_2 = is_tf_keras or keras.__version__<'3.0'
2025

21-
22-
23-
if keras.__version__<'3.0':
26+
if use_keras_2:
2427

2528
from tensorflow.python.client import device_lib
2629
from tensorflow.python.util import nest, tf_inspect
2730
from tensorflow.python.eager import tape
2831
from tensorflow.python.ops.custom_gradient import _graph_mode_decorator
2932
import bert4keras3.ops as ops
3033
load_variable=tf.train.load_variable
31-
34+
norm=tf.norm
3235
else:
3336
from keras import ops
34-
if backlib==torch:
37+
38+
if backlib=='torch':
39+
from torch.utils.checkpoint import checkpoint
3540
def norm(tensor, ord='euclidean', axis=None, keepdims=None):
3641
if ord=='euclidean':
3742
ord=None
3843
return torch.linalg.norm(tensor, ord, axis, keepdims)
44+
def recompute_grad(call):
45+
if not do_recompute:
46+
return call
47+
48+
def inner(self, inputs, **kwargs):
49+
"""定义需要求梯度的函数以及重新定义求梯度过程
50+
(参考自官方自带的tf.recompute_grad函数)
51+
"""
52+
flat_inputs = nest.flatten(inputs)
53+
call_args = tf_inspect.getfullargspec(call).args
54+
for key in ['mask', 'training']:
55+
if key not in call_args and key in kwargs:
56+
del kwargs[key]
57+
def kernel_call():
58+
return call(self, inputs, **kwargs)
59+
return checkpoint(kernel_call,inputs, **kwargs)
60+
61+
return inner
3962
elif backlib=='jax':
63+
import jax
64+
def recompute_grad(call):
65+
if not do_recompute:
66+
return call
67+
return jax.checkpoint(call)
4068
def norm(tensor, ord='euclidean', axis=None, keepdims=None):
4169
if ord=='euclidean':
4270
ord=None
4371
return jax.numpy.linalg.norm(tensor, ord, axis, keepdims)
4472

4573
else:
74+
def recompute_grad(call):
75+
if not do_recompute:
76+
return call
77+
return tf.recompute_grad(call)
4678
norm=tf.norm
4779
ops.norm=norm
4880
# 判断是否启用重计算(通过时间换空间)
@@ -171,7 +203,7 @@ def dtype(x):
171203
pass
172204
K.dtype=dtype
173205

174-
if keras.__version__<'3.0':
206+
if use_keras_2:
175207
def where(cond, x, y):
176208
"""给tf.where加上自动广播
177209
"""
@@ -207,9 +239,10 @@ def sequence_masking(
207239
bias: 额外的偏置项,或者附加的mask;
208240
return_mask: 是否同时返回对齐后的mask。
209241
"""
242+
210243
if not (mask is None and bias is None):
211244
if mask is None:
212-
if K.dtype(bias) == 'bool':
245+
if K.dtype(bias) == 'bool' or (backlib=='torch' and K.dtype(bias) == torch.bool):
213246
mask = bias
214247
x = ops.where(mask, x, value)
215248
else:
@@ -226,6 +259,8 @@ def sequence_masking(
226259

227260
if K.dtype(mask) != 'bool':
228261
mask = ops.cast(mask, 'bool')
262+
elif backlib=='torch' and K.dtype(bias) == torch.bool:
263+
mask = ops.cast(mask, torch.bool)
229264

230265
full_mask = align(mask, [0, axes[0]], K.ndim(x))
231266
for axis in axes[1:]:
@@ -234,7 +269,7 @@ def sequence_masking(
234269
mask = full_mask
235270
if bias is None:
236271
x = ops.where(mask, x, value)
237-
elif K.dtype(bias) == 'bool':
272+
elif K.dtype(bias) == 'bool' or (backlib=='torch' and K.dtype(bias) == torch.bool):
238273
mask = mask & bias
239274
x = ops.where(mask, x, value)
240275
else:
@@ -280,7 +315,7 @@ def attention_normalize(a, mask=None, axis=-1, method='softmax', bias=None):
280315
softmax_plus:来自 https://kexue.fm/archives/8823 。
281316
"""
282317
a, mask = sequence_masking(a, mask, -np.inf, axis, bias, True)
283-
if method == 'softmax':
318+
if method == 'softmax' :
284319
return ops.softmax(a, axis=axis)
285320
else:
286321
if mask is None:
@@ -433,80 +468,79 @@ def graph_mode_decorator(f, *args, **kwargs):
433468
else:
434469
return _graph_mode_decorator(f, args, kwargs)
435470

436-
437-
def recompute_grad(call):
438-
"""重计算装饰器(用来装饰Keras层的call函数)
439-
关于重计算,请参考:https://arxiv.org/abs/1604.06174
440-
"""
441-
if not do_recompute:
442-
return call
443-
444-
def inner(self, inputs, **kwargs):
445-
"""定义需要求梯度的函数以及重新定义求梯度过程
446-
(参考自官方自带的tf.recompute_grad函数)
471+
def recompute_grad(call):
472+
"""重计算装饰器(用来装饰Keras层的call函数)
473+
关于重计算,请参考:https://arxiv.org/abs/1604.06174
447474
"""
448-
flat_inputs = nest.flatten(inputs)
449-
call_args = tf_inspect.getfullargspec(call).args
450-
for key in ['mask', 'training']:
451-
if key not in call_args and key in kwargs:
452-
del kwargs[key]
453-
454-
def kernel_call():
455-
"""定义前向计算
456-
"""
457-
return call(self, inputs, **kwargs)
458-
459-
def call_and_grad(*inputs):
460-
"""定义前向计算和反向计算
475+
if not do_recompute:
476+
return call
477+
478+
def inner(self, inputs, **kwargs):
479+
"""定义需要求梯度的函数以及重新定义求梯度过程
480+
(参考自官方自带的tf.recompute_grad函数)
461481
"""
462-
if is_tf_keras:
463-
with tape.stop_recording():
464-
outputs = kernel_call()
465-
outputs = tf.identity(outputs)
466-
else:
467-
outputs = kernel_call()
468-
469-
def grad_fn(doutputs, variables=None):
470-
watches = list(inputs)
471-
if variables is not None:
472-
watches += list(variables)
473-
with tf.GradientTape() as t:
474-
t.watch(watches)
475-
with tf.control_dependencies([doutputs]):
482+
flat_inputs = nest.flatten(inputs)
483+
call_args = tf_inspect.getfullargspec(call).args
484+
for key in ['mask', 'training']:
485+
if key not in call_args and key in kwargs:
486+
del kwargs[key]
487+
488+
def kernel_call():
489+
"""定义前向计算
490+
"""
491+
return call(self, inputs, **kwargs)
492+
493+
def call_and_grad(*inputs):
494+
"""定义前向计算和反向计算
495+
"""
496+
if is_tf_keras:
497+
with tape.stop_recording():
476498
outputs = kernel_call()
477-
grads = t.gradient(
478-
outputs, watches, output_gradients=[doutputs]
499+
outputs = tf.identity(outputs)
500+
else:
501+
outputs = kernel_call()
502+
503+
def grad_fn(doutputs, variables=None):
504+
watches = list(inputs)
505+
if variables is not None:
506+
watches += list(variables)
507+
with tf.GradientTape() as t:
508+
t.watch(watches)
509+
with tf.control_dependencies([doutputs]):
510+
outputs = kernel_call()
511+
grads = t.gradient(
512+
outputs, watches, output_gradients=[doutputs]
513+
)
514+
del t
515+
return grads[:len(inputs)], grads[len(inputs):]
516+
517+
return outputs, grad_fn
518+
519+
if is_tf_keras: # 仅在tf >= 2.0下可用
520+
outputs, grad_fn = call_and_grad(*flat_inputs)
521+
flat_outputs = nest.flatten(outputs)
522+
523+
def actual_grad_fn(*doutputs):
524+
grads = grad_fn(*doutputs, variables=self.trainable_weights)
525+
return grads[0] + grads[1]
526+
527+
watches = flat_inputs + self.trainable_weights
528+
watches = [tf.convert_to_tensor(x) for x in watches]
529+
tape.record_operation(
530+
call.__name__, flat_outputs, watches, actual_grad_fn
479531
)
480-
del t
481-
return grads[:len(inputs)], grads[len(inputs):]
482-
483-
return outputs, grad_fn
484-
485-
if is_tf_keras: # 仅在tf >= 2.0下可用
486-
outputs, grad_fn = call_and_grad(*flat_inputs)
487-
flat_outputs = nest.flatten(outputs)
488-
489-
def actual_grad_fn(*doutputs):
490-
grads = grad_fn(*doutputs, variables=self.trainable_weights)
491-
return grads[0] + grads[1]
492-
493-
watches = flat_inputs + self.trainable_weights
494-
watches = [tf.convert_to_tensor(x) for x in watches]
495-
tape.record_operation(
496-
call.__name__, flat_outputs, watches, actual_grad_fn
497-
)
498-
return outputs
499-
else: # keras + tf >= 1.14 均可用
500-
return graph_mode_decorator(call_and_grad, *flat_inputs)
501-
502-
return inner
532+
return outputs
533+
else: # keras + tf >= 1.14 均可用
534+
return graph_mode_decorator(call_and_grad, *flat_inputs)
535+
536+
return inner
503537

504538

505539
ops.reshape = reshape
506540
ops.flatten = flatten
507541

508542

509-
if keras.__version__<'3.0':
543+
if use_keras_2:
510544

511545
# 给旧版keras新增symbolic(装饰器),以兼容optimizers.py
512546
keras.backend.symbolic = getattr(keras.backend, 'symbolic', None) or symbolic

bert4keras3/layers.py

+3
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,9 @@ def pay_attention_to(self, inputs, mask=None, **kwargs):
517517
q_mask, v_mask = mask
518518
a_bias, p_bias = kwargs.get('a_bias'), kwargs.get('p_bias')
519519
if a_bias:
520+
520521
a_bias = inputs[n]
522+
521523
n += 1
522524
if p_bias == 'rotary':
523525
qw, kw = apply_rotary_position_embeddings(inputs[n], qw, kw)
@@ -536,6 +538,7 @@ def pay_attention_to(self, inputs, mask=None, **kwargs):
536538
if a_bias is not None and ops.ndim(a_bias) == 3:
537539
a_bias = align(a_bias, [0, -2, -1], ops.ndim(a))
538540
A = attention_normalize(a, v_mask, -1, self.normalization, a_bias)
541+
539542
if self.attention_dropout:
540543
A = self.dropout(A)
541544
# 完成输出

examples/basic_T5PEGASUS_test.py

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Created on Fri Dec 29 19:09:29 2023
4+
5+
@author: Administrator
6+
"""
7+
8+
#测试一下T5-PEGASUS
9+
import os
10+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
11+
os.environ["KERAS_BACKEND"] = "jax"
12+
import numpy as np
13+
import jieba
14+
from bert4keras3.models import build_transformer_model
15+
from bert4keras3.tokenizers import Tokenizer
16+
from bert4keras3.snippets import AutoRegressiveDecoder
17+
base_path='models/chinese_t5_pegasus_base/'
18+
config_path = base_path+'config.json'
19+
checkpoint_path = base_path+ 'model.ckpt'
20+
dict_path = base_path+ 'vocab.txt'
21+
tokenizer= Tokenizer(
22+
dict_path,
23+
do_lower_case=True,
24+
pre_tokenize=lambda s: jieba.cut(s, HMM=False)
25+
)
26+
27+
28+
t5 = build_transformer_model(
29+
config_path=config_path,
30+
checkpoint_path=checkpoint_path,
31+
model='mt5.1.1',
32+
return_keras_model=False,
33+
name='T5',
34+
dropout_rate=0,
35+
)
36+
37+
encoder = t5.encoder
38+
decoder = t5.decoder
39+
t5.model.summary()
40+
class AutoTitle(AutoRegressiveDecoder):
41+
def generate(self, text, topk=1):
42+
c_encoded = encoder.predict(np.array([tokenizer.encode(text)[0]]))[0]
43+
output_ids=[self.start_id]
44+
while output_ids[-1]!=self.end_id and len(output_ids)<128:
45+
46+
outs= self.last_token(decoder).predict([np.expand_dims(c_encoded,0),np.reshape(output_ids,[1,-1])],verbose=3) # 基于beam search
47+
out=np.argmax(outs)
48+
output_ids.append(out)
49+
50+
return tokenizer.decode(output_ids).replace(' ','')
51+
52+
53+
autotitle = AutoTitle(
54+
start_id=tokenizer._token_start_id,
55+
end_id=tokenizer._token_end_id,
56+
maxlen=128
57+
58+
)
59+
print(autotitle.generate('针对以超立方体网络为蓝本的多处理机系统的可靠性和容错能力的精准度量问题,结合多处理机系统遭受计算机病毒攻击时常常发生结构性故障的特点,研究了n维超立方体网络的结构连通性和子结构连通性评价问题。首先,使 用构造n维超立方体网络的3路结构割的方法得到其3路结构连通度的一个上界;然后,使用构造n维超立方体网络的3路子结构集的等价变换或约简变换的方法,得到其3路结构子连通度的一个下界;最后,利用任意网络的3路结构连通度不小于3路子结构连通度的性质,证实了超立方体网络的3路结构连通度和子结构连通度均为该超立方体网络维数'))
60+
61+
'''
62+
原版bert4keras的输出是
63+
针对以超立方体网络为蓝本的多处理机系统的可靠性和容错能力的精准度量问题, 结合多处理机系统遭受计算机病毒攻击时常常发生结构性故障的特点, 研究了n维超立方体网络的结构连通性和子结构连通性评价问题。
64+
65+
'''

0 commit comments

Comments
 (0)