Skip to content
This repository was archived by the owner on Dec 20, 2024. It is now read-only.

Commit 2ff0807

Browse files
icedoom888JPXKQX
andauthored
Feature/hierarchical graphs (#37)
* Hard coded implementation of the hierarchical graph model * Added implementation of Hierarchical Graph networks * Added instantiate model in interface init * if-else branching instead of hydra:instantiate, have to fix this in the future. * Added changes before migration * WORKING implementation of Hierarchical graph network * Refactor and cleaning * Added example config * Minor refactor * Refactor * Refactor and rebase * Refactor and small changes for merge * Re-added asserts in mapper * Added entry in changelog * Refactor pre-merge * Refactored the hierarchical model * Test dimentions completed. * Fixed dynamo issue * Refactored using NamedNodesAttributes * Fixed with git pre-commit * Update src/anemoi/models/models/hierarchical.py Co-authored-by: Mario Santa Cruz <[email protected]> * Update src/anemoi/models/models/hierarchical.py Co-authored-by: Mario Santa Cruz <[email protected]> --------- Co-authored-by: Mario Santa Cruz <[email protected]>
1 parent fd2bcf1 commit 2ff0807

File tree

4 files changed

+315
-0
lines changed

4 files changed

+315
-0
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Keep it human-readable, your future self will thank you!
1414

1515
### Added
1616

17+
- New AnemoiModelEncProcDecHierarchical class available in models [#37](https://github.com/ecmwf/anemoi-models/pull/37)
1718
- Add anemoi-transform link to documentation
1819
- Codeowners file
1920
- Pygrep precommit hooks

src/anemoi/models/layers/processor.py

+1
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ def forward(
323323
*args,
324324
**kwargs,
325325
) -> Tensor:
326+
326327
shape_nodes = change_channels_in_shape(shard_shapes, self.num_channels)
327328
edge_attr = self.trainable(self.edge_attr, batch_size)
328329

src/anemoi/models/models/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,8 @@
66
# In applying this licence, ECMWF does not waive the privileges and immunities
77
# granted to it by virtue of its status as an intergovernmental organisation
88
# nor does it submit to any jurisdiction.
9+
10+
from .encoder_processor_decoder import AnemoiModelEncProcDec
11+
from .hierarchical import AnemoiModelEncProcDecHierarchical
12+
13+
__all__ = ["AnemoiModelEncProcDec", "AnemoiModelEncProcDecHierarchical"]
+308
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
# (C) Copyright 2024 ECMWF.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
# In applying this licence, ECMWF does not waive the privileges and immunities
6+
# granted to it by virtue of its status as an intergovernmental organisation
7+
# nor does it submit to any jurisdiction.
8+
#
9+
10+
import logging
11+
from typing import Optional
12+
13+
import einops
14+
import torch
15+
from anemoi.utils.config import DotDict
16+
from hydra.utils import instantiate
17+
from torch import Tensor
18+
from torch import nn
19+
from torch.distributed.distributed_c10d import ProcessGroup
20+
from torch_geometric.data import HeteroData
21+
22+
from anemoi.models.distributed.shapes import get_shape_shards
23+
from anemoi.models.layers.graph import NamedNodesAttributes
24+
from anemoi.models.layers.graph import TrainableTensor
25+
from anemoi.models.models import AnemoiModelEncProcDec
26+
27+
LOGGER = logging.getLogger(__name__)
28+
29+
30+
class AnemoiModelEncProcDecHierarchical(AnemoiModelEncProcDec):
31+
"""Message passing hierarchical graph neural network."""
32+
33+
def __init__(
34+
self,
35+
*,
36+
model_config: DotDict,
37+
data_indices: dict,
38+
graph_data: HeteroData,
39+
) -> None:
40+
"""Initializes the graph neural network.
41+
42+
Parameters
43+
----------
44+
config : DotDict
45+
Job configuration
46+
data_indices : dict
47+
Data indices
48+
graph_data : HeteroData
49+
Graph definition
50+
"""
51+
nn.Module.__init__(self)
52+
53+
self._graph_data = graph_data
54+
self._graph_name_data = model_config.graph.data
55+
self._graph_hidden_names = model_config.graph.hidden
56+
self.num_hidden = len(self._graph_hidden_names)
57+
58+
# Unpack config for hierarchical graph
59+
self.level_process = model_config.model.enable_hierarchical_level_processing
60+
61+
# hidden_dims is the dimentionality of features at each depth
62+
self.hidden_dims = {
63+
hidden: model_config.model.num_channels * (2**i) for i, hidden in enumerate(self._graph_hidden_names)
64+
}
65+
66+
self._calculate_shapes_and_indices(data_indices)
67+
self._assert_matching_indices(data_indices)
68+
self.data_indices = data_indices
69+
70+
self.multi_step = model_config.training.multistep_input
71+
72+
# self.node_attributes = {hidden_name: NamedNodesAttributes(model_config.model.trainable_parameters[hidden_name], self._graph_data)
73+
# for hidden_name in self._graph_hidden_names}
74+
self.node_attributes = NamedNodesAttributes(model_config.model.trainable_parameters.hidden, self._graph_data)
75+
76+
input_dim = self.multi_step * self.num_input_channels + self.node_attributes.attr_ndims[self._graph_name_data]
77+
78+
# Encoder data -> hidden
79+
self.encoder = instantiate(
80+
model_config.model.encoder,
81+
in_channels_src=input_dim,
82+
in_channels_dst=self.node_attributes.attr_ndims[self._graph_hidden_names[0]],
83+
hidden_dim=self.hidden_dims[self._graph_hidden_names[0]],
84+
sub_graph=self._graph_data[(self._graph_name_data, "to", self._graph_hidden_names[0])],
85+
src_grid_size=self.node_attributes.num_nodes[self._graph_name_data],
86+
dst_grid_size=self.node_attributes.num_nodes[self._graph_hidden_names[0]],
87+
)
88+
89+
# Level processors
90+
if self.level_process:
91+
self.down_level_processor = nn.ModuleDict()
92+
self.up_level_processor = nn.ModuleDict()
93+
94+
for i in range(0, self.num_hidden):
95+
nodes_names = self._graph_hidden_names[i]
96+
97+
self.down_level_processor[nodes_names] = instantiate(
98+
model_config.model.processor,
99+
num_channels=self.hidden_dims[nodes_names],
100+
sub_graph=self._graph_data[(nodes_names, "to", nodes_names)],
101+
src_grid_size=self.node_attributes.num_nodes[nodes_names],
102+
dst_grid_size=self.node_attributes.num_nodes[nodes_names],
103+
num_layers=model_config.model.level_process_num_layers,
104+
)
105+
106+
self.up_level_processor[nodes_names] = instantiate(
107+
model_config.model.processor,
108+
num_channels=self.hidden_dims[nodes_names],
109+
sub_graph=self._graph_data[(nodes_names, "to", nodes_names)],
110+
src_grid_size=self.node_attributes.num_nodes[nodes_names],
111+
dst_grid_size=self.node_attributes.num_nodes[nodes_names],
112+
num_layers=model_config.model.level_process_num_layers,
113+
)
114+
115+
# delete final upscale (does not exist): |->|->|<-|<-|
116+
del self.up_level_processor[nodes_names]
117+
118+
# Downscale
119+
self.downscale = nn.ModuleDict()
120+
121+
for i in range(0, self.num_hidden - 1):
122+
src_nodes_name = self._graph_hidden_names[i]
123+
dst_nodes_name = self._graph_hidden_names[i + 1]
124+
125+
self.downscale[src_nodes_name] = instantiate(
126+
model_config.model.encoder,
127+
in_channels_src=self.hidden_dims[src_nodes_name],
128+
in_channels_dst=self.node_attributes.attr_ndims[dst_nodes_name],
129+
hidden_dim=self.hidden_dims[dst_nodes_name],
130+
sub_graph=self._graph_data[(src_nodes_name, "to", dst_nodes_name)],
131+
src_grid_size=self.node_attributes.num_nodes[src_nodes_name],
132+
dst_grid_size=self.node_attributes.num_nodes[dst_nodes_name],
133+
)
134+
135+
# Upscale
136+
self.upscale = nn.ModuleDict()
137+
138+
for i in range(1, self.num_hidden):
139+
src_nodes_name = self._graph_hidden_names[i]
140+
dst_nodes_name = self._graph_hidden_names[i - 1]
141+
142+
self.upscale[src_nodes_name] = instantiate(
143+
model_config.model.decoder,
144+
in_channels_src=self.hidden_dims[src_nodes_name],
145+
in_channels_dst=self.hidden_dims[dst_nodes_name],
146+
hidden_dim=self.hidden_dims[src_nodes_name],
147+
out_channels_dst=self.hidden_dims[dst_nodes_name],
148+
sub_graph=self._graph_data[(src_nodes_name, "to", dst_nodes_name)],
149+
src_grid_size=self.node_attributes.num_nodes[src_nodes_name],
150+
dst_grid_size=self.node_attributes.num_nodes[dst_nodes_name],
151+
)
152+
153+
# Decoder hidden -> data
154+
self.decoder = instantiate(
155+
model_config.model.decoder,
156+
in_channels_src=self.hidden_dims[self._graph_hidden_names[0]],
157+
in_channels_dst=input_dim,
158+
hidden_dim=self.hidden_dims[self._graph_hidden_names[0]],
159+
out_channels_dst=self.num_output_channels,
160+
sub_graph=self._graph_data[(self._graph_hidden_names[0], "to", self._graph_name_data)],
161+
src_grid_size=self.node_attributes.num_nodes[self._graph_hidden_names[0]],
162+
dst_grid_size=self.node_attributes.num_nodes[self._graph_name_data],
163+
)
164+
165+
# Instantiation of model output bounding functions (e.g., to ensure outputs like TP are positive definite)
166+
self.boundings = nn.ModuleList(
167+
[
168+
instantiate(cfg, name_to_index=self.data_indices.internal_model.output.name_to_index)
169+
for cfg in getattr(model_config.model, "bounding", [])
170+
]
171+
)
172+
173+
def _create_trainable_attributes(self) -> None:
174+
"""Create all trainable attributes."""
175+
self.trainable_data = TrainableTensor(trainable_size=self.trainable_data_size, tensor_size=self._data_grid_size)
176+
self.trainable_hidden = nn.ModuleDict()
177+
178+
for hidden in self._graph_hidden_names:
179+
self.trainable_hidden[hidden] = TrainableTensor(
180+
trainable_size=self.trainable_hidden_size, tensor_size=self._hidden_grid_sizes[hidden]
181+
)
182+
183+
def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> Tensor:
184+
batch_size = x.shape[0]
185+
ensemble_size = x.shape[2]
186+
187+
# add data positional info (lat/lon)
188+
x_trainable_data = torch.cat(
189+
(
190+
einops.rearrange(x, "batch time ensemble grid vars -> (batch ensemble grid) (time vars)"),
191+
self.node_attributes(self._graph_name_data, batch_size=batch_size),
192+
),
193+
dim=-1, # feature dimension
194+
)
195+
196+
# Get all trainable parameters for the hidden layers -> initialisation of each hidden, which becomes trainable bias
197+
x_trainable_hiddens = {}
198+
for hidden in self._graph_hidden_names:
199+
x_trainable_hiddens[hidden] = self.node_attributes(hidden, batch_size=batch_size)
200+
201+
# Get data and hidden shapes for sharding
202+
shard_shapes_data = get_shape_shards(x_trainable_data, 0, model_comm_group)
203+
shard_shapes_hiddens = {}
204+
for hidden, x_latent in x_trainable_hiddens.items():
205+
shard_shapes_hiddens[hidden] = get_shape_shards(x_latent, 0, model_comm_group)
206+
207+
# Run encoder
208+
x_data_latent, curr_latent = self._run_mapper(
209+
self.encoder,
210+
(x_trainable_data, x_trainable_hiddens[self._graph_hidden_names[0]]),
211+
batch_size=batch_size,
212+
shard_shapes=(shard_shapes_data, shard_shapes_hiddens[self._graph_hidden_names[0]]),
213+
model_comm_group=model_comm_group,
214+
)
215+
216+
# Run processor
217+
x_encoded_latents = {}
218+
x_skip = {}
219+
220+
## Downscale
221+
for i in range(0, self.num_hidden - 1):
222+
src_hidden_name = self._graph_hidden_names[i]
223+
dst_hidden_name = self._graph_hidden_names[i + 1]
224+
225+
# Processing at same level
226+
if self.level_process:
227+
curr_latent = self.down_level_processor[src_hidden_name](
228+
curr_latent,
229+
batch_size=batch_size,
230+
shard_shapes=shard_shapes_hiddens[src_hidden_name],
231+
model_comm_group=model_comm_group,
232+
)
233+
234+
# store latents for skip connections
235+
x_skip[src_hidden_name] = curr_latent
236+
237+
# Encode to next hidden level
238+
x_encoded_latents[src_hidden_name], curr_latent = self._run_mapper(
239+
self.downscale[src_hidden_name],
240+
(curr_latent, x_trainable_hiddens[dst_hidden_name]),
241+
batch_size=batch_size,
242+
shard_shapes=(shard_shapes_hiddens[src_hidden_name], shard_shapes_hiddens[dst_hidden_name]),
243+
model_comm_group=model_comm_group,
244+
)
245+
246+
# Processing hidden-most level
247+
if self.level_process:
248+
curr_latent = self.down_level_processor[dst_hidden_name](
249+
curr_latent,
250+
batch_size=batch_size,
251+
shard_shapes=shard_shapes_hiddens[dst_hidden_name],
252+
model_comm_group=model_comm_group,
253+
)
254+
255+
## Upscale
256+
for i in range(self.num_hidden - 1, 0, -1):
257+
src_hidden_name = self._graph_hidden_names[i]
258+
dst_hidden_name = self._graph_hidden_names[i - 1]
259+
260+
# Process to next level
261+
curr_latent = self._run_mapper(
262+
self.upscale[src_hidden_name],
263+
(curr_latent, x_encoded_latents[dst_hidden_name]),
264+
batch_size=batch_size,
265+
shard_shapes=(shard_shapes_hiddens[src_hidden_name], shard_shapes_hiddens[dst_hidden_name]),
266+
model_comm_group=model_comm_group,
267+
)
268+
269+
# Add skip connections
270+
curr_latent = curr_latent + x_skip[dst_hidden_name]
271+
272+
# Processing at same level
273+
if self.level_process:
274+
curr_latent = self.up_level_processor[dst_hidden_name](
275+
curr_latent,
276+
batch_size=batch_size,
277+
shard_shapes=shard_shapes_hiddens[dst_hidden_name],
278+
model_comm_group=model_comm_group,
279+
)
280+
281+
# Run decoder
282+
x_out = self._run_mapper(
283+
self.decoder,
284+
(curr_latent, x_data_latent),
285+
batch_size=batch_size,
286+
shard_shapes=(shard_shapes_hiddens[self._graph_hidden_names[0]], shard_shapes_data),
287+
model_comm_group=model_comm_group,
288+
)
289+
290+
x_out = (
291+
einops.rearrange(
292+
x_out,
293+
"(batch ensemble grid) vars -> batch ensemble grid vars",
294+
batch=batch_size,
295+
ensemble=ensemble_size,
296+
)
297+
.to(dtype=x.dtype)
298+
.clone()
299+
)
300+
301+
# residual connection (just for the prognostic variables)
302+
x_out[..., self._internal_output_idx] += x[:, -1, :, :, self._internal_input_idx]
303+
304+
for bounding in self.boundings:
305+
# bounding performed in the order specified in the config file
306+
x_out = bounding(x_out)
307+
308+
return x_out

0 commit comments

Comments
 (0)