Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
7335215
update: include cyclic offset option
ioannisa92 Feb 25, 2025
709e8b1
chore: typing
ioannisa92 Feb 25, 2025
92aa420
update: create cyclic mask based on tag
ioannisa92 Feb 25, 2025
a939968
update: include cyclic mask in requried features
ioannisa92 Feb 25, 2025
f35b903
update: add cyclic offset to sequence features
ioannisa92 Feb 25, 2025
344c7b2
update: cyclic offset method
ioannisa92 Feb 25, 2025
5fe4379
update: cyclic offset config dimensions
ioannisa92 Feb 25, 2025
dc6a25c
update: cyclic offset option
ioannisa92 Feb 25, 2025
a1684e1
update: cyclic offset option in process_fasta
ioannisa92 Feb 25, 2025
c3661f5
update: cyclic-mask in feats config
ioannisa92 Mar 3, 2025
e51d589
update: cyclic offset option in DataPipeline
ioannisa92 Mar 3, 2025
597d3c0
update: cyclic mask in required feats names
ioannisa92 Mar 3, 2025
883184d
update: cyclic offset calculation in InputEmbedder
ioannisa92 Mar 3, 2025
d3f0a25
update: optional arg cycliuc mask
ioannisa92 Mar 3, 2025
d33b6db
fix: list input of fasta tags for cyclic offset
ioannisa92 Mar 5, 2025
cd0aa76
chore: remove abs from cyclic offset
ioannisa92 Mar 19, 2025
9c2785a
update: instruction on using cyclic offset
ioannisa92 Mar 21, 2025
9ac0dd7
refactor: argparse native list input
ioannisa92 Mar 21, 2025
901cf71
chore: format
ioannisa92 Mar 21, 2025
a2af528
chore: format
ioannisa92 Mar 21, 2025
cb54c8d
chore: typo
ioannisa92 Mar 21, 2025
92d8172
update: reference to AfCycDesign
ioannisa92 Mar 21, 2025
a4d8290
chore: fasta tags
ioannisa92 Mar 21, 2025
b8b3cf7
style: format of add arg parse
mrJeppard Mar 21, 2025
97dd28d
update: optional flags with cyclic_offset
ioannisa92 Mar 21, 2025
6ad1647
chore: cleaner cyclic mask
ioannisa92 Mar 21, 2025
19c4066
chore: cleanup markdown instruction
ioannisa92 Mar 22, 2025
a40e233
update: cyclic offset note about relaxed/unrelaxed output
ioannisa92 Apr 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions docs/source/Inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ python3 run_pretrained_openfold.py \
--pdb70_database_path $BASE_DATA_DIR/pdb70 \
--uniclust30_database_path $BASE_DATA_DIR/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
--bfd_database_path $BASE_DATA_DIR/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--model_device "cuda:0"
--model_device "cuda:0" \
--cyclic_offset FASTA-tag1 FASTA-tag2 ... FASTA-tagN
```

**Required arguments:**
Expand Down Expand Up @@ -138,7 +139,11 @@ Some commonly used command line flags are here. A full list of flags can be view
- `--data_random_seed`: Specifies a random seed to use.
- `--save_outputs`: Saves a copy of all outputs from the model, e.g. the output of the msa track, ptm heads.
- `--experiment_config_json`: Specify configuration settings using a json file. For example, passing a json with `{globals.relax.max_iterations = 10}` specifies 10 as the maximum number of relaxation iterations. See for [`openfold/config.py`](https://github.com/aqlaboratory/openfold/blob/main/openfold/config.py#L283) the full dictionary of configuration settings. Any parameters that are not manually set in these configuration settings will refer to the defaults specified by your `config_preset`.

- `--cyclic_offset`: Specifies a list FASTA tags for cyclic peptides. E.g. `--cyclic_offset FASTA-tag1 FASTA-tag2 ... FASTA-tagN`. When the list is not empty OpenFold will apply a cyclic end-to-end offset on the sequence instead of the deafult linear offset.
The result is that the sequence is treated as a cyclic species instead of a linear one.
It is recommended to use the unrelaxed output with this option as we have noticed worse
cyclization performance with the relaxed output. Refer to the [AfCycDesign preprint paper](https://www.biorxiv.org/content/10.1101/2023.02.25.529956v1.full) for original
implementation and explanation on the usage of a cyclic offset.

### Advanced Options for Increasing Efficiency

Expand Down
12 changes: 10 additions & 2 deletions docs/source/original_readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ python3 run_pretrained_openfold.py \
--config_preset "model_1_ptm" \
--model_device "cuda:0" \
--output_dir ./ \
--openfold_checkpoint_path openfold/resources/openfold_params/finetuning_ptm_2.pt
--openfold_checkpoint_path openfold/resources/openfold_params/finetuning_ptm_2.pt \
--cyclic_offset FASTA-tag1 FASTA-tag2 ... FASTA-tagN
```

where `data` is the same directory as in the previous step. If `jackhmmer`,
Expand All @@ -182,6 +183,12 @@ OpenFold was trained under a newer training schedule than the one from which the
`*_ptm` checkpoints must be run with `*_ptm` config presets and that `_no_templ_`
checkpoints are only compatible with template-less presets (`model_3` and above).

`--cyclic_offset` accepts a list of sequence FASTA tags. When the list is not empty OpenFold will
apply a cyclic end-to-end offset on the sequence instead of the deafult linear offset.
The result is that the sequence is treated as a cyclic species instead of a linear one.
Note: cyclization bond is not reliably retained relaxation step, recommend using unrelaxed structure output. Refer to the [AfCycDesign preprint paper](https://www.biorxiv.org/content/10.1101/2023.02.25.529956v1.full) for original
implementation and explanation on the usage of a cyclic offset.

Note that chunking (as defined in section 1.11.8 of the AlphaFold 2 supplement)
is enabled by default in inference mode. To disable it, set `globals.chunk_size`
to `None` in the config. If a value is specified, OpenFold will attempt to
Expand Down Expand Up @@ -263,7 +270,8 @@ python3 run_pretrained_openfold.py \
--kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign \
--config_preset "model_1_multimer_v3" \
--model_device "cuda:0" \
--output_dir ./
--output_dir ./ \
--cyclic_offset FASTA-tag1 FASTA-tag2 ... FASTA-tagN
```

As with monomer inference, if you've already computed alignments for the query, you can use
Expand Down
4 changes: 4 additions & 0 deletions openfold/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ def model_config(
"common": {
"feat": {
"aatype": [NUM_RES],
"cyclic_mask": [NUM_RES],
"all_atom_mask": [NUM_RES, None],
"all_atom_positions": [NUM_RES, None, None],
"alt_chi_angles": [NUM_RES, None],
Expand Down Expand Up @@ -383,6 +384,7 @@ def model_config(
"between_segment_residues",
"deletion_matrix",
"no_recycling_iters",
'cyclic_mask'
],
"use_templates": templates_enabled,
"use_template_torsion_angles": embed_template_torsion_angles,
Expand Down Expand Up @@ -748,6 +750,7 @@ def model_config(
"common": {
"feat": {
"aatype": [NUM_RES],
"cyclic_mask": [NUM_RES],
"all_atom_mask": [NUM_RES, None],
"all_atom_positions": [NUM_RES, None, None],
# "all_chains_entity_ids": [], # TODO: Resolve missing features, remove processed msa feats
Expand Down Expand Up @@ -808,6 +811,7 @@ def model_config(
"asym_id",
"entity_id",
"sym_id",
"cyclic_mask"
]
},
"supervised": {
Expand Down
11 changes: 9 additions & 2 deletions openfold/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import dataclasses
from multiprocessing import cpu_count
import tempfile
from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
from typing import List, Mapping, Optional, Sequence, Any, MutableMapping, Union
import numpy as np
import torch
from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer
Expand Down Expand Up @@ -854,6 +854,7 @@ def process_fasta(
alignment_dir: str,
alignment_index: Optional[Any] = None,
seqemb_mode: bool = False,
cyclic_offset: Optional[List[str]] = []
) -> FeatureDict:
"""Assembles features for a single sequence in a FASTA file"""
with open(fasta_path) as f:
Expand Down Expand Up @@ -885,6 +886,9 @@ def process_fasta(
num_res=num_res,
)

n_residue_index = sequence_features['residue_index'].shape[0]
sequence_features['cyclic_mask'] = (np.ones(n_residue_index)*(input_description in cyclic_offset)).astype(np.bool_)

sequence_embedding_features = {}
# If using seqemb mode, generate a dummy MSA features using just the sequence
if seqemb_mode:
Expand Down Expand Up @@ -1228,7 +1232,8 @@ def read_msa(start, size):
def process_fasta(self,
fasta_path: str,
alignment_dir: str,
alignment_index: Optional[Any] = None
alignment_index: Optional[Any] = None,
cyclic_offset: Optional[List[str]] = None
) -> FeatureDict:
"""Creates features."""
with open(fasta_path) as f:
Expand Down Expand Up @@ -1266,6 +1271,8 @@ def process_fasta(self,
chain_features,
chain_id=desc
)

chain_features['cyclic_mask'] = (np.ones(chain_features['seq_length'])*(desc in cyclic_offset)).astype(np.bool_)
all_chain_features[desc] = chain_features
sequence_features[seq] = chain_features

Expand Down
1 change: 1 addition & 0 deletions openfold/data/data_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def squeeze_features(protein):
"between_segment_residues",
"residue_index",
"template_all_atom_mask",
'cyclic_mask'
]:
if k in protein:
final_dim = protein[k].shape[-1]
Expand Down
2 changes: 1 addition & 1 deletion openfold/data/feature_processing_multimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
'entity_id', 'entity_mask', 'mem_peak', 'msa', 'msa_mask', 'num_alignments',
'num_templates', 'queue_size', 'residue_index', 'resolution',
'seq_length', 'seq_mask', 'sym_id', 'template_aatype',
'template_all_atom_mask', 'template_all_atom_positions'
'template_all_atom_mask', 'template_all_atom_positions', 'cyclic_mask'
})

MAX_TEMPLATES = 4
Expand Down
2 changes: 1 addition & 1 deletion openfold/data/msa_pairing.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
'sym_id', 'entity_mask', 'deletion_mean',
'prediction_atom_mask',
'literature_positions', 'atom_indices_to_group_indices',
'rigid_group_default_frame')
'rigid_group_default_frame', 'cyclic_mask')
TEMPLATE_FEATURES = ('template_aatype', 'template_all_atom_positions',
'template_all_atom_mask')
CHAIN_FEATURES = ('num_alignments', 'seq_length')
Expand Down
65 changes: 62 additions & 3 deletions openfold/model/embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,30 @@ def __init__(
self.no_bins = 2 * relpos_k + 1
self.linear_relpos = Linear(self.no_bins, c_z)

def relpos(self, ri: torch.Tensor):
def cyclic_offset(self, residue_index: torch.Tensor) -> torch.Tensor:
"""Calculate the cyclic offset for the given residue index.

Parameters
----------
residue_index : torch.Tensor
The residue index tensor.

Returns
-------
torch.Tensor
The cyclic offset tensor.
"""
peptide_length = residue_index.shape[0]
cyclic_offset_array = torch.zeros((peptide_length, peptide_length))
cyc_row = torch.arange(0, -peptide_length, -1)
pc = int(torch.round(torch.tensor(peptide_length / 2))) # Get centre
cyc_row[pc + 1 :] = torch.arange(len(cyc_row[pc + 1 :]), 0, -1)
for i in range(len(cyclic_offset_array)):
cyclic_offset_array[i] = torch.roll(cyc_row, i)
return cyclic_offset_array


def relpos(self, ri: torch.Tensor, cyclic_mask: Optional[torch.Tensor] = None):
"""
Computes relative positional encodings

Expand All @@ -93,6 +116,9 @@ def relpos(self, ri: torch.Tensor):
"residue_index" features of shape [*, N]
"""
d = ri[..., None] - ri[..., None, :]
if cyclic_mask is not None and sum(cyclic_mask)!=0:
d = self.cyclic_offset(ri).type(torch.long).to(d.device)

boundaries = torch.arange(
start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
)
Expand All @@ -110,6 +136,7 @@ def forward(
ri: torch.Tensor,
msa: torch.Tensor,
inplace_safe: bool = False,
cyclic_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
Expand All @@ -132,7 +159,7 @@ def forward(
tf_emb_j = self.linear_tf_z_j(tf)

# [*, N_res, N_res, c_z]
pair_emb = self.relpos(ri.type(tf_emb_i.dtype))
pair_emb = self.relpos(ri.type(tf_emb_i.dtype), cyclic_mask=cyclic_mask)
pair_emb = add(pair_emb,
tf_emb_i[..., None, :],
inplace=inplace_safe
Expand Down Expand Up @@ -211,12 +238,44 @@ def __init__(
else:
self.no_bins = 2 * max_relative_idx + 1
self.linear_relpos = Linear(self.no_bins, c_z)


def cyclic_offset(self, residue_index: torch.Tensor) -> torch.Tensor:
"""Calculate the cyclic offset for the given residue index.

Parameters
----------
residue_index : torch.Tensor
The residue index tensor.

Returns
-------
torch.Tensor
The cyclic offset tensor.
"""
peptide_length = residue_index.shape[0]
cyclic_offset_array = torch.zeros((peptide_length, peptide_length))
cyc_row = torch.arange(0, -peptide_length, -1)
pc = int(torch.round(torch.tensor(peptide_length / 2))) # Get centre
cyc_row[pc + 1 :] = torch.arange(len(cyc_row[pc + 1 :]), 0, -1)
for i in range(len(cyclic_offset_array)):
cyclic_offset_array[i] = torch.roll(cyc_row, i)
return cyclic_offset_array

def relpos(self, batch):
pos = batch["residue_index"]
asym_id = batch["asym_id"]
asym_id_same = (asym_id[..., None] == asym_id[..., None, :])
offset = pos[..., None] - pos[..., None, :]

if sum(batch['cyclic_mask'])!=0:
cyclic_entities = torch.unique(batch['entity_id'][batch['cyclic_mask']])
for cyclic_entity in cyclic_entities:
entity_mask = batch['entity_id'] == cyclic_entity
entity_idx = torch.where(batch['entity_id']==cyclic_entity)[0]
cyclic_pos = pos[entity_mask]
cyclic_offset = self.cyclic_offset(cyclic_pos).type(torch.long)
offset[entity_idx,entity_idx.view(-1,1)] = cyclic_offset.to(offset.device)


clipped_offset = torch.clamp(
offset + self.max_relative_idx, 0, 2 * self.max_relative_idx
Expand Down
1 change: 1 addition & 0 deletions openfold/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ def iteration(self, feats, prevs, _recycle=True):
feats["residue_index"],
feats["msa_feat"],
inplace_safe=inplace_safe,
cyclic_mask = feats["cyclic_mask"]
)

# Unpack the recycling embeddings. Removing them from the list allows
Expand Down
10 changes: 7 additions & 3 deletions run_pretrained_openfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def generate_feature_dict(
'\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
)
feature_dict = data_processor.process_fasta(
fasta_path=tmp_fasta_path, alignment_dir=alignment_dir,
fasta_path=tmp_fasta_path, alignment_dir=alignment_dir, cyclic_offset=args.cyclic_offset
)
elif len(seqs) == 1:
tag = tags[0]
Expand All @@ -151,7 +151,7 @@ def generate_feature_dict(
feature_dict = data_processor.process_fasta(
fasta_path=tmp_fasta_path,
alignment_dir=local_alignment_dir,
seqemb_mode=args.use_single_seq_mode,
seqemb_mode=args.use_single_seq_mode, cyclic_offset=args.cyclic_offset
)
else:
with open(tmp_fasta_path, "w") as fp:
Expand All @@ -175,7 +175,6 @@ def list_files_with_extensions(dir, extensions):
def main(args):
# Create the output directory
os.makedirs(args.output_dir, exist_ok=True)

if args.config_preset.startswith("seq"):
args.use_single_seq_mode = True

Expand Down Expand Up @@ -475,6 +474,11 @@ def main(args):
"--use_deepspeed_evoformer_attention", action="store_true", default=False,
help="Whether to use the DeepSpeed evoformer attention layer. Must have deepspeed installed in the environment.",
)
parser.add_argument(
'--cyclic-offset', metavar='N', type=str, nargs='*', default=[],
help="Space-separated list of sequence tags to apply cyclic offset to"
)

add_data_args(parser)
args = parser.parse_args()

Expand Down