Skip to content
This repository was archived by the owner on Dec 20, 2024. It is now read-only.

Commit d6bc06b

Browse files
authored
Merge branch 'develop' into feature/torch-support
2 parents 3d621f1 + 3898f6f commit d6bc06b

File tree

4 files changed

+106
-3
lines changed

4 files changed

+106
-3
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ Keep it human-readable, your future self will thank you!
2222
- feat: Add `RemoveUnconnectedNodes` post processor to clean unconnected nodes in LAM. (#71)
2323
- feat: Define node sets and edges based on an ICON icosahedral mesh (#53)
2424
- feat: Support for multiple edge builders between two sets of nodes (#70)
25+
- feat: Support for providing lon/lat coordinates from a text file (loaded with numpy loadtxt method) to build the graph `TextNodes` (#93)
26+
- feat: Build 2D graphs with `Voronoi` in case `SphericalVoronoi` does not work well/is an overkill (LAM). Set `flat=true` in the nodes attributes to compute area weight using Voronoi with a qhull options preventing the empty region creation (#93)
2527

2628
# Changed
2729

src/anemoi/graphs/nodes/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from .builders.from_file import LimitedAreaNPZFileNodes
1111
from .builders.from_file import NPZFileNodes
12+
from .builders.from_file import TextNodes
1213
from .builders.from_file import ZarrDatasetNodes
1314
from .builders.from_healpix import HEALPixNodes
1415
from .builders.from_healpix import LimitedAreaHEALPixNodes
@@ -35,4 +36,5 @@
3536
"ICONMultimeshNodes",
3637
"ICONCellGridNodes",
3738
"ICONNodes",
39+
"TextNodes",
3840
]

src/anemoi/graphs/nodes/attributes.py

+67-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
import numpy as np
1919
import torch
2020
from anemoi.datasets import open_dataset
21+
from scipy.spatial import ConvexHull
2122
from scipy.spatial import SphericalVoronoi
23+
from scipy.spatial import Voronoi
2224
from torch_geometric.data import HeteroData
2325
from torch_geometric.data.storage import NodeStorage
2426

@@ -101,6 +103,68 @@ def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray:
101103
class AreaWeights(BaseNodeAttribute):
102104
"""Implements the area of the nodes as the weights.
103105
106+
Attributes
107+
----------
108+
flat: bool
109+
If True, the area is computed in 2D, otherwise in 3D.
110+
**other: Any
111+
Additional keyword arguments, see PlanarAreaWeights and SphericalAreaWeights
112+
for details.
113+
114+
Methods
115+
-------
116+
compute(self, graph, nodes_name)
117+
Compute the area attributes for each node.
118+
"""
119+
120+
def __new__(cls, flat: bool = False, **kwargs):
121+
logging.warning(
122+
"Creating %s with flat=%s and kwargs=%s. In a future release, AreaWeights will be deprecated: please use directly PlanarAreaWeights or SphericalAreaWeights.",
123+
cls.__name__,
124+
flat,
125+
kwargs,
126+
)
127+
if flat:
128+
return PlanarAreaWeights(**kwargs)
129+
return SphericalAreaWeights(**kwargs)
130+
131+
132+
class PlanarAreaWeights(BaseNodeAttribute):
133+
"""Implements the 2D area of the nodes as the weights.
134+
135+
Attributes
136+
----------
137+
norm : str
138+
Normalisation of the weights.
139+
140+
Methods
141+
-------
142+
compute(self, graph, nodes_name)
143+
Compute the area attributes for each node.
144+
"""
145+
146+
def __init__(
147+
self,
148+
norm: str | None = None,
149+
dtype: str = "float32",
150+
) -> None:
151+
super().__init__(norm, dtype)
152+
153+
def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray:
154+
latitudes, longitudes = nodes.x[:, 0], nodes.x[:, 1]
155+
points = np.stack([latitudes, longitudes], -1)
156+
v = Voronoi(points, qhull_options="QJ Pp")
157+
areas = []
158+
for r in v.regions:
159+
area = ConvexHull(v.vertices[r, :]).volume
160+
areas.append(area)
161+
result = np.asarray(areas)
162+
return result
163+
164+
165+
class SphericalAreaWeights(BaseNodeAttribute):
166+
"""Implements the 3D area of the nodes as the weights.
167+
104168
Attributes
105169
----------
106170
norm : str
@@ -148,8 +212,9 @@ def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray:
148212
np.ndarray
149213
Attributes.
150214
"""
151-
points = latlon_rad_to_cartesian(nodes.x)
152-
sv = SphericalVoronoi(points.cpu().numpy(), self.radius, self.centre)
215+
latitudes, longitudes = nodes.x[:, 0], nodes.x[:, 1]
216+
points = latlon_rad_to_cartesian((np.asarray(latitudes), np.asarray(longitudes)))
217+
sv = SphericalVoronoi(points, self.radius, self.centre)
153218
mask = np.array([bool(i) for i in sv.regions])
154219
sv.regions = [region for region in sv.regions if region]
155220
# compute the area weight without empty regions

src/anemoi/graphs/nodes/builders/from_file.py

+35-1
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,37 @@ def get_coordinates(self) -> torch.Tensor:
6363
return self.reshape_coords(dataset.latitudes, dataset.longitudes)
6464

6565

66+
class TextNodes(BaseNodeBuilder):
67+
"""Nodes from text file.
68+
69+
Attributes
70+
----------
71+
dataset : str | DictConfig
72+
The path to txt file containing the coordinates of the nodes.
73+
idx_lon : int
74+
The index of the longitude in the dataset.
75+
idx_lat : int
76+
The index of the latitude in the dataset.
77+
"""
78+
79+
def __init__(self, dataset, name: str, idx_lon: int = 0, idx_lat: int = 1) -> None:
80+
LOGGER.info("Reading the dataset from %s.", dataset)
81+
self.dataset = np.loadtxt(dataset)
82+
self.idx_lon = idx_lon
83+
self.idx_lat = idx_lat
84+
super().__init__(name)
85+
86+
def get_coordinates(self) -> torch.Tensor:
87+
"""Get the coordinates of the nodes.
88+
89+
Returns
90+
-------
91+
torch.Tensor of shape (num_nodes, 2)
92+
A 2D tensor with the coordinates, in radians.
93+
"""
94+
return self.reshape_coords(self.dataset[self.idx_lat, :], self.dataset[self.idx_lon, :])
95+
96+
6697
class NPZFileNodes(BaseNodeBuilder):
6798
"""Nodes from NPZ defined grids.
6899
@@ -146,7 +177,10 @@ def get_coordinates(self) -> np.ndarray:
146177
)
147178
area_mask = self.area_mask_builder.get_mask(coords)
148179

149-
LOGGER.info("Dropping %d nodes from the processor mesh.", len(area_mask) - area_mask.sum())
180+
LOGGER.info(
181+
"Dropping %d nodes from the processor mesh.",
182+
len(area_mask) - area_mask.sum(),
183+
)
150184
coords = coords[area_mask]
151185

152186
return coords

0 commit comments

Comments
 (0)