diff --git a/nequip/model/_scaling.py b/nequip/model/_scaling.py index 736aa25e..647961f3 100644 --- a/nequip/model/_scaling.py +++ b/nequip/model/_scaling.py @@ -4,7 +4,7 @@ from nequip.data import AtomicDataDict import warnings -from typing import Union +from typing import List RESCALE_THRESHOLD = 1e-6 @@ -15,30 +15,21 @@ def GlobalRescale( config, initialize: bool, module_prefix: str, - default_scale: Union[str, float, list], - default_scale_keys: list, + default_scale: float, + default_scale_keys: List[str], ): - """Add global rescaling for energy(-based quantities). + """Rescales a set of fields.""" + scale = config.get(f"{module_prefix}_scale", default_scale) - If ``initialize`` is false, doesn't compute statistics. - """ - global_scale = config.get(f"{module_prefix}_scale", default_scale) - - if global_scale is None: + if scale is None: warnings.warn( - f"Module `{module_prefix}` added but global_scale is `None`. Please check to ensure this is intended. To set global_scale, `{module_prefix}_global_scale` must be provided in the config." + f"Module `{module_prefix}` added but scale is `None`. Please check to ensure this is intended. To set scale, `{module_prefix}_scale` must be provided in the config." ) - # = Get statistics of training dataset = - if initialize: - if global_scale is not None and global_scale < RESCALE_THRESHOLD: - raise ValueError( - f"Global energy scaling was very low: {global_scale}. If dataset values were used, does the dataset contain insufficient variation? Maybe try disabling global scaling with global_scale=None." - ) - else: - # Put dummy values - if global_scale is not None: - global_scale = 1.0 + if scale is not None and scale < RESCALE_THRESHOLD: + raise ValueError( + f"Global energy scaling was very low: {scale}. If dataset values were used, does the dataset contain insufficient variation? Maybe try disabling global scaling with scale=None." + ) assert isinstance(default_scale_keys, list), "keys need to be a list" @@ -46,8 +37,7 @@ def GlobalRescale( return RescaleOutputModule( model=model, scale_keys=[k for k in default_scale_keys if k in model.irreps_out], - scale_by=global_scale, - default_dtype=config.get("default_dtype", None), + scale_by=scale, ) diff --git a/nequip/nn/_rescale.py b/nequip/nn/_rescale.py index 54cc18fb..c985dc03 100644 --- a/nequip/nn/_rescale.py +++ b/nequip/nn/_rescale.py @@ -1,47 +1,36 @@ -from typing import Sequence, List, Union, Optional - import torch from e3nn.util.jit import compile_mode from nequip.data import AtomicDataDict from nequip.nn import GraphModuleMixin -from nequip.utils.misc import dtype_from_name +from nequip.utils.global_dtype import _GLOBAL_DTYPE -# TODO: merge this into GraphModel +from typing import Sequence, List, Dict, Union @compile_mode("script") class RescaleOutput(GraphModuleMixin, torch.nn.Module): """Wrap a model and rescale its outputs. - Note that scaling is always done (casting into) ``default_dtype``, even if ``model_dtype`` is lower precision. + Note that scaling is always done (casting into) ``_GLOBAL_DTYPE=torch.float64``, even if ``model_dtype`` is of lower precision. Args: - model : GraphModuleMixin - The model whose outputs are to be rescaled. - scale_keys : list of keys, default [] - Which fields to rescale. - scale_by : floating or Tensor, default 1. - The scaling factor by which to multiply fields in ``scale``. - irreps_in : dict, optional - Extra inputs expected by this beyond those of `model`; this is only present for compatibility. + model (GraphModuleMixin): model whose outputs are to be rescaled + scale_keys (List[str]) : fields to rescale + scale_by (float): scaling factor by which to multiply fields in ``scale_keys`` """ scale_keys: List[str] _all_keys: List[str] - has_scale: bool - default_dtype: torch.dtype - def __init__( self, model: GraphModuleMixin, - scale_keys: Union[Sequence[str], str] = [], - scale_by=None, - default_dtype: Optional[str] = None, - irreps_in: dict = {}, + scale_keys: Union[Sequence[str], str], + scale_by: float, + irreps_in: Dict = {}, ): super().__init__() @@ -67,17 +56,8 @@ def __init__( self.scale_keys = list(scale_keys) self._all_keys = list(all_keys) - self.default_dtype = dtype_from_name( - torch.get_default_dtype() if default_dtype is None else default_dtype - ) - - self.has_scale = scale_by is not None - if self.has_scale: - scale_by = torch.as_tensor(scale_by, dtype=self.default_dtype) - self.register_buffer("scale_by", scale_by) - else: - # register dummy for TorchScript - self.register_buffer("scale_by", torch.Tensor()) + scale_by = torch.as_tensor(scale_by, dtype=_GLOBAL_DTYPE) + self.register_buffer("scale_by", scale_by) # Finally, we tell all the modules in the model that there is rescaling # This allows them to update parameters, like physical constants with units, @@ -109,8 +89,7 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: # https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html#torch.Tensor.expand # confirmed in PyTorch slack # https://pytorch.slack.com/archives/C3PDTEV8E/p1671652283801129 - if self.has_scale: - for field in self.scale_keys: - v = data[field] - data[field] = v * self.scale_by.expand(v.shape) + for field in self.scale_keys: + v = data[field] + data[field] = v * self.scale_by.expand(v.shape) return data diff --git a/tests/unit/nn/test_rescale.py b/tests/unit/nn/test_rescale.py index 64165ba9..8196da30 100644 --- a/tests/unit/nn/test_rescale.py +++ b/tests/unit/nn/test_rescale.py @@ -11,12 +11,10 @@ from nequip.utils.test import assert_AtomicData_equivariant -@pytest.mark.parametrize("scale_by", [0.77, 1.0, None]) -@pytest.mark.parametrize("shift_trainable", [True, False]) +@pytest.mark.parametrize("scale_by", [0.77, 1.0, 2.6]) def test_rescale( CH3CHO, scale_by, - shift_trainable, model_dtype, ): _, data = CH3CHO @@ -54,8 +52,5 @@ def test_rescale( oh_out = oh_out.to(dtype=rescale_out.dtype) # node attrs are a one hot, so we know orig then are zeros and ones - if scale_by is None: - assert torch.all(oh_out == rescale_out) - else: - ratio = torch.nan_to_num(rescale_out / oh_out) - assert torch.allclose(ratio[oh_out != 0.0], torch.as_tensor(scale_by)) + ratio = torch.nan_to_num(rescale_out / oh_out) + assert torch.allclose(ratio[oh_out != 0.0], torch.as_tensor(scale_by))