Skip to content

Commit

Permalink
Merge branch 'refs/heads/develop' into stratified_metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
kavanase committed Jul 19, 2024
2 parents 1059248 + d1afaf2 commit ea0f8fe
Show file tree
Hide file tree
Showing 16 changed files with 367 additions and 96 deletions.
11 changes: 10 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
Most recent change on the bottom.


## Unreleased
## Unreleased - 0.7.0
### Added
- `--override` now supported as a `nequip-train` flag (similar to its use in `nequip-deploy`)
- add SoftAdapt (https://arxiv.org/abs/2403.18122) callback option

### Changed
- [Breaking] training restart behavior altered: file-wise consistency checks performed between original config and config passed to `nequip-train` on restart (instead of checking the config dicts)
- [Breaking] config format for callbacks changed (see `configs/full.yaml` for an example)

### Fixed
- fixed `wandb_watch` bug

## [0.6.1] - 2024-7-9
### Added
Expand Down
51 changes: 51 additions & 0 deletions CITATION.cff
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
cff-version: "1.2.0"
message: "If you use this software, please cite our article."
authors:
- family-names: Batzner
given-names: Simon
- family-names: Musaelian
given-names: Albert
- family-names: Sun
given-names: Lixin
- family-names: Geiger
given-names: Mario
- family-names: Mailoa
given-names: Jonathan P.
- family-names: Kornbluth
given-names: Mordechai
- family-names: Molinari
given-names: Nicola
- family-names: Smidt
given-names: Tess E.
- family-names: Kozinsky
given-names: Boris
doi: 10.1038/s41467-022-29939-5
preferred-citation:
authors:
- family-names: Batzner
given-names: Simon
- family-names: Musaelian
given-names: Albert
- family-names: Sun
given-names: Lixin
- family-names: Geiger
given-names: Mario
- family-names: Mailoa
given-names: Jonathan P.
- family-names: Kornbluth
given-names: Mordechai
- family-names: Molinari
given-names: Nicola
- family-names: Smidt
given-names: Tess E.
- family-names: Kozinsky
given-names: Boris
doi: 10.1038/s41467-022-29939-5
date-published: 2022-05-04
issn: 2041-1723
journal: Nature Communications
start: 2453
title: "E(3)-equivariant graph neural networks for data-efficient and accurate interatomic potentials"
type: article
url: "https://www.nature.com/articles/s41467-022-29939-5"
volume: 13
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ Details on writing and using plugins can be found in the [Allegro tutorial](http

## References & citing

The theory behind NequIP is described in our preprint (1). NequIP's backend builds on e3nn, a general framework for building E(3)-equivariant neural networks (2). If you use this repository in your work, please consider citing NequIP (1) and e3nn (3):
The theory behind NequIP is described in our [article](https://www.nature.com/articles/s41467-022-29939-5) (1).
NequIP's backend builds on [`e3nn`](https://e3nn.org), a general framework for building E(3)-equivariant
neural networks (2). If you use this repository in your work, please consider citing `NequIP` (1) and `e3nn` (3):

1. https://www.nature.com/articles/s41467-022-29939-5
2. https://e3nn.org
Expand Down
15 changes: 11 additions & 4 deletions configs/full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,17 @@ loss_coeffs:
# In the "schedule" key each entry is a two-element list of:
# - the 1-based epoch index at which to start the new loss coefficients
# - the new loss coefficients as a dict
#
# start_of_epoch_callbacks:
# - !!python/object:nequip.train.callbacks.loss_schedule.SimpleLossSchedule {"schedule": [[2, {"forces": 0.0, "total_energy": 1.0}]]}
#
# callbacks:
# start_of_epoch:
# - !!python/object:nequip.train.callbacks.SimpleLossSchedule {"schedule": [[2, {"forces": 0.0, "total_energy": 1.0}]]}

# You can also try using the SoftAdapt strategy for adaptively changing loss coefficients
# (see https://arxiv.org/abs/2403.18122)
#callbacks:
# end_of_batch:
# - !!python/object:nequip.train.callbacks.SoftAdapt {"batches_per_update": 5, "beta": 1.1}



# output metrics
metrics_components:
Expand Down
24 changes: 23 additions & 1 deletion docs/cite.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,25 @@
Citing Nequip
Citing NequIP
=============
If you use ``NequIP`` in your research, please cite our `article <https://doi.org/10.1038/s41467-022-29939-5>`_:

.. code-block:: bibtex
@article{batzner_e3-equivariant_2022,
title = {E(3)-Equivariant Graph Neural Networks for Data-Efficient and Accurate Interatomic Potentials},
author = {Batzner, Simon and Musaelian, Albert and Sun, Lixin and Geiger, Mario and Mailoa, Jonathan P. and Kornbluth, Mordechai and Molinari, Nicola and Smidt, Tess E. and Kozinsky, Boris},
year = {2022},
month = may,
journal = {Nature Communications},
volume = {13},
number = {1},
pages = {2453},
issn = {2041-1723},
doi = {10.1038/s41467-022-29939-5},
}
The theory behind NequIP is described in our `article <https://doi.org/10.1038/s41467-022-29939-5>`_ above.
NequIP's backend builds on `e3nn <https://e3nn.org>`_, a general framework for building E(3)-equivariant
neural networks (1). If you use this repository in your work, please consider citing ``NequIP`` and ``e3nn`` (2):

1. https://e3nn.org
2. https://doi.org/10.5281/zenodo.3724963
31 changes: 15 additions & 16 deletions examples/plot_dimers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,20 @@
print("Computing dimers...")
potential = {}
N_sample = args.n_samples
N_combs = len(list(itertools.combinations_with_replacement(range(num_types), 2)))
r = torch.zeros(N_sample * N_combs, 2, 3, device=args.device)
rs_one = torch.linspace(args.r_min, model_r_max, 500, device=args.device)
rs = rs_one.repeat([N_combs])
assert rs.shape == (N_combs * N_sample,)
type_combos = [
list(e) for e in itertools.combinations_with_replacement(range(num_types), 2)
]
N_combos = len(type_combos)
r = torch.zeros(N_sample * N_combos, 2, 3, device=args.device)
rs_one = torch.linspace(args.r_min, model_r_max, N_sample, device=args.device)
rs = rs_one.repeat([N_combos])
assert rs.shape == (N_combos * N_sample,)
r[:, 1, 0] += rs # offset second atom along x axis
types = torch.as_tensor(
[list(e) for e in itertools.combinations_with_replacement(range(num_types), 2)]
)
types = types.reshape(N_combs, 1, 2).expand(N_combs, N_sample, 2).reshape(-1)
types = torch.as_tensor(type_combos)
types = types.reshape(N_combos, 1, 2).expand(N_combos, N_sample, 2).reshape(-1)
r = r.reshape(-1, 3)
assert types.shape == r.shape[:1]
N_at_total = N_sample * N_combs * 2
N_at_total = N_sample * N_combos * 2
assert len(types) == N_at_total
edge_index = torch.vstack(
(
Expand All @@ -61,14 +62,14 @@
)
)
data = AtomicData(pos=r, atom_types=types, edge_index=edge_index)
data.batch = torch.arange(N_sample * N_combs, device=args.device).repeat_interleave(2)
data.ptr = torch.arange(0, 2 * N_sample * N_combs + 1, 2, device=args.device)
data.batch = torch.arange(N_sample * N_combos, device=args.device).repeat_interleave(2)
data.ptr = torch.arange(0, 2 * N_sample * N_combos + 1, 2, device=args.device)
result = model(AtomicData.to_AtomicDataDict(data.to(device=args.device)))

print("Plotting...")
energies = (
result[AtomicDataDict.TOTAL_ENERGY_KEY]
.reshape(N_combs, N_sample)
.reshape(N_combos, N_sample)
.cpu()
.detach()
.numpy()
Expand All @@ -83,9 +84,7 @@
dpi=120,
)

for i, (type1, type2) in enumerate(
itertools.combinations_with_replacement(range(num_types), 2)
):
for i, (type1, type2) in enumerate(type_combos):
ax = axs[i]
ax.set_ylabel(f"{type_names[type1]}-{type_names[type2]}")
ax.plot(rs_one, energies[i])
Expand Down
6 changes: 5 additions & 1 deletion nequip/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import sys

from ._version import __version__ # noqa: F401
Expand All @@ -16,7 +17,10 @@
), f"NequIP supports PyTorch 1.11.* or 1.13.* or later, but {torch_version} found"

# warn if using 1.13* or 2.0.*
if packaging.version.parse("1.13.0") <= torch_version:
if (
packaging.version.parse("1.13.0") <= torch_version
and int(os.environ.get("PYTORCH_VERSION_WARNING", 1)) != 0
):
warnings.warn(
f"!! PyTorch version {torch_version} found. Upstream issues in PyTorch versions 1.13.* and 2.* have been seen to cause unusual performance degredations on some CUDA systems that become worse over time; see https://github.com/mir-group/nequip/discussions/311. The best tested PyTorch version to use with CUDA devices is 1.11; while using other versions if you observe this problem, an unexpected lack of this problem, or other strange behavior, please post in the linked GitHub issue."
)
Expand Down
100 changes: 73 additions & 27 deletions nequip/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import logging
import argparse
import warnings
import shutil
import difflib
import yaml

# This is a weird hack to avoid Intel MKL issues on the cluster when this is called as a subprocess of a process that has itself initialized PyTorch.
# Since numpy gets imported later anyway for dataset stuff, this shouldn't affect performance.
Expand All @@ -29,6 +32,8 @@
root="./",
tensorboard=False,
wandb=False,
wandb_watch=False,
wandb_watch_kwargs={},
model_builders=[
"SimpleIrrepsConfig",
"EnergyModel",
Expand All @@ -46,7 +51,7 @@
equivariance_test=False,
grad_anomaly_mode=False,
gpu_oom_offload=False,
append=False,
append=True,
warn_unused=False,
_jit_bailout_depth=2, # avoid 20 iters of pain, see https://github.com/pytorch/pytorch/issues/52286
# Quote from eelison in PyTorch slack:
Expand All @@ -68,32 +73,61 @@


def main(args=None, running_as_script: bool = True):
config = parse_command_line(args)
config, path_to_config, override_options = parse_command_line(args)

if running_as_script:
set_up_script_logger(config.get("log", None), config.verbose)

found_restart_file = exists(f"{config.root}/{config.run_name}/trainer.pth")
train_dir = f"{config.root}/{config.run_name}"
found_restart_file = exists(f"{train_dir}/trainer.pth")
if found_restart_file and not config.append:
raise RuntimeError(
f"Training instance exists at {config.root}/{config.run_name}; "
f"Training instance exists at {train_dir}; "
"either set append to True or use a different root or runname"
)
elif not found_restart_file and isdir(f"{config.root}/{config.run_name}"):
elif not found_restart_file and isdir(train_dir):
# output directory exists but no ``trainer.pth`` file, suggesting previous run crash during
# first training epoch (usually due to memory):
warnings.warn(
f"Previous run folder at {config.root}/{config.run_name} exists, but a saved model "
f"Previous run folder at {train_dir} exists, but a saved model "
f"(trainer.pth file) was not found. This folder will be cleared and a fresh training run will "
f"be started."
)
rmtree(f"{config.root}/{config.run_name}")
rmtree(train_dir)

# for fresh new train
if not found_restart_file:
if not found_restart_file: # fresh start
# update config with override parameters for setting up train-dir
config.update(override_options)
trainer = fresh_start(config)
else:
trainer = restart(config)
# copy original config to training directory
shutil.copyfile(path_to_config, f"{train_dir}/original_config.yaml")
else: # restart
# perform string matching for original config and restart config
# throw error if they are different
with (
open(f"{train_dir}/original_config.yaml") as orig_f,
open(path_to_config) as current_f,
):
diffs = [
x
for x in difflib.Differ().compare(
orig_f.readlines(), current_f.readlines()
)
if x[0] in ("+", "-")
]
if diffs:
raise RuntimeError(
f"Config {path_to_config} used for restart differs from original config for training run in {train_dir}.\n"
+ "The following differences were found:\n\n"
+ "".join(diffs)
+ "\n"
+ "If you intend to override the original config parameters, use the --override flag. For example, use\n"
+ f'`nequip-train {path_to_config} --override "max_epochs: 42"`\n'
+ 'on the command line to override the config parameter "max_epochs"\n'
+ "BE WARNED that use of the --override flag is not protected by consistency checks performed by NequIP."
)
else:
trainer = restart(config, override_options)

# Train
trainer.save()
Expand Down Expand Up @@ -157,6 +191,12 @@ def parse_command_line(args=None):
help="Warn instead of error when the config contains unused keys",
action="store_true",
)
parser.add_argument(
"--override",
help="Override top-level configuration keys from the `--train-dir`/`--model`'s config YAML file. This should be a valid YAML string. Unless you know why you need to, do not use this option.",
type=str,
default=None,
)
args = parser.parse_args(args=args)

config = Config.from_file(args.config, defaults=default_config)
Expand All @@ -169,10 +209,26 @@ def parse_command_line(args=None):
):
config[flag] = getattr(args, flag) or config[flag]

return config
# Set override options before _set_global_options so that things like allow_tf32 are correctly handled
if args.override is not None:
override_options = yaml.load(args.override, Loader=yaml.Loader)
assert isinstance(
override_options, dict
), "--override's YAML string must define a dictionary of top-level options"
overridden_keys = set(config.keys()).intersection(override_options.keys())
set_keys = set(override_options.keys()) - set(overridden_keys)
logging.info(
f"--override: overrode keys {list(overridden_keys)} and set new keys {list(set_keys)}"
)
del overridden_keys, set_keys
else:
override_options = {}

return config, args.config, override_options


def fresh_start(config):

# we use add_to_config cause it's a fresh start and need to record it
check_code_version(config, add_to_config=True)
_set_global_options(config)
Expand Down Expand Up @@ -267,7 +323,7 @@ def _unused_check():
return trainer


def restart(config):
def restart(config, override_options):
# load the dictionary
restart_file = f"{config.root}/{config.run_name}/trainer.pth"
dictionary = load_file(
Expand All @@ -276,20 +332,6 @@ def restart(config):
enforced_format="torch",
)

# compare dictionary to config and update stop condition related arguments
for k in config.keys():
if config[k] != dictionary.get(k, ""):
if k == "max_epochs":
dictionary[k] = config[k]
logging.info(f'Update "{k}" to {dictionary[k]}')
elif k.startswith("early_stop"):
dictionary[k] = config[k]
logging.info(f'Update "{k}" to {dictionary[k]}')
elif isinstance(config[k], type(dictionary.get(k, ""))):
raise ValueError(
f'Key "{k}" is different in config and the result trainer.pth file. Please double check'
)

# note, "trainer.pth"/dictionary also store code versions,
# which will not be stored in config and thus not checked here
check_code_version(config)
Expand All @@ -299,6 +341,10 @@ def restart(config):

config = Config(dictionary, exclude_keys=["state_dict", "progress"])

# override configs loaded from save
dictionary.update(override_options)
config.update(override_options)

# dtype, etc.
_set_global_options(config)

Expand Down
Loading

0 comments on commit ea0f8fe

Please sign in to comment.