Skip to content

Commit

Permalink
Pre-commit formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
kavanase committed Jun 24, 2024
1 parent b27ec10 commit 3021d4e
Show file tree
Hide file tree
Showing 17 changed files with 45 additions and 22 deletions.
1 change: 1 addition & 0 deletions nequip/data/AtomicDataDict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Authors: Albert Musaelian
"""

from typing import Dict, Any

import torch
Expand Down
10 changes: 6 additions & 4 deletions nequip/data/_dataset/_ase_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,12 @@ def _ase_dataset_reader(
datas.append(
(
global_index,
AtomicData.from_ase(atoms=atoms, **atomicdata_kwargs)
if global_index in include_frames
# in-memory dataset will ignore this later, but needed for indexing to work out
else None,
(
AtomicData.from_ase(atoms=atoms, **atomicdata_kwargs)
if global_index in include_frames
# in-memory dataset will ignore this later, but needed for indexing to work out
else None
),
)
)
# Save to a tempfile---
Expand Down
1 change: 1 addition & 0 deletions nequip/data/_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
This is a seperate module to compensate for a TorchScript bug that can only recognize constants when they are accessed as attributes of an imported module.
"""

import sys
from typing import List

Expand Down
1 change: 1 addition & 0 deletions nequip/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class PartialSampler(Sampler[int]):
If `None`, defaults to `len(data_source)`.
generator (Generator): Generator used in sampling.
"""

data_source: Dataset
num_samples_per_epoch: int
shuffle: bool
Expand Down
6 changes: 3 additions & 3 deletions nequip/nn/_convnetlayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ def __init__(
# updated with whatever the convolution outputs (which is a full graph module)
self.irreps_out.update(self.conv.irreps_out)
# but with the features updated by the nonlinearity
self.irreps_out[
AtomicDataDict.NODE_FEATURES_KEY
] = self.equivariant_nonlin.irreps_out
self.irreps_out[AtomicDataDict.NODE_FEATURES_KEY] = (
self.equivariant_nonlin.irreps_out
)

def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
# save old features for resnet
Expand Down
3 changes: 3 additions & 0 deletions nequip/nn/_grad_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class GradientOutput(GraphModuleMixin, torch.nn.Module):
out_field: the field in which to return the computed gradients. Defaults to ``f"d({of})/d({wrt})"`` for each field in ``wrt``.
sign: either 1 or -1; the returned gradient is multiplied by this.
"""

sign: float
_negate: bool
skip: bool
Expand Down Expand Up @@ -119,6 +120,7 @@ class PartialForceOutput(GraphModuleMixin, torch.nn.Module):
vectorize: the vectorize option to ``torch.autograd.functional.jacobian``,
false by default since it doesn't work well.
"""

vectorize: bool

def __init__(
Expand Down Expand Up @@ -183,6 +185,7 @@ class StressOutput(GraphModuleMixin, torch.nn.Module):
func: the energy model to wrap
do_forces: whether to compute forces as well
"""

do_forces: bool

def __init__(
Expand Down
1 change: 1 addition & 0 deletions nequip/nn/_interaction_block.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" Interaction Block """

from typing import Optional, Dict, Callable

import torch
Expand Down
8 changes: 5 additions & 3 deletions nequip/scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,9 +383,11 @@ def main(args=None, running_as_script: bool = True):
if do_metrics:
display_bar = context_stack.enter_context(
tqdm(
bar_format=""
if prog.disable # prog.ncols doesn't exist if disabled
else ("{desc:." + str(prog.ncols) + "}"),
bar_format=(
""
if prog.disable # prog.ncols doesn't exist if disabled
else ("{desc:." + str(prog.ncols) + "}")
),
disable=None,
)
)
Expand Down
4 changes: 3 additions & 1 deletion nequip/scripts/train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" Train a network."""

import logging
import argparse
import warnings
Expand All @@ -24,7 +25,8 @@
from nequip.scripts._logger import set_up_script_logger

warnings.filterwarnings( # unnecessary e3nn-related JIT warning
"ignore", message="The TorchScript type system doesn't support instance-level annotations"
"ignore",
message="The TorchScript type system doesn't support instance-level annotations",
)
default_config = dict(
root="./",
Expand Down
1 change: 1 addition & 0 deletions nequip/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
If a parameter is updated, the updated value will be formatted back to the same type.
"""

from typing import Set, Dict, Any, List

import inspect
Expand Down
1 change: 1 addition & 0 deletions nequip/utils/savenload.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
utilities that involve file searching and operations (i.e. save/load)
"""

from typing import Union, List, Tuple, Optional, Callable
import sys
import logging
Expand Down
8 changes: 5 additions & 3 deletions nequip/utils/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@ def assert_permutation_equivariant(

if tolerance is None:
atol = PERMUTATION_FLOAT_TOLERANCE[
func.model_dtype
if isinstance(func, GraphModel)
else torch.get_default_dtype()
(
func.model_dtype
if isinstance(func, GraphModel)
else torch.get_default_dtype()
)
]
else:
atol = tolerance
Expand Down
10 changes: 6 additions & 4 deletions nequip/utils/unittests/model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,10 +449,12 @@ def test_partial_forces(self, config, atomic_batch, device, strict_locality):
assert torch.allclose(
output[k],
output_partial[k],
atol=1e-8
if k == AtomicDataDict.TOTAL_ENERGY_KEY
and torch.get_default_dtype() == torch.float64
else 1e-5,
atol=(
1e-8
if k == AtomicDataDict.TOTAL_ENERGY_KEY
and torch.get_default_dtype() == torch.float64
else 1e-5
),
)
else:
assert torch.equal(output[k], output_partial[k])
Expand Down
8 changes: 5 additions & 3 deletions tests/integration/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,11 @@ def runit(params: dict):
assert np.allclose(
err,
0.0,
atol=1e-8
if true_identity
else (1e-2 if metric.startswith("e") else 1e-4),
atol=(
1e-8
if true_identity
else (1e-2 if metric.startswith("e") else 1e-4)
),
), f"Metric `{metric}` wasn't zero!"
elif builder == ConstFactorModel:
# TODO: check comperable to naive numpy compute
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/model/test_pair/test_zbl.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_lammps_repro(self, config):
# $ lmp -in zbl_data.lmps
# $ python -c "import numpy as np; d = np.loadtxt('zbl.dat', skiprows=1); np.save('zbl.npy', d)"
refdata = np.load(Path(__file__).parent / "zbl.npy")
for (r, Zi, Zj, pe, fxi, fxj) in refdata:
for r, Zi, Zj, pe, fxi, fxj in refdata:
if r >= r_max:
continue
atoms.positions[1, 0] = r
Expand Down
1 change: 1 addition & 0 deletions tests/unit/utils/test_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Config tests
"""

import pytest

from os import remove
Expand Down
1 change: 1 addition & 0 deletions tests/unit/utils/test_output.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Config tests
"""

import pytest
import tempfile

Expand Down

0 comments on commit 3021d4e

Please sign in to comment.