Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions networks/example_LGATr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import torch
from weaver.nn.model.LGATr import LGATrTagger
from weaver.utils.logger import _logger


class LGATrTaggerWrapper(torch.nn.Module):
def __init__(self, **kwargs) -> None:
super().__init__()
self.mod = LGATrTagger(**kwargs)

@torch.jit.ignore
def no_weight_decay(self):
return {
"mod.cls_token",
}

def forward(self, points, features, lorentz_vectors, mask):
f = self.mod(features, v=lorentz_vectors, mask=mask)
return f


def get_model(data_config, **kwargs):

cfg = dict(
in_s_channels=len(data_config.input_dicts["pf_features"]),
num_classes=len(data_config.label_value),
# symmetry-breaking configurations
spurion_token=True,
beam_spurion="xyplane",
add_time_spurion=True,
beam_mirror=True,
# network configurations
global_token=True,
hidden_mv_channels=16,
hidden_s_channels=32,
num_blocks=12,
num_heads=8,
head_scale=True,
checkpoint_blocks=False,
# gatr configurations
use_fully_connected_subgroup=True,
mix_pseudoscalar_into_scalar=True,
use_bivector=True,
use_geometric_product=True,
)

cfg.update(**kwargs)
_logger.info("Model config: %s" % str(cfg))

model = LGATrTaggerWrapper(**cfg)

model_info = {
"input_names": list(data_config.input_names),
"input_shapes": {
k: ((1,) + s[1:]) for k, s in data_config.input_shapes.items()
},
"output_names": ["softmax"],
"dynamic_axes": {
**{k: {0: "N", 2: "n_" + k.split("_")[0]} for k in data_config.input_names},
**{"softmax": {0: "N"}},
},
}

return model, model_info


def get_loss(data_config, **kwargs):
return torch.nn.CrossEntropyLoss()
5 changes: 4 additions & 1 deletion train_JetClass.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ samples_per_epoch=$((10000 * 1024 / $NGPUS))
samples_per_epoch_val=$((10000 * 128))
dataopts="--num-workers 2 --fetch-step 0.01"

# PN, PFN, PCNN, ParT
# PN, PFN, PCNN, ParT, LGATr
model=$1
if [[ "$model" == "ParT" ]]; then
modelopts="networks/example_ParticleTransformer.py --use-amp"
Expand All @@ -41,6 +41,9 @@ elif [[ "$model" == "PFN" ]]; then
elif [[ "$model" == "PCNN" ]]; then
modelopts="networks/example_PCNN.py"
batchopts="--batch-size 4096 --start-lr 2e-2"
elif [[ "$model" == "LGATr" ]]; then
modelopts="networks/example_LGATr.py"
batchopts="--batch-size 512 --start-lr 3e-4 --optimizer lion"
else
echo "Invalid model $model!"
exit 1
Expand Down
9 changes: 6 additions & 3 deletions train_QuarkGluon.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ DATADIR=${DATADIR_QuarkGluon}
# set a comment via `COMMENT`
suffix=${COMMENT}

# PN, PFN, PCNN, ParT
# PN, PFN, PCNN, ParT, LGATr
model=$1
extraopts=""
if [[ "$model" == "ParT" ]]; then
Expand All @@ -38,6 +38,9 @@ elif [[ "$model" == "PCNN" ]]; then
modelopts="networks/example_PCNN.py"
lr="2e-2"
extraopts="--batch-size 4096"
elif [[ "$model" == "LGATr" ]]; then
modelopts="networks/example_LGATr.py --batch-size 128 --optimizer lion --optimizer-option weight_decay 0.2"
lr="3e-4"
else
echo "Invalid model $model!"
exit 1
Expand Down Expand Up @@ -66,10 +69,10 @@ fi
weaver \
--data-train "${DATADIR}/train_file_*.parquet" \
--data-test "${DATADIR}/test_file_*.parquet" \
--data-config data/QuarkGluon/qg_${FEATURE_TYPE}.yaml --network-config $modelopts \
--data-config data/QuarkGluon/qg_${FEATURE_TYPE}.yaml \
--model-prefix training/QuarkGluon/${model}/{auto}${suffix}/net \
--num-workers 1 --fetch-step 1 --in-memory --train-val-split 0.8889 \
--batch-size 512 --samples-per-epoch 1600000 --samples-per-epoch-val 200000 --num-epochs 20 --gpus 0 \
--start-lr $lr --optimizer ranger --log logs/QuarkGluon_${model}_{auto}${suffix}.log --predict-output pred.root \
--tensorboard QuarkGluon_${FEATURE_TYPE}_${model}${suffix} \
${extraopts} "${@:3}"
--network-config $modelopts ${extraopts} "${@:3}"
9 changes: 6 additions & 3 deletions train_TopLandscape.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ DATADIR=${DATADIR_TopLandscape}
# set a comment via `COMMENT`
suffix=${COMMENT}

# PN, PFN, PCNN, ParT
# PN, PFN, PCNN, ParT, LGATr
model=$1
extraopts=""
if [[ "$model" == "ParT" ]]; then
Expand All @@ -37,6 +37,9 @@ elif [[ "$model" == "PCNN" ]]; then
modelopts="networks/example_PCNN.py"
lr="2e-2"
extraopts="--batch-size 4096"
elif [[ "$model" == "LGATr" ]]; then
modelopts="networks/example_LGATr.py --batch-size 128 --optimizer lion --optimizer-option weight_decay 0.2"
lr="3e-4"
else
echo "Invalid model $model!"
exit 1
Expand All @@ -54,10 +57,10 @@ weaver \
--data-train "${DATADIR}/train_file.parquet" \
--data-val "${DATADIR}/val_file.parquet" \
--data-test "${DATADIR}/test_file.parquet" \
--data-config data/TopLandscape/top_${FEATURE_TYPE}.yaml --network-config $modelopts \
--data-config data/TopLandscape/top_${FEATURE_TYPE}.yaml \
--model-prefix training/TopLandscape/${model}/{auto}${suffix}/net \
--num-workers 1 --fetch-step 1 --in-memory \
--batch-size 512 --samples-per-epoch $((2400 * 512)) --samples-per-epoch-val $((800 * 512)) --num-epochs 20 --gpus 0 \
--start-lr $lr --optimizer ranger --log logs/TopLandscape_${model}_{auto}${suffix}.log --predict-output pred.root \
--tensorboard TopLandscape_${FEATURE_TYPE}_${model}${suffix} \
${extraopts} "${@:3}"
--network-config $modelopts ${extraopts} "${@:3}"