Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
39a5862
added basic GRIT code
ttolhurst Nov 17, 2025
922d6ce
initial connection of model to config
ttolhurst Nov 17, 2025
e8281ac
collect model components and replace old register method
ttolhurst Nov 17, 2025
a67e522
clean up imported layers and encoders
ttolhurst Nov 17, 2025
6966f5f
flow in basic structure for RRWP calculation
ttolhurst Nov 17, 2025
a7bd51d
clean up
ttolhurst Nov 17, 2025
226f2a3
matching up parameters
ttolhurst Nov 17, 2025
88d9ca6
matching up parameters
ttolhurst Nov 17, 2025
b7d9dcf
matching up parameters
ttolhurst Nov 17, 2025
38cc44a
matching up parameters
ttolhurst Nov 17, 2025
7fded95
matching up parameters in grit layer
ttolhurst Nov 17, 2025
0f3b803
matching up parameters in grit layer
ttolhurst Nov 17, 2025
af8ad03
matching up parameters in grit layer
ttolhurst Nov 17, 2025
f430f2a
matching up parameters in data module
ttolhurst Nov 17, 2025
e1c4890
flow over parameters from base model
ttolhurst Nov 17, 2025
36dca00
verified encodings and data flow to model forward method
ttolhurst Nov 17, 2025
a8ec56e
match feature dimensions
ttolhurst Nov 17, 2025
0868b96
match feature dimensions
ttolhurst Nov 17, 2025
3cc21a3
reformat decoder to handle batch format
ttolhurst Nov 17, 2025
1783051
confirmed training loop functions
ttolhurst Nov 17, 2025
c75012f
update toml
ttolhurst Nov 17, 2025
3d3f98b
added forward method to transform class
ttolhurst Nov 17, 2025
d238e75
update readme with install instructions
ttolhurst Nov 17, 2025
17b0889
verifed compat with GPS and GNN
ttolhurst Nov 17, 2025
091f084
work on comments and clean up
ttolhurst Nov 17, 2025
53d5644
deep copy in test method
ttolhurst Nov 17, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ cd gridfm-graphkit
python -m venv venv
source venv/bin/activate
pip install -e .
pip install torch_sparse torch_scatter -f https://data.pyg.org/whl/torch-2.6.0+cu124.html
```

For documentation generation and unit testing, install with the optional `dev` and `test` extras:
Expand Down
88 changes: 88 additions & 0 deletions examples/config/grit_pretraining.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
callbacks:
patience: 100
tol: 0
data:
baseMVA: 100
learn_mask: false
mask_dim: 6
mask_ratio: 0.5
mask_type: rnd
mask_value: -1.0
networks:
# - Texas2k_case1_2016summerpeak
- case24_ieee_rts
- case118_ieee
- case300_ieee
- case89_pegase
- case240_pserc
normalization: baseMVAnorm
scenarios:
# - 5000
- 50000
- 50000
- 30000
- 50000
- 50000
test_ratio: 0.1
val_ratio: 0.1
workers: 0
posenc_RRWP: # TODO maybe better with data section...
enable: True
ksteps: 21
add_identity: True
add_node_attr: False
add_inverse: False
model:
attention_head: 8
dropout: 0.1
edge_dim: 2
hidden_size: 116 # `gt.dim_hidden` must match `gnn.dim_inner`
input_dim: 9
num_layers: 10
output_dim: 6
pe_dim: 20
type: GRIT
act: relu
encoder:
node_encoder: True
edge_encoder: True
node_encoder_name: TODO
node_encoder_bn: True
edge_encoder_bn: True
gt:
layer_type: GritTransformer
dim_hidden: 116 # `gt.dim_hidden` must match `gnn.dim_inner`
layer_norm: False
batch_norm: True
update_e: True
attn_dropout: 0.2
attn:
clamp: 5.
act: 'relu'
full_attn: True
edge_enhance: True
O_e: True
norm_e: True
signed_sqrt: True
bn_momentum: 0.1
bn_no_runner: False

optimizer:
beta1: 0.9
beta2: 0.999
learning_rate: 0.0001
lr_decay: 0.7
lr_patience: 10
seed: 0
training:
batch_size: 8
epochs: 500
loss_weights:
- 0.01
- 0.99
losses:
- MaskedMSE
- PBE
accelerator: auto
devices: auto
strategy: auto
5 changes: 5 additions & 0 deletions gridfm_graphkit/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ def main_cli(args):
max_epochs=config_args.training.epochs,
callbacks=get_training_callbacks(config_args),
)

# print('******model*****')
# print(model)
# print('******model*****')

if args.command == "train" or args.command == "finetune":
trainer.fit(model=model, datamodule=litGrid)

Expand Down
70 changes: 70 additions & 0 deletions gridfm_graphkit/datasets/posenc_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from copy import deepcopy

import numpy as np
import torch
import torch.nn.functional as F

from torch_geometric.utils import (get_laplacian, to_scipy_sparse_matrix,
to_undirected, to_dense_adj)
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_scatter import scatter_add
from functools import partial
from gridfm_graphkit.datasets.rrwp import add_full_rrwp

from torch_geometric.transforms import BaseTransform
from torch_geometric.data import Data
from typing import Any

def compute_posenc_stats(data, pe_types, cfg):
"""Precompute positional encodings for the given graph.
Supported PE statistics to precompute in original implementation,
selected by `pe_types`:
'LapPE': Laplacian eigen-decomposition.
'RWSE': Random walk landing probabilities (diagonals of RW matrices).
'HKfullPE': Full heat kernels and their diagonals. (NOT IMPLEMENTED)
'HKdiagSE': Diagonals of heat kernel diffusion.
'ElstaticSE': Kernel based on the electrostatic interaction between nodes.
'RRWP': Relative Random Walk Probabilities PE (Ours, for GRIT)
Args:
data: PyG graph
pe_types: Positional encoding types to precompute statistics for.
This can also be a combination, e.g. 'eigen+rw_landing'
is_undirected: True if the graph is expected to be undirected
cfg: Main configuration node

Returns:
Extended PyG Data object.
"""
# Verify PE types.
for t in pe_types:
if t not in ['LapPE', 'EquivStableLapPE', 'SignNet',
'RWSE', 'HKdiagSE', 'HKfullPE', 'ElstaticSE','RRWP']:
raise ValueError(f"Unexpected PE stats selection {t} in {pe_types}")

if 'RRWP' in pe_types:
param = cfg.posenc_RRWP
transform = partial(add_full_rrwp,
walk_length=param.ksteps,
attr_name_abs="rrwp",
attr_name_rel="rrwp",
add_identity=True
)
data = transform(data)

return data


class ComputePosencStat(BaseTransform):
def __init__(self, pe_types, cfg):
self.pe_types = pe_types
self.cfg = cfg

def forward(self, data: Any) -> Any:
pass

def __call__(self, data: Data) -> Data:
data = compute_posenc_stats(data,
pe_types=self.pe_types,
cfg=self.cfg
)
return data
15 changes: 15 additions & 0 deletions gridfm_graphkit/datasets/powergrid_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
)
from gridfm_graphkit.datasets.utils import split_dataset
from gridfm_graphkit.datasets.powergrid_dataset import GridDatasetDisk

from gridfm_graphkit.datasets.posenc_stats import ComputePosencStat

import torch_geometric.transforms as T

import numpy as np
import random
import warnings
Expand Down Expand Up @@ -129,6 +134,16 @@ def setup(self, stage: str):
mask_dim=self.args.data.mask_dim,
transform=get_transform(args=self.args),
)

if ('posenc_RRWP' in self.args.data) and self.args.data.posenc_RRWP.enable:
pe_transform = ComputePosencStat(pe_types=['RRWP'],
cfg=self.args.data
)
if dataset.transform is None:
dataset.transform = pe_transform
else:
dataset.transform = T.Compose([pe_transform, dataset.transform])

self.datasets.append(dataset)

num_scenarios = self.args.data.scenarios[i]
Expand Down
87 changes: 87 additions & 0 deletions gridfm_graphkit/datasets/rrwp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from typing import Any, Optional
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_sparse import SparseTensor


def add_node_attr(data: Data, value: Any,
attr_name: Optional[str] = None) -> Data:
if attr_name is None:
if 'x' in data:
x = data.x.view(-1, 1) if data.x.dim() == 1 else data.x
data.x = torch.cat([x, value.to(x.device, x.dtype)], dim=-1)
else:
data.x = value
else:
data[attr_name] = value

return data



@torch.no_grad()
def add_full_rrwp(data,
walk_length=8,
attr_name_abs="rrwp", # name: 'rrwp'
attr_name_rel="rrwp", # name: ('rrwp_idx', 'rrwp_val')
add_identity=True,
spd=False,
**kwargs
):
num_nodes = data.num_nodes
edge_index, edge_weight = data.edge_index, data.edge_weight

adj = SparseTensor.from_edge_index(edge_index, edge_weight,
sparse_sizes=(num_nodes, num_nodes),
)

# Compute D^{-1} A:
deg = adj.sum(dim=1)
deg_inv = 1.0 / adj.sum(dim=1)
deg_inv[deg_inv == float('inf')] = 0
adj = adj * deg_inv.view(-1, 1)
adj = adj.to_dense()

pe_list = []
i = 0
if add_identity:
pe_list.append(torch.eye(num_nodes, dtype=torch.float))
i = i + 1

out = adj
pe_list.append(adj)

if walk_length > 2:
for j in range(i + 1, walk_length):
out = out @ adj
pe_list.append(out)

pe = torch.stack(pe_list, dim=-1) # n x n x k

abs_pe = pe.diagonal().transpose(0, 1) # n x k

rel_pe = SparseTensor.from_dense(pe, has_value=True)
rel_pe_row, rel_pe_col, rel_pe_val = rel_pe.coo()
# rel_pe_idx = torch.stack([rel_pe_row, rel_pe_col], dim=0)
rel_pe_idx = torch.stack([rel_pe_col, rel_pe_row], dim=0)
# the framework of GRIT performing right-mul while adj is row-normalized,
# need to switch the order or row and col.
# note: both can work but the current version is more reasonable.


if spd:
spd_idx = walk_length - torch.arange(walk_length)
val = (rel_pe_val > 0).type(torch.float) * spd_idx.unsqueeze(0)
val = torch.argmax(val, dim=-1)
rel_pe_val = F.one_hot(val, walk_length).type(torch.float)
abs_pe = torch.zeros_like(abs_pe)

data = add_node_attr(data, abs_pe, attr_name=attr_name_abs)
data = add_node_attr(data, rel_pe_idx, attr_name=f"{attr_name_rel}_index")
data = add_node_attr(data, rel_pe_val, attr_name=f"{attr_name_rel}_val")
data.log_deg = torch.log(deg + 1)
data.deg = deg.type(torch.long)

return data

3 changes: 2 additions & 1 deletion gridfm_graphkit/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from gridfm_graphkit.models.gps_transformer import GPSTransformer
from gridfm_graphkit.models.gnn_transformer import GNN_TransformerConv
from gridfm_graphkit.models.grit_transformer import GritTransformer

__all__ = ["GPSTransformer", "GNN_TransformerConv"]
__all__ = ["GPSTransformer", "GNN_TransformerConv", "GRIT"]
8 changes: 7 additions & 1 deletion gridfm_graphkit/models/gnn_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self, args):
requires_grad=False,
)

def forward(self, x, pe, edge_index, edge_attr, batch):
def forward(self, data_batch):
"""
Forward pass for the GPSTransformer.

Expand All @@ -88,6 +88,12 @@ def forward(self, x, pe, edge_index, edge_attr, batch):
Returns:
output (Tensor): Output node features of shape [num_nodes, output_dim].
"""
x=data_batch.x
pe=data_batch.pe
edge_index=data_batch.edge_index
edge_attr=data_batch.edge_attr
batch=data_batch.batch

for conv in self.layers:
x = conv(x, edge_index, edge_attr)
x = nn.LeakyReLU()(x)
Expand Down
8 changes: 7 additions & 1 deletion gridfm_graphkit/models/gps_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __init__(self, args):
requires_grad=False,
)

def forward(self, x, pe, edge_index, edge_attr, batch):
def forward(self, data_batch):
"""
Forward pass for the GPSTransformer.

Expand All @@ -119,6 +119,12 @@ def forward(self, x, pe, edge_index, edge_attr, batch):
Returns:
output (Tensor): Output node features of shape [num_nodes, output_dim].
"""
x=data_batch.x
pe=data_batch.pe
edge_index=data_batch.edge_index
edge_attr=data_batch.edge_attr
batch=data_batch.batch

x_pe = self.pe_norm(pe)

x = self.encoder(x)
Expand Down
Loading
Loading