diff --git a/src/anemoi/graphs/edges/attributes.py b/src/anemoi/graphs/edges/attributes.py index 6be801b..3aca8a8 100644 --- a/src/anemoi/graphs/edges/attributes.py +++ b/src/anemoi/graphs/edges/attributes.py @@ -155,3 +155,74 @@ def post_process(self, values: np.ndarray) -> torch.Tensor: values = 1 - values return values + + +class BooleanBaseEdgeAttribute: + """Base class for boolean edge attributes.""" + + def __init__(self) -> None: + pass + + @abstractmethod + def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str, *args, **kwargs) -> np.ndarray: ... + + def post_process(self, values: np.ndarray) -> torch.Tensor: + """Post-process the values.""" + return torch.tensor(values, dtype=torch.bool) + + def compute(self, graph: HeteroData, edges_name: tuple[str, str, str], *args, **kwargs) -> torch.Tensor: + """Compute the edge attributes.""" + source_name, _, target_name = edges_name + assert ( + source_name in graph.node_types + ), f"Node \"{source_name}\" not found in graph. Optional nodes are {', '.join(graph.node_types)}." + assert ( + target_name in graph.node_types + ), f"Node \"{target_name}\" not found in graph. Optional nodes are {', '.join(graph.node_types)}." + + values = self.get_raw_values(graph, source_name, target_name, *args, **kwargs) + return self.post_process(values) + + +class AttributeFromNode(BooleanBaseEdgeAttribute): + """ + Copy an attribute of either the source or destination node to the edge. + Accesses origin/target node attribute and propagates it to the edge. + Used for example to identify if an encoder edge originates from a LAM or global node. + + Attributes + ---------- + node_attr_name : str + Name of the node attribute to propagate. + + node_type : str + Pick the node to copy from. Options: "src, dst" + + Methods + ------- + get_raw_values(graph, source_name, target_name) + Computes the edge attribute from the source or destination node attribute. + """ + + def __init__(self, node_attr_name: str, node_type: str) -> None: + self.node_attr_name = node_attr_name + assert node_type in ["src", "dst"] + self.node_type = node_type + + def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str) -> np.ndarray: + + edge_index = graph[(source_name, "to", target_name)].edge_index + + if self.node_type == "src": + name_to_copy = source_name + idx = 0 + + else: + name_to_copy = target_name + idx = 1 + + assert hasattr(graph[name_to_copy], self.node_attr_name) + + val = getattr(graph[name_to_copy], self.node_attr_name).numpy()[edge_index[idx]] + + return val