Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 33 additions & 20 deletions gnn_model/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ The pipeline is built with PyTorch Lightning and PyTorch Geometric and features
- A deep Processor uses multiple layers of InteractionNetwork blocks for complex message-passing across the mesh graph.
- A Decoder maps the processed mesh state back to the observation locations to make predictions.

**Scalable & Efficient Training**:

- Supports multi-node, multi-GPU distributed training using PyTorch Lightning's DDPStrategy.
- Implements gradient checkpointing to reduce memory usage, allowing for deeper models and larger batch sizes.
- Features a flexible data pipeline with random window resampling for robust, generalized training on massive time-series datasets.
**Scalable & Efficient Training**:

- Supports multi-node, multi-GPU distributed training using PyTorch Lightning's DDPStrategy.
- Optional PhysicsNeMo domain parallel mesh sharding to split the global mesh across GPUs while keeping data-parallel bin sampling.
- Implements gradient checkpointing to reduce memory usage, allowing for deeper models and larger batch sizes.
- Features a flexible data pipeline with random window resampling for robust, generalized training on massive time-series datasets.

## Core Scripts
- `train_gnn.py`: The main script for launching training and evaluation.
Expand All @@ -33,28 +34,40 @@ Create a Conda environment and install the necessary packages.
pip install -r requirements.txt

```
Or a minimalist install
```bash
pip install numpy pandas scipy torch trimesh networkx torch-geometric scikit-learn zarr joblib lightning psutil

```
Or a minimalist install
```bash
pip install numpy pandas scipy torch trimesh networkx torch-geometric scikit-learn zarr joblib lightning psutil

```
Optional (for domain parallel sharding):
```bash
pip install physicsnemo
```
## Usage
### Configure Your Experiment
Modify `train_gnn.py` to set the hyperparameters for your run:
- Set the full date range for the experiment (FULL_START_DATE, FULL_END_DATE).
- Configure the observation_config dictionary to define which instruments and features to use.
- Adjust model hyperparameters like mesh_resolution, hidden_dim, and num_layers.

### Launch training
Use the provided SLURM script (run_gnn.sh) to launch a multi-node training job, or run directly for a single-machine test:
- To start a new run use this from `sbatch run_gnn.sh`:
```bash
srun --cpu-bind=map_cpu:0,1,2,3 python train_gnn.py
```
- To Resume a run from the last saved checkpoint form `sbatch run_gnn.sh`:
```bash
python train_gnn.py --resume_from_checkpoint checkpoints/last.ckpt
```
### Launch training
Use the provided SLURM script (run_gnn.sh) to launch a multi-node training job, or run directly for a single-machine test:
- To start a new run use this from `sbatch run_gnn.sh`:
```bash
srun --cpu-bind=map_cpu:0,1,2,3 python train_gnn.py
```
- To enable PhysicsNeMo domain parallel mesh sharding (requires `physicsnemo` installed):
```bash
srun --cpu-bind=map_cpu:0,1,2,3 python train_gnn.py --domain_parallel
```
- To explicitly set the PhysicsNeMo device mesh (optional):
```bash
srun --cpu-bind=map_cpu:0,1,2,3 python train_gnn.py --domain_parallel --domain_parallel_mesh_shape 4 --domain_parallel_mesh_dim_names domain
```
- To Resume a run from the last saved checkpoint form `sbatch run_gnn.sh`:
```bash
python train_gnn.py --resume_from_checkpoint checkpoints/last.ckpt
```
- Then run the script:
```bash
sbatch run_gnn.sh
Expand Down
118 changes: 118 additions & 0 deletions gnn_model/domain_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from __future__ import annotations

from dataclasses import dataclass
import importlib.util
from typing import Optional, Tuple

import torch
import torch.distributed as dist


@dataclass(frozen=True)
class MeshShard:
rank: int
world_size: int
start: int
end: int
num_nodes: int

@property
def local_size(self) -> int:
return self.end - self.start

def global_to_local(self, device: Optional[torch.device] = None) -> torch.Tensor:
mapping = torch.full((self.num_nodes,), -1, dtype=torch.long, device=device)
mapping[self.start:self.end] = torch.arange(self.local_size, device=device)
return mapping


@dataclass(frozen=True)
class DomainParallelContext:
mesh: "DeviceMesh"
placements: Tuple[object, ...]


def _physicsnemo_available() -> bool:
return importlib.util.find_spec("physicsnemo") is not None


def build_mesh_shard(num_nodes: int, rank: Optional[int] = None, world_size: Optional[int] = None) -> MeshShard:
if rank is None:
rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0
if world_size is None:
world_size = dist.get_world_size() if dist.is_available() and dist.is_initialized() else 1

base = num_nodes // world_size
remainder = num_nodes % world_size
start = rank * base + min(rank, remainder)
end = start + base + (1 if rank < remainder else 0)
return MeshShard(rank=rank, world_size=world_size, start=start, end=end, num_nodes=num_nodes)


def shard_tensor_dim0(tensor: torch.Tensor, shard: MeshShard) -> torch.Tensor:
return tensor[shard.start:shard.end]


def filter_mesh_edges(
edge_index: torch.Tensor, edge_attr: torch.Tensor, shard: MeshShard
) -> tuple[torch.Tensor, torch.Tensor]:
src = edge_index[0]
dst = edge_index[1]
mask = (src >= shard.start) & (src < shard.end) & (dst >= shard.start) & (dst < shard.end)
local_edge_index = edge_index[:, mask].clone()
local_edge_index = local_edge_index - shard.start
local_edge_attr = edge_attr[mask]
return local_edge_index, local_edge_attr


def filter_bipartite_edges(
edge_index: torch.Tensor,
edge_attr: torch.Tensor,
shard: MeshShard,
mesh_dim: int,
) -> tuple[torch.Tensor, torch.Tensor]:
mesh_idx = edge_index[mesh_dim]
mask = (mesh_idx >= shard.start) & (mesh_idx < shard.end)
local_edge_index = edge_index[:, mask].clone()
local_edge_index[mesh_dim] = local_edge_index[mesh_dim] - shard.start
local_edge_attr = edge_attr[mask]
return local_edge_index, local_edge_attr


def init_domain_parallel_context(
mesh_shape: Optional[list[int]] = None,
mesh_dim_names: Optional[list[str]] = None,
) -> Optional[DomainParallelContext]:
if not _physicsnemo_available():
return None

from physicsnemo.distributed import DistributedManager
from torch.distributed.tensor.placement_types import Shard

DistributedManager.initialize()
dm = DistributedManager()

mesh = dm.initialize_mesh(
mesh_shape=mesh_shape or [-1],
mesh_dim_names=mesh_dim_names or ["domain"],
)
placements = (Shard(0),)
return DomainParallelContext(mesh=mesh, placements=placements)


def maybe_build_shardtensor(
local_tensor: torch.Tensor,
context: Optional[DomainParallelContext],
):
if context is None:
return None
if not _physicsnemo_available():
return None

from physicsnemo.distributed import ShardTensor

return ShardTensor.from_local(
local_tensor=local_tensor,
device_mesh=context.mesh,
placements=context.placements,
)
63 changes: 58 additions & 5 deletions gnn_model/gnn_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,18 @@
import zarr
from zarr.storage import LRUStoreCache
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler
from torch_geometric.data import HeteroData
from torch_geometric.loader import DataLoader as PyGDataLoader

from nnja_adapter import build_zlike_from_df
from domain_parallel import (
build_mesh_shard,
filter_bipartite_edges,
filter_mesh_edges,
init_domain_parallel_context,
shard_tensor_dim0,
)
from process_timeseries import extract_features, organize_bins_times
from create_mesh_graph_global import obs_mesh_conn

Expand Down Expand Up @@ -85,6 +93,9 @@ def __init__(
latent_step_hours=12, # latent rollout support
window_size="12h", # binning window
train_val_split_ratio=0.9, # Default fallback, should be passed from training script
domain_parallel: bool = False,
domain_parallel_mesh_shape: list[int] | None = None,
domain_parallel_mesh_dim_names: list[str] | None = None,
**kwargs,
):
super().__init__()
Expand All @@ -105,6 +116,8 @@ def __init__(
# Version counters (for debugging staleness)
self._train_version = 0
self._val_version = 0
self.domain_parallel_context = None
self.domain_shard = None

# If callbacks want separate windows, they will set these:
# Default: create non-overlapping train/val split to prevent data leakage
Expand Down Expand Up @@ -149,6 +162,14 @@ def __init__(
def setup(self, stage=None):
rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0

if self.hparams.domain_parallel and self.domain_shard is None:
self.domain_parallel_context = init_domain_parallel_context(
mesh_shape=self.hparams.domain_parallel_mesh_shape,
mesh_dim_names=self.hparams.domain_parallel_mesh_dim_names,
)
num_nodes = self.mesh_structure["mesh_features_torch"][0].shape[0]
self.domain_shard = build_mesh_shard(num_nodes)

# Open Zarrs once
if self.z is None:
self.z = {}
Expand Down Expand Up @@ -296,11 +317,21 @@ def _create_graph_structure(self, bin_data):
data = HeteroData()

# 1) Mesh nodes and edges
data["mesh"].x = _t32(self.mesh_structure["mesh_features_torch"][0])
data["mesh"].pos = _t32(self.mesh_structure["mesh_lat_lon_list"][0])

mesh_features = self.mesh_structure["mesh_features_torch"][0]
mesh_pos = self.mesh_structure["mesh_lat_lon_list"][0]
m2m_edge_index = self.mesh_structure["m2m_edge_index_torch"][0]
m2m_edge_attr = self.mesh_structure["m2m_features_torch"][0]

if self.domain_shard is not None:
mesh_features = shard_tensor_dim0(mesh_features, self.domain_shard)
mesh_pos = shard_tensor_dim0(mesh_pos, self.domain_shard)
m2m_edge_index, m2m_edge_attr = filter_mesh_edges(
m2m_edge_index, m2m_edge_attr, self.domain_shard
)

data["mesh"].x = _t32(mesh_features)
data["mesh"].pos = _t32(mesh_pos)

reverse_edges = torch.stack([m2m_edge_index[1], m2m_edge_index[0]], dim=0)
data["mesh", "to", "mesh"].edge_index = torch.cat([m2m_edge_index, reverse_edges], dim=1)
data["mesh", "to", "mesh"].edge_attr = torch.cat([m2m_edge_attr, m2m_edge_attr], dim=0)
Expand Down Expand Up @@ -358,6 +389,13 @@ def _create_latent_nodes(self, data, inst_name, inst_dict, num_latent_steps):
self.mesh_structure["mesh_list"],
o2m=True,
)
if self.domain_shard is not None:
edge_index_encoder, edge_attr_encoder = filter_bipartite_edges(
edge_index_encoder,
edge_attr_encoder,
self.domain_shard,
mesh_dim=1,
)
data[node_type_input, "to", "mesh"].edge_index = edge_index_encoder
data[node_type_input, "to", "mesh"].edge_attr = edge_attr_encoder

Expand Down Expand Up @@ -465,6 +503,13 @@ def _create_latent_nodes(self, data, inst_name, inst_dict, num_latent_steps):
self.mesh_structure["mesh_list"],
o2m=False,
)
if self.domain_shard is not None:
edge_index_decoder, edge_attr_decoder = filter_bipartite_edges(
edge_index_decoder,
edge_attr_decoder,
self.domain_shard,
mesh_dim=0,
)
data["mesh", "to", node_type_target].edge_index = edge_index_decoder
data["mesh", "to", node_type_target].edge_attr = edge_attr_decoder

Expand Down Expand Up @@ -514,10 +559,14 @@ def train_dataloader(self):
feature_stats=self.feature_stats,
tag="TRAIN",
)
sampler = None
if dist.is_available() and dist.is_initialized():
sampler = DistributedSampler(ds, shuffle=True)
loader = PyGDataLoader(
ds,
batch_size=self.hparams.batch_size,
shuffle=True,
shuffle=sampler is None,
sampler=sampler,
num_workers=4,
pin_memory=True,
persistent_workers=False, # safer while debugging stale refs
Expand All @@ -539,10 +588,14 @@ def val_dataloader(self):
feature_stats=self.feature_stats,
tag="VAL",
)
sampler = None
if dist.is_available() and dist.is_initialized():
sampler = DistributedSampler(ds, shuffle=False)
loader = PyGDataLoader(
ds,
batch_size=self.hparams.batch_size,
shuffle=False,
shuffle=sampler is None,
sampler=sampler,
num_workers=4,
pin_memory=True,
persistent_workers=False,
Expand Down
36 changes: 36 additions & 0 deletions gnn_model/gnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@
from loss import weighted_huber_loss
from processor_transformer import SlidingWindowTransformerProcessor
from attn_bipartite import BipartiteGAT
from domain_parallel import (
build_mesh_shard,
filter_mesh_edges,
init_domain_parallel_context,
maybe_build_shardtensor,
shard_tensor_dim0,
)


def _build_instrument_map(observation_config: dict) -> dict[str, int]:
Expand Down Expand Up @@ -73,6 +80,9 @@ def __init__(
decoder_layers: int = 2,
encoder_dropout: float = 0.0,
decoder_dropout: float = 0.0,
domain_parallel: bool = False,
domain_parallel_mesh_shape: list[int] | None = None,
domain_parallel_mesh_dim_names: list[str] | None = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -112,6 +122,10 @@ def __init__(
print("[MODEL] instrument_weights:", {self.instrument_id_to_name[k]: float(v) for k, v in self.instrument_weights.items()})

self.hidden_dim = hidden_dim
self.domain_parallel_context = None
self.domain_shard = None
self.mesh_x_sharded = None
self.mesh_edge_attr_sharded = None

# --- Create and store the mesh structure as part of the model ---
self.mesh_structure = create_mesh(splits=mesh_resolution, levels=4, hierarchical=False, plot=False)
Expand Down Expand Up @@ -260,6 +274,28 @@ def _as_i64(x):
self.register_buffer("mesh_edge_index", _as_i64(mesh_edge_index))
self.register_buffer("mesh_edge_attr", _as_f32(mesh_edge_attr))

def on_fit_start(self):
super().on_fit_start()
if not self.hparams.domain_parallel:
return

if self.domain_shard is None:
self.domain_parallel_context = init_domain_parallel_context(
mesh_shape=self.hparams.domain_parallel_mesh_shape,
mesh_dim_names=self.hparams.domain_parallel_mesh_dim_names,
)
self.domain_shard = build_mesh_shard(self.mesh_x.shape[0])

self.mesh_x = shard_tensor_dim0(self.mesh_x, self.domain_shard)
self.mesh_edge_index, self.mesh_edge_attr = filter_mesh_edges(
self.mesh_edge_index, self.mesh_edge_attr, self.domain_shard
)

self.mesh_x_sharded = maybe_build_shardtensor(self.mesh_x, self.domain_parallel_context)
self.mesh_edge_attr_sharded = maybe_build_shardtensor(
self.mesh_edge_attr, self.domain_parallel_context
)

def transfer_batch_to_device(self, batch, device, dataloader_idx):
# PyG Data/HeteroData implements .to()
if hasattr(batch, "to"):
Expand Down
Loading