diff --git a/src/boltzgen/task/predict/writer.py b/src/boltzgen/task/predict/writer.py index 00279589..c31946a6 100755 --- a/src/boltzgen/task/predict/writer.py +++ b/src/boltzgen/task/predict/writer.py @@ -4,6 +4,7 @@ import numpy as np import torch +import gemmi from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import BasePredictionWriter from torch import Tensor @@ -431,7 +432,7 @@ def write_on_batch_end( # noqa: PLR0915 .squeeze() ) - pdbs = [] + mmcifs = [] all_coords = [] ensemble = [] atom_idx = 0 @@ -447,7 +448,7 @@ def write_on_batch_end( # noqa: PLR0915 raise ValueError("Either atom14 or atom37 must be true") str_frame, _, _ = Structure.from_feat(sample) - pdbs.append(to_pdb(str_frame)) + mmcifs.append(to_mmcif(str_frame)) all_coords.append(str_frame.coords) ensemble.append( ( @@ -457,8 +458,8 @@ def write_on_batch_end( # noqa: PLR0915 ) atom_idx += len(str_frame.coords) - open(self.outdir / f"{file_name}_traj.pdb", "w").write( - self.combine_pdb_models(pdbs) + open(self.outdir / f"{file_name}_traj.cif", "w").write( + self.combine_mmcif_models(mmcifs) ) # Write x0 trajectories @@ -479,7 +480,7 @@ def write_on_batch_end( # noqa: PLR0915 .squeeze() ) - pdbs = [] + mmcifs = [] all_coords = [] ensemble = [] atom_idx = 0 @@ -495,7 +496,7 @@ def write_on_batch_end( # noqa: PLR0915 raise ValueError("Either atom14 or atom37 must be true") str_frame, _, _ = Structure.from_feat(sample) - pdbs.append(to_pdb(str_frame)) + mmcifs.append(to_mmcif(str_frame)) all_coords.append(str_frame.coords) ensemble.append( ( @@ -505,8 +506,8 @@ def write_on_batch_end( # noqa: PLR0915 ) atom_idx += len(str_frame.coords) - open(self.outdir / f"{file_name}_x0_traj.pdb", "w").write( - self.combine_pdb_models(pdbs) + open(self.outdir / f"{file_name}_x0_traj.cif", "w").write( + self.combine_mmcif_models(mmcifs) ) except Exception as e: # noqa: BLE001 @@ -516,18 +517,15 @@ def write_on_batch_end( # noqa: PLR0915 msg = f"predict/writer.py: Validation structure writing failed on {batch['id'][0]} with error {e}. Skipping." print(msg) - def combine_pdb_models(self, pdb_strings): - combined_pdb = "" - model_number = 1 + def combine_mmcif_models(self, mmcif_strings): + gemmi_structure = gemmi.Structure() + for model_number, mmcif_string in enumerate(mmcif_strings, start=1): + block = gemmi.cif.read_string(mmcif_string).sole_block() + gemmi_model = gemmi.make_structure_from_block(block)[0] + gemmi_model.name = str(model_number) + gemmi_structure.add_model(gemmi_model) - for pdb in pdb_strings: - # Add a model number at the start of each model - combined_pdb += f"MODEL {model_number}\n" - combined_pdb += pdb.split("\nEND")[0] - combined_pdb += "\nENDMDL\n" # End of model marker - model_number += 1 - - return combined_pdb + return gemmi_structure.make_mmcif_document().as_string() def on_predict_epoch_end( self,