-
Notifications
You must be signed in to change notification settings - Fork 93
fix: chain template alignments auth labelling (inference) #117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 17 commits
561018c
0787889
c3c14a5
080d0d0
e57d303
865ec86
b29f1f6
a78196f
be1fc4e
6f7d487
126bcca
11420df
65c292f
fbc389a
8450fa8
4a26bc6
78570ff
f16b160
bf286be
f5cb6b9
1364f45
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,7 +37,11 @@ | |
| from openfold3.core.config import config_utils | ||
| from openfold3.core.data.io.sequence.msa import parse_a3m | ||
| from openfold3.core.data.primitives.sequence.hash import get_sequence_hash | ||
| from openfold3.core.data.primitives.structure.metadata import ( | ||
| get_author_to_label_chain_ids, | ||
| ) | ||
| from openfold3.core.data.resources.residues import MoleculeType | ||
| from openfold3.core.data.tools.rscb import fetch_label_to_author_chain_ids | ||
| from openfold3.projects.of3_all_atom.config.inference_query_format import ( | ||
| InferenceQuerySet, | ||
| ) | ||
|
|
@@ -642,6 +646,68 @@ def save_colabfold_mappings( | |
| ) | ||
|
|
||
|
|
||
| def remap_colabfold_template_chain_ids( | ||
| template_alignments: pd.DataFrame, | ||
| m_with_templates: set[int], | ||
| rep_ids: list[str], | ||
| rep_id_to_m: dict[str, int], | ||
| ) -> dict[str, pd.DataFrame]: | ||
| """Remap author chain IDs to label (``label_asym_id``) chain IDs. | ||
|
|
||
| ColabFold returns template IDs with author-assigned chain IDs | ||
| (``pdb_strand_id``), but the rest of the pipeline expects mmCIF | ||
| ``label_asym_id``. This function queries the RCSB PDB API to obtain | ||
| the mapping and rewrites the template IDs. | ||
|
|
||
| Args: | ||
| template_alignments: Raw template alignment DataFrame from ``pdb70.m8``. | ||
| m_with_templates: Set of ColabFold M-indices that have template hits. | ||
| rep_ids: List of representative chain IDs. | ||
| rep_id_to_m: Mapping from representative ID to ColabFold M-index. | ||
|
|
||
| Returns: | ||
| Dictionary mapping ``rep_id`` to a DataFrame with remapped template IDs. | ||
| """ | ||
| # Collect DataFrames per rep_id and accumulate unique PDB IDs | ||
| per_rep: dict[str, pd.DataFrame] = {} | ||
| unique_pdb_ids: set[str] = set() | ||
| for rep_id in rep_ids: | ||
| m_i = rep_id_to_m[rep_id] | ||
| if m_i not in m_with_templates: | ||
| continue | ||
| chain_alns = template_alignments[template_alignments[0] == m_i] | ||
| top_n = chain_alns.copy() | ||
| per_rep[rep_id] = top_n | ||
| unique_pdb_ids.update(top_n[1].str.split("_").str[0]) | ||
|
Comment on lines
+672
to
+681
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't love this code – but it's just refactored into a function... |
||
|
|
||
| # Fetch label->author mappings in one API call, then invert | ||
| label_to_author_maps = fetch_label_to_author_chain_ids(unique_pdb_ids) | ||
| author_to_label_maps = { | ||
| entry_id: get_author_to_label_chain_ids(l2a) | ||
| for entry_id, l2a in label_to_author_maps.items() | ||
| } | ||
|
|
||
| # Remap chain IDs | ||
| for top_n in per_rep.values(): | ||
| remapped_ids = [] | ||
| for template_id in top_n[1]: | ||
| entry_id, author_chain_id = template_id.split("_") | ||
|
|
||
| author_to_label = author_to_label_maps.get(entry_id, {}) | ||
| if author_chain_id not in author_to_label: | ||
| raise RuntimeError( | ||
| f"Author chain {author_chain_id} not found in {entry_id}. " | ||
| f"Available author chains: {sorted(author_to_label.keys())}" | ||
| ) | ||
| label_chain_id = author_to_label[author_chain_id][0] | ||
|
|
||
| remapped_ids.append(f"{entry_id}_{label_chain_id}") | ||
|
|
||
| top_n[1] = remapped_ids | ||
|
|
||
| return per_rep | ||
|
|
||
|
|
||
| class ColabFoldQueryRunner: | ||
| """Class to run queries on the ColabFold MSA server. | ||
|
|
||
|
|
@@ -684,7 +750,6 @@ def __init__( | |
| def query_format_main(self): | ||
| """Submits queries and formats the outputs for main MSAs.""" | ||
| # Submit query for main MSAs | ||
| # TODO: add template alignments fetching code here by setting use_templates=True | ||
| # TODO: replace prints with proper logging | ||
| if len(self.colabfold_mapper.seqs) == 0: | ||
| print("No protein sequences found for main MSA generation. Skipping...") | ||
|
|
@@ -709,7 +774,25 @@ def query_format_main(self): | |
| template_alignments_path = self.output_directory / "template" | ||
| template_alignments_path.mkdir(parents=True, exist_ok=True) | ||
|
|
||
| # Read template alignments if the file exists and has content | ||
| # 1) Save MSA a3m/npz files | ||
| for rep_id, aln in zip( | ||
| self.colabfold_mapper.rep_ids, a3m_lines_main, strict=True | ||
| ): | ||
| rep_dir = main_alignments_path / str(rep_id) | ||
|
|
||
| if "a3m" in self.msa_file_format: | ||
| rep_dir.mkdir(parents=True, exist_ok=True) | ||
| a3m_file = rep_dir / "colabfold_main.a3m" | ||
| with open(a3m_file, "w") as f: | ||
| f.write(aln) | ||
|
|
||
| if "npz" in self.msa_file_format: | ||
| npz_file = Path(f"{rep_dir}.npz") | ||
| msas = {"colabfold_main": parse_a3m(aln)} | ||
| msas_preparsed = {k: v.to_dict() for k, v in msas.items()} | ||
| np.savez_compressed(npz_file, **msas_preparsed) | ||
|
|
||
| # 2) Read raw template alignments and collect unique PDB IDs | ||
| template_alignments_file = self.output_directory / "raw/main/pdb70.m8" | ||
| if ( | ||
| template_alignments_file.exists() | ||
|
|
@@ -725,47 +808,33 @@ def query_format_main(self): | |
| # Create empty DataFrame with expected column structure (at least column 0) | ||
| # to match the structure when file is read with header=None. | ||
| logger.warning( | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we should error here
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could raise an error here instead. My concern is if a user is running a large batch of predictions, they may prefer to be notified later about the issue with missing templates, rather than have the workflow interrupted for a few broken examples. We could think of a better way to record this issue and bring attention to the missing template alginments?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hard to tell – either way it's out of scope in a way, because it's not related to the bug-fix per se. Should we handle this in another PR? This PR is already 20 files, we're ballooning |
||
| "Colabfold returned no templates. \ | ||
| Proceeding without template alignments for this batch." | ||
| "Colabfold returned no templates. " | ||
| "Proceeding without template alignments for this batch." | ||
| ) | ||
| template_alignments = pd.DataFrame() | ||
| m_with_templates = set() | ||
|
|
||
| for rep_id, aln in zip( | ||
| self.colabfold_mapper.rep_ids, a3m_lines_main, strict=False | ||
| ): | ||
| rep_dir = main_alignments_path / str(rep_id) | ||
| template_rep_dir = template_alignments_path / str(rep_id) | ||
|
|
||
| # TODO: add code for which format to save the MSA in | ||
| # If save as a3m... | ||
| if "a3m" in self.msa_file_format: | ||
| rep_dir.mkdir(parents=True, exist_ok=True) | ||
| a3m_file = rep_dir / "colabfold_main.a3m" | ||
| with open(a3m_file, "w") as f: | ||
| f.write(aln) | ||
| if len(template_alignments) == 0: | ||
| return | ||
|
|
||
| # If save as npz... | ||
| if "npz" in self.msa_file_format: | ||
| npz_file = Path(f"{rep_dir}.npz") | ||
| msas = {"colabfold_main": parse_a3m(aln)} | ||
| msas_preparsed = {} | ||
| for k, v in msas.items(): | ||
| msas_preparsed[k] = v.to_dict() | ||
| np.savez_compressed(npz_file, **msas_preparsed) | ||
| # 3) Remap author chain IDs -> label chain IDs via RCSB API | ||
| remapped = remap_colabfold_template_chain_ids( | ||
| template_alignments=template_alignments, | ||
| m_with_templates=m_with_templates, | ||
| rep_ids=self.colabfold_mapper.rep_ids, | ||
| rep_id_to_m=self.colabfold_mapper.rep_id_to_m, | ||
| ) | ||
|
|
||
| # Format template alignments | ||
| m_i = self.colabfold_mapper.rep_id_to_m[rep_id] | ||
| if m_i in m_with_templates and len(template_alignments) > 0: | ||
| template_rep_dir.mkdir(parents=True, exist_ok=True) | ||
| template_alignment_file = template_rep_dir / "colabfold_template.m8" | ||
| template_alignment = template_alignments[template_alignments[0] == m_i] | ||
| template_alignment.to_csv( | ||
| template_alignment_file, | ||
| sep="\t", | ||
| header=False, | ||
| index=False, | ||
| ) | ||
| # 4) Save remapped m8 files | ||
| for rep_id, df in remapped.items(): | ||
| template_rep_dir = template_alignments_path / str(rep_id) | ||
| template_rep_dir.mkdir(parents=True, exist_ok=True) | ||
| df.to_csv( | ||
| template_rep_dir / "colabfold_template.m8", | ||
| sep="\t", | ||
| header=False, | ||
| index=False, | ||
| ) | ||
|
|
||
| def query_format_paired(self): | ||
| """Submits queries and formats the outputs for paired MSAs.""" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved to a new rscb.py module