Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 31 additions & 33 deletions openfold3/core/data/tools/colabfold_msa_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __str__(self) -> str:

def query_colabfold_msa_server(
x: list[str],
*,
prefix: Path,
user_agent: str,
use_templates: bool = False,
Expand All @@ -75,7 +76,7 @@ def query_colabfold_msa_server(
use_filter: bool = True,
filter: bool | None = None,
host_url: str = "https://api.colabfold.com",
) -> list[str] | tuple[list[str], list[str]]:
) -> list[str] | tuple[list[str], list[str | None]]:
"""Submints a single query to the colabfold MSA server.

Adapted from Colabfold run_mmseqs2 https://github.com/sokrypton/ColabFold/blob/main/colabfold/colabfold.py#L69
Expand Down Expand Up @@ -124,7 +125,7 @@ def query_colabfold_msa_server(
"in the future."
)

def submit(seqs, mode, N=101):
def submit(seqs: list[str], mode: str, N: int = 101) -> dict[str, str]:
n, query = N, ""
for seq in seqs:
query += f">{n}\n{seq}\n"
Expand Down Expand Up @@ -159,13 +160,13 @@ def submit(seqs, mode, N=101):
break

try:
out = res.json()
out: dict[str, str] = res.json()
except ValueError:
logger.error(f"Server didn't reply with json: {res.text}")
out = {"status": "ERROR"}
return out

def status(ID):
def status(ID: str) -> dict[str, str]:
while True:
error_count = 0
try:
Expand All @@ -190,13 +191,13 @@ def status(ID):
continue
break
try:
out = res.json()
out: dict[str, str] = res.json()
except ValueError:
logger.error(f"Server didn't reply with json: {res.text}")
out = {"status": "ERROR"}
return out

def download(ID, path):
def download(ID: str, path: str) -> None:
error_count = 0
while True:
try:
Expand Down Expand Up @@ -235,14 +236,14 @@ def download(ID, path):
else:
mode = "env-nofilter" if use_env else "nofilter"
# TODO move to config construction
pairing_strategy = MsaServerPairingStrategy[pairing_strategy.upper()]
pairing_mode = MsaServerPairingStrategy[pairing_strategy.upper()]
if use_pairing:
use_templates = False
mode = ""
# greedy is default, complete was the previous behavior
if pairing_strategy == MsaServerPairingStrategy.GREEDY:
if pairing_mode == MsaServerPairingStrategy.GREEDY:
mode = "pairgreedy"
elif pairing_strategy == MsaServerPairingStrategy.COMPLETE:
elif pairing_mode == MsaServerPairingStrategy.COMPLETE:
mode = "paircomplete"
if use_env:
mode = mode + "-env"
Expand All @@ -260,7 +261,9 @@ def download(ID, path):
seqs_unique = []
# TODO this might be slow for large sets - see main MSA deduplication code for a
# faster option
[seqs_unique.append(x) for x in seqs if x not in seqs_unique]
for s in seqs:
if s not in seqs_unique:
seqs_unique.append(s)
Ms = [N + seqs_unique.index(seq) for seq in seqs]

# Run query
Expand Down Expand Up @@ -336,17 +339,16 @@ def download(ID, path):

# Process templates
if use_templates:
templates = {}
templates: dict[int, list[str]] = {}
with open(f"{path}/pdb70.m8") as f:
for line in f:
p = line.rstrip().split()
M, pdb, _, _ = p[0], p[1], p[2], p[10] # M, pdb, qid, e_value
M = int(M)
if M not in templates:
templates[M] = []
templates[M].append(pdb)
m_idx, pdb = int(p[0]), p[1]
if m_idx not in templates:
templates[m_idx] = []
templates[m_idx].append(pdb)

template_paths = {}
template_paths_by_m: dict[int, str] = {}
for k, TMPL in templates.items():
TMPL_PATH = f"{prefix}/templates_{k}"
if not os.path.isdir(TMPL_PATH):
Expand Down Expand Up @@ -386,36 +388,32 @@ def download(ID, path):
os.symlink("pdb70_a3m.ffindex", f"{TMPL_PATH}/pdb70_cs219.ffindex")
with open(f"{TMPL_PATH}/pdb70_cs219.ffdata", "w") as f:
f.write("")
template_paths[k] = TMPL_PATH
template_paths_by_m[k] = TMPL_PATH

template_paths_ = []
for n in Ms:
if n not in template_paths:
template_paths_.append(None)
else:
template_paths_.append(template_paths[n])
template_paths = template_paths_
template_paths_list: list[str | None] = [template_paths_by_m.get(n) for n in Ms]

# Gather a3m lines
a3m_lines = {}
a3m_by_m: dict[int, list[str]] = {}
for a3m_file in a3m_files:
update_M, M = True, None
update_M = True
current_m: int | None = None
with open(a3m_file) as f:
for line in f:
if len(line) > 0:
if "\x00" in line:
line = line.replace("\x00", "")
update_M = True
if line.startswith(">") and update_M:
M = int(line[1:].rstrip())
current_m = int(line[1:].rstrip())
update_M = False
if M not in a3m_lines:
a3m_lines[M] = []
a3m_lines[M].append(line)
if current_m not in a3m_by_m:
a3m_by_m[current_m] = []
if current_m is not None:
a3m_by_m[current_m].append(line)

a3m_lines = ["".join(a3m_lines[n]) for n in Ms]
a3m_lines_out = ["".join(a3m_by_m[n]) for n in Ms]

return (a3m_lines, template_paths) if use_templates else a3m_lines
return (a3m_lines_out, template_paths_list) if use_templates else a3m_lines_out


class ChainInput(NamedTuple):
Expand Down
75 changes: 75 additions & 0 deletions openfold3/tests/core/data/pipelines/preprocessing/test_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from pathlib import Path

import openfold3
from openfold3.core.data.io.sequence.template import (
A3mParser,
parse_template_alignment,
)
from openfold3.core.data.io.structure.cif import _load_ciffile
from openfold3.core.data.primitives.structure.metadata import (
get_asym_id_to_canonical_seq_dict,
get_author_to_label_chain_ids,
)

_TEST_DATA_DIR = Path(openfold3.__file__).parent / "tests" / "test_data"


class TestTemplatePreprocessor:
def test_template_has_author_chain_id(self):
"""Verify author->label chain ID resolution for 1RNB.

https://github.com/aqlaboratory/openfold-3/issues/101

In 1RNB, author chain "A" is label chain "B" (the protein barnase).
The ColabFold alignment reports "1rnb_A" which must be resolved to
label chain "B" before the sequence can be looked up.
"""
alignment_file = (
_TEST_DATA_DIR / "template_alignments" / "colabfold_template.m8"
)
query_seq_str = "AQVINTFDGVADYLQTYHKLPDNYITKSEAQALGWVASKGNLADVAPGKSIGGDIFSNREGKLPGKSGRTWREADINYTSGFRNSDRILYSSDWLIYKTTDHYQTFTKIR"
templates = parse_template_alignment(
aln_path=Path(alignment_file),
query_seq_str=query_seq_str,
max_sequences=200,
)

# find the offending "1rnb_A"
template = templates[16]
assert template.chain_id == "A" and template.entry_id == "1rnb"

template_structure_file = _TEST_DATA_DIR / "mmcifs" / f"{template.entry_id}.cif"
cif_file = _load_ciffile(template_structure_file)

chain_id_seq_map = get_asym_id_to_canonical_seq_dict(cif_file)
poly_scheme = cif_file.block["pdbx_poly_seq_scheme"]
label_to_author = dict(
zip(
poly_scheme["asym_id"].as_array().tolist(),
poly_scheme["pdb_strand_id"].as_array().tolist(),
strict=True,
)
)
author_to_label_chain_ids = get_author_to_label_chain_ids(label_to_author)
label_chain_id = author_to_label_chain_ids[template.chain_id][0]

# Author "A" -> label "B" (the protein chain)
assert label_chain_id == "B"

template_sequence = chain_id_seq_map.get(label_chain_id)

parser = A3mParser(max_sequences=None)
parsed = parser(
(
f">query_X/1-{len(query_seq_str)}\n"
f"{query_seq_str}\n"
f">{template.entry_id}_{label_chain_id}/{1}-{len(template_sequence)}\n"
f"{template_sequence}\n"
),
query_seq_str,
realign=True,
)

assert len(parsed) == 2
assert parsed[0].seq_id == 1
assert parsed[1].seq_id < 1
54 changes: 54 additions & 0 deletions openfold3/tests/core/data/tools/test_colabfold_msa_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
collect_colabfold_msa_data,
get_sequence_hash,
preprocess_colabfold_msas,
query_colabfold_msa_server,
remap_colabfold_template_chain_ids,
)
from openfold3.projects.of3_all_atom.config.dataset_config_components import MSASettings
Expand Down Expand Up @@ -524,3 +525,56 @@ def test_cli_output_dir_conflict_raises(self, tmp_path):
assert "Output directory mismatch" in str(exc_info.value), (
"Expected ValueError on output directory conflict"
)


# Barnase sequence — 1RNB author chain A = label chain B
_BARNASE_SEQ = (
"AQVINTFDGVADYLQTYHKLPDNYITKSEAQALGWVASKGNLADVAPGKSIGGDIFSNREGKLPGK"
"SGRTWREADINYTSGFRNSDRILYSSDWLIYKTTDHYQTFTKIR"
)


class TestQueryColabfoldMsaServer:
"""Functional test — hits real ColabFold API (~30-60s)."""

def test_barnase_with_templates(self, tmp_path):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I bet we could reuse the ColabFold mmseqs mock itself :
Mock definition: https://github.com/sokrypton/ColabFold/blob/main/tests/mock.py#L109
Usage: https://github.com/sokrypton/ColabFold/blob/9712f2ff262d3977d571919317e06cc96c29cd95/tests/test_msa.py#L7

That would probably be the cleanest way to handle it.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would have to vendor it though :/ Not ideal – ColabFold is a very bulk dep to add, i'd rather avoid it

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've also modified their code a fair bit, i think

Copy link
Copy Markdown
Collaborator Author

@jandom jandom Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, there @jnwei let me re-vive this PR – I think we can do the pytest-VCR approach to these tests

Ah, script that – we need the full colabfold job to complete. They appear to be using some sort of caching, subsequent calls are instant basically

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just reviewed the mock code you provided, yeah let's vendor that – it'll be the easiest.

raw_dir = tmp_path / "raw"
a3m_lines, template_paths = query_colabfold_msa_server(
x=[_BARNASE_SEQ],
prefix=raw_dir,
user_agent="openfold-test/1.0",
use_templates=True,
use_pairing=False,
use_env=True,
use_filter=True,
)

# -- Returned values --
assert len(a3m_lines) == 1
assert a3m_lines[0].startswith(f">101\n{_BARNASE_SEQ}\n")
n_seqs = sum(1 for l in a3m_lines[0].strip().split("\n") if l.startswith(">"))
assert n_seqs > 10 # barnase is well-represented

assert len(template_paths) == 1
tpl_dir = Path(template_paths[0])
assert tpl_dir.name == "templates_101"
assert tpl_dir.is_dir()
cif_files = list(tpl_dir.glob("*.cif"))
assert len(cif_files) > 0

# -- Files on disk --
assert (raw_dir / "uniref.a3m").stat().st_size > 0
assert (raw_dir / "bfd.mgnify30.metaeuk30.smag30.a3m").stat().st_size > 0
assert (raw_dir / "out.tar.gz").exists()

pdb70 = raw_dir / "pdb70.m8"
assert pdb70.stat().st_size > 0
m8_lines = pdb70.read_text().strip().split("\n")
template_ids = [l.split("\t")[1] for l in m8_lines]

# All hits should be for M-index 101
assert all(l.split("\t")[0] == "101" for l in m8_lines)
# Template IDs are pdb_chain format
assert all("_" in tid for tid in template_ids)
# 1rnb_A must appear (the bug case — author chain A != label chain B)
assert "1rnb_A" in template_ids
Loading
Loading