diff --git a/CHANGELOG.md b/CHANGELOG.md index 886fc7b..eaed146 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,8 @@ Keep it human-readable, your future self will thank you! - feat: Add `RemoveUnconnectedNodes` post processor to clean unconnected nodes in LAM. (#71) - feat: Define node sets and edges based on an ICON icosahedral mesh (#53) - feat: Support for multiple edge builders between two sets of nodes (#70) +- feat: Support for providing lon/lat coordinates from a text file (loaded with numpy loadtxt method) to build the graph `TextNodes` (#93) +- feat: Build 2D graphs with `Voronoi` in case `SphericalVoronoi` does not work well/is an overkill (LAM). Set `flat=true` in the nodes attributes to compute area weight using Voronoi with a qhull options preventing the empty region creation (#93) # Changed diff --git a/src/anemoi/graphs/nodes/__init__.py b/src/anemoi/graphs/nodes/__init__.py index 228a282..6fc429b 100644 --- a/src/anemoi/graphs/nodes/__init__.py +++ b/src/anemoi/graphs/nodes/__init__.py @@ -9,6 +9,7 @@ from .builders.from_file import LimitedAreaNPZFileNodes from .builders.from_file import NPZFileNodes +from .builders.from_file import TextNodes from .builders.from_file import ZarrDatasetNodes from .builders.from_healpix import HEALPixNodes from .builders.from_healpix import LimitedAreaHEALPixNodes @@ -35,4 +36,5 @@ "ICONMultimeshNodes", "ICONCellGridNodes", "ICONNodes", + "TextNodes", ] diff --git a/src/anemoi/graphs/nodes/attributes.py b/src/anemoi/graphs/nodes/attributes.py index 3817467..2b9884f 100644 --- a/src/anemoi/graphs/nodes/attributes.py +++ b/src/anemoi/graphs/nodes/attributes.py @@ -18,7 +18,9 @@ import numpy as np import torch from anemoi.datasets import open_dataset +from scipy.spatial import ConvexHull from scipy.spatial import SphericalVoronoi +from scipy.spatial import Voronoi from torch_geometric.data import HeteroData from torch_geometric.data.storage import NodeStorage @@ -101,6 +103,68 @@ def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray: class AreaWeights(BaseNodeAttribute): """Implements the area of the nodes as the weights. + Attributes + ---------- + flat: bool + If True, the area is computed in 2D, otherwise in 3D. + **other: Any + Additional keyword arguments, see PlanarAreaWeights and SphericalAreaWeights + for details. + + Methods + ------- + compute(self, graph, nodes_name) + Compute the area attributes for each node. + """ + + def __new__(cls, flat: bool = False, **kwargs): + logging.warning( + "Creating %s with flat=%s and kwargs=%s. In a future release, AreaWeights will be deprecated: please use directly PlanarAreaWeights or SphericalAreaWeights.", + cls.__name__, + flat, + kwargs, + ) + if flat: + return PlanarAreaWeights(**kwargs) + return SphericalAreaWeights(**kwargs) + + +class PlanarAreaWeights(BaseNodeAttribute): + """Implements the 2D area of the nodes as the weights. + + Attributes + ---------- + norm : str + Normalisation of the weights. + + Methods + ------- + compute(self, graph, nodes_name) + Compute the area attributes for each node. + """ + + def __init__( + self, + norm: str | None = None, + dtype: str = "float32", + ) -> None: + super().__init__(norm, dtype) + + def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray: + latitudes, longitudes = nodes.x[:, 0], nodes.x[:, 1] + points = np.stack([latitudes, longitudes], -1) + v = Voronoi(points, qhull_options="QJ Pp") + areas = [] + for r in v.regions: + area = ConvexHull(v.vertices[r, :]).volume + areas.append(area) + result = np.asarray(areas) + return result + + +class SphericalAreaWeights(BaseNodeAttribute): + """Implements the 3D area of the nodes as the weights. + Attributes ---------- norm : str @@ -148,8 +212,9 @@ def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray: np.ndarray Attributes. """ - points = latlon_rad_to_cartesian(nodes.x) - sv = SphericalVoronoi(points.cpu().numpy(), self.radius, self.centre) + latitudes, longitudes = nodes.x[:, 0], nodes.x[:, 1] + points = latlon_rad_to_cartesian((np.asarray(latitudes), np.asarray(longitudes))) + sv = SphericalVoronoi(points, self.radius, self.centre) mask = np.array([bool(i) for i in sv.regions]) sv.regions = [region for region in sv.regions if region] # compute the area weight without empty regions diff --git a/src/anemoi/graphs/nodes/builders/from_file.py b/src/anemoi/graphs/nodes/builders/from_file.py index c6bdc88..069b4ab 100644 --- a/src/anemoi/graphs/nodes/builders/from_file.py +++ b/src/anemoi/graphs/nodes/builders/from_file.py @@ -63,6 +63,37 @@ def get_coordinates(self) -> torch.Tensor: return self.reshape_coords(dataset.latitudes, dataset.longitudes) +class TextNodes(BaseNodeBuilder): + """Nodes from text file. + + Attributes + ---------- + dataset : str | DictConfig + The path to txt file containing the coordinates of the nodes. + idx_lon : int + The index of the longitude in the dataset. + idx_lat : int + The index of the latitude in the dataset. + """ + + def __init__(self, dataset, name: str, idx_lon: int = 0, idx_lat: int = 1) -> None: + LOGGER.info("Reading the dataset from %s.", dataset) + self.dataset = np.loadtxt(dataset) + self.idx_lon = idx_lon + self.idx_lat = idx_lat + super().__init__(name) + + def get_coordinates(self) -> torch.Tensor: + """Get the coordinates of the nodes. + + Returns + ------- + torch.Tensor of shape (num_nodes, 2) + A 2D tensor with the coordinates, in radians. + """ + return self.reshape_coords(self.dataset[self.idx_lat, :], self.dataset[self.idx_lon, :]) + + class NPZFileNodes(BaseNodeBuilder): """Nodes from NPZ defined grids. @@ -146,7 +177,10 @@ def get_coordinates(self) -> np.ndarray: ) area_mask = self.area_mask_builder.get_mask(coords) - LOGGER.info("Dropping %d nodes from the processor mesh.", len(area_mask) - area_mask.sum()) + LOGGER.info( + "Dropping %d nodes from the processor mesh.", + len(area_mask) - area_mask.sum(), + ) coords = coords[area_mask] return coords