diff --git a/refold/refold_api.py b/refold/refold_api.py index bb17762..93b1cc7 100644 --- a/refold/refold_api.py +++ b/refold/refold_api.py @@ -171,6 +171,7 @@ def run(self, action: str, **kwargs) -> dict[str, Any]: "make_af3_json_pbp_design_only": self.make_af3_json_pbp_design_only, "make_chai1_fasta_multi_process": self.make_chai1_fasta_multi_process, "prepare_chai1_fasta": self.make_chai1_fasta_multi_process, + "make_chai1_fasta_from_backbone_dir": self.make_chai1_fasta_from_backbone_dir, } if action not in dispatch: known = ", ".join(sorted(dispatch.keys())) @@ -1512,3 +1513,179 @@ def generate_chai1_input_fasta(backbone_path: Path, output_path: Path) -> bool: "fail_count": fail_count, }, ) + + def make_chai1_fasta_from_backbone_dir( + self, + backbone_dir: str, + output_dir: str, + ccd_path: str = None, + inverse_fold_dir: str = None, + ): + """ + Generate chai1 input FASTA files for LBP/interface tasks. + Unlike make_chai1_fasta_multi_process (AME-only), this method reads ligand + SMILES directly from each backbone PDB using the CCD database, without + requiring an AME CSV or task-name lookup. + + Args: + backbone_dir: Directory containing backbone PDB files from LigandMPNN. + output_dir: Output directory for FASTA files. + ccd_path: Path to CCD components.cif file. Searches common locations if None. + inverse_fold_dir: Root inversefold dir (parent of backbones/); used to find + seqs/ for LigandMPNN-generated sequences. + """ + import re + import concurrent.futures + from rdkit import Chem + + backbone_path_list = list(Path(backbone_dir).glob("*.pdb")) + os.makedirs(output_dir, exist_ok=True) + + # Locate CCD components.cif + if ccd_path is None: + candidates = [ + Path(__file__).resolve().parents[2] / "ODesign-pipeline" / "ODesign" / "data" / "components.v20240608.cif", + Path(__file__).resolve().parents[1] / "data" / "components.cif", + Path("/mnt/shared-storage-user/liuxinping/ODesign-pipeline/ODesign/data/components.v20240608.cif"), + ] + for c in candidates: + if c.exists(): + ccd_path = str(c) + break + if ccd_path is None or not Path(ccd_path).exists(): + raise FileNotFoundError( + f"CCD components.cif not found. Tried: {candidates}. " + "Pass ccd_path= explicitly or set it in config." + ) + + ccd_parser = LocalCcdParser(ccd_path) + + # Locate seqs/ directory for LigandMPNN-generated sequences + if inverse_fold_dir is None: + inverse_fold_dir = str(Path(backbone_dir).parent.parent / "inverse_fold") + seqs_dir = Path(inverse_fold_dir) / "seqs" + + def get_sequence_from_seqs(backbone_stem: str) -> str: + """Read the LigandMPNN sequence for this backbone from seqs/.""" + seed_idx = None + if "-" in backbone_stem: + tail = backbone_stem.rsplit("-", 1)[1] + if tail.isdigit(): + seed_idx = int(tail) + + def _first_seq_in_fasta(fasta_file: Path, idx: int | None) -> str: + try: + with open(fasta_file) as f: + lines = [l.rstrip("\n") for l in f] + except Exception: + return "" + seqs = [] + hdr, cur = None, None + for line in lines: + if line.startswith(">"): + if hdr is not None and cur is not None: + seqs.append((hdr, cur)) + hdr, cur = line[1:].strip(), "" + elif hdr is not None: + cur += line.strip() + if hdr is not None and cur is not None: + seqs.append((hdr, cur)) + if not seqs: + return "" + if idx is None: + return seqs[0][1].replace(":", "").replace(";", "").upper() + for header, seq in seqs: + m = re.search(r"id=(\d+)", header) + if m and int(m.group(1)) == idx: + return seq.replace(":", "").replace(";", "").upper() + i = idx - 1 + if 0 <= i < len(seqs): + return seqs[i][1].replace(":", "").replace(";", "").upper() + return seqs[0][1].replace(":", "").replace(";", "").upper() + + if not seqs_dir.exists(): + return None + for fa in seqs_dir.glob(f"{backbone_stem}*.fa"): + s = _first_seq_in_fasta(fa, seed_idx) + if s: + return s + design_name = backbone_stem.rsplit("-", 1)[0] if "-" in backbone_stem else backbone_stem + for fa in seqs_dir.glob(f"{design_name}*.fa"): + s = _first_seq_in_fasta(fa, seed_idx) + if s: + return s + return None + + def generate_fasta(backbone_path: Path, output_path: Path) -> bool: + """Generate chai1 FASTA for one backbone PDB.""" + # 1. Read backbone structure + try: + pdb_file = pdb.PDBFile.read(backbone_path) + atom_array = pdb.get_structure(pdb_file, model=1) + except Exception as e: + print(f"Warning: Failed to read {backbone_path}: {e}") + return False + + # 2. Get protein sequence from LigandMPNN seqs/ + protein_seq = get_sequence_from_seqs(backbone_path.stem) + if protein_seq is None: + # Fallback: derive from backbone atom array + try: + protein_atoms = atom_array[~atom_array.hetero] + if len(protein_atoms) > 0: + protein_seq = str(struc.to_sequence(protein_atoms)[0][0]) + else: + print(f"Warning: No protein atoms in {backbone_path.name}") + return False + except Exception as e: + print(f"Warning: Cannot derive sequence from {backbone_path.name}: {e}") + return False + + fasta_lines = [f">protein|name=protein\n{protein_seq}\n"] + + # 3. Extract ligand SMILES from HETATM records + processed = set() + ligand_atoms = atom_array[atom_array.hetero] + for res_name in np.unique(ligand_atoms.res_name): + code = res_name.upper().strip() + if code in ("HOH", "WAT", "") or code in processed: + continue + smiles_result = ccd_parser.get_smiles(code) + if smiles_result is None: + print(f"Warning: No CCD SMILES for {code} in {backbone_path.name}") + continue + smiles = smiles_result[0] if isinstance(smiles_result, list) else smiles_result + smiles = smiles.strip("\"'") + if Chem.MolFromSmiles(smiles) is None: + print(f"Warning: Invalid SMILES for {code}: {smiles}") + continue + fasta_lines.append(f">ligand|name={code}\n{smiles}\n") + processed.add(code) + + if len(fasta_lines) == 1: + print(f"Warning: No valid ligands found in {backbone_path.name}, writing protein-only FASTA") + + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w") as f: + f.writelines(fasta_lines) + return True + + with concurrent.futures.ThreadPoolExecutor(max_workers=self.config.refold.num_workers) as executor: + futures = { + executor.submit(generate_fasta, bp, Path(output_dir) / f"{bp.stem}.fasta"): bp + for bp in backbone_path_list + } + success_count = sum(1 for f in concurrent.futures.as_completed(futures) if f.result()) + fail_count = len(backbone_path_list) - success_count + + print(f"\n完成!成功: {success_count}, 失败: {fail_count}") + print(f"输出目录: {output_dir}") + return self._make_result( + stage="refold.make_chai1_fasta_from_backbone_dir", + outputs={"output_dir": str(output_dir)}, + details={ + "num_backbones": len(backbone_path_list), + "success_count": success_count, + "fail_count": fail_count, + }, + ) diff --git a/scripts/pipeline_framework.py b/scripts/pipeline_framework.py index 63b8db3..bcbfdb3 100644 --- a/scripts/pipeline_framework.py +++ b/scripts/pipeline_framework.py @@ -674,10 +674,10 @@ def _run_refold_af3_pbp_dual(ctx: PipelineContext) -> None: def _prepare_refold_chai1(ctx: PipelineContext) -> None: ctx.runtime["refold_prepare"] = ctx.refold_model.run( - action="make_chai1_fasta_multi_process", + action="make_chai1_fasta_from_backbone_dir", backbone_dir=str(ctx.pipeline_dir / "inverse_fold" / "backbones"), output_dir=str(ctx.pipeline_dir / "refold" / "chai1_inputs"), - origin_cwd=ctx.origin_cwd, + inverse_fold_dir=str(ctx.pipeline_dir / "inverse_fold"), )