diff --git a/CHANGELOG.md b/CHANGELOG.md index 4420e8c..33c692b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,8 +24,7 @@ Keep it human-readable, your future self will thank you! - feat: Support for multiple edge builders between two sets of nodes (#70) - feat: Support for providing lon/lat coordinates from a text file (loaded with numpy loadtxt method) to build the graph `TextNodes` (#93) - feat: Build 2D graphs with `Voronoi` in case `SphericalVoronoi` does not work well/is an overkill (LAM). Set `flat=true` in the nodes attributes to compute area weight using Voronoi with a qhull options preventing the empty region creation (#93) -- feat: Add `AttributeFromNode` edge attribute to copy attribute from source or destination node. Set `node_attr_name` and `node_type : src | dst` in the config to specify which attribute to copy from the source | destination node (#94) - +- feat: Add `AttributeFromSourceNode` and `AttributeFromTargetNode` edge attribute to copy attribute from source or target node. Set `node_attr_name` in the config to specify which attribute to copy from the source | target node (#94) # Changed diff --git a/src/anemoi/graphs/edges/attributes.py b/src/anemoi/graphs/edges/attributes.py index 3aca8a8..3d8122e 100644 --- a/src/anemoi/graphs/edges/attributes.py +++ b/src/anemoi/graphs/edges/attributes.py @@ -24,8 +24,9 @@ class BaseEdgeAttribute(ABC, NormaliserMixin): """Base class for edge attributes.""" - def __init__(self, norm: str | None = None) -> None: + def __init__(self, norm: str | None = None, dtype: str = "float32") -> None: self.norm = norm + self.dtype = dtype @abstractmethod def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str, *args, **kwargs) -> np.ndarray: ... @@ -35,9 +36,9 @@ def post_process(self, values: np.ndarray) -> torch.Tensor: if values.ndim == 1: values = values[:, np.newaxis] - normed_values = self.normalise(values) + norm_values = self.normalise(values) - return torch.tensor(normed_values, dtype=torch.float32) + return torch.tensor(norm_values.astype(self.dtype)) def compute(self, graph: HeteroData, edges_name: tuple[str, str, str], *args, **kwargs) -> torch.Tensor: """Compute the edge attributes.""" @@ -157,36 +158,18 @@ def post_process(self, values: np.ndarray) -> torch.Tensor: return values -class BooleanBaseEdgeAttribute: +class BooleanBaseEdgeAttribute(BaseEdgeAttribute): """Base class for boolean edge attributes.""" def __init__(self) -> None: - pass - - @abstractmethod - def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str, *args, **kwargs) -> np.ndarray: ... - - def post_process(self, values: np.ndarray) -> torch.Tensor: - """Post-process the values.""" - return torch.tensor(values, dtype=torch.bool) - - def compute(self, graph: HeteroData, edges_name: tuple[str, str, str], *args, **kwargs) -> torch.Tensor: - """Compute the edge attributes.""" - source_name, _, target_name = edges_name - assert ( - source_name in graph.node_types - ), f"Node \"{source_name}\" not found in graph. Optional nodes are {', '.join(graph.node_types)}." - assert ( - target_name in graph.node_types - ), f"Node \"{target_name}\" not found in graph. Optional nodes are {', '.join(graph.node_types)}." - - values = self.get_raw_values(graph, source_name, target_name, *args, **kwargs) - return self.post_process(values) + super().__init__(norm=None, dtype="bool") class AttributeFromNode(BooleanBaseEdgeAttribute): """ - Copy an attribute of either the source or destination node to the edge. + Base class for Attribute from Node. + + Copy an attribute of either the source or target node to the edge. Accesses origin/target node attribute and propagates it to the edge. Used for example to identify if an encoder edge originates from a LAM or global node. @@ -195,34 +178,55 @@ class AttributeFromNode(BooleanBaseEdgeAttribute): node_attr_name : str Name of the node attribute to propagate. - node_type : str - Pick the node to copy from. Options: "src, dst" - Methods ------- + get_node_name(source_name, target_name) + Return the name of the node to copy. + get_raw_values(graph, source_name, target_name) - Computes the edge attribute from the source or destination node attribute. + Computes the edge attribute from the source or target node attribute. + """ - def __init__(self, node_attr_name: str, node_type: str) -> None: + def __init__(self, node_attr_name: str) -> None: + super().__init__() self.node_attr_name = node_attr_name - assert node_type in ["src", "dst"] - self.node_type = node_type + self.idx = None + + @abstractmethod + def get_node_name(self, source_name: str, target_name: str): ... def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str) -> np.ndarray: + node_name = self.get_node_name(source_name, target_name) + edge_index = graph[(source_name, "to", target_name)].edge_index + assert hasattr(graph[node_name], self.node_attr_name) + val = getattr(graph[node_name], self.node_attr_name).numpy()[edge_index[self.idx]] + return val - if self.node_type == "src": - name_to_copy = source_name - idx = 0 - else: - name_to_copy = target_name - idx = 1 +class AttributeFromSourceNode(AttributeFromNode): + """ + Copy an attribute of the source node to the edge. + """ - assert hasattr(graph[name_to_copy], self.node_attr_name) + def __init__(self, node_attr_name: str) -> None: + super().__init__(node_attr_name) + self.idx = 0 - val = getattr(graph[name_to_copy], self.node_attr_name).numpy()[edge_index[idx]] + def get_node_name(self, source_name: str, target_name: str): + return source_name - return val + +class AttributeFromTargetNode(AttributeFromNode): + """ + Copy an attribute of the target node to the edge. + """ + + def __init__(self, node_attr_name: str) -> None: + super().__init__(node_attr_name) + self.idx = 1 + + def get_node_name(self, source_name: str, target_name: str): + return target_name