Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 1 addition & 78 deletions openfold3/core/data/primitives/caches/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from pathlib import Path
from typing import NamedTuple

import requests
from tqdm import tqdm

from openfold3.core.data.io.dataset_cache import read_datacache
Expand Down Expand Up @@ -60,6 +59,7 @@
LIGAND_EXCLUSION_LIST,
)
from openfold3.core.data.resources.residues import MoleculeType
from openfold3.core.data.tools.rscb import get_model_ranking_fit

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -855,83 +855,6 @@ def set_nan_fallback_conformer_flag(
return None


# TODO: Do this in preprocessing instead to avoid it going out-of-sync with the data?
def get_model_ranking_fit(pdb_id):
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to a new rscb.py module

"""Fetches the model ranking fit entries for all ligands of a single PDB-ID.

Uses the PDB GraphQL API to fetch the model ranking fit values for all ligands in a
single PDB entry. Note that this function will always fetch from the newest version
of the PDB and can therefore occasionally give incorrect results for old datasets
whose structures have been updated since.
"""
url = "https://data.rcsb.org/graphql" # RCSB PDB's GraphQL API endpoint

query = """
query GetRankingFit($pdb_id: String!) {
entry(entry_id: $pdb_id) {
nonpolymer_entities {
nonpolymer_entity_instances {
rcsb_id
rcsb_nonpolymer_instance_validation_score {
ranking_model_fit
}
}
}
}
}
"""

# Prepare the request with the pdb_id as a variable
variables = {"pdb_id": pdb_id}

# Make the request to the GraphQL endpoint using the variables
response = requests.post(url, json={"query": query, "variables": variables})

# Check if the request was successful
if response.status_code == 200:
try:
# Parse the JSON response
data = response.json()

# Safely navigate through data
entry_data = data.get("data", {}).get("entry", {})
if not entry_data:
return {}

extracted_data = {}

# Check for nonpolymer_entities
nonpolymer_entities = entry_data.get("nonpolymer_entities", [])

if nonpolymer_entities:
for entity in nonpolymer_entities:
for instance in entity.get("nonpolymer_entity_instances", []):
rcsb_id = instance.get("rcsb_id")
validation_score = instance.get(
"rcsb_nonpolymer_instance_validation_score"
)

if (
validation_score
and isinstance(validation_score, list)
and validation_score[0]
):
ranking_model_fit = validation_score[0].get(
"ranking_model_fit"
)
if ranking_model_fit is not None:
extracted_data[rcsb_id] = ranking_model_fit

return extracted_data

except (KeyError, TypeError, ValueError) as e:
print(f"Error processing response for {pdb_id}: {e}")
return {}
else:
print(f"Request failed with status code {response.status_code}")
return {}


def assign_ligand_model_fits(
structure_cache: ValidationDatasetCache, num_threads: int = 3
) -> None:
Expand Down
23 changes: 23 additions & 0 deletions openfold3/core/data/primitives/structure/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,29 @@ def get_asym_id_to_canonical_seq_dict(
}


def get_author_to_label_chain_ids(
label_to_author: dict[str, str],
) -> dict[str, list[str]]:
"""Get a mapping from author (pdb_strand_id) chain ID to label asym_ids.

Multiple label asym_ids can map to the same author chain ID for homomeric
chains. The returned lists are sorted by label asym_id for determinism.

Args:
label_to_author:
Dictionary mapping label asym IDs to author chain IDs.

Returns:
A dictionary mapping author chain IDs to sorted lists of label asym IDs.
"""
author_to_labels: dict[str, list[str]] = defaultdict(list)
for label, author in label_to_author.items():
author_to_labels[author].append(label)
for labels in author_to_labels.values():
labels.sort()
return dict(author_to_labels)


def get_entity_to_three_letter_codes_dict(cif_data: CIFBlock) -> dict[int, list[str]]:
"""Get a dictionary mapping entity IDs to their three-letter-code sequences.

Expand Down
143 changes: 106 additions & 37 deletions openfold3/core/data/tools/colabfold_msa_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@
from openfold3.core.config import config_utils
from openfold3.core.data.io.sequence.msa import parse_a3m
from openfold3.core.data.primitives.sequence.hash import get_sequence_hash
from openfold3.core.data.primitives.structure.metadata import (
get_author_to_label_chain_ids,
)
from openfold3.core.data.resources.residues import MoleculeType
from openfold3.core.data.tools.rscb import fetch_label_to_author_chain_ids
from openfold3.projects.of3_all_atom.config.inference_query_format import (
InferenceQuerySet,
)
Expand Down Expand Up @@ -642,6 +646,68 @@ def save_colabfold_mappings(
)


def remap_colabfold_template_chain_ids(
template_alignments: pd.DataFrame,
m_with_templates: set[int],
rep_ids: list[str],
rep_id_to_m: dict[str, int],
) -> dict[str, pd.DataFrame]:
"""Remap author chain IDs to label (``label_asym_id``) chain IDs.

ColabFold returns template IDs with author-assigned chain IDs
(``pdb_strand_id``), but the rest of the pipeline expects mmCIF
``label_asym_id``. This function queries the RCSB PDB API to obtain
the mapping and rewrites the template IDs.

Args:
template_alignments: Raw template alignment DataFrame from ``pdb70.m8``.
m_with_templates: Set of ColabFold M-indices that have template hits.
rep_ids: List of representative chain IDs.
rep_id_to_m: Mapping from representative ID to ColabFold M-index.

Returns:
Dictionary mapping ``rep_id`` to a DataFrame with remapped template IDs.
"""
# Collect DataFrames per rep_id and accumulate unique PDB IDs
per_rep: dict[str, pd.DataFrame] = {}
unique_pdb_ids: set[str] = set()
for rep_id in rep_ids:
m_i = rep_id_to_m[rep_id]
if m_i not in m_with_templates:
continue
chain_alns = template_alignments[template_alignments[0] == m_i]
top_n = chain_alns.copy()
per_rep[rep_id] = top_n
unique_pdb_ids.update(top_n[1].str.split("_").str[0])
Comment on lines +672 to +681
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love this code – but it's just refactored into a function...


# Fetch label->author mappings in one API call, then invert
label_to_author_maps = fetch_label_to_author_chain_ids(unique_pdb_ids)
author_to_label_maps = {
entry_id: get_author_to_label_chain_ids(l2a)
for entry_id, l2a in label_to_author_maps.items()
}

# Remap chain IDs
for top_n in per_rep.values():
remapped_ids = []
for template_id in top_n[1]:
entry_id, author_chain_id = template_id.split("_")

author_to_label = author_to_label_maps.get(entry_id, {})
if author_chain_id not in author_to_label:
raise RuntimeError(
f"Author chain {author_chain_id} not found in {entry_id}. "
f"Available author chains: {sorted(author_to_label.keys())}"
)
label_chain_id = author_to_label[author_chain_id][0]

remapped_ids.append(f"{entry_id}_{label_chain_id}")

top_n[1] = remapped_ids

return per_rep


class ColabFoldQueryRunner:
"""Class to run queries on the ColabFold MSA server.

Expand Down Expand Up @@ -684,7 +750,6 @@ def __init__(
def query_format_main(self):
"""Submits queries and formats the outputs for main MSAs."""
# Submit query for main MSAs
# TODO: add template alignments fetching code here by setting use_templates=True
# TODO: replace prints with proper logging
if len(self.colabfold_mapper.seqs) == 0:
print("No protein sequences found for main MSA generation. Skipping...")
Expand All @@ -709,7 +774,25 @@ def query_format_main(self):
template_alignments_path = self.output_directory / "template"
template_alignments_path.mkdir(parents=True, exist_ok=True)

# Read template alignments if the file exists and has content
# 1) Save MSA a3m/npz files
for rep_id, aln in zip(
self.colabfold_mapper.rep_ids, a3m_lines_main, strict=True
):
rep_dir = main_alignments_path / str(rep_id)

if "a3m" in self.msa_file_format:
rep_dir.mkdir(parents=True, exist_ok=True)
a3m_file = rep_dir / "colabfold_main.a3m"
with open(a3m_file, "w") as f:
f.write(aln)

if "npz" in self.msa_file_format:
npz_file = Path(f"{rep_dir}.npz")
msas = {"colabfold_main": parse_a3m(aln)}
msas_preparsed = {k: v.to_dict() for k, v in msas.items()}
np.savez_compressed(npz_file, **msas_preparsed)

# 2) Read raw template alignments and collect unique PDB IDs
template_alignments_file = self.output_directory / "raw/main/pdb70.m8"
if (
template_alignments_file.exists()
Expand All @@ -725,47 +808,33 @@ def query_format_main(self):
# Create empty DataFrame with expected column structure (at least column 0)
# to match the structure when file is read with header=None.
logger.warning(
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should error here

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could raise an error here instead.

My concern is if a user is running a large batch of predictions, they may prefer to be notified later about the issue with missing templates, rather than have the workflow interrupted for a few broken examples. We could think of a better way to record this issue and bring attention to the missing template alginments?

Copy link
Copy Markdown
Collaborator Author

@jandom jandom Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hard to tell – either way it's out of scope in a way, because it's not related to the bug-fix per se. Should we handle this in another PR? This PR is already 20 files, we're ballooning

"Colabfold returned no templates. \
Proceeding without template alignments for this batch."
"Colabfold returned no templates. "
"Proceeding without template alignments for this batch."
)
template_alignments = pd.DataFrame()
m_with_templates = set()

for rep_id, aln in zip(
self.colabfold_mapper.rep_ids, a3m_lines_main, strict=False
):
rep_dir = main_alignments_path / str(rep_id)
template_rep_dir = template_alignments_path / str(rep_id)

# TODO: add code for which format to save the MSA in
# If save as a3m...
if "a3m" in self.msa_file_format:
rep_dir.mkdir(parents=True, exist_ok=True)
a3m_file = rep_dir / "colabfold_main.a3m"
with open(a3m_file, "w") as f:
f.write(aln)
if len(template_alignments) == 0:
return

# If save as npz...
if "npz" in self.msa_file_format:
npz_file = Path(f"{rep_dir}.npz")
msas = {"colabfold_main": parse_a3m(aln)}
msas_preparsed = {}
for k, v in msas.items():
msas_preparsed[k] = v.to_dict()
np.savez_compressed(npz_file, **msas_preparsed)
# 3) Remap author chain IDs -> label chain IDs via RCSB API
remapped = remap_colabfold_template_chain_ids(
template_alignments=template_alignments,
m_with_templates=m_with_templates,
rep_ids=self.colabfold_mapper.rep_ids,
rep_id_to_m=self.colabfold_mapper.rep_id_to_m,
)

# Format template alignments
m_i = self.colabfold_mapper.rep_id_to_m[rep_id]
if m_i in m_with_templates and len(template_alignments) > 0:
template_rep_dir.mkdir(parents=True, exist_ok=True)
template_alignment_file = template_rep_dir / "colabfold_template.m8"
template_alignment = template_alignments[template_alignments[0] == m_i]
template_alignment.to_csv(
template_alignment_file,
sep="\t",
header=False,
index=False,
)
# 4) Save remapped m8 files
for rep_id, df in remapped.items():
template_rep_dir = template_alignments_path / str(rep_id)
template_rep_dir.mkdir(parents=True, exist_ok=True)
df.to_csv(
template_rep_dir / "colabfold_template.m8",
sep="\t",
header=False,
index=False,
)

def query_format_paired(self):
"""Submits queries and formats the outputs for paired MSAs."""
Expand Down
Loading
Loading