9
9
is_tf_keras = strtobool (os .environ .get ('TF_KERAS' , '0' ))
10
10
os .environ ["KERAS_BACKEND" ]= os .environ .get ("KERAS_BACKEND" , 'tensorflow' )
11
11
backlib = os .environ ["KERAS_BACKEND" ]
12
- if backlib == 'torch' :
12
+ if backlib == 'tfkeras' :
13
+ is_tf_keras = True
14
+ elif backlib == 'torch' :
13
15
import torch
14
16
elif backlib == 'jax' :
15
17
import jax
16
18
if is_tf_keras :
17
19
sys .modules ['keras' ] = tf .keras
20
+
18
21
import keras
19
22
import keras .backend as K
23
+ do_recompute = strtobool (os .environ .get ('RECOMPUTE' , '0' ))
24
+ use_keras_2 = is_tf_keras or keras .__version__ < '3.0'
20
25
21
-
22
-
23
- if keras .__version__ < '3.0' :
26
+ if use_keras_2 :
24
27
25
28
from tensorflow .python .client import device_lib
26
29
from tensorflow .python .util import nest , tf_inspect
27
30
from tensorflow .python .eager import tape
28
31
from tensorflow .python .ops .custom_gradient import _graph_mode_decorator
29
32
import bert4keras3 .ops as ops
30
33
load_variable = tf .train .load_variable
31
-
34
+ norm = tf . norm
32
35
else :
33
36
from keras import ops
34
- if backlib == torch :
37
+
38
+ if backlib == 'torch' :
39
+ from torch .utils .checkpoint import checkpoint
35
40
def norm (tensor , ord = 'euclidean' , axis = None , keepdims = None ):
36
41
if ord == 'euclidean' :
37
42
ord = None
38
43
return torch .linalg .norm (tensor , ord , axis , keepdims )
44
+ def recompute_grad (call ):
45
+ if not do_recompute :
46
+ return call
47
+
48
+ def inner (self , inputs , ** kwargs ):
49
+ """定义需要求梯度的函数以及重新定义求梯度过程
50
+ (参考自官方自带的tf.recompute_grad函数)
51
+ """
52
+ flat_inputs = nest .flatten (inputs )
53
+ call_args = tf_inspect .getfullargspec (call ).args
54
+ for key in ['mask' , 'training' ]:
55
+ if key not in call_args and key in kwargs :
56
+ del kwargs [key ]
57
+ def kernel_call ():
58
+ return call (self , inputs , ** kwargs )
59
+ return checkpoint (kernel_call ,inputs , ** kwargs )
60
+
61
+ return inner
39
62
elif backlib == 'jax' :
63
+ import jax
64
+ def recompute_grad (call ):
65
+ if not do_recompute :
66
+ return call
67
+ return jax .checkpoint (call )
40
68
def norm (tensor , ord = 'euclidean' , axis = None , keepdims = None ):
41
69
if ord == 'euclidean' :
42
70
ord = None
43
71
return jax .numpy .linalg .norm (tensor , ord , axis , keepdims )
44
72
45
73
else :
74
+ def recompute_grad (call ):
75
+ if not do_recompute :
76
+ return call
77
+ return tf .recompute_grad (call )
46
78
norm = tf .norm
47
79
ops .norm = norm
48
80
# 判断是否启用重计算(通过时间换空间)
@@ -171,7 +203,7 @@ def dtype(x):
171
203
pass
172
204
K .dtype = dtype
173
205
174
- if keras . __version__ < '3.0' :
206
+ if use_keras_2 :
175
207
def where (cond , x , y ):
176
208
"""给tf.where加上自动广播
177
209
"""
@@ -207,9 +239,10 @@ def sequence_masking(
207
239
bias: 额外的偏置项,或者附加的mask;
208
240
return_mask: 是否同时返回对齐后的mask。
209
241
"""
242
+
210
243
if not (mask is None and bias is None ):
211
244
if mask is None :
212
- if K .dtype (bias ) == 'bool' :
245
+ if K .dtype (bias ) == 'bool' or ( backlib == 'torch' and K . dtype ( bias ) == torch . bool ) :
213
246
mask = bias
214
247
x = ops .where (mask , x , value )
215
248
else :
@@ -226,6 +259,8 @@ def sequence_masking(
226
259
227
260
if K .dtype (mask ) != 'bool' :
228
261
mask = ops .cast (mask , 'bool' )
262
+ elif backlib == 'torch' and K .dtype (bias ) == torch .bool :
263
+ mask = ops .cast (mask , torch .bool )
229
264
230
265
full_mask = align (mask , [0 , axes [0 ]], K .ndim (x ))
231
266
for axis in axes [1 :]:
@@ -234,7 +269,7 @@ def sequence_masking(
234
269
mask = full_mask
235
270
if bias is None :
236
271
x = ops .where (mask , x , value )
237
- elif K .dtype (bias ) == 'bool' :
272
+ elif K .dtype (bias ) == 'bool' or ( backlib == 'torch' and K . dtype ( bias ) == torch . bool ) :
238
273
mask = mask & bias
239
274
x = ops .where (mask , x , value )
240
275
else :
@@ -280,7 +315,7 @@ def attention_normalize(a, mask=None, axis=-1, method='softmax', bias=None):
280
315
softmax_plus:来自 https://kexue.fm/archives/8823 。
281
316
"""
282
317
a , mask = sequence_masking (a , mask , - np .inf , axis , bias , True )
283
- if method == 'softmax' :
318
+ if method == 'softmax' :
284
319
return ops .softmax (a , axis = axis )
285
320
else :
286
321
if mask is None :
@@ -433,80 +468,79 @@ def graph_mode_decorator(f, *args, **kwargs):
433
468
else :
434
469
return _graph_mode_decorator (f , args , kwargs )
435
470
436
-
437
- def recompute_grad (call ):
438
- """重计算装饰器(用来装饰Keras层的call函数)
439
- 关于重计算,请参考:https://arxiv.org/abs/1604.06174
440
- """
441
- if not do_recompute :
442
- return call
443
-
444
- def inner (self , inputs , ** kwargs ):
445
- """定义需要求梯度的函数以及重新定义求梯度过程
446
- (参考自官方自带的tf.recompute_grad函数)
471
+ def recompute_grad (call ):
472
+ """重计算装饰器(用来装饰Keras层的call函数)
473
+ 关于重计算,请参考:https://arxiv.org/abs/1604.06174
447
474
"""
448
- flat_inputs = nest .flatten (inputs )
449
- call_args = tf_inspect .getfullargspec (call ).args
450
- for key in ['mask' , 'training' ]:
451
- if key not in call_args and key in kwargs :
452
- del kwargs [key ]
453
-
454
- def kernel_call ():
455
- """定义前向计算
456
- """
457
- return call (self , inputs , ** kwargs )
458
-
459
- def call_and_grad (* inputs ):
460
- """定义前向计算和反向计算
475
+ if not do_recompute :
476
+ return call
477
+
478
+ def inner (self , inputs , ** kwargs ):
479
+ """定义需要求梯度的函数以及重新定义求梯度过程
480
+ (参考自官方自带的tf.recompute_grad函数)
461
481
"""
462
- if is_tf_keras :
463
- with tape .stop_recording ():
464
- outputs = kernel_call ()
465
- outputs = tf .identity (outputs )
466
- else :
467
- outputs = kernel_call ()
468
-
469
- def grad_fn (doutputs , variables = None ):
470
- watches = list (inputs )
471
- if variables is not None :
472
- watches += list (variables )
473
- with tf .GradientTape () as t :
474
- t .watch (watches )
475
- with tf .control_dependencies ([doutputs ]):
482
+ flat_inputs = nest .flatten (inputs )
483
+ call_args = tf_inspect .getfullargspec (call ).args
484
+ for key in ['mask' , 'training' ]:
485
+ if key not in call_args and key in kwargs :
486
+ del kwargs [key ]
487
+
488
+ def kernel_call ():
489
+ """定义前向计算
490
+ """
491
+ return call (self , inputs , ** kwargs )
492
+
493
+ def call_and_grad (* inputs ):
494
+ """定义前向计算和反向计算
495
+ """
496
+ if is_tf_keras :
497
+ with tape .stop_recording ():
476
498
outputs = kernel_call ()
477
- grads = t .gradient (
478
- outputs , watches , output_gradients = [doutputs ]
499
+ outputs = tf .identity (outputs )
500
+ else :
501
+ outputs = kernel_call ()
502
+
503
+ def grad_fn (doutputs , variables = None ):
504
+ watches = list (inputs )
505
+ if variables is not None :
506
+ watches += list (variables )
507
+ with tf .GradientTape () as t :
508
+ t .watch (watches )
509
+ with tf .control_dependencies ([doutputs ]):
510
+ outputs = kernel_call ()
511
+ grads = t .gradient (
512
+ outputs , watches , output_gradients = [doutputs ]
513
+ )
514
+ del t
515
+ return grads [:len (inputs )], grads [len (inputs ):]
516
+
517
+ return outputs , grad_fn
518
+
519
+ if is_tf_keras : # 仅在tf >= 2.0下可用
520
+ outputs , grad_fn = call_and_grad (* flat_inputs )
521
+ flat_outputs = nest .flatten (outputs )
522
+
523
+ def actual_grad_fn (* doutputs ):
524
+ grads = grad_fn (* doutputs , variables = self .trainable_weights )
525
+ return grads [0 ] + grads [1 ]
526
+
527
+ watches = flat_inputs + self .trainable_weights
528
+ watches = [tf .convert_to_tensor (x ) for x in watches ]
529
+ tape .record_operation (
530
+ call .__name__ , flat_outputs , watches , actual_grad_fn
479
531
)
480
- del t
481
- return grads [:len (inputs )], grads [len (inputs ):]
482
-
483
- return outputs , grad_fn
484
-
485
- if is_tf_keras : # 仅在tf >= 2.0下可用
486
- outputs , grad_fn = call_and_grad (* flat_inputs )
487
- flat_outputs = nest .flatten (outputs )
488
-
489
- def actual_grad_fn (* doutputs ):
490
- grads = grad_fn (* doutputs , variables = self .trainable_weights )
491
- return grads [0 ] + grads [1 ]
492
-
493
- watches = flat_inputs + self .trainable_weights
494
- watches = [tf .convert_to_tensor (x ) for x in watches ]
495
- tape .record_operation (
496
- call .__name__ , flat_outputs , watches , actual_grad_fn
497
- )
498
- return outputs
499
- else : # keras + tf >= 1.14 均可用
500
- return graph_mode_decorator (call_and_grad , * flat_inputs )
501
-
502
- return inner
532
+ return outputs
533
+ else : # keras + tf >= 1.14 均可用
534
+ return graph_mode_decorator (call_and_grad , * flat_inputs )
535
+
536
+ return inner
503
537
504
538
505
539
ops .reshape = reshape
506
540
ops .flatten = flatten
507
541
508
542
509
- if keras . __version__ < '3.0' :
543
+ if use_keras_2 :
510
544
511
545
# 给旧版keras新增symbolic(装饰器),以兼容optimizers.py
512
546
keras .backend .symbolic = getattr (keras .backend , 'symbolic' , None ) or symbolic
0 commit comments