Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 177 additions & 0 deletions refold/refold_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def run(self, action: str, **kwargs) -> dict[str, Any]:
"make_af3_json_pbp_design_only": self.make_af3_json_pbp_design_only,
"make_chai1_fasta_multi_process": self.make_chai1_fasta_multi_process,
"prepare_chai1_fasta": self.make_chai1_fasta_multi_process,
"make_chai1_fasta_from_backbone_dir": self.make_chai1_fasta_from_backbone_dir,
}
if action not in dispatch:
known = ", ".join(sorted(dispatch.keys()))
Expand Down Expand Up @@ -1512,3 +1513,179 @@ def generate_chai1_input_fasta(backbone_path: Path, output_path: Path) -> bool:
"fail_count": fail_count,
},
)

def make_chai1_fasta_from_backbone_dir(
self,
backbone_dir: str,
output_dir: str,
ccd_path: str = None,
inverse_fold_dir: str = None,
):
"""
Generate chai1 input FASTA files for LBP/interface tasks.
Unlike make_chai1_fasta_multi_process (AME-only), this method reads ligand
SMILES directly from each backbone PDB using the CCD database, without
requiring an AME CSV or task-name lookup.

Args:
backbone_dir: Directory containing backbone PDB files from LigandMPNN.
output_dir: Output directory for FASTA files.
ccd_path: Path to CCD components.cif file. Searches common locations if None.
inverse_fold_dir: Root inversefold dir (parent of backbones/); used to find
seqs/ for LigandMPNN-generated sequences.
"""
import re
import concurrent.futures
from rdkit import Chem

backbone_path_list = list(Path(backbone_dir).glob("*.pdb"))
os.makedirs(output_dir, exist_ok=True)

# Locate CCD components.cif
if ccd_path is None:
candidates = [
Path(__file__).resolve().parents[2] / "ODesign-pipeline" / "ODesign" / "data" / "components.v20240608.cif",
Path(__file__).resolve().parents[1] / "data" / "components.cif",
Path("/mnt/shared-storage-user/liuxinping/ODesign-pipeline/ODesign/data/components.v20240608.cif"),
]
for c in candidates:
if c.exists():
ccd_path = str(c)
break
if ccd_path is None or not Path(ccd_path).exists():
raise FileNotFoundError(
f"CCD components.cif not found. Tried: {candidates}. "
"Pass ccd_path= explicitly or set it in config."
)

ccd_parser = LocalCcdParser(ccd_path)

# Locate seqs/ directory for LigandMPNN-generated sequences
if inverse_fold_dir is None:
inverse_fold_dir = str(Path(backbone_dir).parent.parent / "inverse_fold")
seqs_dir = Path(inverse_fold_dir) / "seqs"

def get_sequence_from_seqs(backbone_stem: str) -> str:
"""Read the LigandMPNN sequence for this backbone from seqs/."""
seed_idx = None
if "-" in backbone_stem:
tail = backbone_stem.rsplit("-", 1)[1]
if tail.isdigit():
seed_idx = int(tail)

def _first_seq_in_fasta(fasta_file: Path, idx: int | None) -> str:
try:
with open(fasta_file) as f:
lines = [l.rstrip("\n") for l in f]
except Exception:
return ""
seqs = []
hdr, cur = None, None
for line in lines:
if line.startswith(">"):
if hdr is not None and cur is not None:
seqs.append((hdr, cur))
hdr, cur = line[1:].strip(), ""
elif hdr is not None:
cur += line.strip()
if hdr is not None and cur is not None:
seqs.append((hdr, cur))
if not seqs:
return ""
if idx is None:
return seqs[0][1].replace(":", "").replace(";", "").upper()
for header, seq in seqs:
m = re.search(r"id=(\d+)", header)
if m and int(m.group(1)) == idx:
return seq.replace(":", "").replace(";", "").upper()
i = idx - 1
if 0 <= i < len(seqs):
return seqs[i][1].replace(":", "").replace(";", "").upper()
return seqs[0][1].replace(":", "").replace(";", "").upper()

if not seqs_dir.exists():
return None
for fa in seqs_dir.glob(f"{backbone_stem}*.fa"):
s = _first_seq_in_fasta(fa, seed_idx)
if s:
return s
design_name = backbone_stem.rsplit("-", 1)[0] if "-" in backbone_stem else backbone_stem
for fa in seqs_dir.glob(f"{design_name}*.fa"):
s = _first_seq_in_fasta(fa, seed_idx)
if s:
return s
return None

def generate_fasta(backbone_path: Path, output_path: Path) -> bool:
"""Generate chai1 FASTA for one backbone PDB."""
# 1. Read backbone structure
try:
pdb_file = pdb.PDBFile.read(backbone_path)
atom_array = pdb.get_structure(pdb_file, model=1)
except Exception as e:
print(f"Warning: Failed to read {backbone_path}: {e}")
return False

# 2. Get protein sequence from LigandMPNN seqs/
protein_seq = get_sequence_from_seqs(backbone_path.stem)
if protein_seq is None:
# Fallback: derive from backbone atom array
try:
protein_atoms = atom_array[~atom_array.hetero]
if len(protein_atoms) > 0:
protein_seq = str(struc.to_sequence(protein_atoms)[0][0])
else:
print(f"Warning: No protein atoms in {backbone_path.name}")
return False
except Exception as e:
print(f"Warning: Cannot derive sequence from {backbone_path.name}: {e}")
return False

fasta_lines = [f">protein|name=protein\n{protein_seq}\n"]

# 3. Extract ligand SMILES from HETATM records
processed = set()
ligand_atoms = atom_array[atom_array.hetero]
for res_name in np.unique(ligand_atoms.res_name):
code = res_name.upper().strip()
if code in ("HOH", "WAT", "") or code in processed:
continue
smiles_result = ccd_parser.get_smiles(code)
if smiles_result is None:
print(f"Warning: No CCD SMILES for {code} in {backbone_path.name}")
continue
smiles = smiles_result[0] if isinstance(smiles_result, list) else smiles_result
smiles = smiles.strip("\"'")
if Chem.MolFromSmiles(smiles) is None:
print(f"Warning: Invalid SMILES for {code}: {smiles}")
continue
fasta_lines.append(f">ligand|name={code}\n{smiles}\n")
processed.add(code)

if len(fasta_lines) == 1:
print(f"Warning: No valid ligands found in {backbone_path.name}, writing protein-only FASTA")

output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
f.writelines(fasta_lines)
return True

with concurrent.futures.ThreadPoolExecutor(max_workers=self.config.refold.num_workers) as executor:
futures = {
executor.submit(generate_fasta, bp, Path(output_dir) / f"{bp.stem}.fasta"): bp
for bp in backbone_path_list
}
success_count = sum(1 for f in concurrent.futures.as_completed(futures) if f.result())
fail_count = len(backbone_path_list) - success_count

print(f"\n完成!成功: {success_count}, 失败: {fail_count}")
print(f"输出目录: {output_dir}")
return self._make_result(
stage="refold.make_chai1_fasta_from_backbone_dir",
outputs={"output_dir": str(output_dir)},
details={
"num_backbones": len(backbone_path_list),
"success_count": success_count,
"fail_count": fail_count,
},
)
4 changes: 2 additions & 2 deletions scripts/pipeline_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,10 +674,10 @@ def _run_refold_af3_pbp_dual(ctx: PipelineContext) -> None:

def _prepare_refold_chai1(ctx: PipelineContext) -> None:
ctx.runtime["refold_prepare"] = ctx.refold_model.run(
action="make_chai1_fasta_multi_process",
action="make_chai1_fasta_from_backbone_dir",
backbone_dir=str(ctx.pipeline_dir / "inverse_fold" / "backbones"),
output_dir=str(ctx.pipeline_dir / "refold" / "chai1_inputs"),
origin_cwd=ctx.origin_cwd,
inverse_fold_dir=str(ctx.pipeline_dir / "inverse_fold"),
)


Expand Down