From 39a5862ed5d6a8d8b4ff61cd9e614cb374e42732 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:35 -0500 Subject: [PATCH 01/26] added basic GRIT code Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 57 ++++++ gridfm_graphkit/models/grit_transformer.py | 195 +++++++++++++++++++++ 2 files changed, 252 insertions(+) create mode 100644 examples/config/grit_pretraining.yaml create mode 100644 gridfm_graphkit/models/grit_transformer.py diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml new file mode 100644 index 0000000..a6566e0 --- /dev/null +++ b/examples/config/grit_pretraining.yaml @@ -0,0 +1,57 @@ +callbacks: + patience: 100 + tol: 0 +data: + baseMVA: 100 + learn_mask: false + mask_dim: 6 + mask_ratio: 0.5 + mask_type: rnd + mask_value: -1.0 + networks: + # - Texas2k_case1_2016summerpeak + - case24_ieee_rts + # - case118_ieee + # - case300_ieee + - case89_pegase + # - case240_pserc + normalization: baseMVAnorm + scenarios: + # - 5000 + - 5000 + - 5000 + # - 30000 + # - 50000 + # - 50000 + test_ratio: 0.1 + val_ratio: 0.1 + workers: 4 +model: + attention_head: 8 + dropout: 0.1 + edge_dim: 2 + hidden_size: 123 + input_dim: 9 + num_layers: 14 + output_dim: 6 + pe_dim: 20 + type: GPSTransformer # +optimizer: + beta1: 0.9 + beta2: 0.999 + learning_rate: 0.0001 + lr_decay: 0.7 + lr_patience: 10 +seed: 0 +training: + batch_size: 8 + epochs: 500 + loss_weights: + - 0.01 + - 0.99 + losses: + - MaskedMSE + - PBE + accelerator: auto + devices: auto + strategy: auto diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py new file mode 100644 index 0000000..3ee5e8e --- /dev/null +++ b/gridfm_graphkit/models/grit_transformer.py @@ -0,0 +1,195 @@ +from gridfm_graphkit.io.registries import MODELS_REGISTRY +from torch import nn +import torch +import torch_geometric.graphgym.register as register +from torch_geometric.graphgym.config import cfg +from torch_geometric.graphgym.models.gnn import GNNPreMP +from torch_geometric.graphgym.models.layer import (new_layer_config, + BatchNorm1dNode) +from torch_geometric.graphgym.register import register_network +from torch_geometric.graphgym.models.layer import new_layer_config, MLP + + + +class FeatureEncoder(torch.nn.Module): + """ + Encoding node and edge features + + Args: + dim_in (int): Input feature dimension + """ + def __init__(self, dim_in): + super(FeatureEncoder, self).__init__() + self.dim_in = dim_in + if cfg.dataset.node_encoder: + # Encode integer node features via nn.Embeddings + NodeEncoder = register.node_encoder_dict[ + cfg.dataset.node_encoder_name] + self.node_encoder = NodeEncoder(cfg.gnn.dim_inner) + if cfg.dataset.node_encoder_bn: + self.node_encoder_bn = BatchNorm1dNode( + new_layer_config(cfg.gnn.dim_inner, -1, -1, has_act=False, + has_bias=False, cfg=cfg)) + # Update dim_in to reflect the new dimension fo the node features + self.dim_in = cfg.gnn.dim_inner + if cfg.dataset.edge_encoder: + # Hard-limit max edge dim for PNA. + if 'PNA' in cfg.gt.layer_type: + cfg.gnn.dim_edge = min(128, cfg.gnn.dim_inner) + else: + cfg.gnn.dim_edge = cfg.gnn.dim_inner + # Encode integer edge features via nn.Embeddings + EdgeEncoder = register.edge_encoder_dict[ + cfg.dataset.edge_encoder_name] + self.edge_encoder = EdgeEncoder(cfg.gnn.dim_edge) + if cfg.dataset.edge_encoder_bn: + self.edge_encoder_bn = BatchNorm1dNode( + new_layer_config(cfg.gnn.dim_edge, -1, -1, has_act=False, + has_bias=False, cfg=cfg)) + + def forward(self, batch): + for module in self.children(): + batch = module(batch) + return batch + + +@register_head('decoder_head') +class GNNDecoderHead(nn.Module): + """ + Predictoin head for encoder-decoder networks. + + Args: + dim_in (int): Input dimension # TODO update arg comments as needed + dim_out (int): Output dimension. For binary prediction, dim_out=1. + """ + + def __init__(self, dim_in, dim_out): + super(GNNDecoderHead, self).__init__() + + + + # note that the input and output dimensions are from the config file + # if we want this to be variable that will have to change with + # each layer + + # TODO consider use of a bottleneck + + # note the config is imported as in other modules + + # the number of config layers should apriori be different than the encoder + + + global_model_type = cfg.gt.get('layer_type', "GritTransformer") + + TransformerLayer = register.layer_dict.get(global_model_type) + + layers = [] + for l in range(cfg.gnn.layers_decode): + layers.append(TransformerLayer( + in_dim=cfg.gt.dim_hidden, + out_dim=cfg.gt.dim_hidden, + num_heads=cfg.gt.n_heads, + dropout=cfg.gt.dropout, # TODO could migrate this and others to gnn in config + act=cfg.gnn.act, + attn_dropout=cfg.gt.attn_dropout, + layer_norm=cfg.gt.layer_norm, + batch_norm=cfg.gt.batch_norm, + residual=True, + norm_e=cfg.gt.attn.norm_e, + O_e=cfg.gt.attn.O_e, + cfg=cfg.gt, + )) + # layers = [] + + self.layers = torch.nn.Sequential(*layers) + + + + self.layer_post_mp = MLP( + new_layer_config(dim_in, dim_out, cfg.gnn.layers_post_mp, + has_act=False, has_bias=True, cfg=cfg)) + + + + def _apply_index(self, batch): + return batch.x, batch.y + + def forward(self, batch): + batch = self.layers(batch) + + # follow GMAE here and make a final linear projection from the + # hiden dimension to the output dimension + batch = self.layer_post_mp(batch) + + pred, label = self._apply_index(batch) + #print('>>>>>>', pred.size(),label.size()) + return pred, label + + + +@MODELS_REGISTRY.register("GRIT") +class GritTransformer(torch.nn.Module): + ''' + The proposed GritTransformer (Graph Inductive Bias Transformer) + ''' + + def __init__(self, dim_in, dim_out): + super().__init__() + self.encoder = FeatureEncoder(dim_in) + dim_in = self.encoder.dim_in + + self.ablation = True + self.ablation = False + + if cfg.posenc_RRWP.enable: + self.rrwp_abs_encoder = register.node_encoder_dict["rrwp_linear"]\ + (cfg.posenc_RRWP.ksteps, cfg.gnn.dim_inner) + rel_pe_dim = cfg.posenc_RRWP.ksteps + self.rrwp_rel_encoder = register.edge_encoder_dict["rrwp_linear"] \ + (rel_pe_dim, cfg.gnn.dim_edge, + pad_to_full_graph=cfg.gt.attn.full_attn, + add_node_attr_as_self_loop=False, + fill_value=0. + ) + + + if cfg.gnn.layers_pre_mp > 0: + self.pre_mp = GNNPreMP( + dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp) + dim_in = cfg.gnn.dim_inner + + assert cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in, \ + "The inner and hidden dims must match." + + global_model_type = cfg.gt.get('layer_type', "GritTransformer") + # global_model_type = "GritTransformer" + + TransformerLayer = register.layer_dict.get(global_model_type) + + layers = [] + for l in range(cfg.gt.layers): + layers.append(TransformerLayer( + in_dim=cfg.gt.dim_hidden, + out_dim=cfg.gt.dim_hidden, + num_heads=cfg.gt.n_heads, + dropout=cfg.gt.dropout, + act=cfg.gnn.act, + attn_dropout=cfg.gt.attn_dropout, + layer_norm=cfg.gt.layer_norm, + batch_norm=cfg.gt.batch_norm, + residual=True, + norm_e=cfg.gt.attn.norm_e, + O_e=cfg.gt.attn.O_e, + cfg=cfg.gt, + )) + # layers = [] + + self.layers = torch.nn.Sequential(*layers) + GNNHead = register.head_dict[cfg.gnn.head] + self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out) + + def forward(self, batch): + for module in self.children(): + batch = module(batch) + + return batch \ No newline at end of file From 922d6cefe51a19d31cff2b3cda09603e2600a349 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:36 -0500 Subject: [PATCH 02/26] initial connection of model to config Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 35 +++- gridfm_graphkit/models/grit_transformer.py | 202 ++++++++------------- 2 files changed, 113 insertions(+), 124 deletions(-) diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index a6566e0..904d6dc 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -32,10 +32,41 @@ model: edge_dim: 2 hidden_size: 123 input_dim: 9 - num_layers: 14 + num_layers: 10 output_dim: 6 pe_dim: 20 - type: GPSTransformer # + type: GRIT #GPSTransformer # + layers_pre_mp: 0 + act: relu + encoder: + node_encoder: True + edge_encoder: True + node_encoder_name: TODO + node_encoder_bn: True + gt: + layer_type: GritTransformer + # layers: 10 + # n_heads: 8 + dim_hidden: 64 # `gt.dim_hidden` must match `gnn.dim_inner` + # dropout: 0.0 + layer_norm: False + batch_norm: True + update_e: True + attn_dropout: 0.2 + attn: + clamp: 5. + act: 'relu' + full_attn: True + edge_enhance: True + O_e: True + norm_e: True + signed_sqrt: True + posenc_RRWP: + enable: True + ksteps: 21 + add_identity: True + add_node_attr: False + add_inverse: False optimizer: beta1: 0.9 beta2: 0.999 diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 3ee5e8e..b09f527 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -1,12 +1,10 @@ from gridfm_graphkit.io.registries import MODELS_REGISTRY from torch import nn import torch -import torch_geometric.graphgym.register as register -from torch_geometric.graphgym.config import cfg + from torch_geometric.graphgym.models.gnn import GNNPreMP from torch_geometric.graphgym.models.layer import (new_layer_config, BatchNorm1dNode) -from torch_geometric.graphgym.register import register_network from torch_geometric.graphgym.models.layer import new_layer_config, MLP @@ -17,114 +15,49 @@ class FeatureEncoder(torch.nn.Module): Args: dim_in (int): Input feature dimension + + + TODO replace 'register' with local version of it + """ - def __init__(self, dim_in): + def __init__( + self, + dim_in, + dim_inner, + args + ): super(FeatureEncoder, self).__init__() self.dim_in = dim_in - if cfg.dataset.node_encoder: + if args.node_encoder: # Encode integer node features via nn.Embeddings NodeEncoder = register.node_encoder_dict[ - cfg.dataset.node_encoder_name] - self.node_encoder = NodeEncoder(cfg.gnn.dim_inner) - if cfg.dataset.node_encoder_bn: + args.node_encoder_name] + self.node_encoder = NodeEncoder(dim_inner) + if args.node_encoder_bn: self.node_encoder_bn = BatchNorm1dNode( - new_layer_config(cfg.gnn.dim_inner, -1, -1, has_act=False, + new_layer_config(dim_inner, -1, -1, has_act=False, has_bias=False, cfg=cfg)) # Update dim_in to reflect the new dimension fo the node features - self.dim_in = cfg.gnn.dim_inner - if cfg.dataset.edge_encoder: + self.dim_in = dim_inner + if args.edge_encoder: # Hard-limit max edge dim for PNA. - if 'PNA' in cfg.gt.layer_type: - cfg.gnn.dim_edge = min(128, cfg.gnn.dim_inner) + if 'PNA' in args.model.gt.layer_type: # TODO remove condition if PNA not needed + dim_edge = min(128, dim_inner) else: - cfg.gnn.dim_edge = cfg.gnn.dim_inner + dim_edge = dim_inner # Encode integer edge features via nn.Embeddings EdgeEncoder = register.edge_encoder_dict[ cfg.dataset.edge_encoder_name] - self.edge_encoder = EdgeEncoder(cfg.gnn.dim_edge) + self.edge_encoder = EdgeEncoder(dim_edge) if cfg.dataset.edge_encoder_bn: self.edge_encoder_bn = BatchNorm1dNode( - new_layer_config(cfg.gnn.dim_edge, -1, -1, has_act=False, + new_layer_config(dim_edge, -1, -1, has_act=False, has_bias=False, cfg=cfg)) def forward(self, batch): for module in self.children(): batch = module(batch) return batch - - -@register_head('decoder_head') -class GNNDecoderHead(nn.Module): - """ - Predictoin head for encoder-decoder networks. - - Args: - dim_in (int): Input dimension # TODO update arg comments as needed - dim_out (int): Output dimension. For binary prediction, dim_out=1. - """ - - def __init__(self, dim_in, dim_out): - super(GNNDecoderHead, self).__init__() - - - - # note that the input and output dimensions are from the config file - # if we want this to be variable that will have to change with - # each layer - - # TODO consider use of a bottleneck - - # note the config is imported as in other modules - - # the number of config layers should apriori be different than the encoder - - - global_model_type = cfg.gt.get('layer_type', "GritTransformer") - - TransformerLayer = register.layer_dict.get(global_model_type) - - layers = [] - for l in range(cfg.gnn.layers_decode): - layers.append(TransformerLayer( - in_dim=cfg.gt.dim_hidden, - out_dim=cfg.gt.dim_hidden, - num_heads=cfg.gt.n_heads, - dropout=cfg.gt.dropout, # TODO could migrate this and others to gnn in config - act=cfg.gnn.act, - attn_dropout=cfg.gt.attn_dropout, - layer_norm=cfg.gt.layer_norm, - batch_norm=cfg.gt.batch_norm, - residual=True, - norm_e=cfg.gt.attn.norm_e, - O_e=cfg.gt.attn.O_e, - cfg=cfg.gt, - )) - # layers = [] - - self.layers = torch.nn.Sequential(*layers) - - - - self.layer_post_mp = MLP( - new_layer_config(dim_in, dim_out, cfg.gnn.layers_post_mp, - has_act=False, has_bias=True, cfg=cfg)) - - - - def _apply_index(self, batch): - return batch.x, batch.y - - def forward(self, batch): - batch = self.layers(batch) - - # follow GMAE here and make a final linear projection from the - # hiden dimension to the output dimension - batch = self.layer_post_mp(batch) - - pred, label = self._apply_index(batch) - #print('>>>>>>', pred.size(),label.size()) - return pred, label - @MODELS_REGISTRY.register("GRIT") @@ -133,60 +66,85 @@ class GritTransformer(torch.nn.Module): The proposed GritTransformer (Graph Inductive Bias Transformer) ''' - def __init__(self, dim_in, dim_out): + def __init__(self, args): super().__init__() - self.encoder = FeatureEncoder(dim_in) - dim_in = self.encoder.dim_in - self.ablation = True - self.ablation = False + # ### TODO remove default args not needed #### + # self.input_dim = + # self.hidden_dim = + # self.output_dim = + # self.edge_dim = + # self.num_layers = args.model.num_layers + # self.heads = getattr(args.model, "attention_head", 1) + # self.dropout = getattr(args.model, "dropout", 0.0) + # ### ### + + dim_in = args.model.input_dim + dim_out = args.model.output_dim + dim_inner = args.model.hidden_size + dim_edge = args.model.edge_dim + num_heads = args.model.attention_head + dropout = args.model.dropout + num_layers = args.model.num_layers + + self.encoder = FeatureEncoder( + dim_in, + dim_inner, + args.model.encoder + ) # TODO add args + dim_in = self.encoder.dim_in + - if cfg.posenc_RRWP.enable: + if args.model.posenc_RRWP.enable: + # TODO connect 'register' to local version self.rrwp_abs_encoder = register.node_encoder_dict["rrwp_linear"]\ - (cfg.posenc_RRWP.ksteps, cfg.gnn.dim_inner) - rel_pe_dim = cfg.posenc_RRWP.ksteps + (args.model.posenc_RRWP.ksteps, dim_inner) + rel_pe_dim = args.model.posenc_RRWP.ksteps self.rrwp_rel_encoder = register.edge_encoder_dict["rrwp_linear"] \ - (rel_pe_dim, cfg.gnn.dim_edge, - pad_to_full_graph=cfg.gt.attn.full_attn, + (rel_pe_dim, dim_edge, + pad_to_full_graph=args.model.gt.attn.full_attn, add_node_attr_as_self_loop=False, fill_value=0. ) - if cfg.gnn.layers_pre_mp > 0: + if args.model.layers_pre_mp > 0: self.pre_mp = GNNPreMP( - dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp) - dim_in = cfg.gnn.dim_inner + dim_in, dim_inner, args.model.layers_pre_mp) + dim_in = dim_inner - assert cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in, \ + assert args.model.hidden_size == dim_inner == dim_in, \ "The inner and hidden dims must match." - global_model_type = cfg.gt.get('layer_type', "GritTransformer") + global_model_type = args.model.gt.layer_type # global_model_type = "GritTransformer" - + # TODO replace this with local register logic TransformerLayer = register.layer_dict.get(global_model_type) layers = [] - for l in range(cfg.gt.layers): + for ll in range(num_layers): layers.append(TransformerLayer( - in_dim=cfg.gt.dim_hidden, - out_dim=cfg.gt.dim_hidden, - num_heads=cfg.gt.n_heads, - dropout=cfg.gt.dropout, - act=cfg.gnn.act, - attn_dropout=cfg.gt.attn_dropout, - layer_norm=cfg.gt.layer_norm, - batch_norm=cfg.gt.batch_norm, + in_dim=args.model.gt.dim_hidden, + out_dim=args.model.gt.dim_hidden, + num_heads=num_heads, + dropout=dropout, + act=args.model.act, + attn_dropout=args.model.gt.attn_dropout, + layer_norm=args.model.gt.layer_norm, + batch_norm=args.model.gt.batch_norm, residual=True, - norm_e=cfg.gt.attn.norm_e, - O_e=cfg.gt.attn.O_e, - cfg=cfg.gt, + norm_e=args.model.gt.attn.norm_e, + O_e=args.model.gt.attn.O_e, + cfg=args.model.gt, )) - # layers = [] - self.layers = torch.nn.Sequential(*layers) - GNNHead = register.head_dict[cfg.gnn.head] - self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out) + self.layers = nn.Sequential(*layers) + + self.decoder = nn.Sequential( + nn.Linear(dim_inner, dim_inner), + nn.LeakyReLU(), + nn.Linear(dim_inner, dim_out), + ) def forward(self, batch): for module in self.children(): From e8281ac03c1f97fed0b57419b96665fc4fb90020 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:36 -0500 Subject: [PATCH 03/26] collect model components and replace old register method Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_layer.py | 346 +++++++++++++++++++++ gridfm_graphkit/models/grit_transformer.py | 79 +++-- gridfm_graphkit/models/rrwp_encoder.py | 192 ++++++++++++ 3 files changed, 583 insertions(+), 34 deletions(-) create mode 100644 gridfm_graphkit/models/grit_layer.py create mode 100644 gridfm_graphkit/models/rrwp_encoder.py diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py new file mode 100644 index 0000000..dc6cf97 --- /dev/null +++ b/gridfm_graphkit/models/grit_layer.py @@ -0,0 +1,346 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_geometric as pyg +from torch_geometric.utils.num_nodes import maybe_num_nodes +from torch_scatter import scatter, scatter_max, scatter_add + +from grit.utils import negate_edge_index +from torch_geometric.graphgym.register import * +import opt_einsum as oe + +from yacs.config import CfgNode as CN + +import warnings + +def pyg_softmax(src, index, num_nodes=None): + r"""Computes a sparsely evaluated softmax. + Given a value tensor :attr:`src`, this function first groups the values + along the first dimension based on the indices specified in :attr:`index`, + and then proceeds to compute the softmax individually for each group. + + Args: + src (Tensor): The source tensor. + index (LongTensor): The indices of elements for applying the softmax. + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) + + :rtype: :class:`Tensor` + """ + + num_nodes = maybe_num_nodes(index, num_nodes) + + out = src - scatter_max(src, index, dim=0, dim_size=num_nodes)[0][index] + out = out.exp() + out = out / ( + scatter_add(out, index, dim=0, dim_size=num_nodes)[index] + 1e-16) + + return out + + + +class MultiHeadAttentionLayerGritSparse(nn.Module): + """ + Proposed Attention Computation for GRIT + """ + + def __init__(self, in_dim, out_dim, num_heads, use_bias, + clamp=5., dropout=0., act=None, + edge_enhance=True, + sqrt_relu=False, + signed_sqrt=True, + cfg=CN(), + **kwargs): + super().__init__() + + self.out_dim = out_dim + self.num_heads = num_heads + self.dropout = nn.Dropout(dropout) + self.clamp = np.abs(clamp) if clamp is not None else None + self.edge_enhance = edge_enhance + + self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=True) + self.K = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias) + self.E = nn.Linear(in_dim, out_dim * num_heads * 2, bias=True) + self.V = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias) + nn.init.xavier_normal_(self.Q.weight) + nn.init.xavier_normal_(self.K.weight) + nn.init.xavier_normal_(self.E.weight) + nn.init.xavier_normal_(self.V.weight) + + self.Aw = nn.Parameter(torch.zeros(self.out_dim, self.num_heads, 1), requires_grad=True) + nn.init.xavier_normal_(self.Aw) + + if act is None: + self.act = nn.Identity() + else: + self.act = act_dict[act]() + + if self.edge_enhance: + self.VeRow = nn.Parameter(torch.zeros(self.out_dim, self.num_heads, self.out_dim), requires_grad=True) + nn.init.xavier_normal_(self.VeRow) + + def propagate_attention(self, batch): + src = batch.K_h[batch.edge_index[0]] # (num relative) x num_heads x out_dim + dest = batch.Q_h[batch.edge_index[1]] # (num relative) x num_heads x out_dim + score = src + dest # element-wise multiplication + + if batch.get("E", None) is not None: + batch.E = batch.E.view(-1, self.num_heads, self.out_dim * 2) + E_w, E_b = batch.E[:, :, :self.out_dim], batch.E[:, :, self.out_dim:] + # (num relative) x num_heads x out_dim + score = score * E_w + score = torch.sqrt(torch.relu(score)) - torch.sqrt(torch.relu(-score)) + score = score + E_b + + score = self.act(score) + e_t = score + + # output edge + if batch.get("E", None) is not None: + batch.wE = score.flatten(1) + + # final attn + score = oe.contract("ehd, dhc->ehc", score, self.Aw, backend="torch") + if self.clamp is not None: + score = torch.clamp(score, min=-self.clamp, max=self.clamp) + + raw_attn = score + score = pyg_softmax(score, batch.edge_index[1]) # (num relative) x num_heads x 1 + score = self.dropout(score) + batch.attn = score + + # Aggregate with Attn-Score + msg = batch.V_h[batch.edge_index[0]] * score # (num relative) x num_heads x out_dim + batch.wV = torch.zeros_like(batch.V_h) # (num nodes in batch) x num_heads x out_dim + scatter(msg, batch.edge_index[1], dim=0, out=batch.wV, reduce='add') + + if self.edge_enhance and batch.E is not None: + rowV = scatter(e_t * score, batch.edge_index[1], dim=0, reduce="add") + rowV = oe.contract("nhd, dhc -> nhc", rowV, self.VeRow, backend="torch") + batch.wV = batch.wV + rowV + + def forward(self, batch): + Q_h = self.Q(batch.x) + K_h = self.K(batch.x) + + V_h = self.V(batch.x) + if batch.get("edge_attr", None) is not None: + batch.E = self.E(batch.edge_attr) + else: + batch.E = None + + batch.Q_h = Q_h.view(-1, self.num_heads, self.out_dim) + batch.K_h = K_h.view(-1, self.num_heads, self.out_dim) + batch.V_h = V_h.view(-1, self.num_heads, self.out_dim) + self.propagate_attention(batch) + h_out = batch.wV + e_out = batch.get('wE', None) + + return h_out, e_out + + +@register_layer("GritTransformer") +class GritTransformerLayer(nn.Module): + """ + Proposed Transformer Layer for GRIT + """ + def __init__(self, in_dim, out_dim, num_heads, + dropout=0.0, + attn_dropout=0.0, + layer_norm=False, batch_norm=True, + residual=True, + act='relu', + norm_e=True, + O_e=True, + cfg=dict(), + **kwargs): + super().__init__() + + self.debug = False + self.in_channels = in_dim + self.out_channels = out_dim + self.in_dim = in_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.dropout = dropout + self.residual = residual + self.layer_norm = layer_norm + self.batch_norm = batch_norm + + # ------- + self.update_e = cfg.get("update_e", True) + self.bn_momentum = cfg.bn_momentum + self.bn_no_runner = cfg.bn_no_runner + self.rezero = cfg.get("rezero", False) + + self.act = act_dict[act]() if act is not None else nn.Identity() + if cfg.get("attn", None) is None: + cfg.attn = dict() + self.use_attn = cfg.attn.get("use", True) + # self.sigmoid_deg = cfg.attn.get("sigmoid_deg", False) + self.deg_scaler = cfg.attn.get("deg_scaler", True) + + self.attention = MultiHeadAttentionLayerGritSparse( + in_dim=in_dim, + out_dim=out_dim // num_heads, + num_heads=num_heads, + use_bias=cfg.attn.get("use_bias", False), + dropout=attn_dropout, + clamp=cfg.attn.get("clamp", 5.), + act=cfg.attn.get("act", "relu"), + edge_enhance=cfg.attn.get("edge_enhance", True), + sqrt_relu=cfg.attn.get("sqrt_relu", False), + signed_sqrt=cfg.attn.get("signed_sqrt", False), + scaled_attn =cfg.attn.get("scaled_attn", False), + no_qk=cfg.attn.get("no_qk", False), + ) + + if cfg.attn.get('graphormer_attn', False): + self.attention = MultiHeadAttentionLayerGraphormerSparse( + in_dim=in_dim, + out_dim=out_dim // num_heads, + num_heads=num_heads, + use_bias=cfg.attn.get("use_bias", False), + dropout=attn_dropout, + clamp=cfg.attn.get("clamp", 5.), + act=cfg.attn.get("act", "relu"), + edge_enhance=True, + sqrt_relu=cfg.attn.get("sqrt_relu", False), + signed_sqrt=cfg.attn.get("signed_sqrt", False), + scaled_attn =cfg.attn.get("scaled_attn", False), + no_qk=cfg.attn.get("no_qk", False), + ) + + + + self.O_h = nn.Linear(out_dim//num_heads * num_heads, out_dim) + if O_e: + self.O_e = nn.Linear(out_dim//num_heads * num_heads, out_dim) + else: + self.O_e = nn.Identity() + + # -------- Deg Scaler Option ------ + + if self.deg_scaler: + self.deg_coef = nn.Parameter(torch.zeros(1, out_dim//num_heads * num_heads, 2)) + nn.init.xavier_normal_(self.deg_coef) + + if self.layer_norm: + self.layer_norm1_h = nn.LayerNorm(out_dim) + self.layer_norm1_e = nn.LayerNorm(out_dim) if norm_e else nn.Identity() + + if self.batch_norm: + # when the batch_size is really small, use smaller momentum to avoid bad mini-batch leading to extremely bad val/test loss (NaN) + self.batch_norm1_h = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.bn_momentum) + self.batch_norm1_e = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.bn_momentum) if norm_e else nn.Identity() + + # FFN for h + self.FFN_h_layer1 = nn.Linear(out_dim, out_dim * 2) + self.FFN_h_layer2 = nn.Linear(out_dim * 2, out_dim) + + if self.layer_norm: + self.layer_norm2_h = nn.LayerNorm(out_dim) + + if self.batch_norm: + self.batch_norm2_h = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.bn_momentum) + + if self.rezero: + self.alpha1_h = nn.Parameter(torch.zeros(1,1)) + self.alpha2_h = nn.Parameter(torch.zeros(1,1)) + self.alpha1_e = nn.Parameter(torch.zeros(1,1)) + + def forward(self, batch): + h = batch.x + num_nodes = batch.num_nodes + log_deg = get_log_deg(batch) + + h_in1 = h # for first residual connection + e_in1 = batch.get("edge_attr", None) + e = None + # multi-head attention out + + h_attn_out, e_attn_out = self.attention(batch) + + h = h_attn_out.view(num_nodes, -1) + h = F.dropout(h, self.dropout, training=self.training) + + # degree scaler + if self.deg_scaler: + h = torch.stack([h, h * log_deg], dim=-1) + h = (h * self.deg_coef).sum(dim=-1) + + h = self.O_h(h) + if e_attn_out is not None: + e = e_attn_out.flatten(1) + e = F.dropout(e, self.dropout, training=self.training) + e = self.O_e(e) + + if self.residual: + if self.rezero: h = h * self.alpha1_h + h = h_in1 + h # residual connection + if e is not None: + if self.rezero: e = e * self.alpha1_e + e = e + e_in1 + + if self.layer_norm: + h = self.layer_norm1_h(h) + if e is not None: e = self.layer_norm1_e(e) + + if self.batch_norm: + h = self.batch_norm1_h(h) + if e is not None: e = self.batch_norm1_e(e) + + # FFN for h + h_in2 = h # for second residual connection + h = self.FFN_h_layer1(h) + h = self.act(h) + h = F.dropout(h, self.dropout, training=self.training) + h = self.FFN_h_layer2(h) + + if self.residual: + if self.rezero: h = h * self.alpha2_h + h = h_in2 + h # residual connection + + if self.layer_norm: + h = self.layer_norm2_h(h) + + if self.batch_norm: + h = self.batch_norm2_h(h) + + batch.x = h + if self.update_e: + batch.edge_attr = e + else: + batch.edge_attr = e_in1 + + return batch + + def __repr__(self): + return '{}(in_channels={}, out_channels={}, heads={}, residual={})\n[{}]'.format( + self.__class__.__name__, + self.in_channels, + self.out_channels, self.num_heads, self.residual, + super().__repr__(), + ) + + +@torch.no_grad() +def get_log_deg(batch): + if "log_deg" in batch: + log_deg = batch.log_deg + elif "deg" in batch: + deg = batch.deg + log_deg = torch.log(deg + 1).unsqueeze(-1) + else: + warnings.warn("Compute the degree on the fly; Might be problematric if have applied edge-padding to complete graphs") + deg = pyg.utils.degree(batch.edge_index[1], + num_nodes=batch.num_nodes, + dtype=torch.float + ) + log_deg = torch.log(deg + 1) + log_deg = log_deg.view(batch.num_nodes, 1) + return log_deg + + diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index b09f527..e3c6047 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -2,12 +2,42 @@ from torch import nn import torch +from rrwp_encoder import RRWPLinearNodeEncoder, RRWPLinearEdgeEncoder +from grit_layer import GritTransformerLayer + +# TODO verify use from torch_geometric.graphgym.models.gnn import GNNPreMP from torch_geometric.graphgym.models.layer import (new_layer_config, BatchNorm1dNode) from torch_geometric.graphgym.models.layer import new_layer_config, MLP +class LinearNodeEncoder(torch.nn.Module): + def __init__(self, emb_dim): + super().__init__() + + self.encoder = torch.nn.Linear(cfg.share.dim_in, emb_dim) + + def forward(self, batch): + batch.x = self.encoder(batch.x) + return batch + +class LinearEdgeEncoder(torch.nn.Module): + def __init__(self, emb_dim): + super().__init__() + if cfg.dataset.name in ['MNIST', 'CIFAR10']: + self.in_dim = 1 + elif cfg.dataset.name.startswith('attributed_triangle-'): + self.in_dim = 2 + else: + raise ValueError("Input edge feature dim is required to be hardset " + "or refactored to use a cfg option.") + self.encoder = torch.nn.Linear(self.in_dim, emb_dim) + + def forward(self, batch): + batch.edge_attr = self.encoder(batch.edge_attr.view(-1, self.in_dim)) + return batch + class FeatureEncoder(torch.nn.Module): """ @@ -16,9 +46,6 @@ class FeatureEncoder(torch.nn.Module): Args: dim_in (int): Input feature dimension - - TODO replace 'register' with local version of it - """ def __init__( self, @@ -30,9 +57,7 @@ def __init__( self.dim_in = dim_in if args.node_encoder: # Encode integer node features via nn.Embeddings - NodeEncoder = register.node_encoder_dict[ - args.node_encoder_name] - self.node_encoder = NodeEncoder(dim_inner) + self.node_encoder = LinearNodeEncoder(dim_inner) if args.node_encoder_bn: self.node_encoder_bn = BatchNorm1dNode( new_layer_config(dim_inner, -1, -1, has_act=False, @@ -46,9 +71,7 @@ def __init__( else: dim_edge = dim_inner # Encode integer edge features via nn.Embeddings - EdgeEncoder = register.edge_encoder_dict[ - cfg.dataset.edge_encoder_name] - self.edge_encoder = EdgeEncoder(dim_edge) + self.edge_encoder = LinearEdgeEncoder(dim_edge) if cfg.dataset.edge_encoder_bn: self.edge_encoder_bn = BatchNorm1dNode( new_layer_config(dim_edge, -1, -1, has_act=False, @@ -65,19 +88,9 @@ class GritTransformer(torch.nn.Module): ''' The proposed GritTransformer (Graph Inductive Bias Transformer) ''' - def __init__(self, args): super().__init__() - # ### TODO remove default args not needed #### - # self.input_dim = - # self.hidden_dim = - # self.output_dim = - # self.edge_dim = - # self.num_layers = args.model.num_layers - # self.heads = getattr(args.model, "attention_head", 1) - # self.dropout = getattr(args.model, "dropout", 0.0) - # ### ### dim_in = args.model.input_dim dim_out = args.model.output_dim @@ -96,16 +109,19 @@ def __init__(self, args): if args.model.posenc_RRWP.enable: - # TODO connect 'register' to local version - self.rrwp_abs_encoder = register.node_encoder_dict["rrwp_linear"]\ - (args.model.posenc_RRWP.ksteps, dim_inner) + + self.rrwp_abs_encoder = RRWPLinearNodeEncoder( + args.model.posenc_RRWP.ksteps, + dim_inner + ) rel_pe_dim = args.model.posenc_RRWP.ksteps - self.rrwp_rel_encoder = register.edge_encoder_dict["rrwp_linear"] \ - (rel_pe_dim, dim_edge, - pad_to_full_graph=args.model.gt.attn.full_attn, - add_node_attr_as_self_loop=False, - fill_value=0. - ) + self.rrwp_rel_encoder = RRWPLinearNodeEncoder( + rel_pe_dim, + dim_edge, + pad_to_full_graph=args.model.gt.attn.full_attn, + add_node_attr_as_self_loop=False, + fill_value=0. + ) if args.model.layers_pre_mp > 0: @@ -116,14 +132,9 @@ def __init__(self, args): assert args.model.hidden_size == dim_inner == dim_in, \ "The inner and hidden dims must match." - global_model_type = args.model.gt.layer_type - # global_model_type = "GritTransformer" - # TODO replace this with local register logic - TransformerLayer = register.layer_dict.get(global_model_type) - layers = [] for ll in range(num_layers): - layers.append(TransformerLayer( + layers.append(GritTransformerLayer( in_dim=args.model.gt.dim_hidden, out_dim=args.model.gt.dim_hidden, num_heads=num_heads, diff --git a/gridfm_graphkit/models/rrwp_encoder.py b/gridfm_graphkit/models/rrwp_encoder.py new file mode 100644 index 0000000..f98118e --- /dev/null +++ b/gridfm_graphkit/models/rrwp_encoder.py @@ -0,0 +1,192 @@ +''' + The RRWP encoder for GRIT (ours) +''' +import torch +from torch import nn +from torch.nn import functional as F +from ogb.utils.features import get_bond_feature_dims +import torch_sparse + +import torch_geometric as pyg +from torch_geometric.graphgym.register import ( + register_edge_encoder, + register_node_encoder, +) + +from torch_geometric.utils import remove_self_loops, add_remaining_self_loops, add_self_loops +from torch_scatter import scatter +import warnings + +def full_edge_index(edge_index, batch=None): + """ + Retunr the Full batched sparse adjacency matrices given by edge indices. + Returns batched sparse adjacency matrices with exactly those edges that + are not in the input `edge_index` while ignoring self-loops. + Implementation inspired by `torch_geometric.utils.to_dense_adj` + Args: + edge_index: The edge indices. + batch: Batch vector, which assigns each node to a specific example. + Returns: + Complementary edge index. + """ + + if batch is None: + batch = edge_index.new_zeros(edge_index.max().item() + 1) + + batch_size = batch.max().item() + 1 + one = batch.new_ones(batch.size(0)) + num_nodes = scatter(one, batch, + dim=0, dim_size=batch_size, reduce='add') + cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)]) + + negative_index_list = [] + for i in range(batch_size): + n = num_nodes[i].item() + size = [n, n] + adj = torch.ones(size, dtype=torch.short, + device=edge_index.device) + + adj = adj.view(size) + _edge_index = adj.nonzero(as_tuple=False).t().contiguous() + # _edge_index, _ = remove_self_loops(_edge_index) + negative_index_list.append(_edge_index + cum_nodes[i]) + + edge_index_full = torch.cat(negative_index_list, dim=1).contiguous() + return edge_index_full + + + +class RRWPLinearNodeEncoder(torch.nn.Module): + """ + FC_1(RRWP) + FC_2 (Node-attr) + note: FC_2 is given by the Typedict encoder of node-attr in some cases + Parameters: + num_classes - the number of classes for the embedding mapping to learn + """ + def __init__(self, emb_dim, out_dim, use_bias=False, batchnorm=False, layernorm=False, pe_name="rrwp"): + super().__init__() + self.batchnorm = batchnorm + self.layernorm = layernorm + self.name = pe_name + + self.fc = nn.Linear(emb_dim, out_dim, bias=use_bias) + torch.nn.init.xavier_uniform_(self.fc.weight) + + if self.batchnorm: + self.bn = nn.BatchNorm1d(out_dim) + if self.layernorm: + self.ln = nn.LayerNorm(out_dim) + + def forward(self, batch): + # Encode just the first dimension if more exist + rrwp = batch[f"{self.name}"] + rrwp = self.fc(rrwp) + + if self.batchnorm: + rrwp = self.bn(rrwp) + + if self.layernorm: + rrwp = self.ln(rrwp) + + if "x" in batch: + batch.x = batch.x + rrwp + else: + batch.x = rrwp + + return batch + + +class RRWPLinearEdgeEncoder(torch.nn.Module): + ''' + Merge RRWP with given edge-attr and Zero-padding to all pairs of node + FC_1(RRWP) + FC_2(edge-attr) + - FC_2 given by the TypedictEncoder in same cases + - Zero-padding for non-existing edges in fully-connected graph + - (optional) add node-attr as the E_{i,i}'s attr + note: assuming node-attr and edge-attr is with the same dimension after Encoders + ''' + def __init__(self, emb_dim, out_dim, batchnorm=False, layernorm=False, use_bias=False, + pad_to_full_graph=True, fill_value=0., + add_node_attr_as_self_loop=False, + overwrite_old_attr=False): + super().__init__() + # note: batchnorm/layernorm might ruin some properties of pe on providing shortest-path distance info + self.emb_dim = emb_dim + self.out_dim = out_dim + self.add_node_attr_as_self_loop = add_node_attr_as_self_loop + self.overwrite_old_attr=overwrite_old_attr # remove the old edge-attr + + self.batchnorm = batchnorm + self.layernorm = layernorm + if self.batchnorm or self.layernorm: + warnings.warn("batchnorm/layernorm might ruin some properties of pe on providing shortest-path distance info ") + + self.fc = nn.Linear(emb_dim, out_dim, bias=use_bias) + torch.nn.init.xavier_uniform_(self.fc.weight) + self.pad_to_full_graph = pad_to_full_graph + self.fill_value = 0. + + padding = torch.ones(1, out_dim, dtype=torch.float) * fill_value + self.register_buffer("padding", padding) + + if self.batchnorm: + self.bn = nn.BatchNorm1d(out_dim) + + if self.layernorm: + self.ln = nn.LayerNorm(out_dim) + + def forward(self, batch): + rrwp_idx = batch.rrwp_index + rrwp_val = batch.rrwp_val + edge_index = batch.edge_index + edge_attr = batch.edge_attr + rrwp_val = self.fc(rrwp_val) + + if edge_attr is None: + edge_attr = edge_index.new_zeros(edge_index.size(1), rrwp_val.size(1)) + # zero padding for non-existing edges + + if self.overwrite_old_attr: + out_idx, out_val = rrwp_idx, rrwp_val + else: + # edge_index, edge_attr = add_remaining_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) + edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) + + #print('-->>>>', edge_attr.size(), rrwp_val.size()) + out_idx, out_val = torch_sparse.coalesce( + torch.cat([edge_index, rrwp_idx], dim=1), + torch.cat([edge_attr, rrwp_val], dim=0), + batch.num_nodes, batch.num_nodes, + op="add" + ) + + + if self.pad_to_full_graph: + edge_index_full = full_edge_index(out_idx, batch=batch.batch) + edge_attr_pad = self.padding.repeat(edge_index_full.size(1), 1) + # zero padding to fully-connected graphs + out_idx = torch.cat([out_idx, edge_index_full], dim=1) + out_val = torch.cat([out_val, edge_attr_pad], dim=0) + out_idx, out_val = torch_sparse.coalesce( + out_idx, out_val, batch.num_nodes, batch.num_nodes, + op="add" + ) + + if self.batchnorm: + out_val = self.bn(out_val) + + if self.layernorm: + out_val = self.ln(out_val) + + + batch.edge_index, batch.edge_attr = out_idx, out_val + return batch + + def __repr__(self): + return f"{self.__class__.__name__}" \ + f"(pad_to_full_graph={self.pad_to_full_graph}," \ + f"fill_value={self.fill_value}," \ + f"{self.fc.__repr__()})" + + + From a67e52285b69d5fae828b5c6231a586daa6d629e Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:36 -0500 Subject: [PATCH 04/26] clean up imported layers and encoders Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 6 +-- gridfm_graphkit/models/grit_layer.py | 3 -- gridfm_graphkit/models/grit_transformer.py | 50 +++++++++++----------- gridfm_graphkit/models/rrwp_encoder.py | 14 ++---- 4 files changed, 29 insertions(+), 44 deletions(-) diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index 904d6dc..dc6f3a1 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -35,8 +35,7 @@ model: num_layers: 10 output_dim: 6 pe_dim: 20 - type: GRIT #GPSTransformer # - layers_pre_mp: 0 + type: GRIT act: relu encoder: node_encoder: True @@ -45,10 +44,7 @@ model: node_encoder_bn: True gt: layer_type: GritTransformer - # layers: 10 - # n_heads: 8 dim_hidden: 64 # `gt.dim_hidden` must match `gnn.dim_inner` - # dropout: 0.0 layer_norm: False batch_norm: True update_e: True diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py index dc6cf97..b477980 100644 --- a/gridfm_graphkit/models/grit_layer.py +++ b/gridfm_graphkit/models/grit_layer.py @@ -6,8 +6,6 @@ from torch_geometric.utils.num_nodes import maybe_num_nodes from torch_scatter import scatter, scatter_max, scatter_add -from grit.utils import negate_edge_index -from torch_geometric.graphgym.register import * import opt_einsum as oe from yacs.config import CfgNode as CN @@ -141,7 +139,6 @@ def forward(self, batch): return h_out, e_out -@register_layer("GritTransformer") class GritTransformerLayer(nn.Module): """ Proposed Transformer Layer for GRIT diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index e3c6047..715c25f 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -1,15 +1,29 @@ from gridfm_graphkit.io.registries import MODELS_REGISTRY -from torch import nn import torch - +from torch import nn from rrwp_encoder import RRWPLinearNodeEncoder, RRWPLinearEdgeEncoder from grit_layer import GritTransformerLayer -# TODO verify use -from torch_geometric.graphgym.models.gnn import GNNPreMP -from torch_geometric.graphgym.models.layer import (new_layer_config, - BatchNorm1dNode) -from torch_geometric.graphgym.models.layer import new_layer_config, MLP + + +class BatchNorm1dNode(torch.nn.Module): + r"""A batch normalization layer for node-level features. + + Args: + dim_in (int): BatchNorm input dimension. + TODO fill in comments + """ + def __init__(self, dim_in, eps, momentum): + super().__init__() + self.bn = torch.nn.BatchNorm1d( + dim_in, + eps=eps, + momentum=momentum, + ) + + def forward(self, batch): + batch.x = self.bn(batch.x) + return batch class LinearNodeEncoder(torch.nn.Module): @@ -59,23 +73,16 @@ def __init__( # Encode integer node features via nn.Embeddings self.node_encoder = LinearNodeEncoder(dim_inner) if args.node_encoder_bn: - self.node_encoder_bn = BatchNorm1dNode( - new_layer_config(dim_inner, -1, -1, has_act=False, - has_bias=False, cfg=cfg)) + self.node_encoder_bn = BatchNorm1dNode(dim_inner, 1e-5, 0.1) # Update dim_in to reflect the new dimension fo the node features self.dim_in = dim_inner if args.edge_encoder: - # Hard-limit max edge dim for PNA. - if 'PNA' in args.model.gt.layer_type: # TODO remove condition if PNA not needed - dim_edge = min(128, dim_inner) - else: - dim_edge = dim_inner + + dim_edge = dim_inner # Encode integer edge features via nn.Embeddings self.edge_encoder = LinearEdgeEncoder(dim_edge) if cfg.dataset.edge_encoder_bn: - self.edge_encoder_bn = BatchNorm1dNode( - new_layer_config(dim_edge, -1, -1, has_act=False, - has_bias=False, cfg=cfg)) + self.edge_encoder_bn = BatchNorm1dNode(dim_edge, 1e-5, 0.1) def forward(self, batch): for module in self.children(): @@ -107,7 +114,6 @@ def __init__(self, args): ) # TODO add args dim_in = self.encoder.dim_in - if args.model.posenc_RRWP.enable: self.rrwp_abs_encoder = RRWPLinearNodeEncoder( @@ -123,12 +129,6 @@ def __init__(self, args): fill_value=0. ) - - if args.model.layers_pre_mp > 0: - self.pre_mp = GNNPreMP( - dim_in, dim_inner, args.model.layers_pre_mp) - dim_in = dim_inner - assert args.model.hidden_size == dim_inner == dim_in, \ "The inner and hidden dims must match." diff --git a/gridfm_graphkit/models/rrwp_encoder.py b/gridfm_graphkit/models/rrwp_encoder.py index f98118e..b73e463 100644 --- a/gridfm_graphkit/models/rrwp_encoder.py +++ b/gridfm_graphkit/models/rrwp_encoder.py @@ -1,25 +1,18 @@ -''' +""" The RRWP encoder for GRIT (ours) -''' +""" import torch from torch import nn from torch.nn import functional as F -from ogb.utils.features import get_bond_feature_dims import torch_sparse -import torch_geometric as pyg -from torch_geometric.graphgym.register import ( - register_edge_encoder, - register_node_encoder, -) - from torch_geometric.utils import remove_self_loops, add_remaining_self_loops, add_self_loops from torch_scatter import scatter import warnings def full_edge_index(edge_index, batch=None): """ - Retunr the Full batched sparse adjacency matrices given by edge indices. + Return the Full batched sparse adjacency matrices given by edge indices. Returns batched sparse adjacency matrices with exactly those edges that are not in the input `edge_index` while ignoring self-loops. Implementation inspired by `torch_geometric.utils.to_dense_adj` @@ -152,7 +145,6 @@ def forward(self, batch): # edge_index, edge_attr = add_remaining_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) - #print('-->>>>', edge_attr.size(), rrwp_val.size()) out_idx, out_val = torch_sparse.coalesce( torch.cat([edge_index, rrwp_idx], dim=1), torch.cat([edge_attr, rrwp_val], dim=0), From 6966f5ffc56a2384a2cb730324aa1bbf8269efb7 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:37 -0500 Subject: [PATCH 05/26] flow in basic structure for RRWP calculation Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 13 +- gridfm_graphkit/datasets/posenc_stats.py | 423 ++++++++++++++++++ .../datasets/powergrid_datamodule.py | 14 + gridfm_graphkit/datasets/rrwp.py | 103 +++++ 4 files changed, 547 insertions(+), 6 deletions(-) create mode 100644 gridfm_graphkit/datasets/posenc_stats.py create mode 100644 gridfm_graphkit/datasets/rrwp.py diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index dc6f3a1..05f31be 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -26,6 +26,12 @@ data: test_ratio: 0.1 val_ratio: 0.1 workers: 4 + posenc_RRWP: # TODO maybe better with data section... + enable: True + ksteps: 21 + add_identity: True + add_node_attr: False + add_inverse: False model: attention_head: 8 dropout: 0.1 @@ -57,12 +63,7 @@ model: O_e: True norm_e: True signed_sqrt: True - posenc_RRWP: - enable: True - ksteps: 21 - add_identity: True - add_node_attr: False - add_inverse: False + optimizer: beta1: 0.9 beta2: 0.999 diff --git a/gridfm_graphkit/datasets/posenc_stats.py b/gridfm_graphkit/datasets/posenc_stats.py new file mode 100644 index 0000000..492a0a6 --- /dev/null +++ b/gridfm_graphkit/datasets/posenc_stats.py @@ -0,0 +1,423 @@ +from copy import deepcopy + +import numpy as np +import torch +import torch.nn.functional as F + +from torch_geometric.utils import (get_laplacian, to_scipy_sparse_matrix, + to_undirected, to_dense_adj) +from torch_geometric.utils.num_nodes import maybe_num_nodes +from torch_scatter import scatter_add +from functools import partial +from gridfm_graphkit.datasets.rrwp import add_full_rrwp + + +def compute_posenc_stats(data, pe_types, is_undirected, cfg): + """Precompute positional encodings for the given graph. + Supported PE statistics to precompute, selected by `pe_types`: + 'LapPE': Laplacian eigen-decomposition. + 'RWSE': Random walk landing probabilities (diagonals of RW matrices). + 'HKfullPE': Full heat kernels and their diagonals. (NOT IMPLEMENTED) + 'HKdiagSE': Diagonals of heat kernel diffusion. + 'ElstaticSE': Kernel based on the electrostatic interaction between nodes. + 'RRWP': Relative Random Walk Probabilities PE (Ours, for GRIT) + Args: + data: PyG graph + pe_types: Positional encoding types to precompute statistics for. + This can also be a combination, e.g. 'eigen+rw_landing' + is_undirected: True if the graph is expected to be undirected + cfg: Main configuration node + + Returns: + Extended PyG Data object. + """ + # Verify PE types. + for t in pe_types: + if t not in ['LapPE', 'EquivStableLapPE', 'SignNet', + 'RWSE', 'HKdiagSE', 'HKfullPE', 'ElstaticSE','RRWP']: + raise ValueError(f"Unexpected PE stats selection {t} in {pe_types}") + + # Basic preprocessing of the input graph. + if hasattr(data, 'num_nodes'): + N = data.num_nodes # Explicitly given number of nodes, e.g. ogbg-ppa + else: + N = data.x.shape[0] # Number of nodes, including disconnected nodes. + laplacian_norm_type = cfg.posenc_LapPE.eigen.laplacian_norm.lower() + if laplacian_norm_type == 'none': + laplacian_norm_type = None + if is_undirected: + undir_edge_index = data.edge_index + else: + undir_edge_index = to_undirected(data.edge_index) + + # Eigen values and vectors. + evals, evects = None, None + if 'LapPE' in pe_types or 'EquivStableLapPE' in pe_types: + # Eigen-decomposition with numpy, can be reused for Heat kernels. + L = to_scipy_sparse_matrix( + *get_laplacian(undir_edge_index, normalization=laplacian_norm_type, + num_nodes=N) + ) + evals, evects = np.linalg.eigh(L.toarray()) + + if 'LapPE' in pe_types: + max_freqs=cfg.posenc_LapPE.eigen.max_freqs + eigvec_norm=cfg.posenc_LapPE.eigen.eigvec_norm + elif 'EquivStableLapPE' in pe_types: + max_freqs=cfg.posenc_EquivStableLapPE.eigen.max_freqs + eigvec_norm=cfg.posenc_EquivStableLapPE.eigen.eigvec_norm + + data.EigVals, data.EigVecs = get_lap_decomp_stats( + evals=evals, evects=evects, + max_freqs=max_freqs, + eigvec_norm=eigvec_norm) + + if 'SignNet' in pe_types: + # Eigen-decomposition with numpy for SignNet. + norm_type = cfg.posenc_SignNet.eigen.laplacian_norm.lower() + if norm_type == 'none': + norm_type = None + L = to_scipy_sparse_matrix( + *get_laplacian(undir_edge_index, normalization=norm_type, + num_nodes=N) + ) + evals_sn, evects_sn = np.linalg.eigh(L.toarray()) + data.eigvals_sn, data.eigvecs_sn = get_lap_decomp_stats( + evals=evals_sn, evects=evects_sn, + max_freqs=cfg.posenc_SignNet.eigen.max_freqs, + eigvec_norm=cfg.posenc_SignNet.eigen.eigvec_norm) + + # Random Walks. + if 'RWSE' in pe_types: + kernel_param = cfg.posenc_RWSE.kernel + if len(kernel_param.times) == 0: + raise ValueError("List of kernel times required for RWSE") + rw_landing = get_rw_landing_probs(ksteps=kernel_param.times, + edge_index=data.edge_index, + num_nodes=N) + data.pestat_RWSE = rw_landing + + # Heat Kernels. + if 'HKdiagSE' in pe_types or 'HKfullPE' in pe_types: + # Get the eigenvalues and eigenvectors of the regular Laplacian, + # if they have not yet been computed for 'eigen'. + if laplacian_norm_type is not None or evals is None or evects is None: + L_heat = to_scipy_sparse_matrix( + *get_laplacian(undir_edge_index, normalization=None, num_nodes=N) + ) + evals_heat, evects_heat = np.linalg.eigh(L_heat.toarray()) + else: + evals_heat, evects_heat = evals, evects + evals_heat = torch.from_numpy(evals_heat) + evects_heat = torch.from_numpy(evects_heat) + + # Get the full heat kernels. + if 'HKfullPE' in pe_types: + # The heat kernels can't be stored in the Data object without + # additional padding because in PyG's collation of the graphs the + # sizes of tensors must match except in dimension 0. Do this when + # the full heat kernels are actually used downstream by an Encoder. + raise NotImplementedError() + # heat_kernels, hk_diag = get_heat_kernels(evects_heat, evals_heat, + # kernel_times=kernel_param.times) + # data.pestat_HKdiagSE = hk_diag + # Get heat kernel diagonals in more efficient way. + if 'HKdiagSE' in pe_types: + kernel_param = cfg.posenc_HKdiagSE.kernel + if len(kernel_param.times) == 0: + raise ValueError("Diffusion times are required for heat kernel") + hk_diag = get_heat_kernels_diag(evects_heat, evals_heat, + kernel_times=kernel_param.times, + space_dim=0) + data.pestat_HKdiagSE = hk_diag + + # Electrostatic interaction inspired kernel. + if 'ElstaticSE' in pe_types: + elstatic = get_electrostatic_function_encoding(undir_edge_index, N) + data.pestat_ElstaticSE = elstatic + + if 'RRWP' in pe_types: + param = cfg.posenc_RRWP + transform = partial(add_full_rrwp, + walk_length=param.ksteps, + attr_name_abs="rrwp", + attr_name_rel="rrwp", + add_identity=True, + spd=param.spd, # by default False + ) + data = transform(data) + + return data + + +def get_lap_decomp_stats(evals, evects, max_freqs, eigvec_norm='L2'): + """Compute Laplacian eigen-decomposition-based PE stats of the given graph. + + Args: + evals, evects: Precomputed eigen-decomposition + max_freqs: Maximum number of top smallest frequencies / eigenvecs to use + eigvec_norm: Normalization for the eigen vectors of the Laplacian + Returns: + Tensor (num_nodes, max_freqs, 1) eigenvalues repeated for each node + Tensor (num_nodes, max_freqs) of eigenvector values per node + """ + N = len(evals) # Number of nodes, including disconnected nodes. + + # Keep up to the maximum desired number of frequencies. + idx = evals.argsort()[:max_freqs] + evals, evects = evals[idx], np.real(evects[:, idx]) + evals = torch.from_numpy(np.real(evals)).clamp_min(0) + + # Normalize and pad eigen vectors. + evects = torch.from_numpy(evects).float() + evects = eigvec_normalizer(evects, evals, normalization=eigvec_norm) + if N < max_freqs: + EigVecs = F.pad(evects, (0, max_freqs - N), value=float('nan')) + else: + EigVecs = evects + + # Pad and save eigenvalues. + if N < max_freqs: + EigVals = F.pad(evals, (0, max_freqs - N), value=float('nan')).unsqueeze(0) + else: + EigVals = evals.unsqueeze(0) + EigVals = EigVals.repeat(N, 1).unsqueeze(2) + + return EigVals, EigVecs + + +def get_rw_landing_probs(ksteps, edge_index, edge_weight=None, + num_nodes=None, space_dim=0): + """Compute Random Walk landing probabilities for given list of K steps. + + Args: + ksteps: List of k-steps for which to compute the RW landings + edge_index: PyG sparse representation of the graph + edge_weight: (optional) Edge weights + num_nodes: (optional) Number of nodes in the graph + space_dim: (optional) Estimated dimensionality of the space. Used to + correct the random-walk diagonal by a factor `k^(space_dim/2)`. + In euclidean space, this correction means that the height of + the gaussian distribution stays almost constant across the number of + steps, if `space_dim` is the dimension of the euclidean space. + + Returns: + 2D Tensor with shape (num_nodes, len(ksteps)) with RW landing probs + """ + if edge_weight is None: + edge_weight = torch.ones(edge_index.size(1), device=edge_index.device) + num_nodes = maybe_num_nodes(edge_index, num_nodes) + source, dest = edge_index[0], edge_index[1] + deg = scatter_add(edge_weight, source, dim=0, dim_size=num_nodes) # Out degrees. + deg_inv = deg.pow(-1.) + deg_inv.masked_fill_(deg_inv == float('inf'), 0) + + if edge_index.numel() == 0: + P = edge_index.new_zeros((1, num_nodes, num_nodes)) + else: + # P = D^-1 * A + P = torch.diag(deg_inv) @ to_dense_adj(edge_index, max_num_nodes=num_nodes) # 1 x (Num nodes) x (Num nodes) + rws = [] + if ksteps == list(range(min(ksteps), max(ksteps) + 1)): + # Efficient way if ksteps are a consecutive sequence (most of the time the case) + Pk = P.clone().detach().matrix_power(min(ksteps)) + for k in range(min(ksteps), max(ksteps) + 1): + rws.append(torch.diagonal(Pk, dim1=-2, dim2=-1) * \ + (k ** (space_dim / 2))) + Pk = Pk @ P + else: + # Explicitly raising P to power k for each k \in ksteps. + for k in ksteps: + rws.append(torch.diagonal(P.matrix_power(k), dim1=-2, dim2=-1) * \ + (k ** (space_dim / 2))) + rw_landing = torch.cat(rws, dim=0).transpose(0, 1) # (Num nodes) x (K steps) + + return rw_landing + + +def get_heat_kernels_diag(evects, evals, kernel_times=[], space_dim=0): + """Compute Heat kernel diagonal. + + This is a continuous function that represents a Gaussian in the Euclidean + space, and is the solution to the diffusion equation. + The random-walk diagonal should converge to this. + + Args: + evects: Eigenvectors of the Laplacian matrix + evals: Eigenvalues of the Laplacian matrix + kernel_times: Time for the diffusion. Analogous to the k-steps in random + walk. The time is equivalent to the variance of the kernel. + space_dim: (optional) Estimated dimensionality of the space. Used to + correct the diffusion diagonal by a factor `t^(space_dim/2)`. In + euclidean space, this correction means that the height of the + gaussian stays constant across time, if `space_dim` is the dimension + of the euclidean space. + + Returns: + 2D Tensor with shape (num_nodes, len(ksteps)) with RW landing probs + """ + heat_kernels_diag = [] + if len(kernel_times) > 0: + evects = F.normalize(evects, p=2., dim=0) + + # Remove eigenvalues == 0 from the computation of the heat kernel + idx_remove = evals < 1e-8 + evals = evals[~idx_remove] + evects = evects[:, ~idx_remove] + + # Change the shapes for the computations + evals = evals.unsqueeze(-1) # lambda_{i, ..., ...} + evects = evects.transpose(0, 1) # phi_{i,j}: i-th eigvec X j-th node + + # Compute the heat kernels diagonal only for each time + eigvec_mul = evects ** 2 + for t in kernel_times: + # sum_{i>0}(exp(-2 t lambda_i) * phi_{i, j} * phi_{i, j}) + this_kernel = torch.sum(torch.exp(-t * evals) * eigvec_mul, + dim=0, keepdim=False) + + # Multiply by `t` to stabilize the values, since the gaussian height + # is proportional to `1/t` + heat_kernels_diag.append(this_kernel * (t ** (space_dim / 2))) + heat_kernels_diag = torch.stack(heat_kernels_diag, dim=0).transpose(0, 1) + + return heat_kernels_diag + + +def get_heat_kernels(evects, evals, kernel_times=[]): + """Compute full Heat diffusion kernels. + + Args: + evects: Eigenvectors of the Laplacian matrix + evals: Eigenvalues of the Laplacian matrix + kernel_times: Time for the diffusion. Analogous to the k-steps in random + walk. The time is equivalent to the variance of the kernel. + """ + heat_kernels, rw_landing = [], [] + if len(kernel_times) > 0: + evects = F.normalize(evects, p=2., dim=0) + + # Remove eigenvalues == 0 from the computation of the heat kernel + idx_remove = evals < 1e-8 + evals = evals[~idx_remove] + evects = evects[:, ~idx_remove] + + # Change the shapes for the computations + evals = evals.unsqueeze(-1).unsqueeze(-1) # lambda_{i, ..., ...} + evects = evects.transpose(0, 1) # phi_{i,j}: i-th eigvec X j-th node + + # Compute the heat kernels for each time + eigvec_mul = (evects.unsqueeze(2) * evects.unsqueeze(1)) # (phi_{i, j1, ...} * phi_{i, ..., j2}) + for t in kernel_times: + # sum_{i>0}(exp(-2 t lambda_i) * phi_{i, j1, ...} * phi_{i, ..., j2}) + heat_kernels.append( + torch.sum(torch.exp(-t * evals) * eigvec_mul, + dim=0, keepdim=False) + ) + + heat_kernels = torch.stack(heat_kernels, dim=0) # (Num kernel times) x (Num nodes) x (Num nodes) + + # Take the diagonal of each heat kernel, + # i.e. the landing probability of each of the random walks + rw_landing = torch.diagonal(heat_kernels, dim1=-2, dim2=-1).transpose(0, 1) # (Num nodes) x (Num kernel times) + + return heat_kernels, rw_landing + + +def get_electrostatic_function_encoding(edge_index, num_nodes): + """Kernel based on the electrostatic interaction between nodes. + """ + L = to_scipy_sparse_matrix( + *get_laplacian(edge_index, normalization=None, num_nodes=num_nodes) + ).todense() + L = torch.as_tensor(L) + Dinv = torch.eye(L.shape[0]) * (L.diag() ** -1) + A = deepcopy(L).abs() + A.fill_diagonal_(0) + DinvA = Dinv.matmul(A) + + electrostatic = torch.pinverse(L) + electrostatic = electrostatic - electrostatic.diag() + green_encoding = torch.stack([ + electrostatic.min(dim=0)[0], # Min of Vi -> j + electrostatic.max(dim=0)[0], # Max of Vi -> j + electrostatic.mean(dim=0), # Mean of Vi -> j + electrostatic.std(dim=0), # Std of Vi -> j + electrostatic.min(dim=1)[0], # Min of Vj -> i + electrostatic.max(dim=0)[0], # Max of Vj -> i + electrostatic.mean(dim=1), # Mean of Vj -> i + electrostatic.std(dim=1), # Std of Vj -> i + (DinvA * electrostatic).sum(dim=0), # Mean of interaction on direct neighbour + (DinvA * electrostatic).sum(dim=1), # Mean of interaction from direct neighbour + ], dim=1) + + return green_encoding + + +def eigvec_normalizer(EigVecs, EigVals, normalization="L2", eps=1e-12): + """ + Implement different eigenvector normalizations. + """ + + EigVals = EigVals.unsqueeze(0) + + if normalization == "L1": + # L1 normalization: eigvec / sum(abs(eigvec)) + denom = EigVecs.norm(p=1, dim=0, keepdim=True) + + elif normalization == "L2": + # L2 normalization: eigvec / sqrt(sum(eigvec^2)) + denom = EigVecs.norm(p=2, dim=0, keepdim=True) + + elif normalization == "abs-max": + # AbsMax normalization: eigvec / max|eigvec| + denom = torch.max(EigVecs.abs(), dim=0, keepdim=True).values + + elif normalization == "wavelength": + # AbsMax normalization, followed by wavelength multiplication: + # eigvec * pi / (2 * max|eigvec| * sqrt(eigval)) + denom = torch.max(EigVecs.abs(), dim=0, keepdim=True).values + eigval_denom = torch.sqrt(EigVals) + eigval_denom[EigVals < eps] = 1 # Problem with eigval = 0 + denom = denom * eigval_denom * 2 / np.pi + + elif normalization == "wavelength-asin": + # AbsMax normalization, followed by arcsin and wavelength multiplication: + # arcsin(eigvec / max|eigvec|) / sqrt(eigval) + denom_temp = torch.max(EigVecs.abs(), dim=0, keepdim=True).values.clamp_min(eps).expand_as(EigVecs) + EigVecs = torch.asin(EigVecs / denom_temp) + eigval_denom = torch.sqrt(EigVals) + eigval_denom[EigVals < eps] = 1 # Problem with eigval = 0 + denom = eigval_denom + + elif normalization == "wavelength-soft": + # AbsSoftmax normalization, followed by wavelength multiplication: + # eigvec / (softmax|eigvec| * sqrt(eigval)) + denom = (F.softmax(EigVecs.abs(), dim=0) * EigVecs.abs()).sum(dim=0, keepdim=True) + eigval_denom = torch.sqrt(EigVals) + eigval_denom[EigVals < eps] = 1 # Problem with eigval = 0 + denom = denom * eigval_denom + + else: + raise ValueError(f"Unsupported normalization `{normalization}`") + + denom = denom.clamp_min(eps).expand_as(EigVecs) + EigVecs = EigVecs / denom + + return EigVecs + +from torch_geometric.transforms import BaseTransform +from torch_geometric.data import Data, HeteroData + +class ComputePosencStat(BaseTransform): + def __init__(self, pe_types, is_undirected, cfg): + self.pe_types = pe_types + self.is_undirected = is_undirected + self.cfg = cfg + + def __call__(self, data: Data) -> Data: + data = compute_posenc_stats(data, pe_types=self.pe_types, + is_undirected=self.is_undirected, + cfg=self.cfg + ) + return data \ No newline at end of file diff --git a/gridfm_graphkit/datasets/powergrid_datamodule.py b/gridfm_graphkit/datasets/powergrid_datamodule.py index c18c360..ad68f4f 100644 --- a/gridfm_graphkit/datasets/powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/powergrid_datamodule.py @@ -10,6 +10,9 @@ ) from gridfm_graphkit.datasets.utils import split_dataset from gridfm_graphkit.datasets.powergrid_dataset import GridDatasetDisk + +from gridfm_graphkit.datasets.posenc_stats import ComputePosencStat + import numpy as np import random import warnings @@ -129,6 +132,17 @@ def setup(self, stage: str): mask_dim=self.args.data.mask_dim, transform=get_transform(args=self.args), ) + + if self.args.data.posenc_RRWP.enable: + pe_transform = ComputePosencStat(pe_types=pe_enabled_list, # TODO connect arguments + is_undirected=is_undirected, + cfg=cfg + ) + if dataset.transform is None: + dataset.transform = pe_transform + else: + dataset.transform = T.compose([pe_transform, dataset.transform]) + self.datasets.append(dataset) num_scenarios = self.args.data.scenarios[i] diff --git a/gridfm_graphkit/datasets/rrwp.py b/gridfm_graphkit/datasets/rrwp.py new file mode 100644 index 0000000..d88e3e7 --- /dev/null +++ b/gridfm_graphkit/datasets/rrwp.py @@ -0,0 +1,103 @@ +# ------------------------ : new rwpse ---------------- +from typing import Union, Any, Optional +import numpy as np +import torch +import torch.nn.functional as F +import torch_geometric as pyg +from torch_geometric.data import Data, HeteroData +from torch_geometric.transforms import BaseTransform +from torch_scatter import scatter, scatter_add, scatter_max + +from torch_geometric.graphgym.config import cfg + +from torch_geometric.utils import ( + get_laplacian, + get_self_loop_attr, + to_scipy_sparse_matrix, +) +import torch_sparse +from torch_sparse import SparseTensor + + +def add_node_attr(data: Data, value: Any, + attr_name: Optional[str] = None) -> Data: + if attr_name is None: + if 'x' in data: + x = data.x.view(-1, 1) if data.x.dim() == 1 else data.x + data.x = torch.cat([x, value.to(x.device, x.dtype)], dim=-1) + else: + data.x = value + else: + data[attr_name] = value + + return data + + + +@torch.no_grad() +def add_full_rrwp(data, + walk_length=8, + attr_name_abs="rrwp", # name: 'rrwp' + attr_name_rel="rrwp", # name: ('rrwp_idx', 'rrwp_val') + add_identity=True, + spd=False, + **kwargs + ): + device=data.edge_index.device + ind_vec = torch.eye(walk_length, dtype=torch.float, device=device) + num_nodes = data.num_nodes + edge_index, edge_weight = data.edge_index, data.edge_weight + + adj = SparseTensor.from_edge_index(edge_index, edge_weight, + sparse_sizes=(num_nodes, num_nodes), + ) + + # Compute D^{-1} A: + deg = adj.sum(dim=1) + deg_inv = 1.0 / adj.sum(dim=1) + deg_inv[deg_inv == float('inf')] = 0 + adj = adj * deg_inv.view(-1, 1) + adj = adj.to_dense() + + pe_list = [] + i = 0 + if add_identity: + pe_list.append(torch.eye(num_nodes, dtype=torch.float)) + i = i + 1 + + out = adj + pe_list.append(adj) + + if walk_length > 2: + for j in range(i + 1, walk_length): + out = out @ adj + pe_list.append(out) + + pe = torch.stack(pe_list, dim=-1) # n x n x k + + abs_pe = pe.diagonal().transpose(0, 1) # n x k + + rel_pe = SparseTensor.from_dense(pe, has_value=True) + rel_pe_row, rel_pe_col, rel_pe_val = rel_pe.coo() + # rel_pe_idx = torch.stack([rel_pe_row, rel_pe_col], dim=0) + rel_pe_idx = torch.stack([rel_pe_col, rel_pe_row], dim=0) + # the framework of GRIT performing right-mul while adj is row-normalized, + # need to switch the order or row and col. + # note: both can work but the current version is more reasonable. + + + if spd: + spd_idx = walk_length - torch.arange(walk_length) + val = (rel_pe_val > 0).type(torch.float) * spd_idx.unsqueeze(0) + val = torch.argmax(val, dim=-1) + rel_pe_val = F.one_hot(val, walk_length).type(torch.float) + abs_pe = torch.zeros_like(abs_pe) + + data = add_node_attr(data, abs_pe, attr_name=attr_name_abs) + data = add_node_attr(data, rel_pe_idx, attr_name=f"{attr_name_rel}_index") + data = add_node_attr(data, rel_pe_val, attr_name=f"{attr_name_rel}_val") + data.log_deg = torch.log(deg + 1) + data.deg = deg.type(torch.long) + + return data + From a7bd51d1747ae3dd6f12a712b2962484abc59d6f Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:38 -0500 Subject: [PATCH 06/26] clean up Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/datasets/posenc_stats.py | 372 +----------------- .../datasets/powergrid_datamodule.py | 5 +- 2 files changed, 9 insertions(+), 368 deletions(-) diff --git a/gridfm_graphkit/datasets/posenc_stats.py b/gridfm_graphkit/datasets/posenc_stats.py index 492a0a6..049633c 100644 --- a/gridfm_graphkit/datasets/posenc_stats.py +++ b/gridfm_graphkit/datasets/posenc_stats.py @@ -11,8 +11,10 @@ from functools import partial from gridfm_graphkit.datasets.rrwp import add_full_rrwp +from torch_geometric.transforms import BaseTransform +from torch_geometric.data import Data -def compute_posenc_stats(data, pe_types, is_undirected, cfg): +def compute_posenc_stats(data, pe_types, cfg): """Precompute positional encodings for the given graph. Supported PE statistics to precompute, selected by `pe_types`: 'LapPE': Laplacian eigen-decomposition. @@ -37,387 +39,27 @@ def compute_posenc_stats(data, pe_types, is_undirected, cfg): 'RWSE', 'HKdiagSE', 'HKfullPE', 'ElstaticSE','RRWP']: raise ValueError(f"Unexpected PE stats selection {t} in {pe_types}") - # Basic preprocessing of the input graph. - if hasattr(data, 'num_nodes'): - N = data.num_nodes # Explicitly given number of nodes, e.g. ogbg-ppa - else: - N = data.x.shape[0] # Number of nodes, including disconnected nodes. - laplacian_norm_type = cfg.posenc_LapPE.eigen.laplacian_norm.lower() - if laplacian_norm_type == 'none': - laplacian_norm_type = None - if is_undirected: - undir_edge_index = data.edge_index - else: - undir_edge_index = to_undirected(data.edge_index) - - # Eigen values and vectors. - evals, evects = None, None - if 'LapPE' in pe_types or 'EquivStableLapPE' in pe_types: - # Eigen-decomposition with numpy, can be reused for Heat kernels. - L = to_scipy_sparse_matrix( - *get_laplacian(undir_edge_index, normalization=laplacian_norm_type, - num_nodes=N) - ) - evals, evects = np.linalg.eigh(L.toarray()) - - if 'LapPE' in pe_types: - max_freqs=cfg.posenc_LapPE.eigen.max_freqs - eigvec_norm=cfg.posenc_LapPE.eigen.eigvec_norm - elif 'EquivStableLapPE' in pe_types: - max_freqs=cfg.posenc_EquivStableLapPE.eigen.max_freqs - eigvec_norm=cfg.posenc_EquivStableLapPE.eigen.eigvec_norm - - data.EigVals, data.EigVecs = get_lap_decomp_stats( - evals=evals, evects=evects, - max_freqs=max_freqs, - eigvec_norm=eigvec_norm) - - if 'SignNet' in pe_types: - # Eigen-decomposition with numpy for SignNet. - norm_type = cfg.posenc_SignNet.eigen.laplacian_norm.lower() - if norm_type == 'none': - norm_type = None - L = to_scipy_sparse_matrix( - *get_laplacian(undir_edge_index, normalization=norm_type, - num_nodes=N) - ) - evals_sn, evects_sn = np.linalg.eigh(L.toarray()) - data.eigvals_sn, data.eigvecs_sn = get_lap_decomp_stats( - evals=evals_sn, evects=evects_sn, - max_freqs=cfg.posenc_SignNet.eigen.max_freqs, - eigvec_norm=cfg.posenc_SignNet.eigen.eigvec_norm) - - # Random Walks. - if 'RWSE' in pe_types: - kernel_param = cfg.posenc_RWSE.kernel - if len(kernel_param.times) == 0: - raise ValueError("List of kernel times required for RWSE") - rw_landing = get_rw_landing_probs(ksteps=kernel_param.times, - edge_index=data.edge_index, - num_nodes=N) - data.pestat_RWSE = rw_landing - - # Heat Kernels. - if 'HKdiagSE' in pe_types or 'HKfullPE' in pe_types: - # Get the eigenvalues and eigenvectors of the regular Laplacian, - # if they have not yet been computed for 'eigen'. - if laplacian_norm_type is not None or evals is None or evects is None: - L_heat = to_scipy_sparse_matrix( - *get_laplacian(undir_edge_index, normalization=None, num_nodes=N) - ) - evals_heat, evects_heat = np.linalg.eigh(L_heat.toarray()) - else: - evals_heat, evects_heat = evals, evects - evals_heat = torch.from_numpy(evals_heat) - evects_heat = torch.from_numpy(evects_heat) - - # Get the full heat kernels. - if 'HKfullPE' in pe_types: - # The heat kernels can't be stored in the Data object without - # additional padding because in PyG's collation of the graphs the - # sizes of tensors must match except in dimension 0. Do this when - # the full heat kernels are actually used downstream by an Encoder. - raise NotImplementedError() - # heat_kernels, hk_diag = get_heat_kernels(evects_heat, evals_heat, - # kernel_times=kernel_param.times) - # data.pestat_HKdiagSE = hk_diag - # Get heat kernel diagonals in more efficient way. - if 'HKdiagSE' in pe_types: - kernel_param = cfg.posenc_HKdiagSE.kernel - if len(kernel_param.times) == 0: - raise ValueError("Diffusion times are required for heat kernel") - hk_diag = get_heat_kernels_diag(evects_heat, evals_heat, - kernel_times=kernel_param.times, - space_dim=0) - data.pestat_HKdiagSE = hk_diag - - # Electrostatic interaction inspired kernel. - if 'ElstaticSE' in pe_types: - elstatic = get_electrostatic_function_encoding(undir_edge_index, N) - data.pestat_ElstaticSE = elstatic - if 'RRWP' in pe_types: param = cfg.posenc_RRWP transform = partial(add_full_rrwp, walk_length=param.ksteps, attr_name_abs="rrwp", attr_name_rel="rrwp", - add_identity=True, - spd=param.spd, # by default False + add_identity=True ) data = transform(data) return data -def get_lap_decomp_stats(evals, evects, max_freqs, eigvec_norm='L2'): - """Compute Laplacian eigen-decomposition-based PE stats of the given graph. - - Args: - evals, evects: Precomputed eigen-decomposition - max_freqs: Maximum number of top smallest frequencies / eigenvecs to use - eigvec_norm: Normalization for the eigen vectors of the Laplacian - Returns: - Tensor (num_nodes, max_freqs, 1) eigenvalues repeated for each node - Tensor (num_nodes, max_freqs) of eigenvector values per node - """ - N = len(evals) # Number of nodes, including disconnected nodes. - - # Keep up to the maximum desired number of frequencies. - idx = evals.argsort()[:max_freqs] - evals, evects = evals[idx], np.real(evects[:, idx]) - evals = torch.from_numpy(np.real(evals)).clamp_min(0) - - # Normalize and pad eigen vectors. - evects = torch.from_numpy(evects).float() - evects = eigvec_normalizer(evects, evals, normalization=eigvec_norm) - if N < max_freqs: - EigVecs = F.pad(evects, (0, max_freqs - N), value=float('nan')) - else: - EigVecs = evects - - # Pad and save eigenvalues. - if N < max_freqs: - EigVals = F.pad(evals, (0, max_freqs - N), value=float('nan')).unsqueeze(0) - else: - EigVals = evals.unsqueeze(0) - EigVals = EigVals.repeat(N, 1).unsqueeze(2) - - return EigVals, EigVecs - - -def get_rw_landing_probs(ksteps, edge_index, edge_weight=None, - num_nodes=None, space_dim=0): - """Compute Random Walk landing probabilities for given list of K steps. - - Args: - ksteps: List of k-steps for which to compute the RW landings - edge_index: PyG sparse representation of the graph - edge_weight: (optional) Edge weights - num_nodes: (optional) Number of nodes in the graph - space_dim: (optional) Estimated dimensionality of the space. Used to - correct the random-walk diagonal by a factor `k^(space_dim/2)`. - In euclidean space, this correction means that the height of - the gaussian distribution stays almost constant across the number of - steps, if `space_dim` is the dimension of the euclidean space. - - Returns: - 2D Tensor with shape (num_nodes, len(ksteps)) with RW landing probs - """ - if edge_weight is None: - edge_weight = torch.ones(edge_index.size(1), device=edge_index.device) - num_nodes = maybe_num_nodes(edge_index, num_nodes) - source, dest = edge_index[0], edge_index[1] - deg = scatter_add(edge_weight, source, dim=0, dim_size=num_nodes) # Out degrees. - deg_inv = deg.pow(-1.) - deg_inv.masked_fill_(deg_inv == float('inf'), 0) - - if edge_index.numel() == 0: - P = edge_index.new_zeros((1, num_nodes, num_nodes)) - else: - # P = D^-1 * A - P = torch.diag(deg_inv) @ to_dense_adj(edge_index, max_num_nodes=num_nodes) # 1 x (Num nodes) x (Num nodes) - rws = [] - if ksteps == list(range(min(ksteps), max(ksteps) + 1)): - # Efficient way if ksteps are a consecutive sequence (most of the time the case) - Pk = P.clone().detach().matrix_power(min(ksteps)) - for k in range(min(ksteps), max(ksteps) + 1): - rws.append(torch.diagonal(Pk, dim1=-2, dim2=-1) * \ - (k ** (space_dim / 2))) - Pk = Pk @ P - else: - # Explicitly raising P to power k for each k \in ksteps. - for k in ksteps: - rws.append(torch.diagonal(P.matrix_power(k), dim1=-2, dim2=-1) * \ - (k ** (space_dim / 2))) - rw_landing = torch.cat(rws, dim=0).transpose(0, 1) # (Num nodes) x (K steps) - - return rw_landing - - -def get_heat_kernels_diag(evects, evals, kernel_times=[], space_dim=0): - """Compute Heat kernel diagonal. - - This is a continuous function that represents a Gaussian in the Euclidean - space, and is the solution to the diffusion equation. - The random-walk diagonal should converge to this. - - Args: - evects: Eigenvectors of the Laplacian matrix - evals: Eigenvalues of the Laplacian matrix - kernel_times: Time for the diffusion. Analogous to the k-steps in random - walk. The time is equivalent to the variance of the kernel. - space_dim: (optional) Estimated dimensionality of the space. Used to - correct the diffusion diagonal by a factor `t^(space_dim/2)`. In - euclidean space, this correction means that the height of the - gaussian stays constant across time, if `space_dim` is the dimension - of the euclidean space. - - Returns: - 2D Tensor with shape (num_nodes, len(ksteps)) with RW landing probs - """ - heat_kernels_diag = [] - if len(kernel_times) > 0: - evects = F.normalize(evects, p=2., dim=0) - - # Remove eigenvalues == 0 from the computation of the heat kernel - idx_remove = evals < 1e-8 - evals = evals[~idx_remove] - evects = evects[:, ~idx_remove] - - # Change the shapes for the computations - evals = evals.unsqueeze(-1) # lambda_{i, ..., ...} - evects = evects.transpose(0, 1) # phi_{i,j}: i-th eigvec X j-th node - - # Compute the heat kernels diagonal only for each time - eigvec_mul = evects ** 2 - for t in kernel_times: - # sum_{i>0}(exp(-2 t lambda_i) * phi_{i, j} * phi_{i, j}) - this_kernel = torch.sum(torch.exp(-t * evals) * eigvec_mul, - dim=0, keepdim=False) - - # Multiply by `t` to stabilize the values, since the gaussian height - # is proportional to `1/t` - heat_kernels_diag.append(this_kernel * (t ** (space_dim / 2))) - heat_kernels_diag = torch.stack(heat_kernels_diag, dim=0).transpose(0, 1) - - return heat_kernels_diag - - -def get_heat_kernels(evects, evals, kernel_times=[]): - """Compute full Heat diffusion kernels. - - Args: - evects: Eigenvectors of the Laplacian matrix - evals: Eigenvalues of the Laplacian matrix - kernel_times: Time for the diffusion. Analogous to the k-steps in random - walk. The time is equivalent to the variance of the kernel. - """ - heat_kernels, rw_landing = [], [] - if len(kernel_times) > 0: - evects = F.normalize(evects, p=2., dim=0) - - # Remove eigenvalues == 0 from the computation of the heat kernel - idx_remove = evals < 1e-8 - evals = evals[~idx_remove] - evects = evects[:, ~idx_remove] - - # Change the shapes for the computations - evals = evals.unsqueeze(-1).unsqueeze(-1) # lambda_{i, ..., ...} - evects = evects.transpose(0, 1) # phi_{i,j}: i-th eigvec X j-th node - - # Compute the heat kernels for each time - eigvec_mul = (evects.unsqueeze(2) * evects.unsqueeze(1)) # (phi_{i, j1, ...} * phi_{i, ..., j2}) - for t in kernel_times: - # sum_{i>0}(exp(-2 t lambda_i) * phi_{i, j1, ...} * phi_{i, ..., j2}) - heat_kernels.append( - torch.sum(torch.exp(-t * evals) * eigvec_mul, - dim=0, keepdim=False) - ) - - heat_kernels = torch.stack(heat_kernels, dim=0) # (Num kernel times) x (Num nodes) x (Num nodes) - - # Take the diagonal of each heat kernel, - # i.e. the landing probability of each of the random walks - rw_landing = torch.diagonal(heat_kernels, dim1=-2, dim2=-1).transpose(0, 1) # (Num nodes) x (Num kernel times) - - return heat_kernels, rw_landing - - -def get_electrostatic_function_encoding(edge_index, num_nodes): - """Kernel based on the electrostatic interaction between nodes. - """ - L = to_scipy_sparse_matrix( - *get_laplacian(edge_index, normalization=None, num_nodes=num_nodes) - ).todense() - L = torch.as_tensor(L) - Dinv = torch.eye(L.shape[0]) * (L.diag() ** -1) - A = deepcopy(L).abs() - A.fill_diagonal_(0) - DinvA = Dinv.matmul(A) - - electrostatic = torch.pinverse(L) - electrostatic = electrostatic - electrostatic.diag() - green_encoding = torch.stack([ - electrostatic.min(dim=0)[0], # Min of Vi -> j - electrostatic.max(dim=0)[0], # Max of Vi -> j - electrostatic.mean(dim=0), # Mean of Vi -> j - electrostatic.std(dim=0), # Std of Vi -> j - electrostatic.min(dim=1)[0], # Min of Vj -> i - electrostatic.max(dim=0)[0], # Max of Vj -> i - electrostatic.mean(dim=1), # Mean of Vj -> i - electrostatic.std(dim=1), # Std of Vj -> i - (DinvA * electrostatic).sum(dim=0), # Mean of interaction on direct neighbour - (DinvA * electrostatic).sum(dim=1), # Mean of interaction from direct neighbour - ], dim=1) - - return green_encoding - - -def eigvec_normalizer(EigVecs, EigVals, normalization="L2", eps=1e-12): - """ - Implement different eigenvector normalizations. - """ - - EigVals = EigVals.unsqueeze(0) - - if normalization == "L1": - # L1 normalization: eigvec / sum(abs(eigvec)) - denom = EigVecs.norm(p=1, dim=0, keepdim=True) - - elif normalization == "L2": - # L2 normalization: eigvec / sqrt(sum(eigvec^2)) - denom = EigVecs.norm(p=2, dim=0, keepdim=True) - - elif normalization == "abs-max": - # AbsMax normalization: eigvec / max|eigvec| - denom = torch.max(EigVecs.abs(), dim=0, keepdim=True).values - - elif normalization == "wavelength": - # AbsMax normalization, followed by wavelength multiplication: - # eigvec * pi / (2 * max|eigvec| * sqrt(eigval)) - denom = torch.max(EigVecs.abs(), dim=0, keepdim=True).values - eigval_denom = torch.sqrt(EigVals) - eigval_denom[EigVals < eps] = 1 # Problem with eigval = 0 - denom = denom * eigval_denom * 2 / np.pi - - elif normalization == "wavelength-asin": - # AbsMax normalization, followed by arcsin and wavelength multiplication: - # arcsin(eigvec / max|eigvec|) / sqrt(eigval) - denom_temp = torch.max(EigVecs.abs(), dim=0, keepdim=True).values.clamp_min(eps).expand_as(EigVecs) - EigVecs = torch.asin(EigVecs / denom_temp) - eigval_denom = torch.sqrt(EigVals) - eigval_denom[EigVals < eps] = 1 # Problem with eigval = 0 - denom = eigval_denom - - elif normalization == "wavelength-soft": - # AbsSoftmax normalization, followed by wavelength multiplication: - # eigvec / (softmax|eigvec| * sqrt(eigval)) - denom = (F.softmax(EigVecs.abs(), dim=0) * EigVecs.abs()).sum(dim=0, keepdim=True) - eigval_denom = torch.sqrt(EigVals) - eigval_denom[EigVals < eps] = 1 # Problem with eigval = 0 - denom = denom * eigval_denom - - else: - raise ValueError(f"Unsupported normalization `{normalization}`") - - denom = denom.clamp_min(eps).expand_as(EigVecs) - EigVecs = EigVecs / denom - - return EigVecs - -from torch_geometric.transforms import BaseTransform -from torch_geometric.data import Data, HeteroData - class ComputePosencStat(BaseTransform): - def __init__(self, pe_types, is_undirected, cfg): + def __init__(self, pe_types, cfg): self.pe_types = pe_types - self.is_undirected = is_undirected self.cfg = cfg def __call__(self, data: Data) -> Data: - data = compute_posenc_stats(data, pe_types=self.pe_types, - is_undirected=self.is_undirected, + data = compute_posenc_stats(data, + pe_types=self.pe_types, cfg=self.cfg ) return data \ No newline at end of file diff --git a/gridfm_graphkit/datasets/powergrid_datamodule.py b/gridfm_graphkit/datasets/powergrid_datamodule.py index ad68f4f..d73f6fa 100644 --- a/gridfm_graphkit/datasets/powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/powergrid_datamodule.py @@ -134,9 +134,8 @@ def setup(self, stage: str): ) if self.args.data.posenc_RRWP.enable: - pe_transform = ComputePosencStat(pe_types=pe_enabled_list, # TODO connect arguments - is_undirected=is_undirected, - cfg=cfg + pe_transform = ComputePosencStat(pe_types=['RRWP'], + cfg=self.args.data ) if dataset.transform is None: dataset.transform = pe_transform From 226f2a3098c4ed4499ac982d53f1e49338cce9fd Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:38 -0500 Subject: [PATCH 07/26] matching up parameters Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 1 + gridfm_graphkit/datasets/posenc_stats.py | 3 +- .../datasets/powergrid_datamodule.py | 2 + gridfm_graphkit/datasets/rrwp.py | 1 - gridfm_graphkit/models/__init__.py | 3 +- gridfm_graphkit/models/grit_layer.py | 3 +- gridfm_graphkit/models/grit_transformer.py | 42 +++++++++---------- 7 files changed, 28 insertions(+), 27 deletions(-) diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index 05f31be..9896057 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -48,6 +48,7 @@ model: edge_encoder: True node_encoder_name: TODO node_encoder_bn: True + .edge_encoder_bn: True gt: layer_type: GritTransformer dim_hidden: 64 # `gt.dim_hidden` must match `gnn.dim_inner` diff --git a/gridfm_graphkit/datasets/posenc_stats.py b/gridfm_graphkit/datasets/posenc_stats.py index 049633c..8bb2b9d 100644 --- a/gridfm_graphkit/datasets/posenc_stats.py +++ b/gridfm_graphkit/datasets/posenc_stats.py @@ -16,7 +16,8 @@ def compute_posenc_stats(data, pe_types, cfg): """Precompute positional encodings for the given graph. - Supported PE statistics to precompute, selected by `pe_types`: + Supported PE statistics to precompute in original implementation, + selected by `pe_types`: 'LapPE': Laplacian eigen-decomposition. 'RWSE': Random walk landing probabilities (diagonals of RW matrices). 'HKfullPE': Full heat kernels and their diagonals. (NOT IMPLEMENTED) diff --git a/gridfm_graphkit/datasets/powergrid_datamodule.py b/gridfm_graphkit/datasets/powergrid_datamodule.py index d73f6fa..9960e08 100644 --- a/gridfm_graphkit/datasets/powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/powergrid_datamodule.py @@ -13,6 +13,8 @@ from gridfm_graphkit.datasets.posenc_stats import ComputePosencStat +import torch_geometric.transforms as T + import numpy as np import random import warnings diff --git a/gridfm_graphkit/datasets/rrwp.py b/gridfm_graphkit/datasets/rrwp.py index d88e3e7..26218f0 100644 --- a/gridfm_graphkit/datasets/rrwp.py +++ b/gridfm_graphkit/datasets/rrwp.py @@ -8,7 +8,6 @@ from torch_geometric.transforms import BaseTransform from torch_scatter import scatter, scatter_add, scatter_max -from torch_geometric.graphgym.config import cfg from torch_geometric.utils import ( get_laplacian, diff --git a/gridfm_graphkit/models/__init__.py b/gridfm_graphkit/models/__init__.py index de355d3..cc66936 100644 --- a/gridfm_graphkit/models/__init__.py +++ b/gridfm_graphkit/models/__init__.py @@ -1,4 +1,5 @@ from gridfm_graphkit.models.gps_transformer import GPSTransformer from gridfm_graphkit.models.gnn_transformer import GNN_TransformerConv +from gridfm_graphkit.models.grit_transformer import GritTransformer -__all__ = ["GPSTransformer", "GNN_TransformerConv"] +__all__ = ["GPSTransformer", "GNN_TransformerConv", "GRIT"] diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py index b477980..53e7217 100644 --- a/gridfm_graphkit/models/grit_layer.py +++ b/gridfm_graphkit/models/grit_layer.py @@ -8,7 +8,6 @@ import opt_einsum as oe -from yacs.config import CfgNode as CN import warnings @@ -48,7 +47,7 @@ def __init__(self, in_dim, out_dim, num_heads, use_bias, edge_enhance=True, sqrt_relu=False, signed_sqrt=True, - cfg=CN(), + cfg={}, **kwargs): super().__init__() diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 715c25f..49bfdf2 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -1,8 +1,9 @@ from gridfm_graphkit.io.registries import MODELS_REGISTRY import torch from torch import nn -from rrwp_encoder import RRWPLinearNodeEncoder, RRWPLinearEdgeEncoder -from grit_layer import GritTransformerLayer + +from gridfm_graphkit.models.rrwp_encoder import RRWPLinearNodeEncoder, RRWPLinearEdgeEncoder +from gridfm_graphkit.models.grit_layer import GritTransformerLayer @@ -27,25 +28,21 @@ def forward(self, batch): class LinearNodeEncoder(torch.nn.Module): - def __init__(self, emb_dim): + def __init__(self, dim_in, emb_dim): super().__init__() - self.encoder = torch.nn.Linear(cfg.share.dim_in, emb_dim) + self.encoder = torch.nn.Linear(dim_in, emb_dim) def forward(self, batch): batch.x = self.encoder(batch.x) return batch class LinearEdgeEncoder(torch.nn.Module): - def __init__(self, emb_dim): + def __init__(self, edge_dim, emb_dim): super().__init__() - if cfg.dataset.name in ['MNIST', 'CIFAR10']: - self.in_dim = 1 - elif cfg.dataset.name.startswith('attributed_triangle-'): - self.in_dim = 2 - else: - raise ValueError("Input edge feature dim is required to be hardset " - "or refactored to use a cfg option.") + + self.in_dim = edge_dim + self.encoder = torch.nn.Linear(self.in_dim, emb_dim) def forward(self, batch): @@ -69,20 +66,20 @@ def __init__( ): super(FeatureEncoder, self).__init__() self.dim_in = dim_in - if args.node_encoder: + if args.encoder.node_encoder: # Encode integer node features via nn.Embeddings - self.node_encoder = LinearNodeEncoder(dim_inner) - if args.node_encoder_bn: + self.node_encoder = LinearNodeEncoder(self.dim_in, dim_inner) + if args.encoder.node_encoder_bn: self.node_encoder_bn = BatchNorm1dNode(dim_inner, 1e-5, 0.1) # Update dim_in to reflect the new dimension fo the node features self.dim_in = dim_inner - if args.edge_encoder: - - dim_edge = dim_inner + if args.encoder.edge_encoder: + args.edge_dim + enc_dim_edge = dim_inner # Encode integer edge features via nn.Embeddings - self.edge_encoder = LinearEdgeEncoder(dim_edge) - if cfg.dataset.edge_encoder_bn: - self.edge_encoder_bn = BatchNorm1dNode(dim_edge, 1e-5, 0.1) + self.edge_encoder = LinearEdgeEncoder(edge_dim, enc_dim_edge) + if args.encoder.edge_encoder_bn: + self.edge_encoder_bn = BatchNorm1dNode(enc_dim_edge, 1e-5, 0.1) def forward(self, batch): for module in self.children(): @@ -121,7 +118,7 @@ def __init__(self, args): dim_inner ) rel_pe_dim = args.model.posenc_RRWP.ksteps - self.rrwp_rel_encoder = RRWPLinearNodeEncoder( + self.rrwp_rel_encoder = RRWPLinearEdgeEncoder( rel_pe_dim, dim_edge, pad_to_full_graph=args.model.gt.attn.full_attn, @@ -158,6 +155,7 @@ def __init__(self, args): ) def forward(self, batch): + print('process--->>', batch) # TODO remove print for module in self.children(): batch = module(batch) From 88d9ca6a7d2cea3f6f935b123b236a203452809f Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:38 -0500 Subject: [PATCH 08/26] matching up parameters Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 2 +- gridfm_graphkit/models/grit_transformer.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index 9896057..bd76278 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -48,7 +48,7 @@ model: edge_encoder: True node_encoder_name: TODO node_encoder_bn: True - .edge_encoder_bn: True + edge_encoder_bn: True gt: layer_type: GritTransformer dim_hidden: 64 # `gt.dim_hidden` must match `gnn.dim_inner` diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 49bfdf2..e8746b8 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -74,7 +74,7 @@ def __init__( # Update dim_in to reflect the new dimension fo the node features self.dim_in = dim_inner if args.encoder.edge_encoder: - args.edge_dim + edge_dim = args.edge_dim enc_dim_edge = dim_inner # Encode integer edge features via nn.Embeddings self.edge_encoder = LinearEdgeEncoder(edge_dim, enc_dim_edge) @@ -107,8 +107,8 @@ def __init__(self, args): self.encoder = FeatureEncoder( dim_in, dim_inner, - args.model.encoder - ) # TODO add args + args.model + ) dim_in = self.encoder.dim_in if args.model.posenc_RRWP.enable: From b7d9dcf035026c797c1d68577d8b04ed9d3247ee Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:39 -0500 Subject: [PATCH 09/26] matching up parameters Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 2 ++ gridfm_graphkit/models/grit_layer.py | 4 ++-- gridfm_graphkit/models/grit_transformer.py | 6 +++--- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index bd76278..73e5817 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -64,6 +64,8 @@ model: O_e: True norm_e: True signed_sqrt: True + bn_momentum: 0.1 + bn_no_runner: False optimizer: beta1: 0.9 diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py index 53e7217..ffcf584 100644 --- a/gridfm_graphkit/models/grit_layer.py +++ b/gridfm_graphkit/models/grit_layer.py @@ -166,10 +166,10 @@ def __init__(self, in_dim, out_dim, num_heads, self.batch_norm = batch_norm # ------- - self.update_e = cfg.get("update_e", True) + self.update_e = getattr(cfg, "update_e", True) self.bn_momentum = cfg.bn_momentum self.bn_no_runner = cfg.bn_no_runner - self.rezero = cfg.get("rezero", False) + self.rezero = getattr(cfg, "rezero", False) self.act = act_dict[act]() if act is not None else nn.Identity() if cfg.get("attn", None) is None: diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index e8746b8..8d4f696 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -111,13 +111,13 @@ def __init__(self, args): ) dim_in = self.encoder.dim_in - if args.model.posenc_RRWP.enable: + if args.data.posenc_RRWP.enable: self.rrwp_abs_encoder = RRWPLinearNodeEncoder( - args.model.posenc_RRWP.ksteps, + args.data.posenc_RRWP.ksteps, dim_inner ) - rel_pe_dim = args.model.posenc_RRWP.ksteps + rel_pe_dim = args.data.posenc_RRWP.ksteps self.rrwp_rel_encoder = RRWPLinearEdgeEncoder( rel_pe_dim, dim_edge, From 38cc44a31d909cba09a8296784fc5f02e9637dba Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:39 -0500 Subject: [PATCH 10/26] matching up parameters Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_layer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py index ffcf584..98d0b6c 100644 --- a/gridfm_graphkit/models/grit_layer.py +++ b/gridfm_graphkit/models/grit_layer.py @@ -166,10 +166,10 @@ def __init__(self, in_dim, out_dim, num_heads, self.batch_norm = batch_norm # ------- - self.update_e = getattr(cfg, "update_e", True) - self.bn_momentum = cfg.bn_momentum - self.bn_no_runner = cfg.bn_no_runner - self.rezero = getattr(cfg, "rezero", False) + self.update_e = getattr(cfg.attn, "update_e", True) + self.bn_momentum = cfg.attn.bn_momentum + self.bn_no_runner = cfg.attn.bn_no_runner + self.rezero = getattr(cfg.attn, "rezero", False) self.act = act_dict[act]() if act is not None else nn.Identity() if cfg.get("attn", None) is None: From 7fded9556996acec2d55919e95059fe4faba93fa Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:39 -0500 Subject: [PATCH 11/26] matching up parameters in grit layer Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_layer.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py index 98d0b6c..9723304 100644 --- a/gridfm_graphkit/models/grit_layer.py +++ b/gridfm_graphkit/models/grit_layer.py @@ -72,7 +72,7 @@ def __init__(self, in_dim, out_dim, num_heads, use_bias, if act is None: self.act = nn.Identity() else: - self.act = act_dict[act]() + self.act = nn.ReLU() if self.edge_enhance: self.VeRow = nn.Parameter(torch.zeros(self.out_dim, self.num_heads, self.out_dim), requires_grad=True) @@ -171,12 +171,15 @@ def __init__(self, in_dim, out_dim, num_heads, self.bn_no_runner = cfg.attn.bn_no_runner self.rezero = getattr(cfg.attn, "rezero", False) - self.act = act_dict[act]() if act is not None else nn.Identity() - if cfg.get("attn", None) is None: + if act is not None + self.act = nn.ReLU() + else: + self.act = nn.Identity() + + if getattr(cfg, "attn", None) is None: cfg.attn = dict() - self.use_attn = cfg.attn.get("use", True) - # self.sigmoid_deg = cfg.attn.get("sigmoid_deg", False) - self.deg_scaler = cfg.attn.get("deg_scaler", True) + self.use_attn = getattr(cfg.attn, "use", True) + self.deg_scaler = getattr(cfg.attn, "deg_scaler", True) self.attention = MultiHeadAttentionLayerGritSparse( in_dim=in_dim, From 0f3b803a2a2ccfa69b11bac643b1ad14e2e24d38 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:39 -0500 Subject: [PATCH 12/26] matching up parameters in grit layer Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py index 9723304..0bcdf73 100644 --- a/gridfm_graphkit/models/grit_layer.py +++ b/gridfm_graphkit/models/grit_layer.py @@ -171,7 +171,7 @@ def __init__(self, in_dim, out_dim, num_heads, self.bn_no_runner = cfg.attn.bn_no_runner self.rezero = getattr(cfg.attn, "rezero", False) - if act is not None + if act is not None: self.act = nn.ReLU() else: self.act = nn.Identity() From af8ad03ad589f781a35d5deaa557a616c5463e80 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:40 -0500 Subject: [PATCH 13/26] matching up parameters in grit layer Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_layer.py | 38 ++++++++++++++-------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py index 0bcdf73..f95ffc7 100644 --- a/gridfm_graphkit/models/grit_layer.py +++ b/gridfm_graphkit/models/grit_layer.py @@ -185,31 +185,31 @@ def __init__(self, in_dim, out_dim, num_heads, in_dim=in_dim, out_dim=out_dim // num_heads, num_heads=num_heads, - use_bias=cfg.attn.get("use_bias", False), + use_bias=getattr(cfg.attn, "use_bias", False), dropout=attn_dropout, - clamp=cfg.attn.get("clamp", 5.), - act=cfg.attn.get("act", "relu"), - edge_enhance=cfg.attn.get("edge_enhance", True), - sqrt_relu=cfg.attn.get("sqrt_relu", False), - signed_sqrt=cfg.attn.get("signed_sqrt", False), - scaled_attn =cfg.attn.get("scaled_attn", False), - no_qk=cfg.attn.get("no_qk", False), + clamp=getattr(cfg.attn, "clamp", 5.), + act=getattr(cfg.attn, "act", "relu"), + edge_enhance=getattr(cfg.attn, "edge_enhance", True), + sqrt_relu=getattr(cfg.attn, "sqrt_relu", False), + signed_sqrt=getattr(cfg.attn, "signed_sqrt", False), + scaled_attn =getattr(cfg.attn,"scaled_attn", False), + no_qk=getattr(cfg.attn, "no_qk", False), ) - if cfg.attn.get('graphormer_attn', False): + if getattr(cfg.attn, 'graphormer_attn', False): self.attention = MultiHeadAttentionLayerGraphormerSparse( in_dim=in_dim, out_dim=out_dim // num_heads, num_heads=num_heads, - use_bias=cfg.attn.get("use_bias", False), + use_bias=getattr(cfg.attn, "use_bias", False), dropout=attn_dropout, - clamp=cfg.attn.get("clamp", 5.), - act=cfg.attn.get("act", "relu"), + clamp=getattr(cfg.attn, "clamp", 5.), + act=getattr(cfg.attn, "act", "relu"), edge_enhance=True, - sqrt_relu=cfg.attn.get("sqrt_relu", False), - signed_sqrt=cfg.attn.get("signed_sqrt", False), - scaled_attn =cfg.attn.get("scaled_attn", False), - no_qk=cfg.attn.get("no_qk", False), + sqrt_relu=getattr(cfg.attn, "sqrt_relu", False), + signed_sqrt=getattr(cfg.attn, "signed_sqrt", False), + scaled_attn =getattr(cfg.attn, "scaled_attn", False), + no_qk=getattr(cfg.attn, "no_qk", False), ) @@ -232,8 +232,8 @@ def __init__(self, in_dim, out_dim, num_heads, if self.batch_norm: # when the batch_size is really small, use smaller momentum to avoid bad mini-batch leading to extremely bad val/test loss (NaN) - self.batch_norm1_h = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.bn_momentum) - self.batch_norm1_e = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.bn_momentum) if norm_e else nn.Identity() + self.batch_norm1_h = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.attn.bn_momentum) + self.batch_norm1_e = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.attn.bn_momentum) if norm_e else nn.Identity() # FFN for h self.FFN_h_layer1 = nn.Linear(out_dim, out_dim * 2) @@ -243,7 +243,7 @@ def __init__(self, in_dim, out_dim, num_heads, self.layer_norm2_h = nn.LayerNorm(out_dim) if self.batch_norm: - self.batch_norm2_h = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.bn_momentum) + self.batch_norm2_h = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.attn.bn_momentum) if self.rezero: self.alpha1_h = nn.Parameter(torch.zeros(1,1)) From f430f2a7562c3db8b570b673422e7406446de593 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:40 -0500 Subject: [PATCH 14/26] matching up parameters in data module Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/datasets/powergrid_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gridfm_graphkit/datasets/powergrid_datamodule.py b/gridfm_graphkit/datasets/powergrid_datamodule.py index 9960e08..e67956e 100644 --- a/gridfm_graphkit/datasets/powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/powergrid_datamodule.py @@ -142,7 +142,7 @@ def setup(self, stage: str): if dataset.transform is None: dataset.transform = pe_transform else: - dataset.transform = T.compose([pe_transform, dataset.transform]) + dataset.transform = T.Compose([pe_transform, dataset.transform]) self.datasets.append(dataset) From e1c489050bd9995e10efe69d6f356515dc6b965d Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:40 -0500 Subject: [PATCH 15/26] flow over parameters from base model Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_transformer.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 8d4f696..10af25a 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -103,6 +103,19 @@ def __init__(self, args): num_heads = args.model.attention_head dropout = args.model.dropout num_layers = args.model.num_layers + self.mask_dim = getattr(args.data, "mask_dim", 6) + self.mask_value = getattr(args.data, "mask_value", -1.0) + self.learn_mask = getattr(args.data, "learn_mask", False) + if self.learn_mask: + self.mask_value = nn.Parameter( + torch.randn(self.mask_dim) + self.mask_value, + requires_grad=True, + ) + else: + self.mask_value = nn.Parameter( + torch.zeros(self.mask_dim) + self.mask_value, + requires_grad=False, + ) self.encoder = FeatureEncoder( dim_in, From 36dca0095cecea391a586b3ec205a229b3511e7c Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:40 -0500 Subject: [PATCH 16/26] verified encodings and data flow to model forward method Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_transformer.py | 2 +- .../tasks/feature_reconstruction_task.py | 24 ++++++++++--------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 10af25a..cf07a8c 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -167,7 +167,7 @@ def __init__(self, args): nn.Linear(dim_inner, dim_out), ) - def forward(self, batch): + def forward(self, batch): # self, x, pe, edge_index, edge_attr, batch # gps parameters print('process--->>', batch) # TODO remove print for module in self.children(): batch = module(batch) diff --git a/gridfm_graphkit/tasks/feature_reconstruction_task.py b/gridfm_graphkit/tasks/feature_reconstruction_task.py index cb6963b..da2f478 100644 --- a/gridfm_graphkit/tasks/feature_reconstruction_task.py +++ b/gridfm_graphkit/tasks/feature_reconstruction_task.py @@ -74,11 +74,11 @@ def __init__(self, args, node_normalizers, edge_normalizers): self.edge_normalizers = edge_normalizers self.save_hyperparameters() - def forward(self, x, pe, edge_index, edge_attr, batch, mask=None): - if mask is not None: - mask_value_expanded = self.model.mask_value.expand(x.shape[0], -1) - x[:, : mask.shape[1]][mask] = mask_value_expanded[mask] - return self.model(x, pe, edge_index, edge_attr, batch) + def forward(self, batch): + if batch.mask is not None: + mask_value_expanded = self.model.mask_value.expand(batch.x.shape[0], -1) + batch.x[:, : batch.mask.shape[1]][batch.mask] = mask_value_expanded[batch.mask] + return self.model(batch) @rank_zero_only def on_fit_start(self): @@ -111,12 +111,14 @@ def on_fit_start(self): def shared_step(self, batch): output = self.forward( - x=batch.x, - pe=batch.pe, - edge_index=batch.edge_index, - edge_attr=batch.edge_attr, - batch=batch.batch, - mask=batch.mask, + # TODO update args list in the GPS Transf. for consistency + # x=batch.x, + # pe=batch.pe, + # edge_index=batch.edge_index, + # edge_attr=batch.edge_attr, + # batch=batch.batch, + # mask=batch.mask, + batch ) loss_dict = self.loss_fn( From a8ec56efdbef89eef493f3e6f1b897f3e17140b6 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:41 -0500 Subject: [PATCH 17/26] match feature dimensions Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_transformer.py | 2 +- gridfm_graphkit/models/rrwp_encoder.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index cf07a8c..50d0fec 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -133,7 +133,7 @@ def __init__(self, args): rel_pe_dim = args.data.posenc_RRWP.ksteps self.rrwp_rel_encoder = RRWPLinearEdgeEncoder( rel_pe_dim, - dim_edge, + dim_inner, pad_to_full_graph=args.model.gt.attn.full_attn, add_node_attr_as_self_loop=False, fill_value=0. diff --git a/gridfm_graphkit/models/rrwp_encoder.py b/gridfm_graphkit/models/rrwp_encoder.py index b73e463..33c5215 100644 --- a/gridfm_graphkit/models/rrwp_encoder.py +++ b/gridfm_graphkit/models/rrwp_encoder.py @@ -114,6 +114,7 @@ def __init__(self, emb_dim, out_dim, batchnorm=False, layernorm=False, use_bias= if self.batchnorm or self.layernorm: warnings.warn("batchnorm/layernorm might ruin some properties of pe on providing shortest-path distance info ") + print('--------fc in and out:', emb_dim, out_dim) self.fc = nn.Linear(emb_dim, out_dim, bias=use_bias) torch.nn.init.xavier_uniform_(self.fc.weight) self.pad_to_full_graph = pad_to_full_graph @@ -144,7 +145,8 @@ def forward(self, batch): else: # edge_index, edge_attr = add_remaining_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) - + print('xxxx', edge_attr.size(), rrwp_val.size()) + print('yyyy', edge_index.size(), rrwp_idx.size()) out_idx, out_val = torch_sparse.coalesce( torch.cat([edge_index, rrwp_idx], dim=1), torch.cat([edge_attr, rrwp_val], dim=0), From 0868b96e7a5aed4e2648f6288d85073a260afd44 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:41 -0500 Subject: [PATCH 18/26] match feature dimensions Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index 73e5817..52c44c8 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -36,7 +36,7 @@ model: attention_head: 8 dropout: 0.1 edge_dim: 2 - hidden_size: 123 + hidden_size: 64 # `gt.dim_hidden` must match `gnn.dim_inner` input_dim: 9 num_layers: 10 output_dim: 6 From 3cc21a38eea36b3dddca0e3a2d1817e15ac66b86 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:41 -0500 Subject: [PATCH 19/26] reformat decoder to handle batch format Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/cli.py | 5 ++ gridfm_graphkit/models/grit_transformer.py | 58 ++++++++++++++++++++-- 2 files changed, 58 insertions(+), 5 deletions(-) diff --git a/gridfm_graphkit/cli.py b/gridfm_graphkit/cli.py index a7507c1..79cb772 100644 --- a/gridfm_graphkit/cli.py +++ b/gridfm_graphkit/cli.py @@ -77,6 +77,11 @@ def main_cli(args): max_epochs=config_args.training.epochs, callbacks=get_training_callbacks(config_args), ) + + # print('******model*****') + # print(model) + # print('******model*****') + if args.command == "train" or args.command == "finetune": trainer.fit(model=model, datamodule=litGrid) diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 50d0fec..2e85d4e 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -85,6 +85,49 @@ def forward(self, batch): for module in self.children(): batch = module(batch) return batch + +class GraphHead(nn.Module): + """ + Prediction head for graph prediction tasks. + Args: + dim_in (int): Input dimension. + dim_out (int): Output dimension. For binary prediction, dim_out=1. + L (int): Number of hidden layers. + """ + + def __init__(self, dim_in, dim_out): + super().__init__() + # self.deg_scaler = False + # self.fwl = False + + # list_FC_layers = [ + # nn.Linear(dim_in // 2 ** l, dim_in // 2 ** (l + 1), bias=True) + # for l in range(L)] + # list_FC_layers.append( + # nn.Linear(dim_in // 2 ** L, dim_out, bias=True)) + self.FC_layers = nn.Sequential( + nn.Linear(dim_in, dim_in), + nn.LeakyReLU(), + nn.Linear(dim_in, dim_out), + ) #nn.ModuleList(list_FC_layers) + # self.L = L + # self.activation = register.act_dict[cfg.gnn.act]() + # note: modified to add () in the end from original code of 'GPS' + # potentially due to the change of PyG/GraphGym version + + def _apply_index(self, batch): + return batch.graph_feature, batch.y + + def forward(self, batch): + # graph_emb = self.pooling_fun(batch.x, batch.batch) + graph_emb = self.FC_layers(batch.x) + # for l in range(self.L): + # graph_emb = self.FC_layers[l](graph_emb) + # graph_emb = self.activation(graph_emb) + # graph_emb = self.FC_layers[self.L](graph_emb) + batch.graph_feature = graph_emb + pred, label = self._apply_index(batch) + return pred @MODELS_REGISTRY.register("GRIT") @@ -161,15 +204,20 @@ def __init__(self, args): self.layers = nn.Sequential(*layers) - self.decoder = nn.Sequential( - nn.Linear(dim_inner, dim_inner), - nn.LeakyReLU(), - nn.Linear(dim_inner, dim_out), - ) + # self.decoder = nn.Sequential( + # nn.Linear(dim_inner, dim_inner), + # nn.LeakyReLU(), + # nn.Linear(dim_inner, dim_out), + # ) + + self.decoder = GraphHead(dim_inner, dim_out) def forward(self, batch): # self, x, pe, edge_index, edge_attr, batch # gps parameters print('process--->>', batch) # TODO remove print for module in self.children(): + print('----------') + print(module) batch = module(batch) + print('--passed--') return batch \ No newline at end of file From 17830516f74d297db439b659a14edf8db99506f3 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:42 -0500 Subject: [PATCH 20/26] confirmed training loop functions Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_transformer.py | 8 ++++---- gridfm_graphkit/models/rrwp_encoder.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 2e85d4e..4e09de9 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -213,11 +213,11 @@ def __init__(self, args): self.decoder = GraphHead(dim_inner, dim_out) def forward(self, batch): # self, x, pe, edge_index, edge_attr, batch # gps parameters - print('process--->>', batch) # TODO remove print + # print('process--->>', batch) # TODO remove print for module in self.children(): - print('----------') - print(module) + # print('----------') + # print(module) batch = module(batch) - print('--passed--') + # print('--passed--') return batch \ No newline at end of file diff --git a/gridfm_graphkit/models/rrwp_encoder.py b/gridfm_graphkit/models/rrwp_encoder.py index 33c5215..270ca86 100644 --- a/gridfm_graphkit/models/rrwp_encoder.py +++ b/gridfm_graphkit/models/rrwp_encoder.py @@ -145,8 +145,8 @@ def forward(self, batch): else: # edge_index, edge_attr = add_remaining_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) - print('xxxx', edge_attr.size(), rrwp_val.size()) - print('yyyy', edge_index.size(), rrwp_idx.size()) + # print('xxxx', edge_attr.size(), rrwp_val.size()) + # print('yyyy', edge_index.size(), rrwp_idx.size()) out_idx, out_val = torch_sparse.coalesce( torch.cat([edge_index, rrwp_idx], dim=1), torch.cat([edge_attr, rrwp_val], dim=0), From c75012fa845f0775f47df55eda7f2fe2c85d25d9 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:42 -0500 Subject: [PATCH 21/26] update toml Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 22 +++++++++++----------- gridfm_graphkit/models/rrwp_encoder.py | 2 +- pyproject.toml | 1 + 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index 52c44c8..30ee213 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -11,21 +11,21 @@ data: networks: # - Texas2k_case1_2016summerpeak - case24_ieee_rts - # - case118_ieee - # - case300_ieee + - case118_ieee + - case300_ieee - case89_pegase - # - case240_pserc + - case240_pserc normalization: baseMVAnorm scenarios: # - 5000 - - 5000 - - 5000 - # - 30000 - # - 50000 - # - 50000 + - 50000 + - 50000 + - 30000 + - 50000 + - 50000 test_ratio: 0.1 val_ratio: 0.1 - workers: 4 + workers: 8 posenc_RRWP: # TODO maybe better with data section... enable: True ksteps: 21 @@ -36,7 +36,7 @@ model: attention_head: 8 dropout: 0.1 edge_dim: 2 - hidden_size: 64 # `gt.dim_hidden` must match `gnn.dim_inner` + hidden_size: 116 # `gt.dim_hidden` must match `gnn.dim_inner` input_dim: 9 num_layers: 10 output_dim: 6 @@ -51,7 +51,7 @@ model: edge_encoder_bn: True gt: layer_type: GritTransformer - dim_hidden: 64 # `gt.dim_hidden` must match `gnn.dim_inner` + dim_hidden: 116 # `gt.dim_hidden` must match `gnn.dim_inner` layer_norm: False batch_norm: True update_e: True diff --git a/gridfm_graphkit/models/rrwp_encoder.py b/gridfm_graphkit/models/rrwp_encoder.py index 270ca86..2dadd35 100644 --- a/gridfm_graphkit/models/rrwp_encoder.py +++ b/gridfm_graphkit/models/rrwp_encoder.py @@ -114,7 +114,7 @@ def __init__(self, emb_dim, out_dim, batchnorm=False, layernorm=False, use_bias= if self.batchnorm or self.layernorm: warnings.warn("batchnorm/layernorm might ruin some properties of pe on providing shortest-path distance info ") - print('--------fc in and out:', emb_dim, out_dim) + # print('--------fc in and out:', emb_dim, out_dim) self.fc = nn.Linear(emb_dim, out_dim, bias=use_bias) torch.nn.init.xavier_uniform_(self.fc.weight) self.pad_to_full_graph = pad_to_full_graph diff --git a/pyproject.toml b/pyproject.toml index 51c8665..2250ae9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ dependencies = [ "pyyaml", "lightning", "seaborn", + "opt-einsum", ] [project.optional-dependencies] From 3d3f98b3123defb965144e8d50868c46bfdab455 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:42 -0500 Subject: [PATCH 22/26] added forward method to transform class Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/datasets/posenc_stats.py | 6 +++++- gridfm_graphkit/models/grit_transformer.py | 4 ++-- pyproject.toml | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/gridfm_graphkit/datasets/posenc_stats.py b/gridfm_graphkit/datasets/posenc_stats.py index 8bb2b9d..21e7841 100644 --- a/gridfm_graphkit/datasets/posenc_stats.py +++ b/gridfm_graphkit/datasets/posenc_stats.py @@ -13,6 +13,7 @@ from torch_geometric.transforms import BaseTransform from torch_geometric.data import Data +from typing import Any def compute_posenc_stats(data, pe_types, cfg): """Precompute positional encodings for the given graph. @@ -58,9 +59,12 @@ def __init__(self, pe_types, cfg): self.pe_types = pe_types self.cfg = cfg + def forward(self, data: Any) -> Any: + pass + def __call__(self, data: Data) -> Data: data = compute_posenc_stats(data, pe_types=self.pe_types, cfg=self.cfg ) - return data \ No newline at end of file + return data diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 4e09de9..7caed0f 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -213,11 +213,11 @@ def __init__(self, args): self.decoder = GraphHead(dim_inner, dim_out) def forward(self, batch): # self, x, pe, edge_index, edge_attr, batch # gps parameters - # print('process--->>', batch) # TODO remove print + #print('process--->>', batch) # TODO remove print for module in self.children(): # print('----------') # print(module) batch = module(batch) # print('--passed--') - return batch \ No newline at end of file + return batch diff --git a/pyproject.toml b/pyproject.toml index 2250ae9..4a17ed5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ classifiers = [ dependencies = [ - "torch>2.0", + "torch==2.6", "torch-geometric", "mlflow", "nbformat", From d238e7591d4f358ac2877f838debd4ab64aa7afe Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:43 -0500 Subject: [PATCH 23/26] update readme with install instructions Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 78e8661..5a096dc 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,7 @@ cd gridfm-graphkit python -m venv venv source venv/bin/activate pip install -e . +pip install torch_sparse torch_scatter -f https://data.pyg.org/whl/torch-2.6.0+cu124.html ``` For documentation generation and unit testing, install with the optional `dev` and `test` extras: From 17b0889e339f0378d761afe047f5b08db0802d03 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:43 -0500 Subject: [PATCH 24/26] verifed compat with GPS and GNN Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 2 +- gridfm_graphkit/datasets/powergrid_datamodule.py | 2 +- gridfm_graphkit/models/gnn_transformer.py | 8 +++++++- gridfm_graphkit/models/gps_transformer.py | 8 +++++++- 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index 30ee213..8f11c93 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -25,7 +25,7 @@ data: - 50000 test_ratio: 0.1 val_ratio: 0.1 - workers: 8 + workers: 0 posenc_RRWP: # TODO maybe better with data section... enable: True ksteps: 21 diff --git a/gridfm_graphkit/datasets/powergrid_datamodule.py b/gridfm_graphkit/datasets/powergrid_datamodule.py index e67956e..4b0320f 100644 --- a/gridfm_graphkit/datasets/powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/powergrid_datamodule.py @@ -135,7 +135,7 @@ def setup(self, stage: str): transform=get_transform(args=self.args), ) - if self.args.data.posenc_RRWP.enable: + if ('posenc_RRWP' in self.args.data) and self.args.data.posenc_RRWP.enable: pe_transform = ComputePosencStat(pe_types=['RRWP'], cfg=self.args.data ) diff --git a/gridfm_graphkit/models/gnn_transformer.py b/gridfm_graphkit/models/gnn_transformer.py index 9e1ab23..37d3632 100644 --- a/gridfm_graphkit/models/gnn_transformer.py +++ b/gridfm_graphkit/models/gnn_transformer.py @@ -74,7 +74,7 @@ def __init__(self, args): requires_grad=False, ) - def forward(self, x, pe, edge_index, edge_attr, batch): + def forward(self, data_batch): """ Forward pass for the GPSTransformer. @@ -88,6 +88,12 @@ def forward(self, x, pe, edge_index, edge_attr, batch): Returns: output (Tensor): Output node features of shape [num_nodes, output_dim]. """ + x=data_batch.x + pe=data_batch.pe + edge_index=data_batch.edge_index + edge_attr=data_batch.edge_attr + batch=data_batch.batch + for conv in self.layers: x = conv(x, edge_index, edge_attr) x = nn.LeakyReLU()(x) diff --git a/gridfm_graphkit/models/gps_transformer.py b/gridfm_graphkit/models/gps_transformer.py index cc8b648..ca45c5a 100644 --- a/gridfm_graphkit/models/gps_transformer.py +++ b/gridfm_graphkit/models/gps_transformer.py @@ -105,7 +105,7 @@ def __init__(self, args): requires_grad=False, ) - def forward(self, x, pe, edge_index, edge_attr, batch): + def forward(self, data_batch): """ Forward pass for the GPSTransformer. @@ -119,6 +119,12 @@ def forward(self, x, pe, edge_index, edge_attr, batch): Returns: output (Tensor): Output node features of shape [num_nodes, output_dim]. """ + x=data_batch.x + pe=data_batch.pe + edge_index=data_batch.edge_index + edge_attr=data_batch.edge_attr + batch=data_batch.batch + x_pe = self.pe_norm(pe) x = self.encoder(x) From 091f0849ce997ff6d820b2102fb4d631faa7f789 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:43 -0500 Subject: [PATCH 25/26] work on comments and clean up Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/datasets/rrwp.py | 19 +------ gridfm_graphkit/models/grit_layer.py | 12 ++-- gridfm_graphkit/models/grit_transformer.py | 55 +++++++------------ gridfm_graphkit/models/rrwp_encoder.py | 27 ++++----- .../tasks/feature_reconstruction_task.py | 7 --- 5 files changed, 42 insertions(+), 78 deletions(-) diff --git a/gridfm_graphkit/datasets/rrwp.py b/gridfm_graphkit/datasets/rrwp.py index 26218f0..acbe112 100644 --- a/gridfm_graphkit/datasets/rrwp.py +++ b/gridfm_graphkit/datasets/rrwp.py @@ -1,20 +1,7 @@ -# ------------------------ : new rwpse ---------------- -from typing import Union, Any, Optional -import numpy as np +from typing import Any, Optional import torch import torch.nn.functional as F -import torch_geometric as pyg -from torch_geometric.data import Data, HeteroData -from torch_geometric.transforms import BaseTransform -from torch_scatter import scatter, scatter_add, scatter_max - - -from torch_geometric.utils import ( - get_laplacian, - get_self_loop_attr, - to_scipy_sparse_matrix, -) -import torch_sparse +from torch_geometric.data import Data from torch_sparse import SparseTensor @@ -42,8 +29,6 @@ def add_full_rrwp(data, spd=False, **kwargs ): - device=data.edge_index.device - ind_vec = torch.eye(walk_length, dtype=torch.float, device=device) num_nodes = data.num_nodes edge_index, edge_weight = data.edge_index, data.edge_weight diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py index f95ffc7..a1ffc4a 100644 --- a/gridfm_graphkit/models/grit_layer.py +++ b/gridfm_graphkit/models/grit_layer.py @@ -8,11 +8,12 @@ import opt_einsum as oe - import warnings + def pyg_softmax(src, index, num_nodes=None): - r"""Computes a sparsely evaluated softmax. + """ + Computes a sparsely evaluated softmax. Given a value tensor :attr:`src`, this function first groups the values along the first dimension based on the indices specified in :attr:`index`, and then proceeds to compute the softmax individually for each group. @@ -23,7 +24,8 @@ def pyg_softmax(src, index, num_nodes=None): num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) - :rtype: :class:`Tensor` + Returns: + out (Tensor) """ num_nodes = maybe_num_nodes(index, num_nodes) @@ -39,7 +41,7 @@ def pyg_softmax(src, index, num_nodes=None): class MultiHeadAttentionLayerGritSparse(nn.Module): """ - Proposed Attention Computation for GRIT + Attention Computation for GRIT """ def __init__(self, in_dim, out_dim, num_heads, use_bias, @@ -140,7 +142,7 @@ def forward(self, batch): class GritTransformerLayer(nn.Module): """ - Proposed Transformer Layer for GRIT + Transformer Layer for GRIT """ def __init__(self, in_dim, out_dim, num_heads, dropout=0.0, diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 7caed0f..a1717d1 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -12,7 +12,8 @@ class BatchNorm1dNode(torch.nn.Module): Args: dim_in (int): BatchNorm input dimension. - TODO fill in comments + eps (float): BatchNorm eps. + momentum (float): BatchNorm momentum. """ def __init__(self, dim_in, eps, momentum): super().__init__() @@ -88,7 +89,7 @@ def forward(self, batch): class GraphHead(nn.Module): """ - Prediction head for graph prediction tasks. + Prediction head for decoding tasks. Args: dim_in (int): Input dimension. dim_out (int): Output dimension. For binary prediction, dim_out=1. @@ -97,34 +98,18 @@ class GraphHead(nn.Module): def __init__(self, dim_in, dim_out): super().__init__() - # self.deg_scaler = False - # self.fwl = False - - # list_FC_layers = [ - # nn.Linear(dim_in // 2 ** l, dim_in // 2 ** (l + 1), bias=True) - # for l in range(L)] - # list_FC_layers.append( - # nn.Linear(dim_in // 2 ** L, dim_out, bias=True)) + self.FC_layers = nn.Sequential( nn.Linear(dim_in, dim_in), nn.LeakyReLU(), nn.Linear(dim_in, dim_out), - ) #nn.ModuleList(list_FC_layers) - # self.L = L - # self.activation = register.act_dict[cfg.gnn.act]() - # note: modified to add () in the end from original code of 'GPS' - # potentially due to the change of PyG/GraphGym version + ) def _apply_index(self, batch): return batch.graph_feature, batch.y def forward(self, batch): - # graph_emb = self.pooling_fun(batch.x, batch.batch) graph_emb = self.FC_layers(batch.x) - # for l in range(self.L): - # graph_emb = self.FC_layers[l](graph_emb) - # graph_emb = self.activation(graph_emb) - # graph_emb = self.FC_layers[self.L](graph_emb) batch.graph_feature = graph_emb pred, label = self._apply_index(batch) return pred @@ -132,9 +117,12 @@ def forward(self, batch): @MODELS_REGISTRY.register("GRIT") class GritTransformer(torch.nn.Module): - ''' - The proposed GritTransformer (Graph Inductive Bias Transformer) - ''' + """ + The GritTransformer (Graph Inductive Bias Transformer) from + Graph Inductive Biases in Transformers without Message Passing, L. Ma et al., + 2023. + + """ def __init__(self, args): super().__init__() @@ -204,20 +192,19 @@ def __init__(self, args): self.layers = nn.Sequential(*layers) - # self.decoder = nn.Sequential( - # nn.Linear(dim_inner, dim_inner), - # nn.LeakyReLU(), - # nn.Linear(dim_inner, dim_out), - # ) - self.decoder = GraphHead(dim_inner, dim_out) - def forward(self, batch): # self, x, pe, edge_index, edge_attr, batch # gps parameters - #print('process--->>', batch) # TODO remove print + def forward(self, batch): + """ + Forward pass for GRIT. + + Args: + batch (Batch): Pytorch Geometric Batch object, with x, y encodings, etc. + + Returns: + output (Tensor): Output node features of shape [num_nodes, output_dim]. + """ for module in self.children(): - # print('----------') - # print(module) batch = module(batch) - # print('--passed--') return batch diff --git a/gridfm_graphkit/models/rrwp_encoder.py b/gridfm_graphkit/models/rrwp_encoder.py index 2dadd35..1f7fd10 100644 --- a/gridfm_graphkit/models/rrwp_encoder.py +++ b/gridfm_graphkit/models/rrwp_encoder.py @@ -51,10 +51,10 @@ def full_edge_index(edge_index, batch=None): class RRWPLinearNodeEncoder(torch.nn.Module): """ - FC_1(RRWP) + FC_2 (Node-attr) - note: FC_2 is given by the Typedict encoder of node-attr in some cases - Parameters: - num_classes - the number of classes for the embedding mapping to learn + FC_1(RRWP) + FC_2 (Node-attr) + note: FC_2 is given by the Typedict encoder of node-attr in some cases + Parameters: + num_classes - the number of classes for the embedding mapping to learn """ def __init__(self, emb_dim, out_dim, use_bias=False, batchnorm=False, layernorm=False, pe_name="rrwp"): super().__init__() @@ -90,14 +90,14 @@ def forward(self, batch): class RRWPLinearEdgeEncoder(torch.nn.Module): - ''' - Merge RRWP with given edge-attr and Zero-padding to all pairs of node - FC_1(RRWP) + FC_2(edge-attr) - - FC_2 given by the TypedictEncoder in same cases - - Zero-padding for non-existing edges in fully-connected graph - - (optional) add node-attr as the E_{i,i}'s attr - note: assuming node-attr and edge-attr is with the same dimension after Encoders - ''' + """ + Merge RRWP with given edge-attr and Zero-padding to all pairs of node + FC_1(RRWP) + FC_2(edge-attr) + - FC_2 given by the TypedictEncoder in same cases + - Zero-padding for non-existing edges in fully-connected graph + - (optional) add node-attr as the E_{i,i}'s attr + note: assuming node-attr and edge-attr is with the same dimension after Encoders + """ def __init__(self, emb_dim, out_dim, batchnorm=False, layernorm=False, use_bias=False, pad_to_full_graph=True, fill_value=0., add_node_attr_as_self_loop=False, @@ -143,10 +143,7 @@ def forward(self, batch): if self.overwrite_old_attr: out_idx, out_val = rrwp_idx, rrwp_val else: - # edge_index, edge_attr = add_remaining_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) - # print('xxxx', edge_attr.size(), rrwp_val.size()) - # print('yyyy', edge_index.size(), rrwp_idx.size()) out_idx, out_val = torch_sparse.coalesce( torch.cat([edge_index, rrwp_idx], dim=1), torch.cat([edge_attr, rrwp_val], dim=0), diff --git a/gridfm_graphkit/tasks/feature_reconstruction_task.py b/gridfm_graphkit/tasks/feature_reconstruction_task.py index da2f478..0d1743b 100644 --- a/gridfm_graphkit/tasks/feature_reconstruction_task.py +++ b/gridfm_graphkit/tasks/feature_reconstruction_task.py @@ -111,13 +111,6 @@ def on_fit_start(self): def shared_step(self, batch): output = self.forward( - # TODO update args list in the GPS Transf. for consistency - # x=batch.x, - # pe=batch.pe, - # edge_index=batch.edge_index, - # edge_attr=batch.edge_attr, - # batch=batch.batch, - # mask=batch.mask, batch ) From 53d564415e4a8c99f10c7882c343bf38fdadf766 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:44 -0500 Subject: [PATCH 26/26] deep copy in test method Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/tasks/feature_reconstruction_task.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gridfm_graphkit/tasks/feature_reconstruction_task.py b/gridfm_graphkit/tasks/feature_reconstruction_task.py index 0d1743b..e6a4749 100644 --- a/gridfm_graphkit/tasks/feature_reconstruction_task.py +++ b/gridfm_graphkit/tasks/feature_reconstruction_task.py @@ -5,6 +5,7 @@ import numpy as np import os import pandas as pd +import copy from lightning.pytorch.loggers import MLFlowLogger from gridfm_graphkit.io.param_handler import load_model, get_loss_function @@ -162,7 +163,7 @@ def validation_step(self, batch, batch_idx): return loss_dict["loss"] def test_step(self, batch, batch_idx, dataloader_idx=0): - output, loss_dict = self.shared_step(batch) + output, loss_dict = self.shared_step(copy.deepcopy(batch)) dataset_name = self.args.data.networks[dataloader_idx]