diff --git a/networks/example_LGATr.py b/networks/example_LGATr.py new file mode 100644 index 0000000..0f05de6 --- /dev/null +++ b/networks/example_LGATr.py @@ -0,0 +1,68 @@ +import torch +from weaver.nn.model.LGATr import LGATrTagger +from weaver.utils.logger import _logger + + +class LGATrTaggerWrapper(torch.nn.Module): + def __init__(self, **kwargs) -> None: + super().__init__() + self.mod = LGATrTagger(**kwargs) + + @torch.jit.ignore + def no_weight_decay(self): + return { + "mod.cls_token", + } + + def forward(self, points, features, lorentz_vectors, mask): + f = self.mod(features, v=lorentz_vectors, mask=mask) + return f + + +def get_model(data_config, **kwargs): + + cfg = dict( + in_s_channels=len(data_config.input_dicts["pf_features"]), + num_classes=len(data_config.label_value), + # symmetry-breaking configurations + spurion_token=True, + beam_spurion="xyplane", + add_time_spurion=True, + beam_mirror=True, + # network configurations + global_token=True, + hidden_mv_channels=16, + hidden_s_channels=32, + num_blocks=12, + num_heads=8, + head_scale=True, + checkpoint_blocks=False, + # gatr configurations + use_fully_connected_subgroup=True, + mix_pseudoscalar_into_scalar=True, + use_bivector=True, + use_geometric_product=True, + ) + + cfg.update(**kwargs) + _logger.info("Model config: %s" % str(cfg)) + + model = LGATrTaggerWrapper(**cfg) + + model_info = { + "input_names": list(data_config.input_names), + "input_shapes": { + k: ((1,) + s[1:]) for k, s in data_config.input_shapes.items() + }, + "output_names": ["softmax"], + "dynamic_axes": { + **{k: {0: "N", 2: "n_" + k.split("_")[0]} for k in data_config.input_names}, + **{"softmax": {0: "N"}}, + }, + } + + return model, model_info + + +def get_loss(data_config, **kwargs): + return torch.nn.CrossEntropyLoss() diff --git a/train_JetClass.sh b/train_JetClass.sh index ce1e967..4fa83d9 100755 --- a/train_JetClass.sh +++ b/train_JetClass.sh @@ -27,7 +27,7 @@ samples_per_epoch=$((10000 * 1024 / $NGPUS)) samples_per_epoch_val=$((10000 * 128)) dataopts="--num-workers 2 --fetch-step 0.01" -# PN, PFN, PCNN, ParT +# PN, PFN, PCNN, ParT, LGATr model=$1 if [[ "$model" == "ParT" ]]; then modelopts="networks/example_ParticleTransformer.py --use-amp" @@ -41,6 +41,9 @@ elif [[ "$model" == "PFN" ]]; then elif [[ "$model" == "PCNN" ]]; then modelopts="networks/example_PCNN.py" batchopts="--batch-size 4096 --start-lr 2e-2" +elif [[ "$model" == "LGATr" ]]; then + modelopts="networks/example_LGATr.py" + batchopts="--batch-size 512 --start-lr 3e-4 --optimizer lion" else echo "Invalid model $model!" exit 1 diff --git a/train_QuarkGluon.sh b/train_QuarkGluon.sh index ee60ef0..e7c4059 100755 --- a/train_QuarkGluon.sh +++ b/train_QuarkGluon.sh @@ -13,7 +13,7 @@ DATADIR=${DATADIR_QuarkGluon} # set a comment via `COMMENT` suffix=${COMMENT} -# PN, PFN, PCNN, ParT +# PN, PFN, PCNN, ParT, LGATr model=$1 extraopts="" if [[ "$model" == "ParT" ]]; then @@ -38,6 +38,9 @@ elif [[ "$model" == "PCNN" ]]; then modelopts="networks/example_PCNN.py" lr="2e-2" extraopts="--batch-size 4096" +elif [[ "$model" == "LGATr" ]]; then + modelopts="networks/example_LGATr.py --batch-size 128 --optimizer lion --optimizer-option weight_decay 0.2" + lr="3e-4" else echo "Invalid model $model!" exit 1 @@ -66,10 +69,10 @@ fi weaver \ --data-train "${DATADIR}/train_file_*.parquet" \ --data-test "${DATADIR}/test_file_*.parquet" \ - --data-config data/QuarkGluon/qg_${FEATURE_TYPE}.yaml --network-config $modelopts \ + --data-config data/QuarkGluon/qg_${FEATURE_TYPE}.yaml \ --model-prefix training/QuarkGluon/${model}/{auto}${suffix}/net \ --num-workers 1 --fetch-step 1 --in-memory --train-val-split 0.8889 \ --batch-size 512 --samples-per-epoch 1600000 --samples-per-epoch-val 200000 --num-epochs 20 --gpus 0 \ --start-lr $lr --optimizer ranger --log logs/QuarkGluon_${model}_{auto}${suffix}.log --predict-output pred.root \ --tensorboard QuarkGluon_${FEATURE_TYPE}_${model}${suffix} \ - ${extraopts} "${@:3}" + --network-config $modelopts ${extraopts} "${@:3}" diff --git a/train_TopLandscape.sh b/train_TopLandscape.sh index 8594c62..6bfe7ed 100755 --- a/train_TopLandscape.sh +++ b/train_TopLandscape.sh @@ -12,7 +12,7 @@ DATADIR=${DATADIR_TopLandscape} # set a comment via `COMMENT` suffix=${COMMENT} -# PN, PFN, PCNN, ParT +# PN, PFN, PCNN, ParT, LGATr model=$1 extraopts="" if [[ "$model" == "ParT" ]]; then @@ -37,6 +37,9 @@ elif [[ "$model" == "PCNN" ]]; then modelopts="networks/example_PCNN.py" lr="2e-2" extraopts="--batch-size 4096" +elif [[ "$model" == "LGATr" ]]; then + modelopts="networks/example_LGATr.py --batch-size 128 --optimizer lion --optimizer-option weight_decay 0.2" + lr="3e-4" else echo "Invalid model $model!" exit 1 @@ -54,10 +57,10 @@ weaver \ --data-train "${DATADIR}/train_file.parquet" \ --data-val "${DATADIR}/val_file.parquet" \ --data-test "${DATADIR}/test_file.parquet" \ - --data-config data/TopLandscape/top_${FEATURE_TYPE}.yaml --network-config $modelopts \ + --data-config data/TopLandscape/top_${FEATURE_TYPE}.yaml \ --model-prefix training/TopLandscape/${model}/{auto}${suffix}/net \ --num-workers 1 --fetch-step 1 --in-memory \ --batch-size 512 --samples-per-epoch $((2400 * 512)) --samples-per-epoch-val $((800 * 512)) --num-epochs 20 --gpus 0 \ --start-lr $lr --optimizer ranger --log logs/TopLandscape_${model}_{auto}${suffix}.log --predict-output pred.root \ --tensorboard TopLandscape_${FEATURE_TYPE}_${model}${suffix} \ - ${extraopts} "${@:3}" + --network-config $modelopts ${extraopts} "${@:3}"