Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions docs/examples/model/graph_examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -37,7 +37,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -60,7 +60,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -105,6 +105,28 @@
"print(f\"Shortest path: {path}, Length: {length}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Branches on the Shortest Path\n",
"\n",
"`Grid.iter_branches_in_shortest_path` walks the same nodes returned by `get_shortest_path` but exposes the actual `BranchArray` records for each edge. Iterate the result to inspect branch IDs, statuses, or any other metadata without recomputing the path."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from power_grid_model_ds import Grid\n",
"\n",
"grid = Grid.from_txt(\"S1 101\", \"101 102\", \"102 103\")\n",
"for branch in grid.iter_branches_in_shortest_path(101, 103):\n",
" print(f\"Branch {branch.id.item()} runs {branch.from_node.item()} → {branch.to_node.item()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -180,7 +202,7 @@
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"display_name": ".venv (3.12.6)",
"language": "python",
"name": "python3"
},
Expand Down
3 changes: 3 additions & 0 deletions src/power_grid_model_ds/_core/model/arrays/pgm_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,18 +145,21 @@ class Branch3Array(IdArray, Branch3):
def as_branches(self) -> BranchArray:
"""Convert Branch3Array to BranchArray."""
branches_1_2 = BranchArray.empty(self.size)
branches_1_2.id = self.id
branches_1_2.from_node = self.node_1
branches_1_2.to_node = self.node_2
branches_1_2.from_status = self.status_1
branches_1_2.to_status = self.status_2

branches_1_3 = BranchArray.empty(self.size)
branches_1_3.id = self.id
branches_1_3.from_node = self.node_1
branches_1_3.to_node = self.node_3
branches_1_3.from_status = self.status_1
branches_1_3.to_status = self.status_3

branches_2_3 = BranchArray.empty(self.size)
branches_2_3.id = self.id
branches_2_3.from_node = self.node_2
branches_2_3.to_node = self.node_3
branches_2_3.from_status = self.status_2
Expand Down
46 changes: 44 additions & 2 deletions src/power_grid_model_ds/_core/model/grids/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@
# SPDX-License-Identifier: MPL-2.0

import dataclasses
from typing import TYPE_CHECKING
from itertools import pairwise
from typing import TYPE_CHECKING, Iterator

import numpy as np
import numpy.typing as npt

from power_grid_model_ds._core import fancypy as fp
from power_grid_model_ds._core.model.arrays.base.errors import RecordDoesNotExist
from power_grid_model_ds._core.model.arrays.pgm_arrays import BranchArray
from power_grid_model_ds._core.model.arrays.pgm_arrays import Branch3Array, BranchArray
from power_grid_model_ds._core.model.enums.nodes import NodeType
from power_grid_model_ds._core.model.graphs.errors import MissingBranchError

if TYPE_CHECKING:
from power_grid_model_ds._core.model.grids.base import Grid
Expand Down Expand Up @@ -71,3 +73,43 @@ def get_downstream_nodes(grid: "Grid", node_id: int, inclusive: bool = False):
return grid.graphs.active_graph.get_downstream_nodes(
node_id=node_id, start_node_ids=list(substation_nodes.id), inclusive=inclusive
)


def _get_branches(grid: "Grid", from_node: int, to_node: int) -> BranchArray:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather have _private functions below the functions that use them. This is also suggested in the book "Clean Code"

"""Return active branch records and an index filtered to the requested path nodes."""

active_branches = grid.branches.filter(from_status=1, to_status=1).filter(
from_node=from_node, to_node=to_node, mode_="AND"
)
if grid.three_winding_transformer.size:
three_winding_active = grid.three_winding_transformer.as_branches().filter(
from_status=1, to_status=1, from_node=from_node, to_node=to_node, mode_="AND"
)
if three_winding_active.size:
active_branches = fp.concatenate(active_branches, three_winding_active)

return active_branches


def iter_branches_in_shortest_path(
grid: "Grid", from_node_id: int, to_node_id: int, typed: bool = False
) -> Iterator[BranchArray] | Iterator[BranchArray | Branch3Array]:
"""See Grid.iter_branches_in_shortest_path()."""

path, _ = grid.graphs.active_graph.get_shortest_path(from_node_id, to_node_id)

for current_node, next_node in pairwise(path):
branches = _get_branches(grid, current_node, next_node)
if branches.size == 0:
raise MissingBranchError(
f"No active branch connects nodes {current_node} -> {next_node} even though a path exists."
)
if typed:
branch_ids = branches.id.tolist()
try:
typed_branches = grid.get_typed_branches(branch_ids)
except RecordDoesNotExist:
typed_branches = grid.three_winding_transformer.filter(branch_ids)
yield typed_branches
else:
yield branches
24 changes: 23 additions & 1 deletion src/power_grid_model_ds/_core/model/grids/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import warnings
from dataclasses import dataclass, fields
from pathlib import Path
from typing import Literal, Self, Type, TypeVar
from typing import Iterator, Literal, Self, Type, TypeVar

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -66,6 +66,7 @@
get_downstream_nodes,
get_nearest_substation_node,
get_typed_branches,
iter_branches_in_shortest_path,
)
from power_grid_model_ds._core.model.grids.serialization.json import deserialize_from_json, serialize_to_json
from power_grid_model_ds._core.model.grids.serialization.pickle import load_grid_from_pickle, save_grid_to_pickle
Expand Down Expand Up @@ -345,6 +346,27 @@ def get_branches_in_path(self, nodes_in_path: list[int]) -> BranchArray:
"""
return self.branches.filter(from_node=nodes_in_path, to_node=nodes_in_path, from_status=1, to_status=1)

def iter_branches_in_shortest_path(
self, from_node_id: int, to_node_id: int, typed: bool = False
) -> Iterator[BranchArray] | Iterator[BranchArray | Branch3Array]:
"""Returns the ordered active branches that form the shortest path between two nodes. When parallel active edges
are in the path all these branches will be returned for the same from_node and to_node.

Args:
from_node_id (int): External id of the path start node.
to_node_id (int): External id of the path end node.
typed (bool): If True, each yielded branch is converted to its typed array via
``get_typed_branches``.

Yields:
BranchArray: branch arrays for each active branch on the path.

Raises:
MissingBranchError: If the graph reports an edge on the shortest path but no active branch is found.
"""

return iter_branches_in_shortest_path(self, from_node_id, to_node_id, typed=typed)

def get_nearest_substation_node(self, node_id: int):
"""Find the nearest substation node.

Expand Down
25 changes: 25 additions & 0 deletions tests/unit/model/grids/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,31 @@ def test_get_branches_in_path_empty_path(self, basic_grid):
assert 0 == branches.size


class TestIterBranchesInShortestPath:
def test_iter_branches_in_shortest_path_returns_branch_arrays(self, basic_grid):
branches = list(basic_grid.iter_branches_in_shortest_path(101, 106))
assert branches == [basic_grid.branches.filter(id=201), basic_grid.branches.filter(id=301)]

def test_iter_branches_in_shortest_path_three_winding_transformer(self, grid_with_3wt):
branches = list(grid_with_3wt.iter_branches_in_shortest_path(101, 104))
assert len(branches) == 2
assert branches[0].id.item() == 301
assert branches[0].from_node.item() == 101
assert branches[0].to_node.item() == 102
assert branches[1] == grid_with_3wt.branches.filter(id=201)

def test_iter_branches_same_node_returns_empty(self, basic_grid):
assert [] == list(basic_grid.iter_branches_in_shortest_path(101, 101))

def test_iter_branches_in_shortest_path_typed(self, basic_grid):
branches = list(basic_grid.iter_branches_in_shortest_path(101, 106, typed=True))
assert branches == [basic_grid.line.filter(id=201), basic_grid.transformer.filter(id=301)]

def test_iter_branches_in_shortest_path_three_winding_transformer_typed(self, grid_with_3wt):
branches = list(grid_with_3wt.iter_branches_in_shortest_path(101, 104, typed=True))
assert branches == [grid_with_3wt.three_winding_transformer.filter(id=301), grid_with_3wt.line.filter(id=201)]


Comment on lines +97 to +98
Copy link

Copilot AI Feb 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding test coverage for the MissingBranchError case that would be raised at line 142 of _search.py. This would happen when the graph reports a path but no active branch exists between consecutive nodes, which could occur if the graph state is inconsistent with the branch data. A test would help ensure this error is raised with a clear message in such edge cases.

Suggested change
def test_iter_branches_in_shortest_path_missing_branch_raises(self, basic_grid, monkeypatch):
# Simulate an inconsistent state where the graph reports a path but
# there is no active branch between consecutive nodes in that path.
def fake_get_shortest_path(*args, **kwargs):
# Return a path with a node pair that has no active branch
return [101, 999], 1
monkeypatch.object(basic_grid.graphs.active_graph, "get_shortest_path", fake_get_shortest_path)
with pytest.raises(Exception):
list(basic_grid.iter_branches_in_shortest_path(101, 999))

Copilot uses AI. Check for mistakes.
def test_component_three_winding_transformer(grid_with_3wt):
substation_nodes = grid_with_3wt.node.filter(node_type=NodeType.SUBSTATION_NODE.value).id
with grid_with_3wt.graphs.active_graph.tmp_remove_nodes(substation_nodes):
Expand Down