Skip to content

Commit

Permalink
clean up rescale
Browse files Browse the repository at this point in the history
  • Loading branch information
cw-tan committed Oct 22, 2024
1 parent cb437ba commit 70daf0e
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 65 deletions.
34 changes: 12 additions & 22 deletions nequip/model/_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from nequip.data import AtomicDataDict

import warnings
from typing import Union
from typing import List


RESCALE_THRESHOLD = 1e-6
Expand All @@ -15,39 +15,29 @@ def GlobalRescale(
config,
initialize: bool,
module_prefix: str,
default_scale: Union[str, float, list],
default_scale_keys: list,
default_scale: float,
default_scale_keys: List[str],
):
"""Add global rescaling for energy(-based quantities).
"""Rescales a set of fields."""
scale = config.get(f"{module_prefix}_scale", default_scale)

If ``initialize`` is false, doesn't compute statistics.
"""
global_scale = config.get(f"{module_prefix}_scale", default_scale)

if global_scale is None:
if scale is None:
warnings.warn(
f"Module `{module_prefix}` added but global_scale is `None`. Please check to ensure this is intended. To set global_scale, `{module_prefix}_global_scale` must be provided in the config."
f"Module `{module_prefix}` added but scale is `None`. Please check to ensure this is intended. To set scale, `{module_prefix}_scale` must be provided in the config."
)

# = Get statistics of training dataset =
if initialize:
if global_scale is not None and global_scale < RESCALE_THRESHOLD:
raise ValueError(
f"Global energy scaling was very low: {global_scale}. If dataset values were used, does the dataset contain insufficient variation? Maybe try disabling global scaling with global_scale=None."
)
else:
# Put dummy values
if global_scale is not None:
global_scale = 1.0
if scale is not None and scale < RESCALE_THRESHOLD:
raise ValueError(
f"Global energy scaling was very low: {scale}. If dataset values were used, does the dataset contain insufficient variation? Maybe try disabling global scaling with scale=None."
)

assert isinstance(default_scale_keys, list), "keys need to be a list"

# == Build the model ==
return RescaleOutputModule(
model=model,
scale_keys=[k for k in default_scale_keys if k in model.irreps_out],
scale_by=global_scale,
default_dtype=config.get("default_dtype", None),
scale_by=scale,
)


Expand Down
49 changes: 14 additions & 35 deletions nequip/nn/_rescale.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,36 @@
from typing import Sequence, List, Union, Optional

import torch

from e3nn.util.jit import compile_mode

from nequip.data import AtomicDataDict
from nequip.nn import GraphModuleMixin
from nequip.utils.misc import dtype_from_name
from nequip.utils.global_dtype import _GLOBAL_DTYPE

# TODO: merge this into GraphModel
from typing import Sequence, List, Dict, Union


@compile_mode("script")
class RescaleOutput(GraphModuleMixin, torch.nn.Module):
"""Wrap a model and rescale its outputs.
Note that scaling is always done (casting into) ``default_dtype``, even if ``model_dtype`` is lower precision.
Note that scaling is always done (casting into) ``_GLOBAL_DTYPE=torch.float64``, even if ``model_dtype`` is of lower precision.
Args:
model : GraphModuleMixin
The model whose outputs are to be rescaled.
scale_keys : list of keys, default []
Which fields to rescale.
scale_by : floating or Tensor, default 1.
The scaling factor by which to multiply fields in ``scale``.
irreps_in : dict, optional
Extra inputs expected by this beyond those of `model`; this is only present for compatibility.
model (GraphModuleMixin): model whose outputs are to be rescaled
scale_keys (List[str]) : fields to rescale
scale_by (float): scaling factor by which to multiply fields in ``scale_keys``
"""

scale_keys: List[str]
_all_keys: List[str]

has_scale: bool

default_dtype: torch.dtype

def __init__(
self,
model: GraphModuleMixin,
scale_keys: Union[Sequence[str], str] = [],
scale_by=None,
default_dtype: Optional[str] = None,
irreps_in: dict = {},
scale_keys: Union[Sequence[str], str],
scale_by: float,
irreps_in: Dict = {},
):
super().__init__()

Expand All @@ -67,17 +56,8 @@ def __init__(
self.scale_keys = list(scale_keys)
self._all_keys = list(all_keys)

self.default_dtype = dtype_from_name(
torch.get_default_dtype() if default_dtype is None else default_dtype
)

self.has_scale = scale_by is not None
if self.has_scale:
scale_by = torch.as_tensor(scale_by, dtype=self.default_dtype)
self.register_buffer("scale_by", scale_by)
else:
# register dummy for TorchScript
self.register_buffer("scale_by", torch.Tensor())
scale_by = torch.as_tensor(scale_by, dtype=_GLOBAL_DTYPE)
self.register_buffer("scale_by", scale_by)

# Finally, we tell all the modules in the model that there is rescaling
# This allows them to update parameters, like physical constants with units,
Expand Down Expand Up @@ -109,8 +89,7 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
# https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html#torch.Tensor.expand
# confirmed in PyTorch slack
# https://pytorch.slack.com/archives/C3PDTEV8E/p1671652283801129
if self.has_scale:
for field in self.scale_keys:
v = data[field]
data[field] = v * self.scale_by.expand(v.shape)
for field in self.scale_keys:
v = data[field]
data[field] = v * self.scale_by.expand(v.shape)
return data
11 changes: 3 additions & 8 deletions tests/unit/nn/test_rescale.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,10 @@
from nequip.utils.test import assert_AtomicData_equivariant


@pytest.mark.parametrize("scale_by", [0.77, 1.0, None])
@pytest.mark.parametrize("shift_trainable", [True, False])
@pytest.mark.parametrize("scale_by", [0.77, 1.0, 2.6])
def test_rescale(
CH3CHO,
scale_by,
shift_trainable,
model_dtype,
):
_, data = CH3CHO
Expand Down Expand Up @@ -54,8 +52,5 @@ def test_rescale(
oh_out = oh_out.to(dtype=rescale_out.dtype)

# node attrs are a one hot, so we know orig then are zeros and ones
if scale_by is None:
assert torch.all(oh_out == rescale_out)
else:
ratio = torch.nan_to_num(rescale_out / oh_out)
assert torch.allclose(ratio[oh_out != 0.0], torch.as_tensor(scale_by))
ratio = torch.nan_to_num(rescale_out / oh_out)
assert torch.allclose(ratio[oh_out != 0.0], torch.as_tensor(scale_by))

0 comments on commit 70daf0e

Please sign in to comment.