Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions docs/examples/model/graph_examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -37,7 +37,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -60,7 +60,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -105,6 +105,28 @@
"print(f\"Shortest path: {path}, Length: {length}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Branches on the Shortest Path\n",
"\n",
"`Grid.iter_branches_in_shortest_path` walks the same nodes returned by `get_shortest_path` but exposes the actual `BranchArray` records for each edge. Iterate the result to inspect branch IDs, statuses, or any other metadata without recomputing the path."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from power_grid_model_ds import Grid\n",
"\n",
"grid = Grid.from_txt(\"S1 101\", \"101 102\", \"102 103\")\n",
"for branch in grid.iter_branches_in_shortest_path(101, 103):\n",
" print(f\"Branch {branch.id.item()} runs {branch.from_node.item()} → {branch.to_node.item()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -180,7 +202,7 @@
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"display_name": ".venv (3.12.6)",
"language": "python",
"name": "python3"
},
Expand Down
85 changes: 83 additions & 2 deletions src/power_grid_model_ds/_core/model/grids/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@
# SPDX-License-Identifier: MPL-2.0

import dataclasses
from typing import TYPE_CHECKING
from collections import defaultdict
from typing import TYPE_CHECKING, Iterator

import numpy as np
import numpy.typing as npt

from power_grid_model_ds._core import fancypy as fp
from power_grid_model_ds._core.model.arrays import BranchArray
from power_grid_model_ds._core.model.arrays import BranchArray, ThreeWindingTransformerArray
from power_grid_model_ds._core.model.arrays.base.errors import RecordDoesNotExist
from power_grid_model_ds._core.model.enums.nodes import NodeType
from power_grid_model_ds._core.model.graphs.errors import MissingBranchError

if TYPE_CHECKING:
from power_grid_model_ds._core.model.grids.base import Grid
Expand Down Expand Up @@ -71,3 +73,82 @@ def get_downstream_nodes(grid: "Grid", node_id: int, inclusive: bool = False):
return grid.graphs.active_graph.get_downstream_nodes(
node_id=node_id, start_node_ids=list(substation_nodes.id), inclusive=inclusive
)


_THREE_WINDING_BRANCH_CONFIGS = (
("node_1", "node_2", "status_1", "status_2"),
("node_1", "node_3", "status_1", "status_3"),
("node_2", "node_3", "status_2", "status_3"),
)


def _lookup_three_winding_branch(grid: "Grid", node_a: int, node_b: int) -> ThreeWindingTransformerArray:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer having both _private function below the public one.
(This is also recommended in the book Clean Clode)

"""Return the first active transformer that connects the node pair or raise if none exist."""

three_winding_array = grid.three_winding_transformer
error_message = f"No active three-winding transformer connects nodes {node_a} -> {node_b}."
if not three_winding_array.size:
raise MissingBranchError(error_message)

for node_col_a, node_col_b, status_col_a, status_col_b in _THREE_WINDING_BRANCH_CONFIGS:
transformer = three_winding_array.filter(
**{
node_col_a: [node_a, node_b],
node_col_b: [node_a, node_b],
status_col_a: 1,
status_col_b: 1,
} # type: ignore[arg-type]
)
if transformer.size:
return transformer
raise MissingBranchError(error_message)


def _active_branches_for_path(
grid: "Grid", path_nodes: list[int]
) -> tuple[BranchArray, dict[tuple[int, int], list[int]]]:
"""Return active branch records and an index filtered to the requested path nodes."""

active = grid.branches.filter(from_status=1, to_status=1).filter(
from_node=path_nodes, to_node=path_nodes, mode_="OR"
)
if grid.three_winding_transformer.size:
three_winding_active = (
grid.three_winding_transformer.as_branches()
.filter(from_status=1, to_status=1)
.filter(from_node=path_nodes, to_node=path_nodes, mode_="OR")
)
if three_winding_active.size:
active = fp.concatenate(active, three_winding_active)

index: dict[tuple[int, int], list[int]] = defaultdict(list)
for position, (source, target) in enumerate(zip(active.from_node, active.to_node)):
index[(int(source), int(target))].append(position)

return active, index


def iter_branches_in_shortest_path(
grid: "Grid", from_node_id: int, to_node_id: int, typed: bool = False
) -> Iterator[BranchArray]:
"""See Grid.iter_branches_in_shortest_path()."""

path, _ = grid.graphs.active_graph.get_shortest_path(from_node_id, to_node_id)
active_branches, index = _active_branches_for_path(grid, path)

for current_node, next_node in zip(path[:-1], path[1:]):
positions = index.get((current_node, next_node))
if not positions:
raise MissingBranchError(
f"No active branch connects nodes {current_node} -> {next_node} even though a path exists."
)
branch_records = active_branches[positions]
if typed:
branch_ids = branch_records.id.tolist()
try:
typed_branches = grid.get_typed_branches(branch_ids)
except RecordDoesNotExist:
typed_branches = _lookup_three_winding_branch(grid, current_node, next_node)
yield typed_branches
else:
yield branch_records
25 changes: 24 additions & 1 deletion src/power_grid_model_ds/_core/model/grids/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Literal, Self, Type, TypeVar
from typing import Iterator, Literal, Self, Type, TypeVar

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -62,6 +62,9 @@
get_nearest_substation_node,
get_typed_branches,
)
from power_grid_model_ds._core.model.grids._search import (
iter_branches_in_shortest_path as _iter_branches_in_shortest_path,
)
from power_grid_model_ds._core.model.grids.serialization.json import deserialize_from_json, serialize_to_json
from power_grid_model_ds._core.model.grids.serialization.pickle import load_grid_from_pickle, save_grid_to_pickle
from power_grid_model_ds._core.model.grids.serialization.string import (
Expand Down Expand Up @@ -322,6 +325,26 @@ def get_branches_in_path(self, nodes_in_path: list[int]) -> BranchArray:
"""
return self.branches.filter(from_node=nodes_in_path, to_node=nodes_in_path, from_status=1, to_status=1)

def iter_branches_in_shortest_path(
self, from_node_id: int, to_node_id: int, typed: bool = False
) -> Iterator[BranchArray]:
"""Returns the ordered active branches that form the shortest path between two nodes.

Args:
from_node_id (int): External id of the path start node.
to_node_id (int): External id of the path end node.
typed (bool): If True, each yielded branch is converted to its typed array via
``get_typed_branches``.

Yields:
BranchArray: branch arrays for each active branch on the path.

Raises:
MissingBranchError: If the graph reports an edge on the shortest path but no active branch is found.
"""

return _iter_branches_in_shortest_path(self, from_node_id, to_node_id, typed=typed)

def get_nearest_substation_node(self, node_id: int):
"""Find the nearest substation node.

Expand Down
31 changes: 31 additions & 0 deletions tests/unit/model/grids/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from power_grid_model_ds import Grid
from power_grid_model_ds._core.model.arrays import LineArray, ThreeWindingTransformerArray, TransformerArray
from power_grid_model_ds._core.model.arrays.base.errors import RecordDoesNotExist
from power_grid_model_ds._core.model.enums.nodes import NodeType

Expand Down Expand Up @@ -56,6 +57,36 @@ def test_get_branches_in_path_empty_path(self, basic_grid):
assert 0 == branches.size


class TestIterBranchesInShortestPath:
def test_iter_branches_in_shortest_path_returns_branch_arrays(self, basic_grid):
branches = list(basic_grid.iter_branches_in_shortest_path(101, 106))
assert len(branches) == 2
branch_nodes = [(branch.from_node.item(), branch.to_node.item()) for branch in branches]
assert branch_nodes == [(101, 102), (102, 106)]

def test_iter_branches_in_shortest_path_three_winding_transformer(self, grid_with_3wt):
branches = list(grid_with_3wt.iter_branches_in_shortest_path(101, 104))
assert len(branches) == 2
branch_nodes = [(branch.from_node.item(), branch.to_node.item()) for branch in branches]
assert branch_nodes == [(101, 102), (102, 104)]

def test_iter_branches_same_node_returns_empty(self, basic_grid):
assert [] == list(basic_grid.iter_branches_in_shortest_path(101, 101))

def test_iter_branches_in_shortest_path_typed(self, basic_grid):
branches = list(basic_grid.iter_branches_in_shortest_path(101, 106, typed=True))
assert len(branches) == 2
assert isinstance(branches[0], LineArray)
assert isinstance(branches[1], TransformerArray)
assert [branch.id.item() for branch in branches] == [201, 301]

def test_iter_branches_in_shortest_path_three_winding_transformer_typed(self, grid_with_3wt):
branches = list(grid_with_3wt.iter_branches_in_shortest_path(101, 104, typed=True))
assert len(branches) == 2
assert isinstance(branches[0], ThreeWindingTransformerArray)
assert isinstance(branches[1], LineArray)


Comment on lines +97 to +98
Copy link

Copilot AI Feb 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding test coverage for the MissingBranchError case that would be raised at line 142 of _search.py. This would happen when the graph reports a path but no active branch exists between consecutive nodes, which could occur if the graph state is inconsistent with the branch data. A test would help ensure this error is raised with a clear message in such edge cases.

Suggested change
def test_iter_branches_in_shortest_path_missing_branch_raises(self, basic_grid, monkeypatch):
# Simulate an inconsistent state where the graph reports a path but
# there is no active branch between consecutive nodes in that path.
def fake_get_shortest_path(*args, **kwargs):
# Return a path with a node pair that has no active branch
return [101, 999], 1
monkeypatch.object(basic_grid.graphs.active_graph, "get_shortest_path", fake_get_shortest_path)
with pytest.raises(Exception):
list(basic_grid.iter_branches_in_shortest_path(101, 999))

Copilot uses AI. Check for mistakes.
def test_component_three_winding_transformer(grid_with_3wt):
substation_nodes = grid_with_3wt.node.filter(node_type=NodeType.SUBSTATION_NODE.value).id
with grid_with_3wt.graphs.active_graph.tmp_remove_nodes(substation_nodes):
Expand Down