From adcd71c8f1b9b9733096cf661f5516f9b67638be Mon Sep 17 00:00:00 2001 From: Eric Ihli Date: Wed, 8 May 2024 13:40:06 -0700 Subject: [PATCH] fix divide by 0 error in scheduler --- gato/training/schedulers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gato/training/schedulers.py b/gato/training/schedulers.py index 03f76f3..300aa17 100644 --- a/gato/training/schedulers.py +++ b/gato/training/schedulers.py @@ -19,7 +19,7 @@ def get_linear_warmup_cosine_decay_scheduler(optimizer: Optimizer, num_warmup_st return LambdaLR(optimizer, lr_lambda, last_epoch) def _linear_warmup_cosine_decay(current_step: int, *, num_warmup_steps: int,num_training_steps: int, base_lr: float, init_lr: float, min_lr: float, cosine_decay: bool): - if current_step <= num_warmup_steps: + if current_step <= num_warmup_steps and num_warmup_steps != 0: lr = init_lr + (base_lr - init_lr) * current_step / num_warmup_steps elif cosine_decay: # cosine decay from base_lr to min_lr over remaining steps @@ -44,4 +44,4 @@ def _linear_warmup_cosine_decay(current_step: int, *, num_warmup_steps: int,num_ for step in current_steps: lr[step - 1] = _linear_warmup_cosine_decay(step, warmup_steps, max_steps, base_lr, init_lr, min_lr) plt.plot(current_steps, lr) - plt.show() \ No newline at end of file + plt.show()