From 7353f83f13fb232efd079ce060367dc7bbd86cfa Mon Sep 17 00:00:00 2001
From: icedoom888 <alberto.pennino.8@gmail.com>
Date: Tue, 17 Dec 2024 19:32:21 +0100
Subject: [PATCH] Implemented new attribute

---
 src/anemoi/graphs/edges/attributes.py | 71 +++++++++++++++++++++++++++
 1 file changed, 71 insertions(+)

diff --git a/src/anemoi/graphs/edges/attributes.py b/src/anemoi/graphs/edges/attributes.py
index 6be801b..3aca8a8 100644
--- a/src/anemoi/graphs/edges/attributes.py
+++ b/src/anemoi/graphs/edges/attributes.py
@@ -155,3 +155,74 @@ def post_process(self, values: np.ndarray) -> torch.Tensor:
             values = 1 - values
 
         return values
+
+
+class BooleanBaseEdgeAttribute:
+    """Base class for boolean edge attributes."""
+
+    def __init__(self) -> None:
+        pass
+
+    @abstractmethod
+    def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str, *args, **kwargs) -> np.ndarray: ...
+
+    def post_process(self, values: np.ndarray) -> torch.Tensor:
+        """Post-process the values."""
+        return torch.tensor(values, dtype=torch.bool)
+
+    def compute(self, graph: HeteroData, edges_name: tuple[str, str, str], *args, **kwargs) -> torch.Tensor:
+        """Compute the edge attributes."""
+        source_name, _, target_name = edges_name
+        assert (
+            source_name in graph.node_types
+        ), f"Node \"{source_name}\" not found in graph. Optional nodes are {', '.join(graph.node_types)}."
+        assert (
+            target_name in graph.node_types
+        ), f"Node \"{target_name}\" not found in graph. Optional nodes are {', '.join(graph.node_types)}."
+
+        values = self.get_raw_values(graph, source_name, target_name, *args, **kwargs)
+        return self.post_process(values)
+
+
+class AttributeFromNode(BooleanBaseEdgeAttribute):
+    """
+    Copy an attribute of either the source or destination node to the edge.
+    Accesses origin/target node attribute and propagates it to the edge.
+    Used for example to identify if an encoder edge originates from a LAM or global node.
+
+    Attributes
+    ----------
+    node_attr_name : str
+        Name of the node attribute to propagate.
+
+    node_type : str
+        Pick the node to copy from. Options: "src, dst"
+
+    Methods
+    -------
+    get_raw_values(graph, source_name, target_name)
+        Computes the edge attribute from the source or destination node attribute.
+    """
+
+    def __init__(self, node_attr_name: str, node_type: str) -> None:
+        self.node_attr_name = node_attr_name
+        assert node_type in ["src", "dst"]
+        self.node_type = node_type
+
+    def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str) -> np.ndarray:
+
+        edge_index = graph[(source_name, "to", target_name)].edge_index
+
+        if self.node_type == "src":
+            name_to_copy = source_name
+            idx = 0
+
+        else:
+            name_to_copy = target_name
+            idx = 1
+
+        assert hasattr(graph[name_to_copy], self.node_attr_name)
+
+        val = getattr(graph[name_to_copy], self.node_attr_name).numpy()[edge_index[idx]]
+
+        return val