Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
cf33749
Copy-paste gatr folder from https://github.com/heidelberg-hepml/loren…
spinjo Feb 13, 2025
b6f49d5
Remove unused features: ConditionalGATr (and cross-attention), GAP, A…
spinjo Feb 13, 2025
1bf728e
Make gatr imports relative
spinjo Feb 13, 2025
2d020d1
Fix things: Load geometric_product and linear_basis instead of constr…
spinjo Feb 13, 2025
e821495
Add extra requirements for GATr
spinjo Feb 13, 2025
8bd0e97
First working version of LGATr wrapper
spinjo Feb 13, 2025
04a5794
Add for_segmentation option (not sure about this)
spinjo Feb 14, 2025
234a9a0
Use SequenceTrimmer from ParticleTransformer also in LGATr
spinjo Feb 14, 2025
6b3b569
Move spurions to device (oopsie)
spinjo Feb 16, 2025
361f3f3
Enforce autocast to float32 for equi_layer_norm (required for mixed-p…
spinjo Feb 16, 2025
bbb5dcf
Add support for full Lorentz group (default is fully connected subgroup)
spinjo Feb 16, 2025
e0d0c96
Support gradient checkpointing
spinjo Feb 17, 2025
dcda048
Remove unused scalar output layer
spinjo Feb 17, 2025
04e5437
Prepend spurions (instead of appending them) -> easier to keep track
spinjo Feb 18, 2025
4832057
Fix mean-aggregation (set output of padded particles to zero before e…
spinjo Mar 3, 2025
d0c2c3a
Add global_token option
spinjo Mar 3, 2025
e6018a9
Clean up documentation etc in LGATr.py
spinjo Mar 3, 2025
469542a
Add stabilizing LayerNorm
spinjo Mar 5, 2025
0a60774
Delete gatr/ folder again (will soon be imported)
spinjo Mar 15, 2025
0fb23cc
Take GATr code from lgatr repo
spinjo Mar 15, 2025
fc9371e
Add documentation
spinjo Mar 15, 2025
b9970b5
Change docu
spinjo Mar 15, 2025
5f99b50
Add 'use_flex_attention' option (uses torch_geometric)
spinjo Mar 17, 2025
83ff1fc
Minor change
spinjo Mar 18, 2025
52f848a
Import lgatr package; remove sparse-tensor-business
spinjo Mar 19, 2025
1c3aa5f
Add Lion optimizer option
spinjo Mar 19, 2025
e6d4e45
Turn off flop-counter and add comment on why
spinjo Mar 19, 2025
7093ab2
Changes based on recent lgatr rework
spinjo Jun 1, 2025
6bc9a58
Add pyarrow to weaver requirements.txt
spinjo Jun 1, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@ lz4>=3.1.0
xxhash>=1.4.4
tables>=3.6.1
tensorboard>=2.2.0
lgatr
pytorch_optimizer>=3.0.0 # Lion optimizer
pyarrow
286 changes: 286 additions & 0 deletions weaver/nn/model/LGATr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
import torch
from torch import nn

from lgatr import (
LGATr,
embed_vector,
extract_scalar,
get_num_spurions,
get_spurions,
gatr_config,
)
from .ParticleTransformer import SequenceTrimmer


class LGATrWrapper(nn.Module):
"""Interface to the LGATr class"""

def __init__(
self,
in_s_channels: int,
hidden_mv_channels: int,
hidden_s_channels: int,
num_classes: int,
num_blocks: int,
num_heads: int,
# symmetry-breaking configurations
spurion_token: bool = True,
beam_spurion: str = "xyplane",
add_time_spurion: bool = True,
beam_mirror: bool = True,
# network configurations
global_token: bool = True,
activation: str = "gelu",
multi_query: bool = False,
increase_hidden_channels_attention: int = 2,
increase_hidden_channels_mlp: int = 2,
num_hidden_layers_mlp: int = 1,
head_scale: bool = False,
dropout_prob: float = None,
# time/memory configurations
checkpoint_blocks: bool = False,
# gatr configurations
use_fully_connected_subgroup: bool = True,
mix_pseudoscalar_into_scalar: bool = True,
use_bivector: bool = True,
use_geometric_product: bool = True,
):
"""
Parameters
----------
in_s_channels : int
Number of scalar input channels.
Examples are PID, trajectory displacements and kinematic features
like log(pT), delta_phi etc that are invariant under z-rotations.
hidden_mv_channels : int
Number of hidden multivector channels, defines width of L-GATr.
hidden_s_channels : int
Number of hidden scalar channels. We find best performance with
roughly hidden_s_channels ~ 2 * hidden_mv_channels.
num_classes : int
Number of classification scores to predict
num_blocks : int
Number of L-GATr blocks.
num_heads : int
Number of attention heads in L-GATr.
spurion_token : bool
If True, prepend spurions as extra particles (tokens) in the list.
If False, append spurions as extra mv channels.
beam_spurion : str
How the beam spurion is embedded, see lgatr/interface/spurions.py
add_time_spurion : bool
If True, add a time spurion.
beam_mirror : bool
If True and beam_spurion in ["timelike", "lightlike", "spacelike"],
add a mirrored beam_spurion, i.e. with opposite p_z.
global_token : bool
If True, prepend a global token as first particle in the list.
If False, fallback to mean-aggregation.
activation : {"relu", "sigmoid", "gelu"}
Activation function in the MLP layers.
multi_query : bool
If True, use the same query for each head in attention.
increase_hidden_channels_attention : int
Factor by which hidden_mv_channels is increased in attention.
increase_hidden_channels_mlp : int
Factor by which hidden_mv_channels is increased in the MLP.
num_hidden_layers_mlp : int
Number of hidden layers in the MLP.
head_scale : bool
If True, scale the attention heads with a learnable factor.
Inspired by the NormFormer (https://arxiv.org/pdf/2110.09456)
dropout_prob : float
Residual dropout after attention and MLP.
checkpoint_blocks : bool
If True, use torch.utils.checkpoint.checkpoint to save memory
at the cost of a slower backward pass.
use_fully_connected_subgroup : bool
If True, model is only equivariant with respect to
the fully connected subgroup of the Lorentz group,
the proper orthochronous Lorentz group SO^+(1,3),
which does not include parity and time reversal.
This setting affects how the EquiLinear maps work:
For SO^+(1,3), they include transitions scalars/pseudoscalars
vectors/axialvectors and among bivectors, effectively
treating the pseudoscalar/axialvector representations
like another scalar/vector.
Defaults to False, because parity-odd representations
are usually not important in high-energy physics simulations.
mix_pseudoscalar_into_scalar : bool
If True, the pseudoscalar part of the multivector mixes
with the pure-scalar channels in the EquiLinear layer.
This is a technical aspect of how EquiLinear maps work,
and only makes sense it use_fully_connected_subgroup=True.
Attention: The combination use_fully_connected_subgroup=False
and mix_pseudoscalar_into_scalar=True does not make sense,
you are only equivariant w.r.t. the fully connected subgroup
if you choose these settings.
use_bivector : bool
If False, the bivector components are set to zero after they
are created in the GeometricBilinear layer.
This is a toy switch to explore the effect of higher-order
representations.
use_geometric_product : bool
If False, the GeometricBilinear layer is replaced
by a EquiLinear + ScalarGatedNonlinearity layer.
This is a toy switch to explore the effect of the geometric product.
"""
super().__init__()

# spurion business
in_mv_channels = 1
self.global_token = global_token
self.spurion_token = spurion_token

num_spurions = get_num_spurions(
beam_spurion, add_time_spurion, beam_mirror=beam_mirror
)
if not self.spurion_token:
in_mv_channels += num_spurions
self.spurion_kwargs = {
"beam_spurion": beam_spurion,
"add_time_spurion": add_time_spurion,
"beam_mirror": beam_mirror,
}

gatr_config.use_fully_connected_subgroup = use_fully_connected_subgroup
gatr_config.mix_pseudoscalar_into_scalar = mix_pseudoscalar_into_scalar
gatr_config.use_bivector = use_bivector
gatr_config.use_geometric_product = use_geometric_product

attention = dict(
multi_query=multi_query,
num_heads=num_heads,
increase_hidden_channels=increase_hidden_channels_attention,
head_scale=head_scale,
)
mlp = dict(
activation=activation,
increase_hidden_channels=increase_hidden_channels_mlp,
num_hidden_layers=num_hidden_layers_mlp,
)

self.net = LGATr(
num_blocks=num_blocks,
in_mv_channels=in_mv_channels,
out_mv_channels=num_classes,
hidden_mv_channels=hidden_mv_channels,
in_s_channels=in_s_channels,
out_s_channels=None,
hidden_s_channels=hidden_s_channels,
attention=attention,
mlp=mlp,
dropout_prob=dropout_prob,
checkpoint_blocks=checkpoint_blocks,
)

def forward(self, x, v, mask):
"""
Parameters
----------
x : torch.Tensor with shape (batch_size, num_fts, seq_len)
Scalar features, i.e. features that are invariant under z-rotations.
Examples: PID, trajectory displacements, kinematic features like
log(pT), delta_phi, delta_eta
v : torch.Tensor with shape (batch_size, 4, seq_len)
Lorentz vectors in format (px, py, pz, E)
mask : torch.Tensor with shape (batch_size, 1, seq_len)
Boolean mask that contains 'False' for padded jet constituents

Returns
-------
output : torch.Tensor with shape (batch_size, num_classes)
Tagging scores for each class
"""
# reshape input
x = x.transpose(1, 2) # (batch_size, seq_len, num_fts)
v = v.transpose(1, 2) # (batch_size, seq_len, 4)
mask = mask.transpose(1, 2) # (batchsize, seq_len, 1)

# embed data into geometric algebra
fourmomenta = v[:, :, None, [3, 0, 1, 2]] # (px, py, pz, E) -> (E, px, py, pz)
mv = embed_vector(fourmomenta) # (batch_size, seq_len, 1, 16)
s = x # (batch_size, seq_len, num_fts)

# symmetry breaking with spurions
spurions = get_spurions(**self.spurion_kwargs).to(
device=s.device, dtype=s.dtype
)
if self.spurion_token:
# prepend spurions as extra particles in the list
mask_ones = torch.ones_like(mask[:, [0]]).repeat(1, spurions.shape[0], 1)
mask = torch.cat([mask_ones, mask], dim=1)
s_zeros = torch.zeros_like(s[:, [0]]).repeat(1, spurions.shape[0], 1)
s = torch.cat([s_zeros, s], dim=1)
spurions = spurions[None, :, None, :].repeat(mv.shape[0], 1, 1, 1)
mv = torch.cat([spurions, mv], dim=1)
else:
# append spurions as extra mv channels
spurions = spurions[None, None, :, :].repeat(mv.shape[0], mv.shape[1], 1, 1)
mv = torch.cat([mv, spurions], dim=2)

if self.global_token:
# prepend global token as first particle in the list
global_token = torch.zeros_like(mv[:, [0], :, :])
mv = torch.cat((global_token, mv), dim=1)
mask_ones = torch.ones_like(mask[:, [0]])
mask = torch.cat((mask_ones, mask), dim=1)
s_zeros = torch.zeros_like(s[:, [0]])
s = torch.cat((s_zeros, s), dim=1)
is_global = torch.zeros_like(s[:, :, 0], dtype=torch.bool)
is_global[:, 0] = True

# reshape mask to broadcast correctly
mask = mask.bool()
attn_mask = mask[:, None, None, :, 0] # (batch_size, 1, 1, seq_len)
attn_kwargs = {"attn_mask": attn_mask}

# call network
out_mv, _ = self.net(mv, s, **attn_kwargs)
output = extract_scalar(out_mv)[..., 0] # (batch_size, seq_len, num_classes)

# aggregation
if self.global_token:
output = output[is_global]
else:
# mean aggregation
output[~mask[:, 0, 0]] = 0.0
output = output.mean(dim=1)
return output


class LGATrTagger(nn.Module):
"""Mimic other weaver wrappers"""

def __init__(
self,
use_amp=False,
trim=True,
for_inference=False,
for_segmentation=False,
**kwargs,
):
super().__init__()

self.use_amp = use_amp
self.for_inference = for_inference
self.for_segmentation = for_segmentation
self.trimmer = SequenceTrimmer(enabled=trim and not for_inference)
self.net = LGATrWrapper(**kwargs)

def forward(self, x, v=None, mask=None):
with torch.no_grad():
x, v, mask, _ = self.trimmer(x, v, mask)

with torch.autocast("cuda", enabled=self.use_amp):
output = self.net(x, v, mask)

if self.for_segmentation:
output = output.transpose(1, 2).contiguous()
if self.for_inference:
output = torch.softmax(output, dim=1)
return output

if self.for_inference:
output = torch.softmax(output, dim=-1)
return output
6 changes: 5 additions & 1 deletion weaver/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import math
import copy
import torch
from pytorch_optimizer import Lion

from torch.utils.data import DataLoader
from weaver.utils.logger import _logger, _configLogger
Expand Down Expand Up @@ -516,6 +517,8 @@ def init_opt(args, model, **optimizer_options):
opt = torch.optim.AdamW(parameters, lr=args.start_lr, **optimizer_options)
elif args.optimizer == 'radam':
opt = torch.optim.RAdam(parameters, lr=args.start_lr, **optimizer_options)
elif args.optimizer == 'lion':
opt = Lion(parameters, lr=args.start_lr, **optimizer_options)
else:
opt = getattr(torch.optim, args.optimizer)(parameters, lr=args.start_lr, **optimizer_options)

Expand Down Expand Up @@ -643,7 +646,8 @@ def model_setup(args, data_config, device='cpu'):
_logger.info('The following weights has been frozen:\n - %s',
'\n - '.join([name for name, p in model.named_parameters() if not p.requires_grad]))
# _logger.info(model)
flops(model, model_info, device=device)
# dont use flop counting tool from 2019 (causes issues in modern torch, specifically einops; works with e.g. einops==0.6.1)
# flops(model, model_info, device=device)
# loss function
try:
loss_func = network_module.get_loss(data_config, **network_options)
Expand Down