forked from PaddlePaddle/PaddleVideo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcustom_lr.py
338 lines (296 loc) · 12.7 KB
/
custom_lr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from paddle.optimizer.lr import *
import numpy as np
"""
PaddleVideo Learning Rate Schedule:
You can use paddle.optimizer.lr
or define your custom_lr in this file.
"""
class CustomWarmupCosineDecay(LRScheduler):
r"""
We combine warmup and stepwise-cosine which is used in slowfast model.
Args:
warmup_start_lr (float): start learning rate used in warmup stage.
warmup_epochs (int): the number epochs of warmup.
cosine_base_lr (float|int, optional): base learning rate in cosine schedule.
max_epoch (int): total training epochs.
num_iters(int): number iterations of each epoch.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``CosineAnnealingDecay`` instance to schedule learning rate.
"""
def __init__(self,
warmup_start_lr,
warmup_epochs,
cosine_base_lr,
max_epoch,
num_iters,
last_epoch=-1,
verbose=False):
self.warmup_start_lr = warmup_start_lr
self.warmup_epochs = warmup_epochs
self.cosine_base_lr = cosine_base_lr
self.max_epoch = max_epoch
self.num_iters = num_iters
#call step() in base class, last_lr/last_epoch/base_lr will be update
super(CustomWarmupCosineDecay, self).__init__(last_epoch=last_epoch,
verbose=verbose)
def step(self, epoch=None):
"""
``step`` should be called after ``optimizer.step`` . It will update the learning rate in optimizer according to current ``epoch`` .
The new learning rate will take effect on next ``optimizer.step`` .
Args:
epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.
Returns:
None
"""
if epoch is None:
if self.last_epoch == -1:
self.last_epoch += 1
else:
self.last_epoch += 1 / self.num_iters # update step with iters
else:
self.last_epoch = epoch
self.last_lr = self.get_lr()
if self.verbose:
print('Epoch {}: {} set learning rate to {}.'.format(
self.last_epoch, self.__class__.__name__, self.last_lr))
def _lr_func_cosine(self, cur_epoch, cosine_base_lr, max_epoch):
return cosine_base_lr * (math.cos(math.pi * cur_epoch / max_epoch) +
1.0) * 0.5
def get_lr(self):
"""Define lr policy"""
lr = self._lr_func_cosine(self.last_epoch, self.cosine_base_lr,
self.max_epoch)
lr_end = self._lr_func_cosine(self.warmup_epochs, self.cosine_base_lr,
self.max_epoch)
# Perform warm up.
if self.last_epoch < self.warmup_epochs:
lr_start = self.warmup_start_lr
alpha = (lr_end - lr_start) / self.warmup_epochs
lr = self.last_epoch * alpha + lr_start
return lr
class CustomWarmupPiecewiseDecay(LRScheduler):
r"""
This op combine warmup and stepwise-cosine which is used in slowfast model.
Args:
warmup_start_lr (float): start learning rate used in warmup stage.
warmup_epochs (int): the number epochs of warmup.
step_base_lr (float|int, optional): base learning rate in step schedule.
max_epoch (int): total training epochs.
num_iters(int): number iterations of each epoch.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``CustomWarmupPiecewiseDecay`` instance to schedule learning rate.
"""
def __init__(self,
warmup_start_lr,
warmup_epochs,
step_base_lr,
lrs,
gamma,
steps,
max_epoch,
num_iters,
last_epoch=0,
verbose=False):
self.warmup_start_lr = warmup_start_lr
self.warmup_epochs = warmup_epochs
self.step_base_lr = step_base_lr
self.lrs = lrs
self.gamma = gamma
self.steps = steps
self.max_epoch = max_epoch
self.num_iters = num_iters
self.last_epoch = last_epoch
self.last_lr = self.warmup_start_lr # used in first iter
self.verbose = verbose
self._var_name = None
def step(self, epoch=None, rebuild=False):
"""
``step`` should be called after ``optimizer.step`` . It will update the learning rate in optimizer according to current ``epoch`` .
The new learning rate will take effect on next ``optimizer.step`` .
Args:
epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.
Returns:
None
"""
if epoch is None:
if not rebuild:
self.last_epoch += 1 / self.num_iters # update step with iters
else:
self.last_epoch = epoch
self.last_lr = self.get_lr()
if self.verbose:
print(
'step Epoch {}: {} set learning rate to {}.self.num_iters={}, 1/self.num_iters={}'
.format(self.last_epoch, self.__class__.__name__, self.last_lr,
self.num_iters, 1 / self.num_iters))
def _lr_func_steps_with_relative_lrs(self, cur_epoch, lrs, base_lr, steps,
max_epoch):
# get step index
steps = steps + [max_epoch]
for ind, step in enumerate(steps):
if cur_epoch < step:
break
if self.verbose:
print(
'_lr_func_steps_with_relative_lrs, cur_epoch {}: {}, steps {}, ind {}, step{}, max_epoch{}'
.format(cur_epoch, self.__class__.__name__, steps, ind, step,
max_epoch))
return lrs[ind - 1] * base_lr
def get_lr(self):
"""Define lr policy"""
lr = self._lr_func_steps_with_relative_lrs(
self.last_epoch,
self.lrs,
self.step_base_lr,
self.steps,
self.max_epoch,
)
lr_end = self._lr_func_steps_with_relative_lrs(
self.warmup_epochs,
self.lrs,
self.step_base_lr,
self.steps,
self.max_epoch,
)
# Perform warm up.
if self.last_epoch < self.warmup_epochs:
lr_start = self.warmup_start_lr
alpha = (lr_end - lr_start) / self.warmup_epochs
lr = self.last_epoch * alpha + lr_start
if self.verbose:
print(
'get_lr, Epoch {}: {}, lr {}, lr_end {}, self.lrs{}, self.step_base_lr{}, self.steps{}, self.max_epoch{}'
.format(self.last_epoch, self.__class__.__name__, lr, lr_end,
self.lrs, self.step_base_lr, self.steps,
self.max_epoch))
return lr
class CustomPiecewiseDecay(PiecewiseDecay):
def __init__(self, **kargs):
kargs.pop('num_iters')
super().__init__(**kargs)
class CustomWarmupCosineStepDecay(LRScheduler):
def __init__(self,
warmup_iters,
warmup_ratio=0.1,
min_lr=0,
base_lr=3e-5,
max_epoch=30,
last_epoch=-1,
num_iters=None,
verbose=False):
self.warmup_ratio = warmup_ratio
self.min_lr = min_lr
self.warmup_epochs = warmup_iters
self.warmup_iters = warmup_iters * num_iters
self.cnt_iters = 0
self.cnt_epoch = 0
self.num_iters = num_iters
self.tot_iters = max_epoch * num_iters
self.max_epoch = max_epoch
self.cosine_base_lr = base_lr # initial lr for all param groups
self.regular_lr = self.get_regular_lr()
super().__init__(last_epoch=last_epoch, verbose=verbose)
def annealing_cos(self, start, end, factor, weight=1):
cos_out = math.cos(math.pi * factor) + 1
return end + 0.5 * weight * (start - end) * cos_out
def get_regular_lr(self):
progress = self.cnt_epoch
max_progress = self.max_epoch
target_lr = self.min_lr
return self.annealing_cos(self.cosine_base_lr, target_lr, progress /
max_progress) # self.cosine_base_lr
def get_warmup_lr(self, cur_iters):
k = (1 - cur_iters / self.warmup_iters) * (1 - self.warmup_ratio)
warmup_lr = self.regular_lr * (1 - k) # 3e-5 * (1-k)
return warmup_lr
def step(self, epoch=None):
self.regular_lr = self.get_regular_lr()
self.last_lr = self.get_lr()
self.cnt_epoch = (self.cnt_iters +
1) // self.num_iters # update step with iters
self.cnt_iters += 1
if self.verbose:
print('Epoch {}: {} set learning rate to {}.'.format(
self.last_epoch, self.__class__.__name__, self.last_lr))
def get_lr(self):
"""Define lr policy"""
cur_iter = self.cnt_iters
if cur_iter >= self.warmup_iters:
return self.regular_lr
else:
warmup_lr = self.get_warmup_lr(cur_iter)
return warmup_lr
class CustomWarmupAdjustDecay(LRScheduler):
r"""
We combine warmup and stepwise-cosine which is used in slowfast model.
Args:
step_base_lr (float): start learning rate used in warmup stage.
warmup_epochs (int): the number epochs of warmup.
lr_decay_rate (float|int, optional): base learning rate decay rate.
step (int): step in change learning rate.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``CosineAnnealingDecay`` instance to schedule learning rate.
"""
def __init__(self,
step_base_lr,
warmup_epochs,
lr_decay_rate,
boundaries,
num_iters=None,
last_epoch=-1,
verbose=False):
self.step_base_lr = step_base_lr
self.warmup_epochs = warmup_epochs
self.lr_decay_rate = lr_decay_rate
self.boundaries = boundaries
self.num_iters = num_iters
#call step() in base class, last_lr/last_epoch/base_lr will be update
super(CustomWarmupAdjustDecay, self).__init__(last_epoch=last_epoch,
verbose=verbose)
def step(self, epoch=None):
"""
``step`` should be called after ``optimizer.step`` . It will update the learning rate in optimizer according to current ``epoch`` .
The new learning rate will take effect on next ``optimizer.step`` .
Args:
epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.
Returns:
None
"""
if epoch is None:
if self.last_epoch == -1:
self.last_epoch += 1
else:
self.last_epoch += 1 / self.num_iters # update step with iters
else:
self.last_epoch = epoch
self.last_lr = self.get_lr()
if self.verbose:
print('Epoch {}: {} set learning rate to {}.'.format(
self.last_epoch, self.__class__.__name__, self.last_lr))
def get_lr(self):
if self.last_epoch < self.warmup_epochs:
lr = self.step_base_lr * (self.last_epoch + 1) / self.warmup_epochs
else:
lr = self.step_base_lr * (self.lr_decay_rate**np.sum(
self.last_epoch >= np.array(self.boundaries)))
return lr