Hi, I've tried to run a training with an structurev2_small_cp.yaml example, but it seems the training just got stuck at 0 epochs. Do you have any idea how to fix this issue?
Hardware: 4xNVIDIA A100-SXM4-40GB
(boltz) root@C.37872624:/workspace/boltz-cp$ OMP_NUM_THREADS=16 torchrun \
--nnodes 1 \
--nproc_per_node 4 \
src/boltz/distributed/train.py \
scripts/train/configs/datasets_cp.yaml \
parallel_size.size_cp=4 \
parallel_size.size_dp=1
train_records before filter 216472
train_records before filter 216472
train_records before filter 216472
train_records before filter 216472
train_records after filter 203113
val_records after filter 398
train_records after filter 203113
val_records after filter 398
train_records after filter 203113
val_records after filter 398
train_records after filter 203113
val_records after filter 398
train_records before filter 268772
train_records before filter 268772
train_records before filter 268772
train_records before filter 268772
train_records after filter 268772
val_records after filter 0
train_records after filter 268772
val_records after filter 0
train_records after filter 268772
val_records after filter 0
train_records after filter 268772
val_records after filter 0
Training dataset size: 4366678
Training dataset size: 268772
Validation dataset size: 398
Training dataset size: 4366678
Training dataset size: 268772
Validation dataset size: 398
Training dataset size: 4366678
Training dataset size: 268772
Validation dataset size: 398
Training dataset size: 4366678
Training dataset size: 268772
Validation dataset size: 398
/workspace/boltz-cp/src/boltz/distributed/model/modules/diffusion.py:509: UserWarning: CPU offloading-based activation checkpointing by default is not used for Boltz-2 so we do not use it in DTensor DiffusionModule.
self.score_model = DiffusionModule(
/workspace/boltz-cp/src/boltz/distributed/model/modules/diffusion.py:509: UserWarning: CPU offloading-based activation checkpointing by default is not used for Boltz-2 so we do not use it in DTensor DiffusionModule.
self.score_model = DiffusionModule(
/workspace/boltz-cp/src/boltz/distributed/model/modules/diffusion.py:509: UserWarning: CPU offloading-based activation checkpointing by default is not used for Boltz-2 so we do not use it in DTensor DiffusionModule.
self.score_model = DiffusionModule(
/workspace/boltz-cp/src/boltz/distributed/model/modules/diffusion.py:509: UserWarning: CPU offloading-based activation checkpointing by default is not used for Boltz-2 so we do not use it in DTensor DiffusionModule.
self.score_model = DiffusionModule(
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2,3] | Name | Type | Params | Mode
------------------------------------------------------------------------------
0 | validators | ModuleList | 0 | train
1 | s_init | LinearParamsReplicated | 147 K | train
2 | z_init_1 | LinearParamsReplicated | 49.2 K | train
3 | z_init_2 | LinearParamsReplicated | 49.2 K | train
4 | s_norm | LayerNormParamsReplicated | 768 | train
5 | z_norm | LayerNormParamsReplicated | 256 | train
6 | s_recycle | LinearParamsReplicated | 147 K | train 7 | z_recycle | LinearParamsReplicated | 16.4 K | train
8 | token_bonds | LinearParamsReplicated | 128 | train
9 | msa_module | MSAModule | 2.4 M | train
10 | pairformer_module | PairformerModule | 36.9 M | train
11 | distogram_module | DistogramModule | 8.3 K | train
12 | input_embedder | InputEmbedder | 1.1 M | train
13 | rel_pos | RelativePositionEncoder | 17.8 K | train
14 | contact_conditioning | ContactConditioning | 17.5 K | train
15 | token_bonds_type | EmbeddingParamsReplicated | 896 | train
16 | diffusion_conditioning | DiffusionConditioning | 393 K | train
17 | structure_module | AtomDiffusion | 279 M | train
18 | bfactor_module | BFactorModule | 24.6 K | train
------------------------------------------------------------------------------
320 M Trainable params 768 Non-trainable params
320 M Total params
1,283.330 Total estimated model params size (MB)
1988 Modules in train mode
0 Modules in eval mode
/workspace/boltz-cp/.venv/lib/python3.11/site-packages/pytorch_lightning/utilities/_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
/workspace/boltz-cp/.venv/lib/python3.11/site-packages/pytorch_lightning/utilities/_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
/workspace/boltz-cp/.venv/lib/python3.11/site-packages/pytorch_lightning/utilities/_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
/workspace/boltz-cp/.venv/lib/python3.11/site-packages/pytorch_lightning/utilities/_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
Epoch 0: 0%| | 0/36096 [00:00<?, ?it/s]
config
dataset_cp.yaml
defaults:
- structurev2_cp
- _self_
trainer:
accumulate_grad_batches: 4
data:
checkpoint_monitor_val_group: "val/disto_lddt_protein_protein"
max_tokens: 256
max_atoms: 2048
max_seqs: 1024
model:
checkpoint_diffusion_conditioning: false
msa_args:
msa_blocks: 3
activation_checkpointing: false
template_args:
activation_checkpointing: false
pairformer_args:
num_blocks: 12
activation_checkpointing: false
score_model_args:
activation_checkpointing: false
structurev2_cp
defaults:
- structurev2
- _self_
# General DTensor context-parallel settings for normal-size model.
# For small-model CP training, use structurev2_small_cp.yaml instead.
# DTensor CP uses SingleDeviceStrategy; multi-device/node is managed by
# DistributedManager via parallel_size, not by Lightning.
trainer:
accelerator: gpu # must be gpu instead of cuda because CP code manages devices
devices: 1 # must be 1 for DTensor CP (one device per Lightning Trainer)
num_nodes: 1 # must be 1; multi-node is handled by torchrun/SLURM
precision: null # superseded by top-level precision below
# Context-parallelism and data-parallelism topology.
# Override via CLI: parallel_size.size_cp=4 parallel_size.size_dp=2
# Constraint: size_cp must be a perfect square (2D CP mesh); size_dp * size_cp == world_size.
parallel_size:
size_cp: 1 # context-parallel group size (must be a perfect square, e.g. 1, 4, 9, 16)
size_dp: 1 # data-parallel group size
timeout_nccl: 30 # NCCL timeout in minutes (for CUDA)
timeout_gloo: 30 # Gloo timeout in minutes (for CPU)
# Training precision. Values: FP32, TF32, BF16, BF16_MIXED, FP16, FP64
precision: BF16_MIXED
# Triangular attention kernel backend.
# Values: reference, cueq, trifast, cueq_fwd_trifast_bwd
# Note: cueq does not support FP32 precision.
triattn_backend: cueq
# Scaled dot-product attention with bias backend (ring-attention layers).
# Values: reference, torch_sdpa_efficient_attention, torch_flex_attn
sdpa_with_bias_backend: torch_flex_attn
# SDPA with bias backend for shardwise (window-batched) attention layers.
# Values: reference, torch_sdpa_efficient_attention, torch_flex_attn
sdpa_with_bias_shardwise_backend: torch_flex_attn
# CUDA memory profiling. Activated only when output_path_prefix is set.
# Each rank writes to: {output_path_prefix}_rank{global_rank}.pickle
# Additional kwargs are forwarded to torch.cuda.memory._record_memory_history().
CUDAMemoryProfile:
output_path_prefix: null # set a path prefix to enable, e.g. "profiling/mem"
max_entries: 300000 # max allocation/deallocation events to record
# CPU offloading of activation-checkpoint boundary tensors.
# Lists the distributed module types whose checkpoint-boundary activations
# should be offloaded to CPU during forward and restored on backward.
# Requires the corresponding activation_checkpointing to be enabled.
# Valid types: DiffusionTransformer, MSAModule, PairformerModule
# Set to null or omit to disable.
OffloadActvCkptToCPU: null
# DTensor CP requires num_workers=0 (main-process collation for distributed
# tensor construction) and pin_memory=false.
data:
num_workers: 0
pin_memory: false
model:
validators:
- _target_: boltz.distributed.model.validation.rcsb.DistributedRCSBValidator
val_names: ["RCSB"]
confidence_prediction: ${model.confidence_prediction}
physicalism_metrics: False
rmsd_metrics: True
clash_score_metrics: True
structurev2
trainer:
accelerator: cuda
devices: 1
num_nodes: 1
precision: bf16-mixed
gradient_clip_val: 10.0
accumulate_grad_batches: 1
max_epochs: -1
num_sanity_val_steps: 0
# Optional set wandb here
# wandb:
# name: boltz
# project: boltz
# entity: boltz
output: ./outputs
pretrained: null
resume: null
disable_checkpoint: false
matmul_precision: null
save_top_k: -1
v2: true
data:
datasets:
# RCSB Data
- _target_: boltz.data.module.trainingv2.DatasetConfig
target_dir: ./datasets/rcsb_processed_targets
msa_dir: ./datasets/rcsb_processed_msa
template_dir: null
prob: 0.55
filters:
- _target_: boltz.data.filter.dynamic.size.SizeFilter
min_chains: 1
max_chains: 300
- _target_: boltz.data.filter.dynamic.date.DateFilter
date: "2023-06-01"
ref: released
- _target_: boltz.data.filter.dynamic.resolution.ResolutionFilter
resolution: 9.0
sampler:
_target_: boltz.data.sample.v2.cluster.ClusterSampler
cropper:
_target_: boltz.data.crop.boltz.BoltzCropper
min_neighborhood: 0
max_neighborhood: 40
split: ./scripts/train/assets/validation_ids_v2.txt
symmetry_correction: true
val_group: "RCSB"
# AFDB Distillation Data
- _target_: boltz.data.module.trainingv2.DatasetConfig
target_dir: ./datasets/openfold_processed_targets
msa_dir: ./datasets/openfold_processed_msa
template_dir: null
prob: 0.45
filters:
- _target_: boltz.data.filter.dynamic.size.SizeFilter
min_chains: 1
max_chains: 300
sampler:
_target_: boltz.data.sample.v2.cluster.ClusterSampler
cropper:
_target_: boltz.data.crop.boltz.BoltzCropper
min_neighborhood: 0
max_neighborhood: 40
symmetry_correction: true
override_method: "AFDB"
override_bfactor: true
checkpoint_monitor_val_group: "val/lddt" # dataset __RCSB is turned to "" # which validation dataset group to use for checkpoint monitoring
tokenizer:
_target_: boltz.data.tokenize.boltz2.Boltz2TrainingTokenizer
featurizer:
_target_: boltz.data.feature.featurizerv2_train.Boltz2Featurizer
moldir: #PATH_HERE
max_tokens: 384 # 640 # NOTE: cuEq TriAttn backend on sm100f GPUs requires multiples of 8 token counts per CP shard
max_atoms: 3456 # 5760
max_seqs: 8192
pad_to_max_tokens: true
pad_to_max_atoms: true
pad_to_max_seqs: true
samples_per_epoch: 36096
batch_size: 1
num_workers: 2
random_seed: 42
pin_memory: false
overfit: null
return_train_symmetries: false
return_val_symmetries: true
train_binder_pocket_conditioned_prop: 0.15
val_binder_pocket_conditioned_prop: 0.15
train_contact_conditioned_prop: 0.15
val_contact_conditioned_prop: 0.15
binder_pocket_cutoff_val: 6.0
binder_pocket_cutoff_min: 4.0
binder_pocket_cutoff_max: 20.0
binder_pocket_sampling_geometric_p: 0.3
atoms_per_window_queries: 32
min_dist: 2.0
max_dist: 22.0
num_bins: 64
num_ensembles_train: 1
num_ensembles_val: 1
fix_single_ensemble: false
disto_use_ensemble: true
single_sequence_prop_training: 0.05
max_templates_train: 4
max_templates_val: 4
no_template_prob_train: 0.6
no_template_prob_val: 1.0
use_templates: false
msa_sampling_training: true
bfactor_md_correction: true
model:
_target_: boltz.model.models.boltz2.Boltz2
atom_s: 128
atom_z: 16
token_s: 384
token_z: 128
num_bins: 64
atom_feature_dim: 388
atoms_per_window_queries: 32
atoms_per_window_keys: 128
compile_pairformer: false
compile_templates: false
compile_msa: false
ema: true
ema_decay: 0.999
exclude_ions_from_lddt: true
fix_sym_check: true
cyclic_pos_enc: true
num_val_datasets: 1
bond_type_feature: true
conditioning_cutoff_min: ${data.binder_pocket_cutoff_min}
conditioning_cutoff_max: ${data.binder_pocket_cutoff_max}
use_templates: ${data.use_templates}
predict_bfactor: true
checkpoint_diffusion_conditioning: true
validators:
- _target_: boltz.model.validation.rcsb.RCSBValidator
val_names: ["RCSB"]
confidence_prediction: ${model.confidence_prediction}
embedder_args:
atom_encoder_depth: 3
atom_encoder_heads: 4
add_mol_type_feat: true
add_method_conditioning: true
add_modified_flag: true
add_cyclic_flag: true
msa_args:
msa_s: 64
msa_blocks: 4
msa_dropout: 0.15
z_dropout: 0.25
pairwise_head_width: 32
pairwise_num_heads: 4
use_paired_feature: true
activation_checkpointing: true
template_args:
template_dim: 64
template_blocks: 2
activation_checkpointing: true
pairformer_args:
num_blocks: 48
num_heads: 16
dropout: 0.25
post_layer_norm: false
activation_checkpointing: true
v2: true
score_model_args:
sigma_data: 16
dim_fourier: 256
atom_encoder_depth: 3
atom_encoder_heads: 4
token_transformer_depth: 24
token_transformer_heads: 16
atom_decoder_depth: 3
atom_decoder_heads: 4
conditioning_transition_layers: 2
transformer_post_ln: false
activation_checkpointing: true
confidence_prediction: false
affinity_prediction: false
structure_prediction_training: true
training_args:
recycling_steps: 3
sampling_steps: 20
diffusion_multiplicity: 32
diffusion_samples: 1
affinity_loss_weight: 3e-3
confidence_loss_weight: 1e-4
diffusion_loss_weight: 4.0
distogram_loss_weight: 3e-2
bfactor_loss_weight: 1e-3
adam_beta_1: 0.9
adam_beta_2: 0.95
adam_eps: 0.00000001
lr_scheduler: af3
base_lr: 0.0
max_lr: 0.0005
lr_warmup_no_steps: 1000
lr_start_decay_after_n_steps: 50000
lr_decay_every_n_steps: 50000
lr_decay_factor: 0.95
weight_decay: 0.003
weight_decay_exclude: true
validation_args:
recycling_steps: 3
sampling_steps: 200
diffusion_samples: 5
symmetry_correction: true
diffusion_process_args:
sigma_min: 0.0004
sigma_max: 160.0
sigma_data: 16.0
rho: 7
P_mean: -1.2
P_std: 1.5
gamma_0: 0.8
gamma_min: 1.0
noise_scale: 1.0
step_scale: 1.0
coordinate_augmentation: true
alignment_reverse_diff: true
synchronize_sigmas: false
diffusion_loss_args:
add_smooth_lddt_loss: true
nucleotide_loss_weight: 5.0
ligand_loss_weight: 10.0
filter_by_plddt: 0.0
FYI. Inference work fine
Hi, I've tried to run a training with an
structurev2_small_cp.yamlexample, but it seems the training just got stuck at 0 epochs. Do you have any idea how to fix this issue?Hardware: 4xNVIDIA A100-SXM4-40GB
config
dataset_cp.yaml
structurev2_cp
structurev2
FYI. Inference work fine