|
3 | 3 | # SPDX-License-Identifier: MPL-2.0 |
4 | 4 |
|
5 | 5 | import dataclasses |
6 | | -from collections import defaultdict |
| 6 | +from itertools import pairwise |
7 | 7 | from typing import TYPE_CHECKING, Iterator |
8 | 8 |
|
9 | 9 | import numpy as np |
|
12 | 12 | from power_grid_model_ds._core import fancypy as fp |
13 | 13 | from power_grid_model_ds._core.model.arrays import BranchArray |
14 | 14 | from power_grid_model_ds._core.model.arrays.base.errors import RecordDoesNotExist |
| 15 | +from power_grid_model_ds._core.model.arrays.pgm_arrays import Branch3Array |
15 | 16 | from power_grid_model_ds._core.model.enums.nodes import NodeType |
16 | 17 | from power_grid_model_ds._core.model.graphs.errors import MissingBranchError |
17 | 18 |
|
@@ -75,58 +76,41 @@ def get_downstream_nodes(grid: "Grid", node_id: int, inclusive: bool = False): |
75 | 76 | ) |
76 | 77 |
|
77 | 78 |
|
78 | | -_THREE_WINDING_BRANCH_CONFIGS = ( |
79 | | - ("node_1", "node_2", "status_1", "status_2"), |
80 | | - ("node_1", "node_3", "status_1", "status_3"), |
81 | | - ("node_2", "node_3", "status_2", "status_3"), |
82 | | -) |
83 | | - |
84 | | - |
85 | | -def _active_branches_for_path( |
86 | | - grid: "Grid", path_nodes: list[int] |
87 | | -) -> tuple[BranchArray, dict[tuple[int, int], list[int]]]: |
| 79 | +def _get_branch(grid: "Grid", from_node: int, to_node: int) -> BranchArray: |
88 | 80 | """Return active branch records and an index filtered to the requested path nodes.""" |
89 | 81 |
|
90 | | - active = grid.branches.filter(from_status=1, to_status=1).filter( |
91 | | - from_node=path_nodes, to_node=path_nodes, mode_="AND" |
| 82 | + active_branches = grid.branches.filter(from_status=1, to_status=1).filter( |
| 83 | + from_node=from_node, to_node=to_node, mode_="AND" |
92 | 84 | ) |
93 | 85 | if grid.three_winding_transformer.size: |
94 | | - three_winding_active = ( |
95 | | - grid.three_winding_transformer.as_branches() |
96 | | - .filter(from_status=1, to_status=1) |
97 | | - .filter(from_node=path_nodes, to_node=path_nodes, mode_="AND") |
| 86 | + three_winding_active = grid.three_winding_transformer.as_branches().filter( |
| 87 | + from_status=1, to_status=1, from_node=from_node, to_node=to_node, mode_="AND" |
98 | 88 | ) |
99 | 89 | if three_winding_active.size: |
100 | | - active = fp.concatenate(active, three_winding_active) |
101 | | - |
102 | | - index: dict[tuple[int, int], list[int]] = defaultdict(list) |
103 | | - for position, (source, target) in enumerate(zip(active.from_node, active.to_node)): |
104 | | - index[(int(source), int(target))].append(position) |
| 90 | + active_branches = fp.concatenate(active_branches, three_winding_active) |
105 | 91 |
|
106 | | - return active, index |
| 92 | + return active_branches |
107 | 93 |
|
108 | 94 |
|
109 | 95 | def iter_branches_in_shortest_path( |
110 | 96 | grid: "Grid", from_node_id: int, to_node_id: int, typed: bool = False |
111 | | -) -> Iterator[BranchArray]: |
| 97 | +) -> Iterator[BranchArray] | Iterator[BranchArray | Branch3Array]: |
112 | 98 | """See Grid.iter_branches_in_shortest_path().""" |
113 | 99 |
|
114 | 100 | path, _ = grid.graphs.active_graph.get_shortest_path(from_node_id, to_node_id) |
115 | | - active_branches, index = _active_branches_for_path(grid, path) |
116 | 101 |
|
117 | | - for current_node, next_node in zip(path[:-1], path[1:]): |
118 | | - positions = index.get((current_node, next_node)) |
119 | | - if not positions: |
| 102 | + for current_node, next_node in pairwise(path): |
| 103 | + branches = _get_branch(grid, current_node, next_node) |
| 104 | + if branches.size == 0: |
120 | 105 | raise MissingBranchError( |
121 | 106 | f"No active branch connects nodes {current_node} -> {next_node} even though a path exists." |
122 | 107 | ) |
123 | | - branch_records = active_branches[positions] |
124 | 108 | if typed: |
125 | | - branch_ids = branch_records.id.tolist() |
| 109 | + branch_ids = branches.id.tolist() |
126 | 110 | try: |
127 | 111 | typed_branches = grid.get_typed_branches(branch_ids) |
128 | 112 | except RecordDoesNotExist: |
129 | 113 | typed_branches = grid.three_winding_transformer.filter(branch_ids) |
130 | 114 | yield typed_branches |
131 | 115 | else: |
132 | | - yield branch_records |
| 116 | + yield branches |
0 commit comments