Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 149 additions & 112 deletions boltzdesign.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,11 @@ def parse_arguments():
default='protein', help='Type of target molecule')
parser.add_argument('--input_type', type=str, choices=['pdb', 'custom'], default='pdb',
help='Input type: pdb code or custom input')

# Logic allows for colon-separated paths in parsing
parser.add_argument('--pdb_path', type=str, default='',
help='Path to a local PDB file (if specify use custom pdb, else fetch from RCSB)')
help='Path to local PDB file(s). For multiple inputs, separate with colon (e.g. path1.pdb:path2.pdb)')

parser.add_argument('--pdb_target_ids', type=str, default='',
help='Target PDB IDs (comma-separated, e.g., "C,D")')
parser.add_argument('--target_mols', type=str, default='',
Expand Down Expand Up @@ -209,25 +212,14 @@ def parse_arguments():
parser.add_argument('--ccd_path', type=str,
default='~/.boltz/ccd.pkl',
help='Path to CCD file')
parser.add_argument('--alphafold_dir', type=str,
default='~/alphafold3',
help='AlphaFold directory')
parser.add_argument('--af3_docker_name', type=str,
default='alphafold3',
help='Docker name')
parser.add_argument('--af3_database_settings', type=str,
default='~/alphafold3/alphafold3_data_save',
help='AlphaFold3 database settings')
parser.add_argument('--af3_hmmer_path', type=str,
default='/home/jupyter-yehlin/.conda/envs/alphafold3_venv',
help='AlphaFold3 hmmer path, required for RNA MSA generation')

# Control flags
parser.add_argument('--run_boltz_design', type=str2bool, default=True,
help='Run Boltz design step')
parser.add_argument('--run_ligandmpnn', type=str2bool, default=True,
help='Run LigandMPNN redesign step')
parser.add_argument('--run_alphafold', type=str2bool, default=True,
help='Run AlphaFold validation step')
parser.add_argument('--run_validation', type=str2bool, default=True,
help='Run validation step (Boltz)')
parser.add_argument('--run_rosetta', type=str2bool, default=True,
help='Run Rosetta energy calculation (protein targets only)')
parser.add_argument('--redo_boltz_predict', type=str2bool, default=False,
Expand Down Expand Up @@ -513,93 +505,100 @@ def calculate_holo_apo_rmsd(af_pdb_dir, af_pdb_dir_apo, binder_chain):
for pdb_name in os.listdir(af_pdb_dir):
if pdb_name.endswith('.pdb'):
pdb_path = os.path.join(af_pdb_dir, pdb_name)
pdb_path_apo = os.path.join(af_pdb_dir_apo, pdb_name)
xyz_holo, _ = get_CA_and_sequence(pdb_path, chain_id=binder_chain)
xyz_apo, _ = get_CA_and_sequence(pdb_path_apo, chain_id='A')
rmsd = np_rmsd(np.array(xyz_holo), np.array(xyz_apo))
df_confidence_csv.loc[df_confidence_csv['file'] == pdb_name.split('.pdb')[0]+'.cif', 'rmsd'] = rmsd
print(f"{pdb_path} rmsd: {rmsd}")
# If apo directory is None or empty, skip RMSD calculation
if af_pdb_dir_apo and os.path.exists(os.path.join(af_pdb_dir_apo, pdb_name)):
pdb_path_apo = os.path.join(af_pdb_dir_apo, pdb_name)
xyz_holo, _ = get_CA_and_sequence(pdb_path, chain_id=binder_chain)
xyz_apo, _ = get_CA_and_sequence(pdb_path_apo, chain_id='A')
rmsd = np_rmsd(np.array(xyz_holo), np.array(xyz_apo))
df_confidence_csv.loc[df_confidence_csv['file'] == pdb_name.split('.pdb')[0]+'.cif', 'rmsd'] = rmsd
print(f"{pdb_path} rmsd: {rmsd}")
df_confidence_csv.to_csv(confidence_csv_path, index=False)


def run_alphafold_step(args, ligandmpnn_dir, work_dir, mod_to_wt_aa):
"""Run AlphaFold validation step"""
print("Starting AlphaFold validation step...")

alphafold_dir = os.path.expanduser(args.alphafold_dir)
afdb_dir = os.path.expanduser(args.af3_database_settings)
hmmer_path = os.path.expanduser(args.af3_hmmer_path)
print("alphafold_dir", alphafold_dir)
print("afdb_dir", afdb_dir)
print("hmmer_path", hmmer_path)

# Create AlphaFold directories
af_input_dir = f'{ligandmpnn_dir}/02_design_json_af3'
af_output_dir = f'{ligandmpnn_dir}/02_design_final_af3'
af_input_apo_dir = f'{ligandmpnn_dir}/02_design_json_af3_apo'
af_output_apo_dir = f'{ligandmpnn_dir}/02_design_final_af3_apo'

for dir_path in [af_input_dir, af_output_dir, af_input_apo_dir, af_output_apo_dir]:
def run_boltz_validation_step(args, ligandmpnn_dir, work_dir, mod_to_wt_aa):
"""Run Boltz validation step"""
print("Starting Boltz validation step...")

boltz_path = shutil.which("boltz")
if boltz_path is None:
raise FileNotFoundError("The 'boltz' command was not found in the system PATH.")

# Create Validation directories
val_input_dir = f'{ligandmpnn_dir}/02_design_boltz_val_input'
val_output_dir = f'{ligandmpnn_dir}/02_design_boltz_val_output'
val_pdb_dir = f'{ligandmpnn_dir}/03_boltz_val_pdb_success'

for dir_path in [val_input_dir, val_output_dir, val_pdb_dir]:
os.makedirs(dir_path, exist_ok=True)

# Process YAML files
# Input YAMLs from LigandMPNN step
yaml_dir_success_boltz_yaml = os.path.join(ligandmpnn_dir, '01_lmpnn_redesigned_high_iptm', 'yaml')

process_yaml_files(
yaml_dir_success_boltz_yaml,
af_input_dir,
af_input_apo_dir,
target_type=args.target_type,
binder_chain=args.binder_id,
mod_to_wt_aa=mod_to_wt_aa,
afdb_dir=afdb_dir,
hmmer_path=hmmer_path
)
# Run AlphaFold on holo state
subprocess.run([
f'{work_dir}/boltzdesign/alphafold.sh',
af_input_dir,
af_output_dir,
str(args.gpu_id),
alphafold_dir,
args.af3_docker_name
], check=True)

# Run AlphaFold on apo state
subprocess.run([
f'{work_dir}/boltzdesign/alphafold.sh',
af_input_apo_dir,
af_output_apo_dir,
str(args.gpu_id),
alphafold_dir,
args.af3_docker_name
], check=True)

print("AlphaFold validation step completed!")

af_pdb_dir = f"{ligandmpnn_dir}/03_af_pdb_success"
af_pdb_dir_apo = f"{ligandmpnn_dir}/03_af_pdb_apo"

convert_cif_files_to_pdb(af_output_dir, af_pdb_dir, af_dir=True, high_iptm=args.high_iptm)
if not any(f.endswith('.pdb') for f in os.listdir(af_pdb_dir)):
print("No successful designs from AlphaFold")
if not os.path.exists(yaml_dir_success_boltz_yaml) or not os.listdir(yaml_dir_success_boltz_yaml):
print("No input YAMLs found for validation.")
sys.exit(1)

# Copy YAMLs to validation input dir
for yaml_file in os.listdir(yaml_dir_success_boltz_yaml):
if yaml_file.endswith('.yaml'):
shutil.copy(os.path.join(yaml_dir_success_boltz_yaml, yaml_file),
os.path.join(val_input_dir, yaml_file))

# Run Boltz Prediction
# We run prediction on the folder containing the YAMLs
print(f"Running Boltz prediction on {val_input_dir}...")

# Boltz predict command (adjust arguments as needed for specific version)
# Assuming standard usage: boltz predict --input <dir_or_yaml> --output <out_dir>
cmd = [
boltz_path, 'predict',
val_input_dir,
'--out_dir', val_output_dir,
'--devices', str(args.gpu_id),
'--override' # Overwrite existing
]

# Add cache if available
if args.boltz_checkpoint:
cmd.extend(['--cache', args.ccd_path]) # Using ccd path as cache if applicable or generic cache path

try:
subprocess.run(cmd, check=True)
except subprocess.CalledProcessError as e:
print(f"Boltz validation failed: {e}")
sys.exit(1)

print("Boltz validation step completed!")

# Convert results to PDB
convert_cif_files_to_pdb(val_output_dir, val_pdb_dir, high_iptm=args.high_iptm)

if not any(f.endswith('.pdb') for f in os.listdir(val_pdb_dir)):
print("No successful designs from Boltz Validation")
sys.exit(1)
convert_cif_files_to_pdb(af_output_apo_dir, af_pdb_dir_apo, af_dir=True)
calculate_holo_apo_rmsd(af_pdb_dir, af_pdb_dir_apo, args.binder_id)

# Calculate RMSD/Confidence if needed (Currently only Apo supported in RMSD calc, skipping Apo for now)
calculate_holo_apo_rmsd(val_pdb_dir, None, args.binder_id)

return af_output_dir, af_output_apo_dir, af_pdb_dir, af_pdb_dir_apo
return val_output_dir, None, val_pdb_dir, None

def run_rosetta_step(args, ligandmpnn_dir, af_output_dir, af_output_apo_dir, af_pdb_dir, af_pdb_dir_apo):

def run_rosetta_step(args, ligandmpnn_dir, val_output_dir, val_output_apo_dir, val_pdb_dir, val_pdb_dir_apo):
"""Run Rosetta energy calculation (protein targets only)"""
if args.target_type != 'protein':
print("Skipping Rosetta step (not a protein target)")
return

if val_pdb_dir_apo is None:
print("Skipping Rosetta step (Apo structure not available from Boltz validation)")
return

print("Starting Rosetta energy calculation...")
af_pdb_rosetta_success_dir = f"{ligandmpnn_dir}/af_pdb_rosetta_success"
val_pdb_rosetta_success_dir = f"{ligandmpnn_dir}/val_pdb_rosetta_success"
from pyrosetta_utils import measure_rosetta_energy
measure_rosetta_energy(
af_pdb_dir, af_pdb_dir_apo, af_pdb_rosetta_success_dir,
val_pdb_dir, val_pdb_dir_apo, val_pdb_rosetta_success_dir,
binder_holo_chain=args.binder_id, binder_apo_chain='A'
)

Expand Down Expand Up @@ -650,43 +649,81 @@ def generate_yaml_config(args, config_obj):
constraints, modifications = process_design_constraints(target_id_map, args.modifications, args.modifications_positions, args.modification_target, args.contact_residues, args.constraint_target, args.binder_id)
else:
constraints, modifications = None, None
target = []

# Initialize list to hold target data for all inputs
# If single PDB, this will be a list of length 1 containing the sequence data
# If multiple PDBs (separated by colon), it contains data for each state
all_targets_data = []

if args.input_type == "pdb":
pdb_target_ids = [str(x.strip()) for x in args.pdb_target_ids.split(",")] if args.pdb_target_ids else None
target_mols = [str(x.strip()) for x in args.target_mols.split(",")] if args.target_mols else None

# Logic handles list splitting by colon
pdb_paths = []
if args.pdb_path:
pdb_path = Path(args.pdb_path)
print("load local pdb from", pdb_path)
if not pdb_path.is_file():
raise FileNotFoundError(f"Could not find local PDB: {args.pdb_path}")
# Check for colon delimiter for multiple PDBs
if ':' in args.pdb_path:
pdb_paths = [Path(p.strip()) for p in args.pdb_path.split(':')]
print(f"Detected multiple input PDBs: {pdb_paths}")
else:
pdb_paths = [Path(args.pdb_path)]

for p_path in pdb_paths:
print("load local pdb from", p_path)
if not p_path.is_file():
raise FileNotFoundError(f"Could not find local PDB: {p_path}")
else:
print("fetch pdb from RCSB")
download_pdb(args.target_name, config_obj.PDB_DIR)
pdb_path = config_obj.PDB_DIR / f"{args.target_name}.pdb"

if args.target_type in ['rna', 'dna']:
nucleotide_dict = get_nucleotide_from_pdb(pdb_path)
for target_id in pdb_target_ids:
target.append(nucleotide_dict[target_id]['seq'])
elif args.target_type == 'small_molecule':
ligand_dict = get_ligand_from_pdb(args.target_name)
for target_mol in target_mols:
print(target_mol, ligand_dict.keys())
target.append(ligand_dict[target_mol])
elif args.target_type == 'protein':
chain_sequences = get_chains_sequence(pdb_path)
for target_id in pdb_target_ids:
target.append(chain_sequences[target_id])
else:
raise ValueError(f"Unsupported target type: {args.target_type}")
# Default to single path if fetching
pdb_paths = [config_obj.PDB_DIR / f"{args.target_name}.pdb"]

# Loop through each PDB path to extract sequences/structures
for pdb_path in pdb_paths:
current_target_seqs = []

if args.target_type in ['rna', 'dna']:
# Sequence extraction inside loop
nucleotide_dict = get_nucleotide_from_pdb(pdb_path)
for target_id in pdb_target_ids:
current_target_seqs.append(nucleotide_dict[target_id]['seq'])

elif args.target_type == 'small_molecule':
# Small molecule extraction
# Using pdb_path here assuming local file support, otherwise fallback to args.target_name if needed
ligand_dict = get_ligand_from_pdb(str(pdb_path) if args.pdb_path else args.target_name)
for target_mol in target_mols:
print(target_mol, ligand_dict.keys())
current_target_seqs.append(ligand_dict[target_mol])

elif args.target_type == 'protein':
# Protein chain extraction inside loop
chain_sequences = get_chains_sequence(pdb_path)
for target_id in pdb_target_ids:
current_target_seqs.append(chain_sequences[target_id])
else:
raise ValueError(f"Unsupported target type: {args.target_type}")

# Add the extracted sequences for this PDB to the master list
all_targets_data.append(current_target_seqs)

else:
# Custom input usually means direct sequence strings, treating as single state
target_inputs = [str(x.strip()) for x in args.custom_target_input.split(",")] if args.custom_target_input else []
target = target_inputs or [args.target_name]
# Wrap in a list to maintain consistency with multi-state structure above
all_targets_data.append(target_inputs or [args.target_name])

# Passing the aggregated list to generator
# If all_targets_data has length 1 (single PDB), it behaves as before (list of sequences).
# If length > 1, generate_yaml_for_target_binder must handle list of lists.
# Note: We flatten if it's a single PDB to maintain backward compatibility if the utils expect a simple list
final_target_payload = all_targets_data if len(all_targets_data) > 1 else all_targets_data[0]

return generate_yaml_for_target_binder(
args.target_name,
args.target_type,
target,
final_target_payload,
config=config_obj,
binder_id=args.binder_id,
constraints=constraints,
Expand Down Expand Up @@ -721,7 +758,7 @@ def modification_to_wt_aa(modifications, modifications_wt):

def run_pipeline_steps(args, config, boltz_model, yaml_dir, output_dir):
"""Run the pipeline steps based on arguments"""
results = {'ligandmpnn_dir': f"{output_dir['main_dir']}/{output_dir['version']}/ligandmpnn_cutoff_{args.cutoff}", 'af_output_dir': None, 'af_output_apo_dir': None, 'af_pdb_dir': None, 'af_pdb_dir_apo': None}
results = {'ligandmpnn_dir': f"{output_dir['main_dir']}/{output_dir['version']}/ligandmpnn_cutoff_{args.cutoff}", 'val_output_dir': None, 'val_output_apo_dir': None, 'val_pdb_dir': None, 'val_pdb_dir_apo': None}

if args.run_boltz_design:
run_boltz_design_step(args, config, boltz_model, yaml_dir,
Expand All @@ -732,15 +769,15 @@ def run_pipeline_steps(args, config, boltz_model, yaml_dir, output_dir):
args, output_dir['main_dir'], output_dir['version'],
results['ligandmpnn_dir'], yaml_dir, args.work_dir or os.getcwd()
)
if args.run_alphafold:
if args.run_validation:
mod_to_wt_aa = modification_to_wt_aa(args.modifications, args.modifications_wt)
results['af_output_dir'], results['af_output_apo_dir'], results['af_pdb_dir'], results['af_pdb_dir_apo'] = run_alphafold_step(
results['val_output_dir'], results['val_output_apo_dir'], results['val_pdb_dir'], results['val_pdb_dir_apo'] = run_boltz_validation_step(
args, results['ligandmpnn_dir'], args.work_dir or os.getcwd(), mod_to_wt_aa
)

if args.run_rosetta:
run_rosetta_step(args, results['ligandmpnn_dir'],
results['af_output_dir'], results['af_output_apo_dir'], results['af_pdb_dir'], results['af_pdb_dir_apo'])
results['val_output_dir'], results['val_output_apo_dir'], results['val_pdb_dir'], results['val_pdb_dir_apo'])

return results

Expand Down
Loading