Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions mlpf/model/mlpf.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ def __init__(
):
super(MLPF, self).__init__()

self.m_pion = 0.13957021 # GeV

self.conv_type = conv_type

self.act = get_activation(activation)
Expand Down Expand Up @@ -365,7 +367,6 @@ def __init__(
self.nn_eta = RegressionOutput(eta_mode, embed_dim, width, self.act, dropout_ff, self.elemtypes_nonzero)
self.nn_sin_phi = RegressionOutput(sin_phi_mode, embed_dim, width, self.act, dropout_ff, self.elemtypes_nonzero)
self.nn_cos_phi = RegressionOutput(cos_phi_mode, embed_dim, width, self.act, dropout_ff, self.elemtypes_nonzero)
self.nn_energy = RegressionOutput(energy_mode, embed_dim, width, self.act, dropout_ff, self.elemtypes_nonzero)

if self.use_pre_layernorm: # add final norm after last attention block as per https://arxiv.org/abs/2002.04745
self.final_norm_id = torch.nn.LayerNorm(decoding_dim)
Expand Down Expand Up @@ -427,15 +428,21 @@ def forward(self, X_features, mask):
preds_sin_phi = self.nn_sin_phi(X_features, final_embedding_reg, X_features[..., 3:4])
preds_cos_phi = self.nn_cos_phi(X_features, final_embedding_reg, X_features[..., 4:5])

# ensure created particle has positive mass^2 by computing energy from pt and adding a positive-only correction
pt_real = torch.exp(preds_pt.detach()) * X_features[..., 1:2]
# Assume mass of a charged pion and compute energy from pt and eta
pt_real = torch.exp(preds_pt.detach()) * X_features[..., 1:2] # transform pt back into physical space
pz_real = pt_real * torch.sinh(preds_eta.detach())
e_real = torch.log(torch.sqrt(pt_real**2 + pz_real**2) / X_features[..., 5:6])
e_real[~mask] = 0
e_real[torch.isinf(e_real)] = 0
e_real[torch.isnan(e_real)] = 0
preds_energy = e_real + torch.nn.functional.relu(self.nn_energy(X_features, final_embedding_reg, X_features[..., 5:6]))

# E^2 = pt^2 + pz^2 + m^2 in natural units
e_real = torch.sqrt(pt_real**2 + pz_real**2 + self.m_pion**2)
preds_energy = torch.log(e_real / X_features[..., 5:6]) # transform E back into normalized space, this is undone at inference time

# Handle NaNs and infs
preds_energy[~mask] = 0
preds_energy[torch.isinf(preds_energy)] = 0
preds_energy[torch.isnan(preds_energy)] = 0

preds_momentum = torch.cat([preds_pt, preds_eta, preds_sin_phi, preds_cos_phi, preds_energy], axis=-1)

return preds_binary_particle, preds_pid, preds_momentum


Expand Down
2 changes: 2 additions & 0 deletions mlpf/model/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,8 @@ def run(rank, world_size, config, args, outdir, logfile):
jet_match_dr=0.1,
dir_name=testdir_name,
)
if world_size > 1:
dist.barrier() # block until all workers finished executing run_predictions()

if (rank == 0) or (rank == "cpu"): # make plots only on a single machine
if args.make_plots:
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ plotly
pre-commit
protobuf
pyarrow
ray[tune]
ray[train,tune]
scikit-learn
scikit-optimize
scipy
Expand Down
36 changes: 30 additions & 6 deletions scripts/flatiron/pt_raytrain_a100.slurm
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#SBATCH --exclusive
#SBATCH --tasks-per-node=1
#SBATCH -p gpu
#SBATCH --gpus=4
#SBATCH --gpus-per-task=4
#SBATCH --cpus-per-task=64
#SBATCH --constraint=a100-80gb&sxm4
Expand All @@ -27,12 +28,36 @@ module --force purge; module load modules/2.2-20230808
module load slurm gcc cmake cuda/12.1.1 cudnn/8.9.2.26-12.x nccl openmpi apptainer

nvidia-smi
source ~/miniconda3/bin/activate pytorch
# source ~/miniconda3/bin/activate pytorch
source ~/miniforge3/bin/activate mlpf
which python3
python3 --version

export CUDA_VISIBLE_DEVICES=0,1,2,3
num_gpus=${SLURM_GPUS_PER_TASK} # gpus per compute node

num_gpus=$((SLURM_GPUS/SLURM_NNODES)) # gpus per compute node

export SRUN_CPUS_PER_TASK=${SLURM_CPUS_PER_TASK} # necessary on JURECA for Ray to work

## Disable Ray Usage Stats
export RAY_USAGE_STATS_DISABLE=1


echo "DEBUG: SLURM_JOB_ID: $SLURM_JOB_ID"
echo "DEBUG: SLURM_JOB_NODELIST: $SLURM_JOB_NODELIST"
echo "DEBUG: SLURM_NNODES: $SLURM_NNODES"
echo "DEBUG: SLURM_NTASKS: $SLURM_NTASKS"
echo "DEBUG: SLURM_TASKS_PER_NODE: $SLURM_TASKS_PER_NODE"
echo "DEBUG: SLURM_SUBMIT_HOST: $SLURM_SUBMIT_HOST"
echo "DEBUG: SLURMD_NODENAME: $SLURMD_NODENAME"
echo "DEBUG: SLURM_NODEID: $SLURM_NODEID"
echo "DEBUG: SLURM_LOCALID: $SLURM_LOCALID"
echo "DEBUG: SLURM_PROCID: $SLURM_PROCID"
echo "DEBUG: CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES"
echo "DEBUG: SLURM_JOB_NUM_NODES: $SLURM_JOB_NUM_NODES"
echo "DEBUG: SLURM_CPUS_PER_TASK: $SLURM_CPUS_PER_TASK"
echo "DEBUG: SLURM_GPUS_PER_TASK: $SLURM_GPUS_PER_TASK"
echo "DEBUG: SLURM_GPUS: $SLURM_GPUS"
echo "DEBUG: num_gpus: $num_gpus"


if [ "$SLURM_JOB_NUM_NODES" -gt 1 ]; then
Expand Down Expand Up @@ -81,12 +106,11 @@ python3 -u mlpf/pipeline.py --train --ray-train \
--config $1 \
--prefix $2 \
--ray-cpus $((SLURM_CPUS_PER_TASK*SLURM_JOB_NUM_NODES)) \
--gpus $((SLURM_GPUS_PER_TASK*SLURM_JOB_NUM_NODES)) \
--gpus $((SLURM_GPUS)) \
--gpu-batch-multiplier 8 \
--num-workers 8 \
--prefetch-factor 16 \
--experiments-dir /mnt/ceph/users/ewulff/particleflow/experiments \
--local \
--comet
--local

echo 'Training done.'
32 changes: 28 additions & 4 deletions scripts/flatiron/pt_raytrain_h100.slurm
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#SBATCH --exclusive
#SBATCH --tasks-per-node=1
#SBATCH -p gpu
#SBATCH --gpus=8
#SBATCH --gpus-per-task=8
#SBATCH --cpus-per-task=64
#SBATCH --constraint=ib-h100p
Expand All @@ -27,12 +28,35 @@ module --force purge; module load modules/2.2-20230808
module load slurm gcc cmake cuda/12.1.1 cudnn/8.9.2.26-12.x nccl openmpi apptainer

nvidia-smi
source ~/miniconda3/bin/activate pytorch
# source ~/miniconda3/bin/activate pytorch
source ~/miniforge3/bin/activate mlpf
which python3
python3 --version

export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
num_gpus=${SLURM_GPUS_PER_TASK} # gpus per compute node
num_gpus=$((SLURM_GPUS/SLURM_NNODES)) # gpus per compute node

export SRUN_CPUS_PER_TASK=${SLURM_CPUS_PER_TASK} # necessary on JURECA for Ray to work

## Disable Ray Usage Stats
export RAY_USAGE_STATS_DISABLE=1

echo "DEBUG: SLURM_JOB_ID: $SLURM_JOB_ID"
echo "DEBUG: SLURM_JOB_NODELIST: $SLURM_JOB_NODELIST"
echo "DEBUG: SLURM_NNODES: $SLURM_NNODES"
echo "DEBUG: SLURM_NTASKS: $SLURM_NTASKS"
echo "DEBUG: SLURM_TASKS_PER_NODE: $SLURM_TASKS_PER_NODE"
echo "DEBUG: SLURM_SUBMIT_HOST: $SLURM_SUBMIT_HOST"
echo "DEBUG: SLURMD_NODENAME: $SLURMD_NODENAME"
echo "DEBUG: SLURM_NODEID: $SLURM_NODEID"
echo "DEBUG: SLURM_LOCALID: $SLURM_LOCALID"
echo "DEBUG: SLURM_PROCID: $SLURM_PROCID"
echo "DEBUG: CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES"
echo "DEBUG: SLURM_JOB_NUM_NODES: $SLURM_JOB_NUM_NODES"
echo "DEBUG: SLURM_CPUS_PER_TASK: $SLURM_CPUS_PER_TASK"
echo "DEBUG: SLURM_GPUS_PER_TASK: $SLURM_GPUS_PER_TASK"
echo "DEBUG: SLURM_GPUS: $SLURM_GPUS"
echo "DEBUG: num_gpus: $num_gpus"


if [ "$SLURM_JOB_NUM_NODES" -gt 1 ]; then
################# DON NOT CHANGE THINGS HERE UNLESS YOU KNOW WHAT YOU ARE DOING ###############
Expand Down Expand Up @@ -80,7 +104,7 @@ python3 -u mlpf/pipeline.py --train --ray-train \
--config $1 \
--prefix $2 \
--ray-cpus $((SLURM_CPUS_PER_TASK*SLURM_JOB_NUM_NODES)) \
--gpus $((SLURM_GPUS_PER_TASK*SLURM_JOB_NUM_NODES)) \
--gpus $((SLURM_GPUS*SLURM_JOB_NUM_NODES)) \
--gpu-batch-multiplier 8 \
--num-workers 4 \
--prefetch-factor 8 \
Expand Down
3 changes: 2 additions & 1 deletion scripts/flatiron/pt_test.slurm
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ module --force purge; module load modules/2.2-20230808
module load slurm gcc cmake cuda/12.1.1 cudnn/8.9.2.26-12.x nccl openmpi apptainer

nvidia-smi
source ~/miniconda3/bin/activate pytorch
# source ~/miniconda3/bin/activate pytorch
source ~/miniforge3/bin/activate mlpf
which python3
python3 --version

Expand Down