Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions openfold3/core/data/primitives/structure/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,88 @@ def get_asym_id_to_canonical_seq_dict(
}


def get_label_to_author_chain_id_dict(
cif_file: CIFFile | BinaryCIFFile,
) -> dict[str, str]:
"""Get a mapping from label asym_id to author (pdb_strand_id) chain ID.

Reads from ``pdbx_poly_seq_scheme`` so no atom array is needed.

Args:
cif_file:
Parsed mmCIF file containing the structure.

Returns:
A dictionary mapping label asym IDs to author chain IDs.
"""
block = cif_file.block
poly_scheme = block["pdbx_poly_seq_scheme"]
asym_ids = poly_scheme["asym_id"].as_array()
author_ids = poly_scheme["pdb_strand_id"].as_array()

return dict(zip(asym_ids.tolist(), author_ids.tolist(), strict=True))


def get_author_to_label_chain_ids(
label_to_author: dict[str, str],
) -> dict[str, list[str]]:
"""Get a mapping from author (pdb_strand_id) chain ID to label asym_ids.

Multiple label asym_ids can map to the same author chain ID for homomeric
chains. The returned lists are sorted by label asym_id for determinism.

Args:
cif_file:
Parsed mmCIF file containing the structure.

Returns:
A dictionary mapping author chain IDs to sorted lists of label asym IDs.
"""
author_to_labels: dict[str, list[str]] = defaultdict(list)
for label, author in label_to_author.items():
author_to_labels[author].append(label)
for labels in author_to_labels.values():
labels.sort()
return dict(author_to_labels)


def resolve_author_to_label_chain_id(
matching_labels: dict[str, list[str]],
chain_id_seq_map: dict[str, str],
) -> str:
"""Resolve an author (pdb_strand_id) chain ID to a single label asym_id.

For homomeric chains, multiple label asym_ids share the same author chain
ID. This function returns the lexicographically smallest label asym_id.
When *asym_id_to_seq* is provided, it additionally verifies that all
matching label chains carry the same canonical sequence.

Args:
cif_file:
Parsed mmCIF file containing the structure.
asym_id_to_seq:
Optional mapping from label asym_id to canonical sequence
(as returned by :func:`get_asym_id_to_canonical_seq_dict`).
When provided, an error is raised if homomeric chains have
differing sequences.

Returns:
The label asym_id corresponding to *author_chain_id*.

Raises:
KeyError: If *author_chain_id* is not found.
ValueError: If homomeric chains have differing sequences.
"""
if len(matching_labels) > 1:
seqs = {chain_id_seq_map[label] for label in matching_labels}
if len(seqs) != 1:
raise ValueError(
f"Expected identical sequences for homomeric chains "
f"got {len(seqs)} distinct sequences"
)
return matching_labels[0]


def get_entity_to_three_letter_codes_dict(cif_data: CIFBlock) -> dict[int, list[str]]:
"""Get a dictionary mapping entity IDs to their three-letter-code sequences.

Expand Down
146 changes: 110 additions & 36 deletions openfold3/core/data/tools/colabfold_msa_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import numpy as np
import pandas as pd
import requests
from biotite.database.rcsb import fetch
from biotite.structure.io.pdbx import CIFFile
from pydantic import BaseModel, model_validator
from pydantic import ConfigDict as PydanticConfigDict
from pydantic_core import Url
Expand All @@ -37,6 +39,10 @@
from openfold3.core.config import config_utils
from openfold3.core.data.io.sequence.msa import parse_a3m
from openfold3.core.data.primitives.sequence.hash import get_sequence_hash
from openfold3.core.data.primitives.structure.metadata import (
get_author_to_label_chain_ids,
get_label_to_author_chain_id_dict,
)
from openfold3.core.data.resources.residues import MoleculeType
from openfold3.projects.of3_all_atom.config.inference_query_format import (
InferenceQuerySet,
Expand Down Expand Up @@ -684,7 +690,6 @@ def __init__(
def query_format_main(self):
"""Submits queries and formats the outputs for main MSAs."""
# Submit query for main MSAs
# TODO: add template alignments fetching code here by setting use_templates=True
# TODO: replace prints with proper logging
if len(self.colabfold_mapper.seqs) == 0:
print("No protein sequences found for main MSA generation. Skipping...")
Expand All @@ -709,7 +714,25 @@ def query_format_main(self):
template_alignments_path = self.output_directory / "template"
template_alignments_path.mkdir(parents=True, exist_ok=True)

# Read template alignments if the file exists and has content
# 1) Save MSA a3m/npz files
for rep_id, aln in zip(
self.colabfold_mapper.rep_ids, a3m_lines_main, strict=False
):
rep_dir = main_alignments_path / str(rep_id)

if "a3m" in self.msa_file_format:
rep_dir.mkdir(parents=True, exist_ok=True)
a3m_file = rep_dir / "colabfold_main.a3m"
with open(a3m_file, "w") as f:
f.write(aln)

if "npz" in self.msa_file_format:
npz_file = Path(f"{rep_dir}.npz")
msas = {"colabfold_main": parse_a3m(aln)}
msas_preparsed = {k: v.to_dict() for k, v in msas.items()}
np.savez_compressed(npz_file, **msas_preparsed)

# 2) Read raw template alignments and collect unique PDB IDs
template_alignments_file = self.output_directory / "raw/main/pdb70.m8"
if (
template_alignments_file.exists()
Expand All @@ -725,47 +748,98 @@ def query_format_main(self):
# Create empty DataFrame with expected column structure (at least column 0)
# to match the structure when file is read with header=None.
logger.warning(
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should error here

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could raise an error here instead.

My concern is if a user is running a large batch of predictions, they may prefer to be notified later about the issue with missing templates, rather than have the workflow interrupted for a few broken examples. We could think of a better way to record this issue and bring attention to the missing template alginments?

Copy link
Copy Markdown
Collaborator Author

@jandom jandom Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hard to tell – either way it's out of scope in a way, because it's not related to the bug-fix per se. Should we handle this in another PR? This PR is already 20 files, we're ballooning

"Colabfold returned no templates. \
Proceeding without template alignments for this batch."
"Colabfold returned no templates. "
"Proceeding without template alignments for this batch."
)
template_alignments = pd.DataFrame()
m_with_templates = set()

for rep_id, aln in zip(
self.colabfold_mapper.rep_ids, a3m_lines_main, strict=False
):
rep_dir = main_alignments_path / str(rep_id)
template_rep_dir = template_alignments_path / str(rep_id)

# TODO: add code for which format to save the MSA in
# If save as a3m...
if "a3m" in self.msa_file_format:
rep_dir.mkdir(parents=True, exist_ok=True)
a3m_file = rep_dir / "colabfold_main.a3m"
with open(a3m_file, "w") as f:
f.write(aln)
if len(template_alignments) == 0:
return

# If save as npz...
if "npz" in self.msa_file_format:
npz_file = Path(f"{rep_dir}.npz")
msas = {"colabfold_main": parse_a3m(aln)}
msas_preparsed = {}
for k, v in msas.items():
msas_preparsed[k] = v.to_dict()
np.savez_compressed(npz_file, **msas_preparsed)
max_templates_per_chain = 25

# Format template alignments
# Gather unique PDB entry IDs from the top-n rows per chain
unique_pdb_ids = set()
for rep_id in self.colabfold_mapper.rep_ids:
m_i = self.colabfold_mapper.rep_id_to_m[rep_id]
if m_i in m_with_templates and len(template_alignments) > 0:
template_rep_dir.mkdir(parents=True, exist_ok=True)
template_alignment_file = template_rep_dir / "colabfold_template.m8"
template_alignment = template_alignments[template_alignments[0] == m_i]
template_alignment.to_csv(
template_alignment_file,
sep="\t",
header=False,
index=False,
)
if m_i not in m_with_templates:
continue
chain_alns = template_alignments[template_alignments[0] == m_i]
top_n = chain_alns.head(max_templates_per_chain)
# Column 1 = template_id, e.g. "4pqx_A" -> entry_id = "4pqx"
entry_ids = top_n[1].str.split("_").str[0]
unique_pdb_ids.update(entry_ids)

# 3) Download CIF files for all unique PDB IDs
cif_dir = self.output_directory / "template_cifs_tmp"
cif_dir.mkdir(parents=True, exist_ok=True)

pdb_ids_to_download = [
pdb_id
for pdb_id in unique_pdb_ids
if not (cif_dir / f"{pdb_id}.cif").exists()
]
if pdb_ids_to_download:
for pdb_id in tqdm(pdb_ids_to_download, desc="Downloading template CIFs"):
fetch(pdb_id, format="cif", target_path=cif_dir)

# 4) Build author->label chain ID maps and remap + save templates
# Cache per entry_id so we don't re-parse CIFs
author_to_label_cache: dict[str, dict[str, list[str]]] = {}

for rep_id in self.colabfold_mapper.rep_ids:
m_i = self.colabfold_mapper.rep_id_to_m[rep_id]
if m_i not in m_with_templates:
continue

chain_alns = template_alignments[template_alignments[0] == m_i]
top_n = chain_alns.head(max_templates_per_chain).copy()

# Remap author chain IDs -> label chain IDs
remapped_template_ids = []
for template_id in top_n[1]:
entry_id, author_chain_id = template_id.split("_")

if entry_id not in author_to_label_cache:
cif_path = cif_dir / f"{entry_id}.cif"
if cif_path.exists():
cif_file = CIFFile.read(cif_path)
label_to_author = get_label_to_author_chain_id_dict(cif_file)
author_to_label_cache[entry_id] = get_author_to_label_chain_ids(
label_to_author
)
else:
logger.warning(
f"CIF file not found for {entry_id}, skipping remap"
)
author_to_label_cache[entry_id] = {}

a2l = author_to_label_cache[entry_id]
if author_chain_id in a2l:
# first (smallest) label ID
label_chain_id = a2l[author_chain_id][0]
else:
logger.warning(
f"Author chain {author_chain_id} not found in {entry_id}, "
"keeping as-is"
)
label_chain_id = author_chain_id

remapped_template_ids.append(f"{entry_id}_{label_chain_id}")

top_n[1] = remapped_template_ids

# 5) Save remapped m8
template_rep_dir = template_alignments_path / str(rep_id)
template_rep_dir.mkdir(parents=True, exist_ok=True)
template_alignment_file = template_rep_dir / "colabfold_template.m8"
top_n.to_csv(
template_alignment_file,
sep="\t",
header=False,
index=False,
)

def query_format_paired(self):
"""Submits queries and formats the outputs for paired MSAs."""
Expand Down
Loading
Loading