Skip to content

Make SOAP initialization more intuitive w.r.t. species lists #353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 102 additions & 32 deletions bindings/rascal/representations/spherical_expansion.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,47 @@
from collections.abc import Iterable
import json
import logging

from .base import CalculatorFactory, cutoff_function_dict_switch
from ..neighbourlist import AtomsList
import numpy as np
from ..utils import BaseIO
from copy import deepcopy

LOGGER = logging.getLogger(__name__)


def _parse_species_list(species_list):
"""Parse the species_list keyword for initialization

Return a hypers dictionary that can be used to initialize the
underlying C++ object, with expansion_by_species_method and
global_species keys set appropriately
"""
if isinstance(species_list, str):
if (species_list == "environment wise") or (
species_list == "structure wise"
):
expansion_by_species_method = species_list
global_species = []
else:
raise ValueError(
"'species_list' must be one of: {'environment wise', "
"'structure_wise'}, or a list of int"
)
else:
if isinstance(species_list, Iterable):
expansion_by_species_method = "user defined"
global_species = list(species_list)
else:
raise ValueError(
"'species_list' must be a list of int, if user-defined"
)
return {
"expansion_by_species_method": expansion_by_species_method,
"global_species": global_species,
}


class SphericalExpansion(BaseIO):
"""
Expand All @@ -26,16 +62,16 @@ class SphericalExpansion(BaseIO):
max_angular : int
Highest angular momentum number (l) in the expansion

gaussian_sigma_type : str
gaussian_sigma_type : str, default="Constant"
How the Gaussian atom sigmas (smearing widths) are allowed to
vary -- fixed ('Constant'), by species ('PerSpecies'), or by
distance from the central atom ('Radial').

gaussian_sigma_constant : float
Specifies the atomic Gaussian widths, in the case where they're
fixed.
gaussian_sigma_constant : float, default=0.3
Specifies the atomic Gaussian widths when
gaussian_sigma_type=="Constant"

cutoff_function_type : string
cutoff_function_type : string, default="ShiftedCosine"
Choose the type of smooth cutoff function used to define the local
environment. Can be either 'ShiftedCosine' or 'RadialScaling'.

Expand Down Expand Up @@ -72,18 +108,20 @@ class SphericalExpansion(BaseIO):
where :math:`c` is the rate, :math:`r_0` is the scale, :math:`m` is the
exponent.

radial_basis : string
radial_basis : string, default="GTO"
Specifies the type of radial basis R_n to be computed
("GTO" for Gaussian typed orbitals and "DVR" discrete variable representation using Gaussian quadrature rule)
("GTO" for Gaussian typed orbitals and "DVR" discrete variable
representation using Gaussian quadrature rule)

optimization_args : dict
optimization_args : dict, optional
Additional arguments for optimization.
Currently spline optimization for the radial basis function is available
Recommended settings if used {"type":"Spline", "accuracy": 1e-5}
Recommended settings if using: {"type":"Spline", "accuracy": 1e-5}

expansion_by_species_method : string
species_list : string or list(int), default="environment wise"
Specifies the how the species key of the invariant are set-up.
Possible values: 'environment wise', 'user defined', 'structure wise'.
Possible values: 'environment wise', 'structure wise', or a user-defined
list of species.
The descriptor is computed for each atomic enviroment and it is indexed
using tuples of atomic species that are present within the environment.
This index is by definition sparse since a species tuple will be non
Expand All @@ -93,9 +131,16 @@ class SphericalExpansion(BaseIO):
atomic environment.
'structure wise' means that within a structure the species tuples
will be the same for each environment coefficients.
'user defined' uses global_species to set-up the species tuples.

These different settings correspond to different trade-off between
The user-defined option means that all the atomic numbers contained
in the supplied list will be considered. The user _must_ ensure that
all species contained in the structure are represented in the list,
otherwise an error will be raised. They are free to specify species
that do not occur in any structure, however; a typical use case of this
is to compare SOAP vectors between structure sets of possibly different
species composition.

These different settings correspond to different trade-offs between
the memory efficiency of the invariants and the computational
efficiency of the kernel computation.
When computing a kernel using 'environment wise' setting does not allow
Expand All @@ -105,18 +150,13 @@ class SphericalExpansion(BaseIO):

Note that the sparsity of the gradient coefficients and their use to
build kernels does not allow for clear efficiency gains so their
sparsity is kept irrespective of expansion_by_species_method.

global_species : list
list of species to use to set-up the species key of the invariant. It
should contain all the species present in the structure for which
invariants will be computed
sparsity is kept irrespective of the value of this parameter.

compute_gradients : bool
compute_gradients : bool, default False
control the computation of the representation's gradients w.r.t. atomic
positions.

cutoff_function_parameters : dict
cutoff_function_parameters : dict, optional
Additional parameters for the cutoff function.
if cutoff_function_type == 'RadialScaling' then it should have the form

Expand Down Expand Up @@ -146,15 +186,16 @@ def __init__(
cutoff_smooth_width,
max_radial,
max_angular,
gaussian_sigma_type,
gaussian_sigma_type="Constant",
gaussian_sigma_constant=0.3,
cutoff_function_type="ShiftedCosine",
radial_basis="GTO",
optimization_args={},
expansion_by_species_method="environment wise",
global_species=None,
species_list="environment wise",
compute_gradients=False,
cutoff_function_parameters=dict(),
expansion_by_species_method=None,
global_species=None,
):
"""Construct a SphericalExpansion representation

Expand All @@ -165,11 +206,6 @@ class documentation
self.name = "sphericalexpansion"
self.hypers = dict()

if global_species is None:
global_species = []
elif not isinstance(global_species, list):
global_species = list(global_species)

self.update_hyperparameters(
max_radial=max_radial,
max_angular=max_angular,
Expand All @@ -185,6 +221,34 @@ class documentation
cutoff_function = cutoff_function_dict_switch(
cutoff_function_type, **cutoff_function_parameters
)
# Soft backwards compatibility (remove these two if-clauses after 01.11.2021)
# Soft backwards compatibility (remove this whole if-statement after 01.11.2021)
if expansion_by_species_method is not None:
LOGGER.warning(
"Warning: The 'expansion_by_species_method' parameter is deprecated "
"(see 'species_list' parameter instead).\n"
"This message will become an error after 2021-11-01."
)
if expansion_by_species_method != "user defined":
species_list = expansion_by_species_method
elif global_species is not None:
species_list = global_species
else:
raise ValueError(
"Found deprecated 'expansion_by_species_method' parameter "
"set to 'user defined' without 'global_species' set")
elif global_species is not None:
LOGGER.warning(
"Warning: The 'global_species' parameter is deprecated "
"(see 'species_list' parameter instead).\n"
"This message will become an error after 2021-11-01."
"(Also, this needs to be set with "
"'expansion_by_species_method'=='user defined'; proceeding under "
"the assumption that this is what you wanted)"
)
species_list = global_species
species_list_hypers = _parse_species_list(species_list)
self.update_hyperparameters(**species_list_hypers)

gaussian_density = dict(
type=gaussian_sigma_type,
Expand All @@ -198,7 +262,8 @@ class documentation
else:
accuracy = 1e-5
print(
"No accuracy for spline optimization was given. Switching to default accuracy {:.0e}.".format(
"No accuracy for spline optimization was given. "
"Switching to default accuracy {:.0e}.".format(
accuracy
)
)
Expand Down Expand Up @@ -301,14 +366,19 @@ def _get_init_params(self):
gaussian_density = self.hypers["gaussian_density"]
cutoff_function = self.hypers["cutoff_function"]
radial_contribution = self.hypers["radial_contribution"]
global_species = self.hypers["global_species"]
species_list = (
global_species
if global_species
else self.hypers["expansion_by_species_method"]
)

init_params = dict(
interaction_cutoff=cutoff_function["cutoff"]["value"],
cutoff_smooth_width=cutoff_function["smooth_width"]["value"],
max_radial=self.hypers["max_radial"],
max_angular=self.hypers["max_angular"],
expansion_by_species_method=self.hypers["expansion_by_species_method"],
global_species=self.hypers["global_species"],
species_list=species_list,
compute_gradients=self.hypers["compute_gradients"],
gaussian_sigma_type=gaussian_density["type"],
gaussian_sigma_constant=gaussian_density["gaussian_sigma"]["value"],
Expand Down
Loading