diff --git a/openfold3/core/data/primitives/caches/filtering.py b/openfold3/core/data/primitives/caches/filtering.py index 268cffbc9..a1aabb240 100644 --- a/openfold3/core/data/primitives/caches/filtering.py +++ b/openfold3/core/data/primitives/caches/filtering.py @@ -27,7 +27,6 @@ from pathlib import Path from typing import NamedTuple -import requests from tqdm import tqdm from openfold3.core.data.io.dataset_cache import read_datacache @@ -60,6 +59,7 @@ LIGAND_EXCLUSION_LIST, ) from openfold3.core.data.resources.residues import MoleculeType +from openfold3.core.data.tools.rscb import get_model_ranking_fit logger = logging.getLogger(__name__) @@ -855,83 +855,6 @@ def set_nan_fallback_conformer_flag( return None -# TODO: Do this in preprocessing instead to avoid it going out-of-sync with the data? -def get_model_ranking_fit(pdb_id): - """Fetches the model ranking fit entries for all ligands of a single PDB-ID. - - Uses the PDB GraphQL API to fetch the model ranking fit values for all ligands in a - single PDB entry. Note that this function will always fetch from the newest version - of the PDB and can therefore occasionally give incorrect results for old datasets - whose structures have been updated since. - """ - url = "https://data.rcsb.org/graphql" # RCSB PDB's GraphQL API endpoint - - query = """ - query GetRankingFit($pdb_id: String!) { - entry(entry_id: $pdb_id) { - nonpolymer_entities { - nonpolymer_entity_instances { - rcsb_id - rcsb_nonpolymer_instance_validation_score { - ranking_model_fit - } - } - } - } - } - """ - - # Prepare the request with the pdb_id as a variable - variables = {"pdb_id": pdb_id} - - # Make the request to the GraphQL endpoint using the variables - response = requests.post(url, json={"query": query, "variables": variables}) - - # Check if the request was successful - if response.status_code == 200: - try: - # Parse the JSON response - data = response.json() - - # Safely navigate through data - entry_data = data.get("data", {}).get("entry", {}) - if not entry_data: - return {} - - extracted_data = {} - - # Check for nonpolymer_entities - nonpolymer_entities = entry_data.get("nonpolymer_entities", []) - - if nonpolymer_entities: - for entity in nonpolymer_entities: - for instance in entity.get("nonpolymer_entity_instances", []): - rcsb_id = instance.get("rcsb_id") - validation_score = instance.get( - "rcsb_nonpolymer_instance_validation_score" - ) - - if ( - validation_score - and isinstance(validation_score, list) - and validation_score[0] - ): - ranking_model_fit = validation_score[0].get( - "ranking_model_fit" - ) - if ranking_model_fit is not None: - extracted_data[rcsb_id] = ranking_model_fit - - return extracted_data - - except (KeyError, TypeError, ValueError) as e: - print(f"Error processing response for {pdb_id}: {e}") - return {} - else: - print(f"Request failed with status code {response.status_code}") - return {} - - def assign_ligand_model_fits( structure_cache: ValidationDatasetCache, num_threads: int = 3 ) -> None: diff --git a/openfold3/core/data/primitives/structure/metadata.py b/openfold3/core/data/primitives/structure/metadata.py index e9f49418c..02aa609aa 100644 --- a/openfold3/core/data/primitives/structure/metadata.py +++ b/openfold3/core/data/primitives/structure/metadata.py @@ -323,6 +323,29 @@ def get_asym_id_to_canonical_seq_dict( } +def get_author_to_label_chain_ids( + label_to_author: dict[str, str], +) -> dict[str, list[str]]: + """Get a mapping from author (pdb_strand_id) chain ID to label asym_ids. + + Multiple label asym_ids can map to the same author chain ID for homomeric + chains. The returned lists are sorted by label asym_id for determinism. + + Args: + label_to_author: + Dictionary mapping label asym IDs to author chain IDs. + + Returns: + A dictionary mapping author chain IDs to sorted lists of label asym IDs. + """ + author_to_labels: dict[str, list[str]] = defaultdict(list) + for label, author in label_to_author.items(): + author_to_labels[author].append(label) + for labels in author_to_labels.values(): + labels.sort() + return dict(author_to_labels) + + def get_entity_to_three_letter_codes_dict(cif_data: CIFBlock) -> dict[int, list[str]]: """Get a dictionary mapping entity IDs to their three-letter-code sequences. diff --git a/openfold3/core/data/tools/colabfold_msa_server.py b/openfold3/core/data/tools/colabfold_msa_server.py index 907bdfd81..e4b12f295 100644 --- a/openfold3/core/data/tools/colabfold_msa_server.py +++ b/openfold3/core/data/tools/colabfold_msa_server.py @@ -37,7 +37,11 @@ from openfold3.core.config import config_utils from openfold3.core.data.io.sequence.msa import parse_a3m from openfold3.core.data.primitives.sequence.hash import get_sequence_hash +from openfold3.core.data.primitives.structure.metadata import ( + get_author_to_label_chain_ids, +) from openfold3.core.data.resources.residues import MoleculeType +from openfold3.core.data.tools.rscb import fetch_label_to_author_chain_ids from openfold3.projects.of3_all_atom.config.inference_query_format import ( InferenceQuerySet, ) @@ -642,6 +646,68 @@ def save_colabfold_mappings( ) +def remap_colabfold_template_chain_ids( + template_alignments: pd.DataFrame, + m_with_templates: set[int], + rep_ids: list[str], + rep_id_to_m: dict[str, int], +) -> dict[str, pd.DataFrame]: + """Remap author chain IDs to label (``label_asym_id``) chain IDs. + + ColabFold returns template IDs with author-assigned chain IDs + (``pdb_strand_id``), but the rest of the pipeline expects mmCIF + ``label_asym_id``. This function queries the RCSB PDB API to obtain + the mapping and rewrites the template IDs. + + Args: + template_alignments: Raw template alignment DataFrame from ``pdb70.m8``. + m_with_templates: Set of ColabFold M-indices that have template hits. + rep_ids: List of representative chain IDs. + rep_id_to_m: Mapping from representative ID to ColabFold M-index. + + Returns: + Dictionary mapping ``rep_id`` to a DataFrame with remapped template IDs. + """ + # Collect DataFrames per rep_id and accumulate unique PDB IDs + per_rep: dict[str, pd.DataFrame] = {} + unique_pdb_ids: set[str] = set() + for rep_id in rep_ids: + m_i = rep_id_to_m[rep_id] + if m_i not in m_with_templates: + continue + chain_alns = template_alignments[template_alignments[0] == m_i] + top_n = chain_alns.copy() + per_rep[rep_id] = top_n + unique_pdb_ids.update(top_n[1].str.split("_").str[0]) + + # Fetch label->author mappings in one API call, then invert + label_to_author_maps = fetch_label_to_author_chain_ids(unique_pdb_ids) + author_to_label_maps = { + entry_id: get_author_to_label_chain_ids(l2a) + for entry_id, l2a in label_to_author_maps.items() + } + + # Remap chain IDs + for top_n in per_rep.values(): + remapped_ids = [] + for template_id in top_n[1]: + entry_id, author_chain_id = template_id.split("_") + + author_to_label = author_to_label_maps.get(entry_id, {}) + if author_chain_id not in author_to_label: + raise RuntimeError( + f"Author chain {author_chain_id} not found in {entry_id}. " + f"Available author chains: {sorted(author_to_label.keys())}" + ) + label_chain_id = author_to_label[author_chain_id][0] + + remapped_ids.append(f"{entry_id}_{label_chain_id}") + + top_n[1] = remapped_ids + + return per_rep + + class ColabFoldQueryRunner: """Class to run queries on the ColabFold MSA server. @@ -684,7 +750,6 @@ def __init__( def query_format_main(self): """Submits queries and formats the outputs for main MSAs.""" # Submit query for main MSAs - # TODO: add template alignments fetching code here by setting use_templates=True # TODO: replace prints with proper logging if len(self.colabfold_mapper.seqs) == 0: print("No protein sequences found for main MSA generation. Skipping...") @@ -709,7 +774,25 @@ def query_format_main(self): template_alignments_path = self.output_directory / "template" template_alignments_path.mkdir(parents=True, exist_ok=True) - # Read template alignments if the file exists and has content + # 1) Save MSA a3m/npz files + for rep_id, aln in zip( + self.colabfold_mapper.rep_ids, a3m_lines_main, strict=True + ): + rep_dir = main_alignments_path / str(rep_id) + + if "a3m" in self.msa_file_format: + rep_dir.mkdir(parents=True, exist_ok=True) + a3m_file = rep_dir / "colabfold_main.a3m" + with open(a3m_file, "w") as f: + f.write(aln) + + if "npz" in self.msa_file_format: + npz_file = Path(f"{rep_dir}.npz") + msas = {"colabfold_main": parse_a3m(aln)} + msas_preparsed = {k: v.to_dict() for k, v in msas.items()} + np.savez_compressed(npz_file, **msas_preparsed) + + # 2) Read raw template alignments and collect unique PDB IDs template_alignments_file = self.output_directory / "raw/main/pdb70.m8" if ( template_alignments_file.exists() @@ -725,47 +808,33 @@ def query_format_main(self): # Create empty DataFrame with expected column structure (at least column 0) # to match the structure when file is read with header=None. logger.warning( - "Colabfold returned no templates. \ - Proceeding without template alignments for this batch." + "Colabfold returned no templates. " + "Proceeding without template alignments for this batch." ) template_alignments = pd.DataFrame() m_with_templates = set() - for rep_id, aln in zip( - self.colabfold_mapper.rep_ids, a3m_lines_main, strict=False - ): - rep_dir = main_alignments_path / str(rep_id) - template_rep_dir = template_alignments_path / str(rep_id) - - # TODO: add code for which format to save the MSA in - # If save as a3m... - if "a3m" in self.msa_file_format: - rep_dir.mkdir(parents=True, exist_ok=True) - a3m_file = rep_dir / "colabfold_main.a3m" - with open(a3m_file, "w") as f: - f.write(aln) + if len(template_alignments) == 0: + return - # If save as npz... - if "npz" in self.msa_file_format: - npz_file = Path(f"{rep_dir}.npz") - msas = {"colabfold_main": parse_a3m(aln)} - msas_preparsed = {} - for k, v in msas.items(): - msas_preparsed[k] = v.to_dict() - np.savez_compressed(npz_file, **msas_preparsed) + # 3) Remap author chain IDs -> label chain IDs via RCSB API + remapped = remap_colabfold_template_chain_ids( + template_alignments=template_alignments, + m_with_templates=m_with_templates, + rep_ids=self.colabfold_mapper.rep_ids, + rep_id_to_m=self.colabfold_mapper.rep_id_to_m, + ) - # Format template alignments - m_i = self.colabfold_mapper.rep_id_to_m[rep_id] - if m_i in m_with_templates and len(template_alignments) > 0: - template_rep_dir.mkdir(parents=True, exist_ok=True) - template_alignment_file = template_rep_dir / "colabfold_template.m8" - template_alignment = template_alignments[template_alignments[0] == m_i] - template_alignment.to_csv( - template_alignment_file, - sep="\t", - header=False, - index=False, - ) + # 4) Save remapped m8 files + for rep_id, df in remapped.items(): + template_rep_dir = template_alignments_path / str(rep_id) + template_rep_dir.mkdir(parents=True, exist_ok=True) + df.to_csv( + template_rep_dir / "colabfold_template.m8", + sep="\t", + header=False, + index=False, + ) def query_format_paired(self): """Submits queries and formats the outputs for paired MSAs.""" diff --git a/openfold3/core/data/tools/rscb.py b/openfold3/core/data/tools/rscb.py new file mode 100644 index 000000000..75a2b6017 --- /dev/null +++ b/openfold3/core/data/tools/rscb.py @@ -0,0 +1,163 @@ +# Copyright 2026 AlQuraishi Laboratory +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import requests + +logger = logging.getLogger(__name__) + +_RCSB_GRAPHQL_URL = "https://data.rcsb.org/graphql" + +_CHAIN_MAPPING_QUERY = """ +query($ids: [String!]!) { + entries(entry_ids: $ids) { + rcsb_id + polymer_entities { + rcsb_polymer_entity_container_identifiers { + asym_ids + auth_asym_ids + } + } + } +} +""" + + +def fetch_label_to_author_chain_ids( + pdb_ids: set[str], +) -> dict[str, dict[str, str]]: + """Fetch label-to-author chain ID mappings from the RCSB PDB GraphQL API. + + Makes a single batched request for all PDB IDs and returns a nested dict + mapping ``entry_id`` → ``label_asym_id`` → ``author_chain_id``. + + Args: + pdb_ids: Set of PDB entry IDs (e.g. ``{"4pqx", "1rnb"}``). + + Returns: + Nested dict: ``entry_id`` (lower-case) → ``label_asym_id`` → + ``author_chain_id``. + + Raises: + RuntimeError: If the RCSB API request fails. + """ + if not pdb_ids: + return {} + + try: + resp = requests.post( + _RCSB_GRAPHQL_URL, + json={ + "query": _CHAIN_MAPPING_QUERY, + "variables": {"ids": sorted(pdb_ids)}, + }, + timeout=30, + ) + resp.raise_for_status() + except Exception as e: + raise RuntimeError( + f"Failed to fetch chain ID mappings from RCSB for " + f"{len(pdb_ids)} entries. Cannot proceed without chain ID " + f"re-mapping." + ) from e + + data = resp.json().get("data", {}) + entries = data.get("entries") or [] + + result: dict[str, dict[str, str]] = {} + for entry in entries: + entry_id = entry["rcsb_id"].lower() + label_to_author: dict[str, str] = {} + for entity in entry.get("polymer_entities") or []: + ids = entity["rcsb_polymer_entity_container_identifiers"] + for asym_id, auth_id in zip( + ids["asym_ids"], ids["auth_asym_ids"], strict=True + ): + label_to_author[asym_id] = auth_id + result[entry_id] = label_to_author + + return result + + +_MODEL_RANKING_FIT_QUERY = """ +query GetRankingFit($pdb_id: String!) { + entry(entry_id: $pdb_id) { + nonpolymer_entities { + nonpolymer_entity_instances { + rcsb_id + rcsb_nonpolymer_instance_validation_score { + ranking_model_fit + } + } + } + } +} +""" + + +# TODO: Do this in preprocessing instead to avoid it going out-of-sync with the data? +def get_model_ranking_fit(pdb_id: str) -> dict[str, float]: + """Fetch model ranking fit entries for all ligands of a single PDB entry. + + Uses the RCSB PDB GraphQL API to fetch the model ranking fit values for + all ligands in a single PDB entry. Note that this function will always + fetch from the newest version of the PDB and can therefore occasionally + give incorrect results for old datasets whose structures have been updated + since. + + Args: + pdb_id: PDB entry ID (e.g. ``"4pqx"``). + + Returns: + Dictionary mapping ``rcsb_id`` (e.g. ``"4PQX.C"``) to its + ``ranking_model_fit`` score. Returns an empty dict on failure. + """ + response = requests.post( + _RCSB_GRAPHQL_URL, + json={"query": _MODEL_RANKING_FIT_QUERY, "variables": {"pdb_id": pdb_id}}, + timeout=30, + ) + + if response.status_code != 200: + logger.warning("RCSB request failed with status code %d", response.status_code) + return {} + + try: + data = response.json() + entry_data = data.get("data", {}).get("entry", {}) + if not entry_data: + return {} + + extracted_data: dict[str, float] = {} + for entity in entry_data.get("nonpolymer_entities") or []: + for instance in entity.get("nonpolymer_entity_instances") or []: + rcsb_id = instance.get("rcsb_id") + validation_score = instance.get( + "rcsb_nonpolymer_instance_validation_score" + ) + if ( + validation_score + and isinstance(validation_score, list) + and validation_score[0] + ): + ranking_model_fit = validation_score[0].get("ranking_model_fit") + if ranking_model_fit is not None: + extracted_data[rcsb_id] = ranking_model_fit + + return extracted_data + + except (KeyError, TypeError, ValueError) as e: + logger.warning("Error processing response for %s: %s", pdb_id, e) + return {} diff --git a/openfold3/tests/test_template_parsers.py b/openfold3/tests/core/data/io/sequence/template/test_template_parsers.py similarity index 97% rename from openfold3/tests/test_template_parsers.py rename to openfold3/tests/core/data/io/sequence/template/test_template_parsers.py index 88ae25671..88b3969ff 100644 --- a/openfold3/tests/test_template_parsers.py +++ b/openfold3/tests/core/data/io/sequence/template/test_template_parsers.py @@ -18,6 +18,7 @@ import pandas as pd import pytest +import openfold3 from openfold3.core.data.io.sequence.template import ( A3mParser, M8Parser, @@ -25,7 +26,9 @@ TemplateData, ) -TEST_DIR = Path(__file__).parent / "test_data" / "template_alignments" +TEST_DIR = ( + Path(openfold3.__file__).parent / "tests" / "test_data" / "template_alignments" +) QUERY_SEQUENCE = """ MLNSFKLSLQYILPKLWLTRLAGWGASKRAGWLTKLVIDLFVKYYKVDMKEAQKPDTASYRTFNEFFVRPLRDEVRPIDTDPNVLV diff --git a/openfold3/tests/core/data/primitives/structure/test_metadata.py b/openfold3/tests/core/data/primitives/structure/test_metadata.py new file mode 100644 index 000000000..b0b29a5df --- /dev/null +++ b/openfold3/tests/core/data/primitives/structure/test_metadata.py @@ -0,0 +1,29 @@ +import pytest + +from openfold3.core.data.primitives.structure.metadata import ( + get_author_to_label_chain_ids, +) + + +class TestGetAuthorToLabelChainIds: + @pytest.mark.parametrize( + ("label_to_author", "expected"), + [ + pytest.param({"A": "X"}, {"X": ["A"]}, id="single_chain"), + pytest.param( + {"A": "X", "B": "Y", "C": "Z"}, + {"X": ["A"], "Y": ["B"], "Z": ["C"]}, + id="multiple_distinct_chains", + ), + pytest.param( + {"A": "X", "B": "X"}, {"X": ["A", "B"]}, id="homomeric_chains" + ), + pytest.param( + {"C": "X", "A": "X", "B": "X"}, + {"X": ["A", "B", "C"]}, + id="homomeric_chains_sorted", + ), + ], + ) + def test_author_to_labels(self, label_to_author, expected): + assert get_author_to_label_chain_ids(label_to_author) == expected diff --git a/openfold3/tests/core/data/tools/conftest.py b/openfold3/tests/core/data/tools/conftest.py new file mode 100644 index 000000000..7e45a2ae5 --- /dev/null +++ b/openfold3/tests/core/data/tools/conftest.py @@ -0,0 +1,18 @@ +"""Pytest configuration for tools tests -- VCR cassette directory.""" + +from pathlib import Path + +import pytest + +import openfold3 + +_CASSETTE_DIR = ( + Path(openfold3.__file__).parent / "tests" / "test_data" / "cassettes" / "test_rscb" +) + + +@pytest.fixture(scope="module") +def vcr_config(): + return { + "cassette_library_dir": str(_CASSETTE_DIR), + } diff --git a/openfold3/tests/test_colabfold_msa.py b/openfold3/tests/core/data/tools/test_colabfold_msa_server.py similarity index 81% rename from openfold3/tests/test_colabfold_msa.py rename to openfold3/tests/core/data/tools/test_colabfold_msa_server.py index 2b607854d..7b7347dcf 100644 --- a/openfold3/tests/test_colabfold_msa.py +++ b/openfold3/tests/core/data/tools/test_colabfold_msa_server.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests to check handling of colabofold MSA data.""" +"""Tests for the ColabFold MSA server module.""" import json import textwrap @@ -34,6 +34,7 @@ collect_colabfold_msa_data, get_sequence_hash, preprocess_colabfold_msas, + remap_colabfold_template_chain_ids, ) from openfold3.projects.of3_all_atom.config.dataset_config_components import MSASettings from openfold3.projects.of3_all_atom.config.dataset_configs import ( @@ -44,6 +45,52 @@ InferenceQuerySet, ) +_MOCK_FETCH_TARGET = ( + "openfold3.core.data.tools.colabfold_msa_server.fetch_label_to_author_chain_ids" +) +_MOCK_QUERY_TARGET = ( + "openfold3.core.data.tools.colabfold_msa_server.query_colabfold_msa_server" +) + +# Realistic label->author mappings for test PDB entries. +# 1RNB: label B -> author A (protein), label A -> author C (DNA) +# 4PQX: identity mapping +_MOCK_LABEL_TO_AUTHOR = { + "1rnb": {"A": "C", "B": "A"}, + "4pqx": {"A": "A"}, + "test": {"A": "A", "B": "B", "C": "C"}, +} + + +def _mock_fetch_label_to_author(pdb_ids): + """Return mock label->author mappings for known test PDB IDs.""" + return {pid: _MOCK_LABEL_TO_AUTHOR.get(pid, {}) for pid in pdb_ids} + + +def _make_m8_dataframe(template_ids: list[str], m_index: int = 101) -> pd.DataFrame: + """Build a minimal m8-format DataFrame for testing. + + See docs/source/template_how_to.md § 1.1.3 for the m8 column spec. + """ + n = len(template_ids) + return pd.DataFrame( + { + 0: [m_index] * n, + 1: template_ids, + 2: [0.98] * n, + 3: [100] * n, + 4: [1] * n, + 5: [0] * n, + 6: [1] * n, + 7: [100] * n, + 8: [1] * n, + 9: [100] * n, + 10: [1e-10] * n, + 11: [100] * n, + 12: ["100M"] * n, + } + ) + @pytest.fixture def multimer_query_set(): @@ -94,6 +141,47 @@ def test_complex_id_same_on_permutation_of_sequences(self): assert ComplexGroup(order1).rep_id == ComplexGroup(order2).rep_id +class TestRemapColabfoldTemplateChainIds: + """Tests for remap_colabfold_template_chain_ids (RCSB calls mocked).""" + + @patch(_MOCK_FETCH_TARGET, side_effect=_mock_fetch_label_to_author) + def test_remap_author_to_label(self, _mock_fetch): + """1rnb_A (author) should be remapped to 1rnb_B (label).""" + result = remap_colabfold_template_chain_ids( + template_alignments=_make_m8_dataframe(["1rnb_A", "4pqx_A"]), + m_with_templates={101}, + rep_ids=["rep1"], + rep_id_to_m={"rep1": 101}, + ) + + assert "rep1" in result + remapped_ids = result["rep1"][1].tolist() + assert remapped_ids[0] == "1rnb_B" + assert remapped_ids[1] == "4pqx_A" + + @patch(_MOCK_FETCH_TARGET, side_effect=_mock_fetch_label_to_author) + def test_unknown_author_chain_raises(self, _mock_fetch): + """When the author chain ID isn't in the API response, raise.""" + with pytest.raises(RuntimeError, match="Author chain Z not found in 1rnb"): + remap_colabfold_template_chain_ids( + template_alignments=_make_m8_dataframe(["1rnb_Z"]), + m_with_templates={101}, + rep_ids=["rep1"], + rep_id_to_m={"rep1": 101}, + ) + + def test_skips_rep_without_templates(self): + """Rep IDs not in m_with_templates should be skipped (no fetch needed).""" + result = remap_colabfold_template_chain_ids( + template_alignments=_make_m8_dataframe(["1rnb_A"]), + m_with_templates={999}, + rep_ids=["rep1"], + rep_id_to_m={"rep1": 101}, + ) + + assert len(result) == 0 + + class TestColabFoldQueryRunner: def _construct_monomer_query(self, sequence): return InferenceQuerySet.model_validate( @@ -145,10 +233,12 @@ def _make_empty_template_file(path: Path): # Create an empty file (0 bytes) (raw_main_dir / "pdb70.m8").touch() - @patch("openfold3.core.data.tools.colabfold_msa_server.query_colabfold_msa_server") + @patch(_MOCK_FETCH_TARGET, side_effect=_mock_fetch_label_to_author) + @patch(_MOCK_QUERY_TARGET) def test_runner_on_multimer_example( self, mock_query, + _mock_chain_map, tmp_path, multimer_query_set, multimer_sequences, @@ -180,16 +270,15 @@ def test_runner_on_multimer_example( assert (expected_unpaired_dir / f).exists() assert (expected_paired_dir / f).exists() - @patch( - "openfold3.core.data.tools.colabfold_msa_server.query_colabfold_msa_server", - side_effect=_construct_dummy_a3m, - ) + @patch(_MOCK_FETCH_TARGET, side_effect=_mock_fetch_label_to_author) + @patch(_MOCK_QUERY_TARGET, side_effect=_construct_dummy_a3m) @pytest.mark.parametrize( "msa_file_format", ["a3m", "npz"], ids=lambda fmt: f"format={fmt}" ) def test_msa_generation_on_multiple_queries_with_same_name( self, mock_query, + _mock_chain_map, tmp_path, msa_file_format, ): @@ -259,16 +348,15 @@ def test_augment_main_msa_with_query_sequence( f"Unexpected MSA path in augmented query set: {paths_in_augmented[0]}" ) - @patch( - "openfold3.core.data.tools.colabfold_msa_server.query_colabfold_msa_server", - side_effect=_construct_dummy_a3m, - ) + @patch(_MOCK_FETCH_TARGET, side_effect=_mock_fetch_label_to_author) + @patch(_MOCK_QUERY_TARGET, side_effect=_construct_dummy_a3m) @pytest.mark.parametrize( "msa_file_format", ["a3m", "npz"], ids=lambda fmt: f"{fmt}" ) def test_features_on_multiple_queries_with_same_name( self, mock_query, + _mock_chain_map, tmp_path, msa_file_format, ): @@ -325,10 +413,7 @@ def test_features_on_multiple_queries_with_same_name( "Expected all test sequences to be present in the mapping file" ) - @patch( - "openfold3.core.data.tools.colabfold_msa_server.query_colabfold_msa_server", - side_effect=_construct_dummy_a3m, - ) + @patch(_MOCK_QUERY_TARGET, side_effect=_construct_dummy_a3m) def test_empty_m8_file_handling( self, mock_query, @@ -404,7 +489,7 @@ class TestMsaComputationSettings: def test_cli_output_dir_overrides_config(self, tmp_path): """Test that CLI output directory overrides config file setting.""" test_yaml_str = textwrap.dedent("""\ - msa_file_format: a3m + msa_file_format: a3m server_user_agent: test-agent server_url: https://dummy.url """) @@ -423,7 +508,7 @@ def test_cli_output_dir_overrides_config(self, tmp_path): def test_cli_output_dir_conflict_raises(self, tmp_path): """Test that conflict between CLI and config output dirs raises ValueError.""" test_yaml_str = textwrap.dedent(f"""\ - msa_file_format: a3m + msa_file_format: a3m msa_output_directory: {tmp_path / "other_dir"} """) test_yaml_file = tmp_path / "runner.yml" diff --git a/openfold3/tests/core/data/tools/test_rscb.py b/openfold3/tests/core/data/tools/test_rscb.py new file mode 100644 index 000000000..6f4a022d3 --- /dev/null +++ b/openfold3/tests/core/data/tools/test_rscb.py @@ -0,0 +1,87 @@ +"""Tests for the RCSB GraphQL API helpers in ``openfold3.core.data.tools.rscb``. + +Tests marked ``@pytest.mark.vcr`` use *pytest-recording* (vcrpy) to replay +HTTP responses from YAML cassettes stored alongside this file in +``cassettes/``. + +Generating cassettes for the first time:: + + pytest openfold3/tests/core/data/tools/test_rscb.py --vcr-record=all + +Re-recording after RCSB schema changes or new test methods:: + + pytest openfold3/tests/core/data/tools/test_rscb.py --vcr-record=new_episodes + +In CI the cassettes are replayed without network access (the default +``--vcr-record=none`` mode). +""" + +import pytest + +from openfold3.core.data.tools.rscb import ( + fetch_label_to_author_chain_ids, + get_model_ranking_fit, +) + + +class TestFetchLabelToAuthorChainIds: + """Tests for fetch_label_to_author_chain_ids (recorded RCSB responses).""" + + @pytest.mark.vcr + def test_1rnb_label_to_author(self): + """1RNB: label chain B -> author chain A (protein).""" + result = fetch_label_to_author_chain_ids({"1rnb"}) + + assert "1rnb" in result + l2a = result["1rnb"] + assert l2a["B"] == "A" + assert l2a["A"] == "C" + + @pytest.mark.vcr + def test_identity_mapping(self): + """4PQX: label chain IDs match author chain IDs.""" + result = fetch_label_to_author_chain_ids({"4pqx"}) + + assert "4pqx" in result + assert result["4pqx"]["A"] == "A" + + @pytest.mark.vcr + def test_batch_query(self): + """Multiple PDB IDs are fetched in a single request.""" + result = fetch_label_to_author_chain_ids({"1rnb", "4pqx"}) + + assert "1rnb" in result + assert "4pqx" in result + + def test_empty_set(self): + """Empty input returns empty dict without API call.""" + assert fetch_label_to_author_chain_ids(set()) == {} + + +class TestGetModelRankingFit: + """Tests for get_model_ranking_fit (recorded RCSB responses).""" + + @pytest.mark.vcr + def test_entry_with_ligands(self): + """4PQX has ligands with ranking_model_fit scores.""" + result = get_model_ranking_fit("4pqx") + + assert isinstance(result, dict) + assert len(result) > 0 + for rcsb_id, score in result.items(): + assert rcsb_id.startswith("4PQX.") + assert isinstance(score, (int, float)) + + @pytest.mark.vcr + def test_entry_without_ligands(self): + """1RNB (protein-only) returns empty dict.""" + result = get_model_ranking_fit("1rnb") + + assert result == {} + + @pytest.mark.vcr + def test_nonexistent_entry(self): + """Invalid PDB ID returns empty dict without raising.""" + result = get_model_ranking_fit("0000") + + assert result == {} diff --git a/openfold3/tests/test_template.py b/openfold3/tests/core/model/latent/test_template_module.py similarity index 100% rename from openfold3/tests/test_template.py rename to openfold3/tests/core/model/latent/test_template_module.py diff --git a/openfold3/tests/test_data/cassettes/test_rscb/TestFetchLabelToAuthorChainIds.test_1rnb_label_to_author.yaml b/openfold3/tests/test_data/cassettes/test_rscb/TestFetchLabelToAuthorChainIds.test_1rnb_label_to_author.yaml new file mode 100644 index 000000000..3336d8564 --- /dev/null +++ b/openfold3/tests/test_data/cassettes/test_rscb/TestFetchLabelToAuthorChainIds.test_1rnb_label_to_author.yaml @@ -0,0 +1,44 @@ +interactions: +- request: + body: '{"query": "\nquery($ids: [String!]!) {\n entries(entry_ids: $ids) {\n rcsb_id\n polymer_entities + {\n rcsb_polymer_entity_container_identifiers {\n asym_ids\n auth_asym_ids\n }\n }\n }\n}\n", + "variables": {"ids": ["1rnb"]}}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate, br, zstd + Connection: + - keep-alive + Content-Length: + - '257' + Content-Type: + - application/json + User-Agent: + - python-requests/2.32.5 + method: POST + uri: https://data.rcsb.org/graphql + response: + body: + string: '{"data":{"entries":[{"rcsb_id":"1RNB","polymer_entities":[{"rcsb_polymer_entity_container_identifiers":{"asym_ids":["B"],"auth_asym_ids":["A"]}},{"rcsb_polymer_entity_container_identifiers":{"asym_ids":["A"],"auth_asym_ids":["C"]}}]}]}}' + headers: + access-control-allow-headers: + - X-Requested-With,Content-Type,Authorization,rcsb-analytics-traffic-origin,rcsb-analytics-traffic-stage + access-control-allow-methods: + - GET, OPTIONS, HEAD, PUT, POST + access-control-allow-origin: + - '*' + alt-svc: + - h3=":443";ma=60; + content-length: + - '236' + content-type: + - application/json;charset=UTF-8 + date: + - Wed, 25 Mar 2026 12:23:16 GMT + strict-transport-security: + - max-age=16000000; includeSubDomains; preload; + status: + code: 200 + message: OK +version: 1 diff --git a/openfold3/tests/test_data/cassettes/test_rscb/TestFetchLabelToAuthorChainIds.test_batch_query.yaml b/openfold3/tests/test_data/cassettes/test_rscb/TestFetchLabelToAuthorChainIds.test_batch_query.yaml new file mode 100644 index 000000000..b6dadb312 --- /dev/null +++ b/openfold3/tests/test_data/cassettes/test_rscb/TestFetchLabelToAuthorChainIds.test_batch_query.yaml @@ -0,0 +1,44 @@ +interactions: +- request: + body: '{"query": "\nquery($ids: [String!]!) {\n entries(entry_ids: $ids) {\n rcsb_id\n polymer_entities + {\n rcsb_polymer_entity_container_identifiers {\n asym_ids\n auth_asym_ids\n }\n }\n }\n}\n", + "variables": {"ids": ["1rnb", "4pqx"]}}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate, br, zstd + Connection: + - keep-alive + Content-Length: + - '265' + Content-Type: + - application/json + User-Agent: + - python-requests/2.32.5 + method: POST + uri: https://data.rcsb.org/graphql + response: + body: + string: '{"data":{"entries":[{"rcsb_id":"1RNB","polymer_entities":[{"rcsb_polymer_entity_container_identifiers":{"asym_ids":["A"],"auth_asym_ids":["C"]}},{"rcsb_polymer_entity_container_identifiers":{"asym_ids":["B"],"auth_asym_ids":["A"]}}]},{"rcsb_id":"4PQX","polymer_entities":[{"rcsb_polymer_entity_container_identifiers":{"asym_ids":["A","B","C","D"],"auth_asym_ids":["A","B","C","D"]}}]}]}}' + headers: + access-control-allow-headers: + - X-Requested-With,Content-Type,Authorization,rcsb-analytics-traffic-origin,rcsb-analytics-traffic-stage + access-control-allow-methods: + - GET, OPTIONS, HEAD, PUT, POST + access-control-allow-origin: + - '*' + alt-svc: + - h3=":443";ma=60; + content-length: + - '387' + content-type: + - application/json;charset=UTF-8 + date: + - Wed, 25 Mar 2026 12:23:16 GMT + strict-transport-security: + - max-age=16000000; includeSubDomains; preload; + status: + code: 200 + message: OK +version: 1 diff --git a/openfold3/tests/test_data/cassettes/test_rscb/TestFetchLabelToAuthorChainIds.test_identity_mapping.yaml b/openfold3/tests/test_data/cassettes/test_rscb/TestFetchLabelToAuthorChainIds.test_identity_mapping.yaml new file mode 100644 index 000000000..acf97c4b3 --- /dev/null +++ b/openfold3/tests/test_data/cassettes/test_rscb/TestFetchLabelToAuthorChainIds.test_identity_mapping.yaml @@ -0,0 +1,44 @@ +interactions: +- request: + body: '{"query": "\nquery($ids: [String!]!) {\n entries(entry_ids: $ids) {\n rcsb_id\n polymer_entities + {\n rcsb_polymer_entity_container_identifiers {\n asym_ids\n auth_asym_ids\n }\n }\n }\n}\n", + "variables": {"ids": ["4pqx"]}}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate, br, zstd + Connection: + - keep-alive + Content-Length: + - '257' + Content-Type: + - application/json + User-Agent: + - python-requests/2.32.5 + method: POST + uri: https://data.rcsb.org/graphql + response: + body: + string: '{"data":{"entries":[{"rcsb_id":"4PQX","polymer_entities":[{"rcsb_polymer_entity_container_identifiers":{"asym_ids":["A","B","C","D"],"auth_asym_ids":["A","B","C","D"]}}]}]}}' + headers: + access-control-allow-headers: + - X-Requested-With,Content-Type,Authorization,rcsb-analytics-traffic-origin,rcsb-analytics-traffic-stage + access-control-allow-methods: + - GET, OPTIONS, HEAD, PUT, POST + access-control-allow-origin: + - '*' + alt-svc: + - h3=":443";ma=60; + content-length: + - '173' + content-type: + - application/json;charset=UTF-8 + date: + - Wed, 25 Mar 2026 12:23:16 GMT + strict-transport-security: + - max-age=16000000; includeSubDomains; preload; + status: + code: 200 + message: OK +version: 1 diff --git a/openfold3/tests/test_data/cassettes/test_rscb/TestGetModelRankingFit.test_entry_with_ligands.yaml b/openfold3/tests/test_data/cassettes/test_rscb/TestGetModelRankingFit.test_entry_with_ligands.yaml new file mode 100644 index 000000000..771e814ad --- /dev/null +++ b/openfold3/tests/test_data/cassettes/test_rscb/TestGetModelRankingFit.test_entry_with_ligands.yaml @@ -0,0 +1,50 @@ +interactions: +- request: + body: '{"query": "\nquery GetRankingFit($pdb_id: String!) {\n entry(entry_id: + $pdb_id) {\n nonpolymer_entities {\n nonpolymer_entity_instances + {\n rcsb_id\n rcsb_nonpolymer_instance_validation_score + {\n ranking_model_fit\n }\n }\n }\n }\n}\n", + "variables": {"pdb_id": "4pqx"}}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate, br, zstd + Connection: + - keep-alive + Content-Length: + - '375' + Content-Type: + - application/json + User-Agent: + - python-requests/2.32.5 + method: POST + uri: https://data.rcsb.org/graphql + response: + body: + string: '{"data":{"entry":{"nonpolymer_entities":[{"nonpolymer_entity_instances":[{"rcsb_id":"4PQX.L","rcsb_nonpolymer_instance_validation_score":[{"ranking_model_fit":0.4034}]}]},{"nonpolymer_entity_instances":[{"rcsb_id":"4PQX.Z","rcsb_nonpolymer_instance_validation_score":[{"ranking_model_fit":0.334}]},{"rcsb_id":"4PQX.Y","rcsb_nonpolymer_instance_validation_score":[{"ranking_model_fit":0.2004}]},{"rcsb_id":"4PQX.W","rcsb_nonpolymer_instance_validation_score":[{"ranking_model_fit":0.3206}]},{"rcsb_id":"4PQX.V","rcsb_nonpolymer_instance_validation_score":[{"ranking_model_fit":0.5388}]},{"rcsb_id":"4PQX.K","rcsb_nonpolymer_instance_validation_score":[{"ranking_model_fit":0.197}]}]},{"nonpolymer_entity_instances":[{"rcsb_id":"4PQX.X","rcsb_nonpolymer_instance_validation_score":[{"ranking_model_fit":0.2365}]},{"rcsb_id":"4PQX.U","rcsb_nonpolymer_instance_validation_score":[{"ranking_model_fit":0.0197}]},{"rcsb_id":"4PQX.T","rcsb_nonpolymer_instance_validation_score":[{"ranking_model_fit":0.0619}]},{"rcsb_id":"4PQX.S","rcsb_nonpolymer_instance_validation_score":[{"ranking_model_fit":0.3467}]},{"rcsb_id":"4PQX.R","rcsb_nonpolymer_instance_validation_score":[{"ranking_model_fit":0.4149}]},{"rcsb_id":"4PQX.Q","rcsb_nonpolymer_instance_validation_score":[{"ranking_model_fit":0.2279}]},{"rcsb_id":"4PQX.P","rcsb_nonpolymer_instance_validation_score":[{"ranking_model_fit":0.2422}]},{"rcsb_id":"4PQX.O","rcsb_nonpolymer_instance_validation_score":[{"ranking_model_fit":0.4178}]},{"rcsb_id":"4PQX.N","rcsb_nonpolymer_instance_validation_score":[{"ranking_model_fit":0.4496}]},{"rcsb_id":"4PQX.M","rcsb_nonpolymer_instance_validation_score":[{"ranking_model_fit":0.4725}]},{"rcsb_id":"4PQX.J","rcsb_nonpolymer_instance_validation_score":[{"ranking_model_fit":0.079}]},{"rcsb_id":"4PQX.I","rcsb_nonpolymer_instance_validation_score":[{"ranking_model_fit":0.6692}]},{"rcsb_id":"4PQX.H","rcsb_nonpolymer_instance_validation_score":[{"ranking_model_fit":0.2255}]},{"rcsb_id":"4PQX.G","rcsb_nonpolymer_instance_validation_score":[{"ranking_model_fit":0.1442}]},{"rcsb_id":"4PQX.F","rcsb_nonpolymer_instance_validation_score":[{"ranking_model_fit":0.2186}]},{"rcsb_id":"4PQX.E","rcsb_nonpolymer_instance_validation_score":[{"ranking_model_fit":0.6289}]}]}]}}}' + headers: + access-control-allow-headers: + - X-Requested-With,Content-Type,Authorization,rcsb-analytics-traffic-origin,rcsb-analytics-traffic-stage + access-control-allow-methods: + - GET, OPTIONS, HEAD, PUT, POST + access-control-allow-origin: + - '*' + alt-svc: + - h3=":443";ma=60; + content-encoding: + - gzip + content-type: + - application/json;charset=UTF-8 + date: + - Wed, 25 Mar 2026 12:23:16 GMT + strict-transport-security: + - max-age=16000000; includeSubDomains; preload; + transfer-encoding: + - chunked + vary: + - accept-encoding + status: + code: 200 + message: OK +version: 1 diff --git a/openfold3/tests/test_data/cassettes/test_rscb/TestGetModelRankingFit.test_entry_without_ligands.yaml b/openfold3/tests/test_data/cassettes/test_rscb/TestGetModelRankingFit.test_entry_without_ligands.yaml new file mode 100644 index 000000000..570b642ae --- /dev/null +++ b/openfold3/tests/test_data/cassettes/test_rscb/TestGetModelRankingFit.test_entry_without_ligands.yaml @@ -0,0 +1,46 @@ +interactions: +- request: + body: '{"query": "\nquery GetRankingFit($pdb_id: String!) {\n entry(entry_id: + $pdb_id) {\n nonpolymer_entities {\n nonpolymer_entity_instances + {\n rcsb_id\n rcsb_nonpolymer_instance_validation_score + {\n ranking_model_fit\n }\n }\n }\n }\n}\n", + "variables": {"pdb_id": "1rnb"}}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate, br, zstd + Connection: + - keep-alive + Content-Length: + - '375' + Content-Type: + - application/json + User-Agent: + - python-requests/2.32.5 + method: POST + uri: https://data.rcsb.org/graphql + response: + body: + string: '{"data":{"entry":{"nonpolymer_entities":[{"nonpolymer_entity_instances":[{"rcsb_id":"1RNB.C","rcsb_nonpolymer_instance_validation_score":[{"ranking_model_fit":null}]}]}]}}}' + headers: + access-control-allow-headers: + - X-Requested-With,Content-Type,Authorization,rcsb-analytics-traffic-origin,rcsb-analytics-traffic-stage + access-control-allow-methods: + - GET, OPTIONS, HEAD, PUT, POST + access-control-allow-origin: + - '*' + alt-svc: + - h3=":443";ma=60; + content-length: + - '172' + content-type: + - application/json;charset=UTF-8 + date: + - Wed, 25 Mar 2026 12:23:16 GMT + strict-transport-security: + - max-age=16000000; includeSubDomains; preload; + status: + code: 200 + message: OK +version: 1 diff --git a/openfold3/tests/test_data/cassettes/test_rscb/TestGetModelRankingFit.test_nonexistent_entry.yaml b/openfold3/tests/test_data/cassettes/test_rscb/TestGetModelRankingFit.test_nonexistent_entry.yaml new file mode 100644 index 000000000..6cdbf7b20 --- /dev/null +++ b/openfold3/tests/test_data/cassettes/test_rscb/TestGetModelRankingFit.test_nonexistent_entry.yaml @@ -0,0 +1,46 @@ +interactions: +- request: + body: '{"query": "\nquery GetRankingFit($pdb_id: String!) {\n entry(entry_id: + $pdb_id) {\n nonpolymer_entities {\n nonpolymer_entity_instances + {\n rcsb_id\n rcsb_nonpolymer_instance_validation_score + {\n ranking_model_fit\n }\n }\n }\n }\n}\n", + "variables": {"pdb_id": "0000"}}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate, br, zstd + Connection: + - keep-alive + Content-Length: + - '375' + Content-Type: + - application/json + User-Agent: + - python-requests/2.32.5 + method: POST + uri: https://data.rcsb.org/graphql + response: + body: + string: '{"data":{"entry":null}}' + headers: + access-control-allow-headers: + - X-Requested-With,Content-Type,Authorization,rcsb-analytics-traffic-origin,rcsb-analytics-traffic-stage + access-control-allow-methods: + - GET, OPTIONS, HEAD, PUT, POST + access-control-allow-origin: + - '*' + alt-svc: + - h3=":443";ma=60; + content-length: + - '23' + content-type: + - application/json;charset=UTF-8 + date: + - Wed, 25 Mar 2026 12:23:16 GMT + strict-transport-security: + - max-age=16000000; includeSubDomains; preload; + status: + code: 200 + message: OK +version: 1 diff --git a/pyproject.toml b/pyproject.toml index 6837874bd..c428ef6cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,7 @@ test = [ "pytest-cov", "pytest-benchmark", "debugpy", + "pytest-recording", ] [project.optional-dependencies]