Skip to content

Commit 6e15083

Browse files
committed
simplify function
Signed-off-by: Vincent Koppen <vincent.koppen@alliander.com>
1 parent 5c3734b commit 6e15083

File tree

3 files changed

+25
-49
lines changed

3 files changed

+25
-49
lines changed

src/power_grid_model_ds/_core/model/grids/_search.py

Lines changed: 15 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# SPDX-License-Identifier: MPL-2.0
44

55
import dataclasses
6-
from collections import defaultdict
6+
from itertools import pairwise
77
from typing import TYPE_CHECKING, Iterator
88

99
import numpy as np
@@ -12,6 +12,7 @@
1212
from power_grid_model_ds._core import fancypy as fp
1313
from power_grid_model_ds._core.model.arrays import BranchArray
1414
from power_grid_model_ds._core.model.arrays.base.errors import RecordDoesNotExist
15+
from power_grid_model_ds._core.model.arrays.pgm_arrays import Branch3Array
1516
from power_grid_model_ds._core.model.enums.nodes import NodeType
1617
from power_grid_model_ds._core.model.graphs.errors import MissingBranchError
1718

@@ -75,58 +76,41 @@ def get_downstream_nodes(grid: "Grid", node_id: int, inclusive: bool = False):
7576
)
7677

7778

78-
_THREE_WINDING_BRANCH_CONFIGS = (
79-
("node_1", "node_2", "status_1", "status_2"),
80-
("node_1", "node_3", "status_1", "status_3"),
81-
("node_2", "node_3", "status_2", "status_3"),
82-
)
83-
84-
85-
def _active_branches_for_path(
86-
grid: "Grid", path_nodes: list[int]
87-
) -> tuple[BranchArray, dict[tuple[int, int], list[int]]]:
79+
def _get_branch(grid: "Grid", from_node: int, to_node: int) -> BranchArray:
8880
"""Return active branch records and an index filtered to the requested path nodes."""
8981

90-
active = grid.branches.filter(from_status=1, to_status=1).filter(
91-
from_node=path_nodes, to_node=path_nodes, mode_="AND"
82+
active_branches = grid.branches.filter(from_status=1, to_status=1).filter(
83+
from_node=from_node, to_node=to_node, mode_="AND"
9284
)
9385
if grid.three_winding_transformer.size:
94-
three_winding_active = (
95-
grid.three_winding_transformer.as_branches()
96-
.filter(from_status=1, to_status=1)
97-
.filter(from_node=path_nodes, to_node=path_nodes, mode_="AND")
86+
three_winding_active = grid.three_winding_transformer.as_branches().filter(
87+
from_status=1, to_status=1, from_node=from_node, to_node=to_node, mode_="AND"
9888
)
9989
if three_winding_active.size:
100-
active = fp.concatenate(active, three_winding_active)
101-
102-
index: dict[tuple[int, int], list[int]] = defaultdict(list)
103-
for position, (source, target) in enumerate(zip(active.from_node, active.to_node)):
104-
index[(int(source), int(target))].append(position)
90+
active_branches = fp.concatenate(active_branches, three_winding_active)
10591

106-
return active, index
92+
return active_branches
10793

10894

10995
def iter_branches_in_shortest_path(
11096
grid: "Grid", from_node_id: int, to_node_id: int, typed: bool = False
111-
) -> Iterator[BranchArray]:
97+
) -> Iterator[BranchArray] | Iterator[BranchArray | Branch3Array]:
11298
"""See Grid.iter_branches_in_shortest_path()."""
11399

114100
path, _ = grid.graphs.active_graph.get_shortest_path(from_node_id, to_node_id)
115-
active_branches, index = _active_branches_for_path(grid, path)
116101

117-
for current_node, next_node in zip(path[:-1], path[1:]):
118-
positions = index.get((current_node, next_node))
119-
if not positions:
102+
for current_node, next_node in pairwise(path):
103+
branches = _get_branch(grid, current_node, next_node)
104+
if branches.size == 0:
120105
raise MissingBranchError(
121106
f"No active branch connects nodes {current_node} -> {next_node} even though a path exists."
122107
)
123-
branch_records = active_branches[positions]
124108
if typed:
125-
branch_ids = branch_records.id.tolist()
109+
branch_ids = branches.id.tolist()
126110
try:
127111
typed_branches = grid.get_typed_branches(branch_ids)
128112
except RecordDoesNotExist:
129113
typed_branches = grid.three_winding_transformer.filter(branch_ids)
130114
yield typed_branches
131115
else:
132-
yield branch_records
116+
yield branches

src/power_grid_model_ds/_core/model/grids/base.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,7 @@
6161
get_downstream_nodes,
6262
get_nearest_substation_node,
6363
get_typed_branches,
64-
)
65-
from power_grid_model_ds._core.model.grids._search import (
66-
iter_branches_in_shortest_path as _iter_branches_in_shortest_path,
64+
iter_branches_in_shortest_path,
6765
)
6866
from power_grid_model_ds._core.model.grids.serialization.json import deserialize_from_json, serialize_to_json
6967
from power_grid_model_ds._core.model.grids.serialization.pickle import load_grid_from_pickle, save_grid_to_pickle
@@ -327,7 +325,7 @@ def get_branches_in_path(self, nodes_in_path: list[int]) -> BranchArray:
327325

328326
def iter_branches_in_shortest_path(
329327
self, from_node_id: int, to_node_id: int, typed: bool = False
330-
) -> Iterator[BranchArray]:
328+
) -> Iterator[BranchArray] | Iterator[BranchArray | Branch3Array]:
331329
"""Returns the ordered active branches that form the shortest path between two nodes. When parallel active edges
332330
are in the path all these branches will be returned for the same from_node and to_node.
333331
@@ -344,7 +342,7 @@ def iter_branches_in_shortest_path(
344342
MissingBranchError: If the graph reports an edge on the shortest path but no active branch is found.
345343
"""
346344

347-
return _iter_branches_in_shortest_path(self, from_node_id, to_node_id, typed=typed)
345+
return iter_branches_in_shortest_path(self, from_node_id, to_node_id, typed=typed)
348346

349347
def get_nearest_substation_node(self, node_id: int):
350348
"""Find the nearest substation node.

tests/unit/model/grids/test_search.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import pytest
77

88
from power_grid_model_ds import Grid
9-
from power_grid_model_ds._core.model.arrays import LineArray, ThreeWindingTransformerArray, TransformerArray
109
from power_grid_model_ds._core.model.arrays.base.errors import RecordDoesNotExist
1110
from power_grid_model_ds._core.model.enums.nodes import NodeType
1211

@@ -60,31 +59,26 @@ def test_get_branches_in_path_empty_path(self, basic_grid):
6059
class TestIterBranchesInShortestPath:
6160
def test_iter_branches_in_shortest_path_returns_branch_arrays(self, basic_grid):
6261
branches = list(basic_grid.iter_branches_in_shortest_path(101, 106))
63-
assert len(branches) == 2
64-
branch_nodes = [(branch.from_node.item(), branch.to_node.item()) for branch in branches]
65-
assert branch_nodes == [(101, 102), (102, 106)]
62+
assert branches == [basic_grid.branches.filter(id=201), basic_grid.branches.filter(id=301)]
6663

6764
def test_iter_branches_in_shortest_path_three_winding_transformer(self, grid_with_3wt):
6865
branches = list(grid_with_3wt.iter_branches_in_shortest_path(101, 104))
6966
assert len(branches) == 2
70-
branch_nodes = [(branch.from_node.item(), branch.to_node.item()) for branch in branches]
71-
assert branch_nodes == [(101, 102), (102, 104)]
67+
assert branches[0].id.item() == 301
68+
assert branches[0].from_node.item() == 101
69+
assert branches[0].to_node.item() == 102
70+
assert branches[1] == grid_with_3wt.branches.filter(id=201)
7271

7372
def test_iter_branches_same_node_returns_empty(self, basic_grid):
7473
assert [] == list(basic_grid.iter_branches_in_shortest_path(101, 101))
7574

7675
def test_iter_branches_in_shortest_path_typed(self, basic_grid):
7776
branches = list(basic_grid.iter_branches_in_shortest_path(101, 106, typed=True))
78-
assert len(branches) == 2
79-
assert isinstance(branches[0], LineArray)
80-
assert isinstance(branches[1], TransformerArray)
81-
assert [branch.id.item() for branch in branches] == [201, 301]
77+
assert branches == [basic_grid.line.filter(id=201), basic_grid.transformer.filter(id=301)]
8278

8379
def test_iter_branches_in_shortest_path_three_winding_transformer_typed(self, grid_with_3wt):
8480
branches = list(grid_with_3wt.iter_branches_in_shortest_path(101, 104, typed=True))
85-
assert len(branches) == 2
86-
assert isinstance(branches[0], ThreeWindingTransformerArray)
87-
assert isinstance(branches[1], LineArray)
81+
assert branches == [grid_with_3wt.three_winding_transformer.filter(id=301), grid_with_3wt.line.filter(id=201)]
8882

8983

9084
def test_component_three_winding_transformer(grid_with_3wt):

0 commit comments

Comments
 (0)