Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 17 additions & 19 deletions src/boltzgen/task/predict/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import torch
import gemmi
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import BasePredictionWriter
from torch import Tensor
Expand Down Expand Up @@ -431,7 +432,7 @@ def write_on_batch_end( # noqa: PLR0915
.squeeze()
)

pdbs = []
mmcifs = []
all_coords = []
ensemble = []
atom_idx = 0
Expand All @@ -447,7 +448,7 @@ def write_on_batch_end( # noqa: PLR0915
raise ValueError("Either atom14 or atom37 must be true")

str_frame, _, _ = Structure.from_feat(sample)
pdbs.append(to_pdb(str_frame))
mmcifs.append(to_mmcif(str_frame))
all_coords.append(str_frame.coords)
ensemble.append(
(
Expand All @@ -457,8 +458,8 @@ def write_on_batch_end( # noqa: PLR0915
)
atom_idx += len(str_frame.coords)

open(self.outdir / f"{file_name}_traj.pdb", "w").write(
self.combine_pdb_models(pdbs)
open(self.outdir / f"{file_name}_traj.cif", "w").write(
self.combine_mmcif_models(mmcifs)
)

# Write x0 trajectories
Expand All @@ -479,7 +480,7 @@ def write_on_batch_end( # noqa: PLR0915
.squeeze()
)

pdbs = []
mmcifs = []
all_coords = []
ensemble = []
atom_idx = 0
Expand All @@ -495,7 +496,7 @@ def write_on_batch_end( # noqa: PLR0915
raise ValueError("Either atom14 or atom37 must be true")

str_frame, _, _ = Structure.from_feat(sample)
pdbs.append(to_pdb(str_frame))
mmcifs.append(to_mmcif(str_frame))
all_coords.append(str_frame.coords)
ensemble.append(
(
Expand All @@ -505,8 +506,8 @@ def write_on_batch_end( # noqa: PLR0915
)
atom_idx += len(str_frame.coords)

open(self.outdir / f"{file_name}_x0_traj.pdb", "w").write(
self.combine_pdb_models(pdbs)
open(self.outdir / f"{file_name}_x0_traj.cif", "w").write(
self.combine_mmcif_models(mmcifs)
)

except Exception as e: # noqa: BLE001
Expand All @@ -516,18 +517,15 @@ def write_on_batch_end( # noqa: PLR0915
msg = f"predict/writer.py: Validation structure writing failed on {batch['id'][0]} with error {e}. Skipping."
print(msg)

def combine_pdb_models(self, pdb_strings):
combined_pdb = ""
model_number = 1
def combine_mmcif_models(self, mmcif_strings):
gemmi_structure = gemmi.Structure()
for model_number, mmcif_string in enumerate(mmcif_strings, start=1):
block = gemmi.cif.read_string(mmcif_string).sole_block()
gemmi_model = gemmi.make_structure_from_block(block)[0]
gemmi_model.name = str(model_number)
gemmi_structure.add_model(gemmi_model)

for pdb in pdb_strings:
# Add a model number at the start of each model
combined_pdb += f"MODEL {model_number}\n"
combined_pdb += pdb.split("\nEND")[0]
combined_pdb += "\nENDMDL\n" # End of model marker
model_number += 1

return combined_pdb
return gemmi_structure.make_mmcif_document().as_string()

def on_predict_epoch_end(
self,
Expand Down