diff --git a/nequip/data/misc.py b/nequip/data/misc.py new file mode 100644 index 00000000..24aaed6c --- /dev/null +++ b/nequip/data/misc.py @@ -0,0 +1,131 @@ +""" +Relevant chunk from ASE +https://gitlab.com/ase/ase/-/blob/master/ase/data/__init__.py +to reduce dependencies for nn modules +""" + +chemical_symbols = [ + "X", + "H", + "He", + "Li", + "Be", + "B", + "C", + "N", + "O", + "F", + "Ne", + "Na", + "Mg", + "Al", + "Si", + "P", + "S", + "Cl", + "Ar", + "K", + "Ca", + "Sc", + "Ti", + "V", + "Cr", + "Mn", + "Fe", + "Co", + "Ni", + "Cu", + "Zn", + "Ga", + "Ge", + "As", + "Se", + "Br", + "Kr", + "Rb", + "Sr", + "Y", + "Zr", + "Nb", + "Mo", + "Tc", + "Ru", + "Rh", + "Pd", + "Ag", + "Cd", + "In", + "Sn", + "Sb", + "Te", + "I", + "Xe", + "Cs", + "Ba", + "La", + "Ce", + "Pr", + "Nd", + "Pm", + "Sm", + "Eu", + "Gd", + "Tb", + "Dy", + "Ho", + "Er", + "Tm", + "Yb", + "Lu", + "Hf", + "Ta", + "W", + "Re", + "Os", + "Ir", + "Pt", + "Au", + "Hg", + "Tl", + "Pb", + "Bi", + "Po", + "At", + "Rn", + "Fr", + "Ra", + "Ac", + "Th", + "Pa", + "U", + "Np", + "Pu", + "Am", + "Cm", + "Bk", + "Cf", + "Es", + "Fm", + "Md", + "No", + "Lr", + "Rf", + "Db", + "Sg", + "Bh", + "Hs", + "Mt", + "Ds", + "Rg", + "Cn", + "Nh", + "Fl", + "Mc", + "Lv", + "Ts", + "Og", +] + +chemical_symbols_to_atomic_numbers_dict = { + symbol: Z for Z, symbol in enumerate(chemical_symbols) +} diff --git a/nequip/nn/pair_potential.py b/nequip/nn/pair_potential.py index 5704b379..608f7df6 100644 --- a/nequip/nn/pair_potential.py +++ b/nequip/nn/pair_potential.py @@ -4,9 +4,8 @@ from e3nn.util.jit import compile_mode -import ase.data - from nequip.data import AtomicDataDict +from nequip.data.misc import chemical_symbols_to_atomic_numbers_dict from ._util import scatter from ._graph_mixin import GraphModuleMixin from nequip.utils import conditional_torchscript_jit @@ -269,7 +268,7 @@ def __init__( ) assert len(chemical_species) == num_types atomic_numbers: List[int] = [ - ase.data.atomic_numbers[chemical_species[type_i]] + chemical_symbols_to_atomic_numbers_dict[chemical_species[type_i]] for type_i in range(num_types) ] if min(atomic_numbers) < 1: