diff --git a/torchmdnet/datasets/__init__.py b/torchmdnet/datasets/__init__.py index 92a70bd24..42dc7cf3f 100644 --- a/torchmdnet/datasets/__init__.py +++ b/torchmdnet/datasets/__init__.py @@ -2,7 +2,7 @@ # Distributed under the MIT License. # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) -from .ace import Ace +from .ace import Ace, AceHF from .ani import ANI1, ANI1CCX, ANI1X, ANI2X from .comp6 import ( ANIMD, @@ -28,6 +28,7 @@ __all__ = [ "Ace", + "AceHF", "ANIMD", "ANI1", "ANI1CCX", diff --git a/torchmdnet/datasets/ace.py b/torchmdnet/datasets/ace.py index ebb7da0b3..06496db7c 100644 --- a/torchmdnet/datasets/ace.py +++ b/torchmdnet/datasets/ace.py @@ -7,7 +7,7 @@ import os import torch as pt from torchmdnet.datasets.memdataset import MemmappedDataset -from torch_geometric.data import Data +from torch_geometric.data import Data, Dataset from tqdm import tqdm @@ -291,3 +291,44 @@ def sample_iter(self, mol_ids=False): data = self.pre_transform(data) yield data + + +class AceHF(Dataset): + def __init__(self, root="parquet", paths=None, split="train") -> None: + from datasets import load_dataset + + self.dataset = load_dataset(root, data_files=paths, split=split) + self.dataset = self.dataset.with_format("torch") + + def __len__(self): + return self.dataset.num_rows + + def __getitem__(self, idx): + """Gets the data object at index :obj:`idx`. + + The data object contains the following attributes: + + - :obj:`z`: Atomic numbers of the atoms. + - :obj:`pos`: Positions of the atoms. + - :obj:`y`: Formation energy of the molecule. + - :obj:`neg_dy`: Forces on the atoms. + - :obj:`q`: Total charge of the molecule. + - :obj:`pq`: Partial charges of the atoms. + - :obj:`dp`: Dipole moment of the molecule. + + Args: + idx (int): Index of the data object. + + Returns: + :obj:`torch_geometric.data.Data`: The data object. + """ + data = self.dataset[int(idx)] + return Data( + z=data["atomic_numbers"], + pos=data["positions"], + y=data["formation_energy"].view(1, 1), + neg_dy=data["forces"], + q=sum(data["formal_charges"]), + pq=data["partial_charges"], + dp=data["dipole_moment"], + )