Skip to content

Commit

Permalink
Add scheduler options for long training
Browse files Browse the repository at this point in the history
- Implement RepeatedMultiStepLR
  • Loading branch information
itsnamgyu committed Mar 2, 2022
1 parent 3b3b7fb commit f15d431
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 3 deletions.
2 changes: 2 additions & 0 deletions io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def parse_args(mode):
parser.add_argument('--epochs', default=1000, type=int, help='Pre-training epochs.') # similar to aug_mode
parser.add_argument('--model_save_interval', default=50, type=int, help='Save model state every N epochs during pre-training.') # similar to aug_mode
parser.add_argument('--optimizer', default=None, type=str, help="Optimizer used during pre-training {'sgd', 'adam'}. Default if None") # similar to aug_mode
parser.add_argument('--scheduler', default="MultiStepLR", type=str, help="Scheduler to use (refer to `pretrain_new.py`)")
parser.add_argument('--scheduler_milestones', default=[400, 600, 800], type=int, nargs="+", help="Milestones for (Repeated)MultiStepLR scheduler")
parser.add_argument('--num_workers', default=None, type=int)

# New ft params
Expand Down
10 changes: 7 additions & 3 deletions pretrain_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from model import get_model_class
from paths import get_output_directory, get_final_pretrain_state_path, get_pretrain_state_path, \
get_pretrain_params_path, get_pretrain_history_path
from scheduler import RepeatedMultiStepLR


def _get_dataloaders(params):
Expand Down Expand Up @@ -89,9 +90,12 @@ def main(params):
else:
raise ValueError('Invalid value for params.optimizer: {}'.format(params.optimizer))

scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
milestones=[400, 600, 800],
gamma=0.1)
if params.scheduler == "MultiStepLR":
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=params.scheduler_milestones, gamma=0.1)
elif params.scheduler == "RepeatedMultiStepLR":
scheduler = RepeatedMultiStepLR(optimizer, milestones=params.scheduler_milestones, interval=1000, gamma=0.1)
else:
raise ValueError("Invalid value for params.scheduler: {}".format(params.scheduler))

pretrain_history = {
'loss': [0] * params.epochs,
Expand Down
47 changes: 47 additions & 0 deletions scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from torchvision.models import resnet18


class RepeatedMultiStepLR(LambdaLR):
def __init__(self, optimizer, milestones=(400, 600, 800), gamma=0.1, interval=1000, **kwargs):
self.milestones = milestones
self.interval = interval
self.gamma = gamma
super().__init__(optimizer, self._lambda, **kwargs)

def _lambda(self, epoch):
factor = 1
for milestone in self.milestones:
if epoch % self.interval >= milestone:
factor *= self.gamma
return factor


def main():
resnet = resnet18()

optimizer1 = Adam(resnet.parameters(), lr=0.1)
optimizer2 = Adam(resnet.parameters(), lr=0.1)

s1 = torch.optim.lr_scheduler.MultiStepLR(optimizer1, milestones=[400, 600, 800], gamma=0.1)
s2 = RepeatedMultiStepLR(optimizer2, milestones=[400, 600, 800])
s1_history = []
s2_history = []

for i in range(2000):
# print("Epoch {:04d}: {:.6f} / {:.6f}".format(i, s1.get_last_lr()[0], s2.get_last_lr()[0]))
s1_history.append(s1.get_last_lr()[0])
s2_history.append(s2.get_last_lr()[0])
s1.step()
s2.step()

assert (s1_history[:1000] == s2_history[:1000])
assert (s1_history[:1000] == s2_history[1000:])

print("Manual test passed!")


if __name__ == "__main__": # manual unit test
main()

0 comments on commit f15d431

Please sign in to comment.