Skip to content

Commit

Permalink
Merge pull request #38 from mir-group/develop
Browse files Browse the repository at this point in the history
Release: 0.3.1
  • Loading branch information
simonbatzner authored May 14, 2021
2 parents 6e34915 + 9e9a3fa commit 7146fd5
Showing 10 changed files with 234 additions and 45 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -8,6 +8,13 @@ Most recent change on the bottom.

## [Unreleased]

## [0.3.1]
### Fixed
- `iepoch` is no longer off-by-one when restarting a training run that hit `max_epochs`
- Builders, and not just sub-builders, use the class name as a default prefix
### Added
- `early_stopping_xxx` arguments added to enable early stop for platued values or values that out of lower/upper bounds.

## [0.3.0] - 2021-05-07
### Added
- Sub-builders can be skipped in `instantiate` by setting them to `None`
15 changes: 15 additions & 0 deletions configs/full.yaml
Original file line number Diff line number Diff line change
@@ -87,6 +87,21 @@ use_ema: false
ema_decay: 0.999 # ema weight, commonly set to 0.999
ema_use_num_updates: true # whether to use number of updates when computing averages

# early stopping based on metrics values.
# LR, wall and any keys printed in the log file can be used.
# The key can start with Training or Validation. If not defined, the validation value will be used.
early_stopping_patiences: # stop early if a metric value stopped decreasing for n epochs
Validation_loss: 50 #
Training_loss: 100 #
e_mae: 100 #
early_stopping_delta: # If delta is defined, a tiny decrease smaller than delta will not be considered as a decrease
Training_loss: 0.005 #
early_stopping_cumulative_delta: false # If True, the minimum value recorded will not be updated when the decrease is smaller than delta
early_stopping_lower_bounds: # stop early if a metric value is lower than the bound
LR: 1.0e-10 #
early_stopping_upper_bounds: # stop early if a metric value is higher than the bound
wall: 1.0e+100 #

# loss function
loss_coeffs: # different weights to use in a weighted loss functions
forces: 100 # for MD applications, we recommed a force weight of 100 and an energy weight of 1
1 change: 0 additions & 1 deletion configs/minimal.yaml
Original file line number Diff line number Diff line change
@@ -22,7 +22,6 @@ dataset_file_name: benchmark_data/aspirin_ccsd-train.npz

# logging
wandb: false
wandb_project: aspirin
# verbose: debug

# training
2 changes: 1 addition & 1 deletion nequip/_version.py
Original file line number Diff line number Diff line change
@@ -2,4 +2,4 @@
# See Python packaging guide
# https://packaging.python.org/guides/single-sourcing-package-version/

__version__ = "0.3.0"
__version__ = "0.3.1"
7 changes: 0 additions & 7 deletions nequip/scripts/restart.py
Original file line number Diff line number Diff line change
@@ -59,13 +59,6 @@ def restart(file_name, config, mode="update"):
{"float32": torch.float32, "float64": torch.float64}[config.default_dtype]
)

# increase max_epochs if training has hit maximum epochs
if "progress" in dictionary:
stop_args = dictionary["progress"].pop("stop_arg", None)
if stop_args is not None:
dictionary["progress"]["stop_arg"] = None
dictionary["max_epochs"] *= 2

if config.wandb:
from nequip.train.trainer_wandb import TrainerWandB

105 changes: 105 additions & 0 deletions nequip/train/early_stopping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from collections import OrderedDict
from copy import deepcopy
from typing import Mapping, Optional, cast


class EarlyStopping:
"""
Early stop conditions
There are three early stopping conditions:
1. a value lower than a defined lower bound
2. a value higher than a defined upper bound
3. a value hasn't decreased for x epochs within delta range
Args:
lower_bounds (dict): define the key and lower bound for condition 1
upper_bounds (dict): define the key and lower bound for condition 2
patiences (dict): defined the x epochs for condition 3
delta (dict): defined the delta range for condition 3. defaults are 0.0
cumulative_delta (bool): if True, the minimum value recorded for condition 3
will not be updated when the newer value only decreases
for a tiny value (< delta). default False
"""

def __init__(
self,
lower_bounds: dict = {},
upper_bounds: dict = {},
patiences: dict = {},
delta: dict = {},
cumulative_delta: bool = False,
):

self.patiences = deepcopy(patiences)
self.lower_bounds = deepcopy(lower_bounds)
self.upper_bounds = deepcopy(upper_bounds)
self.cumulative_delta = cumulative_delta

self.delta = {}
self.counters = {}
self.minimums = {}
for key, pat in self.patiences.items():
self.patiences[key] = int(pat)
self.counters[key] = 0
self.minimums[key] = None
self.delta[key] = delta.get(key, 0.0)

if pat < 1:
raise ValueError(
f"Argument patience for {key} should be positive integer."
)
if self.delta[key] < 0.0:
raise ValueError("Argument delta should not be a negative number.")

for key in self.delta:
if key not in self.patiences:
raise ValueError(f"patience for {key} should be defined")

def __call__(self, metrics) -> None:

stop = False
stop_args = "Early stopping:"
debug_args = None

# check whether key in metrics hasn't reduced for x epochs
for key, pat in self.patiences.items():

value = metrics[key]
minimum = self.minimums[key]
delta = self.delta[key]

if minimum is None:
self.minimums[key] = value
elif value >= (minimum - delta):
if not self.cumulative_delta and value > minimum:
self.minimums[key] = value
self.counters[key] += 1
debug_args = f"EarlyStopping: {self.counters[key]} / {pat}"
if self.counters[key] >= pat:
stop_args += " {key} has not reduced for {pat} epochs"
stop = True
else:
self.minimums[key] = value
self.counters[key] = 0

for key, bound in self.lower_bounds.items():
if metrics[key] < bound:
stop_args += f" {key} is smaller than {bound}"
stop = True

for key, bound in self.upper_bounds.items():
if metrics[key] > bound:
stop_args += f" {key} is larger than {bound}"
stop = True

return stop, stop_args, debug_args

def state_dict(self) -> "OrderedDict[dict, dict]":
return OrderedDict([("counters", self.counters), ("minimums", self.minimums)])

def load_state_dict(self, state_dict: Mapping) -> None:
self.counters = state_dict["counters"]
self.minimums = state_dict["minimums"]
Loading

0 comments on commit 7146fd5

Please sign in to comment.