Skip to content

Commit e372e6a

Browse files
authored
Add files via upload
1 parent a196d17 commit e372e6a

File tree

1 file changed

+10
-25
lines changed

1 file changed

+10
-25
lines changed

bert4keras3/optimizers.py

+10-25
Original file line numberDiff line numberDiff line change
@@ -929,12 +929,16 @@ def __init__(self, *args, **kwargs):
929929
super(NewOptimizer, self).__init__(*args, **kwargs)
930930
self.accum_grads = {}
931931

932-
932+
def build(self, var_list):
933+
super(NewOptimizer).build(var_list)
934+
for var in var_list:
935+
self.accum_grads[var] = self.add_variable_from_reference(
936+
reference_variable=var, name="momentum"
937+
)
938+
939+
933940
def update_step(self, gradient, variable, learning_rate):
934-
if variable not in self.accum_grads.keys():
935-
self.accum_grads[variable] = ops.zeros(
936-
int_shape(variable), dtype=variable.dtype
937-
)
941+
938942
# 更新判据
939943
cond = ops.equal(self.iterations % self.grad_accum_steps, 0)
940944
cond = K.cast(cond, K.floatx())
@@ -974,16 +978,6 @@ def new_assign_sub(variable, value):
974978
# 获取梯度
975979

976980

977-
def get_gradients(self, loss, params):
978-
accum_grads = []
979-
for p in params:
980-
if p not in self.accum_grads:
981-
self.accum_grads[p] = K.zeros(
982-
K.int_shape(p), dtype=K.dtype(p)
983-
)
984-
accum_grads.append(self.accum_grads[p])
985-
986-
return [ag / self.grad_accum_steps for ag in accum_grads]
987981
def get_config(self):
988982
config = {
989983
'grad_accum_steps': self.grad_accum_steps,
@@ -1444,7 +1438,7 @@ def get_config(self):
14441438

14451439
@export_to_custom_objects
14461440
def extend_with_isnan_skip_v2(BaseOptimizer):
1447-
"""返回新的优化器类,加入梯度累积
1441+
"""返回新的优化器类,加入梯度累积 tf.keras版本
14481442
"""
14491443
class NewOptimizer(BaseOptimizer):
14501444
"""带有梯度累积的优化器
@@ -1525,16 +1519,7 @@ def new_assign_sub(variable, value):
15251519
# 获取梯度
15261520

15271521

1528-
def get_gradients(self, loss, params):
1529-
accum_grads = []
1530-
for p in params:
1531-
if p not in self.accum_grads:
1532-
self.accum_grads[p] = K.zeros(
1533-
K.int_shape(p), dtype=K.dtype(p)
1534-
)
1535-
accum_grads.append(self.accum_grads[p])
15361522

1537-
return [ag / self.grad_accum_steps for ag in accum_grads]
15381523
def get_config(self):
15391524
config = {
15401525
'grad_accum_steps': self.grad_accum_steps,

0 commit comments

Comments
 (0)