Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion benchmarks/bench_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,18 @@
"""

from bench_utils.loaders import load_pickle, load_sdf, load_smarts, load_smiles
from bench_utils.molprep import clone_mols_with_conformers, prep_mols
from bench_utils.molprep import clone_mols_with_conformers, embed_and_jitter, perturb_conformer, prep_mols
from bench_utils.timing import TimingResult, time_it

__all__ = [
"TimingResult",
"clone_mols_with_conformers",
"embed_and_jitter",
"load_pickle",
"load_sdf",
"load_smarts",
"load_smiles",
"perturb_conformer",
"prep_mols",
"time_it",
]
103 changes: 103 additions & 0 deletions benchmarks/bench_utils/molprep.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@

"""Molecule preparation helpers shared across nvMolKit benchmarks."""

import random
from functools import partial

from rdkit import Chem
from rdkit.Chem import rdDistGeom
from rdkit.Geometry import Point3D
from tqdm.contrib.concurrent import process_map


def prep_mols(
Expand Down Expand Up @@ -63,3 +69,100 @@ def clone_mols_with_conformers(mols: list[Chem.Mol]) -> list[Chem.RWMol]:
pristine input.
"""
return [Chem.RWMol(mol) for mol in mols]


def perturb_conformer(conf: Chem.Conformer, delta: float, seed: int) -> None:
"""Apply per-atom uniform jitter to a conformer in place.

Each x/y/z coordinate is shifted by ``delta * U(-delta, delta)``, so
``delta=0.5`` produces displacements bounded by 0.25 A. Matches the
``perturbConformer`` helper used by the C++ FF bench.
"""
rng = random.Random(seed)
for atom_idx in range(conf.GetNumAtoms()):
pos = conf.GetAtomPosition(atom_idx)
conf.SetAtomPosition(
atom_idx,
Point3D(
pos.x + delta * rng.uniform(-delta, delta),
pos.y + delta * rng.uniform(-delta, delta),
pos.z + delta * rng.uniform(-delta, delta),
),
)


def _embed_one(args_tuple: tuple[int, bytes], seed: int, add_hs: bool, min_atoms: int) -> bytes | None:
"""Embed a single ETKDGv3 conformer for one mol payload (multiprocessing worker)."""
idx, mol_bytes = args_tuple
mol = Chem.Mol(mol_bytes)
if mol.GetNumAtoms() < min_atoms:
return None
if add_hs:
mol = Chem.AddHs(mol)
params = rdDistGeom.ETKDGv3()
params.useRandomCoords = True
params.randomSeed = seed + idx
try:
conf_id = rdDistGeom.EmbedMolecule(mol, params=params)
except Exception:
return None
if conf_id < 0 or mol.GetNumConformers() == 0:
return None
if add_hs:
mol = Chem.RemoveHs(mol)
return mol.ToBinary()


def embed_and_jitter(
mols: list[Chem.Mol],
confs_per_mol: int,
seed: int,
num_workers: int = 1,
add_hs: bool = False,
min_atoms: int = 1,
delta: float = 0.5,
desc: str = "Embedding base conformers",
) -> list[Chem.Mol]:
"""Embed one ETKDGv3 base conformer per mol in parallel, then jitter to ``confs_per_mol``.

The embed step runs across mols via ``process_map``; the jitter step is
in-process and serial (cheap). Mols whose base embedding fails are
dropped with a printed count. When ``add_hs`` is true, hydrogens are
added before embedding and stripped from the returned mol.
"""
if not mols:
return []
if confs_per_mol < 1:
raise ValueError(f"confs_per_mol must be >= 1, got {confs_per_mol}")

workers = max(1, num_workers)
binaries = [(i, mol.ToBinary()) for i, mol in enumerate(mols)]
embedded_binaries = process_map(
partial(_embed_one, seed=seed, add_hs=add_hs, min_atoms=min_atoms),
binaries,
max_workers=workers,
chunksize=max(1, len(binaries) // (workers * 8) or 1),
Comment thread
scal444 marked this conversation as resolved.
desc=desc,
)

out: list[Chem.Mol] = []
drop_count = 0
for raw in embedded_binaries:
if raw is None:
drop_count += 1
continue
out.append(Chem.Mol(raw))
if drop_count > 0:
print(f" Dropped {drop_count} molecules during embedding (no conformer generated)")

if confs_per_mol > 1:
for mol_idx, mol in enumerate(out):
base_conf_id = mol.GetConformer().GetId()
base_conf = mol.GetConformer(base_conf_id)
for conf_idx in range(1, confs_per_mol):
new_conf = Chem.Conformer(base_conf)
perturb_conformer(new_conf, delta, seed=seed + mol_idx * confs_per_mol + conf_idx)
mol.AddConformer(new_conf, assignId=True)
perturb_conformer(mol.GetConformer(base_conf_id), delta, seed=seed + mol_idx * confs_per_mol)

return out
13 changes: 9 additions & 4 deletions benchmarks/conformer_rmsd_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import numpy as np
import torch
from bench_utils import perturb_conformer
from benchmark_timing import time_it
from rdkit import Chem
from rdkit.Chem import AllChem, rdDistGeom
Expand Down Expand Up @@ -87,12 +88,16 @@ def run_benchmark(smiles, num_confs_list, seed=42):
params = rdDistGeom.ETKDGv3()
params.randomSeed = seed
params.useRandomCoords = True
rdDistGeom.EmbedMultipleConfs(mol, numConfs=num_confs, params=params)
actual_confs = mol.GetNumConformers()

if actual_confs < 2:
if rdDistGeom.EmbedMolecule(mol, params=params) < 0:
print(f"{num_confs:>8} {'skipped (embedding failed)':>50}")
continue
Comment thread
scal444 marked this conversation as resolved.
base_conf_id = mol.GetConformer().GetId()
for conf_idx in range(1, num_confs):
new_conf = Chem.Conformer(mol.GetConformer(base_conf_id))
perturb_conformer(new_conf, 0.5, seed=seed + conf_idx)
mol.AddConformer(new_conf, assignId=True)
perturb_conformer(mol.GetConformer(base_conf_id), 0.5, seed=seed)
actual_confs = mol.GetNumConformers()

no_h = Chem.RemoveHs(mol)
n_pairs = actual_confs * (actual_confs - 1) // 2
Expand Down
29 changes: 2 additions & 27 deletions benchmarks/ff_optimize_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import torch
from bench_utils import (
clone_mols_with_conformers,
embed_and_jitter,
load_pickle,
load_sdf,
load_smiles,
Expand All @@ -53,32 +54,6 @@
OPTUNA_AVAILABLE = nv_autotune.is_available()


def _embed_conformers(mols: list[Chem.Mol], confs_per_mol: int, seed: int) -> list[Chem.Mol]:
"""Generate ``confs_per_mol`` conformers per molecule using RDKit ETKDGv3.

Molecules where embedding fails to produce at least one conformer are
dropped; a count is printed.
"""
params = rdDistGeom.ETKDGv3()
params.useRandomCoords = True
params.randomSeed = seed

embedded: list[Chem.Mol] = []
drop_count = 0
for mol in mols:
try:
conf_ids = rdDistGeom.EmbedMultipleConfs(mol, numConfs=confs_per_mol, params=params)
if not conf_ids:
drop_count += 1
continue
embedded.append(mol)
except Exception:
drop_count += 1
if drop_count > 0:
print(f" Dropped {drop_count} molecules during embedding (no conformer generated)")
return embedded


def _flatten_energies(per_mol: list[list[float]]) -> list[float]:
"""Flatten ``[[e0, e1, ...], [e0, e1, ...], ...]`` returned by nvmolkit."""
flat: list[float] = []
Expand Down Expand Up @@ -398,7 +373,7 @@ def main() -> None:
print(f" {len(mols)} molecules ready")

print(f"\nEmbedding {args.confs_per_mol} conformer(s) per molecule with RDKit ETKDGv3...")
mols = _embed_conformers(mols, args.confs_per_mol, args.seed)
mols = embed_and_jitter(mols, args.confs_per_mol, seed=args.seed, num_workers=args.rdkit_threads)
if not mols:
print("Error: No molecules retained after embedding")
sys.exit(1)
Expand Down
Loading
Loading