diff --git a/torchmdnet/datasets/__init__.py b/torchmdnet/datasets/__init__.py index 92a70bd24..42dc7cf3f 100644 --- a/torchmdnet/datasets/__init__.py +++ b/torchmdnet/datasets/__init__.py @@ -2,7 +2,7 @@ # Distributed under the MIT License. # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) -from .ace import Ace +from .ace import Ace, AceHF from .ani import ANI1, ANI1CCX, ANI1X, ANI2X from .comp6 import ( ANIMD, @@ -28,6 +28,7 @@ __all__ = [ "Ace", + "AceHF", "ANIMD", "ANI1", "ANI1CCX", diff --git a/torchmdnet/datasets/ace.py b/torchmdnet/datasets/ace.py index ebb7da0b3..a5079c89f 100644 --- a/torchmdnet/datasets/ace.py +++ b/torchmdnet/datasets/ace.py @@ -7,7 +7,7 @@ import os import torch as pt from torchmdnet.datasets.memdataset import MemmappedDataset -from torch_geometric.data import Data +from torch_geometric.data import Data, Dataset from tqdm import tqdm @@ -291,3 +291,102 @@ def sample_iter(self, mol_ids=False): data = self.pre_transform(data) yield data + + +def download_gitea_dataset(path, tmpdir): + try: + from git import Repo + except ImportError: + raise ImportError( + "Could not import GitPython library. Please install it first with `pip install GitPython`" + ) + + assert path.startswith("ssh://") + + # Parse the gitea URL + pieces = path.split("/") + repo_url = "/".join(pieces[:5]) + user = pieces[3] + repo_name = pieces[4] + file_name = pieces[-1] + branch = "main" + commit = None + if "branch" in pieces: + branch = pieces[pieces.index("branch") + 1] + if "commit" in pieces: + commit = pieces[pieces.index("commit") + 1] + + outdir = os.path.join(tmpdir, f"{user}_{repo_name}") + if not os.path.exists(outdir): + repo = Repo.clone_from(repo_url, outdir, no_checkout=True) + else: + repo = Repo(outdir) + + origin = repo.remotes.origin + origin.pull() + if commit is not None: + repo.git.checkout(commit) + else: + repo.git.checkout(branch) + + return os.path.join(outdir, file_name) + + +class AceHF(Dataset): + def __init__( + self, root="parquet", paths=None, split="train", max_gradient=None + ) -> None: + from datasets import load_dataset + import numpy as np + + # Handle gitea parquet datasets + newpaths = paths.copy() + for i, path in enumerate(paths): + if "gitea" in path: + newpaths[i] = download_gitea_dataset(path, "/tmp") + + self.dataset = load_dataset(root, data_files=newpaths, split=split) + if max_gradient is not None: + + def _filter(x): + if np.isnan(x["forces"]).any() or np.isnan(x["formation_energy"]).any(): + return False + return np.max(np.linalg.norm(x["forces"], axis=1)) < max_gradient + + self.dataset = self.dataset.filter( + _filter, desc="Filtering", num_proc=os.cpu_count() // 2 + ) + self.dataset = self.dataset.with_format("torch") + + def __len__(self): + return self.dataset.num_rows + + def __getitem__(self, idx): + """Gets the data object at index :obj:`idx`. + + The data object contains the following attributes: + + - :obj:`z`: Atomic numbers of the atoms. + - :obj:`pos`: Positions of the atoms. + - :obj:`y`: Formation energy of the molecule. + - :obj:`neg_dy`: Forces on the atoms. + - :obj:`q`: Total charge of the molecule. + - :obj:`pq`: Partial charges of the atoms. + - :obj:`dp`: Dipole moment of the molecule. + + Args: + idx (int): Index of the data object. + + Returns: + :obj:`torch_geometric.data.Data`: The data object. + """ + data = self.dataset[int(idx)] + return Data( + z=data["atomic_numbers"], + pos=data["positions"], + y=data["formation_energy"].view(1, 1), + neg_dy=data["forces"], + q=sum(data["formal_charges"]), + pq=data["partial_charges"], + dp=data["dipole_moment"], + )