diff --git a/src/power_grid_model_ds/_core/model/graphs/models/base.py b/src/power_grid_model_ds/_core/model/graphs/models/base.py index ec0bd4a6..0488c56f 100644 --- a/src/power_grid_model_ds/_core/model/graphs/models/base.py +++ b/src/power_grid_model_ds/_core/model/graphs/models/base.py @@ -4,6 +4,7 @@ import warnings from abc import ABC, abstractmethod from contextlib import contextmanager +from itertools import combinations from typing import TYPE_CHECKING, Counter, Generator from numpy._typing import NDArray @@ -26,6 +27,7 @@ class BaseGraphModel(ABC): def __init__(self, active_only=False) -> None: self.active_only = active_only + self._three_winding_branch_groups: set[tuple[int]] = set() def __repr__(self) -> str: return ( @@ -184,7 +186,11 @@ def add_branch_array(self, branch_array: BranchArray) -> None: def add_branch3_array(self, branch3_array: Branch3Array) -> None: """Add all branch3s in the branch3 array to the graph.""" for branch3 in branch3_array: - self.add_branch_array(branch3.as_branches()) + seperate_branches = branch3.as_branches() + self.add_branch_array(seperate_branches) + self._three_winding_branch_groups.add( + tuple(sorted([branch3.node_1.item(), branch3.node_2.item(), branch3.node_3.item()])) + ) def delete_branch_array(self, branch_array: BranchArray, raise_on_fail: bool = True) -> None: """Delete all branches in branch_array from the graph.""" @@ -196,6 +202,8 @@ def delete_branch3_array(self, branch3_array: Branch3Array, raise_on_fail: bool """Delete all branch3s in the branch3 array from the graph.""" for branch3 in branch3_array: self.delete_branch_array(branch3.as_branches(), raise_on_fail=raise_on_fail) + nodes = tuple(sorted([branch3.node_1.item(), branch3.node_2.item(), branch3.node_3.item()])) + self._three_winding_branch_groups.remove(nodes) @contextmanager def tmp_remove_nodes(self, nodes: list[int]) -> Generator: @@ -219,6 +227,16 @@ def tmp_remove_nodes(self, nodes: list[int]) -> Generator: for source, target in edge_list: self.add_branch(source, target) + @contextmanager + def tmp_remove_branches(self, branches: list[tuple[int, int]]) -> Generator: + for from_node, to_node in branches: + self.delete_branch(from_node, to_node) + + yield + + for from_node, to_node in branches: + self.add_branch(from_node, to_node) + def get_shortest_path(self, ext_start_node_id: int, ext_end_node_id: int) -> tuple[list[int], int]: """Calculate the shortest path between two nodes @@ -246,6 +264,45 @@ def get_shortest_path(self, ext_start_node_id: int, ext_end_node_id: int) -> tup except NoPathBetweenNodes as e: raise NoPathBetweenNodes(f"No path between nodes {ext_start_node_id} and {ext_end_node_id}") from e + def _squash_paths_inside_three_winding_transformers(self, paths: list[list[int]]) -> list[list[int]]: + """For each path check if we can squash it. + + If a path contains three nodes of the three winding transformer, we can remove the middle node. + Since we can always go directly. + + NOTE: if a node has an inactive status, we can never have a situation where + we go through 3 nodes of the transformer directly after each other. + So this also works for inactive three winding transformers.""" + replacements = set() + for group in self._three_winding_branch_groups: + replacements.add(frozenset(group)) + + for i in range(len(paths)): + path = paths[i] + + invalid_indices = set() + for index in range(1, len(path) - 1): + if frozenset(sorted([path[index - 1], path[index], path[index + 1]])) in replacements: + invalid_indices.add(index) + + paths[i] = [node for index, node in enumerate(path) if index not in invalid_indices] + return paths + + def _branches_to_remove_from_three_winding_transformers(self) -> list[tuple[int, int]]: + """Returns a list of branches that should be temporarily removed cycles by three winding transformers. + + By removing one branch, we still can make all the paths but do not create to many solutions/invalid solutions. + Later on, we can remove any detour we made because we didn't use the removed branch. + + NOTE: we only want to remove a branch for a three winding transformer if all three of branches are active. + """ + branches_to_remove = [] + for group in self._three_winding_branch_groups: + if all(self.has_branch(from_node, to_node) for from_node, to_node in combinations(group, 2)): + branches_to_remove.append((group[1], group[2])) + + return branches_to_remove + def get_all_paths(self, ext_start_node_id: int, ext_end_node_id: int) -> list[list[int]]: """Retrieves all paths between two (external) nodes. Returns a list of paths, each path containing a list of external nodes. @@ -253,12 +310,27 @@ def get_all_paths(self, ext_start_node_id: int, ext_end_node_id: int) -> list[li if ext_start_node_id == ext_end_node_id: return [] - internal_paths = self._get_all_paths( - source=self.external_to_internal(ext_start_node_id), - target=self.external_to_internal(ext_end_node_id), - ) + correct_for_three_winding = False + if self._three_winding_branch_groups: + branches_to_remove = self._branches_to_remove_from_three_winding_transformers() + correct_for_three_winding = True + + if correct_for_three_winding: + with self.tmp_remove_branches(branches_to_remove): + internal_paths = self._get_all_paths( + source=self.external_to_internal(ext_start_node_id), + target=self.external_to_internal(ext_end_node_id), + ) + else: + internal_paths = self._get_all_paths( + source=self.external_to_internal(ext_start_node_id), + target=self.external_to_internal(ext_end_node_id), + ) - return [self._internals_to_externals(path) for path in internal_paths] + paths = [self._internals_to_externals(path) for path in internal_paths] + if correct_for_three_winding: + return self._squash_paths_inside_three_winding_transformers(paths) + return paths def get_components(self) -> list[list[int]]: """Returns all separate components of the graph as lists @@ -343,8 +415,22 @@ def find_fundamental_cycles(self) -> list[list[int]]: Returns: list[list[int]]: list of cycles, each cycle is a list of (external) node ids """ - internal_cycles = self._find_fundamental_cycles() - return [self._internals_to_externals(nodes) for nodes in internal_cycles] + correct_for_three_winding = False + if self._three_winding_branch_groups: + branches_to_remove = self._branches_to_remove_from_three_winding_transformers() + correct_for_three_winding = True + + if correct_for_three_winding: + with self.tmp_remove_branches(branches_to_remove): + internal_cycles = self._find_fundamental_cycles() + else: + internal_cycles = self._find_fundamental_cycles() + + cycles = [self._internals_to_externals(nodes) for nodes in internal_cycles] + + if correct_for_three_winding: + return self._squash_paths_inside_three_winding_transformers(cycles) + return cycles @classmethod def from_arrays(cls, arrays: "Grid", active_only=False) -> "BaseGraphModel": diff --git a/tests/unit/model/graphs/test_graph_model.py b/tests/unit/model/graphs/test_graph_model.py index cedf2251..7b0ef81f 100644 --- a/tests/unit/model/graphs/test_graph_model.py +++ b/tests/unit/model/graphs/test_graph_model.py @@ -369,6 +369,27 @@ def test_find_first_connected_no_match(self, graph_with_2_routes: BaseGraphModel graph.find_first_connected(1, candidate_node_ids=[99]) +class TestTmpRemoveBranches: + def test_tmp_remove_branches(self, graph_with_2_routes: BaseGraphModel): + graph = deepcopy(graph_with_2_routes) + + assert graph.has_branch(1, 2) + assert graph.has_branch(2, 3) + + with graph.tmp_remove_branches([(1, 2), (2, 3)]): + assert not graph.has_branch(1, 2) + assert not graph.has_branch(2, 3) + + assert graph == graph_with_2_routes + + def test_tmp_remove_branches_non_existent_branch(self, graph_with_2_routes: BaseGraphModel): + graph = deepcopy(graph_with_2_routes) + + with pytest.raises(MissingBranchError): + with graph.tmp_remove_branches([(1, 99)]): + pass + + class TestEq: def test_eq(self, graph_with_2_routes: BaseGraphModel): copied_graph = deepcopy(graph_with_2_routes) diff --git a/tests/unit/model/graphs/test_three_winding.py b/tests/unit/model/graphs/test_three_winding.py new file mode 100644 index 00000000..b4edf898 --- /dev/null +++ b/tests/unit/model/graphs/test_three_winding.py @@ -0,0 +1,194 @@ +import pytest + +from power_grid_model_ds._core.model.arrays.pgm_arrays import NodeArray, ThreeWindingTransformerArray +from power_grid_model_ds._core.model.graphs.models.rustworkx import RustworkxGraphModel + + +def _setup_graph(graph: RustworkxGraphModel) -> None: + nodes = NodeArray.empty(12) + nodes.id = [1, 2, 3, 4, 5, 6, 10, 20, 30, 40, 50, 60] + + graph.add_node_array(nodes) + + # The three winding arrays are open at 5 and 50. + # TODO: add all the 10-60 and 10-40 as well. + # ------------------- + # | | + # 1 -- 2 --- 3 -- 4 10 -- 20 --- 30 -- 40 + # | | | | + # 5 | 50 | + # | | | | + # 6 ---------60-------- + + three = ThreeWindingTransformerArray.empty(2) + three.node_1 = [2, 20] + three.node_2 = [3, 30] + three.node_3 = [5, 50] + three.status_1 = [1, 1] + three.status_2 = [1, 1] + three.status_3 = [0, 0] + graph.add_branch3_array(three) + + graph.add_branch(1, 2) + graph.add_branch(3, 4) + graph.add_branch(5, 6) + + graph.add_branch(10, 20) + graph.add_branch(30, 40) + graph.add_branch(50, 60) + + graph.add_branch(10, 60) + graph.add_branch(10, 40) + graph.add_branch(40, 60) + return graph + + +@pytest.fixture(params=[True, False], ids=["active", "complete"]) +def active_only(request): + return request.param + + +@pytest.fixture +def graph(active_only): + return _setup_graph(RustworkxGraphModel(active_only=active_only)) + + +def list_of_paths_to_set(paths, ordered_paths=False): + if ordered_paths: + return {tuple(path) for path in paths} + else: + return {tuple(sorted(path)) for path in paths} + + +@pytest.mark.usefixtures("graph") +class TestThreeWindingTransformer: + def test_three_winding_transformer_group(self, graph): + assert graph._three_winding_branch_groups == {(2, 3, 5), (20, 30, 50)} + + @pytest.mark.parametrize( + ("source", "dest", "active_expected", "complete_expected"), + [ + pytest.param(1, 6, set(), {(1, 2, 5, 6)}, id="1->6"), + pytest.param(1, 5, set(), {(1, 2, 5)}, id="1->5"), + pytest.param( + 10, + 50, + { + (10, 40, 60, 50), + (10, 60, 50), + (10, 20, 30, 40, 60, 50), + }, + { + (10, 40, 60, 50), + (10, 40, 30, 50), + (10, 60, 40, 30, 50), + (10, 60, 50), + (10, 20, 50), + (10, 20, 30, 40, 60, 50), + }, + id="10->50", + ), + ], + ) + def test_get_all_paths(self, graph, active_only, source, dest, active_expected, complete_expected): + expected = active_expected if active_only else complete_expected + actual_paths = graph.get_all_paths(source, dest) + + assert {tuple(path) for path in actual_paths} == expected + + @pytest.mark.parametrize( + ("source", "dest", "active_expected", "complete_expected"), + [ + pytest.param(1, 6, set(), {(1, 2, 5, 6)}, id="1->6"), + pytest.param(1, 5, set(), {(1, 2, 5)}, id="1->5"), + pytest.param( + 10, + 50, + { + (10, 40, 60, 50), + (10, 60, 50), + }, + { + (10, 60, 50), + (10, 40, 60, 50), + (10, 40, 30, 50), + (10, 60, 40, 30, 50), + (10, 20, 50), + }, + id="10->50", + ), + ], + ) + def test_get_all_paths_removed_branch(self, graph, active_only, source, dest, active_expected, complete_expected): + expected = active_expected if active_only else complete_expected + with graph.tmp_remove_branches([(2, 3), (20, 30)]): + actual_paths = graph.get_all_paths(source, dest) + + assert {tuple(path) for path in actual_paths} == expected + + def test_nr_branches(self, graph, active_only): + expected = 11 if active_only else 15 + assert graph.nr_branches == expected + + def test_get_cycles(self, graph): + expected = ( + {(10, 20, 30, 40, 10), (40, 30, 20, 10, 60, 40)} + if graph.active_only + else {(10, 20, 30, 40, 10), (40, 30, 50, 60, 40), (10, 20, 50, 60, 10)} + ) + actual_cycles = graph.find_fundamental_cycles() + assert {tuple(cycle) for cycle in actual_cycles} == expected + + def test_get_cycles_removed_branch(self, graph): + branches_to_remove = [(2, 3), (20, 30)] if graph.active_only else [(2, 3), (3, 5), (20, 30), (30, 50)] + expected = {(40, 10, 60, 40)} if graph.active_only else {(40, 10, 20, 50, 60, 40), (10, 20, 50, 60, 10)} + + with graph.tmp_remove_branches(branches_to_remove): + actual_cycles = graph.find_fundamental_cycles() + assert {tuple(cycle) for cycle in actual_cycles} == expected + + def test_get_components(self, graph, active_only): + expected = ( + {(1, 2, 3, 4), (5, 6), (10, 20, 30, 40, 50, 60)} + if active_only + else {(1, 2, 3, 4, 5, 6), (10, 20, 30, 40, 50, 60)} + ) + actual_components = graph.get_components() + assert list_of_paths_to_set(actual_components) == expected + + with graph.tmp_remove_nodes([2, 60]): + expected = ( + {(1,), (3, 4), (5, 6), (10, 20, 30, 40), (50,)} + if active_only + else {(1,), (3, 4, 5, 6), (10, 20, 30, 40, 50)} + ) + assert list_of_paths_to_set(graph.get_components()) == list_of_paths_to_set(expected) + + def test_all_branches(self, graph, active_only): + expected_edges = { + (1, 2), + (2, 3), + (3, 4), + (5, 6), + (10, 20), + (20, 30), + (30, 40), + (50, 60), + (10, 40), + (10, 60), + (40, 60), + } + if not active_only: + expected_edges.update({(2, 5), (3, 5), (20, 50), (30, 50)}) + + assert set(graph.all_branches) == expected_edges + + def test_parallel_branches(self, graph, active_only): + graph.add_branch(2, 5) + assert graph.has_parallel_edges() is not active_only + + graph.add_branch(2, 3) + assert graph.has_parallel_edges() is True + + def test_find_first_connected(self, graph, active_only): + assert graph.find_first_connected(10, [20, 50]) == 20