diff --git a/docs/examples/model/graph_examples.ipynb b/docs/examples/model/graph_examples.ipynb index fc92f7ba..33cb4587 100644 --- a/docs/examples/model/graph_examples.ipynb +++ b/docs/examples/model/graph_examples.ipynb @@ -15,7 +15,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -37,7 +37,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -60,7 +60,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -105,6 +105,28 @@ "print(f\"Shortest path: {path}, Length: {length}\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Branches on the Shortest Path\n", + "\n", + "`Grid.iter_branches_in_shortest_path` walks the same nodes returned by `get_shortest_path` but exposes the actual `BranchArray` records for each edge. Iterate the result to inspect branch IDs, statuses, or any other metadata without recomputing the path." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from power_grid_model_ds import Grid\n", + "\n", + "grid = Grid.from_txt(\"S1 101\", \"101 102\", \"102 103\")\n", + "for branch in grid.iter_branches_in_shortest_path(101, 103):\n", + " print(f\"Branch {branch.id.item()} runs {branch.from_node.item()} → {branch.to_node.item()}\")" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -180,7 +202,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv", + "display_name": ".venv (3.12.6)", "language": "python", "name": "python3" }, diff --git a/src/power_grid_model_ds/_core/model/arrays/pgm_arrays.py b/src/power_grid_model_ds/_core/model/arrays/pgm_arrays.py index 74e98acc..180d3960 100644 --- a/src/power_grid_model_ds/_core/model/arrays/pgm_arrays.py +++ b/src/power_grid_model_ds/_core/model/arrays/pgm_arrays.py @@ -145,18 +145,21 @@ class Branch3Array(IdArray, Branch3): def as_branches(self) -> BranchArray: """Convert Branch3Array to BranchArray.""" branches_1_2 = BranchArray.empty(self.size) + branches_1_2.id = self.id branches_1_2.from_node = self.node_1 branches_1_2.to_node = self.node_2 branches_1_2.from_status = self.status_1 branches_1_2.to_status = self.status_2 branches_1_3 = BranchArray.empty(self.size) + branches_1_3.id = self.id branches_1_3.from_node = self.node_1 branches_1_3.to_node = self.node_3 branches_1_3.from_status = self.status_1 branches_1_3.to_status = self.status_3 branches_2_3 = BranchArray.empty(self.size) + branches_2_3.id = self.id branches_2_3.from_node = self.node_2 branches_2_3.to_node = self.node_3 branches_2_3.from_status = self.status_2 diff --git a/src/power_grid_model_ds/_core/model/grids/_search.py b/src/power_grid_model_ds/_core/model/grids/_search.py index 45da7f9d..2cb73f74 100644 --- a/src/power_grid_model_ds/_core/model/grids/_search.py +++ b/src/power_grid_model_ds/_core/model/grids/_search.py @@ -3,15 +3,17 @@ # SPDX-License-Identifier: MPL-2.0 import dataclasses -from typing import TYPE_CHECKING +from itertools import pairwise +from typing import TYPE_CHECKING, Iterator import numpy as np import numpy.typing as npt from power_grid_model_ds._core import fancypy as fp from power_grid_model_ds._core.model.arrays.base.errors import RecordDoesNotExist -from power_grid_model_ds._core.model.arrays.pgm_arrays import BranchArray +from power_grid_model_ds._core.model.arrays.pgm_arrays import Branch3Array, BranchArray from power_grid_model_ds._core.model.enums.nodes import NodeType +from power_grid_model_ds._core.model.graphs.errors import MissingBranchError if TYPE_CHECKING: from power_grid_model_ds._core.model.grids.base import Grid @@ -71,3 +73,43 @@ def get_downstream_nodes(grid: "Grid", node_id: int, inclusive: bool = False): return grid.graphs.active_graph.get_downstream_nodes( node_id=node_id, start_node_ids=list(substation_nodes.id), inclusive=inclusive ) + + +def _get_branches(grid: "Grid", from_node: int, to_node: int) -> BranchArray: + """Return active branch records and an index filtered to the requested path nodes.""" + + active_branches = grid.branches.filter(from_status=1, to_status=1).filter( + from_node=from_node, to_node=to_node, mode_="AND" + ) + if grid.three_winding_transformer.size: + three_winding_active = grid.three_winding_transformer.as_branches().filter( + from_status=1, to_status=1, from_node=from_node, to_node=to_node, mode_="AND" + ) + if three_winding_active.size: + active_branches = fp.concatenate(active_branches, three_winding_active) + + return active_branches + + +def iter_branches_in_shortest_path( + grid: "Grid", from_node_id: int, to_node_id: int, typed: bool = False +) -> Iterator[BranchArray] | Iterator[BranchArray | Branch3Array]: + """See Grid.iter_branches_in_shortest_path().""" + + path, _ = grid.graphs.active_graph.get_shortest_path(from_node_id, to_node_id) + + for current_node, next_node in pairwise(path): + branches = _get_branches(grid, current_node, next_node) + if branches.size == 0: + raise MissingBranchError( + f"No active branch connects nodes {current_node} -> {next_node} even though a path exists." + ) + if typed: + branch_ids = branches.id.tolist() + try: + typed_branches = grid.get_typed_branches(branch_ids) + except RecordDoesNotExist: + typed_branches = grid.three_winding_transformer.filter(branch_ids) + yield typed_branches + else: + yield branches diff --git a/src/power_grid_model_ds/_core/model/grids/base.py b/src/power_grid_model_ds/_core/model/grids/base.py index 9f9a998d..da1063e4 100644 --- a/src/power_grid_model_ds/_core/model/grids/base.py +++ b/src/power_grid_model_ds/_core/model/grids/base.py @@ -7,7 +7,7 @@ import warnings from dataclasses import dataclass, fields from pathlib import Path -from typing import Literal, Self, Type, TypeVar +from typing import Iterator, Literal, Self, Type, TypeVar import numpy as np import numpy.typing as npt @@ -66,6 +66,7 @@ get_downstream_nodes, get_nearest_substation_node, get_typed_branches, + iter_branches_in_shortest_path, ) from power_grid_model_ds._core.model.grids.serialization.json import deserialize_from_json, serialize_to_json from power_grid_model_ds._core.model.grids.serialization.pickle import load_grid_from_pickle, save_grid_to_pickle @@ -345,6 +346,27 @@ def get_branches_in_path(self, nodes_in_path: list[int]) -> BranchArray: """ return self.branches.filter(from_node=nodes_in_path, to_node=nodes_in_path, from_status=1, to_status=1) + def iter_branches_in_shortest_path( + self, from_node_id: int, to_node_id: int, typed: bool = False + ) -> Iterator[BranchArray] | Iterator[BranchArray | Branch3Array]: + """Returns the ordered active branches that form the shortest path between two nodes. When parallel active edges + are in the path all these branches will be returned for the same from_node and to_node. + + Args: + from_node_id (int): External id of the path start node. + to_node_id (int): External id of the path end node. + typed (bool): If True, each yielded branch is converted to its typed array via + ``get_typed_branches``. + + Yields: + BranchArray: branch arrays for each active branch on the path. + + Raises: + MissingBranchError: If the graph reports an edge on the shortest path but no active branch is found. + """ + + return iter_branches_in_shortest_path(self, from_node_id, to_node_id, typed=typed) + def get_nearest_substation_node(self, node_id: int): """Find the nearest substation node. diff --git a/tests/unit/model/grids/test_search.py b/tests/unit/model/grids/test_search.py index f27cf35f..c3e595a0 100644 --- a/tests/unit/model/grids/test_search.py +++ b/tests/unit/model/grids/test_search.py @@ -71,6 +71,31 @@ def test_get_branches_in_path_empty_path(self, basic_grid): assert 0 == branches.size +class TestIterBranchesInShortestPath: + def test_iter_branches_in_shortest_path_returns_branch_arrays(self, basic_grid): + branches = list(basic_grid.iter_branches_in_shortest_path(101, 106)) + assert branches == [basic_grid.branches.filter(id=201), basic_grid.branches.filter(id=301)] + + def test_iter_branches_in_shortest_path_three_winding_transformer(self, grid_with_3wt): + branches = list(grid_with_3wt.iter_branches_in_shortest_path(101, 104)) + assert len(branches) == 2 + assert branches[0].id.item() == 301 + assert branches[0].from_node.item() == 101 + assert branches[0].to_node.item() == 102 + assert branches[1] == grid_with_3wt.branches.filter(id=201) + + def test_iter_branches_same_node_returns_empty(self, basic_grid): + assert [] == list(basic_grid.iter_branches_in_shortest_path(101, 101)) + + def test_iter_branches_in_shortest_path_typed(self, basic_grid): + branches = list(basic_grid.iter_branches_in_shortest_path(101, 106, typed=True)) + assert branches == [basic_grid.line.filter(id=201), basic_grid.transformer.filter(id=301)] + + def test_iter_branches_in_shortest_path_three_winding_transformer_typed(self, grid_with_3wt): + branches = list(grid_with_3wt.iter_branches_in_shortest_path(101, 104, typed=True)) + assert branches == [grid_with_3wt.three_winding_transformer.filter(id=301), grid_with_3wt.line.filter(id=201)] + + def test_component_three_winding_transformer(grid_with_3wt): substation_nodes = grid_with_3wt.node.filter(node_type=NodeType.SUBSTATION_NODE.value).id with grid_with_3wt.graphs.active_graph.tmp_remove_nodes(substation_nodes):