From 8af9bf86782dd13c98b52b2050621fef52e5609e Mon Sep 17 00:00:00 2001 From: MrGreyfun <72959013+MrGreyfun@users.noreply.github.com> Date: Fri, 21 Nov 2025 15:55:44 +0800 Subject: [PATCH 1/2] Save diffusion trajectory in mmCIF format. --- src/boltzgen/task/predict/writer.py | 35 ++++++++++++++--------------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/src/boltzgen/task/predict/writer.py b/src/boltzgen/task/predict/writer.py index 00279589..5dbb848e 100755 --- a/src/boltzgen/task/predict/writer.py +++ b/src/boltzgen/task/predict/writer.py @@ -431,7 +431,7 @@ def write_on_batch_end( # noqa: PLR0915 .squeeze() ) - pdbs = [] + mmcifs = [] all_coords = [] ensemble = [] atom_idx = 0 @@ -447,7 +447,7 @@ def write_on_batch_end( # noqa: PLR0915 raise ValueError("Either atom14 or atom37 must be true") str_frame, _, _ = Structure.from_feat(sample) - pdbs.append(to_pdb(str_frame)) + mmcifs.append(to_mmcif(str_frame)) all_coords.append(str_frame.coords) ensemble.append( ( @@ -457,8 +457,8 @@ def write_on_batch_end( # noqa: PLR0915 ) atom_idx += len(str_frame.coords) - open(self.outdir / f"{file_name}_traj.pdb", "w").write( - self.combine_pdb_models(pdbs) + open(self.outdir / f"{file_name}_traj.cif", "w").write( + self.combine_mmcif_models(mmcifs) ) # Write x0 trajectories @@ -479,7 +479,7 @@ def write_on_batch_end( # noqa: PLR0915 .squeeze() ) - pdbs = [] + mmcifs = [] all_coords = [] ensemble = [] atom_idx = 0 @@ -495,7 +495,7 @@ def write_on_batch_end( # noqa: PLR0915 raise ValueError("Either atom14 or atom37 must be true") str_frame, _, _ = Structure.from_feat(sample) - pdbs.append(to_pdb(str_frame)) + mmcifs.append(to_mmcif(str_frame)) all_coords.append(str_frame.coords) ensemble.append( ( @@ -505,8 +505,8 @@ def write_on_batch_end( # noqa: PLR0915 ) atom_idx += len(str_frame.coords) - open(self.outdir / f"{file_name}_x0_traj.pdb", "w").write( - self.combine_pdb_models(pdbs) + open(self.outdir / f"{file_name}_x0_traj.cif", "w").write( + self.combine_mmcif_models(mmcifs) ) except Exception as e: # noqa: BLE001 @@ -516,18 +516,17 @@ def write_on_batch_end( # noqa: PLR0915 msg = f"predict/writer.py: Validation structure writing failed on {batch['id'][0]} with error {e}. Skipping." print(msg) - def combine_pdb_models(self, pdb_strings): - combined_pdb = "" - model_number = 1 + def combine_mmcif_models(self, mmcif_strings): + import gemmi - for pdb in pdb_strings: - # Add a model number at the start of each model - combined_pdb += f"MODEL {model_number}\n" - combined_pdb += pdb.split("\nEND")[0] - combined_pdb += "\nENDMDL\n" # End of model marker - model_number += 1 + gemmi_structure = gemmi.Structure() + for model_number, mmcif_string in enumerate(mmcif_strings, start=1): + block = gemmi.cif.read_string(mmcif_string).sole_block() + gemmi_model = gemmi.make_structure_from_block(block)[0] + gemmi_model.name = str(model_number) + gemmi_structure.add_model(gemmi_model) - return combined_pdb + return gemmi_structure.make_mmcif_document().as_string() def on_predict_epoch_end( self, From 9f03d3b2682f476f6105546d60c17bdb058c80c0 Mon Sep 17 00:00:00 2001 From: MrGreyfun <72959013+MrGreyfun@users.noreply.github.com> Date: Fri, 21 Nov 2025 16:30:39 +0800 Subject: [PATCH 2/2] Move gemmi import to head --- src/boltzgen/task/predict/writer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/boltzgen/task/predict/writer.py b/src/boltzgen/task/predict/writer.py index 5dbb848e..c31946a6 100755 --- a/src/boltzgen/task/predict/writer.py +++ b/src/boltzgen/task/predict/writer.py @@ -4,6 +4,7 @@ import numpy as np import torch +import gemmi from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import BasePredictionWriter from torch import Tensor @@ -517,8 +518,6 @@ def write_on_batch_end( # noqa: PLR0915 print(msg) def combine_mmcif_models(self, mmcif_strings): - import gemmi - gemmi_structure = gemmi.Structure() for model_number, mmcif_string in enumerate(mmcif_strings, start=1): block = gemmi.cif.read_string(mmcif_string).sole_block()