Skip to content
2 changes: 2 additions & 0 deletions gerrychain/accept.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def cut_edge_accept(partition: Partition) -> bool:
Always accepts the flip if the number of cut_edges increases.
Otherwise, uses the Metropolis criterion to decide.

frm: TODO: Add documentation on what the "Metropolis criterion" is...

:param partition: The current partition to accept a flip from.
:type partition: Partition

Expand Down
228 changes: 179 additions & 49 deletions gerrychain/constraints/contiguity.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,122 @@
from heapq import heappop, heappush
from itertools import count

import networkx as nx
from typing import Callable, Any, Dict, Set
from ..partition import Partition
import random
from .bounds import SelfConfiguringLowerBound

from ..graph import Graph

def are_reachable(G: nx.Graph, source: Any, avoid: Callable, targets: Any) -> bool:
# frm TODO: Remove this comment about NX dependencies (once we are all set with the work)
#
# NX dependencies:
# def _are_reachable(G: nx.Graph, ...)
# nx.is_connected(partition.subgraphs[part]) for part in _affected_parts(partition)
# adj = nx.to_dict_of_lists(partition.subgraphs[part])
#

# frm: TODO: Think about the efficiency of the routines in this module. Almost all
# of these involve traversing the entire graph, and I fear that callers
# might make multiple calls.
#
# Possible solutions are to 1) speed up these routines somehow and 2) cache
# results so that at least we don't do the traversals over and over.

def _are_reachable(graph: Graph, start_node: Any, avoid: Callable, targets: Any) -> bool:
"""
A modified version of NetworkX's function
`networkx.algorithms.shortest_paths.weighted._dijkstra_multisource()`

This function checks if the targets are reachable from the source node
This function checks if the targets are reachable from the start_node node
while avoiding edges based on the avoid condition function.

:param G: The networkx graph
:type G: nx.Graph
:param source: The starting node
:type source: int
:param graph: Graph
:type graph: Graph
:param start_node: The starting node
:type start_node: int
:param avoid: The function that determines if an edge should be avoided.
It should take in three parameters: the start node, the end node, and
the edges to avoid. It should return True if the edge should be avoided,
False otherwise.
# frm: TODO: Fix the comment above about the "avoid" function parameter.
# It may have once been accurate, but the original code below
# passed parameters to it of (node_id, neighbor_node_id, edge_data_dict)
# from NetworkX.Graph._succ So, "the edges to avoid" above is wrong.
# This whole issue is moot, however, since the only routine
# that is used as an avoid function ignores the third parameter.
# Or rather it used to avoid the third parameter, but it has
# been updated to only take two parameters, and the code below
# has been modified to use Graph.neighbors() instead of _succ
# because 1) we can't use NX and 2) because we don't need the
# edge data dictionary anyways...
#
:type avoid: Callable
:param targets: The target nodes that we would like to reach
:type targets: Any

:returns: True if all of the targets are reachable from the source node
:returns: True if all of the targets are reachable from the start_node node
under the avoid condition, False otherwise.
:rtype: bool
"""
G_succ = G._succ if G.is_directed() else G._adj

push = heappush
pop = heappop
dist = {} # dictionary of final distances
node_distances = {} # dictionary of final distances
seen = {}
# fringe is heapq with 3-tuples (distance,c,node)
# use the count c to avoid comparing nodes (may not be able to)
c = count()
fringe = []

seen[source] = 0
push(fringe, (0, next(c), source))

while not all(t in seen for t in targets) and fringe:
(d, _, v) = pop(fringe)
if v in dist:
seen[start_node] = 0
push(fringe, (0, next(c), start_node))


# frm: Original Code:
#
# while not all(t in seen for t in targets) and fringe:
# (d, _, v) = pop(fringe)
# if v in dist:
# continue # already searched this node.
# dist[v] = d
# for u, e in G_succ[v].items():
# if avoid(v, u, e):
# continue
#
# vu_dist = dist[v] + 1
# if u not in seen or vu_dist < seen[u]:
# seen[u] = vu_dist
# push(fringe, (vu_dist, next(c), u))
#
# return all(t in seen for t in targets)
#



# While we have not yet seen all of our targets and while there is
# still some fringe...
while not all(tgt in seen for tgt in targets) and fringe:
(distance, _, node_id) = pop(fringe)
if node_id in node_distances:
continue # already searched this node.
dist[v] = d
for u, e in G_succ[v].items():
if avoid(v, u, e):
node_distances[node_id] = distance

dbg_neighbors = graph.neighbors(node_id)

for neighbor in graph.neighbors(node_id):
if avoid(node_id, neighbor):
continue

vu_dist = dist[v] + 1
if u not in seen or vu_dist < seen[u]:
seen[u] = vu_dist
push(fringe, (vu_dist, next(c), u))
neighbor_distance = node_distances[node_id] + 1
if neighbor not in seen or neighbor_distance < seen[neighbor]:
seen[neighbor] = neighbor_distance
push(fringe, (neighbor_distance, next(c), neighbor))

return all(t in seen for t in targets)
# frm: TODO: Simplify this code. It computes distances and counts but
# never uses them. These must be relics of code copied
# from somewhere else where it had more uses...

return all(tgt in seen for tgt in targets)

def single_flip_contiguous(partition: Partition) -> bool:
"""
Expand All @@ -87,7 +142,7 @@ def single_flip_contiguous(partition: Partition) -> bool:
graph = partition.graph
assignment = partition.assignment

def partition_edge_avoid(start_node: Any, end_node: Any, edge_attrs: Dict):
def _partition_edge_avoid(start_node: Any, end_node: Any):
"""
Helper function used in the graph traversal to avoid edges that cross between different
assignments. It's crucial for ensuring that the traversal only considers paths within
Expand All @@ -98,7 +153,7 @@ def partition_edge_avoid(start_node: Any, end_node: Any, edge_attrs: Dict):
:param end_node: The end node of the edge.
:type end_node: Any
:param edge_attrs: The attributes of the edge (not used in this function). Needed
because this function is passed to :func:`are_reachable`, which expects the
because this function is passed to :func:`_are_reachable`, which expects the
avoid function to have this signature.
:type edge_attrs: Dict

Expand Down Expand Up @@ -126,8 +181,10 @@ def partition_edge_avoid(start_node: Any, end_node: Any, edge_attrs: Dict):
start_neighbor = random.choice(old_neighbors)

# Check if all old neighbors in the same assignment are still reachable.
connected = are_reachable(
graph, start_neighbor, partition_edge_avoid, old_neighbors
# The "_partition_edge_avoid" function will prevent searching across
# a part (district) boundary
connected = _are_reachable(
graph, start_neighbor, _partition_edge_avoid, old_neighbors
)

if not connected:
Expand All @@ -138,7 +195,7 @@ def partition_edge_avoid(start_node: Any, end_node: Any, edge_attrs: Dict):
return True


def affected_parts(partition: Partition) -> Set[int]:
def _affected_parts(partition: Partition) -> Set[int]:
"""
Checks which partitions were affected by the change of nodes.

Expand Down Expand Up @@ -168,19 +225,24 @@ def affected_parts(partition: Partition) -> Set[int]:

def contiguous(partition: Partition) -> bool:
"""
Check if the parts of a partition are connected using :func:`networkx.is_connected`.
Check if the parts of a partition are connected

:param partition: The proposed next :class:`~gerrychain.partition.Partition`
:type partition: Partition

:returns: Whether the partition is contiguous
:rtype: bool
"""
# frm: Original code:
#
# return all(
# nx.is_connected(partition.subgraphs[part]) for part in _affected_parts(partition)
# )

return all(
nx.is_connected(partition.subgraphs[part]) for part in affected_parts(partition)
is_connected_bfs(partition.subgraphs[part]) for part in _affected_parts(partition)
)


def contiguous_bfs(partition: Partition) -> bool:
"""
Checks that a given partition's parts are connected as graphs using a simple
Expand All @@ -192,17 +254,31 @@ def contiguous_bfs(partition: Partition) -> bool:
:returns: Whether the parts of this partition are connected
:rtype: bool
"""
parts_to_check = affected_parts(partition)

# Generates a subgraph for each district and perform a BFS on it
# to check connectedness.
for part in parts_to_check:
adj = nx.to_dict_of_lists(partition.subgraphs[part])
if _bfs(adj) is False:
return False

return True


# frm: TODO: Try to figure out why this routine exists. It seems to be
# exactly the same conceptually as contiguous(). It looks
# at the "affected" parts - those that have changed node
# assignments from parent, and sees if those parts are
# contiguous.
#
# For now, I have just replaced the existing code which depended
# on NX with a call on contiguous(partition).
#

# frm: Original Code:
#
# parts_to_check = _affected_parts(partition)
#
# # Generates a subgraph for each district and perform a BFS on it
# # to check connectedness.
# for part in parts_to_check:
# adj = nx.to_dict_of_lists(partition.subgraphs[part])
# if _bfs(adj) is False:
# return False
#
# return True

return contiguous(partition)

def number_of_contiguous_parts(partition: Partition) -> int:
"""
Expand All @@ -212,8 +288,12 @@ def number_of_contiguous_parts(partition: Partition) -> int:
:returns: Number of contiguous parts in the partition.
:rtype: int
"""
# frm: Original Code:
# parts = partition.assignment.parts
# return sum(1 for part in parts if nx.is_connected(partition.subgraphs[part]))
#
parts = partition.assignment.parts
return sum(1 for part in parts if nx.is_connected(partition.subgraphs[part]))
return sum(1 for part in parts if is_connected_bfs(partition.subgraphs[part]))


# Create an instance of SelfConfiguringLowerBound using the number_of_contiguous_parts function.
Expand All @@ -235,11 +315,37 @@ def contiguous_components(partition: Partition) -> Dict[int, list]:
subgraphs of that part of the partition
:rtype: dict
"""
return {
part: [subgraph.subgraph(nodes) for nodes in nx.connected_components(subgraph)]
for part, subgraph in partition.subgraphs.items()
}

# frm: TODO: NX vs RX Issues here:
#
# The call on subgraph() below is perhaps problematic because it will renumber
# node_ids...
#
# The issue is not that the code is incorrect (with RX there is really no other
# option), but rather that any legacy code will be unprepared to deal with the fact
# that the subgraphs returned are (I think) three node translations away from the
# original NX-Graph object's node_ids.
#
# Translations:
#
# 1) From NX to RX when partition was created
# 2) From top-level RX graph to the partition's subgraphs for each part (district)
# 3) From each part's subgraph to the subgraphs of contiguous_components...
#

# frm: Original Code:
# return {
# part: [subgraph.subgraph(nodes) for nodes in nx.connected_components(subgraph)]
# for part, subgraph in partition.subgraphs.items()
# }
#
connected_components_in_each_partition = {}
for part, subgraph in partition.subgraphs.items():
# create a subgraph for each set of connected nodes in the part's nodes
list_of_connected_subgraphs = subgraph.subgraphs_for_connected_components()
connected_components_in_each_partition[part] = list_of_connected_subgraphs

return connected_components_in_each_partition

def _bfs(graph: Dict[int, list]) -> bool:
"""
Expand All @@ -254,6 +360,7 @@ def _bfs(graph: Dict[int, list]) -> bool:
"""
q = [next(iter(graph))]
visited = set()
# frm TODO: Make sure len() is defined on Graph object...
total_vertices = len(graph)

# Check if the district has a single vertex. If it does, then simply return
Expand All @@ -272,3 +379,26 @@ def _bfs(graph: Dict[int, list]) -> bool:
q += [neighbor]

return total_vertices == len(visited)

# frm: TODO: Verify that is_connected_bfs() works - add a test or two...

# frm: Code obtained from the web - probably could be optimized...
# This code replaced calls on nx.is_connected()
def is_connected_bfs(graph: Graph):
if not graph:
return True

nodes = list(graph.node_indices)

start_node = random.choice(nodes)
visited = {start_node}
queue = [start_node]

while queue:
current_node = queue.pop(0)
for neighbor in graph.neighbors(current_node):
if neighbor not in visited:
visited.add(neighbor)
queue.append(neighbor)

return len(visited) == len(nodes)
Loading