diff --git a/openfold/utils/multiprocessing/__init__.py b/openfold/utils/multiprocessing/__init__.py new file mode 100644 index 000000000..4764e463a --- /dev/null +++ b/openfold/utils/multiprocessing/__init__.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from openfold.utils.multiprocessing.base_pipeline import BasePipeline, BaseTask +from openfold.utils.multiprocessing.msa_task import MsaTask +from openfold.utils.multiprocessing.feature_gen_task import FeatureGenTask +from openfold.utils.multiprocessing.inference_task import InferenceTask +from openfold.utils.multiprocessing.relaxation_task import RelaxationTask diff --git a/openfold/utils/multiprocessing/base_pipeline.py b/openfold/utils/multiprocessing/base_pipeline.py new file mode 100644 index 000000000..aecb5e333 --- /dev/null +++ b/openfold/utils/multiprocessing/base_pipeline.py @@ -0,0 +1,123 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import time +import logging +logging.basicConfig() +from torch.multiprocessing import get_context + +class BaseTask: + def __init__(self, name): + self.name = name + + def _process_start(self, ctx, params, in_q, out_q, exit_event): + self.process = ctx.Process(target=self._process_loop, args=(params, in_q, out_q, exit_event)) + self.process.start() + + def _process_join(self): + self.process.join() + + def _process_loop(self, params, in_q, out_q, exit_event): + logging.info(f"{self.name}: process_func started") + self.setup(params) + while True: + item = in_q.get() + if not item: # None-like signals end of processing; forward it through the chain + out_q.put(item) + break + else: + out_q.put(self._process_one(item)) + exit_event.wait() + self.teardown() + logging.info(f"{self.name}: process_func completed") + + def _process_one(self, item): + logging.info(f"{self.name}: {item['tag']} started") + start_time = time.time() + try: + item = self.process_one(item) + except Exception as e: + logging.error(f"{self.name}: {item['tag']}: ") + logging.error(e) + end_time = time.time() + logging.info(f"{self.name}: {item['tag']} completed in {(end_time - start_time):0.2f} seconds") + return item + + # Should be implemented by derived task + def setup(self, params): + return + + # Should be implemented by derived task + def teardown(self): + return + + # Should be implemented by derived task + def process_one(self, item): + return item + +class BasePipeline: + def __new__(cls, multiprocessing=False, tasks=[]): + if cls is BasePipeline: + if multiprocessing: + return BasePipelineMultiprocessing(multiprocessing, tasks) + else: + return BasePipelineSequential(multiprocessing, tasks) + else: + return super().__new__(cls) + +class BasePipelineMultiprocessing(BasePipeline): + def __init__(self, multiprocessing=True, tasks=[]): + self.ctx = get_context("spawn") + self.tasks = tasks # a list of BaseTask instances, to be connected in series + self.queues = [] # tasks[i] reads from queues[i] and writes to queues[i+1] + self.exit_events = [] # Each task should terminate only after the corresponding exit_event is signalled + for i in range(len(self.tasks)+1): + self.queues.append(self.ctx.Queue()) + for i in range(len(self.tasks)): + self.exit_events.append(self.ctx.Event()) + + def start(self, params): + logging.info(f"pipeline: Starting in multiprocessing mode") + for i in range(len(self.tasks)): + self.tasks[i]._process_start(self.ctx, params, self.queues[i], self.queues[i+1], self.exit_events[i]) + + def run(self, inputs): + for item in inputs: + logging.info(f"pipeline: Sending {item['tag']}") + self.queues[0].put(item) + outputs = [] + for item in inputs: + outputs.append(self.queues[-1].get()) + logging.info(f"pipeline: Received {item['tag']}") + return outputs + + def end(self): + for event in self.exit_events: + event.set() + self.queues[0].put(None) # Signals exit to all tasks + for task in self.tasks: + task._process_join() + for queue in self.queues: + queue.close() + logging.info(f"pipeline: Completing run") + +class BasePipelineSequential(BasePipeline): + def __init__(self, multiprocessing=False, tasks=[]): + self.tasks = tasks + + def start(self, params): + logging.info(f"pipeline: Starting in sequential mode") + for task in self.tasks: + task.setup(params) + + def run(self, inputs): + for item in inputs: + logging.info(f"pipeline: Sending {item['tag']}") + for task in self.tasks: + item = task._process_one(item) + logging.info(f"pipeline: Received {item['tag']}") + + def end(self): + for task in self.tasks: + task.teardown() + logging.info(f"pipeline: Completing run") diff --git a/openfold/utils/multiprocessing/feature_gen_task.py b/openfold/utils/multiprocessing/feature_gen_task.py new file mode 100644 index 000000000..14fae24ae --- /dev/null +++ b/openfold/utils/multiprocessing/feature_gen_task.py @@ -0,0 +1,133 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import os +import math +import torch +from openfold.data import feature_pipeline, templates, data_pipeline +from openfold.utils.trace_utils import pad_feature_dict_seq +from openfold.utils.multiprocessing.base_pipeline import BaseTask + +TRACING_INTERVAL = 50 + +def round_up_seqlen(seqlen): + return int(math.ceil(seqlen / TRACING_INTERVAL)) * TRACING_INTERVAL + +class FeatureGenTask(BaseTask): + def __init__(self): + super().__init__('feature_gen') + + def setup(self, params): + self.params = params + args = self.params['args'] + config = self.params['config'] + is_multimer = "multimer" in args.config_preset + is_custom_template = "use_custom_template" in args and args.use_custom_template + if is_custom_template: + template_featurizer = templates.CustomHitFeaturizer( + mmcif_dir=args.template_mmcif_dir, + max_template_date="9999-12-31", # just dummy, not used + max_hits=-1, # just dummy, not used + kalign_binary_path=args.kalign_binary_path + ) + elif is_multimer: + template_featurizer = templates.HmmsearchHitFeaturizer( + mmcif_dir=args.template_mmcif_dir, + max_template_date=args.max_template_date, + max_hits=config.data.predict.max_templates, + kalign_binary_path=args.kalign_binary_path, + release_dates_path=args.release_dates_path, + obsolete_pdbs_path=args.obsolete_pdbs_path + ) + else: + template_featurizer = templates.HhsearchHitFeaturizer( + mmcif_dir=args.template_mmcif_dir, + max_template_date=args.max_template_date, + max_hits=config.data.predict.max_templates, + kalign_binary_path=args.kalign_binary_path, + release_dates_path=args.release_dates_path, + obsolete_pdbs_path=args.obsolete_pdbs_path + ) + self.data_processor = data_pipeline.DataPipeline( + template_featurizer=template_featurizer, + ) + if is_multimer: + self.data_processor = data_pipeline.DataPipelineMultimer( + monomer_data_pipeline=self.data_processor, + ) + self.feature_processor = feature_pipeline.FeaturePipeline(config.data) + self.feature_dicts = {} + + def teardown(self): + pass + + def process_one(self, item): # item is a tuple of (tag, tags, seqs) + tag = item['tag'] + tags = item['tags'] + seqs = item['seqs'] + alignment_dir = self.params['alignment_dir'] + args = self.params['args'] + is_multimer = "multimer" in args.config_preset + + feature_dict = self.feature_dicts.get(tag, None) + if feature_dict is None: + # Start of generate_feature_dict + tmp_fasta_path = os.path.join(self.params['args'].output_dir, f"tmp_{os.getpid()}_{tag}.fasta") + + if "multimer" in args.config_preset: + with open(tmp_fasta_path, "w") as fp: + fp.write( + '\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)]) + ) + feature_dict = self.data_processor.process_fasta( + fasta_path=tmp_fasta_path, alignment_dir=alignment_dir, + ) + elif len(seqs) == 1: + tag = tags[0] + seq = seqs[0] + with open(tmp_fasta_path, "w") as fp: + fp.write(f">{tag}\n{seq}") + + local_alignment_dir = os.path.join(alignment_dir, tag) + feature_dict = self.data_processor.process_fasta( + fasta_path=tmp_fasta_path, + alignment_dir=local_alignment_dir, + seqemb_mode=args.use_single_seq_mode, + ) + else: + with open(tmp_fasta_path, "w") as fp: + fp.write( + '\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)]) + ) + feature_dict = self.data_processor.process_multiseq_fasta( + fasta_path=tmp_fasta_path, super_alignment_dir=alignment_dir, + ) + os.remove(tmp_fasta_path) + # End of generate_feature_dict + + if args.trace_model: + n = feature_dict["aatype"].shape[-2] + rounded_seqlen = round_up_seqlen(n) + item['rounded_seqlen'] = rounded_seqlen + feature_dict = pad_feature_dict_seq( + feature_dict, rounded_seqlen, + ) + + self.feature_dicts[tag] = feature_dict + + processed_feature_dict = self.feature_processor.process_features( + feature_dict, mode='predict', is_multimer=is_multimer + ) + + processed_feature_dict = { + k: torch.as_tensor(v, device=args.model_device) + for k, v in processed_feature_dict.items() + } + + # Move outputs to shared memory for zero-copy forwarding to inference process + for k, v in processed_feature_dict.items(): + v.share_memory_() + item['feature_dict'] = feature_dict + item['processed_feature_dict'] = processed_feature_dict + return item + diff --git a/openfold/utils/multiprocessing/inference_task.py b/openfold/utils/multiprocessing/inference_task.py new file mode 100644 index 000000000..3fce3ed35 --- /dev/null +++ b/openfold/utils/multiprocessing/inference_task.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import os +import time +import logging +import numpy as np +from openfold.utils.tensor_utils import tensor_tree_map +from openfold.utils.script_utils import (load_models_from_command_line, run_model, prep_output) +from openfold.utils.trace_utils import trace_model_ +from openfold.np import protein +from openfold.utils.multiprocessing.base_pipeline import BaseTask + +logging.basicConfig() +logger = logging.getLogger(__file__) +logger.setLevel(level=logging.INFO) + +class InferenceTask(BaseTask): + def __init__(self): + super().__init__('infer') + + def setup(self, params): + self.params = params + args = params['args'] + self.model_generator = list(load_models_from_command_line( + params['config'], + args.model_device, + args.openfold_checkpoint_path, + args.jax_param_path, + args.output_dir) + ) + self.cur_tracing_interval = 0 + + def process_one(self, item): + tag = item['tag'] + args = self.params['args'] + config = self.params['config'] + feature_dict = item['feature_dict'] + processed_feature_dict = item['processed_feature_dict'] + item['unrelaxed_protein_list'] = [] + item['output_directory_list'] = [] + item['out'] = [] + for model, output_directory in self.model_generator: + if args.trace_model: + rounded_seqlen = item['rounded_seqlen'] + if rounded_seqlen > self.cur_tracing_interval: + logger.info( + f"Tracing model at {rounded_seqlen} residues..." + ) + t = time.perf_counter() + trace_model_(model, processed_feature_dict) + tracing_time = time.perf_counter() - t + logger.info( + f"Tracing time: {tracing_time}" + ) + self.cur_tracing_interval = rounded_seqlen + + out = run_model(model, processed_feature_dict, tag, args.output_dir) + + # Toss out the recycling dimensions --- we don't need them anymore + processed_feature_dict = tensor_tree_map( + lambda x: np.array(x[..., -1].cpu()), + processed_feature_dict + ) + out = tensor_tree_map(lambda x: np.array(x.cpu()), out) + + unrelaxed_protein = prep_output( + out, + processed_feature_dict, + feature_dict, + config, + args.config_preset, + args.multimer_ri_gap, + args.subtract_plddt + ) + unrelaxed_file_suffix = "_unrelaxed.pdb" + if args.cif_output: + unrelaxed_file_suffix = "_unrelaxed.cif" + unrelaxed_output_path = os.path.join( + output_directory, f'{item["output_name"]}{unrelaxed_file_suffix}' + ) + + with open(unrelaxed_output_path, 'w') as fp: + if args.cif_output: + fp.write(protein.to_modelcif(unrelaxed_protein)) + else: + fp.write(protein.to_pdb(unrelaxed_protein)) + + logger.info(f"Output written to {unrelaxed_output_path}...") + + item['unrelaxed_protein_list'].append(unrelaxed_protein) + item['output_directory_list'].append(output_directory) + if args.save_outputs: + item['out'].append(out) + else: + item['out'].append(None) # Just to match counts with other lists in the item + del item['feature_dict'] + del item['processed_feature_dict'] + return item \ No newline at end of file diff --git a/openfold/utils/multiprocessing/msa_task.py b/openfold/utils/multiprocessing/msa_task.py new file mode 100644 index 000000000..1da03d8c4 --- /dev/null +++ b/openfold/utils/multiprocessing/msa_task.py @@ -0,0 +1,97 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import os +import logging +import shutil +from openfold.data.tools import hhsearch, hmmsearch +from openfold.data import data_pipeline +from scripts.precompute_embeddings import EmbeddingGenerator +from openfold.utils.multiprocessing.base_pipeline import BaseTask + +logging.basicConfig() +logger = logging.getLogger(__file__) +logger.setLevel(level=logging.INFO) + + +class MsaTask(BaseTask): + def __init__(self, *args, **kwargs): + super().__init__('msa', *args, **kwargs) + + def setup(self, params): + self.params = params + args = self.params['args'] + if args.use_precomputed_alignments is None: + if "multimer" in args.config_preset: + template_searcher = hmmsearch.Hmmsearch( + binary_path=args.hmmsearch_binary_path, + hmmbuild_binary_path=args.hmmbuild_binary_path, + database_path=args.pdb_seqres_database_path, + ) + else: + template_searcher = hhsearch.HHSearch( + binary_path=args.hhsearch_binary_path, + databases=[args.pdb70_database_path], + ) + + # In seqemb mode, use AlignmentRunner only to generate templates + if args.use_single_seq_mode: + self.alignment_runner = data_pipeline.AlignmentRunner( + jackhmmer_binary_path=args.jackhmmer_binary_path, + uniref90_database_path=args.uniref90_database_path, + template_searcher=template_searcher, + no_cpus=args.cpus, + ) + self.embedding_generator = EmbeddingGenerator() + else: + self.alignment_runner = data_pipeline.AlignmentRunner( + jackhmmer_binary_path=args.jackhmmer_binary_path, + hhblits_binary_path=args.hhblits_binary_path, + uniref90_database_path=args.uniref90_database_path, + mgnify_database_path=args.mgnify_database_path, + bfd_database_path=args.bfd_database_path, + uniref30_database_path=args.uniref30_database_path, + uniclust30_database_path=args.uniclust30_database_path, + uniprot_database_path=args.uniprot_database_path, + template_searcher=template_searcher, + use_small_bfd=args.bfd_database_path is None, + no_cpus=args.cpus + ) + + def teardown(self): + pass + + def process_one(self, item): + tag = item['tag'] + tags = item['tags'] + seqs = item['seqs'] + args = self.params['args'] + alignment_dir = self.params['alignment_dir'] + for tag, seq in zip(tags, seqs): + tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta") + with open(tmp_fasta_path, "w") as fp: + fp.write(f">{tag}\n{seq}") + + local_alignment_dir = os.path.join(alignment_dir, tag) + + if args.use_precomputed_alignments is None: + logger.info(f"Generating alignments for {tag}...") + + os.makedirs(local_alignment_dir, exist_ok=True) + + # In seqemb mode, use AlignmentRunner only to generate templates + if args.use_single_seq_mode: + self.embedding_generator.run(tmp_fasta_path, alignment_dir) + + self.alignment_runner.run( + tmp_fasta_path, local_alignment_dir + ) + else: + logger.info( + f"Using precomputed alignments for {tag} at {alignment_dir}..." + ) + + # Remove temporary FASTA file + os.remove(tmp_fasta_path) + + return item diff --git a/openfold/utils/multiprocessing/relaxation_task.py b/openfold/utils/multiprocessing/relaxation_task.py new file mode 100644 index 000000000..ede38754a --- /dev/null +++ b/openfold/utils/multiprocessing/relaxation_task.py @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import os +import pickle +import logging +from openfold.utils.script_utils import relax_protein +from openfold.utils.multiprocessing.base_pipeline import BaseTask + +logging.basicConfig() +logger = logging.getLogger(__file__) +logger.setLevel(level=logging.INFO) + +class RelaxationTask(BaseTask): + def __init__(self): + super().__init__('relax') + + def setup(self, params): + self.params = params + + def teardown(self): + pass + + def process_one(self, item): + args = self.params['args'] + config = self.params['config'] + output_name = item['output_name'] + if (args.relaxation_device != 'skip') or (args.skip_relaxation): + for unrelaxed_protein, output_directory, out in zip(item['unrelaxed_protein_list'], item['output_directory_list'], item['out']): + relax_protein(config, + args.relaxation_device, + unrelaxed_protein, + output_directory, + output_name, + args.cif_output) + + if args.save_outputs: + output_dict_path = os.path.join( + output_directory, f'{output_name}_output_dict.pkl' + ) + with open(output_dict_path, "wb") as fp: + pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL) + + logger.info(f"Model output written to {output_dict_path}...") + + return item diff --git a/openfold/utils/script_utils.py b/openfold/utils/script_utils.py index facc103d9..c4c2f9cbf 100644 --- a/openfold/utils/script_utils.py +++ b/openfold/utils/script_utils.py @@ -1,3 +1,7 @@ +# Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Modified by: NVIDIA +# Modified on: 25-Aug-2025 + import json import logging import os @@ -167,7 +171,7 @@ def run_model(model, batch, tag, output_dir): return out -def prep_output(out, batch, feature_dict, feature_processor, config_preset, multimer_ri_gap, subtract_plddt): +def prep_output(out, batch, feature_dict, feature_processor_config, config_preset, multimer_ri_gap, subtract_plddt): plddt = out["plddt"] plddt_b_factors = numpy.repeat( @@ -180,26 +184,26 @@ def prep_output(out, batch, feature_dict, feature_processor, config_preset, mult # Prep protein metadata template_domain_names = [] template_chain_index = None - if feature_processor.config.common.use_templates and "template_domain_names" in feature_dict: + if feature_processor_config.data.common.use_templates and "template_domain_names" in feature_dict: template_domain_names = [ t.decode("utf-8") for t in feature_dict["template_domain_names"] ] # This works because templates are not shuffled during inference template_domain_names = template_domain_names[ - :feature_processor.config.predict.max_templates + :feature_processor_config.data.predict.max_templates ] if "template_chain_index" in feature_dict: template_chain_index = feature_dict["template_chain_index"] template_chain_index = template_chain_index[ - :feature_processor.config.predict.max_templates + :feature_processor_config.data.predict.max_templates ] - no_recycling = feature_processor.config.common.max_recycling_iters + no_recycling = feature_processor_config.data.common.max_recycling_iters remark = ', '.join([ f"no_recycling={no_recycling}", - f"max_templates={feature_processor.config.predict.max_templates}", + f"max_templates={feature_processor_config.data.predict.max_templates}", f"config_preset={config_preset}", ]) diff --git a/run_pretrained_openfold.py b/run_pretrained_openfold.py index 510610493..ef1adde79 100644 --- a/run_pretrained_openfold.py +++ b/run_pretrained_openfold.py @@ -12,9 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Modified by: NVIDIA +# Modified on: 25-Aug-2025 + import argparse import logging -import math import numpy as np import os import pickle @@ -40,8 +44,8 @@ torch.set_grad_enabled(False) from openfold.config import model_config +from openfold.utils.multiprocessing import BasePipeline, MsaTask, FeatureGenTask, InferenceTask, RelaxationTask from openfold.data import templates, feature_pipeline, data_pipeline -from openfold.data.tools import hhsearch, hmmsearch from openfold.np import protein from openfold.utils.script_utils import (load_models_from_command_line, parse_fasta, run_model, prep_output, relax_protein) @@ -51,127 +55,11 @@ trace_model_, ) -from scripts.precompute_embeddings import EmbeddingGenerator from scripts.utils import add_data_args - -TRACING_INTERVAL = 50 - - -def precompute_alignments(tags, seqs, alignment_dir, args): - for tag, seq in zip(tags, seqs): - tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta") - with open(tmp_fasta_path, "w") as fp: - fp.write(f">{tag}\n{seq}") - - local_alignment_dir = os.path.join(alignment_dir, tag) - - if args.use_precomputed_alignments is None: - logger.info(f"Generating alignments for {tag}...") - - os.makedirs(local_alignment_dir, exist_ok=True) - - if "multimer" in args.config_preset: - template_searcher = hmmsearch.Hmmsearch( - binary_path=args.hmmsearch_binary_path, - hmmbuild_binary_path=args.hmmbuild_binary_path, - database_path=args.pdb_seqres_database_path, - ) - else: - template_searcher = hhsearch.HHSearch( - binary_path=args.hhsearch_binary_path, - databases=[args.pdb70_database_path], - ) - - # In seqemb mode, use AlignmentRunner only to generate templates - if args.use_single_seq_mode: - alignment_runner = data_pipeline.AlignmentRunner( - jackhmmer_binary_path=args.jackhmmer_binary_path, - uniref90_database_path=args.uniref90_database_path, - template_searcher=template_searcher, - no_cpus=args.cpus, - ) - embedding_generator = EmbeddingGenerator() - embedding_generator.run(tmp_fasta_path, alignment_dir) - else: - alignment_runner = data_pipeline.AlignmentRunner( - jackhmmer_binary_path=args.jackhmmer_binary_path, - hhblits_binary_path=args.hhblits_binary_path, - uniref90_database_path=args.uniref90_database_path, - mgnify_database_path=args.mgnify_database_path, - bfd_database_path=args.bfd_database_path, - uniref30_database_path=args.uniref30_database_path, - uniclust30_database_path=args.uniclust30_database_path, - uniprot_database_path=args.uniprot_database_path, - template_searcher=template_searcher, - use_small_bfd=args.bfd_database_path is None, - no_cpus=args.cpus - ) - - alignment_runner.run( - tmp_fasta_path, local_alignment_dir - ) - else: - logger.info( - f"Using precomputed alignments for {tag} at {alignment_dir}..." - ) - - # Remove temporary FASTA file - os.remove(tmp_fasta_path) - - -def round_up_seqlen(seqlen): - return int(math.ceil(seqlen / TRACING_INTERVAL)) * TRACING_INTERVAL - - -def generate_feature_dict( - tags, - seqs, - alignment_dir, - data_processor, - args, -): - tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta") - - if "multimer" in args.config_preset: - with open(tmp_fasta_path, "w") as fp: - fp.write( - '\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)]) - ) - feature_dict = data_processor.process_fasta( - fasta_path=tmp_fasta_path, alignment_dir=alignment_dir, - ) - elif len(seqs) == 1: - tag = tags[0] - seq = seqs[0] - with open(tmp_fasta_path, "w") as fp: - fp.write(f">{tag}\n{seq}") - - local_alignment_dir = os.path.join(alignment_dir, tag) - feature_dict = data_processor.process_fasta( - fasta_path=tmp_fasta_path, - alignment_dir=local_alignment_dir, - seqemb_mode=args.use_single_seq_mode, - ) - else: - with open(tmp_fasta_path, "w") as fp: - fp.write( - '\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)]) - ) - feature_dict = data_processor.process_multiseq_fasta( - fasta_path=tmp_fasta_path, super_alignment_dir=alignment_dir, - ) - - # Remove temporary FASTA file - os.remove(tmp_fasta_path) - - return feature_dict - - def list_files_with_extensions(dir, extensions): return [f for f in os.listdir(dir) if f.endswith(extensions)] - def main(args): # Create the output directory os.makedirs(args.output_dir, exist_ok=True) @@ -198,38 +86,6 @@ def main(args): is_multimer = "multimer" in args.config_preset is_custom_template = "use_custom_template" in args and args.use_custom_template - if is_custom_template: - template_featurizer = templates.CustomHitFeaturizer( - mmcif_dir=args.template_mmcif_dir, - max_template_date="9999-12-31", # just dummy, not used - max_hits=-1, # just dummy, not used - kalign_binary_path=args.kalign_binary_path - ) - elif is_multimer: - template_featurizer = templates.HmmsearchHitFeaturizer( - mmcif_dir=args.template_mmcif_dir, - max_template_date=args.max_template_date, - max_hits=config.data.predict.max_templates, - kalign_binary_path=args.kalign_binary_path, - release_dates_path=args.release_dates_path, - obsolete_pdbs_path=args.obsolete_pdbs_path - ) - else: - template_featurizer = templates.HhsearchHitFeaturizer( - mmcif_dir=args.template_mmcif_dir, - max_template_date=args.max_template_date, - max_hits=config.data.predict.max_templates, - kalign_binary_path=args.kalign_binary_path, - release_dates_path=args.release_dates_path, - obsolete_pdbs_path=args.obsolete_pdbs_path - ) - data_processor = data_pipeline.DataPipeline( - template_featurizer=template_featurizer, - ) - if is_multimer: - data_processor = data_pipeline.DataPipelineMultimer( - monomer_data_pipeline=data_processor, - ) output_dir_base = args.output_dir random_seed = args.data_random_seed @@ -238,7 +94,6 @@ def main(args): np.random.seed(random_seed) torch.manual_seed(random_seed + 1) - feature_processor = feature_pipeline.FeaturePipeline(config.data) if not os.path.exists(output_dir_base): os.makedirs(output_dir_base) if args.use_precomputed_alignments is None: @@ -271,117 +126,40 @@ def main(args): seq_sort_fn = lambda target: sum([len(s) for s in target[1]]) sorted_targets = sorted(zip(tag_list, seq_list), key=seq_sort_fn) - feature_dicts = {} if is_multimer and args.openfold_checkpoint_path: raise ValueError( '`openfold_checkpoint_path` was specified, but no OpenFold checkpoints are available for multimer mode') - model_generator = load_models_from_command_line( - config, - args.model_device, - args.openfold_checkpoint_path, - args.jax_param_path, - args.output_dir) - - for model, output_directory in model_generator: - cur_tracing_interval = 0 - for (tag, tags), seqs in sorted_targets: - output_name = f'{tag}_{args.config_preset}' - if args.output_postfix is not None: - output_name = f'{output_name}_{args.output_postfix}' - - # Does nothing if the alignments have already been computed - precompute_alignments(tags, seqs, alignment_dir, args) - - feature_dict = feature_dicts.get(tag, None) - if feature_dict is None: - feature_dict = generate_feature_dict( - tags, - seqs, - alignment_dir, - data_processor, - args, - ) - - if args.trace_model: - n = feature_dict["aatype"].shape[-2] - rounded_seqlen = round_up_seqlen(n) - feature_dict = pad_feature_dict_seq( - feature_dict, rounded_seqlen, - ) - - feature_dicts[tag] = feature_dict - processed_feature_dict = feature_processor.process_features( - feature_dict, mode='predict', is_multimer=is_multimer - ) - - processed_feature_dict = { - k: torch.as_tensor(v, device=args.model_device) - for k, v in processed_feature_dict.items() - } - - if args.trace_model: - if rounded_seqlen > cur_tracing_interval: - logger.info( - f"Tracing model at {rounded_seqlen} residues..." - ) - t = time.perf_counter() - trace_model_(model, processed_feature_dict) - tracing_time = time.perf_counter() - t - logger.info( - f"Tracing time: {tracing_time}" - ) - cur_tracing_interval = rounded_seqlen - - out = run_model(model, processed_feature_dict, tag, args.output_dir) - - # Toss out the recycling dimensions --- we don't need them anymore - processed_feature_dict = tensor_tree_map( - lambda x: np.array(x[..., -1].cpu()), - processed_feature_dict - ) - out = tensor_tree_map(lambda x: np.array(x.cpu()), out) - - unrelaxed_protein = prep_output( - out, - processed_feature_dict, - feature_dict, - feature_processor, - args.config_preset, - args.multimer_ri_gap, - args.subtract_plddt - ) - - unrelaxed_file_suffix = "_unrelaxed.pdb" - if args.cif_output: - unrelaxed_file_suffix = "_unrelaxed.cif" - unrelaxed_output_path = os.path.join( - output_directory, f'{output_name}{unrelaxed_file_suffix}' - ) - - with open(unrelaxed_output_path, 'w') as fp: - if args.cif_output: - fp.write(protein.to_modelcif(unrelaxed_protein)) - else: - fp.write(protein.to_pdb(unrelaxed_protein)) - - logger.info(f"Output written to {unrelaxed_output_path}...") - - if not args.skip_relaxation: - # Relax the prediction. - logger.info(f"Running relaxation on {unrelaxed_output_path}...") - relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name, - args.cif_output) - - if args.save_outputs: - output_dict_path = os.path.join( - output_directory, f'{output_name}_output_dict.pkl' - ) - with open(output_dict_path, "wb") as fp: - pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL) - - logger.info(f"Model output written to {output_dict_path}...") + targets = [] + for (tag, tags), seqs in sorted_targets: + output_name = f'{tag}_{args.config_preset}' + if args.output_postfix is not None: + output_name = f'{output_name}_{args.output_postfix}' + targets.append({ + 'tag': tag, + 'tags': tags, + 'seqs': seqs, + 'output_name': output_name, + }) + params = { + 'args': args, + 'config': config, + 'alignment_dir': alignment_dir, + 'is_multimer': is_multimer, + } + inferencePipeline = BasePipeline( + multiprocessing=args.multiprocessing, + tasks=[ + MsaTask(), + FeatureGenTask(), + InferenceTask(), + RelaxationTask(), + ] + ) + inferencePipeline.start(params) + inferencePipeline.run(targets) + inferencePipeline.end() if __name__ == "__main__": @@ -393,6 +171,10 @@ def main(args): parser.add_argument( "template_mmcif_dir", type=str, ) + parser.add_argument( + "--multiprocessing", action="store_true", default=False, + help="""Enable multiprocessing to overlap stages (alignments, generation, inference, relaxation) across the input batch.""" + ) parser.add_argument( "--use_precomputed_alignments", type=str, default=None, help="""Path to alignment directory. If provided, alignment computation @@ -415,6 +197,12 @@ def main(args): help="""Name of the device on which to run the model. Any valid torch device name is accepted (e.g. "cpu", "cuda:0")""" ) + parser.add_argument( + "--relaxation_device", type=str, default="cpu", + help="""Name of the device on which to run the relaxation. A valid torch + device name is accepted (e.g. "cpu", "cuda:0"). + Alternately; setting this to 'skip' would skip relaxation entirely""" + ) parser.add_argument( "--config_preset", type=str, default="model_1", help="""Name of a model config preset defined in openfold/config.py"""