Skip to content

Commit

Permalink
farewell mkdssp
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 29, 2024
1 parent 8975786 commit c48b69b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 127 deletions.
124 changes: 3 additions & 121 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5291,8 +5291,6 @@ def __init__(
contact_mask_threshold: float = 8.0,
is_fine_tuning: bool = False,
weight_dict_config: dict = None,
dssp_path: str = "mkdssp",
use_inhouse_rsa_calculation: bool = False
):
super().__init__()
self.compute_confidence_score = ComputeConfidenceScore(eps=eps)
Expand All @@ -5306,10 +5304,6 @@ def __init__(
self.register_buffer("dist_breaks", dist_breaks)
self.register_buffer('lddt_thresholds', torch.tensor([0.5, 1.0, 2.0, 4.0]))

self.dssp_path = dssp_path

self.use_inhouse_rsa_calculation = use_inhouse_rsa_calculation

atom_type_radii = tensor([
1.65, # 0 - nitrogen
1.87, # 1 - carbon alpha
Expand All @@ -5328,22 +5322,6 @@ def __init__(

self.register_buffer('atom_radii', atom_type_radii, persistent = False)

@property
def is_mkdssp_available(self):
"""Check if `mkdssp` is available.
:return: True if `mkdssp` is available
"""
try:
sh.which(self.dssp_path)
return True
except sh.ErrorReturnCode_1:
return False

@property
def can_calculate_unresolved_protein_rasa(self):
return self.is_mkdssp_available or self.use_inhouse_rsa_calculation

@typecheck
def compute_gpde(
self,
Expand Down Expand Up @@ -5633,92 +5611,6 @@ def compute_weighted_lddt(

return weighted_lddt

@typecheck
def _compute_unresolved_rasa(
self,
unresolved_cid: int,
unresolved_residue_mask: Bool[" n"],
asym_id: Int[" n"],
molecule_ids: Int[" n"],
molecule_atom_lens: Int[" n"],
atom_pos: Float["m 3"],
atom_mask: Bool[" m"],
) -> Float[""]:
"""Compute the unresolved relative solvent accessible surface area (RASA) for proteins.
unresolved_cid: asym_id for protein chains with unresolved residues
unresolved_residue_mask: True for unresolved residues, False for resolved residues
asym_id: asym_id for each residue
molecule_ids: molecule_ids for each residue
molecule_atom_lens: number of atoms for each residue
atom_pos: [m 3] atom positions
atom_mask: True for valid atoms, False for missing/padding atoms
:return: unresolved RASA
"""

assert self.can_calculate_unresolved_protein_rasa, "`mkdssp` needs to be installed"

residue_constants = get_residue_constants(res_chem_index=IS_PROTEIN)

device = atom_pos.device
dtype = atom_pos.dtype
num_atom = atom_pos.shape[0]

chain_mask = asym_id == unresolved_cid
chain_unresolved_residue_mask = unresolved_residue_mask[chain_mask]
chain_asym_id = asym_id[chain_mask]
chain_molecule_ids = molecule_ids[chain_mask]
chain_molecule_atom_lens = molecule_atom_lens[chain_mask]

chain_mask_to_atom = torch.repeat_interleave(chain_mask, molecule_atom_lens)

# if there's padding in num atom
num_pad = num_atom - molecule_atom_lens.sum()
if num_pad > 0:
chain_mask_to_atom = F.pad(chain_mask_to_atom, (0, num_pad), value=False)

chain_atom_pos = atom_pos[chain_mask_to_atom]
chain_atom_mask = atom_mask[chain_mask_to_atom]

structure = protein_structure_from_feature(
chain_asym_id,
chain_molecule_ids,
chain_molecule_atom_lens,
chain_atom_pos,
chain_atom_mask,
)

with tempfile.NamedTemporaryFile(suffix=".pdb", delete=True) as temp_file:
temp_file_path = temp_file.name

pdb_writer = PDBIO()
pdb_writer.set_structure(structure)
pdb_writer.save(temp_file_path)
dssp = DSSP(structure[0], temp_file_path, dssp=self.dssp_path)
dssp_dict = dict(dssp)

rasa = []
aatypes = []
for residue in structure.get_residues():
rsa = float(dssp_dict.get((residue.get_full_id()[2], residue.id))[3])
rasa.append(rsa)

aatype = dssp_dict.get((residue.get_full_id()[2], residue.id))[1]
aatypes.append(residue_constants.restype_order[aatype])

rasa = torch.tensor(rasa, dtype=dtype, device=device)
aatypes = torch.tensor(aatypes, device=device).int()

unresolved_aatypes = aatypes[chain_unresolved_residue_mask]
unresolved_molecule_ids = chain_molecule_ids[chain_unresolved_residue_mask]

assert torch.equal(
unresolved_aatypes, unresolved_molecule_ids
), "aatype not match for input feature and structure"
unresolved_rasa = rasa[chain_unresolved_residue_mask]

return unresolved_rasa.mean()

@typecheck
def calc_atom_access_surface_score_from_structure(
self,
Expand Down Expand Up @@ -5758,7 +5650,7 @@ def calc_atom_access_surface_score(
atom_pos: Float['m 3'],
atom_type: Int['m'],
molecule_atom_lens: Int['n'] | None = None,
fibonacci_sphere_n = 200, # they use 200 in mkdssp, but can be tailored for efficiency
fibonacci_sphere_n = 200, # more points equal better approximation at cost of compute
atom_distance_min_thres = 1e-4
) -> Float['m'] | Float['n']:

Expand Down Expand Up @@ -5852,7 +5744,7 @@ def calc_atom_access_surface_score(
return rasa

@typecheck
def _inhouse_compute_unresolved_rasa(
def _compute_unresolved_rasa(
self,
unresolved_cid: int,
unresolved_residue_mask: Bool["n"],
Expand All @@ -5876,8 +5768,6 @@ def _inhouse_compute_unresolved_rasa(
:return: unresolved RASA
"""

assert self.can_calculate_unresolved_protein_rasa, "`mkdssp` needs to be installed"

num_atom = atom_pos.shape[0]

chain_mask = asym_id == unresolved_cid
Expand Down Expand Up @@ -5949,16 +5839,8 @@ def compute_unresolved_rasa(
weight = weight_dict.get("unresolved", {}).get("unresolved", None)
assert weight, "Weight not found for unresolved"

# for migrating to a rewritten computation of RSA
# to remove mkdssp dependency for model selection

if self.use_inhouse_rsa_calculation:
compute_unresolved_rasa_function = self._inhouse_compute_unresolved_rasa
else:
compute_unresolved_rasa_function = self._compute_unresolved_rasa

unresolved_rasa = [
compute_unresolved_rasa_function(*args)
self._compute_unresolved_rasa(*args)
for args in zip(
unresolved_cid,
unresolved_residue_mask,
Expand Down
7 changes: 1 addition & 6 deletions tests/test_af3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1501,12 +1501,7 @@ def test_unresolved_protein_rasa():

unresolved_residue_mask = torch.randint(0, 2, asym_id.shape).bool()

compute_model_selection_score = ComputeModelSelectionScore(
use_inhouse_rsa_calculation = True
)

if not compute_model_selection_score.can_calculate_unresolved_protein_rasa:
pytest.skip("mkdssp not available for calculating unresolved protein rasa")
compute_model_selection_score = ComputeModelSelectionScore()

# only test with protein

Expand Down

0 comments on commit c48b69b

Please sign in to comment.