diff --git a/scripts/inference.py b/scripts/inference.py index 22d604d..cc6f90d 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -32,10 +32,14 @@ def main(configs : DictConfig) -> None: print_configs(configs) os.environ["DATA_ROOT_DIR"] = configs.data_root_dir os.environ["CKPT_ROOT_DIR"] = configs.ckpt_root_dir + hydra_output_dir = HydraConfig.get().runtime.output_dir configs = ConfigDict( OmegaConf.to_container(configs.exp, resolve=True) ) - dump_dir = HydraConfig.get().runtime.output_dir + # exp.dump_dir allows the caller to separate data output from hydra's + # working directory, which is needed when parallel seed workers each + # require a distinct hydra.run.dir but share one output directory. + dump_dir = configs.get("dump_dir", None) or hydra_output_dir configs.dump_dir = dump_dir error_dir = Path(dump_dir) / "errors" if DIST_WRAPPER.rank == 0: @@ -45,7 +49,7 @@ def main(configs : DictConfig) -> None: logger.info( f"Distributed environment: world size: {DIST_WRAPPER.world_size}, " + f"global rank: {DIST_WRAPPER.rank}, local rank: {DIST_WRAPPER.local_rank}" - ) + ) device = torch.device("cuda:{}".format(DIST_WRAPPER.local_rank)) torch.cuda.set_device(device) if DIST_WRAPPER.world_size > 1: @@ -72,7 +76,7 @@ def main(configs : DictConfig) -> None: ) infer_runner.run() - + if __name__ == "__main__": main() diff --git a/src/utils/inference/dumper.py b/src/utils/inference/dumper.py index e543c41..5948ac7 100644 --- a/src/utils/inference/dumper.py +++ b/src/utils/inference/dumper.py @@ -127,25 +127,25 @@ def _save_structure_sequence( new_atom_array, per_chain_edits=per_chain_edits, design_modality=design_modality, - ) - - output_fpath = os.path.join( - prediction_save_dir, - f"{sample_name}_seed_{seed}_bb_{rank}_seq_{seq_var_idx}.cif", - ) - - if b_factor is not None: - # b_factor.shape == [N_sample, N_atom] - new_atom_array.set_annotation("b_factor", np.round(b_factor[idx], 2)) - - save_structure_cif( - atom_array=new_atom_array, - pred_coordinate=pred_coordinates[idx], - output_fpath=output_fpath, - entity_poly_type=entity_poly_type, - pdb_id=sample_name, ) + output_fpath = os.path.join( + prediction_save_dir, + f"{sample_name}_seed_{seed}_bb_{rank}_seq_{seq_var_idx}.cif", + ) + + if b_factor is not None: + # b_factor.shape == [N_sample, N_atom] + new_atom_array.set_annotation("b_factor", np.round(b_factor[idx], 2)) + + save_structure_cif( + atom_array=new_atom_array, + pred_coordinate=pred_coordinates[idx], + output_fpath=output_fpath, + entity_poly_type=entity_poly_type, + pdb_id=sample_name, + ) + def _apply_sequence_variant_to_atom_array( self, atom_array: AtomArray,