Skip to content

Cannot test a training with a provide example #2

@RTae

Description

@RTae

Hi, I've tried to run a training with an structurev2_small_cp.yaml example, but it seems the training just got stuck at 0 epochs. Do you have any idea how to fix this issue?

Hardware: 4xNVIDIA A100-SXM4-40GB

(boltz) root@C.37872624:/workspace/boltz-cp$ OMP_NUM_THREADS=16 torchrun \
  --nnodes 1 \
  --nproc_per_node 4 \
  src/boltz/distributed/train.py \
  scripts/train/configs/datasets_cp.yaml \
  parallel_size.size_cp=4 \
  parallel_size.size_dp=1
train_records before filter 216472
train_records before filter 216472
train_records before filter 216472
train_records before filter 216472
train_records after filter 203113
val_records after filter 398
train_records after filter 203113
val_records after filter 398
train_records after filter 203113
val_records after filter 398
train_records after filter 203113
val_records after filter 398
train_records before filter 268772
train_records before filter 268772
train_records before filter 268772
train_records before filter 268772
train_records after filter 268772
val_records after filter 0
train_records after filter 268772
val_records after filter 0
train_records after filter 268772
val_records after filter 0
train_records after filter 268772
val_records after filter 0
Training dataset size: 4366678
Training dataset size: 268772
Validation dataset size: 398
Training dataset size: 4366678
Training dataset size: 268772
Validation dataset size: 398
Training dataset size: 4366678
Training dataset size: 268772
Validation dataset size: 398
Training dataset size: 4366678
Training dataset size: 268772
Validation dataset size: 398
/workspace/boltz-cp/src/boltz/distributed/model/modules/diffusion.py:509: UserWarning: CPU offloading-based activation checkpointing by default is not used for Boltz-2 so we do not use it in DTensor DiffusionModule.
  self.score_model = DiffusionModule(
/workspace/boltz-cp/src/boltz/distributed/model/modules/diffusion.py:509: UserWarning: CPU offloading-based activation checkpointing by default is not used for Boltz-2 so we do not use it in DTensor DiffusionModule.
  self.score_model = DiffusionModule(
/workspace/boltz-cp/src/boltz/distributed/model/modules/diffusion.py:509: UserWarning: CPU offloading-based activation checkpointing by default is not used for Boltz-2 so we do not use it in DTensor DiffusionModule.
  self.score_model = DiffusionModule(
/workspace/boltz-cp/src/boltz/distributed/model/modules/diffusion.py:509: UserWarning: CPU offloading-based activation checkpointing by default is not used for Boltz-2 so we do not use it in DTensor DiffusionModule.
  self.score_model = DiffusionModule(
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2,3]                                                                                                                                                                                                                   | Name                   | Type                      | Params | Mode
------------------------------------------------------------------------------
0  | validators             | ModuleList                | 0      | train
1  | s_init                 | LinearParamsReplicated    | 147 K  | train
2  | z_init_1               | LinearParamsReplicated    | 49.2 K | train
3  | z_init_2               | LinearParamsReplicated    | 49.2 K | train
4  | s_norm                 | LayerNormParamsReplicated | 768    | train
5  | z_norm                 | LayerNormParamsReplicated | 256    | train
6  | s_recycle              | LinearParamsReplicated    | 147 K  | train                                                                                                                                                                                       7  | z_recycle              | LinearParamsReplicated    | 16.4 K | train
8  | token_bonds            | LinearParamsReplicated    | 128    | train
9  | msa_module             | MSAModule                 | 2.4 M  | train
10 | pairformer_module      | PairformerModule          | 36.9 M | train
11 | distogram_module       | DistogramModule           | 8.3 K  | train
12 | input_embedder         | InputEmbedder             | 1.1 M  | train
13 | rel_pos                | RelativePositionEncoder   | 17.8 K | train
14 | contact_conditioning   | ContactConditioning       | 17.5 K | train
15 | token_bonds_type       | EmbeddingParamsReplicated | 896    | train
16 | diffusion_conditioning | DiffusionConditioning     | 393 K  | train
17 | structure_module       | AtomDiffusion             | 279 M  | train
18 | bfactor_module         | BFactorModule             | 24.6 K | train
------------------------------------------------------------------------------
320 M     Trainable params                                                                                                                                                                                                                                     768       Non-trainable params
320 M     Total params
1,283.330 Total estimated model params size (MB)
1988      Modules in train mode
0         Modules in eval mode
/workspace/boltz-cp/.venv/lib/python3.11/site-packages/pytorch_lightning/utilities/_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
/workspace/boltz-cp/.venv/lib/python3.11/site-packages/pytorch_lightning/utilities/_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
/workspace/boltz-cp/.venv/lib/python3.11/site-packages/pytorch_lightning/utilities/_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
/workspace/boltz-cp/.venv/lib/python3.11/site-packages/pytorch_lightning/utilities/_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
Epoch 0:   0%|                                                                                                                                                                                                                       | 0/36096 [00:00<?, ?it/s]

config

dataset_cp.yaml

defaults:
  - structurev2_cp
  - _self_

trainer:
  accumulate_grad_batches: 4
  
data:
  checkpoint_monitor_val_group: "val/disto_lddt_protein_protein"
  max_tokens: 256
  max_atoms: 2048
  max_seqs: 1024

model:
  checkpoint_diffusion_conditioning: false
  msa_args:
    msa_blocks: 3
    activation_checkpointing: false
  template_args:
    activation_checkpointing: false
  pairformer_args:
    num_blocks: 12
    activation_checkpointing: false
  score_model_args:
    activation_checkpointing: false

structurev2_cp

defaults:
  - structurev2
  - _self_

# General DTensor context-parallel settings for normal-size model.
# For small-model CP training, use structurev2_small_cp.yaml instead.

# DTensor CP uses SingleDeviceStrategy; multi-device/node is managed by
# DistributedManager via parallel_size, not by Lightning.
trainer:
  accelerator: gpu    # must be gpu instead of cuda because CP code manages devices
  devices: 1          # must be 1 for DTensor CP (one device per Lightning Trainer)
  num_nodes: 1        # must be 1; multi-node is handled by torchrun/SLURM
  precision: null     # superseded by top-level precision below

# Context-parallelism and data-parallelism topology.
# Override via CLI: parallel_size.size_cp=4 parallel_size.size_dp=2
# Constraint: size_cp must be a perfect square (2D CP mesh); size_dp * size_cp == world_size.
parallel_size:
  size_cp: 1          # context-parallel group size (must be a perfect square, e.g. 1, 4, 9, 16)
  size_dp: 1          # data-parallel group size
  timeout_nccl: 30    # NCCL timeout in minutes (for CUDA)
  timeout_gloo: 30    # Gloo timeout in minutes (for CPU)

# Training precision. Values: FP32, TF32, BF16, BF16_MIXED, FP16, FP64
precision: BF16_MIXED

# Triangular attention kernel backend.
# Values: reference, cueq, trifast, cueq_fwd_trifast_bwd
# Note: cueq does not support FP32 precision.
triattn_backend: cueq

# Scaled dot-product attention with bias backend (ring-attention layers).
# Values: reference, torch_sdpa_efficient_attention, torch_flex_attn
sdpa_with_bias_backend: torch_flex_attn

# SDPA with bias backend for shardwise (window-batched) attention layers.
# Values: reference, torch_sdpa_efficient_attention, torch_flex_attn
sdpa_with_bias_shardwise_backend: torch_flex_attn

# CUDA memory profiling. Activated only when output_path_prefix is set.
# Each rank writes to: {output_path_prefix}_rank{global_rank}.pickle
# Additional kwargs are forwarded to torch.cuda.memory._record_memory_history().
CUDAMemoryProfile:
  output_path_prefix: null   # set a path prefix to enable, e.g. "profiling/mem"
  max_entries: 300000         # max allocation/deallocation events to record

# CPU offloading of activation-checkpoint boundary tensors.
# Lists the distributed module types whose checkpoint-boundary activations
# should be offloaded to CPU during forward and restored on backward.
# Requires the corresponding activation_checkpointing to be enabled.
# Valid types: DiffusionTransformer, MSAModule, PairformerModule
# Set to null or omit to disable.
OffloadActvCkptToCPU: null

# DTensor CP requires num_workers=0 (main-process collation for distributed
# tensor construction) and pin_memory=false.
data:
  num_workers: 0
  pin_memory: false

model:
  validators:
    - _target_: boltz.distributed.model.validation.rcsb.DistributedRCSBValidator
      val_names: ["RCSB"]
      confidence_prediction: ${model.confidence_prediction}
      physicalism_metrics: False
      rmsd_metrics: True
      clash_score_metrics: True

structurev2

trainer:
  accelerator: cuda
  devices: 1
  num_nodes: 1
  precision: bf16-mixed
  gradient_clip_val: 10.0
  accumulate_grad_batches: 1
  max_epochs: -1
  num_sanity_val_steps: 0

# Optional set wandb here
# wandb:
#   name: boltz
#   project: boltz
#   entity: boltz

output: ./outputs
pretrained: null
resume: null
disable_checkpoint: false
matmul_precision: null
save_top_k: -1
v2: true

data:
  datasets:
    # RCSB Data
    - _target_: boltz.data.module.trainingv2.DatasetConfig
      target_dir: ./datasets/rcsb_processed_targets
      msa_dir: ./datasets/rcsb_processed_msa
      template_dir: null
      prob: 0.55
      filters:
        - _target_: boltz.data.filter.dynamic.size.SizeFilter
          min_chains: 1
          max_chains: 300
        - _target_: boltz.data.filter.dynamic.date.DateFilter
          date: "2023-06-01"
          ref: released
        - _target_: boltz.data.filter.dynamic.resolution.ResolutionFilter
          resolution: 9.0
      sampler:
        _target_: boltz.data.sample.v2.cluster.ClusterSampler
      cropper:
        _target_: boltz.data.crop.boltz.BoltzCropper
        min_neighborhood: 0
        max_neighborhood: 40
      split: ./scripts/train/assets/validation_ids_v2.txt
      symmetry_correction: true
      val_group: "RCSB"

    # AFDB Distillation Data
    - _target_: boltz.data.module.trainingv2.DatasetConfig
      target_dir: ./datasets/openfold_processed_targets
      msa_dir: ./datasets/openfold_processed_msa
      template_dir: null
      prob: 0.45
      filters:
        - _target_: boltz.data.filter.dynamic.size.SizeFilter
          min_chains: 1
          max_chains: 300
      sampler:
        _target_: boltz.data.sample.v2.cluster.ClusterSampler
      cropper:
        _target_: boltz.data.crop.boltz.BoltzCropper
        min_neighborhood: 0
        max_neighborhood: 40
      symmetry_correction: true
      override_method: "AFDB"
      override_bfactor: true

  checkpoint_monitor_val_group: "val/lddt"  # dataset __RCSB is turned to ""   # which validation dataset group to use for checkpoint monitoring
  tokenizer:
    _target_: boltz.data.tokenize.boltz2.Boltz2TrainingTokenizer
  featurizer:
    _target_: boltz.data.feature.featurizerv2_train.Boltz2Featurizer

  moldir: #PATH_HERE
  max_tokens: 384  # 640 # NOTE: cuEq TriAttn backend on sm100f GPUs requires multiples of 8 token counts per CP shard
  max_atoms: 3456 # 5760
  max_seqs: 8192
  pad_to_max_tokens: true
  pad_to_max_atoms: true
  pad_to_max_seqs: true
  samples_per_epoch: 36096
  batch_size: 1
  num_workers: 2
  random_seed: 42
  pin_memory: false
  overfit: null
  return_train_symmetries: false
  return_val_symmetries: true
  train_binder_pocket_conditioned_prop: 0.15
  val_binder_pocket_conditioned_prop: 0.15
  train_contact_conditioned_prop: 0.15
  val_contact_conditioned_prop: 0.15
  binder_pocket_cutoff_val: 6.0
  binder_pocket_cutoff_min: 4.0
  binder_pocket_cutoff_max: 20.0
  binder_pocket_sampling_geometric_p: 0.3
  atoms_per_window_queries: 32
  min_dist: 2.0
  max_dist: 22.0
  num_bins: 64
  num_ensembles_train: 1
  num_ensembles_val: 1
  fix_single_ensemble: false
  disto_use_ensemble: true
  single_sequence_prop_training: 0.05
  max_templates_train: 4
  max_templates_val: 4
  no_template_prob_train: 0.6
  no_template_prob_val: 1.0
  use_templates: false
  msa_sampling_training: true
  bfactor_md_correction: true

model:
  _target_: boltz.model.models.boltz2.Boltz2
  atom_s: 128
  atom_z: 16
  token_s: 384
  token_z: 128
  num_bins: 64
  atom_feature_dim: 388
  atoms_per_window_queries: 32
  atoms_per_window_keys: 128
  compile_pairformer: false
  compile_templates: false
  compile_msa: false
  ema: true
  ema_decay: 0.999
  exclude_ions_from_lddt: true
  fix_sym_check: true
  cyclic_pos_enc: true
  num_val_datasets: 1
  bond_type_feature: true
  conditioning_cutoff_min: ${data.binder_pocket_cutoff_min}
  conditioning_cutoff_max: ${data.binder_pocket_cutoff_max}
  use_templates: ${data.use_templates}
  predict_bfactor: true
  checkpoint_diffusion_conditioning: true

  validators:
    - _target_: boltz.model.validation.rcsb.RCSBValidator
      val_names: ["RCSB"]
      confidence_prediction: ${model.confidence_prediction}

  embedder_args:
    atom_encoder_depth: 3
    atom_encoder_heads: 4
    add_mol_type_feat: true
    add_method_conditioning: true
    add_modified_flag: true
    add_cyclic_flag: true

  msa_args:
    msa_s: 64
    msa_blocks: 4
    msa_dropout: 0.15
    z_dropout: 0.25
    pairwise_head_width: 32
    pairwise_num_heads: 4
    use_paired_feature: true
    activation_checkpointing: true

  template_args:
    template_dim: 64
    template_blocks: 2
    activation_checkpointing: true

  pairformer_args:
    num_blocks: 48
    num_heads: 16
    dropout: 0.25
    post_layer_norm: false
    activation_checkpointing: true
    v2: true

  score_model_args:
    sigma_data: 16
    dim_fourier: 256
    atom_encoder_depth: 3
    atom_encoder_heads: 4
    token_transformer_depth: 24
    token_transformer_heads: 16
    atom_decoder_depth: 3
    atom_decoder_heads: 4
    conditioning_transition_layers: 2
    transformer_post_ln: false
    activation_checkpointing: true

  confidence_prediction: false
  affinity_prediction: false
  structure_prediction_training: true

  training_args:
    recycling_steps: 3
    sampling_steps: 20
    diffusion_multiplicity: 32
    diffusion_samples: 1
    affinity_loss_weight: 3e-3
    confidence_loss_weight: 1e-4
    diffusion_loss_weight: 4.0
    distogram_loss_weight: 3e-2
    bfactor_loss_weight: 1e-3
    adam_beta_1: 0.9
    adam_beta_2: 0.95
    adam_eps: 0.00000001
    lr_scheduler: af3
    base_lr: 0.0
    max_lr: 0.0005
    lr_warmup_no_steps: 1000
    lr_start_decay_after_n_steps: 50000
    lr_decay_every_n_steps: 50000
    lr_decay_factor: 0.95
    weight_decay: 0.003
    weight_decay_exclude: true

  validation_args:
    recycling_steps: 3
    sampling_steps: 200
    diffusion_samples: 5
    symmetry_correction: true

  diffusion_process_args:
    sigma_min: 0.0004
    sigma_max: 160.0
    sigma_data: 16.0
    rho: 7
    P_mean: -1.2
    P_std: 1.5
    gamma_0: 0.8
    gamma_min: 1.0
    noise_scale: 1.0
    step_scale: 1.0
    coordinate_augmentation: true
    alignment_reverse_diff: true
    synchronize_sigmas: false

  diffusion_loss_args:
    add_smooth_lddt_loss: true
    nucleotide_loss_weight: 5.0
    ligand_loss_weight: 10.0
    filter_by_plddt: 0.0

FYI. Inference work fine

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions