@@ -929,12 +929,16 @@ def __init__(self, *args, **kwargs):
929
929
super (NewOptimizer , self ).__init__ (* args , ** kwargs )
930
930
self .accum_grads = {}
931
931
932
-
932
+ def build (self , var_list ):
933
+ super (NewOptimizer ).build (var_list )
934
+ for var in var_list :
935
+ self .accum_grads [var ] = self .add_variable_from_reference (
936
+ reference_variable = var , name = "momentum"
937
+ )
938
+
939
+
933
940
def update_step (self , gradient , variable , learning_rate ):
934
- if variable not in self .accum_grads .keys ():
935
- self .accum_grads [variable ] = ops .zeros (
936
- int_shape (variable ), dtype = variable .dtype
937
- )
941
+
938
942
# 更新判据
939
943
cond = ops .equal (self .iterations % self .grad_accum_steps , 0 )
940
944
cond = K .cast (cond , K .floatx ())
@@ -974,16 +978,6 @@ def new_assign_sub(variable, value):
974
978
# 获取梯度
975
979
976
980
977
- def get_gradients (self , loss , params ):
978
- accum_grads = []
979
- for p in params :
980
- if p not in self .accum_grads :
981
- self .accum_grads [p ] = K .zeros (
982
- K .int_shape (p ), dtype = K .dtype (p )
983
- )
984
- accum_grads .append (self .accum_grads [p ])
985
-
986
- return [ag / self .grad_accum_steps for ag in accum_grads ]
987
981
def get_config (self ):
988
982
config = {
989
983
'grad_accum_steps' : self .grad_accum_steps ,
@@ -1444,7 +1438,7 @@ def get_config(self):
1444
1438
1445
1439
@export_to_custom_objects
1446
1440
def extend_with_isnan_skip_v2 (BaseOptimizer ):
1447
- """返回新的优化器类,加入梯度累积
1441
+ """返回新的优化器类,加入梯度累积 tf.keras版本
1448
1442
"""
1449
1443
class NewOptimizer (BaseOptimizer ):
1450
1444
"""带有梯度累积的优化器
@@ -1525,16 +1519,7 @@ def new_assign_sub(variable, value):
1525
1519
# 获取梯度
1526
1520
1527
1521
1528
- def get_gradients (self , loss , params ):
1529
- accum_grads = []
1530
- for p in params :
1531
- if p not in self .accum_grads :
1532
- self .accum_grads [p ] = K .zeros (
1533
- K .int_shape (p ), dtype = K .dtype (p )
1534
- )
1535
- accum_grads .append (self .accum_grads [p ])
1536
1522
1537
- return [ag / self .grad_accum_steps for ag in accum_grads ]
1538
1523
def get_config (self ):
1539
1524
config = {
1540
1525
'grad_accum_steps' : self .grad_accum_steps ,
0 commit comments