From db07ecf4c3adc2229295981de75f3f4aa03e67ad Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Fri, 14 Feb 2025 00:26:36 +0100 Subject: [PATCH 1/9] Add LGATr wrapper to networks/, and LGATr option in train_TopLandscape.sh --- networks/example_LGATr.py | 55 +++++++++++++++++++++++++++++++++++++++ train_TopLandscape.sh | 5 +++- 2 files changed, 59 insertions(+), 1 deletion(-) create mode 100644 networks/example_LGATr.py diff --git a/networks/example_LGATr.py b/networks/example_LGATr.py new file mode 100644 index 0000000..6b56f35 --- /dev/null +++ b/networks/example_LGATr.py @@ -0,0 +1,55 @@ +import torch +from weaver.nn.model.LGATr import LGATrTagger +from weaver.utils.logger import _logger + + +class LGATrTaggerWrapper(torch.nn.Module): + def __init__(self, **kwargs) -> None: + super().__init__() + self.mod = LGATrTagger(**kwargs) + + @torch.jit.ignore + def no_weight_decay(self): + return { + "mod.cls_token", + } + + def forward(self, points, features, lorentz_vectors, mask): + f = self.mod(features, v=lorentz_vectors, mask=mask) + return f + + +def get_model(data_config, **kwargs): + + cfg = dict( + in_s_channels=len(data_config.input_dicts["pf_features"]), + num_classes=len(data_config.label_value), + # network configurations + hidden_mv_channels=16, + hidden_s_channels=32, + num_blocks=12, + num_heads=8, + ) + + cfg.update(**kwargs) + _logger.info("Model config: %s" % str(cfg)) + + model = LGATrTaggerWrapper(**cfg) + + model_info = { + "input_names": list(data_config.input_names), + "input_shapes": { + k: ((1,) + s[1:]) for k, s in data_config.input_shapes.items() + }, + "output_names": ["softmax"], + "dynamic_axes": { + **{k: {0: "N", 2: "n_" + k.split("_")[0]} for k in data_config.input_names}, + **{"softmax": {0: "N"}}, + }, + } + + return model, model_info + + +def get_loss(data_config, **kwargs): + return torch.nn.CrossEntropyLoss() diff --git a/train_TopLandscape.sh b/train_TopLandscape.sh index 8594c62..32f8772 100755 --- a/train_TopLandscape.sh +++ b/train_TopLandscape.sh @@ -12,7 +12,7 @@ DATADIR=${DATADIR_TopLandscape} # set a comment via `COMMENT` suffix=${COMMENT} -# PN, PFN, PCNN, ParT +# PN, PFN, PCNN, ParT, LGATr model=$1 extraopts="" if [[ "$model" == "ParT" ]]; then @@ -37,6 +37,9 @@ elif [[ "$model" == "PCNN" ]]; then modelopts="networks/example_PCNN.py" lr="2e-2" extraopts="--batch-size 4096" +elif [[ "$model" == "LGATr" ]]; then + modelopts="networks/example_LGATr.py --use-amp --optimizer-option weight_decay 0.01" + lr="1e-3" else echo "Invalid model $model!" exit 1 From 65329e1f9624c045ccfe896a4eb35987d3bdc189 Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Sun, 16 Feb 2025 20:24:03 +0100 Subject: [PATCH 2/9] Add LGATr option to JetClass and QuarkGluon training scripts --- train_JetClass.sh | 5 ++++- train_QuarkGluon.sh | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/train_JetClass.sh b/train_JetClass.sh index ce1e967..7b5ed1d 100755 --- a/train_JetClass.sh +++ b/train_JetClass.sh @@ -27,7 +27,7 @@ samples_per_epoch=$((10000 * 1024 / $NGPUS)) samples_per_epoch_val=$((10000 * 128)) dataopts="--num-workers 2 --fetch-step 0.01" -# PN, PFN, PCNN, ParT +# PN, PFN, PCNN, ParT, LGATr model=$1 if [[ "$model" == "ParT" ]]; then modelopts="networks/example_ParticleTransformer.py --use-amp" @@ -44,6 +44,9 @@ elif [[ "$model" == "PCNN" ]]; then else echo "Invalid model $model!" exit 1 +elif [[ "$model" == "LGATr" ]]; then + modelopts="networks/example_LGATr.py --use-amp" + batchopts="--batch-size 512 --start-lr 1e-3" fi # "kin", "kinpid", "full" diff --git a/train_QuarkGluon.sh b/train_QuarkGluon.sh index ee60ef0..b8aa6c2 100755 --- a/train_QuarkGluon.sh +++ b/train_QuarkGluon.sh @@ -13,7 +13,7 @@ DATADIR=${DATADIR_QuarkGluon} # set a comment via `COMMENT` suffix=${COMMENT} -# PN, PFN, PCNN, ParT +# PN, PFN, PCNN, ParT, LGATr model=$1 extraopts="" if [[ "$model" == "ParT" ]]; then @@ -41,6 +41,9 @@ elif [[ "$model" == "PCNN" ]]; then else echo "Invalid model $model!" exit 1 +elif [[ "$model" == "LGATr" ]]; then + modelopts="networks/example_LGATr.py --use-amp --optimizer-option weight_decay 0.01" + lr="1e-3" fi # "kin", "kinpid", "kinpidplus" From d8506878b20605c43032ab9d98ffc8035c2c1098 Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Sun, 16 Feb 2025 20:41:49 +0100 Subject: [PATCH 3/9] Add more hyperparameter handles to example_LGATr.py --- networks/example_LGATr.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/networks/example_LGATr.py b/networks/example_LGATr.py index 6b56f35..08f8502 100644 --- a/networks/example_LGATr.py +++ b/networks/example_LGATr.py @@ -24,11 +24,18 @@ def get_model(data_config, **kwargs): cfg = dict( in_s_channels=len(data_config.input_dicts["pf_features"]), num_classes=len(data_config.label_value), + # symmetry-breaking configurations + spurion_token=True, + beam_reference="xyplane", + add_time_reference=True, + two_beams=True, # network configurations hidden_mv_channels=16, hidden_s_channels=32, num_blocks=12, num_heads=8, + double_layernorm=True, + head_scale=True, ) cfg.update(**kwargs) From d0c5eaac9573d7022d1869b7cd79e054001d8282 Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Sun, 16 Feb 2025 20:42:40 +0100 Subject: [PATCH 4/9] Add handle full/subgroup to example_LGATr.py --- networks/example_LGATr.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/networks/example_LGATr.py b/networks/example_LGATr.py index 08f8502..377383a 100644 --- a/networks/example_LGATr.py +++ b/networks/example_LGATr.py @@ -20,6 +20,11 @@ def forward(self, points, features, lorentz_vectors, mask): def get_model(data_config, **kwargs): + use_fully_connected_subgroup = False + if not use_fully_connected_subgroup: + import weaver.nn.model.gatr.primitives.linear + + weaver.nn.model.gatr.primitives.linear.USE_FULLY_CONNECTED_SUBGROUP = False cfg = dict( in_s_channels=len(data_config.input_dicts["pf_features"]), From 21ca6a545f0ba20597f71da019296bd9378bf95d Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Mon, 17 Feb 2025 08:57:38 +0100 Subject: [PATCH 5/9] Support gradient checkpointing --- networks/example_LGATr.py | 1 + 1 file changed, 1 insertion(+) diff --git a/networks/example_LGATr.py b/networks/example_LGATr.py index 377383a..c70f846 100644 --- a/networks/example_LGATr.py +++ b/networks/example_LGATr.py @@ -41,6 +41,7 @@ def get_model(data_config, **kwargs): num_heads=8, double_layernorm=True, head_scale=True, + checkpoint_blocks=False, ) cfg.update(**kwargs) From 272f1e9184ee37d8d44aaae79d150286988ea994 Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Sat, 15 Mar 2025 22:38:16 +0100 Subject: [PATCH 6/9] Update based on changes in weaver-core --- networks/example_LGATr.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/networks/example_LGATr.py b/networks/example_LGATr.py index c70f846..02f51f2 100644 --- a/networks/example_LGATr.py +++ b/networks/example_LGATr.py @@ -20,21 +20,17 @@ def forward(self, points, features, lorentz_vectors, mask): def get_model(data_config, **kwargs): - use_fully_connected_subgroup = False - if not use_fully_connected_subgroup: - import weaver.nn.model.gatr.primitives.linear - - weaver.nn.model.gatr.primitives.linear.USE_FULLY_CONNECTED_SUBGROUP = False cfg = dict( in_s_channels=len(data_config.input_dicts["pf_features"]), num_classes=len(data_config.label_value), # symmetry-breaking configurations spurion_token=True, - beam_reference="xyplane", - add_time_reference=True, - two_beams=True, + beam_spurion="xyplane", + add_time_spurion=True, + beam_mirror=True, # network configurations + global_token=True, hidden_mv_channels=16, hidden_s_channels=32, num_blocks=12, @@ -42,6 +38,11 @@ def get_model(data_config, **kwargs): double_layernorm=True, head_scale=True, checkpoint_blocks=False, + # gatr configurations + use_fully_connected_subgroup=True, + mix_pseudoscalar_into_scalar=True, + use_bivector=True, + use_geometric_product=True, ) cfg.update(**kwargs) From ec90790b29becb6ccbd5e730bfed51b13e985cdb Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Tue, 25 Mar 2025 17:52:52 +0100 Subject: [PATCH 7/9] Turn amp off by default --- train_TopLandscape.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_TopLandscape.sh b/train_TopLandscape.sh index 32f8772..37860f6 100755 --- a/train_TopLandscape.sh +++ b/train_TopLandscape.sh @@ -38,7 +38,7 @@ elif [[ "$model" == "PCNN" ]]; then lr="2e-2" extraopts="--batch-size 4096" elif [[ "$model" == "LGATr" ]]; then - modelopts="networks/example_LGATr.py --use-amp --optimizer-option weight_decay 0.01" + modelopts="networks/example_LGATr.py --optimizer-option weight_decay 0.01" lr="1e-3" else echo "Invalid model $model!" From 149b3c58dfe0725e402941f77e13e022a67d44be Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Sun, 30 Mar 2025 22:55:03 +0200 Subject: [PATCH 8/9] Update GATr default configs --- train_JetClass.sh | 6 +++--- train_QuarkGluon.sh | 10 +++++----- train_TopLandscape.sh | 8 ++++---- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/train_JetClass.sh b/train_JetClass.sh index 7b5ed1d..4fa83d9 100755 --- a/train_JetClass.sh +++ b/train_JetClass.sh @@ -41,12 +41,12 @@ elif [[ "$model" == "PFN" ]]; then elif [[ "$model" == "PCNN" ]]; then modelopts="networks/example_PCNN.py" batchopts="--batch-size 4096 --start-lr 2e-2" +elif [[ "$model" == "LGATr" ]]; then + modelopts="networks/example_LGATr.py" + batchopts="--batch-size 512 --start-lr 3e-4 --optimizer lion" else echo "Invalid model $model!" exit 1 -elif [[ "$model" == "LGATr" ]]; then - modelopts="networks/example_LGATr.py --use-amp" - batchopts="--batch-size 512 --start-lr 1e-3" fi # "kin", "kinpid", "full" diff --git a/train_QuarkGluon.sh b/train_QuarkGluon.sh index b8aa6c2..e7c4059 100755 --- a/train_QuarkGluon.sh +++ b/train_QuarkGluon.sh @@ -38,12 +38,12 @@ elif [[ "$model" == "PCNN" ]]; then modelopts="networks/example_PCNN.py" lr="2e-2" extraopts="--batch-size 4096" +elif [[ "$model" == "LGATr" ]]; then + modelopts="networks/example_LGATr.py --batch-size 128 --optimizer lion --optimizer-option weight_decay 0.2" + lr="3e-4" else echo "Invalid model $model!" exit 1 -elif [[ "$model" == "LGATr" ]]; then - modelopts="networks/example_LGATr.py --use-amp --optimizer-option weight_decay 0.01" - lr="1e-3" fi # "kin", "kinpid", "kinpidplus" @@ -69,10 +69,10 @@ fi weaver \ --data-train "${DATADIR}/train_file_*.parquet" \ --data-test "${DATADIR}/test_file_*.parquet" \ - --data-config data/QuarkGluon/qg_${FEATURE_TYPE}.yaml --network-config $modelopts \ + --data-config data/QuarkGluon/qg_${FEATURE_TYPE}.yaml \ --model-prefix training/QuarkGluon/${model}/{auto}${suffix}/net \ --num-workers 1 --fetch-step 1 --in-memory --train-val-split 0.8889 \ --batch-size 512 --samples-per-epoch 1600000 --samples-per-epoch-val 200000 --num-epochs 20 --gpus 0 \ --start-lr $lr --optimizer ranger --log logs/QuarkGluon_${model}_{auto}${suffix}.log --predict-output pred.root \ --tensorboard QuarkGluon_${FEATURE_TYPE}_${model}${suffix} \ - ${extraopts} "${@:3}" + --network-config $modelopts ${extraopts} "${@:3}" diff --git a/train_TopLandscape.sh b/train_TopLandscape.sh index 37860f6..6bfe7ed 100755 --- a/train_TopLandscape.sh +++ b/train_TopLandscape.sh @@ -38,8 +38,8 @@ elif [[ "$model" == "PCNN" ]]; then lr="2e-2" extraopts="--batch-size 4096" elif [[ "$model" == "LGATr" ]]; then - modelopts="networks/example_LGATr.py --optimizer-option weight_decay 0.01" - lr="1e-3" + modelopts="networks/example_LGATr.py --batch-size 128 --optimizer lion --optimizer-option weight_decay 0.2" + lr="3e-4" else echo "Invalid model $model!" exit 1 @@ -57,10 +57,10 @@ weaver \ --data-train "${DATADIR}/train_file.parquet" \ --data-val "${DATADIR}/val_file.parquet" \ --data-test "${DATADIR}/test_file.parquet" \ - --data-config data/TopLandscape/top_${FEATURE_TYPE}.yaml --network-config $modelopts \ + --data-config data/TopLandscape/top_${FEATURE_TYPE}.yaml \ --model-prefix training/TopLandscape/${model}/{auto}${suffix}/net \ --num-workers 1 --fetch-step 1 --in-memory \ --batch-size 512 --samples-per-epoch $((2400 * 512)) --samples-per-epoch-val $((800 * 512)) --num-epochs 20 --gpus 0 \ --start-lr $lr --optimizer ranger --log logs/TopLandscape_${model}_{auto}${suffix}.log --predict-output pred.root \ --tensorboard TopLandscape_${FEATURE_TYPE}_${model}${suffix} \ - ${extraopts} "${@:3}" + --network-config $modelopts ${extraopts} "${@:3}" From c3bf0adc775fa8ae011715708e71df030bb3610d Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Sun, 1 Jun 2025 13:21:11 +0200 Subject: [PATCH 9/9] Changes based on recent lgatr rework --- networks/example_LGATr.py | 1 - 1 file changed, 1 deletion(-) diff --git a/networks/example_LGATr.py b/networks/example_LGATr.py index 02f51f2..0f05de6 100644 --- a/networks/example_LGATr.py +++ b/networks/example_LGATr.py @@ -35,7 +35,6 @@ def get_model(data_config, **kwargs): hidden_s_channels=32, num_blocks=12, num_heads=8, - double_layernorm=True, head_scale=True, checkpoint_blocks=False, # gatr configurations