Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
feat(tests): add test for LatLonNodes
Browse files Browse the repository at this point in the history
  • Loading branch information
JPXKQX committed Dec 18, 2024
1 parent 94f7819 commit b23e56f
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/anemoi/graphs/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@
from .builders.from_refined_icosahedron import LimitedAreaTriNodes
from .builders.from_refined_icosahedron import StretchedTriNodes
from .builders.from_refined_icosahedron import TriNodes
from .builders.from_vectors import LatLonNodes

__all__ = [
"ZarrDatasetNodes",
"NPZFileNodes",
"TriNodes",
"HexNodes",
"HEALPixNodes",
"LatLonNodes",
"LimitedAreaHEALPixNodes",
"LimitedAreaNPZFileNodes",
"LimitedAreaTriNodes",
Expand Down
64 changes: 64 additions & 0 deletions tests/nodes/test_arrays.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest
import torch
from torch_geometric.data import HeteroData

from anemoi.graphs.nodes.attributes import AreaWeights
from anemoi.graphs.nodes.attributes import UniformWeights
from anemoi.graphs.nodes.builders.from_vectors import LatLonNodes

lats = [45.0, 45.0, 40.0, 40.0]
lons = [5.0, 10.0, 10.0, 5.0]


def test_init():
"""Test LatLonNodes initialization."""
node_builder = LatLonNodes(latitudes=lats, longitudes=lons, name="test_nodes")
assert isinstance(node_builder, LatLonNodes)


def test_fail_init_length_mismatch():
"""Test LatLonNodes initialization with invalid argument."""
lons = [5.0, 10.0, 10.0, 5.0, 5.0]

with pytest.raises(AssertionError):
LatLonNodes(latitudes=lats, longitudes=lons, name="test_nodes")


def test_fail_init_missing_argument():
"""Test NPZFileNodes initialization with missing argument."""
with pytest.raises(TypeError):
LatLonNodes(name="test_nodes")


def test_register_nodes():
"""Test LatLonNodes register correctly the nodes."""
graph = HeteroData()
node_builder = LatLonNodes(latitudes=lats, longitudes=lons, name="test_nodes")
graph = node_builder.register_nodes(graph)

assert graph["test_nodes"].x is not None
assert isinstance(graph["test_nodes"].x, torch.Tensor)
assert graph["test_nodes"].x.shape == (len(lats), 2)
assert graph["test_nodes"].node_type == "LatLonNodes"


@pytest.mark.parametrize("attr_class", [UniformWeights, AreaWeights])
def test_register_attributes(graph_with_nodes: HeteroData, attr_class):
"""Test LatLonNodes register correctly the weights."""
node_builder = LatLonNodes(latitudes=lats, longitudes=lons, name="test_nodes")
config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.attributes.{attr_class.__name__}"}}

graph = node_builder.register_attributes(graph_with_nodes, config)

assert graph["test_nodes"]["test_attr"] is not None
assert isinstance(graph["test_nodes"]["test_attr"], torch.Tensor)
assert graph["test_nodes"]["test_attr"].shape[0] == graph["test_nodes"].x.shape[0]

0 comments on commit b23e56f

Please sign in to comment.