Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
6c76575
First commit for RustworkX conversion project.
May 6, 2025
dd26473
Get recent changes made by Peter
May 6, 2025
043e24a
Added logic to support RX as the Graph provider after creating
May 23, 2025
e724e5d
Added laplacian_matrix() and normalized_laplacian_matrix() methods to…
Jul 8, 2025
4859115
More changes to get tests to pass
Jul 16, 2025
b15f6b5
Deleted files from the repo that should not be tracked...
Jul 17, 2025
fba884e
Lots and lots of changes to get tests to run.
Jul 27, 2025
a1e0953
Yet more changes to get tests to pass...
Jul 29, 2025
069ed26
Mostly cosmetic changes to change Camel-Case to Snake-Case such
Jul 30, 2025
e0a3f7e
Yet more changes to get tests to pass.
Aug 2, 2025
05cc344
Removed debugging code (print stmts)
Aug 4, 2025
d747493
Merge pull request #428 from chief-dweeb/frm_rustworkx
peterrrock2 Aug 5, 2025
e2b4398
Yet more changes to get tests to pass.
Aug 16, 2025
ebced33
Added test for conversion from NX graph back to NX graph after runnin…
Aug 16, 2025
74eceda
Since returning from vacation, I have been focused on getting all of
Oct 12, 2025
22fb203
Performance Changes:
Oct 28, 2025
ae645af
Merge pull request #431 from chief-dweeb/frm_rustworkx
cdonnay Nov 4, 2025
a81233f
CLEAN-UP work: no new functionality
Nov 5, 2025
da18c77
Added a test file and a README.txt file for tests.
Nov 5, 2025
afbec61
Trivial changes - I just categorized all of the TODO: comments
Nov 6, 2025
10bfc86
Mostly a "cleanup" set of changes - code readability,
Dec 8, 2025
4d1d1f2
Added two performance tests - using cProfile
Dec 8, 2025
195720e
Trivial one-line change in a comment...
Dec 8, 2025
37e2983
Merge pull request #432 from chief-dweeb/frm_rustworkx
peterrrock2 Dec 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions gerrychain/accept.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def cut_edge_accept(partition: Partition) -> bool:
Always accepts the flip if the number of cut_edges increases.
Otherwise, uses the Metropolis criterion to decide.

frm: TODO: Documentation: Add documentation on what the "Metropolis criterion" is...

:param partition: The current partition to accept a flip from.
:type partition: Partition

Expand Down
234 changes: 181 additions & 53 deletions gerrychain/constraints/contiguity.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,129 @@
from heapq import heappop, heappush
from itertools import count

import networkx as nx
from typing import Callable, Any, Dict, Set
from ..partition import Partition
import random
from .bounds import SelfConfiguringLowerBound


def are_reachable(G: nx.Graph, source: Any, avoid: Callable, targets: Any) -> bool:
from ..graph import Graph

# frm: TODO: Performance: Think about the efficiency of the routines in this module. Almost all
# of these involve traversing the entire graph, and I fear that callers
# might make multiple calls.
#
# Possible solutions are to 1) speed up these routines somehow and 2) cache
# results so that at least we don't do the traversals over and over.

# frm: TODO: Refactoring: Rethink WTF this module is all about.
#
# It seems like a grab bag for lots of different things - used in different places.
#
# What got me to write this comment was looking at the signature for def contiguous()
# which operates on a partition, but lots of other routines here operate on graphs or
# other things. So, what is going on?
#
# Peter replied to this comment in a pull request:
#
# So anything that is prefixed with an underscore in here should be a helper
# function and not a part of the public API. It looks like, other than
# is_connected_bfs (which should probably be marked "private" with an
# underscore) everything here is acting like an updater.
#


def _are_reachable(graph: Graph, start_node: Any, avoid: Callable, targets: Any) -> bool:
"""
A modified version of NetworkX's function
`networkx.algorithms.shortest_paths.weighted._dijkstra_multisource()`

This function checks if the targets are reachable from the source node
This function checks if the targets are reachable from the start_node node
while avoiding edges based on the avoid condition function.

:param G: The networkx graph
:type G: nx.Graph
:param source: The starting node
:type source: int
:param graph: Graph
:type graph: Graph
:param start_node: The starting node
:type start_node: int
:param avoid: The function that determines if an edge should be avoided.
It should take in three parameters: the start node, the end node, and
the edges to avoid. It should return True if the edge should be avoided,
False otherwise.
# frm: TODO: Documentation: Fix the comment above about the "avoid" function parameter.
# It may have once been accurate, but the original code below
# passed parameters to it of (node_id, neighbor_node_id, edge_data_dict)
# from NetworkX.Graph._succ So, "the edges to avoid" above is wrong.
# This whole issue is moot, however, since the only routine
# that is used as an avoid function ignores the third parameter.
# Or rather it used to avoid the third parameter, but it has
# been updated to only take two parameters, and the code below
# has been modified to use Graph.neighbors() instead of _succ
# because 1) we can't use NX and 2) because we don't need the
# edge data dictionary anyways...
#
:type avoid: Callable
:param targets: The target nodes that we would like to reach
:type targets: Any

:returns: True if all of the targets are reachable from the source node
:returns: True if all of the targets are reachable from the start_node node
under the avoid condition, False otherwise.
:rtype: bool
"""
G_succ = G._succ if G.is_directed() else G._adj

push = heappush
pop = heappop
dist = {} # dictionary of final distances
node_distances = {} # dictionary of final distances
seen = {}
# fringe is heapq with 3-tuples (distance,c,node)
# use the count c to avoid comparing nodes (may not be able to)
c = count()
fringe = []

seen[source] = 0
push(fringe, (0, next(c), source))

while not all(t in seen for t in targets) and fringe:
(d, _, v) = pop(fringe)
if v in dist:
seen[start_node] = 0
push(fringe, (0, next(c), start_node))


# frm: Original Code:
#
# while not all(t in seen for t in targets) and fringe:
# (d, _, v) = pop(fringe)
# if v in dist:
# continue # already searched this node.
# dist[v] = d
# for u, e in G_succ[v].items():
# if avoid(v, u, e):
# continue
#
# vu_dist = dist[v] + 1
# if u not in seen or vu_dist < seen[u]:
# seen[u] = vu_dist
# push(fringe, (vu_dist, next(c), u))
#
# return all(t in seen for t in targets)
#



# While we have not yet seen all of our targets and while there is
# still some fringe...
while not all(tgt in seen for tgt in targets) and fringe:
(distance, _, node_id) = pop(fringe)
if node_id in node_distances:
continue # already searched this node.
dist[v] = d
for u, e in G_succ[v].items():
if avoid(v, u, e):
node_distances[node_id] = distance

for neighbor in graph.neighbors(node_id):
if avoid(node_id, neighbor):
continue

vu_dist = dist[v] + 1
if u not in seen or vu_dist < seen[u]:
seen[u] = vu_dist
push(fringe, (vu_dist, next(c), u))
neighbor_distance = node_distances[node_id] + 1
if neighbor not in seen or neighbor_distance < seen[neighbor]:
seen[neighbor] = neighbor_distance
push(fringe, (neighbor_distance, next(c), neighbor))

return all(t in seen for t in targets)
# frm: TODO: Refactoring: Simplify this code. It computes distances and counts but
# never uses them. These must be relics of code copied
# from somewhere else where it had more uses...

return all(tgt in seen for tgt in targets)

def single_flip_contiguous(partition: Partition) -> bool:
"""
Expand All @@ -87,7 +149,7 @@ def single_flip_contiguous(partition: Partition) -> bool:
graph = partition.graph
assignment = partition.assignment

def partition_edge_avoid(start_node: Any, end_node: Any, edge_attrs: Dict):
def _partition_edge_avoid(start_node: Any, end_node: Any):
"""
Helper function used in the graph traversal to avoid edges that cross between different
assignments. It's crucial for ensuring that the traversal only considers paths within
Expand All @@ -98,7 +160,7 @@ def partition_edge_avoid(start_node: Any, end_node: Any, edge_attrs: Dict):
:param end_node: The end node of the edge.
:type end_node: Any
:param edge_attrs: The attributes of the edge (not used in this function). Needed
because this function is passed to :func:`are_reachable`, which expects the
because this function is passed to :func:`_are_reachable`, which expects the
avoid function to have this signature.
:type edge_attrs: Dict

Expand Down Expand Up @@ -126,8 +188,10 @@ def partition_edge_avoid(start_node: Any, end_node: Any, edge_attrs: Dict):
start_neighbor = random.choice(old_neighbors)

# Check if all old neighbors in the same assignment are still reachable.
connected = are_reachable(
graph, start_neighbor, partition_edge_avoid, old_neighbors
# The "_partition_edge_avoid" function will prevent searching across
# a part (district) boundary
connected = _are_reachable(
graph, start_neighbor, _partition_edge_avoid, old_neighbors
)

if not connected:
Expand All @@ -138,7 +202,7 @@ def partition_edge_avoid(start_node: Any, end_node: Any, edge_attrs: Dict):
return True


def affected_parts(partition: Partition) -> Set[int]:
def _affected_parts(partition: Partition) -> Set[int]:
"""
Checks which partitions were affected by the change of nodes.

Expand Down Expand Up @@ -168,19 +232,19 @@ def affected_parts(partition: Partition) -> Set[int]:

def contiguous(partition: Partition) -> bool:
"""
Check if the parts of a partition are connected using :func:`networkx.is_connected`.
Check if the parts of a partition are connected

:param partition: The proposed next :class:`~gerrychain.partition.Partition`
:type partition: Partition

:returns: Whether the partition is contiguous
:rtype: bool
"""

return all(
nx.is_connected(partition.subgraphs[part]) for part in affected_parts(partition)
is_connected_bfs(partition.subgraphs[part]) for part in _affected_parts(partition)
)


def contiguous_bfs(partition: Partition) -> bool:
"""
Checks that a given partition's parts are connected as graphs using a simple
Expand All @@ -192,17 +256,36 @@ def contiguous_bfs(partition: Partition) -> bool:
:returns: Whether the parts of this partition are connected
:rtype: bool
"""
parts_to_check = affected_parts(partition)

# Generates a subgraph for each district and perform a BFS on it
# to check connectedness.
for part in parts_to_check:
adj = nx.to_dict_of_lists(partition.subgraphs[part])
if _bfs(adj) is False:
return False

return True


# frm: TODO: Refactoring: Figure out why this routine, contiguous_bfs() exists.
#
# It is mentioned in __init__.py so maybe it is used externally in legacy code.
#
# However, I have changed the code so that it just calls contiguous() and all
# of the tests pass, so I am going to assume that my comment below is accurate,
# that is, I am assuming that this function does not need to exist independently
# except for legacy purposes. Stated differently, if someone can verify that
# this routine is NOT needed for legacy purposes, then we can just delete it.
#
# It seems to be exactly the same conceptually as contiguous(). It looks
# at the "affected" parts - those that have changed node
# assignments from parent, and sees if those parts are
# contiguous.
#
# frm: Original Code:
#
# parts_to_check = _affected_parts(partition)
#
# # Generates a subgraph for each district and perform a BFS on it
# # to check connectedness.
# for part in parts_to_check:
# adj = nx.to_dict_of_lists(partition.subgraphs[part])
# if _bfs(adj) is False:
# return False
#
# return True

return contiguous(partition)

def number_of_contiguous_parts(partition: Partition) -> int:
"""
Expand All @@ -213,7 +296,7 @@ def number_of_contiguous_parts(partition: Partition) -> int:
:rtype: int
"""
parts = partition.assignment.parts
return sum(1 for part in parts if nx.is_connected(partition.subgraphs[part]))
return sum(1 for part in parts if is_connected_bfs(partition.subgraphs[part]))


# Create an instance of SelfConfiguringLowerBound using the number_of_contiguous_parts function.
Expand All @@ -235,11 +318,31 @@ def contiguous_components(partition: Partition) -> Dict[int, list]:
subgraphs of that part of the partition
:rtype: dict
"""
return {
part: [subgraph.subgraph(nodes) for nodes in nx.connected_components(subgraph)]
for part, subgraph in partition.subgraphs.items()
}

# frm: TODO: Documentation: Migration Guide: NX vs RX Issues here:
#
# The call on subgraph() below is perhaps problematic because it will renumber
# node_ids...
#
# The issue is not that the code is incorrect (with RX there is really no other
# option), but rather that any legacy code will be unprepared to deal with the fact
# that the subgraphs returned are (I think) three node translations away from the
# original NX-Graph object's node_ids.
#
# Translations:
#
# 1) From NX to RX when partition was created
# 2) From top-level RX graph to the partition's subgraphs for each part (district)
# 3) From each part's subgraph to the subgraphs of contiguous_components...
#

connected_components_in_each_partition = {}
for part, subgraph in partition.subgraphs.items():
# create a subgraph for each set of connected nodes in the part's nodes
list_of_connected_subgraphs = subgraph.subgraphs_for_connected_components()
connected_components_in_each_partition[part] = list_of_connected_subgraphs

return connected_components_in_each_partition

def _bfs(graph: Dict[int, list]) -> bool:
"""
Expand All @@ -254,11 +357,11 @@ def _bfs(graph: Dict[int, list]) -> bool:
"""
q = [next(iter(graph))]
visited = set()
total_vertices = len(graph)
num_nodes = len(graph)

# Check if the district has a single vertex. If it does, then simply return
# `True`, as it's trivially connected.
if total_vertices <= 1:
if num_nodes <= 1:
return True

# bfs!
Expand All @@ -271,4 +374,29 @@ def _bfs(graph: Dict[int, list]) -> bool:
visited.add(neighbor)
q += [neighbor]

return total_vertices == len(visited)
return num_nodes == len(visited)

# frm: TODO: Testing: Verify that is_connected_bfs() works - add a test or two...

# frm: TODO: Refactoring: Move this code into graph.py. It is all about the Graph...

# frm: TODO: Documentation: This code was obtained from the web - probably could be optimized...
# This code replaced calls on nx.is_connected()
def is_connected_bfs(graph: Graph):
if not graph:
return True

nodes = list(graph.node_indices)

start_node = random.choice(nodes)
visited = {start_node}
queue = [start_node]

while queue:
current_node = queue.pop(0)
for neighbor in graph.neighbors(current_node):
if neighbor not in visited:
visited.add(neighbor)
queue.append(neighbor)

return len(visited) == len(nodes)
Loading
Loading