Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
[![PyPI version](https://img.shields.io/pypi/v/shepherd-score.svg)](https://pypi.org/project/shepherd-score/)
[![Python versions](https://img.shields.io/pypi/pyversions/shepherd-score.svg)](https://pypi.org/project/shepherd-score/)
[![Documentation Status](https://readthedocs.org/projects/shepherd-score/badge/?version=latest)](https://shepherd-score.readthedocs.io/en/latest/?badge=latest)
[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/coleygroup/shepherd-score)


📄 **[Paper](https://arxiv.org/abs/2411.04130)** | 📚 **[Documentation](https://shepherd-score.readthedocs.io/en/latest/)** | 📦 **[PyPI](https://pypi.org/project/shepherd-score/)**

Expand Down
Binary file modified docs/performance/benchmark_results_linear.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/performance/benchmark_results_log.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
22 changes: 11 additions & 11 deletions docs/performance/timings.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ Multiprocessed `MoleculePairBatch` was evaluated for >100 pairs.
| Jax Volume | 0.30 | 2.05 | - | - |
| PyTorch (autodiff) | 0.39 | 6.35 | - | - |
| PyTorch (analytical) | 1.00 | 12.54 | - | - |
| Jax Batch (1 cpus, 1 buckets) | **15.54** | **59.57** | 103.69 | 110.59 |
| Jax Batch (1 cpus, 4 buckets) | <u>5.63</u> | <u>42.63</u> | 101.42 | 103.18 |
| Jax Batch (1 cpus, 8 buckets) | 3.08 | 25.01 | 93.00 | 120.69 |
| Jax Batch (1 bucket) | **15.54** | **59.57** | 103.69 | 110.59 |
| Jax Batch (4 buckets) | <u>5.63</u> | <u>42.63</u> | 101.42 | 103.18 |
| Jax Batch (8 buckets) | 3.08 | 25.01 | 93.00 | 120.69 |
| Jax Batch (4 cpus, 4 buckets) | - | 41.51 | **255.56** | <u>484.05</u> |
| Jax Batch (8 cpus, 8 buckets) | - | 31.99 | <u>225.60</u> | **506.64** |

Expand All @@ -56,9 +56,9 @@ Multiprocessed `MoleculePairBatch` was evaluated for >100 pairs.
| Jax Volume+ESP | <u>6.52</u> | 1.58 | - | - |
| PyTorch (autodiff) | 0.19 | 5.57 | - | - |
| PyTorch (analytical) | 0.29 | 9.84 | - | - |
| Jax Batch (1 cpus, 1 buckets) | **14.35** | **72.49** | 81.25 | 123.70 |
| Jax Batch (1 cpus, 4 buckets) | 4.54 | 37.55 | 81.90 | 121.04 |
| Jax Batch (1 cpus, 8 buckets) | 3.02 | 23.59 | 115.76 | 143.04 |
| Jax Batch (1 bucket) | **14.35** | **72.49** | 81.25 | 123.70 |
| Jax Batch (4 buckets) | 4.54 | 37.55 | 81.90 | 121.04 |
| Jax Batch (8 buckets) | 3.02 | 23.59 | 115.76 | 143.04 |
| Jax Batch (4 cpus, 4 buckets) | - | <u>40.92</u> | **299.10** | <u>468.43</u> |
| Jax Batch (8 cpus, 8 buckets) | - | 26.68 | <u>212.16</u> | **506.33** |

Expand All @@ -70,9 +70,9 @@ Multiprocessed `MoleculePairBatch` was evaluated for >100 pairs.
| Jax Pharmacophore (vectorized) | 3.73 | 2.00 | - | - |
| PyTorch (autodiff) | 3.77 | 3.68 | - | - |
| PyTorch (analytical) | <u>11.90</u> | 9.45 | - | - |
| Jax Batch (1 cpus, 1 buckets) | **14.15** | **47.61** | 68.13 | 67.26 |
| Jax Batch (1 cpus, 4 buckets) | 5.17 | <u>44.84</u> | 89.49 | 106.80 |
| Jax Batch (1 cpus, 8 buckets) | 4.25 | 2.29 | 95.04 | 164.53 |
| Jax Batch (1 bucket) | **14.15** | **47.61** | 68.13 | 67.26 |
| Jax Batch (4 buckets) | 5.17 | <u>44.84</u> | 89.49 | 106.80 |
| Jax Batch (8 buckets) | 4.25 | 2.29 | 95.04 | 164.53 |
| Jax Batch (4 cpus, 4 buckets) | - | 44.08 | **294.72** | <u>492.94</u> |
| Jax Batch (8 cpus, 8 buckets) | - | 26.46 | <u>216.67</u> | **532.53** |

Expand All @@ -83,7 +83,7 @@ Multiprocessed `MoleculePairBatch` was evaluated for >100 pairs.
| Jax Surface | <u>3.72</u> | <u>4.12</u> | - | - |
| PyTorch (autodiff) | 0.88 | 1.83 | - | - |
| PyTorch (analytical) | 1.78 | 3.70 | - | - |
| Jax Batch (1 cpus, 1 buckets) | **4.76** | **4.88** | **6.63** | <u>4.34</u> |
| Jax Batch (1 bucket) | **4.76** | **4.88** | **6.63** | <u>4.34</u> |
| Jax Batch (4 cpus, 4 buckets) | - | - | - | **6.68** |

### Surface + ESP alignment — alignments/s
Expand All @@ -93,7 +93,7 @@ Multiprocessed `MoleculePairBatch` was evaluated for >100 pairs.
| Jax Surface+ESP | 2.65 | <u>3.80</u> | - | - |
| PyTorch (autodiff) | 0.65 | 2.76 | - | - |
| PyTorch (analytical) | **3.02** | **4.57** | - | - |
| Jax Batch (1 cpus, 1 buckets) | <u>2.85</u> | 2.30 | **2.71** | - |
| Jax Batch (1 bucket) | <u>2.85</u> | 2.30 | **2.71** | - |

## Experiment results (semi-log plot)

Expand Down
12 changes: 8 additions & 4 deletions shepherd_score/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ def chimera_from_mol(mol: Chem.Mol,
ev_pos = None,
ev_vecs = None,
save_dir: str = './',
surf_point_size: float = 0.05,
verbose: bool = True,
) -> None:
"""
Expand All @@ -397,7 +398,7 @@ def chimera_from_mol(mol: Chem.Mol,
pharm_types = np.concatenate([np.zeros((len(ev_pos)), dtype=int) + 10, pharm_types], axis=0)

if surf_pos is not None and surf_esp is not None:
surf_bild = _chimera_shape_esp_file(surf_pos, surf_esp)
surf_bild = _chimera_shape_esp_file(surf_pos, surf_esp, surf_point_size=surf_point_size)
with open(save_dir_ / f'{mol_id}_x3.bild', 'w') as f:
f.write(surf_bild)
if verbose:
Expand Down Expand Up @@ -453,6 +454,7 @@ def _chimera_pharmacophore_file(pharm_types: np.ndarray, pharm_pos: np.ndarray,
def _chimera_shape_esp_file(surf_pos: np.ndarray,
surf_esp: np.ndarray,
norm_factor: float = 2.0,
surf_point_size: float = 0.05,
) -> str:
esp = surf_esp * 4.0
esp_pos = surf_pos
Expand All @@ -470,7 +472,7 @@ def _chimera_shape_esp_file(surf_pos: np.ndarray,
bild += f'.transparency {0.9}\n'
else:
bild += f'.transparency {0.0}\n'
bild += f'.sphere {p[0]} {p[1]} {p[2]} 0.05\n'
bild += f'.sphere {p[0]} {p[1]} {p[2]} {surf_point_size}\n'

return bild

Expand All @@ -480,6 +482,7 @@ def chimera_from_sample(generated_sample: dict,
save_dir: str,
model_type: Literal['all', 'x2', 'x3', 'x4'] = 'all',
esp_norm_factor: float = 2.0,
surf_point_size: float = 0.05,
verbose: bool = True,
) -> None:
"""
Expand Down Expand Up @@ -513,7 +516,7 @@ def chimera_from_sample(generated_sample: dict,
pharm_types = np.zeros((len(dummy_atom_pos))) + 9

if surf_pos is not None and surf_esp is not None:
esp_bild = _chimera_shape_esp_file(surf_pos, surf_esp, norm_factor=esp_norm_factor)
esp_bild = _chimera_shape_esp_file(surf_pos, surf_esp, norm_factor=esp_norm_factor, surf_point_size=surf_point_size)
with open(path_ / f'{mol_id}_x3.bild', 'w') as f:
f.write(esp_bild)
if verbose:
Expand All @@ -532,6 +535,7 @@ def chimera_from_molecule(molec: Molecule,
mol_id: str | int,
save_dir: str,
esp_norm_factor: float = 2.0,
surf_point_size: float = 0.05,
verbose: bool = True,
) -> None:
"""
Expand Down Expand Up @@ -565,7 +569,7 @@ def chimera_from_molecule(molec: Molecule,
molec.pharm_vecs = np.zeros((len(dummy_atom_pos), 3))

if molec.surf_pos is not None and molec.surf_esp is not None:
esp_bild = _chimera_shape_esp_file(molec.surf_pos, molec.surf_esp, norm_factor=esp_norm_factor)
esp_bild = _chimera_shape_esp_file(molec.surf_pos, molec.surf_esp, norm_factor=esp_norm_factor, surf_point_size=surf_point_size)
with open(path_ / f'{mol_id}_x3.bild', 'w') as f:
f.write(esp_bild)
if verbose:
Expand Down
81 changes: 52 additions & 29 deletions tests/test_batch_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,49 +240,72 @@ def test_masked_pharm_scoring_matches_unmasked():
@pytest.mark.jax
def test_masked_pharm_alignment_matches_unmasked():
"""optimize_pharm_overlay_jax_vectorized_mask should reach similar score as unmasked."""
import rdkit.Chem as Chem
from rdkit.Chem import AllChem
from shepherd_score.container import Molecule
rng = np.random.default_rng(42)
n_types = 8 # real types (0-7); 8 = Dummy
# Acceptor=0, Donor=1, Aromatic=2, Halogen=4 are directional (unit vectors)
# Hydrophobe=3, Cation=5, Anion=6, ZnBinder=7 are non-directional (zero vectors)
DIRECTIONAL = np.array([0, 1, 2, 4])

m = Chem.MolFromSmiles("CC(=O)Nc1ccc(cc1)OCCCC")
m = Chem.AddHs(m)
AllChem.EmbedMolecule(m, AllChem.ETKDGv3())
mol = Molecule(m, pharm_multi_vector=False)
# Each array covers all 8 pharmacophore types exactly once
ptypes_1 = np.arange(n_types, dtype=np.int32)
ptypes_2 = np.arange(n_types, dtype=np.int32)
rng.shuffle(ptypes_1)
rng.shuffle(ptypes_2)
n1 = n2 = n_types

pt = jnp.array(mol.pharm_types)
ancs = jnp.array(mol.pharm_ancs)
vecs = jnp.array(mol.pharm_vecs)
ancs_1 = rng.standard_normal((n1, 3)).astype(np.float32)
ancs_2 = rng.standard_normal((n2, 3)).astype(np.float32)
raw_vecs_1 = rng.standard_normal((n1, 3)).astype(np.float32)
raw_vecs_2 = rng.standard_normal((n2, 3)).astype(np.float32)
# Unit vectors for directional types; zero vectors for non-directional
dir_mask_1 = np.isin(ptypes_1, DIRECTIONAL)[:, None]
dir_mask_2 = np.isin(ptypes_2, DIRECTIONAL)[:, None]
vecs_1 = np.where(dir_mask_1, raw_vecs_1 / np.linalg.norm(raw_vecs_1, axis=1, keepdims=True), 0.0).astype(np.float32)
vecs_2 = np.where(dir_mask_2, raw_vecs_2 / np.linalg.norm(raw_vecs_2, axis=1, keepdims=True), 0.0).astype(np.float32)

# Unmasked
_, _, _, score_unmasked = optimize_pharm_overlay_jax_vectorized(
pt, pt, ancs, ancs, vecs, vecs, num_repeats=10, max_num_steps=50
jnp.array(ptypes_1), jnp.array(ptypes_2),
jnp.array(ancs_1), jnp.array(ancs_2),
jnp.array(vecs_1), jnp.array(vecs_2),
num_repeats=10, max_num_steps=50,
)

# Pad
n = mol.pharm_types.shape[0]
pad = n + 4
# Pad to common size
pad = 12
DUMMY = 8

pt_pad = np.full(pad, DUMMY, dtype=np.int32)
pt_pad[:n] = np.array(pt)
ancs_pad = np.zeros((pad, 3), dtype=np.float32)
ancs_pad[:n] = np.array(ancs)
vecs_pad = np.zeros((pad, 3), dtype=np.float32)
vecs_pad[:n] = np.array(vecs)
mask = np.zeros(pad, dtype=np.float32)
mask[:n] = 1.0
pt1_pad = np.full(pad, DUMMY, dtype=np.int32)
pt1_pad[:n1] = ptypes_1
pt2_pad = np.full(pad, DUMMY, dtype=np.int32)
pt2_pad[:n2] = ptypes_2

a1_pad = np.zeros((pad, 3), dtype=np.float32)
a1_pad[:n1] = ancs_1
a2_pad = np.zeros((pad, 3), dtype=np.float32)
a2_pad[:n2] = ancs_2

v1_pad = np.zeros((pad, 3), dtype=np.float32)
v1_pad[:n1] = vecs_1
v2_pad = np.zeros((pad, 3), dtype=np.float32)
v2_pad[:n2] = vecs_2

mask1 = np.zeros(pad, dtype=np.float32)
mask1[:n1] = 1.0
mask2 = np.zeros(pad, dtype=np.float32)
mask2[:n2] = 1.0

_, _, _, score_masked = optimize_pharm_overlay_jax_vectorized_mask(
jnp.array(pt_pad), jnp.array(pt_pad),
jnp.array(ancs_pad), jnp.array(ancs_pad),
jnp.array(vecs_pad), jnp.array(vecs_pad),
jnp.array(mask), jnp.array(mask),
jnp.array(pt1_pad), jnp.array(pt2_pad),
jnp.array(a1_pad), jnp.array(a2_pad),
jnp.array(v1_pad), jnp.array(v2_pad),
jnp.array(mask1), jnp.array(mask2),
num_repeats=10, max_num_steps=50,
init_ref_anchors=np.array(ancs),
init_fit_anchors=np.array(ancs),
init_ref_anchors=ancs_1,
init_fit_anchors=ancs_2,
)

assert abs(float(score_unmasked) - float(score_masked)) < 1e-5, (
assert abs(float(score_unmasked) - float(score_masked)) < 1e-4, (
f"Masked pharm alignment score {float(score_masked):.4f} differs too much "
f"from unmasked {float(score_unmasked):.4f}"
)
Expand Down
Loading