diff --git a/train.py b/train.py index 6994fb9b..2e743974 100644 --- a/train.py +++ b/train.py @@ -9,6 +9,7 @@ os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" import gc +import math import time from dataclasses import dataclass, asdict @@ -565,8 +566,8 @@ def get_weight_decay(progress): train_loss_f = train_loss.item() - # Fast fail: abort if loss is exploding - if train_loss_f > 100: + # Fast fail: abort if loss is exploding or NaN + if math.isnan(train_loss_f) or train_loss_f > 100: print("FAIL") exit(1)