Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions openfold/utils/multiprocessing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from openfold.utils.multiprocessing.base_pipeline import BasePipeline, BaseTask
from openfold.utils.multiprocessing.msa_task import MsaTask
from openfold.utils.multiprocessing.feature_gen_task import FeatureGenTask
from openfold.utils.multiprocessing.inference_task import InferenceTask
from openfold.utils.multiprocessing.relaxation_task import RelaxationTask
123 changes: 123 additions & 0 deletions openfold/utils/multiprocessing/base_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import time
import logging
logging.basicConfig()
from torch.multiprocessing import get_context

class BaseTask:
def __init__(self, name):
self.name = name

def _process_start(self, ctx, params, in_q, out_q, exit_event):
self.process = ctx.Process(target=self._process_loop, args=(params, in_q, out_q, exit_event))
self.process.start()

def _process_join(self):
self.process.join()

def _process_loop(self, params, in_q, out_q, exit_event):
logging.info(f"{self.name}: process_func started")
self.setup(params)
while True:
item = in_q.get()
if not item: # None-like signals end of processing; forward it through the chain
out_q.put(item)
break
else:
out_q.put(self._process_one(item))
exit_event.wait()
self.teardown()
logging.info(f"{self.name}: process_func completed")

def _process_one(self, item):
logging.info(f"{self.name}: {item['tag']} started")
start_time = time.time()
try:
item = self.process_one(item)
except Exception as e:
logging.error(f"{self.name}: {item['tag']}: ")
logging.error(e)
end_time = time.time()
logging.info(f"{self.name}: {item['tag']} completed in {(end_time - start_time):0.2f} seconds")
return item

# Should be implemented by derived task
def setup(self, params):
return

# Should be implemented by derived task
def teardown(self):
return

# Should be implemented by derived task
def process_one(self, item):
return item

class BasePipeline:
def __new__(cls, multiprocessing=False, tasks=[]):
if cls is BasePipeline:
if multiprocessing:
return BasePipelineMultiprocessing(multiprocessing, tasks)
else:
return BasePipelineSequential(multiprocessing, tasks)
else:
return super().__new__(cls)

class BasePipelineMultiprocessing(BasePipeline):
def __init__(self, multiprocessing=True, tasks=[]):
self.ctx = get_context("spawn")
self.tasks = tasks # a list of BaseTask instances, to be connected in series
self.queues = [] # tasks[i] reads from queues[i] and writes to queues[i+1]
self.exit_events = [] # Each task should terminate only after the corresponding exit_event is signalled
for i in range(len(self.tasks)+1):
self.queues.append(self.ctx.Queue())
for i in range(len(self.tasks)):
self.exit_events.append(self.ctx.Event())

def start(self, params):
logging.info(f"pipeline: Starting in multiprocessing mode")
for i in range(len(self.tasks)):
self.tasks[i]._process_start(self.ctx, params, self.queues[i], self.queues[i+1], self.exit_events[i])

def run(self, inputs):
for item in inputs:
logging.info(f"pipeline: Sending {item['tag']}")
self.queues[0].put(item)
outputs = []
for item in inputs:
outputs.append(self.queues[-1].get())
logging.info(f"pipeline: Received {item['tag']}")
return outputs

def end(self):
for event in self.exit_events:
event.set()
self.queues[0].put(None) # Signals exit to all tasks
for task in self.tasks:
task._process_join()
for queue in self.queues:
queue.close()
logging.info(f"pipeline: Completing run")

class BasePipelineSequential(BasePipeline):
def __init__(self, multiprocessing=False, tasks=[]):
self.tasks = tasks

def start(self, params):
logging.info(f"pipeline: Starting in sequential mode")
for task in self.tasks:
task.setup(params)

def run(self, inputs):
for item in inputs:
logging.info(f"pipeline: Sending {item['tag']}")
for task in self.tasks:
item = task._process_one(item)
logging.info(f"pipeline: Received {item['tag']}")

def end(self):
for task in self.tasks:
task.teardown()
logging.info(f"pipeline: Completing run")
133 changes: 133 additions & 0 deletions openfold/utils/multiprocessing/feature_gen_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import os
import math
import torch
from openfold.data import feature_pipeline, templates, data_pipeline
from openfold.utils.trace_utils import pad_feature_dict_seq
from openfold.utils.multiprocessing.base_pipeline import BaseTask

TRACING_INTERVAL = 50

def round_up_seqlen(seqlen):
return int(math.ceil(seqlen / TRACING_INTERVAL)) * TRACING_INTERVAL

class FeatureGenTask(BaseTask):
def __init__(self):
super().__init__('feature_gen')

def setup(self, params):
self.params = params
args = self.params['args']
config = self.params['config']
is_multimer = "multimer" in args.config_preset
is_custom_template = "use_custom_template" in args and args.use_custom_template
if is_custom_template:
template_featurizer = templates.CustomHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date="9999-12-31", # just dummy, not used
max_hits=-1, # just dummy, not used
kalign_binary_path=args.kalign_binary_path
)
elif is_multimer:
template_featurizer = templates.HmmsearchHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
max_hits=config.data.predict.max_templates,
kalign_binary_path=args.kalign_binary_path,
release_dates_path=args.release_dates_path,
obsolete_pdbs_path=args.obsolete_pdbs_path
)
else:
template_featurizer = templates.HhsearchHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
max_hits=config.data.predict.max_templates,
kalign_binary_path=args.kalign_binary_path,
release_dates_path=args.release_dates_path,
obsolete_pdbs_path=args.obsolete_pdbs_path
)
self.data_processor = data_pipeline.DataPipeline(
template_featurizer=template_featurizer,
)
if is_multimer:
self.data_processor = data_pipeline.DataPipelineMultimer(
monomer_data_pipeline=self.data_processor,
)
self.feature_processor = feature_pipeline.FeaturePipeline(config.data)
self.feature_dicts = {}

def teardown(self):
pass

def process_one(self, item): # item is a tuple of (tag, tags, seqs)
tag = item['tag']
tags = item['tags']
seqs = item['seqs']
alignment_dir = self.params['alignment_dir']
args = self.params['args']
is_multimer = "multimer" in args.config_preset

feature_dict = self.feature_dicts.get(tag, None)
if feature_dict is None:
# Start of generate_feature_dict
tmp_fasta_path = os.path.join(self.params['args'].output_dir, f"tmp_{os.getpid()}_{tag}.fasta")

if "multimer" in args.config_preset:
with open(tmp_fasta_path, "w") as fp:
fp.write(
'\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
)
feature_dict = self.data_processor.process_fasta(
fasta_path=tmp_fasta_path, alignment_dir=alignment_dir,
)
elif len(seqs) == 1:
tag = tags[0]
seq = seqs[0]
with open(tmp_fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}")

local_alignment_dir = os.path.join(alignment_dir, tag)
feature_dict = self.data_processor.process_fasta(
fasta_path=tmp_fasta_path,
alignment_dir=local_alignment_dir,
seqemb_mode=args.use_single_seq_mode,
)
else:
with open(tmp_fasta_path, "w") as fp:
fp.write(
'\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
)
feature_dict = self.data_processor.process_multiseq_fasta(
fasta_path=tmp_fasta_path, super_alignment_dir=alignment_dir,
)
os.remove(tmp_fasta_path)
# End of generate_feature_dict

if args.trace_model:
n = feature_dict["aatype"].shape[-2]
rounded_seqlen = round_up_seqlen(n)
item['rounded_seqlen'] = rounded_seqlen
feature_dict = pad_feature_dict_seq(
feature_dict, rounded_seqlen,
)

self.feature_dicts[tag] = feature_dict

processed_feature_dict = self.feature_processor.process_features(
feature_dict, mode='predict', is_multimer=is_multimer
)

processed_feature_dict = {
k: torch.as_tensor(v, device=args.model_device)
for k, v in processed_feature_dict.items()
}

# Move outputs to shared memory for zero-copy forwarding to inference process
for k, v in processed_feature_dict.items():
v.share_memory_()
item['feature_dict'] = feature_dict
item['processed_feature_dict'] = processed_feature_dict
return item

99 changes: 99 additions & 0 deletions openfold/utils/multiprocessing/inference_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import os
import time
import logging
import numpy as np
from openfold.utils.tensor_utils import tensor_tree_map
from openfold.utils.script_utils import (load_models_from_command_line, run_model, prep_output)
from openfold.utils.trace_utils import trace_model_
from openfold.np import protein
from openfold.utils.multiprocessing.base_pipeline import BaseTask

logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)

class InferenceTask(BaseTask):
def __init__(self):
super().__init__('infer')

def setup(self, params):
self.params = params
args = params['args']
self.model_generator = list(load_models_from_command_line(
params['config'],
args.model_device,
args.openfold_checkpoint_path,
args.jax_param_path,
args.output_dir)
)
self.cur_tracing_interval = 0

def process_one(self, item):
tag = item['tag']
args = self.params['args']
config = self.params['config']
feature_dict = item['feature_dict']
processed_feature_dict = item['processed_feature_dict']
item['unrelaxed_protein_list'] = []
item['output_directory_list'] = []
item['out'] = []
for model, output_directory in self.model_generator:
if args.trace_model:
rounded_seqlen = item['rounded_seqlen']
if rounded_seqlen > self.cur_tracing_interval:
logger.info(
f"Tracing model at {rounded_seqlen} residues..."
)
t = time.perf_counter()
trace_model_(model, processed_feature_dict)
tracing_time = time.perf_counter() - t
logger.info(
f"Tracing time: {tracing_time}"
)
self.cur_tracing_interval = rounded_seqlen

out = run_model(model, processed_feature_dict, tag, args.output_dir)

# Toss out the recycling dimensions --- we don't need them anymore
processed_feature_dict = tensor_tree_map(
lambda x: np.array(x[..., -1].cpu()),
processed_feature_dict
)
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)

unrelaxed_protein = prep_output(
out,
processed_feature_dict,
feature_dict,
config,
args.config_preset,
args.multimer_ri_gap,
args.subtract_plddt
)
unrelaxed_file_suffix = "_unrelaxed.pdb"
if args.cif_output:
unrelaxed_file_suffix = "_unrelaxed.cif"
unrelaxed_output_path = os.path.join(
output_directory, f'{item["output_name"]}{unrelaxed_file_suffix}'
)

with open(unrelaxed_output_path, 'w') as fp:
if args.cif_output:
fp.write(protein.to_modelcif(unrelaxed_protein))
else:
fp.write(protein.to_pdb(unrelaxed_protein))

logger.info(f"Output written to {unrelaxed_output_path}...")

item['unrelaxed_protein_list'].append(unrelaxed_protein)
item['output_directory_list'].append(output_directory)
if args.save_outputs:
item['out'].append(out)
else:
item['out'].append(None) # Just to match counts with other lists in the item
del item['feature_dict']
del item['processed_feature_dict']
return item
Loading
Loading