From cf33749d2b7e6c86fd8bdd2ed7ba4342fdb7c8fa Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Thu, 13 Feb 2025 20:48:18 +0100 Subject: [PATCH 01/29] Copy-paste gatr folder from https://github.com/heidelberg-hepml/lorentz-gatr --- weaver/nn/model/gatr/__init__.py | 7 + weaver/nn/model/gatr/interface/__init__.py | 3 + weaver/nn/model/gatr/interface/scalar.py | 42 ++ weaver/nn/model/gatr/interface/spurions.py | 132 ++++++ weaver/nn/model/gatr/interface/vector.py | 46 +++ weaver/nn/model/gatr/layers/__init__.py | 13 + .../model/gatr/layers/attention/__init__.py | 3 + .../model/gatr/layers/attention/attention.py | 76 ++++ .../nn/model/gatr/layers/attention/config.py | 204 +++++++++ .../gatr/layers/attention/cross_attention.py | 220 ++++++++++ .../layers/attention/positional_encoding.py | 135 ++++++ weaver/nn/model/gatr/layers/attention/qkv.py | 227 ++++++++++ .../gatr/layers/attention/self_attention.py | 161 ++++++++ .../gatr/layers/conditional_gatr_block.py | 191 +++++++++ weaver/nn/model/gatr/layers/dropout.py | 51 +++ weaver/nn/model/gatr/layers/gatr_block.py | 143 +++++++ weaver/nn/model/gatr/layers/layer_norm.py | 70 ++++ weaver/nn/model/gatr/layers/linear.py | 388 ++++++++++++++++++ weaver/nn/model/gatr/layers/mlp/__init__.py | 4 + weaver/nn/model/gatr/layers/mlp/config.py | 44 ++ .../gatr/layers/mlp/geometric_bilinears.py | 105 +++++ weaver/nn/model/gatr/layers/mlp/mlp.py | 106 +++++ .../model/gatr/layers/mlp/nonlinearities.py | 65 +++ weaver/nn/model/gatr/nets/__init__.py | 4 + weaver/nn/model/gatr/nets/axial_gatr.py | 214 ++++++++++ weaver/nn/model/gatr/nets/conditional_gatr.py | 227 ++++++++++ weaver/nn/model/gatr/nets/gap.py | 120 ++++++ weaver/nn/model/gatr/nets/gatr.py | 182 ++++++++ weaver/nn/model/gatr/primitives/__init__.py | 17 + weaver/nn/model/gatr/primitives/attention.py | 142 +++++++ weaver/nn/model/gatr/primitives/bilinear.py | 67 +++ weaver/nn/model/gatr/primitives/dropout.py | 36 ++ weaver/nn/model/gatr/primitives/invariants.py | 171 ++++++++ weaver/nn/model/gatr/primitives/linear.py | 196 +++++++++ .../model/gatr/primitives/nonlinearities.py | 79 ++++ .../nn/model/gatr/primitives/normalization.py | 47 +++ weaver/nn/model/gatr/utils/__init__.py | 0 weaver/nn/model/gatr/utils/clifford.py | 136 ++++++ weaver/nn/model/gatr/utils/einsum.py | 44 ++ weaver/nn/model/gatr/utils/tensors.py | 34 ++ 40 files changed, 4152 insertions(+) create mode 100644 weaver/nn/model/gatr/__init__.py create mode 100644 weaver/nn/model/gatr/interface/__init__.py create mode 100644 weaver/nn/model/gatr/interface/scalar.py create mode 100644 weaver/nn/model/gatr/interface/spurions.py create mode 100644 weaver/nn/model/gatr/interface/vector.py create mode 100644 weaver/nn/model/gatr/layers/__init__.py create mode 100644 weaver/nn/model/gatr/layers/attention/__init__.py create mode 100644 weaver/nn/model/gatr/layers/attention/attention.py create mode 100644 weaver/nn/model/gatr/layers/attention/config.py create mode 100644 weaver/nn/model/gatr/layers/attention/cross_attention.py create mode 100644 weaver/nn/model/gatr/layers/attention/positional_encoding.py create mode 100644 weaver/nn/model/gatr/layers/attention/qkv.py create mode 100644 weaver/nn/model/gatr/layers/attention/self_attention.py create mode 100644 weaver/nn/model/gatr/layers/conditional_gatr_block.py create mode 100644 weaver/nn/model/gatr/layers/dropout.py create mode 100644 weaver/nn/model/gatr/layers/gatr_block.py create mode 100644 weaver/nn/model/gatr/layers/layer_norm.py create mode 100644 weaver/nn/model/gatr/layers/linear.py create mode 100644 weaver/nn/model/gatr/layers/mlp/__init__.py create mode 100644 weaver/nn/model/gatr/layers/mlp/config.py create mode 100644 weaver/nn/model/gatr/layers/mlp/geometric_bilinears.py create mode 100644 weaver/nn/model/gatr/layers/mlp/mlp.py create mode 100644 weaver/nn/model/gatr/layers/mlp/nonlinearities.py create mode 100644 weaver/nn/model/gatr/nets/__init__.py create mode 100644 weaver/nn/model/gatr/nets/axial_gatr.py create mode 100644 weaver/nn/model/gatr/nets/conditional_gatr.py create mode 100644 weaver/nn/model/gatr/nets/gap.py create mode 100644 weaver/nn/model/gatr/nets/gatr.py create mode 100644 weaver/nn/model/gatr/primitives/__init__.py create mode 100644 weaver/nn/model/gatr/primitives/attention.py create mode 100644 weaver/nn/model/gatr/primitives/bilinear.py create mode 100644 weaver/nn/model/gatr/primitives/dropout.py create mode 100644 weaver/nn/model/gatr/primitives/invariants.py create mode 100644 weaver/nn/model/gatr/primitives/linear.py create mode 100644 weaver/nn/model/gatr/primitives/nonlinearities.py create mode 100644 weaver/nn/model/gatr/primitives/normalization.py create mode 100644 weaver/nn/model/gatr/utils/__init__.py create mode 100644 weaver/nn/model/gatr/utils/clifford.py create mode 100644 weaver/nn/model/gatr/utils/einsum.py create mode 100644 weaver/nn/model/gatr/utils/tensors.py diff --git a/weaver/nn/model/gatr/__init__.py b/weaver/nn/model/gatr/__init__.py new file mode 100644 index 00000000..bf2c70d9 --- /dev/null +++ b/weaver/nn/model/gatr/__init__.py @@ -0,0 +1,7 @@ +from .layers.attention.config import SelfAttentionConfig, CrossAttentionConfig +from .layers.mlp.config import MLPConfig +from .nets.axial_gatr import AxialGATr +from .nets.gatr import GATr +from .nets.conditional_gatr import ConditionalGATr + +__version__ = "1.0.0" diff --git a/weaver/nn/model/gatr/interface/__init__.py b/weaver/nn/model/gatr/interface/__init__.py new file mode 100644 index 00000000..995d0750 --- /dev/null +++ b/weaver/nn/model/gatr/interface/__init__.py @@ -0,0 +1,3 @@ +from .vector import embed_vector, extract_vector +from .scalar import embed_scalar, extract_scalar +from .spurions import embed_spurions, get_num_spurions diff --git a/weaver/nn/model/gatr/interface/scalar.py b/weaver/nn/model/gatr/interface/scalar.py new file mode 100644 index 00000000..b248a513 --- /dev/null +++ b/weaver/nn/model/gatr/interface/scalar.py @@ -0,0 +1,42 @@ +import torch + + +def embed_scalar(scalars: torch.Tensor) -> torch.Tensor: + """Embeds a scalar tensor into multivectors. + + Parameters + ---------- + scalars: torch.Tensor with shape (..., 1) + Scalar inputs. + + Returns + ------- + multivectors: torch.Tensor with shape (..., 16) + Multivector outputs. `multivectors[..., [0]]` is the same as `scalars`. The other components + are zero. + """ + + non_scalar_shape = list(scalars.shape[:-1]) + [15] + non_scalar_components = torch.zeros( + non_scalar_shape, device=scalars.device, dtype=scalars.dtype + ) + embedding = torch.cat((scalars, non_scalar_components), dim=-1) + + return embedding + + +def extract_scalar(multivectors: torch.Tensor) -> torch.Tensor: + """Extracts scalar components from multivectors. + + Parameters + ---------- + multivectors: torch.Tensor with shape (..., 16) + Multivector inputs. + + Returns + ------- + scalars: torch.Tensor with shape (..., 1) + Scalar component of multivectors. + """ + + return multivectors[..., [0]] diff --git a/weaver/nn/model/gatr/interface/spurions.py b/weaver/nn/model/gatr/interface/spurions.py new file mode 100644 index 00000000..58537a21 --- /dev/null +++ b/weaver/nn/model/gatr/interface/spurions.py @@ -0,0 +1,132 @@ +import torch +from gatr.interface import embed_vector + + +def get_num_spurions( + beam_reference, + add_time_reference, + two_beams=True, + add_xzplane=False, + add_yzplane=False, +): + """ + Compute how many reference multivectors/spurions a given configuration will have + + Parameters + ---------- + beam_reference: str + Different options for adding a beam_reference + Options: "lightlike", "spacelike", "timelike", "xyplane" + add_time_reference: bool + Whether to add the time direction as a reference to the network + two_beams: bool + Whether we only want (x, 0, 0, 1) or both (x, 0, 0, +/- 1) for the beam + add_xzplane: bool + Whether to add the x-z-plane as a reference to the network + add_yzplane: bool + Whether to add the y-z-plane as a reference to the network + + Returns + ------- + num_spurions: int + Number of spurions + """ + num_spurions = 0 + if beam_reference in ["lightlike", "spacelike", "timelike"]: + num_spurions += 2 if two_beams else 1 + elif beam_reference == "xyplane": + num_spurions += 1 + if add_xzplane: + num_spurions += 1 + if add_yzplane: + num_spurions += 1 + if add_time_reference: + num_spurions += 1 + return num_spurions + + +def embed_spurions( + beam_reference, + add_time_reference, + two_beams=True, + add_xzplane=False, + add_yzplane=False, + device="cpu", + dtype=torch.float32, +): + """ + Construct a list of reference multivectors/spurions for symmetry breaking + + Parameters + ---------- + beam_reference: str + Different options for adding a beam_reference + Options: "lightlike", "spacelike", "timelike", "xyplane" + add_time_reference: bool + Whether to add the time direction as a reference to the network + two_beams: bool + Whether we only want (x, 0, 0, 1) or both (x, 0, 0, +/- 1) for the beam + add_xzplane: bool + Whether to add the x-z-plane as a reference to the network + add_yzplane: bool + Whether to add the y-z-plane as a reference to the network + device + dtype + + Returns + ------- + spurions: torch.tensor with shape (n_spurions, 16) + spurion embedded as multivector object + """ + kwargs = {"device": device, "dtype": dtype} + + if beam_reference in ["lightlike", "spacelike", "timelike"]: + # add another 4-momentum + if beam_reference == "lightlike": + beam = [1, 0, 0, 1] + elif beam_reference == "timelike": + beam = [2**0.5, 0, 0, 1] + elif beam_reference == "spacelike": + beam = [0, 0, 0, 1] + beam = torch.tensor(beam, **kwargs).reshape(1, 4) + beam = embed_vector(beam) + if two_beams: + beam2 = beam.clone() + beam2[..., 4] = -1 # flip pz + beam = torch.cat((beam, beam2), dim=0) + + elif beam_reference == "xyplane": + # add the x-y-plane, embedded as a bivector + # convention for bivector components: [tx, ty, tz, xy, xz, yz] + beam = torch.zeros(1, 16, **kwargs) + beam[..., 8] = 1 + + elif beam_reference is None: + beam = torch.empty(0, 16, **kwargs) + + else: + raise ValueError(f"beam_reference {beam_reference} not implemented") + + if add_xzplane: + # add the x-z-plane, embedded as a bivector + xzplane = torch.zeros(1, 16, **kwargs) + xzplane[..., 10] = 1 + else: + xzplane = torch.empty(0, 16, **kwargs) + + if add_yzplane: + # add the y-z-plane, embedded as a bivector + yzplane = torch.zeros(1, 16, **kwargs) + yzplane[..., 9] = 1 + else: + yzplane = torch.empty(0, 16, **kwargs) + + if add_time_reference: + time = [1, 0, 0, 0] + time = torch.tensor(time, **kwargs).reshape(1, 4) + time = embed_vector(time) + else: + time = torch.empty(0, 16, **kwargs) + + spurions = torch.cat((beam, xzplane, yzplane, time), dim=-2) + return spurions diff --git a/weaver/nn/model/gatr/interface/vector.py b/weaver/nn/model/gatr/interface/vector.py new file mode 100644 index 00000000..54206cb6 --- /dev/null +++ b/weaver/nn/model/gatr/interface/vector.py @@ -0,0 +1,46 @@ +import torch + + +def embed_vector(vector: torch.Tensor) -> torch.Tensor: + """Embeds Lorentz vectors in multivectors. + + Parameters + ---------- + vector : torch.Tensor with shape (..., 4) + Lorentz vector + + Returns + ------- + multivector : torch.Tensor with shape (..., 16) + Embedding into multivector. + """ + + # Create multivector tensor with same batch shape, same device, same dtype as input + batch_shape = vector.shape[:-1] + multivector = torch.zeros( + *batch_shape, 16, dtype=vector.dtype, device=vector.device + ) + + # Embedding into Lorentz vectors + multivector[..., 1:5] = vector + + return multivector + + +def extract_vector(multivector: torch.Tensor) -> torch.Tensor: + """Given a multivector, extract a Lorentz vector. + + Parameters + ---------- + multivector : torch.Tensor with shape (..., 16) + Multivector. + + Returns + ------- + vector : torch.Tensor with shape (..., 4) + Lorentz vector + """ + + vector = multivector[..., 1:5] + + return vector diff --git a/weaver/nn/model/gatr/layers/__init__.py b/weaver/nn/model/gatr/layers/__init__.py new file mode 100644 index 00000000..6c82d47b --- /dev/null +++ b/weaver/nn/model/gatr/layers/__init__.py @@ -0,0 +1,13 @@ +from .attention.config import SelfAttentionConfig, CrossAttentionConfig +from .attention.positional_encoding import ApplyRotaryPositionalEncoding +from .attention.self_attention import SelfAttention +from .attention.cross_attention import CrossAttention +from .dropout import GradeDropout +from .layer_norm import EquiLayerNorm +from .linear import EquiLinear +from .mlp.geometric_bilinears import GeometricBilinear +from .mlp.mlp import GeoMLP +from .mlp.config import MLPConfig +from .mlp.nonlinearities import ScalarGatedNonlinearity +from .gatr_block import GATrBlock +from .conditional_gatr_block import ConditionalGATrBlock diff --git a/weaver/nn/model/gatr/layers/attention/__init__.py b/weaver/nn/model/gatr/layers/attention/__init__.py new file mode 100644 index 00000000..9addf7e6 --- /dev/null +++ b/weaver/nn/model/gatr/layers/attention/__init__.py @@ -0,0 +1,3 @@ +from .config import SelfAttentionConfig, CrossAttentionConfig +from .self_attention import SelfAttention +from .cross_attention import CrossAttention diff --git a/weaver/nn/model/gatr/layers/attention/attention.py b/weaver/nn/model/gatr/layers/attention/attention.py new file mode 100644 index 00000000..73b7f5cb --- /dev/null +++ b/weaver/nn/model/gatr/layers/attention/attention.py @@ -0,0 +1,76 @@ +"""Self-attention layers.""" + +from torch import nn + +from gatr.layers.attention.config import SelfAttentionConfig +from gatr.primitives.attention import sdp_attention + + +class GeometricAttention(nn.Module): + """Geometric attention layer. + + This is the main attention mechanism used in L-GATr. + + Given multivector and scalar queries, keys, and values, this layer computes: + + ``` + attn_weights[..., i, j] = softmax_j[ + ga_inner_product(q_mv[..., i, :, :], k_mv[..., j, :, :]) + + euclidean_inner_product(q_s[..., i, :], k_s[..., j, :]) + ] + out_mv[..., i, c, :] = sum_j attn_weights[..., i, j] v_mv[..., j, c, :] / norm + out_s[..., i, c] = sum_j attn_weights[..., i, j] v_s[..., j, c] / norm + ``` + + Parameters + ---------- + config : SelfAttentionConfig + Attention configuration. + """ + + def __init__(self, config: SelfAttentionConfig) -> None: + super().__init__() + + def forward(self, q_mv, k_mv, v_mv, q_s, k_s, v_s, attention_mask=None): + """Forward pass through geometric attention. + + Given multivector and scalar queries, keys, and values, this forward pass computes: + + ``` + attn_weights[..., i, j] = softmax_j[ + ga_inner_product(q_mv[..., i, :, :], k_mv[..., j, :, :]) + + euclidean_inner_product(q_s[..., i, :], k_s[..., j, :]) + ] + out_mv[..., i, c, :] = sum_j attn_weights[..., i, j] v_mv[..., j, c, :] / norm + out_s[..., i, c] = sum_j attn_weights[..., i, j] v_s[..., j, c] / norm + ``` + + Parameters + ---------- + q_mv : Tensor with shape (..., num_items_out, num_mv_channels_in, 16) + Queries, multivector part. + k_mv : Tensor with shape (..., num_items_in, num_mv_channels_in, 16) + Keys, multivector part. + v_mv : Tensor with shape (..., num_items_in, num_mv_channels_out, 16) + Values, multivector part. + q_s : Tensor with shape (..., heads, num_items_out, num_s_channels_in) + Queries, scalar part. + k_s : Tensor with shape (..., heads, num_items_in, num_s_channels_in) + Keys, scalar part. + v_s : Tensor with shape (..., heads, num_items_in, num_s_channels_out) + Values, scalar part. + attention_mask: None or Tensor or AttentionBias + Optional attention mask. + """ + + h_mv, h_s = sdp_attention( + q_mv, + k_mv, + v_mv, + q_s, + k_s, + v_s, + attn_mask=attention_mask, + ) + + return h_mv, h_s diff --git a/weaver/nn/model/gatr/layers/attention/config.py b/weaver/nn/model/gatr/layers/attention/config.py new file mode 100644 index 00000000..950c45c7 --- /dev/null +++ b/weaver/nn/model/gatr/layers/attention/config.py @@ -0,0 +1,204 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Mapping, Optional + + +@dataclass +class SelfAttentionConfig: + """Configuration for attention. + + Parameters + ---------- + in_mv_channels : int + Number of input multivector channels. + out_mv_channels : int + Number of output multivector channels. + num_heads : int + Number of attention heads. + in_s_channels : int + Input scalar channels. If None, no scalars are expected nor returned. + out_s_channels : int + Output scalar channels. If None, no scalars are expected nor returned. + additional_qk_mv_channels : int + Whether additional multivector features for the keys and queries will be provided. + additional_qk_s_channels : int + Whether additional scalar features for the keys and queries will be provided. + multi_query: bool + Whether to do multi-query attention + pos_encoding : bool + Whether to apply rotary positional embeddings along the item dimension to the scalar keys + and queries. + pos_encoding_base : int + Base for the frequencies in the positional encoding. + output_init : str + Initialization scheme for final linear layer + increase_hidden_channels : int + Factor by which to increase the number of hidden channels (both multivectors and scalars) + dropout_prob : float or None + Dropout probability + head_scale: bool + Whether to use HeadScaleMHA following the NormFormer (https://arxiv.org/pdf/2110.09456) + """ + + multi_query: bool = True + in_mv_channels: Optional[int] = None + out_mv_channels: Optional[int] = None + in_s_channels: Optional[int] = None + out_s_channels: Optional[int] = None + num_heads: int = 8 + additional_qk_mv_channels: int = 0 + additional_qk_s_channels: int = 0 + pos_encoding: bool = False + pos_encoding_base: int = 4096 + output_init: str = "default" + checkpoint: bool = True + increase_hidden_channels: int = 2 + dropout_prob: Optional[float] = None + head_scale: bool = False + + def __post_init__(self): + """Type checking / conversion.""" + if isinstance(self.dropout_prob, str) and self.dropout_prob.lower() in [ + "null", + "none", + ]: + self.dropout_prob = None + + @property + def hidden_mv_channels(self) -> Optional[int]: + """Returns the number of hidden multivector channels.""" + + if self.in_mv_channels is None: + return None + + return max( + self.increase_hidden_channels * self.in_mv_channels // self.num_heads, 1 + ) + + @property + def hidden_s_channels(self) -> Optional[int]: + """Returns the number of hidden scalar channels.""" + + if self.in_s_channels is None: + return None + + hidden_s_channels = max( + self.increase_hidden_channels * self.in_s_channels // self.num_heads, 4 + ) + + # When using positional encoding, the number of scalar hidden channels needs to be even. + # It also should not be too small. + if self.pos_encoding: + hidden_s_channels = (hidden_s_channels + 1) // 2 * 2 + hidden_s_channels = max(hidden_s_channels, 8) + + return hidden_s_channels + + @classmethod + def cast(cls, config: Any) -> SelfAttentionConfig: + """Casts an object as SelfAttentionConfig.""" + if isinstance(config, SelfAttentionConfig): + return config + if isinstance(config, Mapping): + return cls(**config) + raise ValueError(f"Can not cast {config} to {cls}") + + +@dataclass +class CrossAttentionConfig: + """Configuration for cross-attention. + + Parameters + ---------- + in_q_mv_channels : int + Number of input query multivector channels. + in_kv_mv_channels : int + Number of input key/value multivector channels. + out_mv_channels : int + Number of output multivector channels. + num_heads : int + Number of attention heads. + in_q_s_channels : int + Input query scalar channels. If None, no scalars are expected nor returned. + in_kv_s_channels : int + Input key/value scalar channels. If None, no scalars are expected nor returned. + out_s_channels : int + Output scalar channels. If None, no scalars are expected nor returned. + additional_q_mv_channels : int + Whether additional multivector features for the queries will be provided. + additional_q_s_channels : int + Whether additional scalar features for the queries will be provided. + additional_k_mv_channels : int + Whether additional multivector features for the keys will be provided. + additional_k_s_channels : int + Whether additional scalar features for the keys will be provided. + multi_query: bool + Whether to do multi-query attention + output_init : str + Initialization scheme for final linear layer + increase_hidden_channels : int + Factor by which to increase the number of hidden channels (both multivectors and scalars) + dropout_prob : float or None + Dropout probability + head_scale: bool + Whether to use HeadScaleMHA following the NormFormer (https://arxiv.org/pdf/2110.09456) + """ + + in_q_mv_channels: Optional[int] = None + in_kv_mv_channels: Optional[int] = None + out_mv_channels: Optional[int] = None + out_s_channels: Optional[int] = None + in_q_s_channels: Optional[int] = None + in_kv_s_channels: Optional[int] = None + num_heads: int = 8 + additional_q_mv_channels: int = 0 + additional_q_s_channels: int = 0 + additional_k_mv_channels: int = 0 + additional_k_s_channels: int = 0 + multi_query: bool = True + output_init: str = "default" + checkpoint: bool = True + increase_hidden_channels: int = 2 + dropout_prob: Optional[float] = None + head_scale: bool = False + + def __post_init__(self): + """Type checking / conversion.""" + if isinstance(self.dropout_prob, str) and self.dropout_prob.lower() in [ + "null", + "none", + ]: + self.dropout_prob = None + + @property + def hidden_mv_channels(self) -> Optional[int]: + """Returns the number of hidden multivector channels.""" + + if self.in_q_mv_channels is None: + return None + + return max( + self.increase_hidden_channels * self.in_q_mv_channels // self.num_heads, 1 + ) + + @property + def hidden_s_channels(self) -> Optional[int]: + """Returns the number of hidden scalar channels.""" + + if self.in_q_s_channels is None: + assert self.in_kv_s_channels is None + return None + + return max( + self.increase_hidden_channels * self.in_q_s_channels // self.num_heads, 4 + ) + + @classmethod + def cast(cls, config: Any) -> CrossAttentionConfig: + """Casts an object as CrossAttentionConfig.""" + if isinstance(config, CrossAttentionConfig): + return config + if isinstance(config, Mapping): + return cls(**config) + raise ValueError(f"Can not cast {config} to {cls}") diff --git a/weaver/nn/model/gatr/layers/attention/cross_attention.py b/weaver/nn/model/gatr/layers/attention/cross_attention.py new file mode 100644 index 00000000..6fb1024b --- /dev/null +++ b/weaver/nn/model/gatr/layers/attention/cross_attention.py @@ -0,0 +1,220 @@ +"""Cross-attention layer.""" + +from typing import Optional, Tuple + +import torch +from einops import rearrange +from torch import nn + +from gatr.layers.attention.attention import GeometricAttention +from gatr.layers.attention.config import SelfAttentionConfig +from gatr.layers.dropout import GradeDropout +from gatr.layers.linear import EquiLinear + + +class CrossAttention(nn.Module): + """Geometric cross-attention layer. + + Constructs queries, keys, and values, computes attention, and projects linearly to outputs. + + Parameters + ---------- + config : SelfAttentionConfig + Attention configuration. + """ + + def __init__( + self, + config: SelfAttentionConfig, + ) -> None: + super().__init__() + + if ( + config.additional_q_mv_channels > 0 + or config.additional_q_s_channels > 0 + or config.additional_k_mv_channels > 0 + or config.additional_k_s_channels > 0 + ): + raise NotImplementedError( + "Cross attention is not implemented with additional channels" + ) + + # Store settings + self.config = config + + self.q_linear = EquiLinear( + in_mv_channels=config.in_q_mv_channels, + out_mv_channels=config.hidden_mv_channels * config.num_heads, + in_s_channels=config.in_q_s_channels, + out_s_channels=config.hidden_s_channels * config.num_heads, + ) + self.kv_linear = EquiLinear( + in_mv_channels=config.in_kv_mv_channels, + out_mv_channels=2 + * config.hidden_mv_channels + * (1 if config.multi_query else config.num_heads), + in_s_channels=config.in_kv_s_channels, + out_s_channels=2 + * config.hidden_s_channels + * (1 if config.multi_query else config.num_heads), + ) + + # Output projection + self.out_linear = EquiLinear( + in_mv_channels=config.hidden_mv_channels * config.num_heads, + out_mv_channels=config.out_mv_channels, + in_s_channels=( + None + if config.in_kv_s_channels is None + else config.hidden_s_channels * config.num_heads + ), + out_s_channels=config.out_s_channels, + initialization=config.output_init, + ) + + # Attention + self.attention = GeometricAttention(config) + + # Dropout + self.dropout: Optional[nn.Module] + if config.dropout_prob is not None: + raise ValueError( + "Dropout violates equivariance for cross_attention, " + "thats definitely a bug but didn't find the reason yet." + ) + self.dropout = GradeDropout(config.dropout_prob) + else: + self.dropout = None + + # HeadScaleMHA + self.use_head_scale = config.head_scale + if self.use_head_scale: + self.head_scale = nn.Parameter(torch.ones(config.num_heads)) + + def forward( + self, + multivectors_kv: torch.Tensor, + multivectors_q: torch.Tensor, + scalars_kv: Optional[torch.Tensor] = None, + scalars_q: Optional[torch.Tensor] = None, + attention_mask=None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute cross attention. + + Parameters + ---------- + multivectors_kv : torch.Tensor with shape (..., num_items_kv, channels_in, 16) + Input multivectors for key and value. + multivectors_q : torch.Tensor with shape (..., num_items_q, channels_in_q, 16) + Input multivectors for query. + scalars_kv : None or torch.Tensor with shape (..., num_items_kv, in_scalars) + Optional input scalars + scalars_q : None or torch.Tensor with shape (..., num_items_q, in_scalars_q) + Optional input scalars for query + attention_mask: torch.Tensor with shape (..., num_items_q, num_items_kv) or xformers mask. + Attention mask + + Returns + ------- + outputs_mv : torch.Tensor with shape (..., num_items_q, channels_out, 16) + Output multivectors. + output_scalars : torch.Tensor with shape (..., num_items_q, channels_out, out_scalars) + Output scalars, if scalars are provided. Otherwise None. + """ + q_mv, q_s = self.q_linear( + multivectors_q, scalars_q + ) # (..., num_items, hidden_channels, 16) + kv_mv, kv_s = self.kv_linear( + multivectors_kv, scalars_kv + ) # (..., num_items, 2*hidden_channels, 16) + k_mv, v_mv = torch.tensor_split(kv_mv, 2, dim=-2) + k_s, v_s = torch.tensor_split(kv_s, 2, dim=-1) + + # Rearrange to (..., heads, items, channels, 16) shape + q_mv = rearrange( + q_mv, + "... items (hidden_channels num_heads) x -> ... num_heads items hidden_channels x", + num_heads=self.config.num_heads, + hidden_channels=self.config.hidden_mv_channels, + ) + if self.config.multi_query: + k_mv = rearrange( + k_mv, "... items hidden_channels x -> ... 1 items hidden_channels x" + ) + v_mv = rearrange( + v_mv, "... items hidden_channels x -> ... 1 items hidden_channels x" + ) + else: + k_mv = rearrange( + k_mv, + "... items (hidden_channels num_heads) x -> ... num_heads items hidden_channels x", + num_heads=self.config.num_heads, + hidden_channels=self.config.hidden_mv_channels, + ) + v_mv = rearrange( + v_mv, + "... items (hidden_channels num_heads) x -> ... num_heads items hidden_channels x", + num_heads=self.config.num_heads, + hidden_channels=self.config.hidden_mv_channels, + ) + + # Same for scalars + if q_s is not None: + q_s = rearrange( + q_s, + "... items (hidden_channels num_heads) -> ... num_heads items hidden_channels", + num_heads=self.config.num_heads, + hidden_channels=self.config.hidden_s_channels, + ) + if self.config.multi_query: + k_s = rearrange( + k_s, "... items hidden_channels -> ... 1 items hidden_channels" + ) + v_s = rearrange( + v_s, "... items hidden_channels -> ... 1 items hidden_channels" + ) + else: + k_s = rearrange( + k_s, + "... items (hidden_channels num_heads) -> ... num_heads items hidden_channels", + num_heads=self.config.num_heads, + hidden_channels=self.config.hidden_s_channels, + ) + v_s = rearrange( + v_s, + "... items (hidden_channels num_heads) -> ... num_heads items hidden_channels", + num_heads=self.config.num_heads, + hidden_channels=self.config.hidden_s_channels, + ) + else: + q_s, k_s, v_s = None, None, None + + # Attention layer + h_mv, h_s = self.attention( + q_mv, k_mv, v_mv, q_s, k_s, v_s, attention_mask=attention_mask + ) + if self.use_head_scale: + h_mv = h_mv * self.head_scale.view( + *[1] * len(h_mv.shape[:-5]), len(self.head_scale), 1, 1, 1 + ) + h_s = h_s * self.head_scale.view( + *[1] * len(h_s.shape[:-4]), len(self.head_scale), 1, 1 + ) + + h_mv = rearrange( + h_mv, + "... n_heads n_items hidden_channels x -> ... n_items (n_heads hidden_channels) x", + ) + h_s = rearrange( + h_s, + "... n_heads n_items hidden_channels -> ... n_items (n_heads hidden_channels)", + ) + + # Transform linearly one more time + outputs_mv, outputs_s = self.out_linear(h_mv, scalars=h_s) + + # Dropout + if self.dropout is not None: + outputs_mv, outputs_s = self.dropout(outputs_mv, outputs_s) + + return outputs_mv, outputs_s diff --git a/weaver/nn/model/gatr/layers/attention/positional_encoding.py b/weaver/nn/model/gatr/layers/attention/positional_encoding.py new file mode 100644 index 00000000..3fe976c4 --- /dev/null +++ b/weaver/nn/model/gatr/layers/attention/positional_encoding.py @@ -0,0 +1,135 @@ +"""Adapted from the below. + +https://github.com/EleutherAI/gpt-neox/blob/737c9134bfaff7b58217d61f6619f1dcca6c484f/megatron/model/positional_embeddings.py +by EleutherAI at https://github.com/EleutherAI/gpt-neox + +Copyright (c) 2021, EleutherAI + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch + +from gatr.utils.einsum import cached_einsum + + +class ApplyRotaryPositionalEncoding(torch.nn.Module): + """Applies rotary position encodings (RoPE) to scalar tensors. + + References + ---------- + Jianlin Su et al, "RoFormer: Enhanced Transformer with Rotary Position Embedding", + arXiv:2104.09864 + + Parameters + ---------- + num_channels : int + Number of channels (key and query size). + item_dim : int + Embedding dimension. Should be even. + base : int + Determines the frequencies. + """ + + def __init__(self, num_channels, item_dim, base=4096): + super().__init__() + + assert ( + num_channels % 2 == 0 + ), "Number of channels needs to be even for rotary position embeddings" + + inv_freq = 1.0 / ( + base ** (torch.arange(0, num_channels, 2).float() / num_channels) + ) + self.register_buffer("inv_freq", inv_freq) + self.seq_len_cached = None + self.device_cached = None + self.cos_cached = None + self.sin_cached = None + self.item_dim = item_dim + self.num_channels = num_channels + + def forward(self, scalars: torch.Tensor) -> torch.Tensor: + """Computes rotary embeddings along `self.item_dim` and applies them to inputs. + + The inputs are usually scalar queries and keys. + + Assumes that the last dimension is the feature dimension (and is thus not suited + for multivector data!). + + Parameters + ---------- + scalars : torch.Tensor of shape (..., num_channels) + Input data. The last dimension is assumed to be the channel / feature dimension + (NOT the 16 dimensions of a multivector). + + Returns + ------- + outputs : torch.Tensor of shape (..., num_channels) + Output data. Rotary positional embeddings applied to the input tensor. + """ + + # Check inputs + assert scalars.shape[-1] == self.num_channels + + # Compute embeddings, if not already cached + self._compute_embeddings(scalars) + + # Apply embeddings + outputs = ( + scalars * self.cos_cached + self._rotate_half(scalars) * self.sin_cached + ) + + return outputs + + def _compute_embeddings(self, inputs): + """Computes position embeddings and stores them. + + The position embedding is computed along dimension `item_dim` of tensor `inputs` + and is stored in `self.sin_cached` and `self.cos_cached`. + + Parameters + ---------- + inputs : torch.Tensor + Input data. + """ + seq_len = inputs.shape[self.item_dim] + if seq_len != self.seq_len_cached or inputs.device != self.device_cached: + self.seq_len_cached = seq_len + self.device_cached = inputs.device + t = torch.arange(inputs.shape[self.item_dim], device=inputs.device).type_as( + self.inv_freq + ) + freqs = cached_einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(inputs.device) + + self.cos_cached = emb.cos() + self.sin_cached = emb.sin() + + # Insert appropriate amount of dimensions such that the embedding correctly enumerates + # along the item dim + item_dim = ( + self.item_dim if self.item_dim >= 0 else inputs.ndim + self.item_dim + ) # Deal with item_dim < 0 + for _ in range(item_dim + 1, inputs.ndim - 1): + self.cos_cached = self.cos_cached.unsqueeze(1) + self.sin_cached = self.sin_cached.unsqueeze(1) + + @staticmethod + def _rotate_half(inputs): + """Utility function that "rotates" a tensor, as required for rotary embeddings.""" + x1, x2 = ( + inputs[..., : inputs.shape[-1] // 2], + inputs[..., inputs.shape[-1] // 2 :], + ) + return torch.cat((-x2, x1), dim=-1) diff --git a/weaver/nn/model/gatr/layers/attention/qkv.py b/weaver/nn/model/gatr/layers/attention/qkv.py new file mode 100644 index 00000000..70a9698c --- /dev/null +++ b/weaver/nn/model/gatr/layers/attention/qkv.py @@ -0,0 +1,227 @@ +import torch +from einops import rearrange +from torch import nn + +from gatr.layers.attention.config import SelfAttentionConfig +from gatr.layers.linear import EquiLinear + + +class QKVModule(nn.Module): + """Compute (multivector and scalar) queries, keys, and values via multi-head attention. + + Parameters + ---------- + config: SelfAttentionConfig + Attention configuration + """ + + def __init__(self, config: SelfAttentionConfig): + super().__init__() + self.in_linear = EquiLinear( + in_mv_channels=config.in_mv_channels + config.additional_qk_mv_channels, + out_mv_channels=3 * config.hidden_mv_channels * config.num_heads, + in_s_channels=config.in_s_channels + config.additional_qk_s_channels, + out_s_channels=None + if config.in_s_channels is None + else 3 * config.hidden_s_channels * config.num_heads, + ) + self.config = config + + def forward( + self, + inputs, + scalars, + additional_qk_features_mv=None, + additional_qk_features_s=None, + ): + """Forward pass. + + Parameters + ---------- + inputs : torch.Tensor + Multivector inputs + scalars : torch.Tensor + Scalar inputs + additional_qk_features_mv : None or torch.Tensor + Additional multivector features that should be provided for the Q/K computation (e.g. + positions of objects) + additional_qk_features_s : None or torch.Tensor + Additional scalar features that should be provided for the Q/K computation (e.g. + object types) + + Returns + ------- + q_mv : Tensor with shape (..., num_items_out, num_mv_channels_in, 16) + Queries, multivector part. + k_mv : Tensor with shape (..., num_items_in, num_mv_channels_in, 16) + Keys, multivector part. + v_mv : Tensor with shape (..., num_items_in, num_mv_channels_out, 16) + Values, multivector part. + q_s : Tensor with shape (..., heads, num_items_out, num_s_channels_in) + Queries, scalar part. + k_s : Tensor with shape (..., heads, num_items_in, num_s_channels_in) + Keys, scalar part. + v_s : Tensor with shape (..., heads, num_items_in, num_s_channels_out) + Values, scalar part. + """ + + # Additional inputs + if additional_qk_features_mv is not None: + inputs = torch.cat((inputs, additional_qk_features_mv), dim=-2) + if additional_qk_features_s is not None: + scalars = torch.cat((scalars, additional_qk_features_s), dim=-1) + + qkv_mv, qkv_s = self.in_linear( + inputs, scalars + ) # (..., num_items, 3 * hidden_channels * num_heads, 16) + qkv_mv = rearrange( + qkv_mv, + "... items (qkv hidden num_heads) x -> qkv ... num_heads items hidden x", + num_heads=self.config.num_heads, + hidden=self.config.hidden_mv_channels, + qkv=3, + ) + q_mv, k_mv, v_mv = qkv_mv # each: (..., num_heads, num_items, num_channels, 16) + + # Same, for optional scalar components + if qkv_s is not None: + qkv_s = rearrange( + qkv_s, + "... items (qkv hidden num_heads) -> qkv ... num_heads items hidden", + num_heads=self.config.num_heads, + hidden=self.config.hidden_s_channels, + qkv=3, + ) + q_s, k_s, v_s = qkv_s # each: (..., num_heads, num_items, num_channels) + else: + q_s, k_s, v_s = None, None, None + + return q_mv, k_mv, v_mv, q_s, k_s, v_s + + +class MultiQueryQKVModule(nn.Module): + """Compute (multivector and scalar) queries, keys, and values via multi-query attention. + + Parameters + ---------- + config: SelfAttentionConfig + Attention configuration + """ + + def __init__(self, config: SelfAttentionConfig): + super().__init__() + + # Q projection + self.q_linear = EquiLinear( + in_mv_channels=config.in_mv_channels + config.additional_qk_mv_channels, + out_mv_channels=config.hidden_mv_channels * config.num_heads, + in_s_channels=config.in_s_channels + config.additional_qk_s_channels, + out_s_channels=config.hidden_s_channels * config.num_heads, + ) + + # Key and value projections (shared between heads) + self.k_linear = EquiLinear( + in_mv_channels=config.in_mv_channels + config.additional_qk_mv_channels, + out_mv_channels=config.hidden_mv_channels, + in_s_channels=config.in_s_channels + config.additional_qk_s_channels, + out_s_channels=config.hidden_s_channels, + ) + self.v_linear = EquiLinear( + in_mv_channels=config.in_mv_channels, + out_mv_channels=config.hidden_mv_channels, + in_s_channels=config.in_s_channels, + out_s_channels=config.hidden_s_channels, + ) + self.config = config + + def forward( + self, + inputs, + scalars, + additional_qk_features_mv=None, + additional_qk_features_s=None, + ): + """Forward pass. + + Parameters + ---------- + inputs : torch.Tensor + Multivector inputs + scalars : torch.Tensor + Scalar inputs + additional_qk_features_mv : None or torch.Tensor + Additional multivector features that should be provided for the Q/K computation (e.g. + positions of objects) + additional_qk_features_s : None or torch.Tensor + Additional scalar features that should be provided for the Q/K computation (e.g. + object types) + + Returns + ------- + q_mv : Tensor with shape (..., num_items_out, num_mv_channels_in, 16) + Queries, multivector part. + k_mv : Tensor with shape (..., num_items_in, num_mv_channels_in, 16) + Keys, multivector part. + v_mv : Tensor with shape (..., num_items_in, num_mv_channels_out, 16) + Values, multivector part. + q_s : Tensor with shape (..., heads, num_items_out, num_s_channels_in) + Queries, scalar part. + k_s : Tensor with shape (..., heads, num_items_in, num_s_channels_in) + Keys, scalar part. + v_s : Tensor with shape (..., heads, num_items_in, num_s_channels_out) + Values, scalar part. + """ + + # Additional inputs + if additional_qk_features_mv is not None: + qk_inputs = torch.cat((inputs, additional_qk_features_mv), dim=-2) + else: + qk_inputs = inputs + if scalars is not None and additional_qk_features_s is not None: + qk_scalars = torch.cat((scalars, additional_qk_features_s), dim=-1) + else: + qk_scalars = scalars + + # Project to queries, keys, and values (multivector reps) + q_mv, q_s = self.q_linear( + qk_inputs, qk_scalars + ) # (..., num_items, hidden_channels * num_heads, 16) + k_mv, k_s = self.k_linear( + qk_inputs, qk_scalars + ) # (..., num_items, hidden_channels, 16) + v_mv, v_s = self.v_linear( + inputs, scalars + ) # (..., num_items, hidden_channels, 16) + + # Rearrange to (..., heads, items, channels, 16) shape + q_mv = rearrange( + q_mv, + "... items (hidden_channels num_heads) x -> ... num_heads items hidden_channels x", + num_heads=self.config.num_heads, + hidden_channels=self.config.hidden_mv_channels, + ) + k_mv = rearrange( + k_mv, "... items hidden_channels x -> ... 1 items hidden_channels x" + ) + v_mv = rearrange( + v_mv, "... items hidden_channels x -> ... 1 items hidden_channels x" + ) + + # Same for scalars + if q_s is not None: + q_s = rearrange( + q_s, + "... items (hidden_channels num_heads) -> ... num_heads items hidden_channels", + num_heads=self.config.num_heads, + hidden_channels=self.config.hidden_s_channels, + ) + k_s = rearrange( + k_s, "... items hidden_channels -> ... 1 items hidden_channels" + ) + v_s = rearrange( + v_s, "... items hidden_channels -> ... 1 items hidden_channels" + ) + else: + q_s, k_s, v_s = None, None, None + + return q_mv, k_mv, v_mv, q_s, k_s, v_s diff --git a/weaver/nn/model/gatr/layers/attention/self_attention.py b/weaver/nn/model/gatr/layers/attention/self_attention.py new file mode 100644 index 00000000..ab8c10a5 --- /dev/null +++ b/weaver/nn/model/gatr/layers/attention/self_attention.py @@ -0,0 +1,161 @@ +"""Self-attention layers.""" + +from typing import Optional, Tuple + +import torch +from einops import rearrange +from torch import nn + +from gatr.layers.attention.attention import GeometricAttention +from gatr.layers.attention.config import SelfAttentionConfig +from gatr.layers.attention.positional_encoding import ApplyRotaryPositionalEncoding +from gatr.layers.attention.qkv import MultiQueryQKVModule, QKVModule +from gatr.layers.dropout import GradeDropout +from gatr.layers.linear import EquiLinear + + +class SelfAttention(nn.Module): + """Geometric self-attention layer. + + Constructs queries, keys, and values, computes attention, and projects linearly to outputs. + + Parameters + ---------- + config : SelfAttentionConfig + Attention configuration. + """ + + def __init__(self, config: SelfAttentionConfig) -> None: + super().__init__() + + # Store settings + self.config = config + + # QKV computation + self.qkv_module = ( + MultiQueryQKVModule(config) if config.multi_query else QKVModule(config) + ) + + # Output projection + self.out_linear = EquiLinear( + in_mv_channels=config.hidden_mv_channels * config.num_heads, + out_mv_channels=config.out_mv_channels, + in_s_channels=( + None + if config.in_s_channels is None + else config.hidden_s_channels * config.num_heads + ), + out_s_channels=config.out_s_channels, + initialization=config.output_init, + ) + + # Optional positional encoding + self.pos_encoding: nn.Module + if config.pos_encoding: + self.pos_encoding = ApplyRotaryPositionalEncoding( + config.hidden_s_channels, item_dim=-2, base=config.pos_encoding_base + ) + else: + self.pos_encoding = nn.Identity() + + # Attention + self.attention = GeometricAttention(config) + + # Dropout + self.dropout: Optional[nn.Module] + if config.dropout_prob is not None: + self.dropout = GradeDropout(config.dropout_prob) + else: + self.dropout = None + + # HeadScaleMHA + self.use_head_scale = config.head_scale + if self.use_head_scale: + self.head_scale = nn.Parameter(torch.ones(config.num_heads)) + + def forward( + self, + multivectors: torch.Tensor, + additional_qk_features_mv: Optional[torch.Tensor] = None, + scalars: Optional[torch.Tensor] = None, + additional_qk_features_s: Optional[torch.Tensor] = None, + attention_mask=None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Computes forward pass on inputs with shape `(..., items, channels, 16)`. + + The result is the following: + + ``` + # For each head + queries = linear_channels(inputs) + keys = linear_channels(inputs) + values = linear_channels(inputs) + hidden = attention_items(queries, keys, values, biases=biases) + head_output = linear_channels(hidden) + + # Combine results + output = concatenate_heads head_output + ``` + + Parameters + ---------- + multivectors : torch.Tensor with shape (..., num_items, channels_in, 16) + Input multivectors. + additional_qk_features_mv : None or torch.Tensor with shape + (..., num_items, add_qk_mv_channels, 16) + Additional Q/K features, multivector part. + scalars : None or torch.Tensor with shape (..., num_items, num_items, in_scalars) + Optional input scalars + additional_qk_features_s : None or torch.Tensor with shape + (..., num_items, add_qk_mv_channels, 16) + Additional Q/K features, scalar part. + scalars : None or torch.Tensor with shape (..., num_items, num_items, in_scalars) + Optional input scalars + attention_mask: None or torch.Tensor with shape (..., num_items, num_items) + Optional attention mask + + Returns + ------- + outputs_mv : torch.Tensor with shape (..., num_items, channels_out, 16) + Output multivectors. + output_scalars : torch.Tensor with shape (..., num_items, channels_out, out_scalars) + Output scalars, if scalars are provided. Otherwise None. + """ + # Compute Q, K, V + q_mv, k_mv, v_mv, q_s, k_s, v_s = self.qkv_module( + multivectors, scalars, additional_qk_features_mv, additional_qk_features_s + ) + + # Rotary positional encoding + q_s = self.pos_encoding(q_s) + k_s = self.pos_encoding(k_s) + + # Attention layer + h_mv, h_s = self.attention( + q_mv, k_mv, v_mv, q_s, k_s, v_s, attention_mask=attention_mask + ) + if self.use_head_scale: + h_mv = h_mv * self.head_scale.view( + *[1] * len(h_mv.shape[:-5]), len(self.head_scale), 1, 1, 1 + ) + h_s = h_s * self.head_scale.view( + *[1] * len(h_s.shape[:-4]), len(self.head_scale), 1, 1 + ) + + h_mv = rearrange( + h_mv, + "... n_heads n_items hidden_channels x -> ... n_items (n_heads hidden_channels) x", + ) + h_s = rearrange( + h_s, + "... n_heads n_items hidden_channels -> ... n_items (n_heads hidden_channels)", + ) + + # Transform linearly one more time + outputs_mv, outputs_s = self.out_linear(h_mv, scalars=h_s) + + # Dropout + if self.dropout is not None: + outputs_mv, outputs_s = self.dropout(outputs_mv, outputs_s) + + return outputs_mv, outputs_s diff --git a/weaver/nn/model/gatr/layers/conditional_gatr_block.py b/weaver/nn/model/gatr/layers/conditional_gatr_block.py new file mode 100644 index 00000000..073b8bba --- /dev/null +++ b/weaver/nn/model/gatr/layers/conditional_gatr_block.py @@ -0,0 +1,191 @@ +from dataclasses import replace +from typing import Optional, Tuple + +import torch +from torch import nn + +from gatr.layers import ( + SelfAttention, + CrossAttention, + SelfAttentionConfig, + CrossAttentionConfig, +) +from gatr.layers.layer_norm import EquiLayerNorm +from gatr.layers.mlp.config import MLPConfig +from gatr.layers.mlp.mlp import GeoMLP + + +class ConditionalGATrBlock(nn.Module): + """Equivariant transformer decoder block for L-GATr. + + Inputs are first processed by a block consisting of LayerNorm, multi-head geometric + self-attention, and a residual connection. Then the conditions are included with + cross-attention using the same overhead as in the self-attention part. + Then the data is processed by a block consisting of + another LayerNorm, an item-wise two-layer geometric MLP with GeLU activations, and another + residual connection. + + Parameters + ---------- + mv_channels : int + Number of input and output multivector channels + s_channels: int + Number of input and output scalar channels + condition_mv_channels: int + Number of condition multivector channels + condition_s_channels: int + Number of condition scalar channels + attention: SelfAttentionConfig + Attention configuration + crossattention: CrossAttentionConfig + Cross-attention configuration + mlp: MLPConfig + MLP configuration + dropout_prob : float or None + Dropout probability + double_layernorm : bool + Whether to use double layer normalization + """ + + def __init__( + self, + mv_channels: int, + s_channels: int, + condition_mv_channels: int, + condition_s_channels: int, + attention: SelfAttentionConfig, + crossattention: CrossAttentionConfig, + mlp: MLPConfig, + dropout_prob: Optional[float] = None, + double_layernorm: bool = False, + ) -> None: + super().__init__() + + # Normalization layer (stateless, so we can use the same layer for both normalization + # instances) + self.norm = EquiLayerNorm() + self.double_layernorm = double_layernorm + + # Self-attention layer + attention = replace( + attention, + in_mv_channels=mv_channels, + out_mv_channels=mv_channels, + in_s_channels=s_channels, + out_s_channels=s_channels, + output_init="small", + dropout_prob=dropout_prob, + ) + self.attention = SelfAttention(attention) + + # Cross-attention layer + crossattention = replace( + crossattention, + in_q_mv_channels=mv_channels, + in_q_s_channels=s_channels, + in_kv_mv_channels=condition_mv_channels, + in_kv_s_channels=condition_s_channels, + out_mv_channels=mv_channels, + out_s_channels=s_channels, + output_init="small", + dropout_prob=dropout_prob, + ) + self.crossattention = CrossAttention(crossattention) + + # MLP block + mlp = replace( + mlp, + mv_channels=(mv_channels, 2 * mv_channels, mv_channels), + s_channels=(s_channels, 2 * s_channels, s_channels), + dropout_prob=dropout_prob, + ) + self.mlp = GeoMLP(mlp) + + def forward( + self, + multivectors: torch.Tensor, + multivectors_condition: torch.Tensor, + scalars: torch.Tensor = None, + scalars_condition: torch.Tensor = None, + attention_mask=None, + crossattention_mask=None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass of the transformer decoder block. + + Parameters + ---------- + multivectors : torch.Tensor with shape (..., items, channels, 16) + Input multivectors. + scalars : torch.Tensor with shape (..., s_channels) + Input scalars. + multivectors_condition : torch.Tensor with shape (..., items, channels, 16) + Input condition multivectors. + scalars_condition : torch.Tensor with shape (..., s_channels) + Input condition scalars. + attention_mask: None or torch.Tensor or AttentionBias + Optional attention mask. + crossattention_mask: None or torch.Tensor or AttentionBias + Optional attention mask for the condition. + + Returns + ------- + outputs_mv : torch.Tensor with shape (..., items, channels, 16). + Output multivectors + output_scalars : torch.Tensor with shape (..., s_channels) + Output scalars + """ + + # Self-attention block: pre layer norm + h_mv, h_s = self.norm(multivectors, scalars=scalars) + + # Self-attention block: self attention + h_mv, h_s = self.attention( + h_mv, + scalars=h_s, + attention_mask=attention_mask, + ) + + # Self-attention block: post layer norm + if self.double_layernorm: + h_mv, h_s = self.norm(h_mv, scalars=h_s) + + # Self-attention block: skip connection + multivectors = multivectors + h_mv + scalars = scalars + h_s + + # Cross-attention block: pre layer norm + h_mv, h_s = self.norm(multivectors, scalars=scalars) + c_mv, c_s = self.norm(multivectors_condition, scalars=scalars_condition) + + # Cross-attention block: cross attention + h_mv, h_s = self.crossattention( + multivectors_q=h_mv, + multivectors_kv=c_mv, + scalars_q=h_s, + scalars_kv=c_s, + attention_mask=crossattention_mask, + ) + + # Cross-attention block: post layer norm + if self.double_layernorm: + h_mv, h_s = self.norm(h_mv, scalars=h_s) + + # Cross-attention block: skip connection + outputs_mv = multivectors + h_mv + outputs_s = scalars + h_s + + # MLP block: pre layer norm + h_mv, h_s = self.norm(outputs_mv, scalars=outputs_s) + + # MLP block: MLP + h_mv, h_s = self.mlp(h_mv, scalars=h_s) + + # MLP block: post layer norm + if self.double_layernorm: + h_mv, h_s = self.norm(h_mv, scalars=h_s) + + # MLP block: skip connection + outputs_mv = outputs_mv + h_mv + outputs_s = outputs_s + h_s + + return outputs_mv, outputs_s diff --git a/weaver/nn/model/gatr/layers/dropout.py b/weaver/nn/model/gatr/layers/dropout.py new file mode 100644 index 00000000..15180a66 --- /dev/null +++ b/weaver/nn/model/gatr/layers/dropout.py @@ -0,0 +1,51 @@ +"""Equivariant dropout layer.""" + +from typing import Tuple + +import torch +from torch import nn + +from gatr.primitives import grade_dropout + + +class GradeDropout(nn.Module): + """Grade dropout for multivectors (and regular dropout for auxiliary scalars). + + Parameters + ---------- + p : float + Dropout probability. + """ + + def __init__(self, p: float = 0.0): + super().__init__() + self._dropout_prob = p + + def forward( + self, multivectors: torch.Tensor, scalars: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass. Applies dropout. + + Parameters + ---------- + multivectors : torch.Tensor with shape (..., 16) + Multivector inputs. + scalars : torch.Tensor + Scalar inputs. + + Returns + ------- + outputs_mv : torch.Tensor with shape (..., 16) + Multivector inputs with dropout applied. + output_scalars : torch.Tensor + Scalar inputs with dropout applied. + """ + + out_mv = grade_dropout( + multivectors, p=self._dropout_prob, training=self.training + ) + out_s = torch.nn.functional.dropout( + scalars, p=self._dropout_prob, training=self.training + ) + + return out_mv, out_s diff --git a/weaver/nn/model/gatr/layers/gatr_block.py b/weaver/nn/model/gatr/layers/gatr_block.py new file mode 100644 index 00000000..fe37ca40 --- /dev/null +++ b/weaver/nn/model/gatr/layers/gatr_block.py @@ -0,0 +1,143 @@ +from dataclasses import replace +from typing import Optional, Tuple + +import torch +from torch import nn + +from gatr.layers import SelfAttention, SelfAttentionConfig +from gatr.layers.layer_norm import EquiLayerNorm +from gatr.layers.mlp.config import MLPConfig +from gatr.layers.mlp.mlp import GeoMLP + + +class GATrBlock(nn.Module): + """Equivariant transformer encoder block for L-GATr. + + This is the biggest building block of L-GATr. + + Inputs are first processed by a block consisting of LayerNorm, multi-head geometric + self-attention, and a residual connection. Then the data is processed by a block consisting of + another LayerNorm, an item-wise two-layer geometric MLP with GeLU activations, and another + residual connection. + + Parameters + ---------- + mv_channels : int + Number of input and output multivector channels + s_channels: int + Number of input and output scalar channels + attention: SelfAttentionConfig + Attention configuration + mlp: MLPConfig + MLP configuration + dropout_prob : float or None + Dropout probability + double_layernorm : bool + Whether to use double layer normalization + """ + + def __init__( + self, + mv_channels: int, + s_channels: int, + attention: SelfAttentionConfig, + mlp: MLPConfig, + dropout_prob: Optional[float] = None, + double_layernorm: bool = False, + ) -> None: + super().__init__() + + # Normalization layer (stateless, so we can use the same layer for both normalization + # instances) + self.norm = EquiLayerNorm() + self.double_layernorm = double_layernorm + + # Self-attention layer + attention = replace( + attention, + in_mv_channels=mv_channels, + out_mv_channels=mv_channels, + in_s_channels=s_channels, + out_s_channels=s_channels, + output_init="small", + dropout_prob=dropout_prob, + ) + self.attention = SelfAttention(attention) + + # MLP block + mlp = replace( + mlp, + mv_channels=(mv_channels, 2 * mv_channels, mv_channels), + s_channels=(s_channels, 2 * s_channels, s_channels), + dropout_prob=dropout_prob, + ) + self.mlp = GeoMLP(mlp) + + def forward( + self, + multivectors: torch.Tensor, + scalars: torch.Tensor, + additional_qk_features_mv=None, + additional_qk_features_s=None, + attention_mask=None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass of the transformer encoder block. + + Parameters + ---------- + multivectors : torch.Tensor with shape (..., items, channels, 16) + Input multivectors. + scalars : torch.Tensor with shape (..., s_channels) + Input scalars. + additional_qk_features_mv : None or torch.Tensor with shape + (..., num_items, add_qk_mv_channels, 16) + Additional Q/K features, multivector part. + additional_qk_features_s : None or torch.Tensor with shape + (..., num_items, add_qk_mv_channels, 16) + Additional Q/K features, scalar part. + attention_mask: None or torch.Tensor or AttentionBias + Optional attention mask. + + Returns + ------- + outputs_mv : torch.Tensor with shape (..., items, channels, 16). + Output multivectors + output_scalars : torch.Tensor with shape (..., s_channels) + Output scalars + """ + + # Attention block: pre layer norm + h_mv, h_s = self.norm(multivectors, scalars=scalars) + + # Attention block: self attention + h_mv, h_s = self.attention( + h_mv, + scalars=h_s, + additional_qk_features_mv=additional_qk_features_mv, + additional_qk_features_s=additional_qk_features_s, + attention_mask=attention_mask, + ) + + # Attention block: post layer norm + if self.double_layernorm: + h_mv, h_s = self.norm(h_mv, scalars=h_s) + + # Attention block: skip connection + outputs_mv = multivectors + h_mv + outputs_s = scalars + h_s + + # MLP block: pre layer norm + h_mv, h_s = self.norm(outputs_mv, scalars=outputs_s) + + # MLP block: MLP + h_mv, h_s = self.mlp(h_mv, scalars=h_s) + + # MLP block: post layer norm + if self.double_layernorm: + h_mv, h_s = self.norm(h_mv, scalars=h_s) + + # MLP block: skip connection + outputs_mv = outputs_mv + h_mv + outputs_s = outputs_s + h_s + + return outputs_mv, outputs_s diff --git a/weaver/nn/model/gatr/layers/layer_norm.py b/weaver/nn/model/gatr/layers/layer_norm.py new file mode 100644 index 00000000..358b8300 --- /dev/null +++ b/weaver/nn/model/gatr/layers/layer_norm.py @@ -0,0 +1,70 @@ +"""Equivariant normalization layers.""" + +from typing import Tuple + +import torch +from torch import nn + +from gatr.primitives import equi_layer_norm + + +class EquiLayerNorm(nn.Module): + """Equivariant LayerNorm for multivectors. + + Rescales input such that `mean_channels |inputs|^2 = 1`, where the norm is the GA norm and the + mean goes over the channel dimensions. + + In addition, the layer performs a regular LayerNorm operation on auxiliary scalar inputs. + + Parameters + ---------- + mv_channel_dim : int + Channel dimension index for multivector inputs. Defaults to the second-last entry (last are + the multivector components). + scalar_channel_dim : int + Channel dimension index for scalar inputs. Defaults to the last entry. + epsilon : float + Small numerical factor to avoid instabilities. We use a reasonably large number to balance + issues that arise from some multivector components not contributing to the norm. + """ + + def __init__(self, mv_channel_dim=-2, scalar_channel_dim=-1, epsilon: float = 0.01): + super().__init__() + self.mv_channel_dim = mv_channel_dim + self.epsilon = epsilon + + if scalar_channel_dim != -1: + raise NotImplementedError( + "Currently, only scalar_channel_dim = -1 is implemented, but found" + f" {scalar_channel_dim}" + ) + + def forward( + self, multivectors: torch.Tensor, scalars: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass. Computes equivariant LayerNorm for multivectors. + + Parameters + ---------- + multivectors : torch.Tensor with shape (..., 16) + Multivector inputs + scalars : torch.Tensor with shape (..., self.in_channels, self.in_scalars) + Scalar inputs + + Returns + ------- + outputs_mv : torch.Tensor with shape (..., 16) + Normalized multivectors + output_scalars : torch.Tensor with shape (..., self.out_channels, self.in_scalars) + Normalized scalars. + """ + + outputs_mv = equi_layer_norm( + multivectors, channel_dim=self.mv_channel_dim, epsilon=self.epsilon + ) + normalized_shape = scalars.shape[-1:] + outputs_s = torch.nn.functional.layer_norm( + scalars, normalized_shape=normalized_shape + ) + + return outputs_mv, outputs_s diff --git a/weaver/nn/model/gatr/layers/linear.py b/weaver/nn/model/gatr/layers/linear.py new file mode 100644 index 00000000..6233191b --- /dev/null +++ b/weaver/nn/model/gatr/layers/linear.py @@ -0,0 +1,388 @@ +"""Pin-equivariant linear layers between multivector tensors (torch.nn.Modules).""" + +from typing import Optional, Tuple, Union + +import numpy as np +import torch +from torch import nn + +from gatr.interface import embed_scalar +from gatr.primitives.linear import equi_linear, USE_FULLY_CONNECTED_SUBGROUP + +# switch to mix pseudoscalar multivector components directly into scalar components +# this only makes sense when working with the special orthochronous Lorentz group, +# Note: This is an efficiency boost, the same action can be achieved with an extra linear layer +MIX_MVPSEUDOSCALAR_INTO_SCALAR = True +NUM_PIN_LINEAR_BASIS_ELEMENTS = 10 if USE_FULLY_CONNECTED_SUBGROUP else 5 + + +class EquiLinear(nn.Module): + """Pin-equivariant linear layer. + + The forward pass maps multivector inputs with shape (..., in_channels, 16) to multivector + outputs with shape (..., out_channels, 16) as + + ``` + outputs[..., j, y] = sum_{i, b, x} weights[j, i, b] basis_map[b, x, y] inputs[..., i, x] + ``` + + plus an optional bias term for outputs[..., :, 0] (biases in other multivector components would + break equivariance). + + Here basis_map are precomputed (see gatr.primitives.linear) and weights are the + learnable weights of this layer. + + If there are auxiliary input scalars, they transform under a linear layer, and mix with the + scalar components the multivector data. Note that in this layer (and only here) the auxiliary + scalars are optional. + + This layer supports four initialization schemes: + - "default": preserves (or actually slightly reducing) the variance of the data in + the forward pass + - "small": variance of outputs is approximately one order of magnitude smaller + than for "default" + - "unit_scalar": outputs will be close to (1, 0, 0, ..., 0) + - "almost_unit_scalar": similar to "unit_scalar", but with more stochasticity + + Parameters + ---------- + in_mv_channels : int + Input multivector channels + out_mv_channels : int + Output multivector channels + bias : bool + Whether a bias term is added to the scalar component of the multivector outputs + in_s_channels : int or None + Input scalar channels. If None, no scalars are expected nor returned. + out_s_channels : int or None + Output scalar channels. If None, no scalars are expected nor returned. + initialization : {"default", "small", "unit_scalar", "almost_unit_scalar"} + Initialization scheme. For "default", initialize with the same philosophy as most + networks do: preserve variance (approximately) in the forward pass. For "small", + initalize the network such that the variance of the output data is approximately one + order of magnitude smaller than that of the input data. For "unit_scalar", initialize + the layer such that the output multivectors will be closer to (1, 0, 0, ..., 0). + "almost_unit_scalar" is similar, but with more randomness. + """ + + def __init__( + self, + in_mv_channels: int, + out_mv_channels: int, + in_s_channels: Optional[int] = None, + out_s_channels: Optional[int] = None, + bias: bool = True, + initialization: str = "default", + ) -> None: + super().__init__() + + # Check inputs + if initialization in ["unit_scalar", "almost_unit_scalar"]: + assert bias, "unit_scalar initialization requires bias" + if in_s_channels is None: + raise NotImplementedError( + "unit_scalar initialization is currently only implemented for scalar inputs" + ) + + self._in_mv_channels = in_mv_channels + + # MV -> MV + self.weight = nn.Parameter( + torch.empty( + (out_mv_channels, in_mv_channels, NUM_PIN_LINEAR_BASIS_ELEMENTS) + ) + ) + + # We only need a separate bias here if that isn't already covered by the linear map from + # scalar inputs + self.bias = ( + nn.Parameter(torch.zeros((out_mv_channels, 1))) + if bias and in_s_channels is None + else None + ) + + # Scalars -> MV scalars + self.s2mvs: Optional[nn.Linear] + mix_factor = 2 if MIX_MVPSEUDOSCALAR_INTO_SCALAR else 1 + if in_s_channels: + self.s2mvs = nn.Linear( + in_s_channels, mix_factor * out_mv_channels, bias=bias + ) + else: + self.s2mvs = None + + # MV scalars -> scalars + if out_s_channels: + self.mvs2s = nn.Linear( + mix_factor * in_mv_channels, out_s_channels, bias=bias + ) + else: + self.mvs2s = None + + # Scalars -> scalars + if in_s_channels is not None and out_s_channels is not None: + self.s2s = nn.Linear( + in_s_channels, out_s_channels, bias=False + ) # Bias would be duplicate + else: + self.s2s = None + + # Initialization + self.reset_parameters(initialization) + + def forward( + self, multivectors: torch.Tensor, scalars: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: + """Maps input multivectors and scalars using the most general equivariant linear map. + + The result is again multivectors and scalars. + + For multivectors we have: + ``` + outputs[..., j, y] = sum_{i, b, x} weights[j, i, b] basis_map[b, x, y] inputs[..., i, x] + = sum_i linear(inputs[..., i, :], weights[j, i, :]) + ``` + + Here basis_map are precomputed (see gatr.primitives.linear) and weights are the + learnable weights of this layer. + + Parameters + ---------- + multivectors : torch.Tensor with shape (..., in_mv_channels, 16) + Input multivectors + scalars : None or torch.Tensor with shape (..., in_s_channels) + Optional input scalars + + Returns + ------- + outputs_mv : torch.Tensor with shape (..., out_mv_channels, 16) + Output multivectors + outputs_s : None or torch.Tensor with shape (..., out_s_channels) + Output scalars, if scalars are provided. Otherwise None. + """ + + outputs_mv = equi_linear(multivectors, self.weight) # (..., out_channels, 16) + + if self.bias is not None: + bias = embed_scalar(self.bias) + outputs_mv = outputs_mv + bias + + if self.s2mvs is not None and scalars is not None: + if MIX_MVPSEUDOSCALAR_INTO_SCALAR: + outputs_mv[..., [0, -1]] += self.s2mvs(scalars).view( + *outputs_mv.shape[:-2], outputs_mv.shape[-2], 2 + ) + else: + outputs_mv[..., 0] += self.s2mvs(scalars) + + if self.mvs2s is not None: + if MIX_MVPSEUDOSCALAR_INTO_SCALAR: + outputs_s = self.mvs2s(multivectors[..., [0, -1]].flatten(start_dim=-2)) + else: + outputs_s = self.mvs2s(multivectors[..., 0]) + if self.s2s is not None and scalars is not None: + outputs_s = outputs_s + self.s2s(scalars) + else: + outputs_s = None + + return outputs_mv, outputs_s + + def reset_parameters( + self, + initialization: str, + gain: float = 1.0, + additional_factor=1.0 / np.sqrt(3.0), + ) -> None: + """Initializes the weights of the layer. + + Parameters + ---------- + initialization : {"default", "small", "unit_scalar", "almost_unit_scalar"} + Initialization scheme. For "default", initialize with the same philosophy as most + networks do: preserve variance (approximately) in the forward pass. For "small", + initalize the network such that the variance of the output data is approximately one + order of magnitude smaller than that of the input data. For "unit_scalar", initialize + the layer such that the output multivectors will be closer to (1, 0, 0, ..., 0). + "almost_unit_scalar" is similar, but with more randomness. + gain : float + Gain factor for the activations. Should be 1.0 if previous layer has no activation, + sqrt(2) if it has a ReLU activation, and so on. Can be computed with + `torch.nn.init.calculate_gain()`. + additional_factor : float + Empirically, it has been found that slightly *decreasing* the data variance at each + layer gives a better performance. In particular, the PyTorch default initialization uses + an additional factor of 1/sqrt(3) (cancelling the factor of sqrt(3) that naturally + arises when computing the bounds of a uniform initialization). A discussion of this was + (to the best of our knowledge) never published, but see + https://github.com/pytorch/pytorch/issues/57109 and + https://soumith.ch/files/20141213_gplus_nninit_discussion.htm. + """ + + # Prefactors depending on initialization scheme + ( + mv_component_factors, + mv_factor, + mvs_bias_shift, + s_factor, + ) = self._compute_init_factors( + initialization, + gain, + additional_factor, + ) + + # Following He et al, 1502.01852, we aim to preserve the variance in the forward pass. + # A sufficient criterion for this is that the variance of the weights is given by + # `Var[w] = gain^2 / fan`. + # Here `gain^2` is 2 if the previous layer has a ReLU nonlinearity, 1 for the initial layer, + # and some other value in other situations (we may not care about this too much). + # More importantly, `fan` is the number of connections: the number of input elements that + # get summed over to compute each output element. + + # Let us fist consider the multivector outputs. + self._init_multivectors(mv_component_factors, mv_factor, mvs_bias_shift) + + # Then let's consider the maps to scalars. + self._init_scalars(s_factor) + + @staticmethod + def _compute_init_factors( + initialization, + gain, + additional_factor, + ): + """Computes prefactors for the initialization. + + See self.reset_parameters(). + """ + + if initialization not in { + "default", + "small", + "unit_scalar", + "almost_unit_scalar", + }: + raise ValueError(f"Unknown initialization scheme {initialization}") + + if initialization == "default": + mv_factor = gain * additional_factor * np.sqrt(3) + s_factor = gain * additional_factor * np.sqrt(3) + mvs_bias_shift = 0.0 + elif initialization == "small": + # Change scale by a factor of 0.1 in this layer + mv_factor = 0.1 * gain * additional_factor * np.sqrt(3) + s_factor = 0.1 * gain * additional_factor * np.sqrt(3) + mvs_bias_shift = 0.0 + elif initialization == "unit_scalar": + # Change scale by a factor of 0.1 for MV outputs, and initialize bias around 1 + mv_factor = 0.1 * gain * additional_factor * np.sqrt(3) + s_factor = gain * additional_factor * np.sqrt(3) + mvs_bias_shift = 1.0 + elif initialization == "almost_unit_scalar": + # Change scale by a factor of 0.5 for MV outputs, and initialize bias around 1 + mv_factor = 0.5 * gain * additional_factor * np.sqrt(3) + s_factor = gain * additional_factor * np.sqrt(3) + mvs_bias_shift = 1.0 + else: + raise ValueError( + f"Unknown initialization scheme {initialization}, expected" + ' "default", "small", "unit_scalar" or "almost_unit_scalar".' + ) + + # Individual factors for each multivector component (could be tuned for performance) + mv_component_factors = torch.ones(NUM_PIN_LINEAR_BASIS_ELEMENTS) + return mv_component_factors, mv_factor, mvs_bias_shift, s_factor + + def _init_multivectors(self, mv_component_factors, mv_factor, mvs_bias_shift): + """Weight initialization for maps to multivector outputs.""" + + # We have + # `outputs[..., j, y] = sum_{i, b, x} weights[j, i, b] basis_map[b, x, y] inputs[..., i, x]` + # The basis maps are more or less grade projections, summing over all basis elements + # corresponds to (almost) an identity map in the GA space. The sum over `b` and `x` thus + # does not contribute to `fan` substantially. (We may add a small ad-hoc factor later to + # make up for this approximation.) However, there is still the sum over incoming channels, + # and thus `fan ~ mv_in_channels`. Assuming (for now) that the previous layer contained a + # ReLU activation, we finally have the condition `Var[w] = 2 / mv_in_channels`. + # Since the variance of a uniform distribution between -a and a is given by + # `Var[Uniform(-a, a)] = a^2/3`, we should set `a = gain * sqrt(3 / mv_in_channels)`. + # In theory (see docstring). + fan_in = self._in_mv_channels + bound = mv_factor / np.sqrt(fan_in) + for i, factor in enumerate(mv_component_factors): + nn.init.uniform_(self.weight[..., i], a=-factor * bound, b=factor * bound) + + # Now let's focus on the scalar components of the multivector outputs. + # If there are only multivector inputs, all is good. But if scalar inputs contribute them as + # well, they contribute to the output variance as well. + # In this case, we initialize such that the multivector inputs and the scalar inputs each + # contribute half to the output variance. + # We can achieve this by inspecting the basis maps and seeing that only basis element 0 + # contributes to the scalar output. Thus, we can reduce the variance of the correponding + # weights to give a variance of 0.5, not 1. + if self.s2mvs is not None: + # contribution from scalar -> mv scalar + bound = mv_component_factors[0] * mv_factor / np.sqrt(fan_in) / np.sqrt(2) + nn.init.uniform_(self.weight[..., [0]], a=-bound, b=bound) + if MIX_MVPSEUDOSCALAR_INTO_SCALAR: + # contribution from scalar -> mv pseudoscalar + bound = ( + mv_component_factors[-1] * mv_factor / np.sqrt(fan_in) / np.sqrt(2) + ) + nn.init.uniform_(self.weight[..., [-1]], a=-bound, b=bound) + + # The same holds for the scalar-to-MV map, where we also just want a variance of 0.5. + # Note: This is not properly extended to scalar and pseudoscalar outputs yet + if self.s2mvs is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out( + self.s2mvs.weight + ) # pylint:disable=protected-access + fan_in = max( + fan_in, 1 + ) # Since in theory we could have 0-channel scalar "data" + bound = mv_component_factors[0] * mv_factor / np.sqrt(fan_in) / np.sqrt(2) + nn.init.uniform_(self.s2mvs.weight, a=-bound, b=bound) + + # Bias needs to be adapted, as the overall fan in is different (need to account for MV + # and s inputs) and we may need to account for the unit_scalar initialization scheme + if self.s2mvs.bias is not None: + fan_in = ( + nn.init._calculate_fan_in_and_fan_out(self.s2mvs.weight)[0] + + self._in_mv_channels + ) + bound = mv_component_factors[0] / np.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_( + self.s2mvs.bias, mvs_bias_shift - bound, mvs_bias_shift + bound + ) + + def _init_scalars(self, s_factor): + """Weight initialization for maps to multivector outputs.""" + + # If both exist, we need to account for overcounting again, and assign each a target a + # variance of 0.5. + # Note: This is not properly extended to scalar and pseudoscalar outputs yet + models = [] + if self.s2s: + models.append(self.s2s) + if self.mvs2s: + models.append(self.mvs2s) + for model in models: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out( + model.weight + ) # pylint:disable=protected-access + fan_in = max( + fan_in, 1 + ) # Since in theory we could have 0-channel scalar "data" + bound = s_factor / np.sqrt(fan_in) / np.sqrt(len(models)) + nn.init.uniform_(model.weight, a=-bound, b=bound) + # Bias needs to be adapted, as the overall fan in is different (need to account for MV and + # s inputs) + if self.mvs2s and self.mvs2s.bias is not None: + fan_in = nn.init._calculate_fan_in_and_fan_out(self.mvs2s.weight)[ + 0 + ] # pylint:disable=protected-access + if self.s2s: + fan_in += nn.init._calculate_fan_in_and_fan_out(self.s2s.weight)[ + 0 + ] # pylint:disable=protected-access + bound = s_factor / np.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(self.mvs2s.bias, -bound, bound) diff --git a/weaver/nn/model/gatr/layers/mlp/__init__.py b/weaver/nn/model/gatr/layers/mlp/__init__.py new file mode 100644 index 00000000..2423e991 --- /dev/null +++ b/weaver/nn/model/gatr/layers/mlp/__init__.py @@ -0,0 +1,4 @@ +from .config import MLPConfig +from .geometric_bilinears import GeometricBilinear +from .mlp import GeoMLP +from .nonlinearities import ScalarGatedNonlinearity diff --git a/weaver/nn/model/gatr/layers/mlp/config.py b/weaver/nn/model/gatr/layers/mlp/config.py new file mode 100644 index 00000000..959ab2f0 --- /dev/null +++ b/weaver/nn/model/gatr/layers/mlp/config.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, List, Mapping, Optional + + +@dataclass +class MLPConfig: + """Geometric MLP configuration. + + Parameters + ---------- + mv_channels : iterable of int + Number of multivector channels at each layer, from input to output + s_channels : None or iterable of int + If not None, sets the number of scalar channels at each layer, from input to output. Length + needs to match mv_channels + activation : {"relu", "sigmoid", "gelu"} + Which (gated) activation function to use + dropout_prob : float or None + Dropout probability + """ + + mv_channels: Optional[List[int]] = None + s_channels: Optional[List[int]] = None + activation: str = "gelu" + dropout_prob: Optional[float] = None + + def __post_init__(self): + """Type checking / conversion.""" + if isinstance(self.dropout_prob, str) and self.dropout_prob.lower() in [ + "null", + "none", + ]: + self.dropout_prob = None + + @classmethod + def cast(cls, config: Any) -> MLPConfig: + """Casts an object as MLPConfig.""" + if isinstance(config, MLPConfig): + return config + if isinstance(config, Mapping): + return cls(**config) + raise ValueError(f"Can not cast {config} to {cls}") diff --git a/weaver/nn/model/gatr/layers/mlp/geometric_bilinears.py b/weaver/nn/model/gatr/layers/mlp/geometric_bilinears.py new file mode 100644 index 00000000..061f53f2 --- /dev/null +++ b/weaver/nn/model/gatr/layers/mlp/geometric_bilinears.py @@ -0,0 +1,105 @@ +"""Pin-equivariant geometric product layer between multivector tensors (torch.nn.Modules).""" + +from typing import Optional, Tuple + +import torch +from torch import nn + +from gatr.layers.linear import EquiLinear +from gatr.primitives import geometric_product +from gatr.layers.layer_norm import EquiLayerNorm + +# switch to set bivector components to zero, +# after they are generated by the geometric product +ZERO_BIVECTOR = False + + +class GeometricBilinear(nn.Module): + """Geometric bilinear layer. + + Pin-equivariant map between multivector tensors that constructs new geometric features via + geometric products. + + Parameters + ---------- + in_mv_channels : int + Input multivector channels of `x` + out_mv_channels : int + Output multivector channels + hidden_mv_channels : int or None + Hidden MV channels. If None, uses out_mv_channels. + in_s_channels : int or None + Input scalar channels of `x`. If None, no scalars are expected nor returned. + out_s_channels : int or None + Output scalar channels. If None, no scalars are expected nor returned. + """ + + def __init__( + self, + in_mv_channels: int, + out_mv_channels: int, + hidden_mv_channels: Optional[int] = None, + in_s_channels: Optional[int] = None, + out_s_channels: Optional[int] = None, + ) -> None: + super().__init__() + + # Default options + if hidden_mv_channels is None: + hidden_mv_channels = out_mv_channels + + # Linear projections for GP + self.linear_left = EquiLinear( + in_mv_channels, + hidden_mv_channels, + in_s_channels=in_s_channels, + out_s_channels=None, + ) + self.linear_right = EquiLinear( + in_mv_channels, + hidden_mv_channels, + in_s_channels=in_s_channels, + out_s_channels=None, + initialization="almost_unit_scalar", + ) + + # Output linear projection + self.linear_out = EquiLinear( + hidden_mv_channels, out_mv_channels, in_s_channels, out_s_channels + ) + self.norm = EquiLayerNorm() + + def forward( + self, + multivectors: torch.Tensor, + scalars: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass. + + Parameters + ---------- + multivectors : torch.Tensor with shape (..., in_mv_channels, 16) + Input multivectors + scalars : torch.Tensor with shape (..., in_s_channels) + Input scalars + + Returns + ------- + outputs_mv : torch.Tensor with shape (..., self.out_mv_channels, 16) + Output multivectors + output_s : None or torch.Tensor with shape (..., out_s_channels) + Output scalars. + """ + + # GP + left, _ = self.linear_left(multivectors, scalars=scalars) + right, _ = self.linear_right(multivectors, scalars=scalars) + gp_outputs = geometric_product(left, right) + if ZERO_BIVECTOR: + gp_outputs[..., 5:11] = 0.0 + + # Output linear + outputs_mv, outputs_s = self.linear_out(gp_outputs, scalars=scalars) + + outputs_mv, outputs_s = self.norm(outputs_mv, outputs_s) + return outputs_mv, outputs_s diff --git a/weaver/nn/model/gatr/layers/mlp/mlp.py b/weaver/nn/model/gatr/layers/mlp/mlp.py new file mode 100644 index 00000000..fbfed722 --- /dev/null +++ b/weaver/nn/model/gatr/layers/mlp/mlp.py @@ -0,0 +1,106 @@ +"""Factory functions for simple MLPs for multivector data.""" + +from typing import List, Tuple, Union + +import torch +from torch import nn + +from gatr.layers.dropout import GradeDropout +from gatr.layers.linear import EquiLinear +from gatr.layers.mlp.config import MLPConfig +from gatr.layers.mlp.geometric_bilinears import GeometricBilinear +from gatr.layers.mlp.nonlinearities import ScalarGatedNonlinearity + +USE_GEOMETRIC_PRODUCT = True + + +class GeoMLP(nn.Module): + """Geometric MLP. + + This is a core component of GATr's transformer blocks. It is similar to a regular MLP, except + that it uses geometric bilinears (the geometric product) in place of the first linear layer. + + Assumes input has shape `(..., channels[0], 16)`, output has shape `(..., channels[-1], 16)`, + will create hidden layers with shape `(..., channel, 16)` for each additional entry in + `channels`. + + Parameters + ---------- + config: MLPConfig + Configuration object + """ + + def __init__( + self, + config: MLPConfig, + ) -> None: + super().__init__() + + # Store settings + self.config = config + + assert config.mv_channels is not None + s_channels = ( + [None for _ in config.mv_channels] + if config.s_channels is None + else config.s_channels + ) + + layers: List[nn.Module] = [] + + if len(config.mv_channels) >= 2: + kwargs = dict( + in_mv_channels=config.mv_channels[0], + out_mv_channels=config.mv_channels[1], + in_s_channels=s_channels[0], + out_s_channels=s_channels[1], + ) + if USE_GEOMETRIC_PRODUCT: + layers.append(GeometricBilinear(**kwargs)) + else: + layers.append(ScalarGatedNonlinearity(config.activation)) + layers.append(EquiLinear(**kwargs)) + if config.dropout_prob is not None: + layers.append(GradeDropout(config.dropout_prob)) + + for in_, out, in_s, out_s in zip( + config.mv_channels[1:-1], + config.mv_channels[2:], + s_channels[1:-1], + s_channels[2:], + ): + layers.append(ScalarGatedNonlinearity(config.activation)) + layers.append( + EquiLinear(in_, out, in_s_channels=in_s, out_s_channels=out_s) + ) + if config.dropout_prob is not None: + layers.append(GradeDropout(config.dropout_prob)) + + self.layers = nn.ModuleList(layers) + + def forward( + self, multivectors: torch.Tensor, scalars: torch.Tensor + ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: + """Forward pass. + + Parameters + ---------- + multivectors : torch.Tensor with shape (..., in_mv_channels, 16) + Input multivectors. + scalars : None or torch.Tensor with shape (..., in_s_channels) + Optional input scalars. + + Returns + ------- + outputs_mv : torch.Tensor with shape (..., out_mv_channels, 16) + Output multivectors. + outputs_s : None or torch.Tensor with shape (..., out_s_channels) + Output scalars, if scalars are provided. Otherwise None. + """ + + mv, s = multivectors, scalars + + for i, layer in enumerate(self.layers): + mv, s = layer(mv, scalars=s) + + return mv, s diff --git a/weaver/nn/model/gatr/layers/mlp/nonlinearities.py b/weaver/nn/model/gatr/layers/mlp/nonlinearities.py new file mode 100644 index 00000000..94c871e4 --- /dev/null +++ b/weaver/nn/model/gatr/layers/mlp/nonlinearities.py @@ -0,0 +1,65 @@ +from typing import Tuple + +import torch +from torch import nn + +from gatr.primitives.nonlinearities import gated_gelu, gated_relu, gated_sigmoid + + +class ScalarGatedNonlinearity(nn.Module): + """Gated nonlinearity, where the gate is simply given by the scalar component of the input. + + Given multivector input x, computes f(x_0) * x, where f can either be ReLU, sigmoid, or GeLU. + + Auxiliary scalar inputs are simply processed with ReLU, sigmoid, or GeLU, without gating. + + Parameters + ---------- + nonlinearity : {"relu", "sigmoid", "gelu"} + Non-linearity type + """ + + def __init__(self, nonlinearity: str = "relu", **kwargs) -> None: + super().__init__() + + gated_fn_dict = dict(relu=gated_relu, gelu=gated_gelu, sigmoid=gated_sigmoid) + scalar_fn_dict = dict( + relu=nn.functional.relu, + gelu=nn.functional.gelu, + sigmoid=nn.functional.sigmoid, + ) + try: + self.gated_nonlinearity = gated_fn_dict[nonlinearity] + self.scalar_nonlinearity = scalar_fn_dict[nonlinearity] + except KeyError as exc: + raise ValueError( + f"Unknown nonlinearity {nonlinearity} for options {list(gated_fn_dict.keys())}" + ) from exc + + def forward( + self, multivectors: torch.Tensor, scalars: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Computes f(x_0) * x for multivector x, where f is GELU, ReLU, or sigmoid. + + f is chosen depending on self.nonlinearity. + + Parameters + ---------- + multivectors : torch.Tensor with shape (..., self.in_channels, 16) + Input multivectors + scalars : None or torch.Tensor with shape (..., self.in_channels, self.in_scalars) + Input scalars + + Returns + ------- + outputs_mv : torch.Tensor with shape (..., self.out_channels, 16) + Output multivectors + output_scalars : torch.Tensor with shape (..., self.out_channels, self.in_scalars) + Output scalars + """ + + gates = multivectors[..., [0]] + outputs_mv = self.gated_nonlinearity(multivectors, gates=gates) + outputs_s = self.scalar_nonlinearity(scalars) + + return outputs_mv, outputs_s diff --git a/weaver/nn/model/gatr/nets/__init__.py b/weaver/nn/model/gatr/nets/__init__.py new file mode 100644 index 00000000..2489425b --- /dev/null +++ b/weaver/nn/model/gatr/nets/__init__.py @@ -0,0 +1,4 @@ +from .axial_gatr import AxialGATr +from .gatr import GATr +from .gap import GAP +from .conditional_gatr import ConditionalGATr diff --git a/weaver/nn/model/gatr/nets/axial_gatr.py b/weaver/nn/model/gatr/nets/axial_gatr.py new file mode 100644 index 00000000..dc42f989 --- /dev/null +++ b/weaver/nn/model/gatr/nets/axial_gatr.py @@ -0,0 +1,214 @@ +from dataclasses import replace +from typing import Optional, Tuple, Union + +import torch +from einops import rearrange +from torch import nn +from torch.utils.checkpoint import checkpoint + +from gatr.layers.attention.config import SelfAttentionConfig +from gatr.layers.gatr_block import GATrBlock +from gatr.layers.linear import EquiLinear +from gatr.layers.mlp.config import MLPConfig + +# Default rearrange patterns +_MV_REARRANGE_PATTERN = "... i j c x -> ... j i c x" +_S_REARRANGE_PATTERN = "... i j c -> ... j i c" + + +class AxialGATr(nn.Module): # pylint: disable=duplicate-code + """Axial L-GATr network for two token dimensions. + + It combines `num_blocks` L-GATr transformer blocks, each consisting of geometric self-attention + layers, a geometric MLP, residual connections, and normalization layers. In addition, there + are initial and final equivariant linear layers. + + Assumes input data with shape `(..., num_items_1, num_items_2, num_channels, 16)`. + + The first, third, fifth, ... block computes attention over the `items_2` axis. The other blocks + compute attention over the `items_1` axis. Positional encoding can be specified separately for + both axes. + + Parameters + ---------- + in_mv_channels : int + Number of input multivector channels. + out_mv_channels : int + Number of output multivector channels. + hidden_mv_channels : int + Number of hidden multivector channels. + in_s_channels : None or int + If not None, sets the number of scalar input channels. + out_s_channels : None or int + If not None, sets the number of scalar output channels. + hidden_s_channels : None or int + If not None, sets the number of scalar hidden channels. + attention: Dict + Data for SelfAttentionConfig + mlp: Dict + Data for MLPConfig + num_blocks : int + Number of transformer blocks. + pos_encodings : tuple of bool + Whether to apply rotary positional embeddings along the item dimensions to the scalar keys + and queries. The first element in the tuple determines whether positional embeddings + are applied to the first item dimension, the second element the same for the second item + dimension. + collapse_dims_for_odd_blocks : bool + Whether the batch dimensions will be collapsed in odd blocks (to support xformers block + attention) + """ + + def __init__( + self, + in_mv_channels: int, + out_mv_channels: int, + hidden_mv_channels: int, + in_s_channels: Optional[int], + out_s_channels: Optional[int], + hidden_s_channels: Optional[int], + attention: SelfAttentionConfig, + mlp: MLPConfig, + num_blocks: int = 20, + checkpoint_blocks: bool = False, + pos_encodings: Tuple[bool, bool] = (False, False), + collapse_dims_for_odd_blocks=False, + **kwargs, + ) -> None: + super().__init__() + self.linear_in = EquiLinear( + in_mv_channels, + hidden_mv_channels, + in_s_channels=in_s_channels, + out_s_channels=hidden_s_channels, + ) + attention = SelfAttentionConfig.cast(attention) + mlp = MLPConfig.cast(mlp) + self.blocks = nn.ModuleList( + [ + GATrBlock( + mv_channels=hidden_mv_channels, + s_channels=hidden_s_channels, + attention=replace( + attention, + pos_encoding=pos_encodings[(block + 1) % 2], + ), + mlp=mlp, + ) + for block in range(num_blocks) + ] + ) + self.linear_out = EquiLinear( + hidden_mv_channels, + out_mv_channels, + in_s_channels=hidden_s_channels, + out_s_channels=out_s_channels, + ) + self._checkpoint_blocks = checkpoint_blocks + self._collapse_dims_for_odd_blocks = collapse_dims_for_odd_blocks + + def forward( + self, + multivectors: torch.Tensor, + scalars: Optional[torch.Tensor] = None, + attention_mask: Optional[Tuple] = None, + ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: + """Forward pass of the network. + + Parameters + ---------- + multivectors : torch.Tensor with shape (..., num_items_1, num_items_2, in_mv_channels, 16) + Input multivectors. + scalars : None or torch.Tensor with shape (..., num_items_1, num_items_2, in_s_channels) + Optional input scalars. + attention_mask : None or tuple of torch.Tensor + Optional attention masks + + Returns + ------- + outputs_mv : torch.Tensor with shape (..., num_items_1, num_items_2, out_mv_channels, 16) + Output multivectors. + outputs_s : None or torch.Tensor with shape + (..., num_items_1, num_items_2, out_mv_channels, 16) + Output scalars, if scalars are provided. Otherwise None. + """ + + # Pass through the blocks + h_mv, h_s = self.linear_in(multivectors, scalars=scalars) + + for i, block in enumerate(self.blocks): + # For first, third, ... block, we want to perform attention over the first token + # dimension. We implement this by transposing the two item dimensions. + if i % 2 == 1: + h_mv, h_s, input_batch_dims = self._reshape_data_before_odd_blocks( + h_mv, h_s + ) + else: + input_batch_dims = None + + # Attention masks will also be different + if attention_mask is None: + this_attention_mask = None + else: + this_attention_mask = attention_mask[(i + 1) % 2] + + if self._checkpoint_blocks: + h_mv, h_s = checkpoint( + block, + h_mv, + use_reentrant=False, + scalars=h_s, + attention_mask=this_attention_mask, + ) + else: + h_mv, h_s = block( + h_mv, + scalars=h_s, + attention_mask=this_attention_mask, + ) + + # Transposing back to standard axis order + if i % 2 == 1: + h_mv, h_s = self._reshape_data_after_odd_blocks( + h_mv, h_s, input_batch_dims + ) + + outputs_mv, outputs_s = self.linear_out(h_mv, scalars=h_s) + + return outputs_mv, outputs_s + + def _reshape_data_before_odd_blocks(self, multivector, scalar): + # Prepare reshuffling between axial layers + input_batch_dims = multivector.shape[:2] + assert scalar.shape[:2] == input_batch_dims + + multivector = rearrange( + multivector, _MV_REARRANGE_PATTERN + ) # (axis2, axis1, ...) + scalar = rearrange(scalar, _S_REARRANGE_PATTERN) # (axis2, axis1, ...) + + if self._collapse_dims_for_odd_blocks: + multivector = multivector.reshape( + -1, *multivector.shape[2:] + ) # (axis2 * axis1, ...) + scalar = scalar.reshape(-1, *scalar.shape[2:]) # (axis2 * axis1, ...) + + return multivector, scalar, input_batch_dims + + def _reshape_data_after_odd_blocks(self, multivector, scalar, input_batch_dims): + # Transposing back to standard axis order + + if self._collapse_dims_for_odd_blocks: + multivector = multivector.reshape( + *input_batch_dims, *multivector.shape[1:] + ) # (axis2, axis1, ...) + scalar = scalar.reshape( + *input_batch_dims, *scalar.shape[1:] + ) # (axis2, axis1, ...) + + multivector = rearrange( + multivector, _MV_REARRANGE_PATTERN + ) # (axis1, axis2, ...) + scalar = rearrange(scalar, _S_REARRANGE_PATTERN) # (axis1, axis2, ...) + + return multivector, scalar diff --git a/weaver/nn/model/gatr/nets/conditional_gatr.py b/weaver/nn/model/gatr/nets/conditional_gatr.py new file mode 100644 index 00000000..e23d6617 --- /dev/null +++ b/weaver/nn/model/gatr/nets/conditional_gatr.py @@ -0,0 +1,227 @@ +"""Equivariant transformer for multivector data.""" + +from dataclasses import replace +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from torch.utils.checkpoint import checkpoint + +from gatr.layers import ( + CrossAttentionConfig, + SelfAttentionConfig, + GATrBlock, + ConditionalGATrBlock, + EquiLinear, +) +from gatr.layers.mlp.config import MLPConfig + + +class ConditionalGATr(nn.Module): + """L-GATr network for a data with a single token dimension. + + It combines `num_blocks` L-GATr transformer blocks, each consisting of geometric self-attention + layers, a geometric MLP, residual connections, and normalization layers. In addition, there + are initial and final equivariant linear layers. + + Assumes input has shape `(..., items, in_channels, 16)`, output has shape + `(..., items, out_channels, 16)`, will create hidden representations with shape + `(..., items, hidden_channels, 16)`. + + Parameters + ---------- + in_mv_channels : int + Number of input multivector channels. + condition_mv_channels : int + Number of condition multivector channels. + out_mv_channels : int + Number of output multivector channels. + hidden_mv_channels : int + Number of hidden multivector channels. + in_s_channels : None or int + If not None, sets the number of scalar input channels. + condition_s_channels : None or int + If not None, sets the number of scalar condition channels. + out_s_channels : None or int + If not None, sets the number of scalar output channels. + hidden_s_channels : None or int + If not None, sets the number of scalar hidden channels. + attention: Dict + Data for SelfAttentionConfig + crossattention: Dict + Data for CrossAttentionConfig + attention_condition: Dict + Data for SelfAttentionConfig + mlp: Dict + Data for MLPConfig + num_blocks : int + Number of transformer blocks. + dropout_prob : float or None + Dropout probability + double_layernorm : bool + Whether to use double layer normalization + """ + + def __init__( + self, + in_mv_channels: int, + condition_mv_channels: int, + out_mv_channels: int, + hidden_mv_channels: int, + in_s_channels: Optional[int], + condition_s_channels: Optional[int], + out_s_channels: Optional[int], + hidden_s_channels: Optional[int], + attention: SelfAttentionConfig, + crossattention: CrossAttentionConfig, + attention_condition: SelfAttentionConfig, + mlp: MLPConfig, + num_blocks: int = 10, + checkpoint_blocks: bool = False, + dropout_prob: Optional[float] = None, + double_layernorm: bool = False, + **kwargs, + ) -> None: + super().__init__() + self.linear_in_condition = EquiLinear( + condition_mv_channels, + hidden_mv_channels, + in_s_channels=condition_s_channels, + out_s_channels=hidden_s_channels, + ) + self.linear_in = EquiLinear( + in_mv_channels, + hidden_mv_channels, + in_s_channels=in_s_channels, + out_s_channels=hidden_s_channels, + ) + self.linear_condition = EquiLinear( + condition_mv_channels, + hidden_mv_channels, + in_s_channels=condition_s_channels, + out_s_channels=hidden_s_channels, + ) + attention = SelfAttentionConfig.cast(attention) + crossattention = CrossAttentionConfig.cast(crossattention) + attention_condition = SelfAttentionConfig.cast(attention_condition) + mlp = MLPConfig.cast(mlp) + self.condition_blocks = nn.ModuleList( + [ + GATrBlock( + mv_channels=hidden_mv_channels, + s_channels=hidden_s_channels, + attention=attention_condition, + mlp=mlp, + dropout_prob=dropout_prob, + double_layernorm=double_layernorm, + ) + for _ in range(num_blocks) + ] + ) + self.blocks = nn.ModuleList( + [ + ConditionalGATrBlock( + mv_channels=hidden_mv_channels, + s_channels=hidden_s_channels, + condition_mv_channels=hidden_mv_channels, + condition_s_channels=hidden_s_channels, + attention=attention, + crossattention=crossattention, + mlp=mlp, + dropout_prob=dropout_prob, + double_layernorm=double_layernorm, + ) + ] + ) + self.linear_out = EquiLinear( + hidden_mv_channels, + out_mv_channels, + in_s_channels=hidden_s_channels, + out_s_channels=out_s_channels, + ) + self._checkpoint_blocks = checkpoint_blocks + + def forward( + self, + multivectors: torch.Tensor, + multivectors_condition: torch.Tensor, + scalars: Optional[torch.Tensor] = None, + scalars_condition: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + attention_mask_condition: Optional[torch.Tensor] = None, + crossattention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: + """Forward pass of the network. + + Parameters + ---------- + multivectors : torch.Tensor with shape (..., num_items, in_mv_channels, 16) + Input multivectors. + multivectors_condition : torch.Tensor with shape (..., num_items_condition, in_mv_channels, 16) + Input multivectors. + scalars : None or torch.Tensor with shape (..., num_items, in_s_channels) + Optional input scalars. + scalars_condition : None or torch.Tensor with shape (..., num_items_condition, in_s_channels) + Optional input scalars. + attention_mask: None or torch.Tensor with shape (..., num_items, num_items) + Optional attention mask + attention_mask_condition: None or torch.Tensor with shape (..., num_items_condition, num_items_condition) + Optional attention mask for condition + crossattention_mask: None or torch.Tensor with shape (..., num_items, num_items_condition) + Optional mask for cross-attention + + Returns + ------- + outputs_mv : torch.Tensor with shape (..., num_items, out_mv_channels, 16) + Output multivectors. + outputs_s : None or torch.Tensor with shape (..., num_items, out_s_channels) + Output scalars, if scalars are provided. Otherwise None. + """ + + # Encode condition with GATr blocks + c_mv, c_s = self.linear_in_condition( + multivectors_condition, scalars=scalars_condition + ) + for block in self.condition_blocks: + if self._checkpoint_blocks: + c_mv, c_s = checkpoint( + block, + c_mv, + use_reentrant=False, + scalars=c_s, + attention_mask=attention_mask_condition, + ) + else: + c_mv, c_s = block( + c_mv, + scalars=c_s, + attention_mask=attention_mask_condition, + ) + + # Decode condition into main track with + h_mv, h_s = self.linear_in(multivectors, scalars=scalars) + for block in self.blocks: + if self._checkpoint_blocks: + h_mv, h_s = checkpoint( + block, + h_mv, + use_reentrant=False, + scalars=h_s, + multivectors_condition=c_mv, + scalars_condition=c_s, + attention_mask=attention_mask, + crossattention_mask=crossattention_mask, + ) + else: + h_mv, h_s = block( + h_mv, + scalars=h_s, + multivectors_condition=c_mv, + scalars_condition=c_s, + attention_mask=attention_mask, + crossattention_mask=crossattention_mask, + ) + + outputs_mv, outputs_s = self.linear_out(h_mv, scalars=h_s) + + return outputs_mv, outputs_s diff --git a/weaver/nn/model/gatr/nets/gap.py b/weaver/nn/model/gatr/nets/gap.py new file mode 100644 index 00000000..d8e75234 --- /dev/null +++ b/weaver/nn/model/gatr/nets/gap.py @@ -0,0 +1,120 @@ +from dataclasses import replace +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from torch.utils.checkpoint import checkpoint + +from gatr.layers.linear import EquiLinear +from gatr.layers.mlp import MLPConfig, GeoMLP + + +class GAP(nn.Module): + """Geometric Algebra Perceptron network for a data with a single token dimension. + It combines `num_blocks` GeoMLP blocks. + + Assumes input has shape `(..., in_channels, 16)`, output has shape + `(..., out_channels, 16)`, will create hidden representations with shape + `(..., hidden_channels, 16)`. + + Parameters + ---------- + in_mv_channels : int + Number of input multivector channels. + out_mv_channels : int + Number of output multivector channels. + hidden_mv_channels : int + Number of hidden multivector channels. + in_s_channels : None or int + If not None, sets the number of scalar input channels. + out_s_channels : None or int + If not None, sets the number of scalar output channels. + hidden_s_channels : None or int + If not None, sets the number of scalar hidden channels. + num_blocks : int + Number of resnet blocks. + dropout_prob : float or None + Dropout probability + """ + + def __init__( + self, + in_mv_channels: int, + out_mv_channels: int, + hidden_mv_channels: int, + in_s_channels: Optional[int], + out_s_channels: Optional[int], + hidden_s_channels: Optional[int], + mlp: MLPConfig, + num_blocks: int = 10, + num_layers: int = 3, + checkpoint_blocks: bool = False, + dropout_prob: Optional[float] = None, + **kwargs, + ) -> None: + super().__init__() + + self.linear_in = EquiLinear( + in_mv_channels, + hidden_mv_channels, + in_s_channels=in_s_channels, + out_s_channels=hidden_s_channels, + ) + + mlp = MLPConfig.cast(mlp) + mlp = replace( + mlp, + mv_channels=[hidden_mv_channels for _ in range(num_layers)], + s_channels=[hidden_s_channels for _ in range(num_layers)], + dropout_prob=dropout_prob, + ) + self.blocks = nn.ModuleList([GeoMLP(mlp) for _ in range(num_blocks)]) + + self.linear_out = EquiLinear( + hidden_mv_channels, + out_mv_channels, + in_s_channels=hidden_s_channels, + out_s_channels=out_s_channels, + ) + self._checkpoint_blocks = checkpoint_blocks + + def forward( + self, + multivectors: torch.Tensor, + scalars: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: + """Forward pass of the network. + + Parameters + ---------- + multivectors : torch.Tensor with shape (..., in_mv_channels, 16) + Input multivectors. + scalars : None or torch.Tensor with shape (..., in_s_channels) + Optional input scalars. + + Returns + ------- + outputs_mv : torch.Tensor with shape (..., out_mv_channels, 16) + Output multivectors. + outputs_s : None or torch.Tensor with shape (..., out_s_channels) + Output scalars, if scalars are provided. Otherwise None. + """ + + # Pass through the blocks + h_mv, h_s = self.linear_in(multivectors, scalars=scalars) + for block in self.blocks: + if self._checkpoint_blocks: + h_mv, h_s = checkpoint( + block, + h_mv, + use_reentrant=False, + scalars=h_s, + ) + else: + h_mv, h_s = block( + h_mv, + scalars=h_s, + ) + outputs_mv, outputs_s = self.linear_out(h_mv, scalars=h_s) + + return outputs_mv, outputs_s diff --git a/weaver/nn/model/gatr/nets/gatr.py b/weaver/nn/model/gatr/nets/gatr.py new file mode 100644 index 00000000..3b3d059c --- /dev/null +++ b/weaver/nn/model/gatr/nets/gatr.py @@ -0,0 +1,182 @@ +"""Equivariant transformer for multivector data.""" + +from dataclasses import replace +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from torch.utils.checkpoint import checkpoint + +from gatr.layers.attention.config import SelfAttentionConfig +from gatr.layers.gatr_block import GATrBlock +from gatr.layers.linear import EquiLinear +from gatr.layers.mlp.config import MLPConfig + + +class GATr(nn.Module): + """L-GATr network for a data with a single token dimension. + + It combines `num_blocks` L-GATr transformer blocks, each consisting of geometric self-attention + layers, a geometric MLP, residual connections, and normalization layers. In addition, there + are initial and final equivariant linear layers. + + Assumes input has shape `(..., items, in_channels, 16)`, output has shape + `(..., items, out_channels, 16)`, will create hidden representations with shape + `(..., items, hidden_channels, 16)`. + + Parameters + ---------- + in_mv_channels : int + Number of input multivector channels. + out_mv_channels : int + Number of output multivector channels. + hidden_mv_channels : int + Number of hidden multivector channels. + in_s_channels : None or int + If not None, sets the number of scalar input channels. + out_s_channels : None or int + If not None, sets the number of scalar output channels. + hidden_s_channels : None or int + If not None, sets the number of scalar hidden channels. + attention: Dict + Data for SelfAttentionConfig + mlp: Dict + Data for MLPConfig + num_blocks : int + Number of transformer blocks. + dropout_prob : float or None + Dropout probability + double_layernorm : bool + Whether to use double layer normalization + """ + + def __init__( + self, + in_mv_channels: int, + out_mv_channels: int, + hidden_mv_channels: int, + in_s_channels: Optional[int], + out_s_channels: Optional[int], + hidden_s_channels: Optional[int], + attention: SelfAttentionConfig, + mlp: MLPConfig, + num_blocks: int = 10, + reinsert_mv_channels: Optional[Tuple[int]] = None, + reinsert_s_channels: Optional[Tuple[int]] = None, + checkpoint_blocks: bool = False, + dropout_prob: Optional[float] = None, + double_layernorm: bool = False, + **kwargs, + ) -> None: + super().__init__() + self.linear_in = EquiLinear( + in_mv_channels, + hidden_mv_channels, + in_s_channels=in_s_channels, + out_s_channels=hidden_s_channels, + ) + attention = replace( + SelfAttentionConfig.cast(attention), + additional_qk_mv_channels=0 + if reinsert_mv_channels is None + else len(reinsert_mv_channels), + additional_qk_s_channels=0 + if reinsert_s_channels is None + else len(reinsert_s_channels), + ) + mlp = MLPConfig.cast(mlp) + self.blocks = nn.ModuleList( + [ + GATrBlock( + mv_channels=hidden_mv_channels, + s_channels=hidden_s_channels, + attention=attention, + mlp=mlp, + dropout_prob=dropout_prob, + double_layernorm=double_layernorm, + ) + for _ in range(num_blocks) + ] + ) + self.linear_out = EquiLinear( + hidden_mv_channels, + out_mv_channels, + in_s_channels=hidden_s_channels, + out_s_channels=out_s_channels, + ) + self._reinsert_s_channels = reinsert_s_channels + self._reinsert_mv_channels = reinsert_mv_channels + self._checkpoint_blocks = checkpoint_blocks + + def forward( + self, + multivectors: torch.Tensor, + scalars: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: + """Forward pass of the network. + + Parameters + ---------- + multivectors : torch.Tensor with shape (..., in_mv_channels, 16) + Input multivectors. + scalars : None or torch.Tensor with shape (..., in_s_channels) + Optional input scalars. + attention_mask: None or torch.Tensor with shape (..., num_items, num_items) + Optional attention mask + + Returns + ------- + outputs_mv : torch.Tensor with shape (..., out_mv_channels, 16) + Output multivectors. + outputs_s : None or torch.Tensor with shape (..., out_s_channels) + Output scalars, if scalars are provided. Otherwise None. + """ + + # Channels that will be re-inserted in any query / key computation + ( + additional_qk_features_mv, + additional_qk_features_s, + ) = self._construct_reinserted_channels(multivectors, scalars) + + # Pass through the blocks + h_mv, h_s = self.linear_in(multivectors, scalars=scalars) + for block in self.blocks: + if self._checkpoint_blocks: + h_mv, h_s = checkpoint( + block, + h_mv, + use_reentrant=False, + scalars=h_s, + additional_qk_features_mv=additional_qk_features_mv, + additional_qk_features_s=additional_qk_features_s, + attention_mask=attention_mask, + ) + else: + h_mv, h_s = block( + h_mv, + scalars=h_s, + additional_qk_features_mv=additional_qk_features_mv, + additional_qk_features_s=additional_qk_features_s, + attention_mask=attention_mask, + ) + + outputs_mv, outputs_s = self.linear_out(h_mv, scalars=h_s) + + return outputs_mv, outputs_s + + def _construct_reinserted_channels(self, multivectors, scalars): + """Constructs input features that will be reinserted in every attention layer.""" + + if self._reinsert_mv_channels is None: + additional_qk_features_mv = None + else: + additional_qk_features_mv = multivectors[..., self._reinsert_mv_channels, :] + + if self._reinsert_s_channels is None: + additional_qk_features_s = None + else: + assert scalars is not None + additional_qk_features_s = scalars[..., self._reinsert_s_channels] + + return additional_qk_features_mv, additional_qk_features_s diff --git a/weaver/nn/model/gatr/primitives/__init__.py b/weaver/nn/model/gatr/primitives/__init__.py new file mode 100644 index 00000000..e1a3d4a5 --- /dev/null +++ b/weaver/nn/model/gatr/primitives/__init__.py @@ -0,0 +1,17 @@ +from .attention import sdp_attention +from .bilinear import geometric_product +from .dropout import grade_dropout +from .invariants import ( + inner_product, + squared_norm, + abs_squared_norm, + pin_invariants, +) +from .linear import ( + equi_linear, + grade_involute, + grade_project, + reverse, +) +from .nonlinearities import gated_gelu, gated_relu, gated_sigmoid +from .normalization import equi_layer_norm diff --git a/weaver/nn/model/gatr/primitives/attention.py b/weaver/nn/model/gatr/primitives/attention.py new file mode 100644 index 00000000..9344647f --- /dev/null +++ b/weaver/nn/model/gatr/primitives/attention.py @@ -0,0 +1,142 @@ +from typing import Optional, Tuple, Union + +import torch +from einops import rearrange +from torch import Tensor +from torch.nn.functional import scaled_dot_product_attention as torch_sdpa +from xformers.ops import AttentionBias, memory_efficient_attention + +from gatr.primitives.invariants import _load_inner_product_factors + +# Masked out attention logits are set to this constant (a finite replacement for -inf): +_MASKED_OUT = float("-inf") + +# Force the use of xformers attention, even when no xformers attention mask is provided: +FORCE_XFORMERS = False + + +def sdp_attention( + q_mv: Tensor, + k_mv: Tensor, + v_mv: Tensor, + q_s: Tensor, + k_s: Tensor, + v_s: Tensor, + attn_mask: Optional[Tensor] = None, +) -> Tuple[Tensor, Tensor]: + """Equivariant geometric attention based on scaled dot products. + + Expects both multivector and scalar queries, keys, and values as inputs. + Then this function computes multivector and scalar outputs in the following way: + + ``` + attn_weights[..., i, j] = softmax_j[ + ga_inner_product(q_mv[..., i, :, :], k_mv[..., j, :, :]) + + euclidean_inner_product(q_s[..., i, :], k_s[..., j, :]) + ] + out_mv[..., i, c, :] = sum_j attn_weights[..., i, j] v_mv[..., j, c, :] / norm + out_s[..., i, c] = sum_j attn_weights[..., i, j] v_s[..., j, c] / norm + ``` + + Parameters + ---------- + q_mv : Tensor with shape (..., num_items_out, num_mv_channels_in, 16) + Queries, multivector part. + k_mv : Tensor with shape (..., num_items_in, num_mv_channels_in, 16) + Keys, multivector part. + v_mv : Tensor with shape (..., num_items_in, num_mv_channels_out, 16) + Values, multivector part. + q_s : Tensor with shape (..., num_items_out, num_s_channels_in) + Queries, scalar part. + k_s : Tensor with shape (..., num_items_in, num_s_channels_in) + Keys, scalar part. + v_s : Tensor with shape (..., num_items_in, num_s_channels_out) + Values, scalar part. + attn_mask : None or Tensor with shape (..., num_items, num_items) + Optional attention mask + + Returns + ------- + outputs_mv : Tensor with shape (..., num_items_out, num_mv_channels_out, 16) + Result, multivector part + outputs_s : Tensor with shape (..., num_items_out, num_s_channels_out) + Result, scalar part + """ + + # Construct queries and keys by concatenating relevant MV components and aux scalars + q = torch.cat( + [ + rearrange( + q_mv + * _load_inner_product_factors(device=q_mv.device, dtype=q_mv.dtype), + "... c x -> ... (c x)", + ), + q_s, + ], + -1, + ) + k = torch.cat([rearrange(k_mv, "... c x -> ... (c x)"), k_s], -1) + + num_channels_out = v_mv.shape[-2] + v = torch.cat([rearrange(v_mv, "... c x -> ... (c x)"), v_s], -1) + + v_out = scaled_dot_product_attention(q, k, v, attn_mask) + + v_out_mv = rearrange( + v_out[..., : num_channels_out * 16], "... (c x) -> ... c x", x=16 + ) + v_out_s = v_out[..., num_channels_out * 16 :] + + return v_out_mv, v_out_s + + +def scaled_dot_product_attention( + query: Tensor, + key: Tensor, + value: Tensor, + attn_mask: Optional[Union[AttentionBias, Tensor]] = None, + is_causal=False, +) -> Tensor: + """Execute (vanilla) scaled dot-product attention. + + Dynamically dispatch to xFormers if attn_mask is an instance of xformers.ops.AttentionBias + or FORCE_XFORMERS is set, use torch otherwise. + + Parameters + ---------- + query : Tensor + of shape [batch, head, item, d] + key : Tensor + of shape [batch, head, item, d] + value : Tensor + of shape [batch, head, item, d] + attn_mask : Optional[Union[AttentionBias, Tensor]] + Attention mask + is_causal: bool + + Returns + ------- + Tensor + of shape [batch, head, item, d] + """ + if FORCE_XFORMERS or isinstance(attn_mask, AttentionBias): + assert ( + not is_causal + ), "is_causal=True not implemented yet for xformers attention" + if key.shape[1] != query.shape[1]: # required to make multi_query work + key = key.expand(key.shape[0], query.shape[1], *key.shape[2:]) + value = value.expand(value.shape[0], query.shape[1], *value.shape[2:]) + query = query.transpose( + 1, 2 + ) # [batch, head, item, d] -> [batch, item, head, d] + key = key.transpose(1, 2) + value = value.transpose(1, 2) + out = memory_efficient_attention( + query.contiguous(), + key.contiguous(), + value.contiguous(), + attn_bias=attn_mask, + ) + out = out.transpose(1, 2) # [batch, item, head, d] -> [batch, head, item, d] + return out + return torch_sdpa(query, key, value, attn_mask=attn_mask, is_causal=is_causal) diff --git a/weaver/nn/model/gatr/primitives/bilinear.py b/weaver/nn/model/gatr/primitives/bilinear.py new file mode 100644 index 00000000..0fcaa19f --- /dev/null +++ b/weaver/nn/model/gatr/primitives/bilinear.py @@ -0,0 +1,67 @@ +from functools import lru_cache + +import torch +import clifford + +from gatr.utils.einsum import cached_einsum + + +@lru_cache() +def _load_geometric_product_tensor( + device=torch.device("cpu"), dtype=torch.float32 +) -> torch.Tensor: + """Loads geometric product tensor for geometric product between multivectors. + + This function is cached. + + Parameters + ---------- + device : torch.Device or str + Device + dtype : torch.Dtype + Data type + + Returns + ------- + basis : torch.Tensor with shape (16, 16, 16) + Geometric product tensor + """ + + # To avoid duplicate loading, base everything on float32 CPU version + if device not in [torch.device("cpu"), "cpu"] and dtype != torch.float32: + gmt = _load_geometric_product_tensor() + else: + layout, _ = clifford.Cl(1, 3) + gmt = torch.tensor(layout.gmt, dtype=torch.float32) + gmt = torch.transpose(gmt, 1, 0) + + # Convert to dense tensor + # The reason we do that is that einsum is not defined for sparse tensors + gmt = gmt.to_dense() + + return gmt.to(device=device, dtype=dtype) + + +def geometric_product(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Computes the geometric product f(x,y) = xy. + + Parameters + ---------- + x : torch.Tensor with shape (..., 16) + First input multivector. Batch dimensions must be broadcastable between x and y. + y : torch.Tensor with shape (..., 16) + Second input multivector. Batch dimensions must be broadcastable between x and y. + + Returns + ------- + outputs : torch.Tensor with shape (..., 16) + Result. Batch dimensions are result of broadcasting between x, y, and coeffs. + """ + + # Select kernel on correct device + gp = _load_geometric_product_tensor(device=x.device, dtype=x.dtype) + + # Compute geometric product + outputs = cached_einsum("i j k, ... j, ... k -> ... i", gp, x, y) + + return outputs diff --git a/weaver/nn/model/gatr/primitives/dropout.py b/weaver/nn/model/gatr/primitives/dropout.py new file mode 100644 index 00000000..45d010ac --- /dev/null +++ b/weaver/nn/model/gatr/primitives/dropout.py @@ -0,0 +1,36 @@ +import torch + +from gatr.primitives.linear import grade_project + + +def grade_dropout(x: torch.Tensor, p: float, training: bool = True) -> torch.Tensor: + """Multivector dropout, dropping out grades independently. + + Parameters + ---------- + x : torch.Tensor with shape (..., 16) + Input data. + p : float + Dropout probability (assumed the same for each grade). + training : bool + Switches between train-time and test-time behaviour. + + Returns + ------- + outputs : torch.Tensor with shape (..., 16) + Inputs with dropout applied. + """ + + # Project to grades + x = grade_project(x) + + # Apply standard 1D dropout + # For whatever reason, that only works with a single batch dimension, so let's reshape a bit + h = x.view(-1, 5, 16) + h = torch.nn.functional.dropout1d(h, p=p, training=training, inplace=False) + h = h.view(x.shape) + + # Combine grades again + h = torch.sum(h, dim=-2) + + return h diff --git a/weaver/nn/model/gatr/primitives/invariants.py b/weaver/nn/model/gatr/primitives/invariants.py new file mode 100644 index 00000000..fed9805d --- /dev/null +++ b/weaver/nn/model/gatr/primitives/invariants.py @@ -0,0 +1,171 @@ +from functools import lru_cache + +import torch +import math + +from gatr.primitives.linear import grade_project +from gatr.utils.einsum import cached_einsum + + +@lru_cache() +def _load_inner_product_factors( + device=torch.device("cpu"), dtype=torch.float32 +) -> torch.Tensor: + """Constructs an array of 1's and -1's for the metric of the space, + used to compute the inner product. + + Parameters + ---------- + device : torch.device + Device + dtype : torch.dtype + Dtype + + Returns + ------- + ip_factors : torch.Tensor with shape (16,) + Inner product factors + """ + + _INNER_PRODUCT_FACTORS = [1, 1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, -1, -1] + factors = torch.tensor( + _INNER_PRODUCT_FACTORS, dtype=torch.float32, device=torch.device("cpu") + ).to_dense() + return factors.to(device=device, dtype=dtype) + + +@lru_cache() +def _load_metric_grades( + device=torch.device("cpu"), dtype=torch.float32 +) -> torch.Tensor: + """Generate tensor of the diagonal of the GA metric, combined with a grade projection. + + Parameters + ---------- + device : torch.device + Device + dtype : torch.dtype + Dtype + + Returns + ------- + torch.Tensor of shape [5, 16] + """ + m = _load_inner_product_factors(device=torch.device("cpu"), dtype=torch.float32) + m_grades = torch.zeros(5, 16, device=torch.device("cpu"), dtype=torch.float32) + offset = 0 + for k in range(4 + 1): + d = math.comb(4, k) + m_grades[k, offset : offset + d] = m[offset : offset + d] + offset += d + return m_grades.to(device=device, dtype=dtype) + + +def inner_product( + x: torch.Tensor, y: torch.Tensor, channel_sum: bool = False +) -> torch.Tensor: + """Computes the inner product of multivectors f(x,y) = = <~x y>_0. + + In addition to summing over the 16 multivector dimensions, this function also sums + over an additional channel dimension if channel_sum == True. + + Equal to `geometric_product(reverse(x), y)[..., [0]]` (but faster). + + Parameters + ---------- + x : torch.Tensor with shape (..., 16) or (..., channels, 16) + First input multivector. Batch dimensions must be broadcastable between x and y. + y : torch.Tensor with shape (..., 16) or (..., channels, 16) + Second input multivector. Batch dimensions must be broadcastable between x and y. + channel_sum: bool + Whether to sum over the second-to-last axis (channels) + + Returns + ------- + outputs : torch.Tensor with shape (..., 1) + Result. Batch dimensions are result of broadcasting between x and y. + """ + + x = x * _load_inner_product_factors(device=x.device, dtype=x.dtype) + + if channel_sum: + outputs = cached_einsum("... c i, ... c i -> ...", x, y) + else: + outputs = cached_einsum("... i, ... i -> ...", x, y) + + # We want the output to have shape (..., 1) + outputs = outputs.unsqueeze(-1) + + return outputs + + +def squared_norm(x: torch.Tensor) -> torch.Tensor: + """Computes the squared GA norm of an input multivector. + + Equal to inner_product(x, x). + + NOTE: this primitive is not used widely in our architectures. + + Parameters + ---------- + x : torch.Tensor with shape (..., 16) + Input multivector. + + Returns + ------- + outputs : torch.Tensor with shape (..., 1) + Geometric algebra norm of x. + """ + + return inner_product(x, x) + + +def pin_invariants(x: torch.Tensor, epsilon: float = 0.01) -> torch.Tensor: + """Computes five invariants from multivectors: scalar component, norms of the four other grades. + + NOTE: this primitive is not used widely in our architectures. + + Parameters + ---------- + x : torch.Tensor with shape (..., 16) + Input multivector. + epsilon : float + Epsilon parameter that regularizes the norm in case it is lower or equal to zero to avoid infinite gradients. + + + Returns + ------- + outputs : torch.Tensor with shape (..., 5) + Invariants computed from input multivectors + """ + + # Project to grades + projections = grade_project(x) # (..., 5, 16) + + # Compute norms + squared_norms = inner_product(projections, projections)[..., 0] # (..., 5) + norms = torch.sqrt(torch.clamp(squared_norms, epsilon)) + + # Outputs: scalar component of input and norms of four other grades + return torch.cat((x[..., [0]], norms[..., 1:]), dim=-1) # (..., 5) + + +def abs_squared_norm(x: torch.Tensor) -> torch.Tensor: + """Computes a modified version of the squared norm that is positive semidefinite and can + therefore be used in layer normalization. + + Parameters + ---------- + x : torch.Tensor with shape (..., 16) + Input multivector. + + Returns + ------- + outputs : torch.Tensor with shape (..., 1) + Geometric algebra norm of x. + """ + m = _load_metric_grades(device=x.device, dtype=x.dtype) + abs_squared_norms = ( + cached_einsum("... i, ... i, g i -> ... g", x, x, m).abs().sum(-1, keepdim=True) + ) + return abs_squared_norms diff --git a/weaver/nn/model/gatr/primitives/linear.py b/weaver/nn/model/gatr/primitives/linear.py new file mode 100644 index 00000000..1e3b279d --- /dev/null +++ b/weaver/nn/model/gatr/primitives/linear.py @@ -0,0 +1,196 @@ +from functools import lru_cache + +import torch +import clifford +import numpy as np + +from gatr.utils.einsum import cached_einsum, custom_einsum + +# switch to decide whether to use the full Lorentz group ('False') +# or the special orthochronous Lorentz group ('True') +# They only differ in the construction of linear maps in _compute_pin_equi_linear_basis +USE_FULLY_CONNECTED_SUBGROUP = True + + +@lru_cache() +def _compute_pin_equi_linear_basis( + device=torch.device("cpu"), dtype=torch.float32, normalize=True +) -> torch.Tensor: + """Constructs basis elements for Pin(1,3)-equivariant linear maps between multivectors. + + This function is cached. + + Parameters + ---------- + device : torch.device + Device + dtype : torch.dtype + Dtype + normalize : bool + Whether to normalize the basis elements + + Returns + ------- + basis : torch.Tensor with shape (NUM_PIN_LINEAR_BASIS_ELEMENTS, 16, 16) + Basis elements for equivariant linear maps. + """ + + if device not in [torch.device("cpu"), "cpu"] and dtype != torch.float32: + basis = _compute_pin_equi_linear_basis(normalize=normalize) + else: + # explicit construction of a pin-equilinear basis for the Lorentz group + layout, blades = clifford.Cl(1, 3) + linear_basis = [] + mults = [1, layout.pseudoScalar] if USE_FULLY_CONNECTED_SUBGROUP else [1] + for mult in mults: + for grade in range(5): + w = np.stack([(x(grade) * mult).value for x in blades.values()], 1) + w = w.astype(np.float32) + if normalize: + # w /= np.linalg.norm(w) # straight-forward normalization + w /= np.linalg.svd(w)[1].max() # alternative normalization + linear_basis.append(w) + linear_basis = np.stack(linear_basis) + + basis = torch.tensor(linear_basis, dtype=torch.float32).to_dense() + return basis.to(device=device, dtype=dtype) + + +@lru_cache() +def _compute_reversal(device=torch.device("cpu"), dtype=torch.float32) -> torch.Tensor: + """Constructs a matrix that computes multivector reversal. + + Parameters + ---------- + device : torch.device + Device + dtype : torch.dtype + Dtype + + Returns + ------- + reversal_diag : torch.Tensor with shape (16,) + The diagonal of the reversal matrix, consisting of +1 and -1 entries. + """ + reversal_flat = torch.ones(16, device=device, dtype=dtype) + reversal_flat[5:15] = -1 + return reversal_flat + + +@lru_cache() +def _compute_grade_involution( + device=torch.device("cpu"), dtype=torch.float32 +) -> torch.Tensor: + """Constructs a matrix that computes multivector grade involution. + + Parameters + ---------- + device : torch.device + Device + dtype : torch.dtype + Dtype + + Returns + ------- + involution_diag : torch.Tensor with shape (16,) + The diagonal of the involution matrix, consisting of +1 and -1 entries. + """ + involution_flat = torch.ones(16, device=device, dtype=dtype) + involution_flat[1:5] = -1 + involution_flat[11:15] = -1 + return involution_flat + + +def equi_linear(x: torch.Tensor, coeffs: torch.Tensor) -> torch.Tensor: + """Pin-equivariant linear map f(x) = sum_{a,j} coeffs_a W^a_ij x_j. + + The W^a are seven pre-defined basis elements. + + Parameters + ---------- + x : torch.Tensor with shape (..., in_channels, 16) + Input multivector. Batch dimensions must be broadcastable between x and coeffs. + coeffs : torch.Tensor with shape (out_channels, in_channels, 10) + Coefficients for the basis elements. Batch dimensions must be broadcastable between x and + coeffs. + + Returns + ------- + outputs : torch.Tensor with shape (..., 16) + Result. Batch dimensions are result of broadcasting between x and coeffs. + """ + basis = _compute_pin_equi_linear_basis(device=x.device, dtype=x.dtype) + return custom_einsum( + "y x a, a i j, ... x j -> ... y i", coeffs, basis, x, path=[0, 1, 0, 1] + ) + + +def grade_project(x: torch.Tensor) -> torch.Tensor: + """Projects an input tensor to the individual grades. + + The return value is a single tensor with a new grade dimension. + + NOTE: this primitive is not used widely in our architectures. + + Parameters + ---------- + x : torch.Tensor with shape (..., 16) + Input multivector. + + Returns + ------- + outputs : torch.Tensor with shape (..., 5, 16) + Output multivector. The second-to-last dimension indexes the grades. + """ + + # Select kernel on correct device + basis = _compute_pin_equi_linear_basis( + device=x.device, dtype=x.dtype, normalize=False + ) + + # First five basis elements are grade projections + basis = basis[:5] + + # Project to grades + projections = cached_einsum("g i j, ... j -> ... g i", basis, x) + + return projections + + +def reverse(x: torch.Tensor) -> torch.Tensor: + """Computes the reversal of a multivector. + + The reversal has the same scalar, vector, and pseudoscalar components, but flips sign in the + bivector and trivector components. + + Parameters + ---------- + x : torch.Tensor with shape (..., 16) + Input multivector. + + Returns + ------- + outputs : torch.Tensor with shape (..., 16) + Output multivector. + """ + return _compute_reversal(device=x.device, dtype=x.dtype) * x + + +def grade_involute(x: torch.Tensor) -> torch.Tensor: + """Computes the grade involution of a multivector. + + The reversal has the same scalar, bivector, and pseudoscalar components, but flips sign in the + vector and trivector components. + + Parameters + ---------- + x : torch.Tensor with shape (..., 16) + Input multivector. + + Returns + ------- + outputs : torch.Tensor with shape (..., 16) + Output multivector. + """ + + return _compute_grade_involution(device=x.device, dtype=x.dtype) * x diff --git a/weaver/nn/model/gatr/primitives/nonlinearities.py b/weaver/nn/model/gatr/primitives/nonlinearities.py new file mode 100644 index 00000000..ad285301 --- /dev/null +++ b/weaver/nn/model/gatr/primitives/nonlinearities.py @@ -0,0 +1,79 @@ +import math + +import torch + + +def gated_relu(x: torch.Tensor, gates: torch.Tensor) -> torch.Tensor: + """Pin-equivariant gated ReLU nonlinearity. + + Given multivector input x and scalar input gates (with matching batch dimensions), computes + ReLU(gates) * x. + + Parameters + ---------- + x : torch.Tensor with shape (..., 16) + Multivector input + gates : torch.Tensor with shape (..., 1) + Pin-invariant gates. + + Returns + ------- + outputs : torch.Tensor with shape (..., 16) + Computes ReLU(gates) * x, with broadcasting along the last dimension. + """ + + weights = torch.nn.functional.relu(gates) + outputs = weights * x + return outputs + + +def gated_sigmoid(x: torch.Tensor, gates: torch.Tensor): + """Pin-equivariant gated sigmoid nonlinearity. + + Given multivector input x and scalar input gates (with matching batch dimensions), computes + sigmoid(gates) * x. + + Parameters + ---------- + x : torch.Tensor with shape (..., 16) + Multivector input + gates : torch.Tensor with shape (..., 1) + Pin-invariant gates. + + Returns + ------- + outputs : torch.Tensor with shape (..., 16) + Computes sigmoid(gates) * x, with broadcasting along the last dimension. + """ + + weights = torch.nn.functional.sigmoid(gates) + outputs = weights * x + return outputs + + +def gated_gelu(x: torch.Tensor, gates: torch.Tensor) -> torch.Tensor: + """Pin-equivariant gated GeLU nonlinearity without division. + + Given multivector input x and scalar input gates (with matching batch dimensions), computes + GeLU(gates) * x. + + References + ---------- + Dan Hendrycks, Kevin Gimpel, "Gaussian Error Linear Units (GELUs)", arXiv:1606.08415 + + Parameters + ---------- + x : torch.Tensor with shape (..., 16) + Multivector input + gates : torch.Tensor with shape (..., 1) + Pin-invariant gates. + + Returns + ------- + outputs : torch.Tensor with shape (..., 16) + Computes GeLU(gates) * x, with broadcasting along the last dimension. + """ + + weights = torch.nn.functional.gelu(gates, approximate="tanh") + outputs = weights * x + return outputs diff --git a/weaver/nn/model/gatr/primitives/normalization.py b/weaver/nn/model/gatr/primitives/normalization.py new file mode 100644 index 00000000..7ac64e55 --- /dev/null +++ b/weaver/nn/model/gatr/primitives/normalization.py @@ -0,0 +1,47 @@ +import torch + +from gatr.primitives.invariants import abs_squared_norm + + +def equi_layer_norm( + x: torch.Tensor, channel_dim: int = -2, gain: float = 1.0, epsilon: float = 0.01 +) -> torch.Tensor: + """Equivariant LayerNorm for multivectors. + + Rescales input such that `mean_channels |inputs|^2 = 1`, where the norm is the GA norm and the + mean goes over the channel dimensions. + + Using a factor `gain > 1` makes up for the fact that the GP norm overestimates the actual + standard deviation of the input data. + + Parameters + ---------- + x : torch.Tensor with shape `(batch_dim, *channel_dims, 16)` + Input multivectors. + channel_dim : int + Channel dimension index. Defaults to the second-last entry (last are the multivector + components). + gain : float + Target output scale. + epsilon : float + Small numerical factor to avoid instabilities. By default, we use a reasonably large number + to balance issues that arise from some multivector components not contributing to the norm. + + Returns + ------- + outputs : torch.Tensor with shape `(batch_dim, *channel_dims, 16)` + Normalized inputs. + """ + + # Compute mean_channels |inputs|^2 + abs_squared_norms = abs_squared_norm(x) + abs_squared_norms = torch.mean(abs_squared_norms, dim=channel_dim, keepdim=True) + + # Insure against low-norm tensors (which can arise even when `x.var(dim=-1)` is high b/c some + # entries don't contribute to the inner product / GP norm!) + abs_squared_norms = torch.clamp(abs_squared_norms, epsilon) + + # Rescale inputs + outputs = gain * x / torch.sqrt(abs_squared_norms) + + return outputs diff --git a/weaver/nn/model/gatr/utils/__init__.py b/weaver/nn/model/gatr/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/weaver/nn/model/gatr/utils/clifford.py b/weaver/nn/model/gatr/utils/clifford.py new file mode 100644 index 00000000..983f3031 --- /dev/null +++ b/weaver/nn/model/gatr/utils/clifford.py @@ -0,0 +1,136 @@ +"""Geometric algebra operations based on the clifford library.""" + +from typing import Optional + +import clifford +import numpy as np +import torch + +LAYOUT, BLADES = clifford.Cl(1, 3) + + +def np_to_mv(array): + """Shorthand to transform a numpy array to a Pin(1,3) multivector.""" + return clifford.MultiVector(LAYOUT, value=array) + + +def tensor_to_mv(tensor): + """Shorthand to transform a numpy array to a Pin(1,3) multivector.""" + return np_to_mv(tensor.detach().cpu().numpy()) + + +def tensor_to_mv_list(tensor): + """Transforms a torch.Tensor to a list of multivector objects.""" + + tensor = tensor.reshape((-1, 16)) + mv_list = [tensor_to_mv(x) for x in tensor] + + return mv_list + + +def mv_list_to_tensor(multivectors, batch_shape=None): + """Transforms a list of multivector objects to a torch.Tensor.""" + + tensor = torch.from_numpy(np.array([mv.value for mv in multivectors])).to( + torch.float32 + ) + if batch_shape is not None: + tensor = tensor.reshape(*batch_shape, 16) + + return tensor + + +def sample_pin_multivector( + spin: bool = False, rng: Optional[np.random.Generator] = None +): + """Samples from the Pin(1,3) group as a product of reflections.""" + + if rng is None: + rng = np.random.default_rng() + + # Sample number of reflections we want to multiply + if spin: + i = np.random.randint(3) * 2 + else: + i = np.random.randint(5) + + # If no reflections, just return unit scalar + if i == 0: + return BLADES[""] + + multivector = 1.0 + for _ in range(i): + # Sample reflection vector + vector = np.zeros(16) + vector[2:5] = rng.normal(size=3) * 2 + norm = np.linalg.norm(vector[2:5]) + vector[1] = (rng.uniform(size=1) - 0.5) * norm + + vector_mv = np_to_mv(vector) + vector_mv = vector_mv / abs(vector_mv.mag2()) ** 0.5 + + # Multiply together (geometric product) + multivector = multivector * vector_mv + + return multivector + + +def get_parity(mv): + """Gets parity of a clifford multivector. + + Given a clifford multivector, returns True if it is pure-odd-grade, False if it is pure-even + grade, and raises a RuntimeError if it is mixed. + """ + if mv == mv.even: + return False + if mv == mv.odd: + return True + raise RuntimeError(f"Mixed-grade multivector: {mv}") + + +def sandwich(u, x): + """Given clifford multivectors, computes their sandwich product. + + Specifically, given a Pin element u and a PGA element x, both given as clifford multivectors, + computes the sandwich product + ``` + sandwich(x, u) = (-1)^(grade(u) * grade(x)) u x u^{-1} . + ``` + + If `u` is of odd grades, then this is equal to `u * grade_involute(x) * u^{-1}`. + If `u` is of even grades, then this is equal to `u * x * u^{-1}`. + """ + + if get_parity(u): + return u * x.gradeInvol() * u.shirokov_inverse() + + return u * x * u.shirokov_inverse() + + +class SlowRandomPinTransform: + """Random Pin transform on a multivector torch.Tensor. + + Slow, only used for testing purposes. Breaks computational graph. + """ + + def __init__(self, spin=False, rng=None): + super().__init__() + self._u = sample_pin_multivector(spin, rng) + self._u_inverse = self._u.shirokov_inverse() + + def __call__(self, inputs: torch.Tensor) -> torch.Tensor: + """Apply Pin transformation to multivector inputs.""" + # Input shape + assert inputs.shape[-1] == 16 + batch_dims = inputs.shape[:-1] + + # Convert inputs to list of multivectors + inputs_mv = tensor_to_mv_list(inputs) + + # Transform + outputs_mv = [sandwich(self._u, x) for x in inputs_mv] + + # Back to tensor + outputs = mv_list_to_tensor(outputs_mv, batch_shape=batch_dims) + + return outputs diff --git a/weaver/nn/model/gatr/utils/einsum.py b/weaver/nn/model/gatr/utils/einsum.py new file mode 100644 index 00000000..1e256035 --- /dev/null +++ b/weaver/nn/model/gatr/utils/einsum.py @@ -0,0 +1,44 @@ +"""This module provides efficiency improvements over torch's einsum through caching.""" + +import functools +from typing import List, Sequence + +import opt_einsum +import torch + + +def custom_einsum( + equation: str, *operands: torch.Tensor, path: List[int] +) -> torch.Tensor: + """Computes einsum with a custom contraction order.""" + + # Justification: For the sake of performance, we need direct access to torch's private methods. + + # pylint:disable-next=protected-access + return torch._VF.einsum(equation, operands, path=path) # type: ignore[attr-defined] + + +def cached_einsum(equation: str, *operands: torch.Tensor) -> torch.Tensor: + """Computes einsum with a cached optimal contraction. + + Inspired by upstream + https://github.com/pytorch/pytorch/blob/v1.13.0/torch/functional.py#L381. + """ + op_shape = tuple(op.shape for op in operands) + path = _get_cached_path_for_equation_and_shapes( + equation=equation, op_shape=op_shape + ) + + return custom_einsum(equation, *operands, path=path) + + +@functools.lru_cache(maxsize=None) +def _get_cached_path_for_equation_and_shapes( + equation: str, op_shape: Sequence[torch.Tensor] +) -> List[int]: + """Provides caching of optimal path.""" + tupled_path = opt_einsum.contract_path( + equation, *op_shape, optimize="optimal", shapes=True + )[0] + + return [item for pair in tupled_path for item in pair] diff --git a/weaver/nn/model/gatr/utils/tensors.py b/weaver/nn/model/gatr/utils/tensors.py new file mode 100644 index 00000000..3373e20c --- /dev/null +++ b/weaver/nn/model/gatr/utils/tensors.py @@ -0,0 +1,34 @@ +import torch + + +def assert_equal(vals): + """Assert all values in sequence are equal.""" + for v in vals: + assert v == vals[0] + + +def block_stack(tensors, dim1, dim2): + """Block diagonally stack tensors along dimensions dim1 and dim2.""" + assert_equal([t.dim() for t in tensors]) + shapes = [t.shape for t in tensors] + shapes_t = list(map(list, zip(*shapes))) + for i, ss in enumerate(shapes_t): + if i not in (dim1, dim2): + assert_equal(ss) + + dim2_len = sum(shapes_t[dim2]) + opts = dict(device=tensors[0].device, dtype=tensors[0].dtype) + + padded_tensors = [] + offset = 0 + for tensor in tensors: + before_shape = list(tensor.shape) + before_shape[dim2] = offset + after_shape = list(tensor.shape) + after_shape[dim2] = dim2_len - tensor.shape[dim2] - offset + before = torch.zeros(*before_shape, **opts) + after = torch.zeros(*after_shape, **opts) + padded = torch.cat([before, tensor, after], dim2) + padded_tensors.append(padded) + offset += tensor.shape[dim2] + return torch.cat(padded_tensors, dim1) From b6f49d56547e80e7b24eb02eb7c695ee45237d57 Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Thu, 13 Feb 2025 20:53:30 +0100 Subject: [PATCH 02/29] Remove unused features: ConditionalGATr (and cross-attention), GAP, AxialGATr, positional encodings, xformers attention, utils/clifford.py (used in tests) and utils/tensors.py (used in experiment code) --- weaver/nn/model/gatr/layers/__init__.py | 5 +- .../model/gatr/layers/attention/__init__.py | 3 +- .../nn/model/gatr/layers/attention/config.py | 112 --------- .../gatr/layers/attention/cross_attention.py | 220 ----------------- .../layers/attention/positional_encoding.py | 135 ----------- .../gatr/layers/attention/self_attention.py | 14 -- .../gatr/layers/conditional_gatr_block.py | 191 --------------- weaver/nn/model/gatr/nets/__init__.py | 3 - weaver/nn/model/gatr/nets/axial_gatr.py | 214 ----------------- weaver/nn/model/gatr/nets/conditional_gatr.py | 227 ------------------ weaver/nn/model/gatr/nets/gap.py | 120 --------- weaver/nn/model/gatr/primitives/attention.py | 37 +-- weaver/nn/model/gatr/utils/clifford.py | 136 ----------- weaver/nn/model/gatr/utils/tensors.py | 34 --- 14 files changed, 5 insertions(+), 1446 deletions(-) delete mode 100644 weaver/nn/model/gatr/layers/attention/cross_attention.py delete mode 100644 weaver/nn/model/gatr/layers/attention/positional_encoding.py delete mode 100644 weaver/nn/model/gatr/layers/conditional_gatr_block.py delete mode 100644 weaver/nn/model/gatr/nets/axial_gatr.py delete mode 100644 weaver/nn/model/gatr/nets/conditional_gatr.py delete mode 100644 weaver/nn/model/gatr/nets/gap.py delete mode 100644 weaver/nn/model/gatr/utils/clifford.py delete mode 100644 weaver/nn/model/gatr/utils/tensors.py diff --git a/weaver/nn/model/gatr/layers/__init__.py b/weaver/nn/model/gatr/layers/__init__.py index 6c82d47b..7757d6fa 100644 --- a/weaver/nn/model/gatr/layers/__init__.py +++ b/weaver/nn/model/gatr/layers/__init__.py @@ -1,7 +1,5 @@ -from .attention.config import SelfAttentionConfig, CrossAttentionConfig -from .attention.positional_encoding import ApplyRotaryPositionalEncoding +from .attention.config import SelfAttentionConfig from .attention.self_attention import SelfAttention -from .attention.cross_attention import CrossAttention from .dropout import GradeDropout from .layer_norm import EquiLayerNorm from .linear import EquiLinear @@ -10,4 +8,3 @@ from .mlp.config import MLPConfig from .mlp.nonlinearities import ScalarGatedNonlinearity from .gatr_block import GATrBlock -from .conditional_gatr_block import ConditionalGATrBlock diff --git a/weaver/nn/model/gatr/layers/attention/__init__.py b/weaver/nn/model/gatr/layers/attention/__init__.py index 9addf7e6..8c76bbb4 100644 --- a/weaver/nn/model/gatr/layers/attention/__init__.py +++ b/weaver/nn/model/gatr/layers/attention/__init__.py @@ -1,3 +1,2 @@ -from .config import SelfAttentionConfig, CrossAttentionConfig +from .config import SelfAttentionConfig from .self_attention import SelfAttention -from .cross_attention import CrossAttention diff --git a/weaver/nn/model/gatr/layers/attention/config.py b/weaver/nn/model/gatr/layers/attention/config.py index 950c45c7..76a12e9a 100644 --- a/weaver/nn/model/gatr/layers/attention/config.py +++ b/weaver/nn/model/gatr/layers/attention/config.py @@ -26,11 +26,6 @@ class SelfAttentionConfig: Whether additional scalar features for the keys and queries will be provided. multi_query: bool Whether to do multi-query attention - pos_encoding : bool - Whether to apply rotary positional embeddings along the item dimension to the scalar keys - and queries. - pos_encoding_base : int - Base for the frequencies in the positional encoding. output_init : str Initialization scheme for final linear layer increase_hidden_channels : int @@ -49,8 +44,6 @@ class SelfAttentionConfig: num_heads: int = 8 additional_qk_mv_channels: int = 0 additional_qk_s_channels: int = 0 - pos_encoding: bool = False - pos_encoding_base: int = 4096 output_init: str = "default" checkpoint: bool = True increase_hidden_channels: int = 2 @@ -87,12 +80,6 @@ def hidden_s_channels(self) -> Optional[int]: self.increase_hidden_channels * self.in_s_channels // self.num_heads, 4 ) - # When using positional encoding, the number of scalar hidden channels needs to be even. - # It also should not be too small. - if self.pos_encoding: - hidden_s_channels = (hidden_s_channels + 1) // 2 * 2 - hidden_s_channels = max(hidden_s_channels, 8) - return hidden_s_channels @classmethod @@ -103,102 +90,3 @@ def cast(cls, config: Any) -> SelfAttentionConfig: if isinstance(config, Mapping): return cls(**config) raise ValueError(f"Can not cast {config} to {cls}") - - -@dataclass -class CrossAttentionConfig: - """Configuration for cross-attention. - - Parameters - ---------- - in_q_mv_channels : int - Number of input query multivector channels. - in_kv_mv_channels : int - Number of input key/value multivector channels. - out_mv_channels : int - Number of output multivector channels. - num_heads : int - Number of attention heads. - in_q_s_channels : int - Input query scalar channels. If None, no scalars are expected nor returned. - in_kv_s_channels : int - Input key/value scalar channels. If None, no scalars are expected nor returned. - out_s_channels : int - Output scalar channels. If None, no scalars are expected nor returned. - additional_q_mv_channels : int - Whether additional multivector features for the queries will be provided. - additional_q_s_channels : int - Whether additional scalar features for the queries will be provided. - additional_k_mv_channels : int - Whether additional multivector features for the keys will be provided. - additional_k_s_channels : int - Whether additional scalar features for the keys will be provided. - multi_query: bool - Whether to do multi-query attention - output_init : str - Initialization scheme for final linear layer - increase_hidden_channels : int - Factor by which to increase the number of hidden channels (both multivectors and scalars) - dropout_prob : float or None - Dropout probability - head_scale: bool - Whether to use HeadScaleMHA following the NormFormer (https://arxiv.org/pdf/2110.09456) - """ - - in_q_mv_channels: Optional[int] = None - in_kv_mv_channels: Optional[int] = None - out_mv_channels: Optional[int] = None - out_s_channels: Optional[int] = None - in_q_s_channels: Optional[int] = None - in_kv_s_channels: Optional[int] = None - num_heads: int = 8 - additional_q_mv_channels: int = 0 - additional_q_s_channels: int = 0 - additional_k_mv_channels: int = 0 - additional_k_s_channels: int = 0 - multi_query: bool = True - output_init: str = "default" - checkpoint: bool = True - increase_hidden_channels: int = 2 - dropout_prob: Optional[float] = None - head_scale: bool = False - - def __post_init__(self): - """Type checking / conversion.""" - if isinstance(self.dropout_prob, str) and self.dropout_prob.lower() in [ - "null", - "none", - ]: - self.dropout_prob = None - - @property - def hidden_mv_channels(self) -> Optional[int]: - """Returns the number of hidden multivector channels.""" - - if self.in_q_mv_channels is None: - return None - - return max( - self.increase_hidden_channels * self.in_q_mv_channels // self.num_heads, 1 - ) - - @property - def hidden_s_channels(self) -> Optional[int]: - """Returns the number of hidden scalar channels.""" - - if self.in_q_s_channels is None: - assert self.in_kv_s_channels is None - return None - - return max( - self.increase_hidden_channels * self.in_q_s_channels // self.num_heads, 4 - ) - - @classmethod - def cast(cls, config: Any) -> CrossAttentionConfig: - """Casts an object as CrossAttentionConfig.""" - if isinstance(config, CrossAttentionConfig): - return config - if isinstance(config, Mapping): - return cls(**config) - raise ValueError(f"Can not cast {config} to {cls}") diff --git a/weaver/nn/model/gatr/layers/attention/cross_attention.py b/weaver/nn/model/gatr/layers/attention/cross_attention.py deleted file mode 100644 index 6fb1024b..00000000 --- a/weaver/nn/model/gatr/layers/attention/cross_attention.py +++ /dev/null @@ -1,220 +0,0 @@ -"""Cross-attention layer.""" - -from typing import Optional, Tuple - -import torch -from einops import rearrange -from torch import nn - -from gatr.layers.attention.attention import GeometricAttention -from gatr.layers.attention.config import SelfAttentionConfig -from gatr.layers.dropout import GradeDropout -from gatr.layers.linear import EquiLinear - - -class CrossAttention(nn.Module): - """Geometric cross-attention layer. - - Constructs queries, keys, and values, computes attention, and projects linearly to outputs. - - Parameters - ---------- - config : SelfAttentionConfig - Attention configuration. - """ - - def __init__( - self, - config: SelfAttentionConfig, - ) -> None: - super().__init__() - - if ( - config.additional_q_mv_channels > 0 - or config.additional_q_s_channels > 0 - or config.additional_k_mv_channels > 0 - or config.additional_k_s_channels > 0 - ): - raise NotImplementedError( - "Cross attention is not implemented with additional channels" - ) - - # Store settings - self.config = config - - self.q_linear = EquiLinear( - in_mv_channels=config.in_q_mv_channels, - out_mv_channels=config.hidden_mv_channels * config.num_heads, - in_s_channels=config.in_q_s_channels, - out_s_channels=config.hidden_s_channels * config.num_heads, - ) - self.kv_linear = EquiLinear( - in_mv_channels=config.in_kv_mv_channels, - out_mv_channels=2 - * config.hidden_mv_channels - * (1 if config.multi_query else config.num_heads), - in_s_channels=config.in_kv_s_channels, - out_s_channels=2 - * config.hidden_s_channels - * (1 if config.multi_query else config.num_heads), - ) - - # Output projection - self.out_linear = EquiLinear( - in_mv_channels=config.hidden_mv_channels * config.num_heads, - out_mv_channels=config.out_mv_channels, - in_s_channels=( - None - if config.in_kv_s_channels is None - else config.hidden_s_channels * config.num_heads - ), - out_s_channels=config.out_s_channels, - initialization=config.output_init, - ) - - # Attention - self.attention = GeometricAttention(config) - - # Dropout - self.dropout: Optional[nn.Module] - if config.dropout_prob is not None: - raise ValueError( - "Dropout violates equivariance for cross_attention, " - "thats definitely a bug but didn't find the reason yet." - ) - self.dropout = GradeDropout(config.dropout_prob) - else: - self.dropout = None - - # HeadScaleMHA - self.use_head_scale = config.head_scale - if self.use_head_scale: - self.head_scale = nn.Parameter(torch.ones(config.num_heads)) - - def forward( - self, - multivectors_kv: torch.Tensor, - multivectors_q: torch.Tensor, - scalars_kv: Optional[torch.Tensor] = None, - scalars_q: Optional[torch.Tensor] = None, - attention_mask=None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Compute cross attention. - - Parameters - ---------- - multivectors_kv : torch.Tensor with shape (..., num_items_kv, channels_in, 16) - Input multivectors for key and value. - multivectors_q : torch.Tensor with shape (..., num_items_q, channels_in_q, 16) - Input multivectors for query. - scalars_kv : None or torch.Tensor with shape (..., num_items_kv, in_scalars) - Optional input scalars - scalars_q : None or torch.Tensor with shape (..., num_items_q, in_scalars_q) - Optional input scalars for query - attention_mask: torch.Tensor with shape (..., num_items_q, num_items_kv) or xformers mask. - Attention mask - - Returns - ------- - outputs_mv : torch.Tensor with shape (..., num_items_q, channels_out, 16) - Output multivectors. - output_scalars : torch.Tensor with shape (..., num_items_q, channels_out, out_scalars) - Output scalars, if scalars are provided. Otherwise None. - """ - q_mv, q_s = self.q_linear( - multivectors_q, scalars_q - ) # (..., num_items, hidden_channels, 16) - kv_mv, kv_s = self.kv_linear( - multivectors_kv, scalars_kv - ) # (..., num_items, 2*hidden_channels, 16) - k_mv, v_mv = torch.tensor_split(kv_mv, 2, dim=-2) - k_s, v_s = torch.tensor_split(kv_s, 2, dim=-1) - - # Rearrange to (..., heads, items, channels, 16) shape - q_mv = rearrange( - q_mv, - "... items (hidden_channels num_heads) x -> ... num_heads items hidden_channels x", - num_heads=self.config.num_heads, - hidden_channels=self.config.hidden_mv_channels, - ) - if self.config.multi_query: - k_mv = rearrange( - k_mv, "... items hidden_channels x -> ... 1 items hidden_channels x" - ) - v_mv = rearrange( - v_mv, "... items hidden_channels x -> ... 1 items hidden_channels x" - ) - else: - k_mv = rearrange( - k_mv, - "... items (hidden_channels num_heads) x -> ... num_heads items hidden_channels x", - num_heads=self.config.num_heads, - hidden_channels=self.config.hidden_mv_channels, - ) - v_mv = rearrange( - v_mv, - "... items (hidden_channels num_heads) x -> ... num_heads items hidden_channels x", - num_heads=self.config.num_heads, - hidden_channels=self.config.hidden_mv_channels, - ) - - # Same for scalars - if q_s is not None: - q_s = rearrange( - q_s, - "... items (hidden_channels num_heads) -> ... num_heads items hidden_channels", - num_heads=self.config.num_heads, - hidden_channels=self.config.hidden_s_channels, - ) - if self.config.multi_query: - k_s = rearrange( - k_s, "... items hidden_channels -> ... 1 items hidden_channels" - ) - v_s = rearrange( - v_s, "... items hidden_channels -> ... 1 items hidden_channels" - ) - else: - k_s = rearrange( - k_s, - "... items (hidden_channels num_heads) -> ... num_heads items hidden_channels", - num_heads=self.config.num_heads, - hidden_channels=self.config.hidden_s_channels, - ) - v_s = rearrange( - v_s, - "... items (hidden_channels num_heads) -> ... num_heads items hidden_channels", - num_heads=self.config.num_heads, - hidden_channels=self.config.hidden_s_channels, - ) - else: - q_s, k_s, v_s = None, None, None - - # Attention layer - h_mv, h_s = self.attention( - q_mv, k_mv, v_mv, q_s, k_s, v_s, attention_mask=attention_mask - ) - if self.use_head_scale: - h_mv = h_mv * self.head_scale.view( - *[1] * len(h_mv.shape[:-5]), len(self.head_scale), 1, 1, 1 - ) - h_s = h_s * self.head_scale.view( - *[1] * len(h_s.shape[:-4]), len(self.head_scale), 1, 1 - ) - - h_mv = rearrange( - h_mv, - "... n_heads n_items hidden_channels x -> ... n_items (n_heads hidden_channels) x", - ) - h_s = rearrange( - h_s, - "... n_heads n_items hidden_channels -> ... n_items (n_heads hidden_channels)", - ) - - # Transform linearly one more time - outputs_mv, outputs_s = self.out_linear(h_mv, scalars=h_s) - - # Dropout - if self.dropout is not None: - outputs_mv, outputs_s = self.dropout(outputs_mv, outputs_s) - - return outputs_mv, outputs_s diff --git a/weaver/nn/model/gatr/layers/attention/positional_encoding.py b/weaver/nn/model/gatr/layers/attention/positional_encoding.py deleted file mode 100644 index 3fe976c4..00000000 --- a/weaver/nn/model/gatr/layers/attention/positional_encoding.py +++ /dev/null @@ -1,135 +0,0 @@ -"""Adapted from the below. - -https://github.com/EleutherAI/gpt-neox/blob/737c9134bfaff7b58217d61f6619f1dcca6c484f/megatron/model/positional_embeddings.py -by EleutherAI at https://github.com/EleutherAI/gpt-neox - -Copyright (c) 2021, EleutherAI - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -import torch - -from gatr.utils.einsum import cached_einsum - - -class ApplyRotaryPositionalEncoding(torch.nn.Module): - """Applies rotary position encodings (RoPE) to scalar tensors. - - References - ---------- - Jianlin Su et al, "RoFormer: Enhanced Transformer with Rotary Position Embedding", - arXiv:2104.09864 - - Parameters - ---------- - num_channels : int - Number of channels (key and query size). - item_dim : int - Embedding dimension. Should be even. - base : int - Determines the frequencies. - """ - - def __init__(self, num_channels, item_dim, base=4096): - super().__init__() - - assert ( - num_channels % 2 == 0 - ), "Number of channels needs to be even for rotary position embeddings" - - inv_freq = 1.0 / ( - base ** (torch.arange(0, num_channels, 2).float() / num_channels) - ) - self.register_buffer("inv_freq", inv_freq) - self.seq_len_cached = None - self.device_cached = None - self.cos_cached = None - self.sin_cached = None - self.item_dim = item_dim - self.num_channels = num_channels - - def forward(self, scalars: torch.Tensor) -> torch.Tensor: - """Computes rotary embeddings along `self.item_dim` and applies them to inputs. - - The inputs are usually scalar queries and keys. - - Assumes that the last dimension is the feature dimension (and is thus not suited - for multivector data!). - - Parameters - ---------- - scalars : torch.Tensor of shape (..., num_channels) - Input data. The last dimension is assumed to be the channel / feature dimension - (NOT the 16 dimensions of a multivector). - - Returns - ------- - outputs : torch.Tensor of shape (..., num_channels) - Output data. Rotary positional embeddings applied to the input tensor. - """ - - # Check inputs - assert scalars.shape[-1] == self.num_channels - - # Compute embeddings, if not already cached - self._compute_embeddings(scalars) - - # Apply embeddings - outputs = ( - scalars * self.cos_cached + self._rotate_half(scalars) * self.sin_cached - ) - - return outputs - - def _compute_embeddings(self, inputs): - """Computes position embeddings and stores them. - - The position embedding is computed along dimension `item_dim` of tensor `inputs` - and is stored in `self.sin_cached` and `self.cos_cached`. - - Parameters - ---------- - inputs : torch.Tensor - Input data. - """ - seq_len = inputs.shape[self.item_dim] - if seq_len != self.seq_len_cached or inputs.device != self.device_cached: - self.seq_len_cached = seq_len - self.device_cached = inputs.device - t = torch.arange(inputs.shape[self.item_dim], device=inputs.device).type_as( - self.inv_freq - ) - freqs = cached_einsum("i,j->ij", t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1).to(inputs.device) - - self.cos_cached = emb.cos() - self.sin_cached = emb.sin() - - # Insert appropriate amount of dimensions such that the embedding correctly enumerates - # along the item dim - item_dim = ( - self.item_dim if self.item_dim >= 0 else inputs.ndim + self.item_dim - ) # Deal with item_dim < 0 - for _ in range(item_dim + 1, inputs.ndim - 1): - self.cos_cached = self.cos_cached.unsqueeze(1) - self.sin_cached = self.sin_cached.unsqueeze(1) - - @staticmethod - def _rotate_half(inputs): - """Utility function that "rotates" a tensor, as required for rotary embeddings.""" - x1, x2 = ( - inputs[..., : inputs.shape[-1] // 2], - inputs[..., inputs.shape[-1] // 2 :], - ) - return torch.cat((-x2, x1), dim=-1) diff --git a/weaver/nn/model/gatr/layers/attention/self_attention.py b/weaver/nn/model/gatr/layers/attention/self_attention.py index ab8c10a5..6b148532 100644 --- a/weaver/nn/model/gatr/layers/attention/self_attention.py +++ b/weaver/nn/model/gatr/layers/attention/self_attention.py @@ -8,7 +8,6 @@ from gatr.layers.attention.attention import GeometricAttention from gatr.layers.attention.config import SelfAttentionConfig -from gatr.layers.attention.positional_encoding import ApplyRotaryPositionalEncoding from gatr.layers.attention.qkv import MultiQueryQKVModule, QKVModule from gatr.layers.dropout import GradeDropout from gatr.layers.linear import EquiLinear @@ -49,15 +48,6 @@ def __init__(self, config: SelfAttentionConfig) -> None: initialization=config.output_init, ) - # Optional positional encoding - self.pos_encoding: nn.Module - if config.pos_encoding: - self.pos_encoding = ApplyRotaryPositionalEncoding( - config.hidden_s_channels, item_dim=-2, base=config.pos_encoding_base - ) - else: - self.pos_encoding = nn.Identity() - # Attention self.attention = GeometricAttention(config) @@ -126,10 +116,6 @@ def forward( multivectors, scalars, additional_qk_features_mv, additional_qk_features_s ) - # Rotary positional encoding - q_s = self.pos_encoding(q_s) - k_s = self.pos_encoding(k_s) - # Attention layer h_mv, h_s = self.attention( q_mv, k_mv, v_mv, q_s, k_s, v_s, attention_mask=attention_mask diff --git a/weaver/nn/model/gatr/layers/conditional_gatr_block.py b/weaver/nn/model/gatr/layers/conditional_gatr_block.py deleted file mode 100644 index 073b8bba..00000000 --- a/weaver/nn/model/gatr/layers/conditional_gatr_block.py +++ /dev/null @@ -1,191 +0,0 @@ -from dataclasses import replace -from typing import Optional, Tuple - -import torch -from torch import nn - -from gatr.layers import ( - SelfAttention, - CrossAttention, - SelfAttentionConfig, - CrossAttentionConfig, -) -from gatr.layers.layer_norm import EquiLayerNorm -from gatr.layers.mlp.config import MLPConfig -from gatr.layers.mlp.mlp import GeoMLP - - -class ConditionalGATrBlock(nn.Module): - """Equivariant transformer decoder block for L-GATr. - - Inputs are first processed by a block consisting of LayerNorm, multi-head geometric - self-attention, and a residual connection. Then the conditions are included with - cross-attention using the same overhead as in the self-attention part. - Then the data is processed by a block consisting of - another LayerNorm, an item-wise two-layer geometric MLP with GeLU activations, and another - residual connection. - - Parameters - ---------- - mv_channels : int - Number of input and output multivector channels - s_channels: int - Number of input and output scalar channels - condition_mv_channels: int - Number of condition multivector channels - condition_s_channels: int - Number of condition scalar channels - attention: SelfAttentionConfig - Attention configuration - crossattention: CrossAttentionConfig - Cross-attention configuration - mlp: MLPConfig - MLP configuration - dropout_prob : float or None - Dropout probability - double_layernorm : bool - Whether to use double layer normalization - """ - - def __init__( - self, - mv_channels: int, - s_channels: int, - condition_mv_channels: int, - condition_s_channels: int, - attention: SelfAttentionConfig, - crossattention: CrossAttentionConfig, - mlp: MLPConfig, - dropout_prob: Optional[float] = None, - double_layernorm: bool = False, - ) -> None: - super().__init__() - - # Normalization layer (stateless, so we can use the same layer for both normalization - # instances) - self.norm = EquiLayerNorm() - self.double_layernorm = double_layernorm - - # Self-attention layer - attention = replace( - attention, - in_mv_channels=mv_channels, - out_mv_channels=mv_channels, - in_s_channels=s_channels, - out_s_channels=s_channels, - output_init="small", - dropout_prob=dropout_prob, - ) - self.attention = SelfAttention(attention) - - # Cross-attention layer - crossattention = replace( - crossattention, - in_q_mv_channels=mv_channels, - in_q_s_channels=s_channels, - in_kv_mv_channels=condition_mv_channels, - in_kv_s_channels=condition_s_channels, - out_mv_channels=mv_channels, - out_s_channels=s_channels, - output_init="small", - dropout_prob=dropout_prob, - ) - self.crossattention = CrossAttention(crossattention) - - # MLP block - mlp = replace( - mlp, - mv_channels=(mv_channels, 2 * mv_channels, mv_channels), - s_channels=(s_channels, 2 * s_channels, s_channels), - dropout_prob=dropout_prob, - ) - self.mlp = GeoMLP(mlp) - - def forward( - self, - multivectors: torch.Tensor, - multivectors_condition: torch.Tensor, - scalars: torch.Tensor = None, - scalars_condition: torch.Tensor = None, - attention_mask=None, - crossattention_mask=None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Forward pass of the transformer decoder block. - - Parameters - ---------- - multivectors : torch.Tensor with shape (..., items, channels, 16) - Input multivectors. - scalars : torch.Tensor with shape (..., s_channels) - Input scalars. - multivectors_condition : torch.Tensor with shape (..., items, channels, 16) - Input condition multivectors. - scalars_condition : torch.Tensor with shape (..., s_channels) - Input condition scalars. - attention_mask: None or torch.Tensor or AttentionBias - Optional attention mask. - crossattention_mask: None or torch.Tensor or AttentionBias - Optional attention mask for the condition. - - Returns - ------- - outputs_mv : torch.Tensor with shape (..., items, channels, 16). - Output multivectors - output_scalars : torch.Tensor with shape (..., s_channels) - Output scalars - """ - - # Self-attention block: pre layer norm - h_mv, h_s = self.norm(multivectors, scalars=scalars) - - # Self-attention block: self attention - h_mv, h_s = self.attention( - h_mv, - scalars=h_s, - attention_mask=attention_mask, - ) - - # Self-attention block: post layer norm - if self.double_layernorm: - h_mv, h_s = self.norm(h_mv, scalars=h_s) - - # Self-attention block: skip connection - multivectors = multivectors + h_mv - scalars = scalars + h_s - - # Cross-attention block: pre layer norm - h_mv, h_s = self.norm(multivectors, scalars=scalars) - c_mv, c_s = self.norm(multivectors_condition, scalars=scalars_condition) - - # Cross-attention block: cross attention - h_mv, h_s = self.crossattention( - multivectors_q=h_mv, - multivectors_kv=c_mv, - scalars_q=h_s, - scalars_kv=c_s, - attention_mask=crossattention_mask, - ) - - # Cross-attention block: post layer norm - if self.double_layernorm: - h_mv, h_s = self.norm(h_mv, scalars=h_s) - - # Cross-attention block: skip connection - outputs_mv = multivectors + h_mv - outputs_s = scalars + h_s - - # MLP block: pre layer norm - h_mv, h_s = self.norm(outputs_mv, scalars=outputs_s) - - # MLP block: MLP - h_mv, h_s = self.mlp(h_mv, scalars=h_s) - - # MLP block: post layer norm - if self.double_layernorm: - h_mv, h_s = self.norm(h_mv, scalars=h_s) - - # MLP block: skip connection - outputs_mv = outputs_mv + h_mv - outputs_s = outputs_s + h_s - - return outputs_mv, outputs_s diff --git a/weaver/nn/model/gatr/nets/__init__.py b/weaver/nn/model/gatr/nets/__init__.py index 2489425b..2c9e6e9e 100644 --- a/weaver/nn/model/gatr/nets/__init__.py +++ b/weaver/nn/model/gatr/nets/__init__.py @@ -1,4 +1 @@ -from .axial_gatr import AxialGATr from .gatr import GATr -from .gap import GAP -from .conditional_gatr import ConditionalGATr diff --git a/weaver/nn/model/gatr/nets/axial_gatr.py b/weaver/nn/model/gatr/nets/axial_gatr.py deleted file mode 100644 index dc42f989..00000000 --- a/weaver/nn/model/gatr/nets/axial_gatr.py +++ /dev/null @@ -1,214 +0,0 @@ -from dataclasses import replace -from typing import Optional, Tuple, Union - -import torch -from einops import rearrange -from torch import nn -from torch.utils.checkpoint import checkpoint - -from gatr.layers.attention.config import SelfAttentionConfig -from gatr.layers.gatr_block import GATrBlock -from gatr.layers.linear import EquiLinear -from gatr.layers.mlp.config import MLPConfig - -# Default rearrange patterns -_MV_REARRANGE_PATTERN = "... i j c x -> ... j i c x" -_S_REARRANGE_PATTERN = "... i j c -> ... j i c" - - -class AxialGATr(nn.Module): # pylint: disable=duplicate-code - """Axial L-GATr network for two token dimensions. - - It combines `num_blocks` L-GATr transformer blocks, each consisting of geometric self-attention - layers, a geometric MLP, residual connections, and normalization layers. In addition, there - are initial and final equivariant linear layers. - - Assumes input data with shape `(..., num_items_1, num_items_2, num_channels, 16)`. - - The first, third, fifth, ... block computes attention over the `items_2` axis. The other blocks - compute attention over the `items_1` axis. Positional encoding can be specified separately for - both axes. - - Parameters - ---------- - in_mv_channels : int - Number of input multivector channels. - out_mv_channels : int - Number of output multivector channels. - hidden_mv_channels : int - Number of hidden multivector channels. - in_s_channels : None or int - If not None, sets the number of scalar input channels. - out_s_channels : None or int - If not None, sets the number of scalar output channels. - hidden_s_channels : None or int - If not None, sets the number of scalar hidden channels. - attention: Dict - Data for SelfAttentionConfig - mlp: Dict - Data for MLPConfig - num_blocks : int - Number of transformer blocks. - pos_encodings : tuple of bool - Whether to apply rotary positional embeddings along the item dimensions to the scalar keys - and queries. The first element in the tuple determines whether positional embeddings - are applied to the first item dimension, the second element the same for the second item - dimension. - collapse_dims_for_odd_blocks : bool - Whether the batch dimensions will be collapsed in odd blocks (to support xformers block - attention) - """ - - def __init__( - self, - in_mv_channels: int, - out_mv_channels: int, - hidden_mv_channels: int, - in_s_channels: Optional[int], - out_s_channels: Optional[int], - hidden_s_channels: Optional[int], - attention: SelfAttentionConfig, - mlp: MLPConfig, - num_blocks: int = 20, - checkpoint_blocks: bool = False, - pos_encodings: Tuple[bool, bool] = (False, False), - collapse_dims_for_odd_blocks=False, - **kwargs, - ) -> None: - super().__init__() - self.linear_in = EquiLinear( - in_mv_channels, - hidden_mv_channels, - in_s_channels=in_s_channels, - out_s_channels=hidden_s_channels, - ) - attention = SelfAttentionConfig.cast(attention) - mlp = MLPConfig.cast(mlp) - self.blocks = nn.ModuleList( - [ - GATrBlock( - mv_channels=hidden_mv_channels, - s_channels=hidden_s_channels, - attention=replace( - attention, - pos_encoding=pos_encodings[(block + 1) % 2], - ), - mlp=mlp, - ) - for block in range(num_blocks) - ] - ) - self.linear_out = EquiLinear( - hidden_mv_channels, - out_mv_channels, - in_s_channels=hidden_s_channels, - out_s_channels=out_s_channels, - ) - self._checkpoint_blocks = checkpoint_blocks - self._collapse_dims_for_odd_blocks = collapse_dims_for_odd_blocks - - def forward( - self, - multivectors: torch.Tensor, - scalars: Optional[torch.Tensor] = None, - attention_mask: Optional[Tuple] = None, - ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: - """Forward pass of the network. - - Parameters - ---------- - multivectors : torch.Tensor with shape (..., num_items_1, num_items_2, in_mv_channels, 16) - Input multivectors. - scalars : None or torch.Tensor with shape (..., num_items_1, num_items_2, in_s_channels) - Optional input scalars. - attention_mask : None or tuple of torch.Tensor - Optional attention masks - - Returns - ------- - outputs_mv : torch.Tensor with shape (..., num_items_1, num_items_2, out_mv_channels, 16) - Output multivectors. - outputs_s : None or torch.Tensor with shape - (..., num_items_1, num_items_2, out_mv_channels, 16) - Output scalars, if scalars are provided. Otherwise None. - """ - - # Pass through the blocks - h_mv, h_s = self.linear_in(multivectors, scalars=scalars) - - for i, block in enumerate(self.blocks): - # For first, third, ... block, we want to perform attention over the first token - # dimension. We implement this by transposing the two item dimensions. - if i % 2 == 1: - h_mv, h_s, input_batch_dims = self._reshape_data_before_odd_blocks( - h_mv, h_s - ) - else: - input_batch_dims = None - - # Attention masks will also be different - if attention_mask is None: - this_attention_mask = None - else: - this_attention_mask = attention_mask[(i + 1) % 2] - - if self._checkpoint_blocks: - h_mv, h_s = checkpoint( - block, - h_mv, - use_reentrant=False, - scalars=h_s, - attention_mask=this_attention_mask, - ) - else: - h_mv, h_s = block( - h_mv, - scalars=h_s, - attention_mask=this_attention_mask, - ) - - # Transposing back to standard axis order - if i % 2 == 1: - h_mv, h_s = self._reshape_data_after_odd_blocks( - h_mv, h_s, input_batch_dims - ) - - outputs_mv, outputs_s = self.linear_out(h_mv, scalars=h_s) - - return outputs_mv, outputs_s - - def _reshape_data_before_odd_blocks(self, multivector, scalar): - # Prepare reshuffling between axial layers - input_batch_dims = multivector.shape[:2] - assert scalar.shape[:2] == input_batch_dims - - multivector = rearrange( - multivector, _MV_REARRANGE_PATTERN - ) # (axis2, axis1, ...) - scalar = rearrange(scalar, _S_REARRANGE_PATTERN) # (axis2, axis1, ...) - - if self._collapse_dims_for_odd_blocks: - multivector = multivector.reshape( - -1, *multivector.shape[2:] - ) # (axis2 * axis1, ...) - scalar = scalar.reshape(-1, *scalar.shape[2:]) # (axis2 * axis1, ...) - - return multivector, scalar, input_batch_dims - - def _reshape_data_after_odd_blocks(self, multivector, scalar, input_batch_dims): - # Transposing back to standard axis order - - if self._collapse_dims_for_odd_blocks: - multivector = multivector.reshape( - *input_batch_dims, *multivector.shape[1:] - ) # (axis2, axis1, ...) - scalar = scalar.reshape( - *input_batch_dims, *scalar.shape[1:] - ) # (axis2, axis1, ...) - - multivector = rearrange( - multivector, _MV_REARRANGE_PATTERN - ) # (axis1, axis2, ...) - scalar = rearrange(scalar, _S_REARRANGE_PATTERN) # (axis1, axis2, ...) - - return multivector, scalar diff --git a/weaver/nn/model/gatr/nets/conditional_gatr.py b/weaver/nn/model/gatr/nets/conditional_gatr.py deleted file mode 100644 index e23d6617..00000000 --- a/weaver/nn/model/gatr/nets/conditional_gatr.py +++ /dev/null @@ -1,227 +0,0 @@ -"""Equivariant transformer for multivector data.""" - -from dataclasses import replace -from typing import Optional, Tuple, Union - -import torch -from torch import nn -from torch.utils.checkpoint import checkpoint - -from gatr.layers import ( - CrossAttentionConfig, - SelfAttentionConfig, - GATrBlock, - ConditionalGATrBlock, - EquiLinear, -) -from gatr.layers.mlp.config import MLPConfig - - -class ConditionalGATr(nn.Module): - """L-GATr network for a data with a single token dimension. - - It combines `num_blocks` L-GATr transformer blocks, each consisting of geometric self-attention - layers, a geometric MLP, residual connections, and normalization layers. In addition, there - are initial and final equivariant linear layers. - - Assumes input has shape `(..., items, in_channels, 16)`, output has shape - `(..., items, out_channels, 16)`, will create hidden representations with shape - `(..., items, hidden_channels, 16)`. - - Parameters - ---------- - in_mv_channels : int - Number of input multivector channels. - condition_mv_channels : int - Number of condition multivector channels. - out_mv_channels : int - Number of output multivector channels. - hidden_mv_channels : int - Number of hidden multivector channels. - in_s_channels : None or int - If not None, sets the number of scalar input channels. - condition_s_channels : None or int - If not None, sets the number of scalar condition channels. - out_s_channels : None or int - If not None, sets the number of scalar output channels. - hidden_s_channels : None or int - If not None, sets the number of scalar hidden channels. - attention: Dict - Data for SelfAttentionConfig - crossattention: Dict - Data for CrossAttentionConfig - attention_condition: Dict - Data for SelfAttentionConfig - mlp: Dict - Data for MLPConfig - num_blocks : int - Number of transformer blocks. - dropout_prob : float or None - Dropout probability - double_layernorm : bool - Whether to use double layer normalization - """ - - def __init__( - self, - in_mv_channels: int, - condition_mv_channels: int, - out_mv_channels: int, - hidden_mv_channels: int, - in_s_channels: Optional[int], - condition_s_channels: Optional[int], - out_s_channels: Optional[int], - hidden_s_channels: Optional[int], - attention: SelfAttentionConfig, - crossattention: CrossAttentionConfig, - attention_condition: SelfAttentionConfig, - mlp: MLPConfig, - num_blocks: int = 10, - checkpoint_blocks: bool = False, - dropout_prob: Optional[float] = None, - double_layernorm: bool = False, - **kwargs, - ) -> None: - super().__init__() - self.linear_in_condition = EquiLinear( - condition_mv_channels, - hidden_mv_channels, - in_s_channels=condition_s_channels, - out_s_channels=hidden_s_channels, - ) - self.linear_in = EquiLinear( - in_mv_channels, - hidden_mv_channels, - in_s_channels=in_s_channels, - out_s_channels=hidden_s_channels, - ) - self.linear_condition = EquiLinear( - condition_mv_channels, - hidden_mv_channels, - in_s_channels=condition_s_channels, - out_s_channels=hidden_s_channels, - ) - attention = SelfAttentionConfig.cast(attention) - crossattention = CrossAttentionConfig.cast(crossattention) - attention_condition = SelfAttentionConfig.cast(attention_condition) - mlp = MLPConfig.cast(mlp) - self.condition_blocks = nn.ModuleList( - [ - GATrBlock( - mv_channels=hidden_mv_channels, - s_channels=hidden_s_channels, - attention=attention_condition, - mlp=mlp, - dropout_prob=dropout_prob, - double_layernorm=double_layernorm, - ) - for _ in range(num_blocks) - ] - ) - self.blocks = nn.ModuleList( - [ - ConditionalGATrBlock( - mv_channels=hidden_mv_channels, - s_channels=hidden_s_channels, - condition_mv_channels=hidden_mv_channels, - condition_s_channels=hidden_s_channels, - attention=attention, - crossattention=crossattention, - mlp=mlp, - dropout_prob=dropout_prob, - double_layernorm=double_layernorm, - ) - ] - ) - self.linear_out = EquiLinear( - hidden_mv_channels, - out_mv_channels, - in_s_channels=hidden_s_channels, - out_s_channels=out_s_channels, - ) - self._checkpoint_blocks = checkpoint_blocks - - def forward( - self, - multivectors: torch.Tensor, - multivectors_condition: torch.Tensor, - scalars: Optional[torch.Tensor] = None, - scalars_condition: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - attention_mask_condition: Optional[torch.Tensor] = None, - crossattention_mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: - """Forward pass of the network. - - Parameters - ---------- - multivectors : torch.Tensor with shape (..., num_items, in_mv_channels, 16) - Input multivectors. - multivectors_condition : torch.Tensor with shape (..., num_items_condition, in_mv_channels, 16) - Input multivectors. - scalars : None or torch.Tensor with shape (..., num_items, in_s_channels) - Optional input scalars. - scalars_condition : None or torch.Tensor with shape (..., num_items_condition, in_s_channels) - Optional input scalars. - attention_mask: None or torch.Tensor with shape (..., num_items, num_items) - Optional attention mask - attention_mask_condition: None or torch.Tensor with shape (..., num_items_condition, num_items_condition) - Optional attention mask for condition - crossattention_mask: None or torch.Tensor with shape (..., num_items, num_items_condition) - Optional mask for cross-attention - - Returns - ------- - outputs_mv : torch.Tensor with shape (..., num_items, out_mv_channels, 16) - Output multivectors. - outputs_s : None or torch.Tensor with shape (..., num_items, out_s_channels) - Output scalars, if scalars are provided. Otherwise None. - """ - - # Encode condition with GATr blocks - c_mv, c_s = self.linear_in_condition( - multivectors_condition, scalars=scalars_condition - ) - for block in self.condition_blocks: - if self._checkpoint_blocks: - c_mv, c_s = checkpoint( - block, - c_mv, - use_reentrant=False, - scalars=c_s, - attention_mask=attention_mask_condition, - ) - else: - c_mv, c_s = block( - c_mv, - scalars=c_s, - attention_mask=attention_mask_condition, - ) - - # Decode condition into main track with - h_mv, h_s = self.linear_in(multivectors, scalars=scalars) - for block in self.blocks: - if self._checkpoint_blocks: - h_mv, h_s = checkpoint( - block, - h_mv, - use_reentrant=False, - scalars=h_s, - multivectors_condition=c_mv, - scalars_condition=c_s, - attention_mask=attention_mask, - crossattention_mask=crossattention_mask, - ) - else: - h_mv, h_s = block( - h_mv, - scalars=h_s, - multivectors_condition=c_mv, - scalars_condition=c_s, - attention_mask=attention_mask, - crossattention_mask=crossattention_mask, - ) - - outputs_mv, outputs_s = self.linear_out(h_mv, scalars=h_s) - - return outputs_mv, outputs_s diff --git a/weaver/nn/model/gatr/nets/gap.py b/weaver/nn/model/gatr/nets/gap.py deleted file mode 100644 index d8e75234..00000000 --- a/weaver/nn/model/gatr/nets/gap.py +++ /dev/null @@ -1,120 +0,0 @@ -from dataclasses import replace -from typing import Optional, Tuple, Union - -import torch -from torch import nn -from torch.utils.checkpoint import checkpoint - -from gatr.layers.linear import EquiLinear -from gatr.layers.mlp import MLPConfig, GeoMLP - - -class GAP(nn.Module): - """Geometric Algebra Perceptron network for a data with a single token dimension. - It combines `num_blocks` GeoMLP blocks. - - Assumes input has shape `(..., in_channels, 16)`, output has shape - `(..., out_channels, 16)`, will create hidden representations with shape - `(..., hidden_channels, 16)`. - - Parameters - ---------- - in_mv_channels : int - Number of input multivector channels. - out_mv_channels : int - Number of output multivector channels. - hidden_mv_channels : int - Number of hidden multivector channels. - in_s_channels : None or int - If not None, sets the number of scalar input channels. - out_s_channels : None or int - If not None, sets the number of scalar output channels. - hidden_s_channels : None or int - If not None, sets the number of scalar hidden channels. - num_blocks : int - Number of resnet blocks. - dropout_prob : float or None - Dropout probability - """ - - def __init__( - self, - in_mv_channels: int, - out_mv_channels: int, - hidden_mv_channels: int, - in_s_channels: Optional[int], - out_s_channels: Optional[int], - hidden_s_channels: Optional[int], - mlp: MLPConfig, - num_blocks: int = 10, - num_layers: int = 3, - checkpoint_blocks: bool = False, - dropout_prob: Optional[float] = None, - **kwargs, - ) -> None: - super().__init__() - - self.linear_in = EquiLinear( - in_mv_channels, - hidden_mv_channels, - in_s_channels=in_s_channels, - out_s_channels=hidden_s_channels, - ) - - mlp = MLPConfig.cast(mlp) - mlp = replace( - mlp, - mv_channels=[hidden_mv_channels for _ in range(num_layers)], - s_channels=[hidden_s_channels for _ in range(num_layers)], - dropout_prob=dropout_prob, - ) - self.blocks = nn.ModuleList([GeoMLP(mlp) for _ in range(num_blocks)]) - - self.linear_out = EquiLinear( - hidden_mv_channels, - out_mv_channels, - in_s_channels=hidden_s_channels, - out_s_channels=out_s_channels, - ) - self._checkpoint_blocks = checkpoint_blocks - - def forward( - self, - multivectors: torch.Tensor, - scalars: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: - """Forward pass of the network. - - Parameters - ---------- - multivectors : torch.Tensor with shape (..., in_mv_channels, 16) - Input multivectors. - scalars : None or torch.Tensor with shape (..., in_s_channels) - Optional input scalars. - - Returns - ------- - outputs_mv : torch.Tensor with shape (..., out_mv_channels, 16) - Output multivectors. - outputs_s : None or torch.Tensor with shape (..., out_s_channels) - Output scalars, if scalars are provided. Otherwise None. - """ - - # Pass through the blocks - h_mv, h_s = self.linear_in(multivectors, scalars=scalars) - for block in self.blocks: - if self._checkpoint_blocks: - h_mv, h_s = checkpoint( - block, - h_mv, - use_reentrant=False, - scalars=h_s, - ) - else: - h_mv, h_s = block( - h_mv, - scalars=h_s, - ) - outputs_mv, outputs_s = self.linear_out(h_mv, scalars=h_s) - - return outputs_mv, outputs_s diff --git a/weaver/nn/model/gatr/primitives/attention.py b/weaver/nn/model/gatr/primitives/attention.py index 9344647f..c69e60d7 100644 --- a/weaver/nn/model/gatr/primitives/attention.py +++ b/weaver/nn/model/gatr/primitives/attention.py @@ -1,20 +1,12 @@ -from typing import Optional, Tuple, Union +from typing import Optional, Tuple import torch from einops import rearrange from torch import Tensor from torch.nn.functional import scaled_dot_product_attention as torch_sdpa -from xformers.ops import AttentionBias, memory_efficient_attention from gatr.primitives.invariants import _load_inner_product_factors -# Masked out attention logits are set to this constant (a finite replacement for -inf): -_MASKED_OUT = float("-inf") - -# Force the use of xformers attention, even when no xformers attention mask is provided: -FORCE_XFORMERS = False - - def sdp_attention( q_mv: Tensor, k_mv: Tensor, @@ -94,14 +86,11 @@ def scaled_dot_product_attention( query: Tensor, key: Tensor, value: Tensor, - attn_mask: Optional[Union[AttentionBias, Tensor]] = None, + attn_mask: Optional[Tensor] = None, is_causal=False, ) -> Tensor: """Execute (vanilla) scaled dot-product attention. - Dynamically dispatch to xFormers if attn_mask is an instance of xformers.ops.AttentionBias - or FORCE_XFORMERS is set, use torch otherwise. - Parameters ---------- query : Tensor @@ -110,7 +99,7 @@ def scaled_dot_product_attention( of shape [batch, head, item, d] value : Tensor of shape [batch, head, item, d] - attn_mask : Optional[Union[AttentionBias, Tensor]] + attn_mask : Optional[Tensor] Attention mask is_causal: bool @@ -119,24 +108,4 @@ def scaled_dot_product_attention( Tensor of shape [batch, head, item, d] """ - if FORCE_XFORMERS or isinstance(attn_mask, AttentionBias): - assert ( - not is_causal - ), "is_causal=True not implemented yet for xformers attention" - if key.shape[1] != query.shape[1]: # required to make multi_query work - key = key.expand(key.shape[0], query.shape[1], *key.shape[2:]) - value = value.expand(value.shape[0], query.shape[1], *value.shape[2:]) - query = query.transpose( - 1, 2 - ) # [batch, head, item, d] -> [batch, item, head, d] - key = key.transpose(1, 2) - value = value.transpose(1, 2) - out = memory_efficient_attention( - query.contiguous(), - key.contiguous(), - value.contiguous(), - attn_bias=attn_mask, - ) - out = out.transpose(1, 2) # [batch, item, head, d] -> [batch, head, item, d] - return out return torch_sdpa(query, key, value, attn_mask=attn_mask, is_causal=is_causal) diff --git a/weaver/nn/model/gatr/utils/clifford.py b/weaver/nn/model/gatr/utils/clifford.py deleted file mode 100644 index 983f3031..00000000 --- a/weaver/nn/model/gatr/utils/clifford.py +++ /dev/null @@ -1,136 +0,0 @@ -"""Geometric algebra operations based on the clifford library.""" - -from typing import Optional - -import clifford -import numpy as np -import torch - -LAYOUT, BLADES = clifford.Cl(1, 3) - - -def np_to_mv(array): - """Shorthand to transform a numpy array to a Pin(1,3) multivector.""" - return clifford.MultiVector(LAYOUT, value=array) - - -def tensor_to_mv(tensor): - """Shorthand to transform a numpy array to a Pin(1,3) multivector.""" - return np_to_mv(tensor.detach().cpu().numpy()) - - -def tensor_to_mv_list(tensor): - """Transforms a torch.Tensor to a list of multivector objects.""" - - tensor = tensor.reshape((-1, 16)) - mv_list = [tensor_to_mv(x) for x in tensor] - - return mv_list - - -def mv_list_to_tensor(multivectors, batch_shape=None): - """Transforms a list of multivector objects to a torch.Tensor.""" - - tensor = torch.from_numpy(np.array([mv.value for mv in multivectors])).to( - torch.float32 - ) - if batch_shape is not None: - tensor = tensor.reshape(*batch_shape, 16) - - return tensor - - -def sample_pin_multivector( - spin: bool = False, rng: Optional[np.random.Generator] = None -): - """Samples from the Pin(1,3) group as a product of reflections.""" - - if rng is None: - rng = np.random.default_rng() - - # Sample number of reflections we want to multiply - if spin: - i = np.random.randint(3) * 2 - else: - i = np.random.randint(5) - - # If no reflections, just return unit scalar - if i == 0: - return BLADES[""] - - multivector = 1.0 - for _ in range(i): - # Sample reflection vector - vector = np.zeros(16) - vector[2:5] = rng.normal(size=3) * 2 - norm = np.linalg.norm(vector[2:5]) - vector[1] = (rng.uniform(size=1) - 0.5) * norm - - vector_mv = np_to_mv(vector) - vector_mv = vector_mv / abs(vector_mv.mag2()) ** 0.5 - - # Multiply together (geometric product) - multivector = multivector * vector_mv - - return multivector - - -def get_parity(mv): - """Gets parity of a clifford multivector. - - Given a clifford multivector, returns True if it is pure-odd-grade, False if it is pure-even - grade, and raises a RuntimeError if it is mixed. - """ - if mv == mv.even: - return False - if mv == mv.odd: - return True - raise RuntimeError(f"Mixed-grade multivector: {mv}") - - -def sandwich(u, x): - """Given clifford multivectors, computes their sandwich product. - - Specifically, given a Pin element u and a PGA element x, both given as clifford multivectors, - computes the sandwich product - ``` - sandwich(x, u) = (-1)^(grade(u) * grade(x)) u x u^{-1} . - ``` - - If `u` is of odd grades, then this is equal to `u * grade_involute(x) * u^{-1}`. - If `u` is of even grades, then this is equal to `u * x * u^{-1}`. - """ - - if get_parity(u): - return u * x.gradeInvol() * u.shirokov_inverse() - - return u * x * u.shirokov_inverse() - - -class SlowRandomPinTransform: - """Random Pin transform on a multivector torch.Tensor. - - Slow, only used for testing purposes. Breaks computational graph. - """ - - def __init__(self, spin=False, rng=None): - super().__init__() - self._u = sample_pin_multivector(spin, rng) - self._u_inverse = self._u.shirokov_inverse() - - def __call__(self, inputs: torch.Tensor) -> torch.Tensor: - """Apply Pin transformation to multivector inputs.""" - # Input shape - assert inputs.shape[-1] == 16 - batch_dims = inputs.shape[:-1] - - # Convert inputs to list of multivectors - inputs_mv = tensor_to_mv_list(inputs) - - # Transform - outputs_mv = [sandwich(self._u, x) for x in inputs_mv] - - # Back to tensor - outputs = mv_list_to_tensor(outputs_mv, batch_shape=batch_dims) - - return outputs diff --git a/weaver/nn/model/gatr/utils/tensors.py b/weaver/nn/model/gatr/utils/tensors.py deleted file mode 100644 index 3373e20c..00000000 --- a/weaver/nn/model/gatr/utils/tensors.py +++ /dev/null @@ -1,34 +0,0 @@ -import torch - - -def assert_equal(vals): - """Assert all values in sequence are equal.""" - for v in vals: - assert v == vals[0] - - -def block_stack(tensors, dim1, dim2): - """Block diagonally stack tensors along dimensions dim1 and dim2.""" - assert_equal([t.dim() for t in tensors]) - shapes = [t.shape for t in tensors] - shapes_t = list(map(list, zip(*shapes))) - for i, ss in enumerate(shapes_t): - if i not in (dim1, dim2): - assert_equal(ss) - - dim2_len = sum(shapes_t[dim2]) - opts = dict(device=tensors[0].device, dtype=tensors[0].dtype) - - padded_tensors = [] - offset = 0 - for tensor in tensors: - before_shape = list(tensor.shape) - before_shape[dim2] = offset - after_shape = list(tensor.shape) - after_shape[dim2] = dim2_len - tensor.shape[dim2] - offset - before = torch.zeros(*before_shape, **opts) - after = torch.zeros(*after_shape, **opts) - padded = torch.cat([before, tensor, after], dim2) - padded_tensors.append(padded) - offset += tensor.shape[dim2] - return torch.cat(padded_tensors, dim1) From 1bf728e55333ada686e42ee3d85fac1e284e23a0 Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Thu, 13 Feb 2025 21:01:48 +0100 Subject: [PATCH 03/29] Make gatr imports relative --- weaver/nn/model/gatr/interface/spurions.py | 2 +- weaver/nn/model/gatr/layers/attention/attention.py | 4 ++-- weaver/nn/model/gatr/layers/attention/qkv.py | 4 ++-- .../nn/model/gatr/layers/attention/self_attention.py | 10 +++++----- weaver/nn/model/gatr/layers/dropout.py | 2 +- weaver/nn/model/gatr/layers/gatr_block.py | 8 ++++---- weaver/nn/model/gatr/layers/layer_norm.py | 2 +- weaver/nn/model/gatr/layers/linear.py | 4 ++-- weaver/nn/model/gatr/layers/mlp/geometric_bilinears.py | 6 +++--- weaver/nn/model/gatr/layers/mlp/mlp.py | 10 +++++----- weaver/nn/model/gatr/layers/mlp/nonlinearities.py | 2 +- weaver/nn/model/gatr/nets/gatr.py | 8 ++++---- weaver/nn/model/gatr/primitives/attention.py | 2 +- weaver/nn/model/gatr/primitives/bilinear.py | 2 +- weaver/nn/model/gatr/primitives/dropout.py | 2 +- weaver/nn/model/gatr/primitives/invariants.py | 4 ++-- weaver/nn/model/gatr/primitives/linear.py | 2 +- weaver/nn/model/gatr/primitives/normalization.py | 2 +- 18 files changed, 38 insertions(+), 38 deletions(-) diff --git a/weaver/nn/model/gatr/interface/spurions.py b/weaver/nn/model/gatr/interface/spurions.py index 58537a21..69fde990 100644 --- a/weaver/nn/model/gatr/interface/spurions.py +++ b/weaver/nn/model/gatr/interface/spurions.py @@ -1,5 +1,5 @@ import torch -from gatr.interface import embed_vector +from .vector import embed_vector def get_num_spurions( diff --git a/weaver/nn/model/gatr/layers/attention/attention.py b/weaver/nn/model/gatr/layers/attention/attention.py index 73b7f5cb..94a654ad 100644 --- a/weaver/nn/model/gatr/layers/attention/attention.py +++ b/weaver/nn/model/gatr/layers/attention/attention.py @@ -2,8 +2,8 @@ from torch import nn -from gatr.layers.attention.config import SelfAttentionConfig -from gatr.primitives.attention import sdp_attention +from .config import SelfAttentionConfig +from ...primitives.attention import sdp_attention class GeometricAttention(nn.Module): diff --git a/weaver/nn/model/gatr/layers/attention/qkv.py b/weaver/nn/model/gatr/layers/attention/qkv.py index 70a9698c..17179a33 100644 --- a/weaver/nn/model/gatr/layers/attention/qkv.py +++ b/weaver/nn/model/gatr/layers/attention/qkv.py @@ -2,8 +2,8 @@ from einops import rearrange from torch import nn -from gatr.layers.attention.config import SelfAttentionConfig -from gatr.layers.linear import EquiLinear +from .config import SelfAttentionConfig +from ..linear import EquiLinear class QKVModule(nn.Module): diff --git a/weaver/nn/model/gatr/layers/attention/self_attention.py b/weaver/nn/model/gatr/layers/attention/self_attention.py index 6b148532..b34141a3 100644 --- a/weaver/nn/model/gatr/layers/attention/self_attention.py +++ b/weaver/nn/model/gatr/layers/attention/self_attention.py @@ -6,11 +6,11 @@ from einops import rearrange from torch import nn -from gatr.layers.attention.attention import GeometricAttention -from gatr.layers.attention.config import SelfAttentionConfig -from gatr.layers.attention.qkv import MultiQueryQKVModule, QKVModule -from gatr.layers.dropout import GradeDropout -from gatr.layers.linear import EquiLinear +from .attention import GeometricAttention +from .config import SelfAttentionConfig +from .qkv import MultiQueryQKVModule, QKVModule +from ..dropout import GradeDropout +from ..linear import EquiLinear class SelfAttention(nn.Module): diff --git a/weaver/nn/model/gatr/layers/dropout.py b/weaver/nn/model/gatr/layers/dropout.py index 15180a66..ec32d041 100644 --- a/weaver/nn/model/gatr/layers/dropout.py +++ b/weaver/nn/model/gatr/layers/dropout.py @@ -5,7 +5,7 @@ import torch from torch import nn -from gatr.primitives import grade_dropout +from ..primitives import grade_dropout class GradeDropout(nn.Module): diff --git a/weaver/nn/model/gatr/layers/gatr_block.py b/weaver/nn/model/gatr/layers/gatr_block.py index fe37ca40..f7cf651c 100644 --- a/weaver/nn/model/gatr/layers/gatr_block.py +++ b/weaver/nn/model/gatr/layers/gatr_block.py @@ -4,10 +4,10 @@ import torch from torch import nn -from gatr.layers import SelfAttention, SelfAttentionConfig -from gatr.layers.layer_norm import EquiLayerNorm -from gatr.layers.mlp.config import MLPConfig -from gatr.layers.mlp.mlp import GeoMLP +from .attention import SelfAttention, SelfAttentionConfig +from .layer_norm import EquiLayerNorm +from .mlp.config import MLPConfig +from .mlp.mlp import GeoMLP class GATrBlock(nn.Module): diff --git a/weaver/nn/model/gatr/layers/layer_norm.py b/weaver/nn/model/gatr/layers/layer_norm.py index 358b8300..ed1456ec 100644 --- a/weaver/nn/model/gatr/layers/layer_norm.py +++ b/weaver/nn/model/gatr/layers/layer_norm.py @@ -5,7 +5,7 @@ import torch from torch import nn -from gatr.primitives import equi_layer_norm +from ..primitives import equi_layer_norm class EquiLayerNorm(nn.Module): diff --git a/weaver/nn/model/gatr/layers/linear.py b/weaver/nn/model/gatr/layers/linear.py index 6233191b..8033d715 100644 --- a/weaver/nn/model/gatr/layers/linear.py +++ b/weaver/nn/model/gatr/layers/linear.py @@ -6,8 +6,8 @@ import torch from torch import nn -from gatr.interface import embed_scalar -from gatr.primitives.linear import equi_linear, USE_FULLY_CONNECTED_SUBGROUP +from ..interface import embed_scalar +from ..primitives.linear import equi_linear, USE_FULLY_CONNECTED_SUBGROUP # switch to mix pseudoscalar multivector components directly into scalar components # this only makes sense when working with the special orthochronous Lorentz group, diff --git a/weaver/nn/model/gatr/layers/mlp/geometric_bilinears.py b/weaver/nn/model/gatr/layers/mlp/geometric_bilinears.py index 061f53f2..287bf87a 100644 --- a/weaver/nn/model/gatr/layers/mlp/geometric_bilinears.py +++ b/weaver/nn/model/gatr/layers/mlp/geometric_bilinears.py @@ -5,9 +5,9 @@ import torch from torch import nn -from gatr.layers.linear import EquiLinear -from gatr.primitives import geometric_product -from gatr.layers.layer_norm import EquiLayerNorm +from ..linear import EquiLinear +from ...primitives import geometric_product +from ..layer_norm import EquiLayerNorm # switch to set bivector components to zero, # after they are generated by the geometric product diff --git a/weaver/nn/model/gatr/layers/mlp/mlp.py b/weaver/nn/model/gatr/layers/mlp/mlp.py index fbfed722..05f1421c 100644 --- a/weaver/nn/model/gatr/layers/mlp/mlp.py +++ b/weaver/nn/model/gatr/layers/mlp/mlp.py @@ -5,11 +5,11 @@ import torch from torch import nn -from gatr.layers.dropout import GradeDropout -from gatr.layers.linear import EquiLinear -from gatr.layers.mlp.config import MLPConfig -from gatr.layers.mlp.geometric_bilinears import GeometricBilinear -from gatr.layers.mlp.nonlinearities import ScalarGatedNonlinearity +from ..dropout import GradeDropout +from ..linear import EquiLinear +from .config import MLPConfig +from .geometric_bilinears import GeometricBilinear +from .nonlinearities import ScalarGatedNonlinearity USE_GEOMETRIC_PRODUCT = True diff --git a/weaver/nn/model/gatr/layers/mlp/nonlinearities.py b/weaver/nn/model/gatr/layers/mlp/nonlinearities.py index 94c871e4..789015a3 100644 --- a/weaver/nn/model/gatr/layers/mlp/nonlinearities.py +++ b/weaver/nn/model/gatr/layers/mlp/nonlinearities.py @@ -3,7 +3,7 @@ import torch from torch import nn -from gatr.primitives.nonlinearities import gated_gelu, gated_relu, gated_sigmoid +from ...primitives.nonlinearities import gated_gelu, gated_relu, gated_sigmoid class ScalarGatedNonlinearity(nn.Module): diff --git a/weaver/nn/model/gatr/nets/gatr.py b/weaver/nn/model/gatr/nets/gatr.py index 3b3d059c..25f1ee87 100644 --- a/weaver/nn/model/gatr/nets/gatr.py +++ b/weaver/nn/model/gatr/nets/gatr.py @@ -7,10 +7,10 @@ from torch import nn from torch.utils.checkpoint import checkpoint -from gatr.layers.attention.config import SelfAttentionConfig -from gatr.layers.gatr_block import GATrBlock -from gatr.layers.linear import EquiLinear -from gatr.layers.mlp.config import MLPConfig +from ..layers.attention.config import SelfAttentionConfig +from ..layers.gatr_block import GATrBlock +from ..layers.linear import EquiLinear +from ..layers.mlp.config import MLPConfig class GATr(nn.Module): diff --git a/weaver/nn/model/gatr/primitives/attention.py b/weaver/nn/model/gatr/primitives/attention.py index c69e60d7..926f5f15 100644 --- a/weaver/nn/model/gatr/primitives/attention.py +++ b/weaver/nn/model/gatr/primitives/attention.py @@ -5,7 +5,7 @@ from torch import Tensor from torch.nn.functional import scaled_dot_product_attention as torch_sdpa -from gatr.primitives.invariants import _load_inner_product_factors +from .invariants import _load_inner_product_factors def sdp_attention( q_mv: Tensor, diff --git a/weaver/nn/model/gatr/primitives/bilinear.py b/weaver/nn/model/gatr/primitives/bilinear.py index 0fcaa19f..9115f2ff 100644 --- a/weaver/nn/model/gatr/primitives/bilinear.py +++ b/weaver/nn/model/gatr/primitives/bilinear.py @@ -3,7 +3,7 @@ import torch import clifford -from gatr.utils.einsum import cached_einsum +from ..utils.einsum import cached_einsum @lru_cache() diff --git a/weaver/nn/model/gatr/primitives/dropout.py b/weaver/nn/model/gatr/primitives/dropout.py index 45d010ac..3af07a0d 100644 --- a/weaver/nn/model/gatr/primitives/dropout.py +++ b/weaver/nn/model/gatr/primitives/dropout.py @@ -1,6 +1,6 @@ import torch -from gatr.primitives.linear import grade_project +from .linear import grade_project def grade_dropout(x: torch.Tensor, p: float, training: bool = True) -> torch.Tensor: diff --git a/weaver/nn/model/gatr/primitives/invariants.py b/weaver/nn/model/gatr/primitives/invariants.py index fed9805d..2910a37f 100644 --- a/weaver/nn/model/gatr/primitives/invariants.py +++ b/weaver/nn/model/gatr/primitives/invariants.py @@ -3,8 +3,8 @@ import torch import math -from gatr.primitives.linear import grade_project -from gatr.utils.einsum import cached_einsum +from .linear import grade_project +from ..utils.einsum import cached_einsum @lru_cache() diff --git a/weaver/nn/model/gatr/primitives/linear.py b/weaver/nn/model/gatr/primitives/linear.py index 1e3b279d..140bfb4d 100644 --- a/weaver/nn/model/gatr/primitives/linear.py +++ b/weaver/nn/model/gatr/primitives/linear.py @@ -4,7 +4,7 @@ import clifford import numpy as np -from gatr.utils.einsum import cached_einsum, custom_einsum +from ..utils.einsum import cached_einsum, custom_einsum # switch to decide whether to use the full Lorentz group ('False') # or the special orthochronous Lorentz group ('True') diff --git a/weaver/nn/model/gatr/primitives/normalization.py b/weaver/nn/model/gatr/primitives/normalization.py index 7ac64e55..092217b0 100644 --- a/weaver/nn/model/gatr/primitives/normalization.py +++ b/weaver/nn/model/gatr/primitives/normalization.py @@ -1,6 +1,6 @@ import torch -from gatr.primitives.invariants import abs_squared_norm +from .invariants import abs_squared_norm def equi_layer_norm( From 2d020d11b14028ac20a297c7bad929907680c32c Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Fri, 14 Feb 2025 00:21:54 +0100 Subject: [PATCH 04/29] Fix things: Load geometric_product and linear_basis instead of constructing them using the clifford package, remove refs to removed code --- weaver/nn/model/gatr/__init__.py | 4 +--- weaver/nn/model/gatr/primitives/attention.py | 1 + weaver/nn/model/gatr/primitives/bilinear.py | 11 +++------- .../gatr/primitives/geometric_product.pt | Bin 0 -> 17614 bytes weaver/nn/model/gatr/primitives/linear.py | 20 +++--------------- .../nn/model/gatr/primitives/linear_basis.pt | Bin 0 -> 11445 bytes 6 files changed, 8 insertions(+), 28 deletions(-) create mode 100644 weaver/nn/model/gatr/primitives/geometric_product.pt create mode 100644 weaver/nn/model/gatr/primitives/linear_basis.pt diff --git a/weaver/nn/model/gatr/__init__.py b/weaver/nn/model/gatr/__init__.py index bf2c70d9..8e9578a9 100644 --- a/weaver/nn/model/gatr/__init__.py +++ b/weaver/nn/model/gatr/__init__.py @@ -1,7 +1,5 @@ -from .layers.attention.config import SelfAttentionConfig, CrossAttentionConfig +from .layers.attention.config import SelfAttentionConfig from .layers.mlp.config import MLPConfig -from .nets.axial_gatr import AxialGATr from .nets.gatr import GATr -from .nets.conditional_gatr import ConditionalGATr __version__ = "1.0.0" diff --git a/weaver/nn/model/gatr/primitives/attention.py b/weaver/nn/model/gatr/primitives/attention.py index 926f5f15..d521c28c 100644 --- a/weaver/nn/model/gatr/primitives/attention.py +++ b/weaver/nn/model/gatr/primitives/attention.py @@ -7,6 +7,7 @@ from .invariants import _load_inner_product_factors + def sdp_attention( q_mv: Tensor, k_mv: Tensor, diff --git a/weaver/nn/model/gatr/primitives/bilinear.py b/weaver/nn/model/gatr/primitives/bilinear.py index 9115f2ff..df7e6825 100644 --- a/weaver/nn/model/gatr/primitives/bilinear.py +++ b/weaver/nn/model/gatr/primitives/bilinear.py @@ -1,7 +1,7 @@ from functools import lru_cache +from pathlib import Path import torch -import clifford from ..utils.einsum import cached_einsum @@ -31,13 +31,8 @@ def _load_geometric_product_tensor( if device not in [torch.device("cpu"), "cpu"] and dtype != torch.float32: gmt = _load_geometric_product_tensor() else: - layout, _ = clifford.Cl(1, 3) - gmt = torch.tensor(layout.gmt, dtype=torch.float32) - gmt = torch.transpose(gmt, 1, 0) - - # Convert to dense tensor - # The reason we do that is that einsum is not defined for sparse tensors - gmt = gmt.to_dense() + filename = Path(__file__).parent.resolve() / "geometric_product.pt" + gmt = torch.load(filename).to(torch.float32).to_dense() return gmt.to(device=device, dtype=dtype) diff --git a/weaver/nn/model/gatr/primitives/geometric_product.pt b/weaver/nn/model/gatr/primitives/geometric_product.pt new file mode 100644 index 0000000000000000000000000000000000000000..e74a618f2980a2e90d3f9675b1b661318aa843bd GIT binary patch literal 17614 zcmeI4%Wfk@6o$*WIL2&7LWs!>LM&Jq2{N6`ICd5YutwfkBTTp`Vq2_x% zSY?AoJOV;Oya}%WZ-B%ev1Cb=oeI^*b-HfHmLp9^p1Pg-{{Nh^lS~%X?``ieR;#h4 z=N_xE^LYAgoM)5K;3AujFGu-<@h~4ATzr#sJ3H(v{#)G2JaTv*rpj3!egRz`)r%Hql?RQ z=Lx%$r@No_SnrPbdXZM2unO;0UZi_Jjz-fYiAVWl`fOJJBpb(BJbpYG<@L1o?rHi; zp4Ly(R}cEVz1sNC=7(!g zq5RB+utN405>f^f%ylz1=6esp=R`S*LiaZ2cpE(zie54Z$qjx!;J8HAxa7iei8Tsq z0x0+t2p_hG<1p)Z>;1*1#{^5M8_FSbU(z}V`1h|W6B z(@TTs=e=4dwU2oUsin%d?YroiVzHcTz1YUN=(xn%>g(c6I;PsGOcP6urRZaNE}ngS z`0z2WCViiMK6^|))O89zWcbun^am1X7Wm?C$b2h=>2+H_y%~V&@I^?P2kUFRwQ(x$;Q_JKBTQ}8mT{%d9A$rDK z2;avcK4q?=rQ+DID7DlaU?1w3&IRtb8Lxt=>SBDDo4UuhzfLXH<2rI#g@&4g?3;8< z>yl&BV&`SER~JL-jo|MDSbiZsF{Z8_6mur8A%0A37yltFQw||)mmFklh#rRSrmsiczE$`O zVd^}>xe&gorXg(TQ;vO@CZCc`)%tjZ+Aa*2*d-n0=RAzljYD*GJTVt?Rv~_DbXUzs zYvTDkBF3B7Ie=+XD@@y^hAy$FUcmbsW7^Dzcn{heknc-fo0DC{+{#+rg3mlWfpvYa zfx6_l_(!1Fkp@f#Gnou>2E1M|xW9@d3}KYls?Vv8FP zFWfX*N yyWaZN_WL)G-Rqugb8!&}fBxlc@$h}(`Vf{p7XLCPmixV3anT0b7kWFAY+tyVgge(*UYsc z${qrV+(|=}d+Rv9=RpU%em9gK2OSu|&Lh{sGRha5XcZ&5iLR9zW-kB!$EVUqf=Nds zvED6ki>cDcsk22 zqm~(I6mEh6!2kygh%lo=IU6!tmoPf3tC(AS&R_@ZvKf@l#E&5iT-$R`Z7Suo z+>TK%{(nMcwN}+sO;c+XP1W_PQqz<&HN9H9rznc9tCfANqR}kheZGMrnn;!sO?IMu zoz{qM(ie=KtAmY0ObbA_824Y~iGi3%OXC>Rtz2g?o)L8{gE110#sQ`~vB3a*E5-n# z{>KrfS=wO`ns;Id=l8E+X0vxA$J^=M2}CBk8(&{gy)Uuf1eO}{qbgx=6gkt^4KqXa Mv>-!%bpB}WAA#;8hyVZp literal 0 HcmV?d00001 From e821495bbc78edeec05f77db38a96fd6be0154a4 Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Fri, 14 Feb 2025 00:22:18 +0100 Subject: [PATCH 05/29] Add extra requirements for GATr --- requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements.txt b/requirements.txt index c1381826..df8f5768 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,5 @@ lz4>=3.1.0 xxhash>=1.4.4 tables>=3.6.1 tensorboard>=2.2.0 +einops==0.6.1 # for LGATr (newer einops version has issues with __name__, can be resolved with careful version picking) +opt_einsum>=3.3.0 # for LGATr From 8bd0e9702bf8ac855e0da0356a730e0688218365 Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Fri, 14 Feb 2025 00:23:19 +0100 Subject: [PATCH 06/29] First working version of LGATr wrapper --- weaver/nn/model/LGATr.py | 162 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 162 insertions(+) create mode 100644 weaver/nn/model/LGATr.py diff --git a/weaver/nn/model/LGATr.py b/weaver/nn/model/LGATr.py new file mode 100644 index 00000000..565149e4 --- /dev/null +++ b/weaver/nn/model/LGATr.py @@ -0,0 +1,162 @@ +import torch +from torch import nn + +from .gatr import GATr, SelfAttentionConfig, MLPConfig +from .gatr.interface import ( + embed_vector, + extract_scalar, + get_num_spurions, + embed_spurions, +) + + +class LGATrWrapper(nn.Module): + """ + Wrapper that handles interface to the GATr code + - create dataclasses for attention and mlp + - append spurions (symmetry-breaking) + - interface to geometric algebra + - mean aggregation to extract tagging score + """ + + def __init__( + self, + in_s_channels, + hidden_mv_channels, + hidden_s_channels, + num_classes, + num_blocks, + num_heads, + spurion_token=True, + beam_reference="xyplane", + add_time_reference=True, + two_beams=True, + activation="gelu", + multi_query=False, + increase_hidden_channels=2, + head_scale=False, + double_layernorm=False, + dropout_prob=None, + ): + super().__init__() + + # spurion business + in_mv_channels = 1 + num_spurions = get_num_spurions( + beam_reference, add_time_reference, two_beams=two_beams + ) + self.spurion_token = spurion_token + self.spurion_kwargs = { + "beam_reference": beam_reference, + "add_time_reference": add_time_reference, + "two_beams": two_beams, + } + if not self.spurion_token: + in_mv_channels += num_spurions + + attention = SelfAttentionConfig( + multi_query=multi_query, + num_heads=num_heads, + increase_hidden_channels=increase_hidden_channels, + dropout_prob=dropout_prob, + head_scale=head_scale, + ) + mlp = MLPConfig( + activation=activation, + dropout_prob=dropout_prob, + ) + + self.net = GATr( + in_mv_channels=in_mv_channels, + out_mv_channels=num_classes, + hidden_mv_channels=hidden_mv_channels, + in_s_channels=in_s_channels, + out_s_channels=1, + hidden_s_channels=hidden_s_channels, + attention=attention, + mlp=mlp, + num_blocks=num_blocks, + double_layernorm=double_layernorm, + dropout_prob=dropout_prob, + ) + + def forward(self, x, v, mask): + # reshape input + x = x.transpose(1, 2) # (batch_size, seq_len, num_fts) + v = v.transpose(1, 2) # (batch_size, seq_len, 4) + mask = mask.transpose(1, 2) # (batchsize, seq_len, 1) + + # geometric algebra embedding + fourmomenta = v[:, :, None, [3, 0, 1, 2]] # (px, py, pz, E) -> (E, px, py, pz) + mv = embed_vector(fourmomenta) + s = x + + # spurion business + spurions = embed_spurions(**self.spurion_kwargs) + if self.spurion_token: + # add spurions as extra tokens + # (have to also extend mask and zero-pad scalars) + mask_ones = torch.ones( + mask.shape[0], + spurions.shape[0], + mask.shape[2], + device=mask.device, + dtype=mask.dtype, + ) + mask = torch.cat([mask, mask_ones], dim=1) + + s_zeros = torch.zeros( + s.shape[0], + spurions.shape[0], + s.shape[2], + device=s.device, + dtype=s.dtype, + ) + s = torch.cat([s, s_zeros], dim=1) + + spurions = spurions[None, :, None, :].repeat(mv.shape[0], 1, 1, 1) + mv = torch.cat([mv, spurions], dim=1) + else: + # add spurions as extra mv channels + spurions = spurions[None, None, :, :].repeat(mv.shape[0], mv.shape[1], 1, 1) + mv = torch.cat([mv, spurions], dim=2) + + # reshape mask to broadcast correctly + mask = mask.bool() + mask = mask[:, None, None, :, 0] # (batch_size, 1, 1, seq_len) + + # call network + out_mv, _ = self.net(mv, s, mask) + output = extract_scalar(out_mv)[..., 0] + + # mean aggregation + output = output.mean(dim=1) + return output + + +class LGATrTagger(nn.Module): + """Mimic weaver features""" + + def __init__( + self, + use_amp=False, + for_inference=False, + for_segmentation=False, + **kwargs, + ): + super().__init__() # not support this kind of **kwargs for now + + self.use_amp = use_amp + self.for_inference = for_inference + self.for_segmentation = for_segmentation + self.net = LGATrWrapper(**kwargs) + + def forward(self, x, v=None, mask=None): + with torch.autocast("cuda", enabled=self.use_amp): + # TODO: implement for_segmentation + + output = self.net(x, v, mask) + + if self.for_inference: + output = torch.softmax(output, dim=-1) + return output From 04a5794fc803465aad738821c9e5e700c498f622 Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Fri, 14 Feb 2025 13:02:09 +0100 Subject: [PATCH 07/29] Add for_segmentation option (not sure about this) --- weaver/nn/model/LGATr.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/weaver/nn/model/LGATr.py b/weaver/nn/model/LGATr.py index 565149e4..5b0689f0 100644 --- a/weaver/nn/model/LGATr.py +++ b/weaver/nn/model/LGATr.py @@ -153,10 +153,14 @@ def __init__( def forward(self, x, v=None, mask=None): with torch.autocast("cuda", enabled=self.use_amp): - # TODO: implement for_segmentation - output = self.net(x, v, mask) + if self.for_segmentation: + output = output.transpose(1, 2).contiguous() + if self.for_inference: + output = torch.softmax(output, dim=1) + return output + if self.for_inference: output = torch.softmax(output, dim=-1) return output From 234a9a021dbd4967f7800691ee8de5f67688ddf3 Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Fri, 14 Feb 2025 13:09:06 +0100 Subject: [PATCH 08/29] Use SequenceTrimmer from ParticleTransformer also in LGATr --- weaver/nn/model/LGATr.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/weaver/nn/model/LGATr.py b/weaver/nn/model/LGATr.py index 5b0689f0..23af48c3 100644 --- a/weaver/nn/model/LGATr.py +++ b/weaver/nn/model/LGATr.py @@ -8,6 +8,7 @@ get_num_spurions, embed_spurions, ) +from .ParticleTransformer import SequenceTrimmer class LGATrWrapper(nn.Module): @@ -140,6 +141,7 @@ class LGATrTagger(nn.Module): def __init__( self, use_amp=False, + trim=True, for_inference=False, for_segmentation=False, **kwargs, @@ -149,9 +151,13 @@ def __init__( self.use_amp = use_amp self.for_inference = for_inference self.for_segmentation = for_segmentation + self.trimmer = SequenceTrimmer(enabled=trim and not for_inference) self.net = LGATrWrapper(**kwargs) def forward(self, x, v=None, mask=None): + with torch.no_grad(): + x, v, mask, _ = self.trimmer(x, v, mask) + with torch.autocast("cuda", enabled=self.use_amp): output = self.net(x, v, mask) From 6b3b5691b015d22c488ce2dff8916936a8695aa6 Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Sun, 16 Feb 2025 19:37:24 +0100 Subject: [PATCH 09/29] Move spurions to device (oopsie) --- weaver/nn/model/LGATr.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/weaver/nn/model/LGATr.py b/weaver/nn/model/LGATr.py index 23af48c3..434c6b5d 100644 --- a/weaver/nn/model/LGATr.py +++ b/weaver/nn/model/LGATr.py @@ -93,7 +93,9 @@ def forward(self, x, v, mask): s = x # spurion business - spurions = embed_spurions(**self.spurion_kwargs) + spurions = embed_spurions(**self.spurion_kwargs).to( + device=s.device, dtype=s.dtype + ) if self.spurion_token: # add spurions as extra tokens # (have to also extend mask and zero-pad scalars) From 361f3f38b11293b67fa8e50755ab6ce0e08e4704 Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Sun, 16 Feb 2025 19:38:01 +0100 Subject: [PATCH 10/29] Enforce autocast to float32 for equi_layer_norm (required for mixed-precision training) --- weaver/nn/model/gatr/primitives/invariants.py | 2 + weaver/nn/model/gatr/utils/misc.py | 115 ++++++++++++++++++ 2 files changed, 117 insertions(+) create mode 100644 weaver/nn/model/gatr/utils/misc.py diff --git a/weaver/nn/model/gatr/primitives/invariants.py b/weaver/nn/model/gatr/primitives/invariants.py index 2910a37f..4839bdfe 100644 --- a/weaver/nn/model/gatr/primitives/invariants.py +++ b/weaver/nn/model/gatr/primitives/invariants.py @@ -5,6 +5,7 @@ from .linear import grade_project from ..utils.einsum import cached_einsum +from ..utils.misc import minimum_autocast_precision @lru_cache() @@ -150,6 +151,7 @@ def pin_invariants(x: torch.Tensor, epsilon: float = 0.01) -> torch.Tensor: return torch.cat((x[..., [0]], norms[..., 1:]), dim=-1) # (..., 5) +@minimum_autocast_precision(torch.float32) def abs_squared_norm(x: torch.Tensor) -> torch.Tensor: """Computes a modified version of the squared norm that is positive semidefinite and can therefore be used in layer normalization. diff --git a/weaver/nn/model/gatr/utils/misc.py b/weaver/nn/model/gatr/utils/misc.py new file mode 100644 index 00000000..98be6e5b --- /dev/null +++ b/weaver/nn/model/gatr/utils/misc.py @@ -0,0 +1,115 @@ +from functools import wraps +from itertools import chain +from typing import Any, Callable, List, Literal, Optional, Union + +import torch +from torch import Tensor + + +def minimum_autocast_precision( + min_dtype: torch.dtype = torch.float32, + output: Optional[Union[Literal["low", "high"], torch.dtype]] = None, + which_args: Optional[List[int]] = None, + which_kwargs: Optional[List[str]] = None, +): + """Decorator that ensures input tensors are autocast to a minimum precision. + Only has an effect in autocast-enabled regions. Otherwise, does not change the function. + Only floating-point inputs are modified. Non-tensors, integer tensors, and boolean tensors are + untouched. + Note: AMP is turned on and off separately for CPU and CUDA. This decorator may fail in + the case where both devices are used, with only one of them on AMP. + Parameters + ---------- + min_dtype : dtype + Minimum dtype. Default: float32. + output: None or "low" or "high" or dtype + Specifies which dtypes the outputs should be cast to. Only floating-point Tensor outputs + are affected. If None, the outputs are not modified. If "low", the lowest-precision input + dtype is used. If "high", `min_dtype` or the highest-precision input dtype is used + (whichever is higher). + which_args : None or list of int + If not None, specifies which positional arguments are to be modified. If None (the default), + all positional arguments are modified (if they are Tensors and of a floating-point dtype). + which_kwargs : bool + If not None, specifies which keyword arguments are to be modified. If None (the default), + all keyword arguments are modified (if they are Tensors and of a floating-point dtype). + Returns + ------- + decorator : Callable + Decorator. + """ + + def decorator(func: Callable): + """Decorator that casts input tensors to minimum precision.""" + + def _cast_in(var: Any): + """Casts a single input to at least 32-bit precision.""" + if not isinstance(var, Tensor): + # We don't want to modify non-Tensors + return var + if not var.dtype.is_floating_point: + # Integer / boolean tensors are also not touched + return var + dtype = max(var.dtype, min_dtype, key=lambda dt: torch.finfo(dt).bits) + return var.to(dtype) + + def _cast_out(var: Any, dtype: torch.dtype): + """Casts a single output to desired precision.""" + if not isinstance(var, Tensor): + # We don't want to modify non-Tensors + return var + if not var.dtype.is_floating_point: + # Integer / boolean tensors are also not touched + return var + return var.to(dtype) + + @wraps(func) + def decorated_func(*args: Any, **kwargs: Any): + """Decorated func.""" + # Only change dtypes in autocast-enabled regions + if not (torch.is_autocast_enabled() or torch.is_autocast_cpu_enabled()): + # NB: torch.is_autocast_enabled() only checks for GPU autocast + # See https://github.com/pytorch/pytorch/issues/110966 + return func(*args, **kwargs) + # Cast inputs to at least 32 bit + mod_args = [ + _cast_in(arg) + for i, arg in enumerate(args) + if which_args is None or i in which_args + ] + mod_kwargs = { + key: _cast_in(val) + for key, val in kwargs.items() + if which_kwargs is None or key in which_kwargs + } + # Call function w/o autocast enabled + with torch.autocast(device_type="cuda", enabled=False), torch.autocast( + device_type="cpu", enabled=False + ): + outputs = func(*mod_args, **mod_kwargs) + # Cast outputs to correct dtype + if output is None: + return outputs + if output in ["low", "high"]: + in_dtypes = [ + arg.dtype + for arg in chain(args, kwargs.values()) + if isinstance(arg, Tensor) and arg.dtype.is_floating_point + ] + assert len(in_dtypes) + if output == "low": + out_dtype = min( + [min_dtype] + in_dtypes, key=lambda dt: torch.finfo(dt).bits + ) + else: + out_dtype = max(in_dtypes, key=lambda dt: torch.finfo(dt).bits) + else: + out_dtype = output + if isinstance(outputs, tuple): + return (_cast_out(val, out_dtype) for val in outputs) + else: + return _cast_out(outputs, out_dtype) + + return decorated_func + + return decorator From bbb5dcfe02da517c2a5db578277f4ac86b2fb8a5 Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Sun, 16 Feb 2025 20:40:36 +0100 Subject: [PATCH 11/29] Add support for full Lorentz group (default is fully connected subgroup) --- weaver/nn/model/gatr/layers/linear.py | 9 +++++++-- weaver/nn/model/gatr/primitives/linear.py | 7 ++++++- .../model/gatr/primitives/linear_basis_full.pt | Bin 0 -> 6335 bytes ...{linear_basis.pt => linear_basis_subgroup.pt} | Bin 4 files changed, 13 insertions(+), 3 deletions(-) create mode 100644 weaver/nn/model/gatr/primitives/linear_basis_full.pt rename weaver/nn/model/gatr/primitives/{linear_basis.pt => linear_basis_subgroup.pt} (100%) diff --git a/weaver/nn/model/gatr/layers/linear.py b/weaver/nn/model/gatr/layers/linear.py index 8033d715..226c1134 100644 --- a/weaver/nn/model/gatr/layers/linear.py +++ b/weaver/nn/model/gatr/layers/linear.py @@ -7,13 +7,12 @@ from torch import nn from ..interface import embed_scalar -from ..primitives.linear import equi_linear, USE_FULLY_CONNECTED_SUBGROUP +from ..primitives.linear import equi_linear # switch to mix pseudoscalar multivector components directly into scalar components # this only makes sense when working with the special orthochronous Lorentz group, # Note: This is an efficiency boost, the same action can be achieved with an extra linear layer MIX_MVPSEUDOSCALAR_INTO_SCALAR = True -NUM_PIN_LINEAR_BASIS_ELEMENTS = 10 if USE_FULLY_CONNECTED_SUBGROUP else 5 class EquiLinear(nn.Module): @@ -87,6 +86,9 @@ def __init__( self._in_mv_channels = in_mv_channels # MV -> MV + from ..primitives.linear import USE_FULLY_CONNECTED_SUBGROUP + + NUM_PIN_LINEAR_BASIS_ELEMENTS = 10 if USE_FULLY_CONNECTED_SUBGROUP else 5 self.weight = nn.Parameter( torch.empty( (out_mv_channels, in_mv_channels, NUM_PIN_LINEAR_BASIS_ELEMENTS) @@ -289,6 +291,9 @@ def _compute_init_factors( ) # Individual factors for each multivector component (could be tuned for performance) + from ..primitives.linear import USE_FULLY_CONNECTED_SUBGROUP + + NUM_PIN_LINEAR_BASIS_ELEMENTS = 10 if USE_FULLY_CONNECTED_SUBGROUP else 5 mv_component_factors = torch.ones(NUM_PIN_LINEAR_BASIS_ELEMENTS) return mv_component_factors, mv_factor, mvs_bias_shift, s_factor diff --git a/weaver/nn/model/gatr/primitives/linear.py b/weaver/nn/model/gatr/primitives/linear.py index d11b4352..234be38b 100644 --- a/weaver/nn/model/gatr/primitives/linear.py +++ b/weaver/nn/model/gatr/primitives/linear.py @@ -37,7 +37,12 @@ def _compute_pin_equi_linear_basis( if device not in [torch.device("cpu"), "cpu"] and dtype != torch.float32: basis = _compute_pin_equi_linear_basis(normalize=normalize) else: - filename = Path(__file__).parent.resolve() / "linear_basis.pt" + file = ( + "linear_basis_subgroup.pt" + if USE_FULLY_CONNECTED_SUBGROUP + else "linear_basis_full.pt" + ) + filename = Path(__file__).parent.resolve() / file basis = torch.load(filename).to(torch.float32).to_dense() return basis.to(device=device, dtype=dtype) diff --git a/weaver/nn/model/gatr/primitives/linear_basis_full.pt b/weaver/nn/model/gatr/primitives/linear_basis_full.pt new file mode 100644 index 0000000000000000000000000000000000000000..8dd7f63fde93d7c0165591f96a6c9e514805ba38 GIT binary patch literal 6335 zcmeHM&2G~`5MC#B65LA@DbsafzrJ_m}vJk5Z69kGt$a3sev9RKa@Q@6#KMYJIxic5%I-;O|er0Va31u-|QTTkLoE zk&%QWJ9cx{aNhT9h$F^g!n1sO{OeW?GH@@>vWa?P7S8If?%17z6th%UjtET-A)z$M zfMh^2up$H3v4@liV?`Qem}EdQPyz1zU0o;o# z*g+4(3fh Date: Mon, 17 Feb 2025 08:58:21 +0100 Subject: [PATCH 12/29] Support gradient checkpointing --- weaver/nn/model/LGATr.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/weaver/nn/model/LGATr.py b/weaver/nn/model/LGATr.py index 434c6b5d..d34596df 100644 --- a/weaver/nn/model/LGATr.py +++ b/weaver/nn/model/LGATr.py @@ -38,6 +38,7 @@ def __init__( head_scale=False, double_layernorm=False, dropout_prob=None, + checkpoint_blocks=False, ): super().__init__() @@ -79,6 +80,7 @@ def __init__( num_blocks=num_blocks, double_layernorm=double_layernorm, dropout_prob=dropout_prob, + checkpoint_blocks=checkpoint_blocks, ) def forward(self, x, v, mask): From dcda0483140a21d49a40ee2ec6c7ebeaa9a58915 Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Mon, 17 Feb 2025 18:51:50 +0100 Subject: [PATCH 13/29] Remove unused scalar output layer --- weaver/nn/model/LGATr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/weaver/nn/model/LGATr.py b/weaver/nn/model/LGATr.py index d34596df..61d7c274 100644 --- a/weaver/nn/model/LGATr.py +++ b/weaver/nn/model/LGATr.py @@ -73,7 +73,7 @@ def __init__( out_mv_channels=num_classes, hidden_mv_channels=hidden_mv_channels, in_s_channels=in_s_channels, - out_s_channels=1, + out_s_channels=None, hidden_s_channels=hidden_s_channels, attention=attention, mlp=mlp, From 04e5437daa09ec4c047a88cbd914de75fe89c8e8 Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Tue, 18 Feb 2025 10:49:08 +0100 Subject: [PATCH 14/29] Prepend spurions (instead of appending them) -> easier to keep track --- weaver/nn/model/LGATr.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/weaver/nn/model/LGATr.py b/weaver/nn/model/LGATr.py index 61d7c274..0d02b16a 100644 --- a/weaver/nn/model/LGATr.py +++ b/weaver/nn/model/LGATr.py @@ -99,7 +99,7 @@ def forward(self, x, v, mask): device=s.device, dtype=s.dtype ) if self.spurion_token: - # add spurions as extra tokens + # prepend spurions as extra tokens # (have to also extend mask and zero-pad scalars) mask_ones = torch.ones( mask.shape[0], @@ -108,7 +108,7 @@ def forward(self, x, v, mask): device=mask.device, dtype=mask.dtype, ) - mask = torch.cat([mask, mask_ones], dim=1) + mask = torch.cat([mask_ones, mask], dim=1) s_zeros = torch.zeros( s.shape[0], @@ -117,10 +117,10 @@ def forward(self, x, v, mask): device=s.device, dtype=s.dtype, ) - s = torch.cat([s, s_zeros], dim=1) + s = torch.cat([s_zeros, s], dim=1) spurions = spurions[None, :, None, :].repeat(mv.shape[0], 1, 1, 1) - mv = torch.cat([mv, spurions], dim=1) + mv = torch.cat([spurions, mv], dim=1) else: # add spurions as extra mv channels spurions = spurions[None, None, :, :].repeat(mv.shape[0], mv.shape[1], 1, 1) From 4832057a748f2d117f4bd528c9e763765c783286 Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Mon, 3 Mar 2025 12:59:21 +0100 Subject: [PATCH 15/29] Fix mean-aggregation (set output of padded particles to zero before evaluating the mean) --- weaver/nn/model/LGATr.py | 1 + 1 file changed, 1 insertion(+) diff --git a/weaver/nn/model/LGATr.py b/weaver/nn/model/LGATr.py index 0d02b16a..2754ebe4 100644 --- a/weaver/nn/model/LGATr.py +++ b/weaver/nn/model/LGATr.py @@ -135,6 +135,7 @@ def forward(self, x, v, mask): output = extract_scalar(out_mv)[..., 0] # mean aggregation + output[~mask[:, 0, 0]] = 0.0 output = output.mean(dim=1) return output From d0c2c3a12dd70604bce7a403eeeab83b32f8971d Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Mon, 3 Mar 2025 13:20:38 +0100 Subject: [PATCH 16/29] Add global_token option --- weaver/nn/model/LGATr.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/weaver/nn/model/LGATr.py b/weaver/nn/model/LGATr.py index 2754ebe4..f7912b0a 100644 --- a/weaver/nn/model/LGATr.py +++ b/weaver/nn/model/LGATr.py @@ -28,6 +28,7 @@ def __init__( num_classes, num_blocks, num_heads, + global_token=True, spurion_token=True, beam_reference="xyplane", add_time_reference=True, @@ -44,17 +45,19 @@ def __init__( # spurion business in_mv_channels = 1 + self.global_token = global_token + self.spurion_token = spurion_token + num_spurions = get_num_spurions( beam_reference, add_time_reference, two_beams=two_beams ) - self.spurion_token = spurion_token + if not self.spurion_token: + in_mv_channels += num_spurions self.spurion_kwargs = { "beam_reference": beam_reference, "add_time_reference": add_time_reference, "two_beams": two_beams, } - if not self.spurion_token: - in_mv_channels += num_spurions attention = SelfAttentionConfig( multi_query=multi_query, @@ -126,6 +129,16 @@ def forward(self, x, v, mask): spurions = spurions[None, None, :, :].repeat(mv.shape[0], mv.shape[1], 1, 1) mv = torch.cat([mv, spurions], dim=2) + # global token business + if self.global_token: + # prepend global token as first particle in the list + global_token = torch.zeros_like(mv[:, [0], :, :]) + mv = torch.cat((global_token, mv), dim=1) + mask_ones = torch.ones_like(mask[:, [0]]) + mask = torch.cat((mask_ones, mask), dim=1) + s_zeros = torch.zeros_like(s[:, [0]]) + s = torch.cat((s_zeros, s), dim=1) + # reshape mask to broadcast correctly mask = mask.bool() mask = mask[:, None, None, :, 0] # (batch_size, 1, 1, seq_len) @@ -134,9 +147,13 @@ def forward(self, x, v, mask): out_mv, _ = self.net(mv, s, mask) output = extract_scalar(out_mv)[..., 0] - # mean aggregation - output[~mask[:, 0, 0]] = 0.0 - output = output.mean(dim=1) + # aggregation + if self.global_token: + output = output[:, 0] + else: + # mean aggregation + output[~mask[:, 0, 0]] = 0.0 + output = output.mean(dim=1) return output From e6018a903565889d4ecd8aa105eb4ad4c1ffd1d9 Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Mon, 3 Mar 2025 13:30:29 +0100 Subject: [PATCH 17/29] Clean up documentation etc in LGATr.py --- weaver/nn/model/LGATr.py | 36 ++++++++++-------------------------- 1 file changed, 10 insertions(+), 26 deletions(-) diff --git a/weaver/nn/model/LGATr.py b/weaver/nn/model/LGATr.py index f7912b0a..8da15aaf 100644 --- a/weaver/nn/model/LGATr.py +++ b/weaver/nn/model/LGATr.py @@ -17,7 +17,7 @@ class LGATrWrapper(nn.Module): - create dataclasses for attention and mlp - append spurions (symmetry-breaking) - interface to geometric algebra - - mean aggregation to extract tagging score + - extract tagging score with global token or mean-aggregation """ def __init__( @@ -92,44 +92,28 @@ def forward(self, x, v, mask): v = v.transpose(1, 2) # (batch_size, seq_len, 4) mask = mask.transpose(1, 2) # (batchsize, seq_len, 1) - # geometric algebra embedding + # embed data into geometric algebra fourmomenta = v[:, :, None, [3, 0, 1, 2]] # (px, py, pz, E) -> (E, px, py, pz) - mv = embed_vector(fourmomenta) - s = x + mv = embed_vector(fourmomenta) # (batch_size, seq_len, 1, 16) + s = x # (batch_size, seq_len, num_fts) - # spurion business + # symmetry breaking with spurions spurions = embed_spurions(**self.spurion_kwargs).to( device=s.device, dtype=s.dtype ) if self.spurion_token: - # prepend spurions as extra tokens - # (have to also extend mask and zero-pad scalars) - mask_ones = torch.ones( - mask.shape[0], - spurions.shape[0], - mask.shape[2], - device=mask.device, - dtype=mask.dtype, - ) + # prepend spurions as extra particles in the list + mask_ones = torch.ones_like(mask[:, [0]]).repeat(1, spurions.shape[0], 1) mask = torch.cat([mask_ones, mask], dim=1) - - s_zeros = torch.zeros( - s.shape[0], - spurions.shape[0], - s.shape[2], - device=s.device, - dtype=s.dtype, - ) + s_zeros = torch.zeros_like(s[:, [0]]).repeat(1, spurions.shape[0], 1) s = torch.cat([s_zeros, s], dim=1) - spurions = spurions[None, :, None, :].repeat(mv.shape[0], 1, 1, 1) mv = torch.cat([spurions, mv], dim=1) else: - # add spurions as extra mv channels + # append spurions as extra mv channels spurions = spurions[None, None, :, :].repeat(mv.shape[0], mv.shape[1], 1, 1) mv = torch.cat([mv, spurions], dim=2) - # global token business if self.global_token: # prepend global token as first particle in the list global_token = torch.zeros_like(mv[:, [0], :, :]) @@ -145,7 +129,7 @@ def forward(self, x, v, mask): # call network out_mv, _ = self.net(mv, s, mask) - output = extract_scalar(out_mv)[..., 0] + output = extract_scalar(out_mv)[..., 0] # (batch_size, seq_len, num_classes) # aggregation if self.global_token: From 469542a8a9d863c16dba4c0d13b1431e322bd4bf Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Wed, 5 Mar 2025 11:53:29 +0100 Subject: [PATCH 18/29] Add stabilizing LayerNorm --- weaver/nn/model/gatr/layers/attention/qkv.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/weaver/nn/model/gatr/layers/attention/qkv.py b/weaver/nn/model/gatr/layers/attention/qkv.py index 17179a33..9dc077fc 100644 --- a/weaver/nn/model/gatr/layers/attention/qkv.py +++ b/weaver/nn/model/gatr/layers/attention/qkv.py @@ -2,8 +2,9 @@ from einops import rearrange from torch import nn -from .config import SelfAttentionConfig -from ..linear import EquiLinear +from gatr.layers.attention.config import SelfAttentionConfig +from gatr.layers.linear import EquiLinear +from gatr.layers.layer_norm import EquiLayerNorm class QKVModule(nn.Module): @@ -25,6 +26,7 @@ def __init__(self, config: SelfAttentionConfig): if config.in_s_channels is None else 3 * config.hidden_s_channels * config.num_heads, ) + self.norm_qkv = EquiLayerNorm() self.config = config def forward( @@ -96,6 +98,10 @@ def forward( else: q_s, k_s, v_s = None, None, None + q_mv, q_s = self.norm_qkv(q_mv, scalars=q_s) + k_mv, k_s = self.norm_qkv(k_mv, scalars=k_s) + v_mv, v_s = self.norm_qkv(v_mv, scalars=v_s) + return q_mv, k_mv, v_mv, q_s, k_s, v_s @@ -132,6 +138,7 @@ def __init__(self, config: SelfAttentionConfig): in_s_channels=config.in_s_channels, out_s_channels=config.hidden_s_channels, ) + self.norm_qkv = EquiLayerNorm() self.config = config def forward( @@ -224,4 +231,8 @@ def forward( else: q_s, k_s, v_s = None, None, None + q_mv, q_s = self.norm_qkv(q_mv, scalars=q_s) + k_mv, k_s = self.norm_qkv(k_mv, scalars=k_s) + v_mv, v_s = self.norm_qkv(v_mv, scalars=v_s) + return q_mv, k_mv, v_mv, q_s, k_s, v_s From 0a60774ee746b63e89d5b7426eef2b25b61a16bd Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Sat, 15 Mar 2025 22:11:24 +0100 Subject: [PATCH 19/29] Delete gatr/ folder again (will soon be imported) --- weaver/nn/model/gatr/__init__.py | 5 - weaver/nn/model/gatr/interface/__init__.py | 3 - weaver/nn/model/gatr/interface/scalar.py | 42 -- weaver/nn/model/gatr/interface/spurions.py | 132 ------ weaver/nn/model/gatr/interface/vector.py | 46 -- weaver/nn/model/gatr/layers/__init__.py | 10 - .../model/gatr/layers/attention/__init__.py | 2 - .../model/gatr/layers/attention/attention.py | 76 ---- .../nn/model/gatr/layers/attention/config.py | 92 ---- weaver/nn/model/gatr/layers/attention/qkv.py | 238 ----------- .../gatr/layers/attention/self_attention.py | 147 ------- weaver/nn/model/gatr/layers/dropout.py | 51 --- weaver/nn/model/gatr/layers/gatr_block.py | 143 ------- weaver/nn/model/gatr/layers/layer_norm.py | 70 ---- weaver/nn/model/gatr/layers/linear.py | 393 ------------------ weaver/nn/model/gatr/layers/mlp/__init__.py | 4 - weaver/nn/model/gatr/layers/mlp/config.py | 44 -- .../gatr/layers/mlp/geometric_bilinears.py | 105 ----- weaver/nn/model/gatr/layers/mlp/mlp.py | 106 ----- .../model/gatr/layers/mlp/nonlinearities.py | 65 --- weaver/nn/model/gatr/nets/__init__.py | 1 - weaver/nn/model/gatr/nets/gatr.py | 182 -------- weaver/nn/model/gatr/primitives/__init__.py | 17 - weaver/nn/model/gatr/primitives/attention.py | 112 ----- weaver/nn/model/gatr/primitives/bilinear.py | 62 --- weaver/nn/model/gatr/primitives/dropout.py | 36 -- .../gatr/primitives/geometric_product.pt | Bin 17614 -> 0 bytes weaver/nn/model/gatr/primitives/invariants.py | 173 -------- weaver/nn/model/gatr/primitives/linear.py | 187 --------- .../gatr/primitives/linear_basis_full.pt | Bin 6335 -> 0 bytes .../gatr/primitives/linear_basis_subgroup.pt | Bin 11445 -> 0 bytes .../model/gatr/primitives/nonlinearities.py | 79 ---- .../nn/model/gatr/primitives/normalization.py | 47 --- weaver/nn/model/gatr/utils/__init__.py | 0 weaver/nn/model/gatr/utils/einsum.py | 44 -- weaver/nn/model/gatr/utils/misc.py | 115 ----- 36 files changed, 2829 deletions(-) delete mode 100644 weaver/nn/model/gatr/__init__.py delete mode 100644 weaver/nn/model/gatr/interface/__init__.py delete mode 100644 weaver/nn/model/gatr/interface/scalar.py delete mode 100644 weaver/nn/model/gatr/interface/spurions.py delete mode 100644 weaver/nn/model/gatr/interface/vector.py delete mode 100644 weaver/nn/model/gatr/layers/__init__.py delete mode 100644 weaver/nn/model/gatr/layers/attention/__init__.py delete mode 100644 weaver/nn/model/gatr/layers/attention/attention.py delete mode 100644 weaver/nn/model/gatr/layers/attention/config.py delete mode 100644 weaver/nn/model/gatr/layers/attention/qkv.py delete mode 100644 weaver/nn/model/gatr/layers/attention/self_attention.py delete mode 100644 weaver/nn/model/gatr/layers/dropout.py delete mode 100644 weaver/nn/model/gatr/layers/gatr_block.py delete mode 100644 weaver/nn/model/gatr/layers/layer_norm.py delete mode 100644 weaver/nn/model/gatr/layers/linear.py delete mode 100644 weaver/nn/model/gatr/layers/mlp/__init__.py delete mode 100644 weaver/nn/model/gatr/layers/mlp/config.py delete mode 100644 weaver/nn/model/gatr/layers/mlp/geometric_bilinears.py delete mode 100644 weaver/nn/model/gatr/layers/mlp/mlp.py delete mode 100644 weaver/nn/model/gatr/layers/mlp/nonlinearities.py delete mode 100644 weaver/nn/model/gatr/nets/__init__.py delete mode 100644 weaver/nn/model/gatr/nets/gatr.py delete mode 100644 weaver/nn/model/gatr/primitives/__init__.py delete mode 100644 weaver/nn/model/gatr/primitives/attention.py delete mode 100644 weaver/nn/model/gatr/primitives/bilinear.py delete mode 100644 weaver/nn/model/gatr/primitives/dropout.py delete mode 100644 weaver/nn/model/gatr/primitives/geometric_product.pt delete mode 100644 weaver/nn/model/gatr/primitives/invariants.py delete mode 100644 weaver/nn/model/gatr/primitives/linear.py delete mode 100644 weaver/nn/model/gatr/primitives/linear_basis_full.pt delete mode 100644 weaver/nn/model/gatr/primitives/linear_basis_subgroup.pt delete mode 100644 weaver/nn/model/gatr/primitives/nonlinearities.py delete mode 100644 weaver/nn/model/gatr/primitives/normalization.py delete mode 100644 weaver/nn/model/gatr/utils/__init__.py delete mode 100644 weaver/nn/model/gatr/utils/einsum.py delete mode 100644 weaver/nn/model/gatr/utils/misc.py diff --git a/weaver/nn/model/gatr/__init__.py b/weaver/nn/model/gatr/__init__.py deleted file mode 100644 index 8e9578a9..00000000 --- a/weaver/nn/model/gatr/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .layers.attention.config import SelfAttentionConfig -from .layers.mlp.config import MLPConfig -from .nets.gatr import GATr - -__version__ = "1.0.0" diff --git a/weaver/nn/model/gatr/interface/__init__.py b/weaver/nn/model/gatr/interface/__init__.py deleted file mode 100644 index 995d0750..00000000 --- a/weaver/nn/model/gatr/interface/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .vector import embed_vector, extract_vector -from .scalar import embed_scalar, extract_scalar -from .spurions import embed_spurions, get_num_spurions diff --git a/weaver/nn/model/gatr/interface/scalar.py b/weaver/nn/model/gatr/interface/scalar.py deleted file mode 100644 index b248a513..00000000 --- a/weaver/nn/model/gatr/interface/scalar.py +++ /dev/null @@ -1,42 +0,0 @@ -import torch - - -def embed_scalar(scalars: torch.Tensor) -> torch.Tensor: - """Embeds a scalar tensor into multivectors. - - Parameters - ---------- - scalars: torch.Tensor with shape (..., 1) - Scalar inputs. - - Returns - ------- - multivectors: torch.Tensor with shape (..., 16) - Multivector outputs. `multivectors[..., [0]]` is the same as `scalars`. The other components - are zero. - """ - - non_scalar_shape = list(scalars.shape[:-1]) + [15] - non_scalar_components = torch.zeros( - non_scalar_shape, device=scalars.device, dtype=scalars.dtype - ) - embedding = torch.cat((scalars, non_scalar_components), dim=-1) - - return embedding - - -def extract_scalar(multivectors: torch.Tensor) -> torch.Tensor: - """Extracts scalar components from multivectors. - - Parameters - ---------- - multivectors: torch.Tensor with shape (..., 16) - Multivector inputs. - - Returns - ------- - scalars: torch.Tensor with shape (..., 1) - Scalar component of multivectors. - """ - - return multivectors[..., [0]] diff --git a/weaver/nn/model/gatr/interface/spurions.py b/weaver/nn/model/gatr/interface/spurions.py deleted file mode 100644 index 69fde990..00000000 --- a/weaver/nn/model/gatr/interface/spurions.py +++ /dev/null @@ -1,132 +0,0 @@ -import torch -from .vector import embed_vector - - -def get_num_spurions( - beam_reference, - add_time_reference, - two_beams=True, - add_xzplane=False, - add_yzplane=False, -): - """ - Compute how many reference multivectors/spurions a given configuration will have - - Parameters - ---------- - beam_reference: str - Different options for adding a beam_reference - Options: "lightlike", "spacelike", "timelike", "xyplane" - add_time_reference: bool - Whether to add the time direction as a reference to the network - two_beams: bool - Whether we only want (x, 0, 0, 1) or both (x, 0, 0, +/- 1) for the beam - add_xzplane: bool - Whether to add the x-z-plane as a reference to the network - add_yzplane: bool - Whether to add the y-z-plane as a reference to the network - - Returns - ------- - num_spurions: int - Number of spurions - """ - num_spurions = 0 - if beam_reference in ["lightlike", "spacelike", "timelike"]: - num_spurions += 2 if two_beams else 1 - elif beam_reference == "xyplane": - num_spurions += 1 - if add_xzplane: - num_spurions += 1 - if add_yzplane: - num_spurions += 1 - if add_time_reference: - num_spurions += 1 - return num_spurions - - -def embed_spurions( - beam_reference, - add_time_reference, - two_beams=True, - add_xzplane=False, - add_yzplane=False, - device="cpu", - dtype=torch.float32, -): - """ - Construct a list of reference multivectors/spurions for symmetry breaking - - Parameters - ---------- - beam_reference: str - Different options for adding a beam_reference - Options: "lightlike", "spacelike", "timelike", "xyplane" - add_time_reference: bool - Whether to add the time direction as a reference to the network - two_beams: bool - Whether we only want (x, 0, 0, 1) or both (x, 0, 0, +/- 1) for the beam - add_xzplane: bool - Whether to add the x-z-plane as a reference to the network - add_yzplane: bool - Whether to add the y-z-plane as a reference to the network - device - dtype - - Returns - ------- - spurions: torch.tensor with shape (n_spurions, 16) - spurion embedded as multivector object - """ - kwargs = {"device": device, "dtype": dtype} - - if beam_reference in ["lightlike", "spacelike", "timelike"]: - # add another 4-momentum - if beam_reference == "lightlike": - beam = [1, 0, 0, 1] - elif beam_reference == "timelike": - beam = [2**0.5, 0, 0, 1] - elif beam_reference == "spacelike": - beam = [0, 0, 0, 1] - beam = torch.tensor(beam, **kwargs).reshape(1, 4) - beam = embed_vector(beam) - if two_beams: - beam2 = beam.clone() - beam2[..., 4] = -1 # flip pz - beam = torch.cat((beam, beam2), dim=0) - - elif beam_reference == "xyplane": - # add the x-y-plane, embedded as a bivector - # convention for bivector components: [tx, ty, tz, xy, xz, yz] - beam = torch.zeros(1, 16, **kwargs) - beam[..., 8] = 1 - - elif beam_reference is None: - beam = torch.empty(0, 16, **kwargs) - - else: - raise ValueError(f"beam_reference {beam_reference} not implemented") - - if add_xzplane: - # add the x-z-plane, embedded as a bivector - xzplane = torch.zeros(1, 16, **kwargs) - xzplane[..., 10] = 1 - else: - xzplane = torch.empty(0, 16, **kwargs) - - if add_yzplane: - # add the y-z-plane, embedded as a bivector - yzplane = torch.zeros(1, 16, **kwargs) - yzplane[..., 9] = 1 - else: - yzplane = torch.empty(0, 16, **kwargs) - - if add_time_reference: - time = [1, 0, 0, 0] - time = torch.tensor(time, **kwargs).reshape(1, 4) - time = embed_vector(time) - else: - time = torch.empty(0, 16, **kwargs) - - spurions = torch.cat((beam, xzplane, yzplane, time), dim=-2) - return spurions diff --git a/weaver/nn/model/gatr/interface/vector.py b/weaver/nn/model/gatr/interface/vector.py deleted file mode 100644 index 54206cb6..00000000 --- a/weaver/nn/model/gatr/interface/vector.py +++ /dev/null @@ -1,46 +0,0 @@ -import torch - - -def embed_vector(vector: torch.Tensor) -> torch.Tensor: - """Embeds Lorentz vectors in multivectors. - - Parameters - ---------- - vector : torch.Tensor with shape (..., 4) - Lorentz vector - - Returns - ------- - multivector : torch.Tensor with shape (..., 16) - Embedding into multivector. - """ - - # Create multivector tensor with same batch shape, same device, same dtype as input - batch_shape = vector.shape[:-1] - multivector = torch.zeros( - *batch_shape, 16, dtype=vector.dtype, device=vector.device - ) - - # Embedding into Lorentz vectors - multivector[..., 1:5] = vector - - return multivector - - -def extract_vector(multivector: torch.Tensor) -> torch.Tensor: - """Given a multivector, extract a Lorentz vector. - - Parameters - ---------- - multivector : torch.Tensor with shape (..., 16) - Multivector. - - Returns - ------- - vector : torch.Tensor with shape (..., 4) - Lorentz vector - """ - - vector = multivector[..., 1:5] - - return vector diff --git a/weaver/nn/model/gatr/layers/__init__.py b/weaver/nn/model/gatr/layers/__init__.py deleted file mode 100644 index 7757d6fa..00000000 --- a/weaver/nn/model/gatr/layers/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -from .attention.config import SelfAttentionConfig -from .attention.self_attention import SelfAttention -from .dropout import GradeDropout -from .layer_norm import EquiLayerNorm -from .linear import EquiLinear -from .mlp.geometric_bilinears import GeometricBilinear -from .mlp.mlp import GeoMLP -from .mlp.config import MLPConfig -from .mlp.nonlinearities import ScalarGatedNonlinearity -from .gatr_block import GATrBlock diff --git a/weaver/nn/model/gatr/layers/attention/__init__.py b/weaver/nn/model/gatr/layers/attention/__init__.py deleted file mode 100644 index 8c76bbb4..00000000 --- a/weaver/nn/model/gatr/layers/attention/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .config import SelfAttentionConfig -from .self_attention import SelfAttention diff --git a/weaver/nn/model/gatr/layers/attention/attention.py b/weaver/nn/model/gatr/layers/attention/attention.py deleted file mode 100644 index 94a654ad..00000000 --- a/weaver/nn/model/gatr/layers/attention/attention.py +++ /dev/null @@ -1,76 +0,0 @@ -"""Self-attention layers.""" - -from torch import nn - -from .config import SelfAttentionConfig -from ...primitives.attention import sdp_attention - - -class GeometricAttention(nn.Module): - """Geometric attention layer. - - This is the main attention mechanism used in L-GATr. - - Given multivector and scalar queries, keys, and values, this layer computes: - - ``` - attn_weights[..., i, j] = softmax_j[ - ga_inner_product(q_mv[..., i, :, :], k_mv[..., j, :, :]) - + euclidean_inner_product(q_s[..., i, :], k_s[..., j, :]) - ] - out_mv[..., i, c, :] = sum_j attn_weights[..., i, j] v_mv[..., j, c, :] / norm - out_s[..., i, c] = sum_j attn_weights[..., i, j] v_s[..., j, c] / norm - ``` - - Parameters - ---------- - config : SelfAttentionConfig - Attention configuration. - """ - - def __init__(self, config: SelfAttentionConfig) -> None: - super().__init__() - - def forward(self, q_mv, k_mv, v_mv, q_s, k_s, v_s, attention_mask=None): - """Forward pass through geometric attention. - - Given multivector and scalar queries, keys, and values, this forward pass computes: - - ``` - attn_weights[..., i, j] = softmax_j[ - ga_inner_product(q_mv[..., i, :, :], k_mv[..., j, :, :]) - + euclidean_inner_product(q_s[..., i, :], k_s[..., j, :]) - ] - out_mv[..., i, c, :] = sum_j attn_weights[..., i, j] v_mv[..., j, c, :] / norm - out_s[..., i, c] = sum_j attn_weights[..., i, j] v_s[..., j, c] / norm - ``` - - Parameters - ---------- - q_mv : Tensor with shape (..., num_items_out, num_mv_channels_in, 16) - Queries, multivector part. - k_mv : Tensor with shape (..., num_items_in, num_mv_channels_in, 16) - Keys, multivector part. - v_mv : Tensor with shape (..., num_items_in, num_mv_channels_out, 16) - Values, multivector part. - q_s : Tensor with shape (..., heads, num_items_out, num_s_channels_in) - Queries, scalar part. - k_s : Tensor with shape (..., heads, num_items_in, num_s_channels_in) - Keys, scalar part. - v_s : Tensor with shape (..., heads, num_items_in, num_s_channels_out) - Values, scalar part. - attention_mask: None or Tensor or AttentionBias - Optional attention mask. - """ - - h_mv, h_s = sdp_attention( - q_mv, - k_mv, - v_mv, - q_s, - k_s, - v_s, - attn_mask=attention_mask, - ) - - return h_mv, h_s diff --git a/weaver/nn/model/gatr/layers/attention/config.py b/weaver/nn/model/gatr/layers/attention/config.py deleted file mode 100644 index 76a12e9a..00000000 --- a/weaver/nn/model/gatr/layers/attention/config.py +++ /dev/null @@ -1,92 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from typing import Any, Mapping, Optional - - -@dataclass -class SelfAttentionConfig: - """Configuration for attention. - - Parameters - ---------- - in_mv_channels : int - Number of input multivector channels. - out_mv_channels : int - Number of output multivector channels. - num_heads : int - Number of attention heads. - in_s_channels : int - Input scalar channels. If None, no scalars are expected nor returned. - out_s_channels : int - Output scalar channels. If None, no scalars are expected nor returned. - additional_qk_mv_channels : int - Whether additional multivector features for the keys and queries will be provided. - additional_qk_s_channels : int - Whether additional scalar features for the keys and queries will be provided. - multi_query: bool - Whether to do multi-query attention - output_init : str - Initialization scheme for final linear layer - increase_hidden_channels : int - Factor by which to increase the number of hidden channels (both multivectors and scalars) - dropout_prob : float or None - Dropout probability - head_scale: bool - Whether to use HeadScaleMHA following the NormFormer (https://arxiv.org/pdf/2110.09456) - """ - - multi_query: bool = True - in_mv_channels: Optional[int] = None - out_mv_channels: Optional[int] = None - in_s_channels: Optional[int] = None - out_s_channels: Optional[int] = None - num_heads: int = 8 - additional_qk_mv_channels: int = 0 - additional_qk_s_channels: int = 0 - output_init: str = "default" - checkpoint: bool = True - increase_hidden_channels: int = 2 - dropout_prob: Optional[float] = None - head_scale: bool = False - - def __post_init__(self): - """Type checking / conversion.""" - if isinstance(self.dropout_prob, str) and self.dropout_prob.lower() in [ - "null", - "none", - ]: - self.dropout_prob = None - - @property - def hidden_mv_channels(self) -> Optional[int]: - """Returns the number of hidden multivector channels.""" - - if self.in_mv_channels is None: - return None - - return max( - self.increase_hidden_channels * self.in_mv_channels // self.num_heads, 1 - ) - - @property - def hidden_s_channels(self) -> Optional[int]: - """Returns the number of hidden scalar channels.""" - - if self.in_s_channels is None: - return None - - hidden_s_channels = max( - self.increase_hidden_channels * self.in_s_channels // self.num_heads, 4 - ) - - return hidden_s_channels - - @classmethod - def cast(cls, config: Any) -> SelfAttentionConfig: - """Casts an object as SelfAttentionConfig.""" - if isinstance(config, SelfAttentionConfig): - return config - if isinstance(config, Mapping): - return cls(**config) - raise ValueError(f"Can not cast {config} to {cls}") diff --git a/weaver/nn/model/gatr/layers/attention/qkv.py b/weaver/nn/model/gatr/layers/attention/qkv.py deleted file mode 100644 index 9dc077fc..00000000 --- a/weaver/nn/model/gatr/layers/attention/qkv.py +++ /dev/null @@ -1,238 +0,0 @@ -import torch -from einops import rearrange -from torch import nn - -from gatr.layers.attention.config import SelfAttentionConfig -from gatr.layers.linear import EquiLinear -from gatr.layers.layer_norm import EquiLayerNorm - - -class QKVModule(nn.Module): - """Compute (multivector and scalar) queries, keys, and values via multi-head attention. - - Parameters - ---------- - config: SelfAttentionConfig - Attention configuration - """ - - def __init__(self, config: SelfAttentionConfig): - super().__init__() - self.in_linear = EquiLinear( - in_mv_channels=config.in_mv_channels + config.additional_qk_mv_channels, - out_mv_channels=3 * config.hidden_mv_channels * config.num_heads, - in_s_channels=config.in_s_channels + config.additional_qk_s_channels, - out_s_channels=None - if config.in_s_channels is None - else 3 * config.hidden_s_channels * config.num_heads, - ) - self.norm_qkv = EquiLayerNorm() - self.config = config - - def forward( - self, - inputs, - scalars, - additional_qk_features_mv=None, - additional_qk_features_s=None, - ): - """Forward pass. - - Parameters - ---------- - inputs : torch.Tensor - Multivector inputs - scalars : torch.Tensor - Scalar inputs - additional_qk_features_mv : None or torch.Tensor - Additional multivector features that should be provided for the Q/K computation (e.g. - positions of objects) - additional_qk_features_s : None or torch.Tensor - Additional scalar features that should be provided for the Q/K computation (e.g. - object types) - - Returns - ------- - q_mv : Tensor with shape (..., num_items_out, num_mv_channels_in, 16) - Queries, multivector part. - k_mv : Tensor with shape (..., num_items_in, num_mv_channels_in, 16) - Keys, multivector part. - v_mv : Tensor with shape (..., num_items_in, num_mv_channels_out, 16) - Values, multivector part. - q_s : Tensor with shape (..., heads, num_items_out, num_s_channels_in) - Queries, scalar part. - k_s : Tensor with shape (..., heads, num_items_in, num_s_channels_in) - Keys, scalar part. - v_s : Tensor with shape (..., heads, num_items_in, num_s_channels_out) - Values, scalar part. - """ - - # Additional inputs - if additional_qk_features_mv is not None: - inputs = torch.cat((inputs, additional_qk_features_mv), dim=-2) - if additional_qk_features_s is not None: - scalars = torch.cat((scalars, additional_qk_features_s), dim=-1) - - qkv_mv, qkv_s = self.in_linear( - inputs, scalars - ) # (..., num_items, 3 * hidden_channels * num_heads, 16) - qkv_mv = rearrange( - qkv_mv, - "... items (qkv hidden num_heads) x -> qkv ... num_heads items hidden x", - num_heads=self.config.num_heads, - hidden=self.config.hidden_mv_channels, - qkv=3, - ) - q_mv, k_mv, v_mv = qkv_mv # each: (..., num_heads, num_items, num_channels, 16) - - # Same, for optional scalar components - if qkv_s is not None: - qkv_s = rearrange( - qkv_s, - "... items (qkv hidden num_heads) -> qkv ... num_heads items hidden", - num_heads=self.config.num_heads, - hidden=self.config.hidden_s_channels, - qkv=3, - ) - q_s, k_s, v_s = qkv_s # each: (..., num_heads, num_items, num_channels) - else: - q_s, k_s, v_s = None, None, None - - q_mv, q_s = self.norm_qkv(q_mv, scalars=q_s) - k_mv, k_s = self.norm_qkv(k_mv, scalars=k_s) - v_mv, v_s = self.norm_qkv(v_mv, scalars=v_s) - - return q_mv, k_mv, v_mv, q_s, k_s, v_s - - -class MultiQueryQKVModule(nn.Module): - """Compute (multivector and scalar) queries, keys, and values via multi-query attention. - - Parameters - ---------- - config: SelfAttentionConfig - Attention configuration - """ - - def __init__(self, config: SelfAttentionConfig): - super().__init__() - - # Q projection - self.q_linear = EquiLinear( - in_mv_channels=config.in_mv_channels + config.additional_qk_mv_channels, - out_mv_channels=config.hidden_mv_channels * config.num_heads, - in_s_channels=config.in_s_channels + config.additional_qk_s_channels, - out_s_channels=config.hidden_s_channels * config.num_heads, - ) - - # Key and value projections (shared between heads) - self.k_linear = EquiLinear( - in_mv_channels=config.in_mv_channels + config.additional_qk_mv_channels, - out_mv_channels=config.hidden_mv_channels, - in_s_channels=config.in_s_channels + config.additional_qk_s_channels, - out_s_channels=config.hidden_s_channels, - ) - self.v_linear = EquiLinear( - in_mv_channels=config.in_mv_channels, - out_mv_channels=config.hidden_mv_channels, - in_s_channels=config.in_s_channels, - out_s_channels=config.hidden_s_channels, - ) - self.norm_qkv = EquiLayerNorm() - self.config = config - - def forward( - self, - inputs, - scalars, - additional_qk_features_mv=None, - additional_qk_features_s=None, - ): - """Forward pass. - - Parameters - ---------- - inputs : torch.Tensor - Multivector inputs - scalars : torch.Tensor - Scalar inputs - additional_qk_features_mv : None or torch.Tensor - Additional multivector features that should be provided for the Q/K computation (e.g. - positions of objects) - additional_qk_features_s : None or torch.Tensor - Additional scalar features that should be provided for the Q/K computation (e.g. - object types) - - Returns - ------- - q_mv : Tensor with shape (..., num_items_out, num_mv_channels_in, 16) - Queries, multivector part. - k_mv : Tensor with shape (..., num_items_in, num_mv_channels_in, 16) - Keys, multivector part. - v_mv : Tensor with shape (..., num_items_in, num_mv_channels_out, 16) - Values, multivector part. - q_s : Tensor with shape (..., heads, num_items_out, num_s_channels_in) - Queries, scalar part. - k_s : Tensor with shape (..., heads, num_items_in, num_s_channels_in) - Keys, scalar part. - v_s : Tensor with shape (..., heads, num_items_in, num_s_channels_out) - Values, scalar part. - """ - - # Additional inputs - if additional_qk_features_mv is not None: - qk_inputs = torch.cat((inputs, additional_qk_features_mv), dim=-2) - else: - qk_inputs = inputs - if scalars is not None and additional_qk_features_s is not None: - qk_scalars = torch.cat((scalars, additional_qk_features_s), dim=-1) - else: - qk_scalars = scalars - - # Project to queries, keys, and values (multivector reps) - q_mv, q_s = self.q_linear( - qk_inputs, qk_scalars - ) # (..., num_items, hidden_channels * num_heads, 16) - k_mv, k_s = self.k_linear( - qk_inputs, qk_scalars - ) # (..., num_items, hidden_channels, 16) - v_mv, v_s = self.v_linear( - inputs, scalars - ) # (..., num_items, hidden_channels, 16) - - # Rearrange to (..., heads, items, channels, 16) shape - q_mv = rearrange( - q_mv, - "... items (hidden_channels num_heads) x -> ... num_heads items hidden_channels x", - num_heads=self.config.num_heads, - hidden_channels=self.config.hidden_mv_channels, - ) - k_mv = rearrange( - k_mv, "... items hidden_channels x -> ... 1 items hidden_channels x" - ) - v_mv = rearrange( - v_mv, "... items hidden_channels x -> ... 1 items hidden_channels x" - ) - - # Same for scalars - if q_s is not None: - q_s = rearrange( - q_s, - "... items (hidden_channels num_heads) -> ... num_heads items hidden_channels", - num_heads=self.config.num_heads, - hidden_channels=self.config.hidden_s_channels, - ) - k_s = rearrange( - k_s, "... items hidden_channels -> ... 1 items hidden_channels" - ) - v_s = rearrange( - v_s, "... items hidden_channels -> ... 1 items hidden_channels" - ) - else: - q_s, k_s, v_s = None, None, None - - q_mv, q_s = self.norm_qkv(q_mv, scalars=q_s) - k_mv, k_s = self.norm_qkv(k_mv, scalars=k_s) - v_mv, v_s = self.norm_qkv(v_mv, scalars=v_s) - - return q_mv, k_mv, v_mv, q_s, k_s, v_s diff --git a/weaver/nn/model/gatr/layers/attention/self_attention.py b/weaver/nn/model/gatr/layers/attention/self_attention.py deleted file mode 100644 index b34141a3..00000000 --- a/weaver/nn/model/gatr/layers/attention/self_attention.py +++ /dev/null @@ -1,147 +0,0 @@ -"""Self-attention layers.""" - -from typing import Optional, Tuple - -import torch -from einops import rearrange -from torch import nn - -from .attention import GeometricAttention -from .config import SelfAttentionConfig -from .qkv import MultiQueryQKVModule, QKVModule -from ..dropout import GradeDropout -from ..linear import EquiLinear - - -class SelfAttention(nn.Module): - """Geometric self-attention layer. - - Constructs queries, keys, and values, computes attention, and projects linearly to outputs. - - Parameters - ---------- - config : SelfAttentionConfig - Attention configuration. - """ - - def __init__(self, config: SelfAttentionConfig) -> None: - super().__init__() - - # Store settings - self.config = config - - # QKV computation - self.qkv_module = ( - MultiQueryQKVModule(config) if config.multi_query else QKVModule(config) - ) - - # Output projection - self.out_linear = EquiLinear( - in_mv_channels=config.hidden_mv_channels * config.num_heads, - out_mv_channels=config.out_mv_channels, - in_s_channels=( - None - if config.in_s_channels is None - else config.hidden_s_channels * config.num_heads - ), - out_s_channels=config.out_s_channels, - initialization=config.output_init, - ) - - # Attention - self.attention = GeometricAttention(config) - - # Dropout - self.dropout: Optional[nn.Module] - if config.dropout_prob is not None: - self.dropout = GradeDropout(config.dropout_prob) - else: - self.dropout = None - - # HeadScaleMHA - self.use_head_scale = config.head_scale - if self.use_head_scale: - self.head_scale = nn.Parameter(torch.ones(config.num_heads)) - - def forward( - self, - multivectors: torch.Tensor, - additional_qk_features_mv: Optional[torch.Tensor] = None, - scalars: Optional[torch.Tensor] = None, - additional_qk_features_s: Optional[torch.Tensor] = None, - attention_mask=None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Computes forward pass on inputs with shape `(..., items, channels, 16)`. - - The result is the following: - - ``` - # For each head - queries = linear_channels(inputs) - keys = linear_channels(inputs) - values = linear_channels(inputs) - hidden = attention_items(queries, keys, values, biases=biases) - head_output = linear_channels(hidden) - - # Combine results - output = concatenate_heads head_output - ``` - - Parameters - ---------- - multivectors : torch.Tensor with shape (..., num_items, channels_in, 16) - Input multivectors. - additional_qk_features_mv : None or torch.Tensor with shape - (..., num_items, add_qk_mv_channels, 16) - Additional Q/K features, multivector part. - scalars : None or torch.Tensor with shape (..., num_items, num_items, in_scalars) - Optional input scalars - additional_qk_features_s : None or torch.Tensor with shape - (..., num_items, add_qk_mv_channels, 16) - Additional Q/K features, scalar part. - scalars : None or torch.Tensor with shape (..., num_items, num_items, in_scalars) - Optional input scalars - attention_mask: None or torch.Tensor with shape (..., num_items, num_items) - Optional attention mask - - Returns - ------- - outputs_mv : torch.Tensor with shape (..., num_items, channels_out, 16) - Output multivectors. - output_scalars : torch.Tensor with shape (..., num_items, channels_out, out_scalars) - Output scalars, if scalars are provided. Otherwise None. - """ - # Compute Q, K, V - q_mv, k_mv, v_mv, q_s, k_s, v_s = self.qkv_module( - multivectors, scalars, additional_qk_features_mv, additional_qk_features_s - ) - - # Attention layer - h_mv, h_s = self.attention( - q_mv, k_mv, v_mv, q_s, k_s, v_s, attention_mask=attention_mask - ) - if self.use_head_scale: - h_mv = h_mv * self.head_scale.view( - *[1] * len(h_mv.shape[:-5]), len(self.head_scale), 1, 1, 1 - ) - h_s = h_s * self.head_scale.view( - *[1] * len(h_s.shape[:-4]), len(self.head_scale), 1, 1 - ) - - h_mv = rearrange( - h_mv, - "... n_heads n_items hidden_channels x -> ... n_items (n_heads hidden_channels) x", - ) - h_s = rearrange( - h_s, - "... n_heads n_items hidden_channels -> ... n_items (n_heads hidden_channels)", - ) - - # Transform linearly one more time - outputs_mv, outputs_s = self.out_linear(h_mv, scalars=h_s) - - # Dropout - if self.dropout is not None: - outputs_mv, outputs_s = self.dropout(outputs_mv, outputs_s) - - return outputs_mv, outputs_s diff --git a/weaver/nn/model/gatr/layers/dropout.py b/weaver/nn/model/gatr/layers/dropout.py deleted file mode 100644 index ec32d041..00000000 --- a/weaver/nn/model/gatr/layers/dropout.py +++ /dev/null @@ -1,51 +0,0 @@ -"""Equivariant dropout layer.""" - -from typing import Tuple - -import torch -from torch import nn - -from ..primitives import grade_dropout - - -class GradeDropout(nn.Module): - """Grade dropout for multivectors (and regular dropout for auxiliary scalars). - - Parameters - ---------- - p : float - Dropout probability. - """ - - def __init__(self, p: float = 0.0): - super().__init__() - self._dropout_prob = p - - def forward( - self, multivectors: torch.Tensor, scalars: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Forward pass. Applies dropout. - - Parameters - ---------- - multivectors : torch.Tensor with shape (..., 16) - Multivector inputs. - scalars : torch.Tensor - Scalar inputs. - - Returns - ------- - outputs_mv : torch.Tensor with shape (..., 16) - Multivector inputs with dropout applied. - output_scalars : torch.Tensor - Scalar inputs with dropout applied. - """ - - out_mv = grade_dropout( - multivectors, p=self._dropout_prob, training=self.training - ) - out_s = torch.nn.functional.dropout( - scalars, p=self._dropout_prob, training=self.training - ) - - return out_mv, out_s diff --git a/weaver/nn/model/gatr/layers/gatr_block.py b/weaver/nn/model/gatr/layers/gatr_block.py deleted file mode 100644 index f7cf651c..00000000 --- a/weaver/nn/model/gatr/layers/gatr_block.py +++ /dev/null @@ -1,143 +0,0 @@ -from dataclasses import replace -from typing import Optional, Tuple - -import torch -from torch import nn - -from .attention import SelfAttention, SelfAttentionConfig -from .layer_norm import EquiLayerNorm -from .mlp.config import MLPConfig -from .mlp.mlp import GeoMLP - - -class GATrBlock(nn.Module): - """Equivariant transformer encoder block for L-GATr. - - This is the biggest building block of L-GATr. - - Inputs are first processed by a block consisting of LayerNorm, multi-head geometric - self-attention, and a residual connection. Then the data is processed by a block consisting of - another LayerNorm, an item-wise two-layer geometric MLP with GeLU activations, and another - residual connection. - - Parameters - ---------- - mv_channels : int - Number of input and output multivector channels - s_channels: int - Number of input and output scalar channels - attention: SelfAttentionConfig - Attention configuration - mlp: MLPConfig - MLP configuration - dropout_prob : float or None - Dropout probability - double_layernorm : bool - Whether to use double layer normalization - """ - - def __init__( - self, - mv_channels: int, - s_channels: int, - attention: SelfAttentionConfig, - mlp: MLPConfig, - dropout_prob: Optional[float] = None, - double_layernorm: bool = False, - ) -> None: - super().__init__() - - # Normalization layer (stateless, so we can use the same layer for both normalization - # instances) - self.norm = EquiLayerNorm() - self.double_layernorm = double_layernorm - - # Self-attention layer - attention = replace( - attention, - in_mv_channels=mv_channels, - out_mv_channels=mv_channels, - in_s_channels=s_channels, - out_s_channels=s_channels, - output_init="small", - dropout_prob=dropout_prob, - ) - self.attention = SelfAttention(attention) - - # MLP block - mlp = replace( - mlp, - mv_channels=(mv_channels, 2 * mv_channels, mv_channels), - s_channels=(s_channels, 2 * s_channels, s_channels), - dropout_prob=dropout_prob, - ) - self.mlp = GeoMLP(mlp) - - def forward( - self, - multivectors: torch.Tensor, - scalars: torch.Tensor, - additional_qk_features_mv=None, - additional_qk_features_s=None, - attention_mask=None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Forward pass of the transformer encoder block. - - Parameters - ---------- - multivectors : torch.Tensor with shape (..., items, channels, 16) - Input multivectors. - scalars : torch.Tensor with shape (..., s_channels) - Input scalars. - additional_qk_features_mv : None or torch.Tensor with shape - (..., num_items, add_qk_mv_channels, 16) - Additional Q/K features, multivector part. - additional_qk_features_s : None or torch.Tensor with shape - (..., num_items, add_qk_mv_channels, 16) - Additional Q/K features, scalar part. - attention_mask: None or torch.Tensor or AttentionBias - Optional attention mask. - - Returns - ------- - outputs_mv : torch.Tensor with shape (..., items, channels, 16). - Output multivectors - output_scalars : torch.Tensor with shape (..., s_channels) - Output scalars - """ - - # Attention block: pre layer norm - h_mv, h_s = self.norm(multivectors, scalars=scalars) - - # Attention block: self attention - h_mv, h_s = self.attention( - h_mv, - scalars=h_s, - additional_qk_features_mv=additional_qk_features_mv, - additional_qk_features_s=additional_qk_features_s, - attention_mask=attention_mask, - ) - - # Attention block: post layer norm - if self.double_layernorm: - h_mv, h_s = self.norm(h_mv, scalars=h_s) - - # Attention block: skip connection - outputs_mv = multivectors + h_mv - outputs_s = scalars + h_s - - # MLP block: pre layer norm - h_mv, h_s = self.norm(outputs_mv, scalars=outputs_s) - - # MLP block: MLP - h_mv, h_s = self.mlp(h_mv, scalars=h_s) - - # MLP block: post layer norm - if self.double_layernorm: - h_mv, h_s = self.norm(h_mv, scalars=h_s) - - # MLP block: skip connection - outputs_mv = outputs_mv + h_mv - outputs_s = outputs_s + h_s - - return outputs_mv, outputs_s diff --git a/weaver/nn/model/gatr/layers/layer_norm.py b/weaver/nn/model/gatr/layers/layer_norm.py deleted file mode 100644 index ed1456ec..00000000 --- a/weaver/nn/model/gatr/layers/layer_norm.py +++ /dev/null @@ -1,70 +0,0 @@ -"""Equivariant normalization layers.""" - -from typing import Tuple - -import torch -from torch import nn - -from ..primitives import equi_layer_norm - - -class EquiLayerNorm(nn.Module): - """Equivariant LayerNorm for multivectors. - - Rescales input such that `mean_channels |inputs|^2 = 1`, where the norm is the GA norm and the - mean goes over the channel dimensions. - - In addition, the layer performs a regular LayerNorm operation on auxiliary scalar inputs. - - Parameters - ---------- - mv_channel_dim : int - Channel dimension index for multivector inputs. Defaults to the second-last entry (last are - the multivector components). - scalar_channel_dim : int - Channel dimension index for scalar inputs. Defaults to the last entry. - epsilon : float - Small numerical factor to avoid instabilities. We use a reasonably large number to balance - issues that arise from some multivector components not contributing to the norm. - """ - - def __init__(self, mv_channel_dim=-2, scalar_channel_dim=-1, epsilon: float = 0.01): - super().__init__() - self.mv_channel_dim = mv_channel_dim - self.epsilon = epsilon - - if scalar_channel_dim != -1: - raise NotImplementedError( - "Currently, only scalar_channel_dim = -1 is implemented, but found" - f" {scalar_channel_dim}" - ) - - def forward( - self, multivectors: torch.Tensor, scalars: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Forward pass. Computes equivariant LayerNorm for multivectors. - - Parameters - ---------- - multivectors : torch.Tensor with shape (..., 16) - Multivector inputs - scalars : torch.Tensor with shape (..., self.in_channels, self.in_scalars) - Scalar inputs - - Returns - ------- - outputs_mv : torch.Tensor with shape (..., 16) - Normalized multivectors - output_scalars : torch.Tensor with shape (..., self.out_channels, self.in_scalars) - Normalized scalars. - """ - - outputs_mv = equi_layer_norm( - multivectors, channel_dim=self.mv_channel_dim, epsilon=self.epsilon - ) - normalized_shape = scalars.shape[-1:] - outputs_s = torch.nn.functional.layer_norm( - scalars, normalized_shape=normalized_shape - ) - - return outputs_mv, outputs_s diff --git a/weaver/nn/model/gatr/layers/linear.py b/weaver/nn/model/gatr/layers/linear.py deleted file mode 100644 index 226c1134..00000000 --- a/weaver/nn/model/gatr/layers/linear.py +++ /dev/null @@ -1,393 +0,0 @@ -"""Pin-equivariant linear layers between multivector tensors (torch.nn.Modules).""" - -from typing import Optional, Tuple, Union - -import numpy as np -import torch -from torch import nn - -from ..interface import embed_scalar -from ..primitives.linear import equi_linear - -# switch to mix pseudoscalar multivector components directly into scalar components -# this only makes sense when working with the special orthochronous Lorentz group, -# Note: This is an efficiency boost, the same action can be achieved with an extra linear layer -MIX_MVPSEUDOSCALAR_INTO_SCALAR = True - - -class EquiLinear(nn.Module): - """Pin-equivariant linear layer. - - The forward pass maps multivector inputs with shape (..., in_channels, 16) to multivector - outputs with shape (..., out_channels, 16) as - - ``` - outputs[..., j, y] = sum_{i, b, x} weights[j, i, b] basis_map[b, x, y] inputs[..., i, x] - ``` - - plus an optional bias term for outputs[..., :, 0] (biases in other multivector components would - break equivariance). - - Here basis_map are precomputed (see gatr.primitives.linear) and weights are the - learnable weights of this layer. - - If there are auxiliary input scalars, they transform under a linear layer, and mix with the - scalar components the multivector data. Note that in this layer (and only here) the auxiliary - scalars are optional. - - This layer supports four initialization schemes: - - "default": preserves (or actually slightly reducing) the variance of the data in - the forward pass - - "small": variance of outputs is approximately one order of magnitude smaller - than for "default" - - "unit_scalar": outputs will be close to (1, 0, 0, ..., 0) - - "almost_unit_scalar": similar to "unit_scalar", but with more stochasticity - - Parameters - ---------- - in_mv_channels : int - Input multivector channels - out_mv_channels : int - Output multivector channels - bias : bool - Whether a bias term is added to the scalar component of the multivector outputs - in_s_channels : int or None - Input scalar channels. If None, no scalars are expected nor returned. - out_s_channels : int or None - Output scalar channels. If None, no scalars are expected nor returned. - initialization : {"default", "small", "unit_scalar", "almost_unit_scalar"} - Initialization scheme. For "default", initialize with the same philosophy as most - networks do: preserve variance (approximately) in the forward pass. For "small", - initalize the network such that the variance of the output data is approximately one - order of magnitude smaller than that of the input data. For "unit_scalar", initialize - the layer such that the output multivectors will be closer to (1, 0, 0, ..., 0). - "almost_unit_scalar" is similar, but with more randomness. - """ - - def __init__( - self, - in_mv_channels: int, - out_mv_channels: int, - in_s_channels: Optional[int] = None, - out_s_channels: Optional[int] = None, - bias: bool = True, - initialization: str = "default", - ) -> None: - super().__init__() - - # Check inputs - if initialization in ["unit_scalar", "almost_unit_scalar"]: - assert bias, "unit_scalar initialization requires bias" - if in_s_channels is None: - raise NotImplementedError( - "unit_scalar initialization is currently only implemented for scalar inputs" - ) - - self._in_mv_channels = in_mv_channels - - # MV -> MV - from ..primitives.linear import USE_FULLY_CONNECTED_SUBGROUP - - NUM_PIN_LINEAR_BASIS_ELEMENTS = 10 if USE_FULLY_CONNECTED_SUBGROUP else 5 - self.weight = nn.Parameter( - torch.empty( - (out_mv_channels, in_mv_channels, NUM_PIN_LINEAR_BASIS_ELEMENTS) - ) - ) - - # We only need a separate bias here if that isn't already covered by the linear map from - # scalar inputs - self.bias = ( - nn.Parameter(torch.zeros((out_mv_channels, 1))) - if bias and in_s_channels is None - else None - ) - - # Scalars -> MV scalars - self.s2mvs: Optional[nn.Linear] - mix_factor = 2 if MIX_MVPSEUDOSCALAR_INTO_SCALAR else 1 - if in_s_channels: - self.s2mvs = nn.Linear( - in_s_channels, mix_factor * out_mv_channels, bias=bias - ) - else: - self.s2mvs = None - - # MV scalars -> scalars - if out_s_channels: - self.mvs2s = nn.Linear( - mix_factor * in_mv_channels, out_s_channels, bias=bias - ) - else: - self.mvs2s = None - - # Scalars -> scalars - if in_s_channels is not None and out_s_channels is not None: - self.s2s = nn.Linear( - in_s_channels, out_s_channels, bias=False - ) # Bias would be duplicate - else: - self.s2s = None - - # Initialization - self.reset_parameters(initialization) - - def forward( - self, multivectors: torch.Tensor, scalars: Optional[torch.Tensor] = None - ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: - """Maps input multivectors and scalars using the most general equivariant linear map. - - The result is again multivectors and scalars. - - For multivectors we have: - ``` - outputs[..., j, y] = sum_{i, b, x} weights[j, i, b] basis_map[b, x, y] inputs[..., i, x] - = sum_i linear(inputs[..., i, :], weights[j, i, :]) - ``` - - Here basis_map are precomputed (see gatr.primitives.linear) and weights are the - learnable weights of this layer. - - Parameters - ---------- - multivectors : torch.Tensor with shape (..., in_mv_channels, 16) - Input multivectors - scalars : None or torch.Tensor with shape (..., in_s_channels) - Optional input scalars - - Returns - ------- - outputs_mv : torch.Tensor with shape (..., out_mv_channels, 16) - Output multivectors - outputs_s : None or torch.Tensor with shape (..., out_s_channels) - Output scalars, if scalars are provided. Otherwise None. - """ - - outputs_mv = equi_linear(multivectors, self.weight) # (..., out_channels, 16) - - if self.bias is not None: - bias = embed_scalar(self.bias) - outputs_mv = outputs_mv + bias - - if self.s2mvs is not None and scalars is not None: - if MIX_MVPSEUDOSCALAR_INTO_SCALAR: - outputs_mv[..., [0, -1]] += self.s2mvs(scalars).view( - *outputs_mv.shape[:-2], outputs_mv.shape[-2], 2 - ) - else: - outputs_mv[..., 0] += self.s2mvs(scalars) - - if self.mvs2s is not None: - if MIX_MVPSEUDOSCALAR_INTO_SCALAR: - outputs_s = self.mvs2s(multivectors[..., [0, -1]].flatten(start_dim=-2)) - else: - outputs_s = self.mvs2s(multivectors[..., 0]) - if self.s2s is not None and scalars is not None: - outputs_s = outputs_s + self.s2s(scalars) - else: - outputs_s = None - - return outputs_mv, outputs_s - - def reset_parameters( - self, - initialization: str, - gain: float = 1.0, - additional_factor=1.0 / np.sqrt(3.0), - ) -> None: - """Initializes the weights of the layer. - - Parameters - ---------- - initialization : {"default", "small", "unit_scalar", "almost_unit_scalar"} - Initialization scheme. For "default", initialize with the same philosophy as most - networks do: preserve variance (approximately) in the forward pass. For "small", - initalize the network such that the variance of the output data is approximately one - order of magnitude smaller than that of the input data. For "unit_scalar", initialize - the layer such that the output multivectors will be closer to (1, 0, 0, ..., 0). - "almost_unit_scalar" is similar, but with more randomness. - gain : float - Gain factor for the activations. Should be 1.0 if previous layer has no activation, - sqrt(2) if it has a ReLU activation, and so on. Can be computed with - `torch.nn.init.calculate_gain()`. - additional_factor : float - Empirically, it has been found that slightly *decreasing* the data variance at each - layer gives a better performance. In particular, the PyTorch default initialization uses - an additional factor of 1/sqrt(3) (cancelling the factor of sqrt(3) that naturally - arises when computing the bounds of a uniform initialization). A discussion of this was - (to the best of our knowledge) never published, but see - https://github.com/pytorch/pytorch/issues/57109 and - https://soumith.ch/files/20141213_gplus_nninit_discussion.htm. - """ - - # Prefactors depending on initialization scheme - ( - mv_component_factors, - mv_factor, - mvs_bias_shift, - s_factor, - ) = self._compute_init_factors( - initialization, - gain, - additional_factor, - ) - - # Following He et al, 1502.01852, we aim to preserve the variance in the forward pass. - # A sufficient criterion for this is that the variance of the weights is given by - # `Var[w] = gain^2 / fan`. - # Here `gain^2` is 2 if the previous layer has a ReLU nonlinearity, 1 for the initial layer, - # and some other value in other situations (we may not care about this too much). - # More importantly, `fan` is the number of connections: the number of input elements that - # get summed over to compute each output element. - - # Let us fist consider the multivector outputs. - self._init_multivectors(mv_component_factors, mv_factor, mvs_bias_shift) - - # Then let's consider the maps to scalars. - self._init_scalars(s_factor) - - @staticmethod - def _compute_init_factors( - initialization, - gain, - additional_factor, - ): - """Computes prefactors for the initialization. - - See self.reset_parameters(). - """ - - if initialization not in { - "default", - "small", - "unit_scalar", - "almost_unit_scalar", - }: - raise ValueError(f"Unknown initialization scheme {initialization}") - - if initialization == "default": - mv_factor = gain * additional_factor * np.sqrt(3) - s_factor = gain * additional_factor * np.sqrt(3) - mvs_bias_shift = 0.0 - elif initialization == "small": - # Change scale by a factor of 0.1 in this layer - mv_factor = 0.1 * gain * additional_factor * np.sqrt(3) - s_factor = 0.1 * gain * additional_factor * np.sqrt(3) - mvs_bias_shift = 0.0 - elif initialization == "unit_scalar": - # Change scale by a factor of 0.1 for MV outputs, and initialize bias around 1 - mv_factor = 0.1 * gain * additional_factor * np.sqrt(3) - s_factor = gain * additional_factor * np.sqrt(3) - mvs_bias_shift = 1.0 - elif initialization == "almost_unit_scalar": - # Change scale by a factor of 0.5 for MV outputs, and initialize bias around 1 - mv_factor = 0.5 * gain * additional_factor * np.sqrt(3) - s_factor = gain * additional_factor * np.sqrt(3) - mvs_bias_shift = 1.0 - else: - raise ValueError( - f"Unknown initialization scheme {initialization}, expected" - ' "default", "small", "unit_scalar" or "almost_unit_scalar".' - ) - - # Individual factors for each multivector component (could be tuned for performance) - from ..primitives.linear import USE_FULLY_CONNECTED_SUBGROUP - - NUM_PIN_LINEAR_BASIS_ELEMENTS = 10 if USE_FULLY_CONNECTED_SUBGROUP else 5 - mv_component_factors = torch.ones(NUM_PIN_LINEAR_BASIS_ELEMENTS) - return mv_component_factors, mv_factor, mvs_bias_shift, s_factor - - def _init_multivectors(self, mv_component_factors, mv_factor, mvs_bias_shift): - """Weight initialization for maps to multivector outputs.""" - - # We have - # `outputs[..., j, y] = sum_{i, b, x} weights[j, i, b] basis_map[b, x, y] inputs[..., i, x]` - # The basis maps are more or less grade projections, summing over all basis elements - # corresponds to (almost) an identity map in the GA space. The sum over `b` and `x` thus - # does not contribute to `fan` substantially. (We may add a small ad-hoc factor later to - # make up for this approximation.) However, there is still the sum over incoming channels, - # and thus `fan ~ mv_in_channels`. Assuming (for now) that the previous layer contained a - # ReLU activation, we finally have the condition `Var[w] = 2 / mv_in_channels`. - # Since the variance of a uniform distribution between -a and a is given by - # `Var[Uniform(-a, a)] = a^2/3`, we should set `a = gain * sqrt(3 / mv_in_channels)`. - # In theory (see docstring). - fan_in = self._in_mv_channels - bound = mv_factor / np.sqrt(fan_in) - for i, factor in enumerate(mv_component_factors): - nn.init.uniform_(self.weight[..., i], a=-factor * bound, b=factor * bound) - - # Now let's focus on the scalar components of the multivector outputs. - # If there are only multivector inputs, all is good. But if scalar inputs contribute them as - # well, they contribute to the output variance as well. - # In this case, we initialize such that the multivector inputs and the scalar inputs each - # contribute half to the output variance. - # We can achieve this by inspecting the basis maps and seeing that only basis element 0 - # contributes to the scalar output. Thus, we can reduce the variance of the correponding - # weights to give a variance of 0.5, not 1. - if self.s2mvs is not None: - # contribution from scalar -> mv scalar - bound = mv_component_factors[0] * mv_factor / np.sqrt(fan_in) / np.sqrt(2) - nn.init.uniform_(self.weight[..., [0]], a=-bound, b=bound) - if MIX_MVPSEUDOSCALAR_INTO_SCALAR: - # contribution from scalar -> mv pseudoscalar - bound = ( - mv_component_factors[-1] * mv_factor / np.sqrt(fan_in) / np.sqrt(2) - ) - nn.init.uniform_(self.weight[..., [-1]], a=-bound, b=bound) - - # The same holds for the scalar-to-MV map, where we also just want a variance of 0.5. - # Note: This is not properly extended to scalar and pseudoscalar outputs yet - if self.s2mvs is not None: - fan_in, _ = nn.init._calculate_fan_in_and_fan_out( - self.s2mvs.weight - ) # pylint:disable=protected-access - fan_in = max( - fan_in, 1 - ) # Since in theory we could have 0-channel scalar "data" - bound = mv_component_factors[0] * mv_factor / np.sqrt(fan_in) / np.sqrt(2) - nn.init.uniform_(self.s2mvs.weight, a=-bound, b=bound) - - # Bias needs to be adapted, as the overall fan in is different (need to account for MV - # and s inputs) and we may need to account for the unit_scalar initialization scheme - if self.s2mvs.bias is not None: - fan_in = ( - nn.init._calculate_fan_in_and_fan_out(self.s2mvs.weight)[0] - + self._in_mv_channels - ) - bound = mv_component_factors[0] / np.sqrt(fan_in) if fan_in > 0 else 0 - nn.init.uniform_( - self.s2mvs.bias, mvs_bias_shift - bound, mvs_bias_shift + bound - ) - - def _init_scalars(self, s_factor): - """Weight initialization for maps to multivector outputs.""" - - # If both exist, we need to account for overcounting again, and assign each a target a - # variance of 0.5. - # Note: This is not properly extended to scalar and pseudoscalar outputs yet - models = [] - if self.s2s: - models.append(self.s2s) - if self.mvs2s: - models.append(self.mvs2s) - for model in models: - fan_in, _ = nn.init._calculate_fan_in_and_fan_out( - model.weight - ) # pylint:disable=protected-access - fan_in = max( - fan_in, 1 - ) # Since in theory we could have 0-channel scalar "data" - bound = s_factor / np.sqrt(fan_in) / np.sqrt(len(models)) - nn.init.uniform_(model.weight, a=-bound, b=bound) - # Bias needs to be adapted, as the overall fan in is different (need to account for MV and - # s inputs) - if self.mvs2s and self.mvs2s.bias is not None: - fan_in = nn.init._calculate_fan_in_and_fan_out(self.mvs2s.weight)[ - 0 - ] # pylint:disable=protected-access - if self.s2s: - fan_in += nn.init._calculate_fan_in_and_fan_out(self.s2s.weight)[ - 0 - ] # pylint:disable=protected-access - bound = s_factor / np.sqrt(fan_in) if fan_in > 0 else 0 - nn.init.uniform_(self.mvs2s.bias, -bound, bound) diff --git a/weaver/nn/model/gatr/layers/mlp/__init__.py b/weaver/nn/model/gatr/layers/mlp/__init__.py deleted file mode 100644 index 2423e991..00000000 --- a/weaver/nn/model/gatr/layers/mlp/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .config import MLPConfig -from .geometric_bilinears import GeometricBilinear -from .mlp import GeoMLP -from .nonlinearities import ScalarGatedNonlinearity diff --git a/weaver/nn/model/gatr/layers/mlp/config.py b/weaver/nn/model/gatr/layers/mlp/config.py deleted file mode 100644 index 959ab2f0..00000000 --- a/weaver/nn/model/gatr/layers/mlp/config.py +++ /dev/null @@ -1,44 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from typing import Any, List, Mapping, Optional - - -@dataclass -class MLPConfig: - """Geometric MLP configuration. - - Parameters - ---------- - mv_channels : iterable of int - Number of multivector channels at each layer, from input to output - s_channels : None or iterable of int - If not None, sets the number of scalar channels at each layer, from input to output. Length - needs to match mv_channels - activation : {"relu", "sigmoid", "gelu"} - Which (gated) activation function to use - dropout_prob : float or None - Dropout probability - """ - - mv_channels: Optional[List[int]] = None - s_channels: Optional[List[int]] = None - activation: str = "gelu" - dropout_prob: Optional[float] = None - - def __post_init__(self): - """Type checking / conversion.""" - if isinstance(self.dropout_prob, str) and self.dropout_prob.lower() in [ - "null", - "none", - ]: - self.dropout_prob = None - - @classmethod - def cast(cls, config: Any) -> MLPConfig: - """Casts an object as MLPConfig.""" - if isinstance(config, MLPConfig): - return config - if isinstance(config, Mapping): - return cls(**config) - raise ValueError(f"Can not cast {config} to {cls}") diff --git a/weaver/nn/model/gatr/layers/mlp/geometric_bilinears.py b/weaver/nn/model/gatr/layers/mlp/geometric_bilinears.py deleted file mode 100644 index 287bf87a..00000000 --- a/weaver/nn/model/gatr/layers/mlp/geometric_bilinears.py +++ /dev/null @@ -1,105 +0,0 @@ -"""Pin-equivariant geometric product layer between multivector tensors (torch.nn.Modules).""" - -from typing import Optional, Tuple - -import torch -from torch import nn - -from ..linear import EquiLinear -from ...primitives import geometric_product -from ..layer_norm import EquiLayerNorm - -# switch to set bivector components to zero, -# after they are generated by the geometric product -ZERO_BIVECTOR = False - - -class GeometricBilinear(nn.Module): - """Geometric bilinear layer. - - Pin-equivariant map between multivector tensors that constructs new geometric features via - geometric products. - - Parameters - ---------- - in_mv_channels : int - Input multivector channels of `x` - out_mv_channels : int - Output multivector channels - hidden_mv_channels : int or None - Hidden MV channels. If None, uses out_mv_channels. - in_s_channels : int or None - Input scalar channels of `x`. If None, no scalars are expected nor returned. - out_s_channels : int or None - Output scalar channels. If None, no scalars are expected nor returned. - """ - - def __init__( - self, - in_mv_channels: int, - out_mv_channels: int, - hidden_mv_channels: Optional[int] = None, - in_s_channels: Optional[int] = None, - out_s_channels: Optional[int] = None, - ) -> None: - super().__init__() - - # Default options - if hidden_mv_channels is None: - hidden_mv_channels = out_mv_channels - - # Linear projections for GP - self.linear_left = EquiLinear( - in_mv_channels, - hidden_mv_channels, - in_s_channels=in_s_channels, - out_s_channels=None, - ) - self.linear_right = EquiLinear( - in_mv_channels, - hidden_mv_channels, - in_s_channels=in_s_channels, - out_s_channels=None, - initialization="almost_unit_scalar", - ) - - # Output linear projection - self.linear_out = EquiLinear( - hidden_mv_channels, out_mv_channels, in_s_channels, out_s_channels - ) - self.norm = EquiLayerNorm() - - def forward( - self, - multivectors: torch.Tensor, - scalars: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Forward pass. - - Parameters - ---------- - multivectors : torch.Tensor with shape (..., in_mv_channels, 16) - Input multivectors - scalars : torch.Tensor with shape (..., in_s_channels) - Input scalars - - Returns - ------- - outputs_mv : torch.Tensor with shape (..., self.out_mv_channels, 16) - Output multivectors - output_s : None or torch.Tensor with shape (..., out_s_channels) - Output scalars. - """ - - # GP - left, _ = self.linear_left(multivectors, scalars=scalars) - right, _ = self.linear_right(multivectors, scalars=scalars) - gp_outputs = geometric_product(left, right) - if ZERO_BIVECTOR: - gp_outputs[..., 5:11] = 0.0 - - # Output linear - outputs_mv, outputs_s = self.linear_out(gp_outputs, scalars=scalars) - - outputs_mv, outputs_s = self.norm(outputs_mv, outputs_s) - return outputs_mv, outputs_s diff --git a/weaver/nn/model/gatr/layers/mlp/mlp.py b/weaver/nn/model/gatr/layers/mlp/mlp.py deleted file mode 100644 index 05f1421c..00000000 --- a/weaver/nn/model/gatr/layers/mlp/mlp.py +++ /dev/null @@ -1,106 +0,0 @@ -"""Factory functions for simple MLPs for multivector data.""" - -from typing import List, Tuple, Union - -import torch -from torch import nn - -from ..dropout import GradeDropout -from ..linear import EquiLinear -from .config import MLPConfig -from .geometric_bilinears import GeometricBilinear -from .nonlinearities import ScalarGatedNonlinearity - -USE_GEOMETRIC_PRODUCT = True - - -class GeoMLP(nn.Module): - """Geometric MLP. - - This is a core component of GATr's transformer blocks. It is similar to a regular MLP, except - that it uses geometric bilinears (the geometric product) in place of the first linear layer. - - Assumes input has shape `(..., channels[0], 16)`, output has shape `(..., channels[-1], 16)`, - will create hidden layers with shape `(..., channel, 16)` for each additional entry in - `channels`. - - Parameters - ---------- - config: MLPConfig - Configuration object - """ - - def __init__( - self, - config: MLPConfig, - ) -> None: - super().__init__() - - # Store settings - self.config = config - - assert config.mv_channels is not None - s_channels = ( - [None for _ in config.mv_channels] - if config.s_channels is None - else config.s_channels - ) - - layers: List[nn.Module] = [] - - if len(config.mv_channels) >= 2: - kwargs = dict( - in_mv_channels=config.mv_channels[0], - out_mv_channels=config.mv_channels[1], - in_s_channels=s_channels[0], - out_s_channels=s_channels[1], - ) - if USE_GEOMETRIC_PRODUCT: - layers.append(GeometricBilinear(**kwargs)) - else: - layers.append(ScalarGatedNonlinearity(config.activation)) - layers.append(EquiLinear(**kwargs)) - if config.dropout_prob is not None: - layers.append(GradeDropout(config.dropout_prob)) - - for in_, out, in_s, out_s in zip( - config.mv_channels[1:-1], - config.mv_channels[2:], - s_channels[1:-1], - s_channels[2:], - ): - layers.append(ScalarGatedNonlinearity(config.activation)) - layers.append( - EquiLinear(in_, out, in_s_channels=in_s, out_s_channels=out_s) - ) - if config.dropout_prob is not None: - layers.append(GradeDropout(config.dropout_prob)) - - self.layers = nn.ModuleList(layers) - - def forward( - self, multivectors: torch.Tensor, scalars: torch.Tensor - ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: - """Forward pass. - - Parameters - ---------- - multivectors : torch.Tensor with shape (..., in_mv_channels, 16) - Input multivectors. - scalars : None or torch.Tensor with shape (..., in_s_channels) - Optional input scalars. - - Returns - ------- - outputs_mv : torch.Tensor with shape (..., out_mv_channels, 16) - Output multivectors. - outputs_s : None or torch.Tensor with shape (..., out_s_channels) - Output scalars, if scalars are provided. Otherwise None. - """ - - mv, s = multivectors, scalars - - for i, layer in enumerate(self.layers): - mv, s = layer(mv, scalars=s) - - return mv, s diff --git a/weaver/nn/model/gatr/layers/mlp/nonlinearities.py b/weaver/nn/model/gatr/layers/mlp/nonlinearities.py deleted file mode 100644 index 789015a3..00000000 --- a/weaver/nn/model/gatr/layers/mlp/nonlinearities.py +++ /dev/null @@ -1,65 +0,0 @@ -from typing import Tuple - -import torch -from torch import nn - -from ...primitives.nonlinearities import gated_gelu, gated_relu, gated_sigmoid - - -class ScalarGatedNonlinearity(nn.Module): - """Gated nonlinearity, where the gate is simply given by the scalar component of the input. - - Given multivector input x, computes f(x_0) * x, where f can either be ReLU, sigmoid, or GeLU. - - Auxiliary scalar inputs are simply processed with ReLU, sigmoid, or GeLU, without gating. - - Parameters - ---------- - nonlinearity : {"relu", "sigmoid", "gelu"} - Non-linearity type - """ - - def __init__(self, nonlinearity: str = "relu", **kwargs) -> None: - super().__init__() - - gated_fn_dict = dict(relu=gated_relu, gelu=gated_gelu, sigmoid=gated_sigmoid) - scalar_fn_dict = dict( - relu=nn.functional.relu, - gelu=nn.functional.gelu, - sigmoid=nn.functional.sigmoid, - ) - try: - self.gated_nonlinearity = gated_fn_dict[nonlinearity] - self.scalar_nonlinearity = scalar_fn_dict[nonlinearity] - except KeyError as exc: - raise ValueError( - f"Unknown nonlinearity {nonlinearity} for options {list(gated_fn_dict.keys())}" - ) from exc - - def forward( - self, multivectors: torch.Tensor, scalars: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Computes f(x_0) * x for multivector x, where f is GELU, ReLU, or sigmoid. - - f is chosen depending on self.nonlinearity. - - Parameters - ---------- - multivectors : torch.Tensor with shape (..., self.in_channels, 16) - Input multivectors - scalars : None or torch.Tensor with shape (..., self.in_channels, self.in_scalars) - Input scalars - - Returns - ------- - outputs_mv : torch.Tensor with shape (..., self.out_channels, 16) - Output multivectors - output_scalars : torch.Tensor with shape (..., self.out_channels, self.in_scalars) - Output scalars - """ - - gates = multivectors[..., [0]] - outputs_mv = self.gated_nonlinearity(multivectors, gates=gates) - outputs_s = self.scalar_nonlinearity(scalars) - - return outputs_mv, outputs_s diff --git a/weaver/nn/model/gatr/nets/__init__.py b/weaver/nn/model/gatr/nets/__init__.py deleted file mode 100644 index 2c9e6e9e..00000000 --- a/weaver/nn/model/gatr/nets/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .gatr import GATr diff --git a/weaver/nn/model/gatr/nets/gatr.py b/weaver/nn/model/gatr/nets/gatr.py deleted file mode 100644 index 25f1ee87..00000000 --- a/weaver/nn/model/gatr/nets/gatr.py +++ /dev/null @@ -1,182 +0,0 @@ -"""Equivariant transformer for multivector data.""" - -from dataclasses import replace -from typing import Optional, Tuple, Union - -import torch -from torch import nn -from torch.utils.checkpoint import checkpoint - -from ..layers.attention.config import SelfAttentionConfig -from ..layers.gatr_block import GATrBlock -from ..layers.linear import EquiLinear -from ..layers.mlp.config import MLPConfig - - -class GATr(nn.Module): - """L-GATr network for a data with a single token dimension. - - It combines `num_blocks` L-GATr transformer blocks, each consisting of geometric self-attention - layers, a geometric MLP, residual connections, and normalization layers. In addition, there - are initial and final equivariant linear layers. - - Assumes input has shape `(..., items, in_channels, 16)`, output has shape - `(..., items, out_channels, 16)`, will create hidden representations with shape - `(..., items, hidden_channels, 16)`. - - Parameters - ---------- - in_mv_channels : int - Number of input multivector channels. - out_mv_channels : int - Number of output multivector channels. - hidden_mv_channels : int - Number of hidden multivector channels. - in_s_channels : None or int - If not None, sets the number of scalar input channels. - out_s_channels : None or int - If not None, sets the number of scalar output channels. - hidden_s_channels : None or int - If not None, sets the number of scalar hidden channels. - attention: Dict - Data for SelfAttentionConfig - mlp: Dict - Data for MLPConfig - num_blocks : int - Number of transformer blocks. - dropout_prob : float or None - Dropout probability - double_layernorm : bool - Whether to use double layer normalization - """ - - def __init__( - self, - in_mv_channels: int, - out_mv_channels: int, - hidden_mv_channels: int, - in_s_channels: Optional[int], - out_s_channels: Optional[int], - hidden_s_channels: Optional[int], - attention: SelfAttentionConfig, - mlp: MLPConfig, - num_blocks: int = 10, - reinsert_mv_channels: Optional[Tuple[int]] = None, - reinsert_s_channels: Optional[Tuple[int]] = None, - checkpoint_blocks: bool = False, - dropout_prob: Optional[float] = None, - double_layernorm: bool = False, - **kwargs, - ) -> None: - super().__init__() - self.linear_in = EquiLinear( - in_mv_channels, - hidden_mv_channels, - in_s_channels=in_s_channels, - out_s_channels=hidden_s_channels, - ) - attention = replace( - SelfAttentionConfig.cast(attention), - additional_qk_mv_channels=0 - if reinsert_mv_channels is None - else len(reinsert_mv_channels), - additional_qk_s_channels=0 - if reinsert_s_channels is None - else len(reinsert_s_channels), - ) - mlp = MLPConfig.cast(mlp) - self.blocks = nn.ModuleList( - [ - GATrBlock( - mv_channels=hidden_mv_channels, - s_channels=hidden_s_channels, - attention=attention, - mlp=mlp, - dropout_prob=dropout_prob, - double_layernorm=double_layernorm, - ) - for _ in range(num_blocks) - ] - ) - self.linear_out = EquiLinear( - hidden_mv_channels, - out_mv_channels, - in_s_channels=hidden_s_channels, - out_s_channels=out_s_channels, - ) - self._reinsert_s_channels = reinsert_s_channels - self._reinsert_mv_channels = reinsert_mv_channels - self._checkpoint_blocks = checkpoint_blocks - - def forward( - self, - multivectors: torch.Tensor, - scalars: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: - """Forward pass of the network. - - Parameters - ---------- - multivectors : torch.Tensor with shape (..., in_mv_channels, 16) - Input multivectors. - scalars : None or torch.Tensor with shape (..., in_s_channels) - Optional input scalars. - attention_mask: None or torch.Tensor with shape (..., num_items, num_items) - Optional attention mask - - Returns - ------- - outputs_mv : torch.Tensor with shape (..., out_mv_channels, 16) - Output multivectors. - outputs_s : None or torch.Tensor with shape (..., out_s_channels) - Output scalars, if scalars are provided. Otherwise None. - """ - - # Channels that will be re-inserted in any query / key computation - ( - additional_qk_features_mv, - additional_qk_features_s, - ) = self._construct_reinserted_channels(multivectors, scalars) - - # Pass through the blocks - h_mv, h_s = self.linear_in(multivectors, scalars=scalars) - for block in self.blocks: - if self._checkpoint_blocks: - h_mv, h_s = checkpoint( - block, - h_mv, - use_reentrant=False, - scalars=h_s, - additional_qk_features_mv=additional_qk_features_mv, - additional_qk_features_s=additional_qk_features_s, - attention_mask=attention_mask, - ) - else: - h_mv, h_s = block( - h_mv, - scalars=h_s, - additional_qk_features_mv=additional_qk_features_mv, - additional_qk_features_s=additional_qk_features_s, - attention_mask=attention_mask, - ) - - outputs_mv, outputs_s = self.linear_out(h_mv, scalars=h_s) - - return outputs_mv, outputs_s - - def _construct_reinserted_channels(self, multivectors, scalars): - """Constructs input features that will be reinserted in every attention layer.""" - - if self._reinsert_mv_channels is None: - additional_qk_features_mv = None - else: - additional_qk_features_mv = multivectors[..., self._reinsert_mv_channels, :] - - if self._reinsert_s_channels is None: - additional_qk_features_s = None - else: - assert scalars is not None - additional_qk_features_s = scalars[..., self._reinsert_s_channels] - - return additional_qk_features_mv, additional_qk_features_s diff --git a/weaver/nn/model/gatr/primitives/__init__.py b/weaver/nn/model/gatr/primitives/__init__.py deleted file mode 100644 index e1a3d4a5..00000000 --- a/weaver/nn/model/gatr/primitives/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -from .attention import sdp_attention -from .bilinear import geometric_product -from .dropout import grade_dropout -from .invariants import ( - inner_product, - squared_norm, - abs_squared_norm, - pin_invariants, -) -from .linear import ( - equi_linear, - grade_involute, - grade_project, - reverse, -) -from .nonlinearities import gated_gelu, gated_relu, gated_sigmoid -from .normalization import equi_layer_norm diff --git a/weaver/nn/model/gatr/primitives/attention.py b/weaver/nn/model/gatr/primitives/attention.py deleted file mode 100644 index d521c28c..00000000 --- a/weaver/nn/model/gatr/primitives/attention.py +++ /dev/null @@ -1,112 +0,0 @@ -from typing import Optional, Tuple - -import torch -from einops import rearrange -from torch import Tensor -from torch.nn.functional import scaled_dot_product_attention as torch_sdpa - -from .invariants import _load_inner_product_factors - - -def sdp_attention( - q_mv: Tensor, - k_mv: Tensor, - v_mv: Tensor, - q_s: Tensor, - k_s: Tensor, - v_s: Tensor, - attn_mask: Optional[Tensor] = None, -) -> Tuple[Tensor, Tensor]: - """Equivariant geometric attention based on scaled dot products. - - Expects both multivector and scalar queries, keys, and values as inputs. - Then this function computes multivector and scalar outputs in the following way: - - ``` - attn_weights[..., i, j] = softmax_j[ - ga_inner_product(q_mv[..., i, :, :], k_mv[..., j, :, :]) - + euclidean_inner_product(q_s[..., i, :], k_s[..., j, :]) - ] - out_mv[..., i, c, :] = sum_j attn_weights[..., i, j] v_mv[..., j, c, :] / norm - out_s[..., i, c] = sum_j attn_weights[..., i, j] v_s[..., j, c] / norm - ``` - - Parameters - ---------- - q_mv : Tensor with shape (..., num_items_out, num_mv_channels_in, 16) - Queries, multivector part. - k_mv : Tensor with shape (..., num_items_in, num_mv_channels_in, 16) - Keys, multivector part. - v_mv : Tensor with shape (..., num_items_in, num_mv_channels_out, 16) - Values, multivector part. - q_s : Tensor with shape (..., num_items_out, num_s_channels_in) - Queries, scalar part. - k_s : Tensor with shape (..., num_items_in, num_s_channels_in) - Keys, scalar part. - v_s : Tensor with shape (..., num_items_in, num_s_channels_out) - Values, scalar part. - attn_mask : None or Tensor with shape (..., num_items, num_items) - Optional attention mask - - Returns - ------- - outputs_mv : Tensor with shape (..., num_items_out, num_mv_channels_out, 16) - Result, multivector part - outputs_s : Tensor with shape (..., num_items_out, num_s_channels_out) - Result, scalar part - """ - - # Construct queries and keys by concatenating relevant MV components and aux scalars - q = torch.cat( - [ - rearrange( - q_mv - * _load_inner_product_factors(device=q_mv.device, dtype=q_mv.dtype), - "... c x -> ... (c x)", - ), - q_s, - ], - -1, - ) - k = torch.cat([rearrange(k_mv, "... c x -> ... (c x)"), k_s], -1) - - num_channels_out = v_mv.shape[-2] - v = torch.cat([rearrange(v_mv, "... c x -> ... (c x)"), v_s], -1) - - v_out = scaled_dot_product_attention(q, k, v, attn_mask) - - v_out_mv = rearrange( - v_out[..., : num_channels_out * 16], "... (c x) -> ... c x", x=16 - ) - v_out_s = v_out[..., num_channels_out * 16 :] - - return v_out_mv, v_out_s - - -def scaled_dot_product_attention( - query: Tensor, - key: Tensor, - value: Tensor, - attn_mask: Optional[Tensor] = None, - is_causal=False, -) -> Tensor: - """Execute (vanilla) scaled dot-product attention. - - Parameters - ---------- - query : Tensor - of shape [batch, head, item, d] - key : Tensor - of shape [batch, head, item, d] - value : Tensor - of shape [batch, head, item, d] - attn_mask : Optional[Tensor] - Attention mask - is_causal: bool - - Returns - ------- - Tensor - of shape [batch, head, item, d] - """ - return torch_sdpa(query, key, value, attn_mask=attn_mask, is_causal=is_causal) diff --git a/weaver/nn/model/gatr/primitives/bilinear.py b/weaver/nn/model/gatr/primitives/bilinear.py deleted file mode 100644 index df7e6825..00000000 --- a/weaver/nn/model/gatr/primitives/bilinear.py +++ /dev/null @@ -1,62 +0,0 @@ -from functools import lru_cache -from pathlib import Path - -import torch - -from ..utils.einsum import cached_einsum - - -@lru_cache() -def _load_geometric_product_tensor( - device=torch.device("cpu"), dtype=torch.float32 -) -> torch.Tensor: - """Loads geometric product tensor for geometric product between multivectors. - - This function is cached. - - Parameters - ---------- - device : torch.Device or str - Device - dtype : torch.Dtype - Data type - - Returns - ------- - basis : torch.Tensor with shape (16, 16, 16) - Geometric product tensor - """ - - # To avoid duplicate loading, base everything on float32 CPU version - if device not in [torch.device("cpu"), "cpu"] and dtype != torch.float32: - gmt = _load_geometric_product_tensor() - else: - filename = Path(__file__).parent.resolve() / "geometric_product.pt" - gmt = torch.load(filename).to(torch.float32).to_dense() - - return gmt.to(device=device, dtype=dtype) - - -def geometric_product(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - """Computes the geometric product f(x,y) = xy. - - Parameters - ---------- - x : torch.Tensor with shape (..., 16) - First input multivector. Batch dimensions must be broadcastable between x and y. - y : torch.Tensor with shape (..., 16) - Second input multivector. Batch dimensions must be broadcastable between x and y. - - Returns - ------- - outputs : torch.Tensor with shape (..., 16) - Result. Batch dimensions are result of broadcasting between x, y, and coeffs. - """ - - # Select kernel on correct device - gp = _load_geometric_product_tensor(device=x.device, dtype=x.dtype) - - # Compute geometric product - outputs = cached_einsum("i j k, ... j, ... k -> ... i", gp, x, y) - - return outputs diff --git a/weaver/nn/model/gatr/primitives/dropout.py b/weaver/nn/model/gatr/primitives/dropout.py deleted file mode 100644 index 3af07a0d..00000000 --- a/weaver/nn/model/gatr/primitives/dropout.py +++ /dev/null @@ -1,36 +0,0 @@ -import torch - -from .linear import grade_project - - -def grade_dropout(x: torch.Tensor, p: float, training: bool = True) -> torch.Tensor: - """Multivector dropout, dropping out grades independently. - - Parameters - ---------- - x : torch.Tensor with shape (..., 16) - Input data. - p : float - Dropout probability (assumed the same for each grade). - training : bool - Switches between train-time and test-time behaviour. - - Returns - ------- - outputs : torch.Tensor with shape (..., 16) - Inputs with dropout applied. - """ - - # Project to grades - x = grade_project(x) - - # Apply standard 1D dropout - # For whatever reason, that only works with a single batch dimension, so let's reshape a bit - h = x.view(-1, 5, 16) - h = torch.nn.functional.dropout1d(h, p=p, training=training, inplace=False) - h = h.view(x.shape) - - # Combine grades again - h = torch.sum(h, dim=-2) - - return h diff --git a/weaver/nn/model/gatr/primitives/geometric_product.pt b/weaver/nn/model/gatr/primitives/geometric_product.pt deleted file mode 100644 index e74a618f2980a2e90d3f9675b1b661318aa843bd..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 17614 zcmeI4%Wfk@6o$*WIL2&7LWs!>LM&Jq2{N6`ICd5YutwfkBTTp`Vq2_x% zSY?AoJOV;Oya}%WZ-B%ev1Cb=oeI^*b-HfHmLp9^p1Pg-{{Nh^lS~%X?``ieR;#h4 z=N_xE^LYAgoM)5K;3AujFGu-<@h~4ATzr#sJ3H(v{#)G2JaTv*rpj3!egRz`)r%Hql?RQ z=Lx%$r@No_SnrPbdXZM2unO;0UZi_Jjz-fYiAVWl`fOJJBpb(BJbpYG<@L1o?rHi; zp4Ly(R}cEVz1sNC=7(!g zq5RB+utN405>f^f%ylz1=6esp=R`S*LiaZ2cpE(zie54Z$qjx!;J8HAxa7iei8Tsq z0x0+t2p_hG<1p)Z>;1*1#{^5M8_FSbU(z}V`1h|W6B z(@TTs=e=4dwU2oUsin%d?YroiVzHcTz1YUN=(xn%>g(c6I;PsGOcP6urRZaNE}ngS z`0z2WCViiMK6^|))O89zWcbun^am1X7Wm?C$b2h=>2+H_y%~V&@I^?P2kUFRwQ(x$;Q_JKBTQ}8mT{%d9A$rDK z2;avcK4q?=rQ+DID7DlaU?1w3&IRtb8Lxt=>SBDDo4UuhzfLXH<2rI#g@&4g?3;8< z>yl&BV&`SER~JL-jo|MDSbiZsF{Z8_6mur8A%0A37yltFQw||)mmFklh#rRSrmsiczE$`O zVd^}>xe&gorXg(TQ;vO@CZCc`)%tjZ+Aa*2*d-n0=RAzljYD*GJTVt?Rv~_DbXUzs zYvTDkBF3B7Ie=+XD@@y^hAy$FUcmbsW7^Dzcn{heknc-fo0DC{+{#+rg3mlWfpvYa zfx6_l_(!1Fkp@f#Gnou>2E1M|xW9@d3}KYls?Vv8FP zFWfX*N yyWaZN_WL)G-Rqugb8!&}fBxlc@$h}(`Vf{p7XLCPmixV3anT0b torch.Tensor: - """Constructs an array of 1's and -1's for the metric of the space, - used to compute the inner product. - - Parameters - ---------- - device : torch.device - Device - dtype : torch.dtype - Dtype - - Returns - ------- - ip_factors : torch.Tensor with shape (16,) - Inner product factors - """ - - _INNER_PRODUCT_FACTORS = [1, 1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, -1, -1] - factors = torch.tensor( - _INNER_PRODUCT_FACTORS, dtype=torch.float32, device=torch.device("cpu") - ).to_dense() - return factors.to(device=device, dtype=dtype) - - -@lru_cache() -def _load_metric_grades( - device=torch.device("cpu"), dtype=torch.float32 -) -> torch.Tensor: - """Generate tensor of the diagonal of the GA metric, combined with a grade projection. - - Parameters - ---------- - device : torch.device - Device - dtype : torch.dtype - Dtype - - Returns - ------- - torch.Tensor of shape [5, 16] - """ - m = _load_inner_product_factors(device=torch.device("cpu"), dtype=torch.float32) - m_grades = torch.zeros(5, 16, device=torch.device("cpu"), dtype=torch.float32) - offset = 0 - for k in range(4 + 1): - d = math.comb(4, k) - m_grades[k, offset : offset + d] = m[offset : offset + d] - offset += d - return m_grades.to(device=device, dtype=dtype) - - -def inner_product( - x: torch.Tensor, y: torch.Tensor, channel_sum: bool = False -) -> torch.Tensor: - """Computes the inner product of multivectors f(x,y) = = <~x y>_0. - - In addition to summing over the 16 multivector dimensions, this function also sums - over an additional channel dimension if channel_sum == True. - - Equal to `geometric_product(reverse(x), y)[..., [0]]` (but faster). - - Parameters - ---------- - x : torch.Tensor with shape (..., 16) or (..., channels, 16) - First input multivector. Batch dimensions must be broadcastable between x and y. - y : torch.Tensor with shape (..., 16) or (..., channels, 16) - Second input multivector. Batch dimensions must be broadcastable between x and y. - channel_sum: bool - Whether to sum over the second-to-last axis (channels) - - Returns - ------- - outputs : torch.Tensor with shape (..., 1) - Result. Batch dimensions are result of broadcasting between x and y. - """ - - x = x * _load_inner_product_factors(device=x.device, dtype=x.dtype) - - if channel_sum: - outputs = cached_einsum("... c i, ... c i -> ...", x, y) - else: - outputs = cached_einsum("... i, ... i -> ...", x, y) - - # We want the output to have shape (..., 1) - outputs = outputs.unsqueeze(-1) - - return outputs - - -def squared_norm(x: torch.Tensor) -> torch.Tensor: - """Computes the squared GA norm of an input multivector. - - Equal to inner_product(x, x). - - NOTE: this primitive is not used widely in our architectures. - - Parameters - ---------- - x : torch.Tensor with shape (..., 16) - Input multivector. - - Returns - ------- - outputs : torch.Tensor with shape (..., 1) - Geometric algebra norm of x. - """ - - return inner_product(x, x) - - -def pin_invariants(x: torch.Tensor, epsilon: float = 0.01) -> torch.Tensor: - """Computes five invariants from multivectors: scalar component, norms of the four other grades. - - NOTE: this primitive is not used widely in our architectures. - - Parameters - ---------- - x : torch.Tensor with shape (..., 16) - Input multivector. - epsilon : float - Epsilon parameter that regularizes the norm in case it is lower or equal to zero to avoid infinite gradients. - - - Returns - ------- - outputs : torch.Tensor with shape (..., 5) - Invariants computed from input multivectors - """ - - # Project to grades - projections = grade_project(x) # (..., 5, 16) - - # Compute norms - squared_norms = inner_product(projections, projections)[..., 0] # (..., 5) - norms = torch.sqrt(torch.clamp(squared_norms, epsilon)) - - # Outputs: scalar component of input and norms of four other grades - return torch.cat((x[..., [0]], norms[..., 1:]), dim=-1) # (..., 5) - - -@minimum_autocast_precision(torch.float32) -def abs_squared_norm(x: torch.Tensor) -> torch.Tensor: - """Computes a modified version of the squared norm that is positive semidefinite and can - therefore be used in layer normalization. - - Parameters - ---------- - x : torch.Tensor with shape (..., 16) - Input multivector. - - Returns - ------- - outputs : torch.Tensor with shape (..., 1) - Geometric algebra norm of x. - """ - m = _load_metric_grades(device=x.device, dtype=x.dtype) - abs_squared_norms = ( - cached_einsum("... i, ... i, g i -> ... g", x, x, m).abs().sum(-1, keepdim=True) - ) - return abs_squared_norms diff --git a/weaver/nn/model/gatr/primitives/linear.py b/weaver/nn/model/gatr/primitives/linear.py deleted file mode 100644 index 234be38b..00000000 --- a/weaver/nn/model/gatr/primitives/linear.py +++ /dev/null @@ -1,187 +0,0 @@ -from functools import lru_cache -from pathlib import Path - -import torch - -from ..utils.einsum import cached_einsum, custom_einsum - -# switch to decide whether to use the full Lorentz group ('False') -# or the special orthochronous Lorentz group ('True') -# They only differ in the construction of linear maps in _compute_pin_equi_linear_basis -USE_FULLY_CONNECTED_SUBGROUP = True - - -@lru_cache() -def _compute_pin_equi_linear_basis( - device=torch.device("cpu"), dtype=torch.float32, normalize=True -) -> torch.Tensor: - """Constructs basis elements for Pin(1,3)-equivariant linear maps between multivectors. - - This function is cached. - - Parameters - ---------- - device : torch.device - Device - dtype : torch.dtype - Dtype - normalize : bool - Whether to normalize the basis elements - - Returns - ------- - basis : torch.Tensor with shape (NUM_PIN_LINEAR_BASIS_ELEMENTS, 16, 16) - Basis elements for equivariant linear maps. - """ - - if device not in [torch.device("cpu"), "cpu"] and dtype != torch.float32: - basis = _compute_pin_equi_linear_basis(normalize=normalize) - else: - file = ( - "linear_basis_subgroup.pt" - if USE_FULLY_CONNECTED_SUBGROUP - else "linear_basis_full.pt" - ) - filename = Path(__file__).parent.resolve() / file - basis = torch.load(filename).to(torch.float32).to_dense() - return basis.to(device=device, dtype=dtype) - - -@lru_cache() -def _compute_reversal(device=torch.device("cpu"), dtype=torch.float32) -> torch.Tensor: - """Constructs a matrix that computes multivector reversal. - - Parameters - ---------- - device : torch.device - Device - dtype : torch.dtype - Dtype - - Returns - ------- - reversal_diag : torch.Tensor with shape (16,) - The diagonal of the reversal matrix, consisting of +1 and -1 entries. - """ - reversal_flat = torch.ones(16, device=device, dtype=dtype) - reversal_flat[5:15] = -1 - return reversal_flat - - -@lru_cache() -def _compute_grade_involution( - device=torch.device("cpu"), dtype=torch.float32 -) -> torch.Tensor: - """Constructs a matrix that computes multivector grade involution. - - Parameters - ---------- - device : torch.device - Device - dtype : torch.dtype - Dtype - - Returns - ------- - involution_diag : torch.Tensor with shape (16,) - The diagonal of the involution matrix, consisting of +1 and -1 entries. - """ - involution_flat = torch.ones(16, device=device, dtype=dtype) - involution_flat[1:5] = -1 - involution_flat[11:15] = -1 - return involution_flat - - -def equi_linear(x: torch.Tensor, coeffs: torch.Tensor) -> torch.Tensor: - """Pin-equivariant linear map f(x) = sum_{a,j} coeffs_a W^a_ij x_j. - - The W^a are seven pre-defined basis elements. - - Parameters - ---------- - x : torch.Tensor with shape (..., in_channels, 16) - Input multivector. Batch dimensions must be broadcastable between x and coeffs. - coeffs : torch.Tensor with shape (out_channels, in_channels, 10) - Coefficients for the basis elements. Batch dimensions must be broadcastable between x and - coeffs. - - Returns - ------- - outputs : torch.Tensor with shape (..., 16) - Result. Batch dimensions are result of broadcasting between x and coeffs. - """ - basis = _compute_pin_equi_linear_basis(device=x.device, dtype=x.dtype) - return custom_einsum( - "y x a, a i j, ... x j -> ... y i", coeffs, basis, x, path=[0, 1, 0, 1] - ) - - -def grade_project(x: torch.Tensor) -> torch.Tensor: - """Projects an input tensor to the individual grades. - - The return value is a single tensor with a new grade dimension. - - NOTE: this primitive is not used widely in our architectures. - - Parameters - ---------- - x : torch.Tensor with shape (..., 16) - Input multivector. - - Returns - ------- - outputs : torch.Tensor with shape (..., 5, 16) - Output multivector. The second-to-last dimension indexes the grades. - """ - - # Select kernel on correct device - basis = _compute_pin_equi_linear_basis( - device=x.device, dtype=x.dtype, normalize=False - ) - - # First five basis elements are grade projections - basis = basis[:5] - - # Project to grades - projections = cached_einsum("g i j, ... j -> ... g i", basis, x) - - return projections - - -def reverse(x: torch.Tensor) -> torch.Tensor: - """Computes the reversal of a multivector. - - The reversal has the same scalar, vector, and pseudoscalar components, but flips sign in the - bivector and trivector components. - - Parameters - ---------- - x : torch.Tensor with shape (..., 16) - Input multivector. - - Returns - ------- - outputs : torch.Tensor with shape (..., 16) - Output multivector. - """ - return _compute_reversal(device=x.device, dtype=x.dtype) * x - - -def grade_involute(x: torch.Tensor) -> torch.Tensor: - """Computes the grade involution of a multivector. - - The reversal has the same scalar, bivector, and pseudoscalar components, but flips sign in the - vector and trivector components. - - Parameters - ---------- - x : torch.Tensor with shape (..., 16) - Input multivector. - - Returns - ------- - outputs : torch.Tensor with shape (..., 16) - Output multivector. - """ - - return _compute_grade_involution(device=x.device, dtype=x.dtype) * x diff --git a/weaver/nn/model/gatr/primitives/linear_basis_full.pt b/weaver/nn/model/gatr/primitives/linear_basis_full.pt deleted file mode 100644 index 8dd7f63fde93d7c0165591f96a6c9e514805ba38..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6335 zcmeHM&2G~`5MC#B65LA@DbsafzrJ_m}vJk5Z69kGt$a3sev9RKa@Q@6#KMYJIxic5%I-;O|er0Va31u-|QTTkLoE zk&%QWJ9cx{aNhT9h$F^g!n1sO{OeW?GH@@>vWa?P7S8If?%17z6th%UjtET-A)z$M zfMh^2up$H3v4@liV?`Qem}EdQPyz1zU0o;o# z*g+4(3fh7kWFAY+tyVgge(*UYsc z${qrV+(|=}d+Rv9=RpU%em9gK2OSu|&Lh{sGRha5XcZ&5iLR9zW-kB!$EVUqf=Nds zvED6ki>cDcsk22 zqm~(I6mEh6!2kygh%lo=IU6!tmoPf3tC(AS&R_@ZvKf@l#E&5iT-$R`Z7Suo z+>TK%{(nMcwN}+sO;c+XP1W_PQqz<&HN9H9rznc9tCfANqR}kheZGMrnn;!sO?IMu zoz{qM(ie=KtAmY0ObbA_824Y~iGi3%OXC>Rtz2g?o)L8{gE110#sQ`~vB3a*E5-n# z{>KrfS=wO`ns;Id=l8E+X0vxA$J^=M2}CBk8(&{gy)Uuf1eO}{qbgx=6gkt^4KqXa Mv>-!%bpB}WAA#;8hyVZp diff --git a/weaver/nn/model/gatr/primitives/nonlinearities.py b/weaver/nn/model/gatr/primitives/nonlinearities.py deleted file mode 100644 index ad285301..00000000 --- a/weaver/nn/model/gatr/primitives/nonlinearities.py +++ /dev/null @@ -1,79 +0,0 @@ -import math - -import torch - - -def gated_relu(x: torch.Tensor, gates: torch.Tensor) -> torch.Tensor: - """Pin-equivariant gated ReLU nonlinearity. - - Given multivector input x and scalar input gates (with matching batch dimensions), computes - ReLU(gates) * x. - - Parameters - ---------- - x : torch.Tensor with shape (..., 16) - Multivector input - gates : torch.Tensor with shape (..., 1) - Pin-invariant gates. - - Returns - ------- - outputs : torch.Tensor with shape (..., 16) - Computes ReLU(gates) * x, with broadcasting along the last dimension. - """ - - weights = torch.nn.functional.relu(gates) - outputs = weights * x - return outputs - - -def gated_sigmoid(x: torch.Tensor, gates: torch.Tensor): - """Pin-equivariant gated sigmoid nonlinearity. - - Given multivector input x and scalar input gates (with matching batch dimensions), computes - sigmoid(gates) * x. - - Parameters - ---------- - x : torch.Tensor with shape (..., 16) - Multivector input - gates : torch.Tensor with shape (..., 1) - Pin-invariant gates. - - Returns - ------- - outputs : torch.Tensor with shape (..., 16) - Computes sigmoid(gates) * x, with broadcasting along the last dimension. - """ - - weights = torch.nn.functional.sigmoid(gates) - outputs = weights * x - return outputs - - -def gated_gelu(x: torch.Tensor, gates: torch.Tensor) -> torch.Tensor: - """Pin-equivariant gated GeLU nonlinearity without division. - - Given multivector input x and scalar input gates (with matching batch dimensions), computes - GeLU(gates) * x. - - References - ---------- - Dan Hendrycks, Kevin Gimpel, "Gaussian Error Linear Units (GELUs)", arXiv:1606.08415 - - Parameters - ---------- - x : torch.Tensor with shape (..., 16) - Multivector input - gates : torch.Tensor with shape (..., 1) - Pin-invariant gates. - - Returns - ------- - outputs : torch.Tensor with shape (..., 16) - Computes GeLU(gates) * x, with broadcasting along the last dimension. - """ - - weights = torch.nn.functional.gelu(gates, approximate="tanh") - outputs = weights * x - return outputs diff --git a/weaver/nn/model/gatr/primitives/normalization.py b/weaver/nn/model/gatr/primitives/normalization.py deleted file mode 100644 index 092217b0..00000000 --- a/weaver/nn/model/gatr/primitives/normalization.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch - -from .invariants import abs_squared_norm - - -def equi_layer_norm( - x: torch.Tensor, channel_dim: int = -2, gain: float = 1.0, epsilon: float = 0.01 -) -> torch.Tensor: - """Equivariant LayerNorm for multivectors. - - Rescales input such that `mean_channels |inputs|^2 = 1`, where the norm is the GA norm and the - mean goes over the channel dimensions. - - Using a factor `gain > 1` makes up for the fact that the GP norm overestimates the actual - standard deviation of the input data. - - Parameters - ---------- - x : torch.Tensor with shape `(batch_dim, *channel_dims, 16)` - Input multivectors. - channel_dim : int - Channel dimension index. Defaults to the second-last entry (last are the multivector - components). - gain : float - Target output scale. - epsilon : float - Small numerical factor to avoid instabilities. By default, we use a reasonably large number - to balance issues that arise from some multivector components not contributing to the norm. - - Returns - ------- - outputs : torch.Tensor with shape `(batch_dim, *channel_dims, 16)` - Normalized inputs. - """ - - # Compute mean_channels |inputs|^2 - abs_squared_norms = abs_squared_norm(x) - abs_squared_norms = torch.mean(abs_squared_norms, dim=channel_dim, keepdim=True) - - # Insure against low-norm tensors (which can arise even when `x.var(dim=-1)` is high b/c some - # entries don't contribute to the inner product / GP norm!) - abs_squared_norms = torch.clamp(abs_squared_norms, epsilon) - - # Rescale inputs - outputs = gain * x / torch.sqrt(abs_squared_norms) - - return outputs diff --git a/weaver/nn/model/gatr/utils/__init__.py b/weaver/nn/model/gatr/utils/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/weaver/nn/model/gatr/utils/einsum.py b/weaver/nn/model/gatr/utils/einsum.py deleted file mode 100644 index 1e256035..00000000 --- a/weaver/nn/model/gatr/utils/einsum.py +++ /dev/null @@ -1,44 +0,0 @@ -"""This module provides efficiency improvements over torch's einsum through caching.""" - -import functools -from typing import List, Sequence - -import opt_einsum -import torch - - -def custom_einsum( - equation: str, *operands: torch.Tensor, path: List[int] -) -> torch.Tensor: - """Computes einsum with a custom contraction order.""" - - # Justification: For the sake of performance, we need direct access to torch's private methods. - - # pylint:disable-next=protected-access - return torch._VF.einsum(equation, operands, path=path) # type: ignore[attr-defined] - - -def cached_einsum(equation: str, *operands: torch.Tensor) -> torch.Tensor: - """Computes einsum with a cached optimal contraction. - - Inspired by upstream - https://github.com/pytorch/pytorch/blob/v1.13.0/torch/functional.py#L381. - """ - op_shape = tuple(op.shape for op in operands) - path = _get_cached_path_for_equation_and_shapes( - equation=equation, op_shape=op_shape - ) - - return custom_einsum(equation, *operands, path=path) - - -@functools.lru_cache(maxsize=None) -def _get_cached_path_for_equation_and_shapes( - equation: str, op_shape: Sequence[torch.Tensor] -) -> List[int]: - """Provides caching of optimal path.""" - tupled_path = opt_einsum.contract_path( - equation, *op_shape, optimize="optimal", shapes=True - )[0] - - return [item for pair in tupled_path for item in pair] diff --git a/weaver/nn/model/gatr/utils/misc.py b/weaver/nn/model/gatr/utils/misc.py deleted file mode 100644 index 98be6e5b..00000000 --- a/weaver/nn/model/gatr/utils/misc.py +++ /dev/null @@ -1,115 +0,0 @@ -from functools import wraps -from itertools import chain -from typing import Any, Callable, List, Literal, Optional, Union - -import torch -from torch import Tensor - - -def minimum_autocast_precision( - min_dtype: torch.dtype = torch.float32, - output: Optional[Union[Literal["low", "high"], torch.dtype]] = None, - which_args: Optional[List[int]] = None, - which_kwargs: Optional[List[str]] = None, -): - """Decorator that ensures input tensors are autocast to a minimum precision. - Only has an effect in autocast-enabled regions. Otherwise, does not change the function. - Only floating-point inputs are modified. Non-tensors, integer tensors, and boolean tensors are - untouched. - Note: AMP is turned on and off separately for CPU and CUDA. This decorator may fail in - the case where both devices are used, with only one of them on AMP. - Parameters - ---------- - min_dtype : dtype - Minimum dtype. Default: float32. - output: None or "low" or "high" or dtype - Specifies which dtypes the outputs should be cast to. Only floating-point Tensor outputs - are affected. If None, the outputs are not modified. If "low", the lowest-precision input - dtype is used. If "high", `min_dtype` or the highest-precision input dtype is used - (whichever is higher). - which_args : None or list of int - If not None, specifies which positional arguments are to be modified. If None (the default), - all positional arguments are modified (if they are Tensors and of a floating-point dtype). - which_kwargs : bool - If not None, specifies which keyword arguments are to be modified. If None (the default), - all keyword arguments are modified (if they are Tensors and of a floating-point dtype). - Returns - ------- - decorator : Callable - Decorator. - """ - - def decorator(func: Callable): - """Decorator that casts input tensors to minimum precision.""" - - def _cast_in(var: Any): - """Casts a single input to at least 32-bit precision.""" - if not isinstance(var, Tensor): - # We don't want to modify non-Tensors - return var - if not var.dtype.is_floating_point: - # Integer / boolean tensors are also not touched - return var - dtype = max(var.dtype, min_dtype, key=lambda dt: torch.finfo(dt).bits) - return var.to(dtype) - - def _cast_out(var: Any, dtype: torch.dtype): - """Casts a single output to desired precision.""" - if not isinstance(var, Tensor): - # We don't want to modify non-Tensors - return var - if not var.dtype.is_floating_point: - # Integer / boolean tensors are also not touched - return var - return var.to(dtype) - - @wraps(func) - def decorated_func(*args: Any, **kwargs: Any): - """Decorated func.""" - # Only change dtypes in autocast-enabled regions - if not (torch.is_autocast_enabled() or torch.is_autocast_cpu_enabled()): - # NB: torch.is_autocast_enabled() only checks for GPU autocast - # See https://github.com/pytorch/pytorch/issues/110966 - return func(*args, **kwargs) - # Cast inputs to at least 32 bit - mod_args = [ - _cast_in(arg) - for i, arg in enumerate(args) - if which_args is None or i in which_args - ] - mod_kwargs = { - key: _cast_in(val) - for key, val in kwargs.items() - if which_kwargs is None or key in which_kwargs - } - # Call function w/o autocast enabled - with torch.autocast(device_type="cuda", enabled=False), torch.autocast( - device_type="cpu", enabled=False - ): - outputs = func(*mod_args, **mod_kwargs) - # Cast outputs to correct dtype - if output is None: - return outputs - if output in ["low", "high"]: - in_dtypes = [ - arg.dtype - for arg in chain(args, kwargs.values()) - if isinstance(arg, Tensor) and arg.dtype.is_floating_point - ] - assert len(in_dtypes) - if output == "low": - out_dtype = min( - [min_dtype] + in_dtypes, key=lambda dt: torch.finfo(dt).bits - ) - else: - out_dtype = max(in_dtypes, key=lambda dt: torch.finfo(dt).bits) - else: - out_dtype = output - if isinstance(outputs, tuple): - return (_cast_out(val, out_dtype) for val in outputs) - else: - return _cast_out(outputs, out_dtype) - - return decorated_func - - return decorator From 0fb23cc6f72af617ef18f1a560d9a6f54bdec7b5 Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Sat, 15 Mar 2025 22:12:05 +0100 Subject: [PATCH 20/29] Take GATr code from lgatr repo --- requirements.txt | 5 +++-- weaver/nn/model/LGATr.py | 45 ++++++++++++++++++++++++++-------------- 2 files changed, 32 insertions(+), 18 deletions(-) diff --git a/requirements.txt b/requirements.txt index df8f5768..4ea8c101 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,5 +13,6 @@ lz4>=3.1.0 xxhash>=1.4.4 tables>=3.6.1 tensorboard>=2.2.0 -einops==0.6.1 # for LGATr (newer einops version has issues with __name__, can be resolved with careful version picking) -opt_einsum>=3.3.0 # for LGATr +einops +opt_einsum +lgatr @ git+https://github.com/heidelberg-hepml/lgatr.git@basics diff --git a/weaver/nn/model/LGATr.py b/weaver/nn/model/LGATr.py index 8da15aaf..b2c6ebb1 100644 --- a/weaver/nn/model/LGATr.py +++ b/weaver/nn/model/LGATr.py @@ -1,12 +1,13 @@ import torch from torch import nn -from .gatr import GATr, SelfAttentionConfig, MLPConfig -from .gatr.interface import ( +from lgatr import ( + LGATr, embed_vector, extract_scalar, get_num_spurions, - embed_spurions, + get_spurions, + gatr_config, ) from .ParticleTransformer import SequenceTrimmer @@ -28,11 +29,13 @@ def __init__( num_classes, num_blocks, num_heads, + # symmetry-breaking configurations global_token=True, spurion_token=True, - beam_reference="xyplane", - add_time_reference=True, - two_beams=True, + beam_spurion="xyplane", + add_time_spurion=True, + beam_mirror=True, + # network configurations activation="gelu", multi_query=False, increase_hidden_channels=2, @@ -40,6 +43,11 @@ def __init__( double_layernorm=False, dropout_prob=None, checkpoint_blocks=False, + # gatr configurations + use_fully_connected_subgroup=True, + mix_pseudoscalar_into_scalar=True, + use_bivector=True, + use_geometric_product=True, ): super().__init__() @@ -49,29 +57,34 @@ def __init__( self.spurion_token = spurion_token num_spurions = get_num_spurions( - beam_reference, add_time_reference, two_beams=two_beams + beam_spurion, add_time_spurion, beam_mirror=beam_mirror ) if not self.spurion_token: in_mv_channels += num_spurions self.spurion_kwargs = { - "beam_reference": beam_reference, - "add_time_reference": add_time_reference, - "two_beams": two_beams, + "beam_spurion": beam_spurion, + "add_time_spurion": add_time_spurion, + "beam_mirror": beam_mirror, } - attention = SelfAttentionConfig( + gatr_config.use_fully_connected_subgroup = use_fully_connected_subgroup + gatr_config.mix_pseudoscalar_into_scalar = mix_pseudoscalar_into_scalar + gatr_config.use_bivector = use_bivector + gatr_config.use_geometric_product = use_geometric_product + + attention = dict( multi_query=multi_query, num_heads=num_heads, increase_hidden_channels=increase_hidden_channels, dropout_prob=dropout_prob, head_scale=head_scale, ) - mlp = MLPConfig( + mlp = dict( activation=activation, dropout_prob=dropout_prob, ) - self.net = GATr( + self.net = LGATr( in_mv_channels=in_mv_channels, out_mv_channels=num_classes, hidden_mv_channels=hidden_mv_channels, @@ -98,7 +111,7 @@ def forward(self, x, v, mask): s = x # (batch_size, seq_len, num_fts) # symmetry breaking with spurions - spurions = embed_spurions(**self.spurion_kwargs).to( + spurions = get_spurions(**self.spurion_kwargs).to( device=s.device, dtype=s.dtype ) if self.spurion_token: @@ -128,7 +141,7 @@ def forward(self, x, v, mask): mask = mask[:, None, None, :, 0] # (batch_size, 1, 1, seq_len) # call network - out_mv, _ = self.net(mv, s, mask) + out_mv, _ = self.net(mv, s, attn_mask=mask) output = extract_scalar(out_mv)[..., 0] # (batch_size, seq_len, num_classes) # aggregation @@ -152,7 +165,7 @@ def __init__( for_segmentation=False, **kwargs, ): - super().__init__() # not support this kind of **kwargs for now + super().__init__() self.use_amp = use_amp self.for_inference = for_inference From fc9371e541f7a9a2b3c5e3fb42c8d974e5cec201 Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Sat, 15 Mar 2025 22:36:03 +0100 Subject: [PATCH 21/29] Add documentation --- weaver/nn/model/LGATr.py | 142 ++++++++++++++++++++++++++++++++------- 1 file changed, 119 insertions(+), 23 deletions(-) diff --git a/weaver/nn/model/LGATr.py b/weaver/nn/model/LGATr.py index b2c6ebb1..ec99540d 100644 --- a/weaver/nn/model/LGATr.py +++ b/weaver/nn/model/LGATr.py @@ -23,32 +23,111 @@ class LGATrWrapper(nn.Module): def __init__( self, - in_s_channels, - hidden_mv_channels, - hidden_s_channels, - num_classes, - num_blocks, - num_heads, + in_s_channels: int, + hidden_mv_channels: int, + hidden_s_channels: int, + num_classes: int, + num_blocks: int, + num_heads: int, # symmetry-breaking configurations - global_token=True, - spurion_token=True, - beam_spurion="xyplane", - add_time_spurion=True, - beam_mirror=True, + global_token: bool = True, + spurion_token: bool = True, + beam_spurion: str = "xyplane", + add_time_spurion: bool = True, + beam_mirror: bool = True, # network configurations - activation="gelu", - multi_query=False, - increase_hidden_channels=2, - head_scale=False, - double_layernorm=False, - dropout_prob=None, - checkpoint_blocks=False, + activation: str = "gelu", + multi_query: bool = False, + increase_hidden_channels: int = 2, + head_scale: bool = False, + double_layernorm: bool = False, + dropout_prob: float = None, + checkpoint_blocks: bool = False, # gatr configurations - use_fully_connected_subgroup=True, - mix_pseudoscalar_into_scalar=True, - use_bivector=True, - use_geometric_product=True, + use_fully_connected_subgroup: bool = True, + mix_pseudoscalar_into_scalar: bool = True, + use_bivector: bool = True, + use_geometric_product: bool = True, ): + """ + Parameters + ---------- + in_s_channels : int + Number of scalar input channels. + Examples are PID, trajectory displacements and kinematic features + like log(pT), delta_phi etc that are invariant under z-rotations. + hidden_mv_channels : int + Number of hidden multivector channels, defines width of L-GATr. + hidden_s_channels : int + Number of hidden scalar channels. We find best performance with + roughly hidden_s_channels ~ 2 * hidden_mv_channels. + num_classes : int + Number of classification scores to predict + num_blocks : int + Number of L-GATr blocks. + num_heads : int + Number of attention heads in L-GATr. + global_token : bool + If True, prepend a global token as first particle in the list. + If False, fallback to mean-aggregation. + spurion_token : bool + If True, prepend spurions as extra particles (tokens) in the list. + If False, append spurions as extra mv channels. + beam_spurion : str + How the beam spurion is embedded, see lgatr/interface/spurions.py + add_time_spurion : bool + If True, add a time spurion. + beam_mirror : bool + If True and beam_spurion in ["timelike", "lightlike", "spacelike"], + add a mirrored beam_spurion, i.e. with opposite p_z. + activation : {"relu", "sigmoid", "gelu"} + Activation function in the MLP layers. + multi_query : bool + If True, use the same query for each head in attention. + increase_hidden_channels : int + Factor by which hidden_mv_channels is increased in attention. + head_scale : bool + If True, scale the attention heads with a learnable factor. + Inspired by the NormFormer (https://arxiv.org/pdf/2110.09456) + double_layernorm : bool + If True, applies layer normalization also after attention. + The default is only before attention ('pre-layernorm transformer') + dropout_prob : float + Residual dropout after attention and MLP. + checkpoint_blocks : bool + If True, use torch.utils.checkpoint.checkpoint to save memory + at the cost of a slower backward pass. + use_fully_connected_subgroup : bool + If True, model is only equivariant with respect to + the fully connected subgroup of the Lorentz group, + the proper orthochronous Lorentz group SO^+(1,3), + which does not include parity and time reversal. + This setting affects how the EquiLinear maps work: + For SO^+(1,3), they include transitions scalars/pseudoscalars + vectors/axialvectors and among bivectors, effectively + treating the pseudoscalar/axialvector representations + like another scalar/vector. + Defaults to False, because parity-odd representations + are usually not important in high-energy physics simulations. + mix_pseudoscalar_into_scalar : bool + If True, the pseudoscalar part of the multivector mixes + with the pure-scalar channels in the EquiLinear layer. + This is a technical aspect of how EquiLinear maps work, + and only makes sense it use_fully_connected_subgroup=True. + Attention: The combination use_fully_connected_subgroup=False + and mix_pseudoscalar_into_scalar=True does not make sense, + you are only equivariant w.r.t. the fully connected subgroup + if you choose these settings. + use_bivector : bool + If False, the bivector components are set to zero after they + are created in the GeometricBilinear layer. + This is a toy switch to explore the effect of higher-order + representations. + use_geometric_product : bool + If False, the GeometricBilinear layer is replaced + by a EquiLinear + ScalarGatedNonlinearity layer. + This is a toy switch to explore the effect of the geometric product. + """ super().__init__() # spurion business @@ -100,6 +179,23 @@ def __init__( ) def forward(self, x, v, mask): + """ + Parameters + ---------- + x : torch.Tensor with shape (batch_size, num_fts, seq_len) + Scalar features, i.e. features that are invariant under z-rotations. + Examples: PID, trajectory displacements, kinematic features like + log(pT), delta_phi, delta_eta + v : torch.Tensor with shape (batch_size, 4, seq_len) + Lorentz vectors in format (px, py, pz, E) + mask : torch.Tensor with shape (batch_size, 1, seq_len) + Boolean mask that contains 'False' for padded jet constituents + + Returns + ------- + output : torch.Tensor with shape (batch_size, num_classes) + Tagging scores for each class + """ # reshape input x = x.transpose(1, 2) # (batch_size, seq_len, num_fts) v = v.transpose(1, 2) # (batch_size, seq_len, 4) @@ -155,7 +251,7 @@ def forward(self, x, v, mask): class LGATrTagger(nn.Module): - """Mimic weaver features""" + """Mimic other weaver wrappers""" def __init__( self, From b9970b5fd938aced43a5b147f660ed32b0d09323 Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Sat, 15 Mar 2025 22:41:23 +0100 Subject: [PATCH 22/29] Change docu --- weaver/nn/model/LGATr.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/weaver/nn/model/LGATr.py b/weaver/nn/model/LGATr.py index ec99540d..e83d29bf 100644 --- a/weaver/nn/model/LGATr.py +++ b/weaver/nn/model/LGATr.py @@ -13,13 +13,7 @@ class LGATrWrapper(nn.Module): - """ - Wrapper that handles interface to the GATr code - - create dataclasses for attention and mlp - - append spurions (symmetry-breaking) - - interface to geometric algebra - - extract tagging score with global token or mean-aggregation - """ + """Interface to the LGATr class""" def __init__( self, From 5f99b50c3365e996faf803489b3261b7d35054be Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Mon, 17 Mar 2025 23:55:15 +0100 Subject: [PATCH 23/29] Add 'use_flex_attention' option (uses torch_geometric) --- requirements.txt | 1 + weaver/nn/model/LGATr.py | 82 ++++++++++++++++++++++++++++++++++------ 2 files changed, 71 insertions(+), 12 deletions(-) diff --git a/requirements.txt b/requirements.txt index 4ea8c101..04ff3afc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,3 +16,4 @@ tensorboard>=2.2.0 einops opt_einsum lgatr @ git+https://github.com/heidelberg-hepml/lgatr.git@basics +torch_geometric diff --git a/weaver/nn/model/LGATr.py b/weaver/nn/model/LGATr.py index e83d29bf..2da07c7c 100644 --- a/weaver/nn/model/LGATr.py +++ b/weaver/nn/model/LGATr.py @@ -1,5 +1,7 @@ import torch from torch import nn +from torch.nn.attention.flex_attention import create_block_mask +from torch_geometric.nn.aggr import MeanAggregation from lgatr import ( LGATr, @@ -9,6 +11,7 @@ get_spurions, gatr_config, ) +import lgatr.primitives.attention from .ParticleTransformer import SequenceTrimmer @@ -24,19 +27,22 @@ def __init__( num_blocks: int, num_heads: int, # symmetry-breaking configurations - global_token: bool = True, spurion_token: bool = True, beam_spurion: str = "xyplane", add_time_spurion: bool = True, beam_mirror: bool = True, # network configurations + global_token: bool = True, activation: str = "gelu", multi_query: bool = False, increase_hidden_channels: int = 2, head_scale: bool = False, double_layernorm: bool = False, dropout_prob: float = None, + # time/memory configurations checkpoint_blocks: bool = False, + use_flex_attention: bool = True, + compile_flex_attention: bool = True, # gatr configurations use_fully_connected_subgroup: bool = True, mix_pseudoscalar_into_scalar: bool = True, @@ -61,9 +67,6 @@ def __init__( Number of L-GATr blocks. num_heads : int Number of attention heads in L-GATr. - global_token : bool - If True, prepend a global token as first particle in the list. - If False, fallback to mean-aggregation. spurion_token : bool If True, prepend spurions as extra particles (tokens) in the list. If False, append spurions as extra mv channels. @@ -74,6 +77,9 @@ def __init__( beam_mirror : bool If True and beam_spurion in ["timelike", "lightlike", "spacelike"], add a mirrored beam_spurion, i.e. with opposite p_z. + global_token : bool + If True, prepend a global token as first particle in the list. + If False, fallback to mean-aggregation. activation : {"relu", "sigmoid", "gelu"} Activation function in the MLP layers. multi_query : bool @@ -91,6 +97,11 @@ def __init__( checkpoint_blocks : bool If True, use torch.utils.checkpoint.checkpoint to save memory at the cost of a slower backward pass. + use_flex_attention : bool + If True, embed jets as sparse tensors and use flex-attention + for an efficient block-diagonal attention matrix. + If False, use the default torch attention backend with zero-padding. + Using True saves a factor of ~2 in memory and maybe a bit of speed. use_fully_connected_subgroup : bool If True, model is only equivariant with respect to the fully connected subgroup of the Lorentz group, @@ -126,8 +137,24 @@ def __init__( # spurion business in_mv_channels = 1 - self.global_token = global_token + self.global_token = False # global_token self.spurion_token = spurion_token + self.use_flex_attention = use_flex_attention + + if self.use_flex_attention: + if compile_flex_attention: + # torch.compile for attention speedup + if not torch.cuda.is_available(): + # suppress weird error on CPU + torch._dynamo.config.suppress_errors = True + lgatr.primitives.attention.flex_attention = torch.compile( + lgatr.primitives.attention.flex_attention, dynamic=True + ) + global create_block_mask + create_block_mask = torch.compile(create_block_mask, dynamic=True) + + if not self.global_token: + self.aggregator = MeanAggregation() num_spurions = get_num_spurions( beam_spurion, add_time_spurion, beam_mirror=beam_mirror @@ -225,22 +252,53 @@ def forward(self, x, v, mask): mask = torch.cat((mask_ones, mask), dim=1) s_zeros = torch.zeros_like(s[:, [0]]) s = torch.cat((s_zeros, s), dim=1) + is_global = torch.zeros_like(s[:, :, 0], dtype=torch.bool) + is_global[:, 0] = True - # reshape mask to broadcast correctly - mask = mask.bool() - mask = mask[:, None, None, :, 0] # (batch_size, 1, 1, seq_len) + if self.use_flex_attention: + # flatten across batch and sequence dimension + mask = mask[:, :, 0] + mv = mv[mask].unsqueeze(0) + s = s[mask].unsqueeze(0) + + if self.global_token: + is_global = is_global[mask] + else: + batch = torch.arange(mask.shape[0], device=mask.device) + batch = batch.unsqueeze(-1).repeat(1, mask.shape[1]) + batch = batch[mask] + + block_mask_fn = lambda b, h, q_idx, kv_idx: batch[q_idx] == batch[kv_idx] + block_mask = create_block_mask( + block_mask_fn, + B=None, + H=None, + Q_LEN=mv.shape[1], + KV_LEN=mv.shape[1], + device=mv.device, + ) + attn_kwargs = {"block_mask": block_mask} + else: + # reshape mask to broadcast correctly + mask = mask.bool() + attn_mask = mask[:, None, None, :, 0] # (batch_size, 1, 1, seq_len) + attn_kwargs = {"attn_mask": attn_mask} # call network - out_mv, _ = self.net(mv, s, attn_mask=mask) + out_mv, _ = self.net(mv, s, **attn_kwargs) output = extract_scalar(out_mv)[..., 0] # (batch_size, seq_len, num_classes) # aggregation if self.global_token: - output = output[:, 0] + output = output[0] if self.use_flex_attention else output + output = output[is_global] else: # mean aggregation - output[~mask[:, 0, 0]] = 0.0 - output = output.mean(dim=1) + if self.use_flex_attention: + output = self.aggregator(output[0], index=batch) + else: + output[~mask[:, 0, 0]] = 0.0 + output = output.mean(dim=1) return output From 83ff1fca49424e40dd8ad464cf2759c142ec3e6f Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Tue, 18 Mar 2025 16:57:03 +0100 Subject: [PATCH 24/29] Minor change --- weaver/nn/model/LGATr.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/weaver/nn/model/LGATr.py b/weaver/nn/model/LGATr.py index 2da07c7c..8f70436b 100644 --- a/weaver/nn/model/LGATr.py +++ b/weaver/nn/model/LGATr.py @@ -144,9 +144,8 @@ def __init__( if self.use_flex_attention: if compile_flex_attention: # torch.compile for attention speedup - if not torch.cuda.is_available(): - # suppress weird error on CPU - torch._dynamo.config.suppress_errors = True + # suppress weird error on CPU + torch._dynamo.config.suppress_errors = True lgatr.primitives.attention.flex_attention = torch.compile( lgatr.primitives.attention.flex_attention, dynamic=True ) From 52f848a86aebb35a52281bf21d62b3f82d30a4c7 Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Wed, 19 Mar 2025 09:15:30 +0100 Subject: [PATCH 25/29] Import lgatr package; remove sparse-tensor-business --- requirements.txt | 5 +-- weaver/nn/model/LGATr.py | 67 +++++----------------------------------- 2 files changed, 8 insertions(+), 64 deletions(-) diff --git a/requirements.txt b/requirements.txt index 04ff3afc..aa3763d7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,4 @@ lz4>=3.1.0 xxhash>=1.4.4 tables>=3.6.1 tensorboard>=2.2.0 -einops -opt_einsum -lgatr @ git+https://github.com/heidelberg-hepml/lgatr.git@basics -torch_geometric +lgatr diff --git a/weaver/nn/model/LGATr.py b/weaver/nn/model/LGATr.py index 8f70436b..10154ed2 100644 --- a/weaver/nn/model/LGATr.py +++ b/weaver/nn/model/LGATr.py @@ -1,7 +1,5 @@ import torch from torch import nn -from torch.nn.attention.flex_attention import create_block_mask -from torch_geometric.nn.aggr import MeanAggregation from lgatr import ( LGATr, @@ -11,7 +9,6 @@ get_spurions, gatr_config, ) -import lgatr.primitives.attention from .ParticleTransformer import SequenceTrimmer @@ -41,8 +38,6 @@ def __init__( dropout_prob: float = None, # time/memory configurations checkpoint_blocks: bool = False, - use_flex_attention: bool = True, - compile_flex_attention: bool = True, # gatr configurations use_fully_connected_subgroup: bool = True, mix_pseudoscalar_into_scalar: bool = True, @@ -97,11 +92,6 @@ def __init__( checkpoint_blocks : bool If True, use torch.utils.checkpoint.checkpoint to save memory at the cost of a slower backward pass. - use_flex_attention : bool - If True, embed jets as sparse tensors and use flex-attention - for an efficient block-diagonal attention matrix. - If False, use the default torch attention backend with zero-padding. - Using True saves a factor of ~2 in memory and maybe a bit of speed. use_fully_connected_subgroup : bool If True, model is only equivariant with respect to the fully connected subgroup of the Lorentz group, @@ -137,23 +127,8 @@ def __init__( # spurion business in_mv_channels = 1 - self.global_token = False # global_token + self.global_token = global_token self.spurion_token = spurion_token - self.use_flex_attention = use_flex_attention - - if self.use_flex_attention: - if compile_flex_attention: - # torch.compile for attention speedup - # suppress weird error on CPU - torch._dynamo.config.suppress_errors = True - lgatr.primitives.attention.flex_attention = torch.compile( - lgatr.primitives.attention.flex_attention, dynamic=True - ) - global create_block_mask - create_block_mask = torch.compile(create_block_mask, dynamic=True) - - if not self.global_token: - self.aggregator = MeanAggregation() num_spurions = get_num_spurions( beam_spurion, add_time_spurion, beam_mirror=beam_mirror @@ -254,34 +229,10 @@ def forward(self, x, v, mask): is_global = torch.zeros_like(s[:, :, 0], dtype=torch.bool) is_global[:, 0] = True - if self.use_flex_attention: - # flatten across batch and sequence dimension - mask = mask[:, :, 0] - mv = mv[mask].unsqueeze(0) - s = s[mask].unsqueeze(0) - - if self.global_token: - is_global = is_global[mask] - else: - batch = torch.arange(mask.shape[0], device=mask.device) - batch = batch.unsqueeze(-1).repeat(1, mask.shape[1]) - batch = batch[mask] - - block_mask_fn = lambda b, h, q_idx, kv_idx: batch[q_idx] == batch[kv_idx] - block_mask = create_block_mask( - block_mask_fn, - B=None, - H=None, - Q_LEN=mv.shape[1], - KV_LEN=mv.shape[1], - device=mv.device, - ) - attn_kwargs = {"block_mask": block_mask} - else: - # reshape mask to broadcast correctly - mask = mask.bool() - attn_mask = mask[:, None, None, :, 0] # (batch_size, 1, 1, seq_len) - attn_kwargs = {"attn_mask": attn_mask} + # reshape mask to broadcast correctly + mask = mask.bool() + attn_mask = mask[:, None, None, :, 0] # (batch_size, 1, 1, seq_len) + attn_kwargs = {"attn_mask": attn_mask} # call network out_mv, _ = self.net(mv, s, **attn_kwargs) @@ -289,15 +240,11 @@ def forward(self, x, v, mask): # aggregation if self.global_token: - output = output[0] if self.use_flex_attention else output output = output[is_global] else: # mean aggregation - if self.use_flex_attention: - output = self.aggregator(output[0], index=batch) - else: - output[~mask[:, 0, 0]] = 0.0 - output = output.mean(dim=1) + output[~mask[:, 0, 0]] = 0.0 + output = output.mean(dim=1) return output From 1c3aa5f76b143663b31830dca34510bd271bb015 Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Wed, 19 Mar 2025 09:15:51 +0100 Subject: [PATCH 26/29] Add Lion optimizer option --- requirements.txt | 1 + weaver/train.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/requirements.txt b/requirements.txt index aa3763d7..f292b3d7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,4 @@ xxhash>=1.4.4 tables>=3.6.1 tensorboard>=2.2.0 lgatr +pytorch_optimizer>=3.0.0 # Lion optimizer diff --git a/weaver/train.py b/weaver/train.py index ec3fbf54..5d3c4470 100644 --- a/weaver/train.py +++ b/weaver/train.py @@ -11,6 +11,7 @@ import math import copy import torch +from pytorch_optimizer import Lion from torch.utils.data import DataLoader from weaver.utils.logger import _logger, _configLogger @@ -516,6 +517,8 @@ def init_opt(args, model, **optimizer_options): opt = torch.optim.AdamW(parameters, lr=args.start_lr, **optimizer_options) elif args.optimizer == 'radam': opt = torch.optim.RAdam(parameters, lr=args.start_lr, **optimizer_options) + elif args.optimizer == 'lion': + opt = Lion(parameters, lr=args.start_lr, **optimizer_options) else: opt = getattr(torch.optim, args.optimizer)(parameters, lr=args.start_lr, **optimizer_options) From e6d4e45258b3596eb950f616b54d281e59e57d1b Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Wed, 19 Mar 2025 09:16:18 +0100 Subject: [PATCH 27/29] Turn off flop-counter and add comment on why --- weaver/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/weaver/train.py b/weaver/train.py index 5d3c4470..f419fb36 100644 --- a/weaver/train.py +++ b/weaver/train.py @@ -646,7 +646,8 @@ def model_setup(args, data_config, device='cpu'): _logger.info('The following weights has been frozen:\n - %s', '\n - '.join([name for name, p in model.named_parameters() if not p.requires_grad])) # _logger.info(model) - flops(model, model_info, device=device) + # dont use flop counting tool from 2019 (causes issues in modern torch, specifically einops; works with e.g. einops==0.6.1) + # flops(model, model_info, device=device) # loss function try: loss_func = network_module.get_loss(data_config, **network_options) From 7093ab256f27a74678e01da0e67b7aaeaefbe520 Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Sun, 1 Jun 2025 13:21:51 +0200 Subject: [PATCH 28/29] Changes based on recent lgatr rework --- weaver/nn/model/LGATr.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/weaver/nn/model/LGATr.py b/weaver/nn/model/LGATr.py index 10154ed2..67f147b7 100644 --- a/weaver/nn/model/LGATr.py +++ b/weaver/nn/model/LGATr.py @@ -32,9 +32,10 @@ def __init__( global_token: bool = True, activation: str = "gelu", multi_query: bool = False, - increase_hidden_channels: int = 2, + increase_hidden_channels_attention: int = 2, + increase_hidden_channels_mlp: int = 2, + num_hidden_layers_mlp: int = 1, head_scale: bool = False, - double_layernorm: bool = False, dropout_prob: float = None, # time/memory configurations checkpoint_blocks: bool = False, @@ -79,14 +80,15 @@ def __init__( Activation function in the MLP layers. multi_query : bool If True, use the same query for each head in attention. - increase_hidden_channels : int + increase_hidden_channels_attention : int Factor by which hidden_mv_channels is increased in attention. + increase_hidden_channels_mlp : int + Factor by which hidden_mv_channels is increased in the MLP. + num_hidden_layers_mlp : int + Number of hidden layers in the MLP. head_scale : bool If True, scale the attention heads with a learnable factor. Inspired by the NormFormer (https://arxiv.org/pdf/2110.09456) - double_layernorm : bool - If True, applies layer normalization also after attention. - The default is only before attention ('pre-layernorm transformer') dropout_prob : float Residual dropout after attention and MLP. checkpoint_blocks : bool @@ -149,16 +151,17 @@ def __init__( attention = dict( multi_query=multi_query, num_heads=num_heads, - increase_hidden_channels=increase_hidden_channels, - dropout_prob=dropout_prob, + increase_hidden_channels=increase_hidden_channels_attention, head_scale=head_scale, ) mlp = dict( activation=activation, - dropout_prob=dropout_prob, + increase_hidden_channels=increase_hidden_channels_mlp, + num_hidden_layers=num_hidden_layers_mlp, ) self.net = LGATr( + num_blocks=num_blocks, in_mv_channels=in_mv_channels, out_mv_channels=num_classes, hidden_mv_channels=hidden_mv_channels, @@ -167,8 +170,6 @@ def __init__( hidden_s_channels=hidden_s_channels, attention=attention, mlp=mlp, - num_blocks=num_blocks, - double_layernorm=double_layernorm, dropout_prob=dropout_prob, checkpoint_blocks=checkpoint_blocks, ) From 6bc9a588d4a4c5c8b6b35bef8fbdcdcb655cc0b5 Mon Sep 17 00:00:00 2001 From: Jonas Spinner Date: Sun, 1 Jun 2025 13:22:09 +0200 Subject: [PATCH 29/29] Add pyarrow to weaver requirements.txt --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index f292b3d7..3c934b4c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,3 +15,4 @@ tables>=3.6.1 tensorboard>=2.2.0 lgatr pytorch_optimizer>=3.0.0 # Lion optimizer +pyarrow