Skip to content

Commit

Permalink
Merge branch 'develop' into warn-unused
Browse files Browse the repository at this point in the history
  • Loading branch information
Linux-cpp-lisp committed Mar 7, 2022
2 parents c65fb38 + a90d22d commit 98558ef
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 16 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ Most recent change on the bottom.


## [Unreleased] - 0.5.4
### Added
- `NequIPCalculator` now handles per-atom energies

### Fixed
- Better error in `Dataset.statistics` when field is missing
- `NequIPCalculator` now outputs energy as scalar rather than `(1, 1)` array

## [0.5.3] - 2022-02-23
### Added
Expand Down
2 changes: 1 addition & 1 deletion configs/full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ early_stopping_lower_bounds:
LR: 1.0e-5

early_stopping_upper_bounds: # stop early if a metric value is higher than the bound
wall: 1.0e+100
cumulative_wall: 1.0e+100

# loss function
loss_coeffs: # different weights to use in a weighted loss functions
Expand Down
15 changes: 13 additions & 2 deletions nequip/ase/nequip_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class NequIPCalculator(Calculator):
"""

implemented_properties = ["energy", "forces"]
implemented_properties = ["energy", "energies", "forces"]

def __init__(
self,
Expand Down Expand Up @@ -113,11 +113,22 @@ def calculate(self, atoms=None, properties=["energy"], system_changes=all_change
# predict + extract data
out = self.model(data)
forces = out[AtomicDataDict.FORCE_KEY].detach().cpu().numpy()
energy = out[AtomicDataDict.TOTAL_ENERGY_KEY].detach().cpu().numpy()
energy = (
out[AtomicDataDict.TOTAL_ENERGY_KEY].detach().cpu().numpy().reshape(tuple())
)

# store results
self.results = {
"energy": energy * self.energy_units_to_eV,
# force has units eng / len:
"forces": forces * (self.energy_units_to_eV / self.length_units_to_A),
}

if AtomicDataDict.PER_ATOM_ENERGY_KEY in out:
self.results["energies"] = self.energy_units_to_eV * (
out[AtomicDataDict.PER_ATOM_ENERGY_KEY]
.detach()
.squeeze(-1)
.cpu()
.numpy()
)
4 changes: 4 additions & 0 deletions nequip/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,10 @@ def statistics(
assert arr_is_per in ("node", "graph", "edge")
else:
# Give a better error
if field not in ff_transformed and field not in data_transformed:
raise RuntimeError(
f"Field `{field}` for which statistics were requested not found in data."
)
if field not in selectors:
# this means field is not selected and so not available
raise RuntimeError(
Expand Down
13 changes: 11 additions & 2 deletions nequip/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def __init__(
config=None,
):
self._initialized = False
self.cumulative_wall = 0
logging.debug("* Initialize Trainer")

assert isinstance(config, Config)
Expand Down Expand Up @@ -417,14 +418,14 @@ def init_objects(self):
)
n_args = 0
for key, item in kwargs.items():
# prepand VALIDATION string if k is not with
# prepend VALIDATION string if k is not with
if isinstance(item, dict):
new_dict = {}
for k, v in item.items():
if (
k.lower().startswith(VALIDATION)
or k.lower().startswith(TRAIN)
or k.lower() in ["lr", "wall"]
or k.lower() in ["lr", "wall", "cumulative_wall"]
):
new_dict[k] = item[k]
else:
Expand Down Expand Up @@ -516,6 +517,7 @@ def as_dict(
dictionary["state_dict"]["cuda_rng_state"] = torch.cuda.get_rng_state(
device=self.torch_device
)
dictionary["state_dict"]["cumulative_wall"] = self.cumulative_wall

if training_progress:
dictionary["progress"] = {}
Expand Down Expand Up @@ -651,6 +653,7 @@ def from_dict(cls, dictionary, append: Optional[bool] = None):
if item is not None:
item.load_state_dict(state_dict[key])
trainer._initialized = True
trainer.cumulative_wall = state_dict["cumulative_wall"]

torch.set_rng_state(state_dict["rng_state"])
trainer.dataset_rng.set_state(state_dict["dataset_rng_state"])
Expand Down Expand Up @@ -728,6 +731,7 @@ def init(self):
self.init_objects()

self._initialized = True
self.cumulative_wall = 0

def init_metrics(self):
if self.metrics_components is None:
Expand Down Expand Up @@ -767,6 +771,7 @@ def train(self):

self.init_log()
self.wall = perf_counter()
self.previous_cumulative_wall = self.cumulative_wall

with atomic_write_group():
if self.iepoch == -1:
Expand Down Expand Up @@ -1050,7 +1055,9 @@ def final_log(self):

self.logger.info(f"! Stop training: {self.stop_arg}")
wall = perf_counter() - self.wall
self.cumulative_wall = wall + self.previous_cumulative_wall
self.logger.info(f"Wall time: {wall}")
self.logger.info(f"Cumulative wall time: {self.cumulative_wall}")

def end_of_epoch_log(self):
"""
Expand All @@ -1059,10 +1066,12 @@ def end_of_epoch_log(self):

lr = self.optim.param_groups[0]["lr"]
wall = perf_counter() - self.wall
self.cumulative_wall = wall + self.previous_cumulative_wall
self.mae_dict = dict(
LR=lr,
epoch=self.iepoch,
wall=wall,
cumulative_wall=self.cumulative_wall,
)

header = "epoch, wall, LR"
Expand Down
16 changes: 12 additions & 4 deletions tests/integration/test_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ def test_deploy(BENCHMARK_ROOT, device):
# # TODO: is this true?
# pytest.skip("CUDA and subprocesses have issues")

keys = [AtomicDataDict.TOTAL_ENERGY_KEY, AtomicDataDict.FORCE_KEY]
keys = [
AtomicDataDict.TOTAL_ENERGY_KEY,
AtomicDataDict.FORCE_KEY,
AtomicDataDict.PER_ATOM_ENERGY_KEY,
]

config_path = pathlib.Path(__file__).parents[2] / "configs/minimal.yaml"
true_config = yaml.load(config_path.read_text(), Loader=yaml.Loader)
Expand Down Expand Up @@ -73,7 +77,7 @@ def test_deploy(BENCHMARK_ROOT, device):
dataset = dataset_from_config(Config.from_file(full_config_path))
data = AtomicData.to_AtomicDataDict(dataset[0].to(device))
for k in keys:
data.pop(k)
data.pop(k, None)
train_pred = best_mod(data)
train_pred = {k: train_pred[k].to("cpu") for k in keys}

Expand All @@ -92,7 +96,7 @@ def test_deploy(BENCHMARK_ROOT, device):
data_idx = 0
data = AtomicData.to_AtomicDataDict(dataset[data_idx].to("cpu"))
for k in keys:
data.pop(k)
data.pop(k, None)
deploy_pred = deploy_mod(data)
deploy_pred = {k: deploy_pred[k].to("cpu") for k in keys}
for k in keys:
Expand Down Expand Up @@ -127,10 +131,14 @@ def test_deploy(BENCHMARK_ROOT, device):
ase_pred = {
AtomicDataDict.TOTAL_ENERGY_KEY: atoms.get_potential_energy(),
AtomicDataDict.FORCE_KEY: atoms.get_forces(),
AtomicDataDict.PER_ATOM_ENERGY_KEY: atoms.get_potential_energies(),
}
assert ase_pred[AtomicDataDict.TOTAL_ENERGY_KEY].shape == tuple()
assert ase_pred[AtomicDataDict.FORCE_KEY].shape == (len(atoms), 3)
assert ase_pred[AtomicDataDict.PER_ATOM_ENERGY_KEY].shape == (len(atoms),)
for k in keys:
assert torch.allclose(
deploy_pred[k],
deploy_pred[k].squeeze(-1),
torch.as_tensor(ase_pred[k], dtype=torch.get_default_dtype()),
atol=atol,
)
15 changes: 8 additions & 7 deletions tests/unit/data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def ase_file(molecules):

@pytest.fixture(scope="function")
def npz():
np.random.seed(0)
natoms = NATOMS
nframes = 8
yield dict(
Expand Down Expand Up @@ -219,6 +220,7 @@ def test_per_graph_field(self, npz_dataset, fixed_field, subset, key, dim):
if npz_dataset is None:
return

torch.manual_seed(0)
E = torch.rand((npz_dataset.len(),) + dim)
ref_mean = torch.mean(E / NATOMS, dim=0)
ref_std = torch.std(E / NATOMS, dim=0)
Expand Down Expand Up @@ -296,9 +298,9 @@ def test_per_graph_field(
del Ns

if alpha == 1e-5:
ref_mean, ref_std, E = generate_E(N, 1000, 0.0)
ref_mean, ref_std, E = generate_E(N, 100, 1000, 0.0)
else:
ref_mean, ref_std, E = generate_E(N, 1000, 0.5)
ref_mean, ref_std, E = generate_E(N, 100, 1000, 0.5)

if subset:
E_orig_order = torch.zeros_like(
Expand Down Expand Up @@ -337,9 +339,8 @@ def test_per_graph_field(
if alpha == 1e-5:
assert torch.allclose(mean, ref_mean, rtol=1e-1)
else:
assert torch.allclose(mean, ref_mean, rtol=2)
# This test is disabled because it (correctly) fails sometimes
# assert torch.allclose(std, torch.zeros_like(ref_mean), atol=alpha * 100)
assert torch.allclose(mean, ref_mean, rtol=1)
assert torch.allclose(std, torch.zeros_like(ref_mean), atol=alpha * 100)
elif regressor == "NormalizedGaussianProcess":
assert torch.std(mean).numpy() == 0
else:
Expand Down Expand Up @@ -436,9 +437,9 @@ def test_from_atoms(self, molecules):
)


def generate_E(N, mean, std):
def generate_E(N, mean_min, mean_max, std):
torch.manual_seed(0)
ref_mean = torch.rand((N.shape[1])) * mean
ref_mean = torch.rand((N.shape[1])) * (mean_max - mean_min) + mean_min
t_mean = torch.ones((N.shape[0], 1)) * ref_mean.reshape([1, -1])
ref_std = torch.rand((N.shape[1])) * std
t_std = torch.ones((N.shape[0], 1)) * ref_std.reshape([1, -1])
Expand Down
37 changes: 37 additions & 0 deletions tests/unit/utils/test_gp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch
import pytest

from nequip.utils.regressor import base_gp
from sklearn.gaussian_process.kernels import DotProduct


# @pytest.mark.parametrize("full_rank", [True, False])
@pytest.mark.parametrize("full_rank", [False])
@pytest.mark.parametrize("alpha", [0, 1e-3, 0.1, 1])
def test_random(full_rank, alpha):

if alpha == 0 and not full_rank:
return

torch.manual_seed(0)
n_samples = 10
n_dim = 3

if full_rank:
X = torch.randint(low=1, high=10, size=(n_samples, n_dim))
else:
X = torch.randint(low=1, high=10, size=(n_samples, 1)) * torch.ones(
(n_samples, n_dim)
)

ref_mean = torch.rand((n_dim, 1))
y = torch.matmul(X, ref_mean)

mean, std = base_gp(
X, y, DotProduct, {"sigma_0": 0, "sigma_0_bounds": "fixed"}, alpha=0.1
)

if full_rank:
assert torch.allclose(ref_mean, mean, rtol=0.5)
else:
assert torch.allclose(mean, mean[0], rtol=1e-3)

0 comments on commit 98558ef

Please sign in to comment.