Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 94 additions & 8 deletions src/power_grid_model_ds/_core/model/graphs/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager
from itertools import combinations
from typing import TYPE_CHECKING, Counter, Generator

from numpy._typing import NDArray
Expand All @@ -26,6 +27,7 @@ class BaseGraphModel(ABC):

def __init__(self, active_only=False) -> None:
self.active_only = active_only
self._three_winding_branch_groups: set[tuple[int]] = set()

def __repr__(self) -> str:
return (
Expand Down Expand Up @@ -184,7 +186,11 @@ def add_branch_array(self, branch_array: BranchArray) -> None:
def add_branch3_array(self, branch3_array: Branch3Array) -> None:
"""Add all branch3s in the branch3 array to the graph."""
for branch3 in branch3_array:
self.add_branch_array(branch3.as_branches())
seperate_branches = branch3.as_branches()
self.add_branch_array(seperate_branches)
self._three_winding_branch_groups.add(
tuple(sorted([branch3.node_1.item(), branch3.node_2.item(), branch3.node_3.item()]))
)

def delete_branch_array(self, branch_array: BranchArray, raise_on_fail: bool = True) -> None:
"""Delete all branches in branch_array from the graph."""
Expand All @@ -196,6 +202,8 @@ def delete_branch3_array(self, branch3_array: Branch3Array, raise_on_fail: bool
"""Delete all branch3s in the branch3 array from the graph."""
for branch3 in branch3_array:
self.delete_branch_array(branch3.as_branches(), raise_on_fail=raise_on_fail)
nodes = tuple(sorted([branch3.node_1.item(), branch3.node_2.item(), branch3.node_3.item()]))
self._three_winding_branch_groups.remove(nodes)

@contextmanager
def tmp_remove_nodes(self, nodes: list[int]) -> Generator:
Expand All @@ -219,6 +227,16 @@ def tmp_remove_nodes(self, nodes: list[int]) -> Generator:
for source, target in edge_list:
self.add_branch(source, target)

@contextmanager
def tmp_remove_branches(self, branches: list[tuple[int, int]]) -> Generator:
for from_node, to_node in branches:
self.delete_branch(from_node, to_node)

yield

for from_node, to_node in branches:
self.add_branch(from_node, to_node)

def get_shortest_path(self, ext_start_node_id: int, ext_end_node_id: int) -> tuple[list[int], int]:
"""Calculate the shortest path between two nodes

Expand Down Expand Up @@ -246,19 +264,73 @@ def get_shortest_path(self, ext_start_node_id: int, ext_end_node_id: int) -> tup
except NoPathBetweenNodes as e:
raise NoPathBetweenNodes(f"No path between nodes {ext_start_node_id} and {ext_end_node_id}") from e

def _squash_paths_inside_three_winding_transformers(self, paths: list[list[int]]) -> list[list[int]]:
"""For each path check if we can squash it.

If a path contains three nodes of the three winding transformer, we can remove the middle node.
Since we can always go directly.

NOTE: if a node has an inactive status, we can never have a situation where
we go through 3 nodes of the transformer directly after each other.
So this also works for inactive three winding transformers."""
replacements = set()
for group in self._three_winding_branch_groups:
replacements.add(frozenset(group))

for i in range(len(paths)):
path = paths[i]

invalid_indices = set()
for index in range(1, len(path) - 1):
if frozenset(sorted([path[index - 1], path[index], path[index + 1]])) in replacements:
invalid_indices.add(index)

paths[i] = [node for index, node in enumerate(path) if index not in invalid_indices]
return paths

def _branches_to_remove_from_three_winding_transformers(self) -> list[tuple[int, int]]:
"""Returns a list of branches that should be temporarily removed cycles by three winding transformers.

By removing one branch, we still can make all the paths but do not create to many solutions/invalid solutions.
Later on, we can remove any detour we made because we didn't use the removed branch.

NOTE: we only want to remove a branch for a three winding transformer if all three of branches are active.
"""
branches_to_remove = []
for group in self._three_winding_branch_groups:
if all(self.has_branch(from_node, to_node) for from_node, to_node in combinations(group, 2)):
branches_to_remove.append((group[1], group[2]))

return branches_to_remove

def get_all_paths(self, ext_start_node_id: int, ext_end_node_id: int) -> list[list[int]]:
"""Retrieves all paths between two (external) nodes.
Returns a list of paths, each path containing a list of external nodes.
"""
if ext_start_node_id == ext_end_node_id:
return []

internal_paths = self._get_all_paths(
source=self.external_to_internal(ext_start_node_id),
target=self.external_to_internal(ext_end_node_id),
)
correct_for_three_winding = False
if self._three_winding_branch_groups:
branches_to_remove = self._branches_to_remove_from_three_winding_transformers()
correct_for_three_winding = True

if correct_for_three_winding:
with self.tmp_remove_branches(branches_to_remove):
internal_paths = self._get_all_paths(
source=self.external_to_internal(ext_start_node_id),
target=self.external_to_internal(ext_end_node_id),
)
else:
internal_paths = self._get_all_paths(
source=self.external_to_internal(ext_start_node_id),
target=self.external_to_internal(ext_end_node_id),
)

return [self._internals_to_externals(path) for path in internal_paths]
paths = [self._internals_to_externals(path) for path in internal_paths]
if correct_for_three_winding:
return self._squash_paths_inside_three_winding_transformers(paths)
return paths

def get_components(self) -> list[list[int]]:
"""Returns all separate components of the graph as lists
Expand Down Expand Up @@ -343,8 +415,22 @@ def find_fundamental_cycles(self) -> list[list[int]]:
Returns:
list[list[int]]: list of cycles, each cycle is a list of (external) node ids
"""
internal_cycles = self._find_fundamental_cycles()
return [self._internals_to_externals(nodes) for nodes in internal_cycles]
correct_for_three_winding = False
if self._three_winding_branch_groups:
branches_to_remove = self._branches_to_remove_from_three_winding_transformers()
correct_for_three_winding = True

if correct_for_three_winding:
with self.tmp_remove_branches(branches_to_remove):
internal_cycles = self._find_fundamental_cycles()
else:
internal_cycles = self._find_fundamental_cycles()

cycles = [self._internals_to_externals(nodes) for nodes in internal_cycles]

if correct_for_three_winding:
return self._squash_paths_inside_three_winding_transformers(cycles)
return cycles

@classmethod
def from_arrays(cls, arrays: "Grid", active_only=False) -> "BaseGraphModel":
Expand Down
21 changes: 21 additions & 0 deletions tests/unit/model/graphs/test_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,27 @@ def test_find_first_connected_no_match(self, graph_with_2_routes: BaseGraphModel
graph.find_first_connected(1, candidate_node_ids=[99])


class TestTmpRemoveBranches:
def test_tmp_remove_branches(self, graph_with_2_routes: BaseGraphModel):
graph = deepcopy(graph_with_2_routes)

assert graph.has_branch(1, 2)
assert graph.has_branch(2, 3)

with graph.tmp_remove_branches([(1, 2), (2, 3)]):
assert not graph.has_branch(1, 2)
assert not graph.has_branch(2, 3)

assert graph == graph_with_2_routes

def test_tmp_remove_branches_non_existent_branch(self, graph_with_2_routes: BaseGraphModel):
graph = deepcopy(graph_with_2_routes)

with pytest.raises(MissingBranchError):
with graph.tmp_remove_branches([(1, 99)]):
pass


class TestEq:
def test_eq(self, graph_with_2_routes: BaseGraphModel):
copied_graph = deepcopy(graph_with_2_routes)
Expand Down
194 changes: 194 additions & 0 deletions tests/unit/model/graphs/test_three_winding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import pytest

from power_grid_model_ds._core.model.arrays.pgm_arrays import NodeArray, ThreeWindingTransformerArray
from power_grid_model_ds._core.model.graphs.models.rustworkx import RustworkxGraphModel

Check failure on line 4 in tests/unit/model/graphs/test_three_winding.py

View workflow job for this annotation

GitHub Actions / check-code-quality / check-code-quality (3.12)

ruff (I001)

tests/unit/model/graphs/test_three_winding.py:1:1: I001 Import block is un-sorted or un-formatted help: Organize imports

Check failure on line 4 in tests/unit/model/graphs/test_three_winding.py

View workflow job for this annotation

GitHub Actions / check-code-quality / check-code-quality (3.13)

ruff (I001)

tests/unit/model/graphs/test_three_winding.py:1:1: I001 Import block is un-sorted or un-formatted help: Organize imports

Check failure on line 4 in tests/unit/model/graphs/test_three_winding.py

View workflow job for this annotation

GitHub Actions / check-code-quality / check-code-quality (3.14)

ruff (I001)

tests/unit/model/graphs/test_three_winding.py:1:1: I001 Import block is un-sorted or un-formatted help: Organize imports


def _setup_graph(graph: RustworkxGraphModel) -> None:
nodes = NodeArray.empty(12)
nodes.id = [1, 2, 3, 4, 5, 6, 10, 20, 30, 40, 50, 60]

graph.add_node_array(nodes)

# The three winding arrays are open at 5 and 50.
# TODO: add all the 10-60 and 10-40 as well.
# -------------------
# | |
# 1 -- 2 --- 3 -- 4 10 -- 20 --- 30 -- 40
# | | | |
# 5 | 50 |
# | | | |
# 6 ---------60--------

three = ThreeWindingTransformerArray.empty(2)
three.node_1 = [2, 20]
three.node_2 = [3, 30]
three.node_3 = [5, 50]
three.status_1 = [1, 1]
three.status_2 = [1, 1]
three.status_3 = [0, 0]
graph.add_branch3_array(three)

graph.add_branch(1, 2)
graph.add_branch(3, 4)
graph.add_branch(5, 6)

graph.add_branch(10, 20)
graph.add_branch(30, 40)
graph.add_branch(50, 60)

graph.add_branch(10, 60)
graph.add_branch(10, 40)
graph.add_branch(40, 60)
return graph


@pytest.fixture(params=[True, False], ids=["active", "complete"])
def active_only(request):
return request.param


@pytest.fixture
def graph(active_only):
return _setup_graph(RustworkxGraphModel(active_only=active_only))


def list_of_paths_to_set(paths, ordered_paths=False):
if ordered_paths:
return {tuple(path) for path in paths}
else:
return {tuple(sorted(path)) for path in paths}


@pytest.mark.usefixtures("graph")
class TestThreeWindingTransformer:
def test_three_winding_transformer_group(self, graph):
assert graph._three_winding_branch_groups == {(2, 3, 5), (20, 30, 50)}

@pytest.mark.parametrize(
("source", "dest", "active_expected", "complete_expected"),
[
pytest.param(1, 6, set(), {(1, 2, 5, 6)}, id="1->6"),
pytest.param(1, 5, set(), {(1, 2, 5)}, id="1->5"),
pytest.param(
10,
50,
{
(10, 40, 60, 50),
(10, 60, 50),
(10, 20, 30, 40, 60, 50),
},
{
(10, 40, 60, 50),
(10, 40, 30, 50),
(10, 60, 40, 30, 50),
(10, 60, 50),
(10, 20, 50),
(10, 20, 30, 40, 60, 50),
},
id="10->50",
),
],
)
def test_get_all_paths(self, graph, active_only, source, dest, active_expected, complete_expected):
expected = active_expected if active_only else complete_expected
actual_paths = graph.get_all_paths(source, dest)

assert {tuple(path) for path in actual_paths} == expected

@pytest.mark.parametrize(
("source", "dest", "active_expected", "complete_expected"),
[
pytest.param(1, 6, set(), {(1, 2, 5, 6)}, id="1->6"),
pytest.param(1, 5, set(), {(1, 2, 5)}, id="1->5"),
pytest.param(
10,
50,
{
(10, 40, 60, 50),
(10, 60, 50),
},
{
(10, 60, 50),
(10, 40, 60, 50),
(10, 40, 30, 50),
(10, 60, 40, 30, 50),
(10, 20, 50),
},
id="10->50",
),
],
)
def test_get_all_paths_removed_branch(self, graph, active_only, source, dest, active_expected, complete_expected):
expected = active_expected if active_only else complete_expected
with graph.tmp_remove_branches([(2, 3), (20, 30)]):
actual_paths = graph.get_all_paths(source, dest)

assert {tuple(path) for path in actual_paths} == expected

def test_nr_branches(self, graph, active_only):
expected = 11 if active_only else 15
assert graph.nr_branches == expected

def test_get_cycles(self, graph):
expected = (
{(10, 20, 30, 40, 10), (40, 30, 20, 10, 60, 40)}
if graph.active_only
else {(10, 20, 30, 40, 10), (40, 30, 50, 60, 40), (10, 20, 50, 60, 10)}
)
actual_cycles = graph.find_fundamental_cycles()
assert {tuple(cycle) for cycle in actual_cycles} == expected

def test_get_cycles_removed_branch(self, graph):
branches_to_remove = [(2, 3), (20, 30)] if graph.active_only else [(2, 3), (3, 5), (20, 30), (30, 50)]
expected = {(40, 10, 60, 40)} if graph.active_only else {(40, 10, 20, 50, 60, 40), (10, 20, 50, 60, 10)}

with graph.tmp_remove_branches(branches_to_remove):
actual_cycles = graph.find_fundamental_cycles()
assert {tuple(cycle) for cycle in actual_cycles} == expected

def test_get_components(self, graph, active_only):
expected = (
{(1, 2, 3, 4), (5, 6), (10, 20, 30, 40, 50, 60)}
if active_only
else {(1, 2, 3, 4, 5, 6), (10, 20, 30, 40, 50, 60)}
)
actual_components = graph.get_components()
assert list_of_paths_to_set(actual_components) == expected

with graph.tmp_remove_nodes([2, 60]):
expected = (
{(1,), (3, 4), (5, 6), (10, 20, 30, 40), (50,)}
if active_only
else {(1,), (3, 4, 5, 6), (10, 20, 30, 40, 50)}
)
assert list_of_paths_to_set(graph.get_components()) == list_of_paths_to_set(expected)

def test_all_branches(self, graph, active_only):
expected_edges = {
(1, 2),
(2, 3),
(3, 4),
(5, 6),
(10, 20),
(20, 30),
(30, 40),
(50, 60),
(10, 40),
(10, 60),
(40, 60),
}
if not active_only:
expected_edges.update({(2, 5), (3, 5), (20, 50), (30, 50)})

assert set(graph.all_branches) == expected_edges

def test_parallel_branches(self, graph, active_only):
graph.add_branch(2, 5)
assert graph.has_parallel_edges() is not active_only

graph.add_branch(2, 3)
assert graph.has_parallel_edges() is True

def test_find_first_connected(self, graph, active_only):
assert graph.find_first_connected(10, [20, 50]) == 20
Loading