diff --git a/README.md b/README.md index 78e8661..5a096dc 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,7 @@ cd gridfm-graphkit python -m venv venv source venv/bin/activate pip install -e . +pip install torch_sparse torch_scatter -f https://data.pyg.org/whl/torch-2.6.0+cu124.html ``` For documentation generation and unit testing, install with the optional `dev` and `test` extras: diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml new file mode 100644 index 0000000..8f11c93 --- /dev/null +++ b/examples/config/grit_pretraining.yaml @@ -0,0 +1,88 @@ +callbacks: + patience: 100 + tol: 0 +data: + baseMVA: 100 + learn_mask: false + mask_dim: 6 + mask_ratio: 0.5 + mask_type: rnd + mask_value: -1.0 + networks: + # - Texas2k_case1_2016summerpeak + - case24_ieee_rts + - case118_ieee + - case300_ieee + - case89_pegase + - case240_pserc + normalization: baseMVAnorm + scenarios: + # - 5000 + - 50000 + - 50000 + - 30000 + - 50000 + - 50000 + test_ratio: 0.1 + val_ratio: 0.1 + workers: 0 + posenc_RRWP: # TODO maybe better with data section... + enable: True + ksteps: 21 + add_identity: True + add_node_attr: False + add_inverse: False +model: + attention_head: 8 + dropout: 0.1 + edge_dim: 2 + hidden_size: 116 # `gt.dim_hidden` must match `gnn.dim_inner` + input_dim: 9 + num_layers: 10 + output_dim: 6 + pe_dim: 20 + type: GRIT + act: relu + encoder: + node_encoder: True + edge_encoder: True + node_encoder_name: TODO + node_encoder_bn: True + edge_encoder_bn: True + gt: + layer_type: GritTransformer + dim_hidden: 116 # `gt.dim_hidden` must match `gnn.dim_inner` + layer_norm: False + batch_norm: True + update_e: True + attn_dropout: 0.2 + attn: + clamp: 5. + act: 'relu' + full_attn: True + edge_enhance: True + O_e: True + norm_e: True + signed_sqrt: True + bn_momentum: 0.1 + bn_no_runner: False + +optimizer: + beta1: 0.9 + beta2: 0.999 + learning_rate: 0.0001 + lr_decay: 0.7 + lr_patience: 10 +seed: 0 +training: + batch_size: 8 + epochs: 500 + loss_weights: + - 0.01 + - 0.99 + losses: + - MaskedMSE + - PBE + accelerator: auto + devices: auto + strategy: auto diff --git a/gridfm_graphkit/cli.py b/gridfm_graphkit/cli.py index a7507c1..79cb772 100644 --- a/gridfm_graphkit/cli.py +++ b/gridfm_graphkit/cli.py @@ -77,6 +77,11 @@ def main_cli(args): max_epochs=config_args.training.epochs, callbacks=get_training_callbacks(config_args), ) + + # print('******model*****') + # print(model) + # print('******model*****') + if args.command == "train" or args.command == "finetune": trainer.fit(model=model, datamodule=litGrid) diff --git a/gridfm_graphkit/datasets/posenc_stats.py b/gridfm_graphkit/datasets/posenc_stats.py new file mode 100644 index 0000000..21e7841 --- /dev/null +++ b/gridfm_graphkit/datasets/posenc_stats.py @@ -0,0 +1,70 @@ +from copy import deepcopy + +import numpy as np +import torch +import torch.nn.functional as F + +from torch_geometric.utils import (get_laplacian, to_scipy_sparse_matrix, + to_undirected, to_dense_adj) +from torch_geometric.utils.num_nodes import maybe_num_nodes +from torch_scatter import scatter_add +from functools import partial +from gridfm_graphkit.datasets.rrwp import add_full_rrwp + +from torch_geometric.transforms import BaseTransform +from torch_geometric.data import Data +from typing import Any + +def compute_posenc_stats(data, pe_types, cfg): + """Precompute positional encodings for the given graph. + Supported PE statistics to precompute in original implementation, + selected by `pe_types`: + 'LapPE': Laplacian eigen-decomposition. + 'RWSE': Random walk landing probabilities (diagonals of RW matrices). + 'HKfullPE': Full heat kernels and their diagonals. (NOT IMPLEMENTED) + 'HKdiagSE': Diagonals of heat kernel diffusion. + 'ElstaticSE': Kernel based on the electrostatic interaction between nodes. + 'RRWP': Relative Random Walk Probabilities PE (Ours, for GRIT) + Args: + data: PyG graph + pe_types: Positional encoding types to precompute statistics for. + This can also be a combination, e.g. 'eigen+rw_landing' + is_undirected: True if the graph is expected to be undirected + cfg: Main configuration node + + Returns: + Extended PyG Data object. + """ + # Verify PE types. + for t in pe_types: + if t not in ['LapPE', 'EquivStableLapPE', 'SignNet', + 'RWSE', 'HKdiagSE', 'HKfullPE', 'ElstaticSE','RRWP']: + raise ValueError(f"Unexpected PE stats selection {t} in {pe_types}") + + if 'RRWP' in pe_types: + param = cfg.posenc_RRWP + transform = partial(add_full_rrwp, + walk_length=param.ksteps, + attr_name_abs="rrwp", + attr_name_rel="rrwp", + add_identity=True + ) + data = transform(data) + + return data + + +class ComputePosencStat(BaseTransform): + def __init__(self, pe_types, cfg): + self.pe_types = pe_types + self.cfg = cfg + + def forward(self, data: Any) -> Any: + pass + + def __call__(self, data: Data) -> Data: + data = compute_posenc_stats(data, + pe_types=self.pe_types, + cfg=self.cfg + ) + return data diff --git a/gridfm_graphkit/datasets/powergrid_datamodule.py b/gridfm_graphkit/datasets/powergrid_datamodule.py index c18c360..4b0320f 100644 --- a/gridfm_graphkit/datasets/powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/powergrid_datamodule.py @@ -10,6 +10,11 @@ ) from gridfm_graphkit.datasets.utils import split_dataset from gridfm_graphkit.datasets.powergrid_dataset import GridDatasetDisk + +from gridfm_graphkit.datasets.posenc_stats import ComputePosencStat + +import torch_geometric.transforms as T + import numpy as np import random import warnings @@ -129,6 +134,16 @@ def setup(self, stage: str): mask_dim=self.args.data.mask_dim, transform=get_transform(args=self.args), ) + + if ('posenc_RRWP' in self.args.data) and self.args.data.posenc_RRWP.enable: + pe_transform = ComputePosencStat(pe_types=['RRWP'], + cfg=self.args.data + ) + if dataset.transform is None: + dataset.transform = pe_transform + else: + dataset.transform = T.Compose([pe_transform, dataset.transform]) + self.datasets.append(dataset) num_scenarios = self.args.data.scenarios[i] diff --git a/gridfm_graphkit/datasets/rrwp.py b/gridfm_graphkit/datasets/rrwp.py new file mode 100644 index 0000000..acbe112 --- /dev/null +++ b/gridfm_graphkit/datasets/rrwp.py @@ -0,0 +1,87 @@ +from typing import Any, Optional +import torch +import torch.nn.functional as F +from torch_geometric.data import Data +from torch_sparse import SparseTensor + + +def add_node_attr(data: Data, value: Any, + attr_name: Optional[str] = None) -> Data: + if attr_name is None: + if 'x' in data: + x = data.x.view(-1, 1) if data.x.dim() == 1 else data.x + data.x = torch.cat([x, value.to(x.device, x.dtype)], dim=-1) + else: + data.x = value + else: + data[attr_name] = value + + return data + + + +@torch.no_grad() +def add_full_rrwp(data, + walk_length=8, + attr_name_abs="rrwp", # name: 'rrwp' + attr_name_rel="rrwp", # name: ('rrwp_idx', 'rrwp_val') + add_identity=True, + spd=False, + **kwargs + ): + num_nodes = data.num_nodes + edge_index, edge_weight = data.edge_index, data.edge_weight + + adj = SparseTensor.from_edge_index(edge_index, edge_weight, + sparse_sizes=(num_nodes, num_nodes), + ) + + # Compute D^{-1} A: + deg = adj.sum(dim=1) + deg_inv = 1.0 / adj.sum(dim=1) + deg_inv[deg_inv == float('inf')] = 0 + adj = adj * deg_inv.view(-1, 1) + adj = adj.to_dense() + + pe_list = [] + i = 0 + if add_identity: + pe_list.append(torch.eye(num_nodes, dtype=torch.float)) + i = i + 1 + + out = adj + pe_list.append(adj) + + if walk_length > 2: + for j in range(i + 1, walk_length): + out = out @ adj + pe_list.append(out) + + pe = torch.stack(pe_list, dim=-1) # n x n x k + + abs_pe = pe.diagonal().transpose(0, 1) # n x k + + rel_pe = SparseTensor.from_dense(pe, has_value=True) + rel_pe_row, rel_pe_col, rel_pe_val = rel_pe.coo() + # rel_pe_idx = torch.stack([rel_pe_row, rel_pe_col], dim=0) + rel_pe_idx = torch.stack([rel_pe_col, rel_pe_row], dim=0) + # the framework of GRIT performing right-mul while adj is row-normalized, + # need to switch the order or row and col. + # note: both can work but the current version is more reasonable. + + + if spd: + spd_idx = walk_length - torch.arange(walk_length) + val = (rel_pe_val > 0).type(torch.float) * spd_idx.unsqueeze(0) + val = torch.argmax(val, dim=-1) + rel_pe_val = F.one_hot(val, walk_length).type(torch.float) + abs_pe = torch.zeros_like(abs_pe) + + data = add_node_attr(data, abs_pe, attr_name=attr_name_abs) + data = add_node_attr(data, rel_pe_idx, attr_name=f"{attr_name_rel}_index") + data = add_node_attr(data, rel_pe_val, attr_name=f"{attr_name_rel}_val") + data.log_deg = torch.log(deg + 1) + data.deg = deg.type(torch.long) + + return data + diff --git a/gridfm_graphkit/models/__init__.py b/gridfm_graphkit/models/__init__.py index de355d3..cc66936 100644 --- a/gridfm_graphkit/models/__init__.py +++ b/gridfm_graphkit/models/__init__.py @@ -1,4 +1,5 @@ from gridfm_graphkit.models.gps_transformer import GPSTransformer from gridfm_graphkit.models.gnn_transformer import GNN_TransformerConv +from gridfm_graphkit.models.grit_transformer import GritTransformer -__all__ = ["GPSTransformer", "GNN_TransformerConv"] +__all__ = ["GPSTransformer", "GNN_TransformerConv", "GRIT"] diff --git a/gridfm_graphkit/models/gnn_transformer.py b/gridfm_graphkit/models/gnn_transformer.py index 9e1ab23..37d3632 100644 --- a/gridfm_graphkit/models/gnn_transformer.py +++ b/gridfm_graphkit/models/gnn_transformer.py @@ -74,7 +74,7 @@ def __init__(self, args): requires_grad=False, ) - def forward(self, x, pe, edge_index, edge_attr, batch): + def forward(self, data_batch): """ Forward pass for the GPSTransformer. @@ -88,6 +88,12 @@ def forward(self, x, pe, edge_index, edge_attr, batch): Returns: output (Tensor): Output node features of shape [num_nodes, output_dim]. """ + x=data_batch.x + pe=data_batch.pe + edge_index=data_batch.edge_index + edge_attr=data_batch.edge_attr + batch=data_batch.batch + for conv in self.layers: x = conv(x, edge_index, edge_attr) x = nn.LeakyReLU()(x) diff --git a/gridfm_graphkit/models/gps_transformer.py b/gridfm_graphkit/models/gps_transformer.py index cc8b648..ca45c5a 100644 --- a/gridfm_graphkit/models/gps_transformer.py +++ b/gridfm_graphkit/models/gps_transformer.py @@ -105,7 +105,7 @@ def __init__(self, args): requires_grad=False, ) - def forward(self, x, pe, edge_index, edge_attr, batch): + def forward(self, data_batch): """ Forward pass for the GPSTransformer. @@ -119,6 +119,12 @@ def forward(self, x, pe, edge_index, edge_attr, batch): Returns: output (Tensor): Output node features of shape [num_nodes, output_dim]. """ + x=data_batch.x + pe=data_batch.pe + edge_index=data_batch.edge_index + edge_attr=data_batch.edge_attr + batch=data_batch.batch + x_pe = self.pe_norm(pe) x = self.encoder(x) diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py new file mode 100644 index 0000000..a1ffc4a --- /dev/null +++ b/gridfm_graphkit/models/grit_layer.py @@ -0,0 +1,347 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_geometric as pyg +from torch_geometric.utils.num_nodes import maybe_num_nodes +from torch_scatter import scatter, scatter_max, scatter_add + +import opt_einsum as oe + +import warnings + + +def pyg_softmax(src, index, num_nodes=None): + """ + Computes a sparsely evaluated softmax. + Given a value tensor :attr:`src`, this function first groups the values + along the first dimension based on the indices specified in :attr:`index`, + and then proceeds to compute the softmax individually for each group. + + Args: + src (Tensor): The source tensor. + index (LongTensor): The indices of elements for applying the softmax. + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) + + Returns: + out (Tensor) + """ + + num_nodes = maybe_num_nodes(index, num_nodes) + + out = src - scatter_max(src, index, dim=0, dim_size=num_nodes)[0][index] + out = out.exp() + out = out / ( + scatter_add(out, index, dim=0, dim_size=num_nodes)[index] + 1e-16) + + return out + + + +class MultiHeadAttentionLayerGritSparse(nn.Module): + """ + Attention Computation for GRIT + """ + + def __init__(self, in_dim, out_dim, num_heads, use_bias, + clamp=5., dropout=0., act=None, + edge_enhance=True, + sqrt_relu=False, + signed_sqrt=True, + cfg={}, + **kwargs): + super().__init__() + + self.out_dim = out_dim + self.num_heads = num_heads + self.dropout = nn.Dropout(dropout) + self.clamp = np.abs(clamp) if clamp is not None else None + self.edge_enhance = edge_enhance + + self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=True) + self.K = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias) + self.E = nn.Linear(in_dim, out_dim * num_heads * 2, bias=True) + self.V = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias) + nn.init.xavier_normal_(self.Q.weight) + nn.init.xavier_normal_(self.K.weight) + nn.init.xavier_normal_(self.E.weight) + nn.init.xavier_normal_(self.V.weight) + + self.Aw = nn.Parameter(torch.zeros(self.out_dim, self.num_heads, 1), requires_grad=True) + nn.init.xavier_normal_(self.Aw) + + if act is None: + self.act = nn.Identity() + else: + self.act = nn.ReLU() + + if self.edge_enhance: + self.VeRow = nn.Parameter(torch.zeros(self.out_dim, self.num_heads, self.out_dim), requires_grad=True) + nn.init.xavier_normal_(self.VeRow) + + def propagate_attention(self, batch): + src = batch.K_h[batch.edge_index[0]] # (num relative) x num_heads x out_dim + dest = batch.Q_h[batch.edge_index[1]] # (num relative) x num_heads x out_dim + score = src + dest # element-wise multiplication + + if batch.get("E", None) is not None: + batch.E = batch.E.view(-1, self.num_heads, self.out_dim * 2) + E_w, E_b = batch.E[:, :, :self.out_dim], batch.E[:, :, self.out_dim:] + # (num relative) x num_heads x out_dim + score = score * E_w + score = torch.sqrt(torch.relu(score)) - torch.sqrt(torch.relu(-score)) + score = score + E_b + + score = self.act(score) + e_t = score + + # output edge + if batch.get("E", None) is not None: + batch.wE = score.flatten(1) + + # final attn + score = oe.contract("ehd, dhc->ehc", score, self.Aw, backend="torch") + if self.clamp is not None: + score = torch.clamp(score, min=-self.clamp, max=self.clamp) + + raw_attn = score + score = pyg_softmax(score, batch.edge_index[1]) # (num relative) x num_heads x 1 + score = self.dropout(score) + batch.attn = score + + # Aggregate with Attn-Score + msg = batch.V_h[batch.edge_index[0]] * score # (num relative) x num_heads x out_dim + batch.wV = torch.zeros_like(batch.V_h) # (num nodes in batch) x num_heads x out_dim + scatter(msg, batch.edge_index[1], dim=0, out=batch.wV, reduce='add') + + if self.edge_enhance and batch.E is not None: + rowV = scatter(e_t * score, batch.edge_index[1], dim=0, reduce="add") + rowV = oe.contract("nhd, dhc -> nhc", rowV, self.VeRow, backend="torch") + batch.wV = batch.wV + rowV + + def forward(self, batch): + Q_h = self.Q(batch.x) + K_h = self.K(batch.x) + + V_h = self.V(batch.x) + if batch.get("edge_attr", None) is not None: + batch.E = self.E(batch.edge_attr) + else: + batch.E = None + + batch.Q_h = Q_h.view(-1, self.num_heads, self.out_dim) + batch.K_h = K_h.view(-1, self.num_heads, self.out_dim) + batch.V_h = V_h.view(-1, self.num_heads, self.out_dim) + self.propagate_attention(batch) + h_out = batch.wV + e_out = batch.get('wE', None) + + return h_out, e_out + + +class GritTransformerLayer(nn.Module): + """ + Transformer Layer for GRIT + """ + def __init__(self, in_dim, out_dim, num_heads, + dropout=0.0, + attn_dropout=0.0, + layer_norm=False, batch_norm=True, + residual=True, + act='relu', + norm_e=True, + O_e=True, + cfg=dict(), + **kwargs): + super().__init__() + + self.debug = False + self.in_channels = in_dim + self.out_channels = out_dim + self.in_dim = in_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.dropout = dropout + self.residual = residual + self.layer_norm = layer_norm + self.batch_norm = batch_norm + + # ------- + self.update_e = getattr(cfg.attn, "update_e", True) + self.bn_momentum = cfg.attn.bn_momentum + self.bn_no_runner = cfg.attn.bn_no_runner + self.rezero = getattr(cfg.attn, "rezero", False) + + if act is not None: + self.act = nn.ReLU() + else: + self.act = nn.Identity() + + if getattr(cfg, "attn", None) is None: + cfg.attn = dict() + self.use_attn = getattr(cfg.attn, "use", True) + self.deg_scaler = getattr(cfg.attn, "deg_scaler", True) + + self.attention = MultiHeadAttentionLayerGritSparse( + in_dim=in_dim, + out_dim=out_dim // num_heads, + num_heads=num_heads, + use_bias=getattr(cfg.attn, "use_bias", False), + dropout=attn_dropout, + clamp=getattr(cfg.attn, "clamp", 5.), + act=getattr(cfg.attn, "act", "relu"), + edge_enhance=getattr(cfg.attn, "edge_enhance", True), + sqrt_relu=getattr(cfg.attn, "sqrt_relu", False), + signed_sqrt=getattr(cfg.attn, "signed_sqrt", False), + scaled_attn =getattr(cfg.attn,"scaled_attn", False), + no_qk=getattr(cfg.attn, "no_qk", False), + ) + + if getattr(cfg.attn, 'graphormer_attn', False): + self.attention = MultiHeadAttentionLayerGraphormerSparse( + in_dim=in_dim, + out_dim=out_dim // num_heads, + num_heads=num_heads, + use_bias=getattr(cfg.attn, "use_bias", False), + dropout=attn_dropout, + clamp=getattr(cfg.attn, "clamp", 5.), + act=getattr(cfg.attn, "act", "relu"), + edge_enhance=True, + sqrt_relu=getattr(cfg.attn, "sqrt_relu", False), + signed_sqrt=getattr(cfg.attn, "signed_sqrt", False), + scaled_attn =getattr(cfg.attn, "scaled_attn", False), + no_qk=getattr(cfg.attn, "no_qk", False), + ) + + + + self.O_h = nn.Linear(out_dim//num_heads * num_heads, out_dim) + if O_e: + self.O_e = nn.Linear(out_dim//num_heads * num_heads, out_dim) + else: + self.O_e = nn.Identity() + + # -------- Deg Scaler Option ------ + + if self.deg_scaler: + self.deg_coef = nn.Parameter(torch.zeros(1, out_dim//num_heads * num_heads, 2)) + nn.init.xavier_normal_(self.deg_coef) + + if self.layer_norm: + self.layer_norm1_h = nn.LayerNorm(out_dim) + self.layer_norm1_e = nn.LayerNorm(out_dim) if norm_e else nn.Identity() + + if self.batch_norm: + # when the batch_size is really small, use smaller momentum to avoid bad mini-batch leading to extremely bad val/test loss (NaN) + self.batch_norm1_h = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.attn.bn_momentum) + self.batch_norm1_e = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.attn.bn_momentum) if norm_e else nn.Identity() + + # FFN for h + self.FFN_h_layer1 = nn.Linear(out_dim, out_dim * 2) + self.FFN_h_layer2 = nn.Linear(out_dim * 2, out_dim) + + if self.layer_norm: + self.layer_norm2_h = nn.LayerNorm(out_dim) + + if self.batch_norm: + self.batch_norm2_h = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.attn.bn_momentum) + + if self.rezero: + self.alpha1_h = nn.Parameter(torch.zeros(1,1)) + self.alpha2_h = nn.Parameter(torch.zeros(1,1)) + self.alpha1_e = nn.Parameter(torch.zeros(1,1)) + + def forward(self, batch): + h = batch.x + num_nodes = batch.num_nodes + log_deg = get_log_deg(batch) + + h_in1 = h # for first residual connection + e_in1 = batch.get("edge_attr", None) + e = None + # multi-head attention out + + h_attn_out, e_attn_out = self.attention(batch) + + h = h_attn_out.view(num_nodes, -1) + h = F.dropout(h, self.dropout, training=self.training) + + # degree scaler + if self.deg_scaler: + h = torch.stack([h, h * log_deg], dim=-1) + h = (h * self.deg_coef).sum(dim=-1) + + h = self.O_h(h) + if e_attn_out is not None: + e = e_attn_out.flatten(1) + e = F.dropout(e, self.dropout, training=self.training) + e = self.O_e(e) + + if self.residual: + if self.rezero: h = h * self.alpha1_h + h = h_in1 + h # residual connection + if e is not None: + if self.rezero: e = e * self.alpha1_e + e = e + e_in1 + + if self.layer_norm: + h = self.layer_norm1_h(h) + if e is not None: e = self.layer_norm1_e(e) + + if self.batch_norm: + h = self.batch_norm1_h(h) + if e is not None: e = self.batch_norm1_e(e) + + # FFN for h + h_in2 = h # for second residual connection + h = self.FFN_h_layer1(h) + h = self.act(h) + h = F.dropout(h, self.dropout, training=self.training) + h = self.FFN_h_layer2(h) + + if self.residual: + if self.rezero: h = h * self.alpha2_h + h = h_in2 + h # residual connection + + if self.layer_norm: + h = self.layer_norm2_h(h) + + if self.batch_norm: + h = self.batch_norm2_h(h) + + batch.x = h + if self.update_e: + batch.edge_attr = e + else: + batch.edge_attr = e_in1 + + return batch + + def __repr__(self): + return '{}(in_channels={}, out_channels={}, heads={}, residual={})\n[{}]'.format( + self.__class__.__name__, + self.in_channels, + self.out_channels, self.num_heads, self.residual, + super().__repr__(), + ) + + +@torch.no_grad() +def get_log_deg(batch): + if "log_deg" in batch: + log_deg = batch.log_deg + elif "deg" in batch: + deg = batch.deg + log_deg = torch.log(deg + 1).unsqueeze(-1) + else: + warnings.warn("Compute the degree on the fly; Might be problematric if have applied edge-padding to complete graphs") + deg = pyg.utils.degree(batch.edge_index[1], + num_nodes=batch.num_nodes, + dtype=torch.float + ) + log_deg = torch.log(deg + 1) + log_deg = log_deg.view(batch.num_nodes, 1) + return log_deg + + diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py new file mode 100644 index 0000000..a1717d1 --- /dev/null +++ b/gridfm_graphkit/models/grit_transformer.py @@ -0,0 +1,210 @@ +from gridfm_graphkit.io.registries import MODELS_REGISTRY +import torch +from torch import nn + +from gridfm_graphkit.models.rrwp_encoder import RRWPLinearNodeEncoder, RRWPLinearEdgeEncoder +from gridfm_graphkit.models.grit_layer import GritTransformerLayer + + + +class BatchNorm1dNode(torch.nn.Module): + r"""A batch normalization layer for node-level features. + + Args: + dim_in (int): BatchNorm input dimension. + eps (float): BatchNorm eps. + momentum (float): BatchNorm momentum. + """ + def __init__(self, dim_in, eps, momentum): + super().__init__() + self.bn = torch.nn.BatchNorm1d( + dim_in, + eps=eps, + momentum=momentum, + ) + + def forward(self, batch): + batch.x = self.bn(batch.x) + return batch + + +class LinearNodeEncoder(torch.nn.Module): + def __init__(self, dim_in, emb_dim): + super().__init__() + + self.encoder = torch.nn.Linear(dim_in, emb_dim) + + def forward(self, batch): + batch.x = self.encoder(batch.x) + return batch + +class LinearEdgeEncoder(torch.nn.Module): + def __init__(self, edge_dim, emb_dim): + super().__init__() + + self.in_dim = edge_dim + + self.encoder = torch.nn.Linear(self.in_dim, emb_dim) + + def forward(self, batch): + batch.edge_attr = self.encoder(batch.edge_attr.view(-1, self.in_dim)) + return batch + + +class FeatureEncoder(torch.nn.Module): + """ + Encoding node and edge features + + Args: + dim_in (int): Input feature dimension + + """ + def __init__( + self, + dim_in, + dim_inner, + args + ): + super(FeatureEncoder, self).__init__() + self.dim_in = dim_in + if args.encoder.node_encoder: + # Encode integer node features via nn.Embeddings + self.node_encoder = LinearNodeEncoder(self.dim_in, dim_inner) + if args.encoder.node_encoder_bn: + self.node_encoder_bn = BatchNorm1dNode(dim_inner, 1e-5, 0.1) + # Update dim_in to reflect the new dimension fo the node features + self.dim_in = dim_inner + if args.encoder.edge_encoder: + edge_dim = args.edge_dim + enc_dim_edge = dim_inner + # Encode integer edge features via nn.Embeddings + self.edge_encoder = LinearEdgeEncoder(edge_dim, enc_dim_edge) + if args.encoder.edge_encoder_bn: + self.edge_encoder_bn = BatchNorm1dNode(enc_dim_edge, 1e-5, 0.1) + + def forward(self, batch): + for module in self.children(): + batch = module(batch) + return batch + +class GraphHead(nn.Module): + """ + Prediction head for decoding tasks. + Args: + dim_in (int): Input dimension. + dim_out (int): Output dimension. For binary prediction, dim_out=1. + L (int): Number of hidden layers. + """ + + def __init__(self, dim_in, dim_out): + super().__init__() + + self.FC_layers = nn.Sequential( + nn.Linear(dim_in, dim_in), + nn.LeakyReLU(), + nn.Linear(dim_in, dim_out), + ) + + def _apply_index(self, batch): + return batch.graph_feature, batch.y + + def forward(self, batch): + graph_emb = self.FC_layers(batch.x) + batch.graph_feature = graph_emb + pred, label = self._apply_index(batch) + return pred + + +@MODELS_REGISTRY.register("GRIT") +class GritTransformer(torch.nn.Module): + """ + The GritTransformer (Graph Inductive Bias Transformer) from + Graph Inductive Biases in Transformers without Message Passing, L. Ma et al., + 2023. + + """ + def __init__(self, args): + super().__init__() + + + dim_in = args.model.input_dim + dim_out = args.model.output_dim + dim_inner = args.model.hidden_size + dim_edge = args.model.edge_dim + num_heads = args.model.attention_head + dropout = args.model.dropout + num_layers = args.model.num_layers + self.mask_dim = getattr(args.data, "mask_dim", 6) + self.mask_value = getattr(args.data, "mask_value", -1.0) + self.learn_mask = getattr(args.data, "learn_mask", False) + if self.learn_mask: + self.mask_value = nn.Parameter( + torch.randn(self.mask_dim) + self.mask_value, + requires_grad=True, + ) + else: + self.mask_value = nn.Parameter( + torch.zeros(self.mask_dim) + self.mask_value, + requires_grad=False, + ) + + self.encoder = FeatureEncoder( + dim_in, + dim_inner, + args.model + ) + dim_in = self.encoder.dim_in + + if args.data.posenc_RRWP.enable: + + self.rrwp_abs_encoder = RRWPLinearNodeEncoder( + args.data.posenc_RRWP.ksteps, + dim_inner + ) + rel_pe_dim = args.data.posenc_RRWP.ksteps + self.rrwp_rel_encoder = RRWPLinearEdgeEncoder( + rel_pe_dim, + dim_inner, + pad_to_full_graph=args.model.gt.attn.full_attn, + add_node_attr_as_self_loop=False, + fill_value=0. + ) + + assert args.model.hidden_size == dim_inner == dim_in, \ + "The inner and hidden dims must match." + + layers = [] + for ll in range(num_layers): + layers.append(GritTransformerLayer( + in_dim=args.model.gt.dim_hidden, + out_dim=args.model.gt.dim_hidden, + num_heads=num_heads, + dropout=dropout, + act=args.model.act, + attn_dropout=args.model.gt.attn_dropout, + layer_norm=args.model.gt.layer_norm, + batch_norm=args.model.gt.batch_norm, + residual=True, + norm_e=args.model.gt.attn.norm_e, + O_e=args.model.gt.attn.O_e, + cfg=args.model.gt, + )) + + self.layers = nn.Sequential(*layers) + + self.decoder = GraphHead(dim_inner, dim_out) + + def forward(self, batch): + """ + Forward pass for GRIT. + + Args: + batch (Batch): Pytorch Geometric Batch object, with x, y encodings, etc. + + Returns: + output (Tensor): Output node features of shape [num_nodes, output_dim]. + """ + for module in self.children(): + batch = module(batch) + + return batch diff --git a/gridfm_graphkit/models/rrwp_encoder.py b/gridfm_graphkit/models/rrwp_encoder.py new file mode 100644 index 0000000..1f7fd10 --- /dev/null +++ b/gridfm_graphkit/models/rrwp_encoder.py @@ -0,0 +1,183 @@ +""" + The RRWP encoder for GRIT (ours) +""" +import torch +from torch import nn +from torch.nn import functional as F +import torch_sparse + +from torch_geometric.utils import remove_self_loops, add_remaining_self_loops, add_self_loops +from torch_scatter import scatter +import warnings + +def full_edge_index(edge_index, batch=None): + """ + Return the Full batched sparse adjacency matrices given by edge indices. + Returns batched sparse adjacency matrices with exactly those edges that + are not in the input `edge_index` while ignoring self-loops. + Implementation inspired by `torch_geometric.utils.to_dense_adj` + Args: + edge_index: The edge indices. + batch: Batch vector, which assigns each node to a specific example. + Returns: + Complementary edge index. + """ + + if batch is None: + batch = edge_index.new_zeros(edge_index.max().item() + 1) + + batch_size = batch.max().item() + 1 + one = batch.new_ones(batch.size(0)) + num_nodes = scatter(one, batch, + dim=0, dim_size=batch_size, reduce='add') + cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)]) + + negative_index_list = [] + for i in range(batch_size): + n = num_nodes[i].item() + size = [n, n] + adj = torch.ones(size, dtype=torch.short, + device=edge_index.device) + + adj = adj.view(size) + _edge_index = adj.nonzero(as_tuple=False).t().contiguous() + # _edge_index, _ = remove_self_loops(_edge_index) + negative_index_list.append(_edge_index + cum_nodes[i]) + + edge_index_full = torch.cat(negative_index_list, dim=1).contiguous() + return edge_index_full + + + +class RRWPLinearNodeEncoder(torch.nn.Module): + """ + FC_1(RRWP) + FC_2 (Node-attr) + note: FC_2 is given by the Typedict encoder of node-attr in some cases + Parameters: + num_classes - the number of classes for the embedding mapping to learn + """ + def __init__(self, emb_dim, out_dim, use_bias=False, batchnorm=False, layernorm=False, pe_name="rrwp"): + super().__init__() + self.batchnorm = batchnorm + self.layernorm = layernorm + self.name = pe_name + + self.fc = nn.Linear(emb_dim, out_dim, bias=use_bias) + torch.nn.init.xavier_uniform_(self.fc.weight) + + if self.batchnorm: + self.bn = nn.BatchNorm1d(out_dim) + if self.layernorm: + self.ln = nn.LayerNorm(out_dim) + + def forward(self, batch): + # Encode just the first dimension if more exist + rrwp = batch[f"{self.name}"] + rrwp = self.fc(rrwp) + + if self.batchnorm: + rrwp = self.bn(rrwp) + + if self.layernorm: + rrwp = self.ln(rrwp) + + if "x" in batch: + batch.x = batch.x + rrwp + else: + batch.x = rrwp + + return batch + + +class RRWPLinearEdgeEncoder(torch.nn.Module): + """ + Merge RRWP with given edge-attr and Zero-padding to all pairs of node + FC_1(RRWP) + FC_2(edge-attr) + - FC_2 given by the TypedictEncoder in same cases + - Zero-padding for non-existing edges in fully-connected graph + - (optional) add node-attr as the E_{i,i}'s attr + note: assuming node-attr and edge-attr is with the same dimension after Encoders + """ + def __init__(self, emb_dim, out_dim, batchnorm=False, layernorm=False, use_bias=False, + pad_to_full_graph=True, fill_value=0., + add_node_attr_as_self_loop=False, + overwrite_old_attr=False): + super().__init__() + # note: batchnorm/layernorm might ruin some properties of pe on providing shortest-path distance info + self.emb_dim = emb_dim + self.out_dim = out_dim + self.add_node_attr_as_self_loop = add_node_attr_as_self_loop + self.overwrite_old_attr=overwrite_old_attr # remove the old edge-attr + + self.batchnorm = batchnorm + self.layernorm = layernorm + if self.batchnorm or self.layernorm: + warnings.warn("batchnorm/layernorm might ruin some properties of pe on providing shortest-path distance info ") + + # print('--------fc in and out:', emb_dim, out_dim) + self.fc = nn.Linear(emb_dim, out_dim, bias=use_bias) + torch.nn.init.xavier_uniform_(self.fc.weight) + self.pad_to_full_graph = pad_to_full_graph + self.fill_value = 0. + + padding = torch.ones(1, out_dim, dtype=torch.float) * fill_value + self.register_buffer("padding", padding) + + if self.batchnorm: + self.bn = nn.BatchNorm1d(out_dim) + + if self.layernorm: + self.ln = nn.LayerNorm(out_dim) + + def forward(self, batch): + rrwp_idx = batch.rrwp_index + rrwp_val = batch.rrwp_val + edge_index = batch.edge_index + edge_attr = batch.edge_attr + rrwp_val = self.fc(rrwp_val) + + if edge_attr is None: + edge_attr = edge_index.new_zeros(edge_index.size(1), rrwp_val.size(1)) + # zero padding for non-existing edges + + if self.overwrite_old_attr: + out_idx, out_val = rrwp_idx, rrwp_val + else: + edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) + out_idx, out_val = torch_sparse.coalesce( + torch.cat([edge_index, rrwp_idx], dim=1), + torch.cat([edge_attr, rrwp_val], dim=0), + batch.num_nodes, batch.num_nodes, + op="add" + ) + + + if self.pad_to_full_graph: + edge_index_full = full_edge_index(out_idx, batch=batch.batch) + edge_attr_pad = self.padding.repeat(edge_index_full.size(1), 1) + # zero padding to fully-connected graphs + out_idx = torch.cat([out_idx, edge_index_full], dim=1) + out_val = torch.cat([out_val, edge_attr_pad], dim=0) + out_idx, out_val = torch_sparse.coalesce( + out_idx, out_val, batch.num_nodes, batch.num_nodes, + op="add" + ) + + if self.batchnorm: + out_val = self.bn(out_val) + + if self.layernorm: + out_val = self.ln(out_val) + + + batch.edge_index, batch.edge_attr = out_idx, out_val + return batch + + def __repr__(self): + return f"{self.__class__.__name__}" \ + f"(pad_to_full_graph={self.pad_to_full_graph}," \ + f"fill_value={self.fill_value}," \ + f"{self.fc.__repr__()})" + + + diff --git a/gridfm_graphkit/tasks/feature_reconstruction_task.py b/gridfm_graphkit/tasks/feature_reconstruction_task.py index cb6963b..e6a4749 100644 --- a/gridfm_graphkit/tasks/feature_reconstruction_task.py +++ b/gridfm_graphkit/tasks/feature_reconstruction_task.py @@ -5,6 +5,7 @@ import numpy as np import os import pandas as pd +import copy from lightning.pytorch.loggers import MLFlowLogger from gridfm_graphkit.io.param_handler import load_model, get_loss_function @@ -74,11 +75,11 @@ def __init__(self, args, node_normalizers, edge_normalizers): self.edge_normalizers = edge_normalizers self.save_hyperparameters() - def forward(self, x, pe, edge_index, edge_attr, batch, mask=None): - if mask is not None: - mask_value_expanded = self.model.mask_value.expand(x.shape[0], -1) - x[:, : mask.shape[1]][mask] = mask_value_expanded[mask] - return self.model(x, pe, edge_index, edge_attr, batch) + def forward(self, batch): + if batch.mask is not None: + mask_value_expanded = self.model.mask_value.expand(batch.x.shape[0], -1) + batch.x[:, : batch.mask.shape[1]][batch.mask] = mask_value_expanded[batch.mask] + return self.model(batch) @rank_zero_only def on_fit_start(self): @@ -111,12 +112,7 @@ def on_fit_start(self): def shared_step(self, batch): output = self.forward( - x=batch.x, - pe=batch.pe, - edge_index=batch.edge_index, - edge_attr=batch.edge_attr, - batch=batch.batch, - mask=batch.mask, + batch ) loss_dict = self.loss_fn( @@ -167,7 +163,7 @@ def validation_step(self, batch, batch_idx): return loss_dict["loss"] def test_step(self, batch, batch_idx, dataloader_idx=0): - output, loss_dict = self.shared_step(batch) + output, loss_dict = self.shared_step(copy.deepcopy(batch)) dataset_name = self.args.data.networks[dataloader_idx] diff --git a/pyproject.toml b/pyproject.toml index 51c8665..4a17ed5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ classifiers = [ dependencies = [ - "torch>2.0", + "torch==2.6", "torch-geometric", "mlflow", "nbformat", @@ -51,6 +51,7 @@ dependencies = [ "pyyaml", "lightning", "seaborn", + "opt-einsum", ] [project.optional-dependencies]