diff --git a/gerrychain/accept.py b/gerrychain/accept.py index 92d3b5c8..fee5393c 100644 --- a/gerrychain/accept.py +++ b/gerrychain/accept.py @@ -21,6 +21,8 @@ def cut_edge_accept(partition: Partition) -> bool: Always accepts the flip if the number of cut_edges increases. Otherwise, uses the Metropolis criterion to decide. + frm: TODO: Documentation: Add documentation on what the "Metropolis criterion" is... + :param partition: The current partition to accept a flip from. :type partition: Partition diff --git a/gerrychain/constraints/contiguity.py b/gerrychain/constraints/contiguity.py index e1077e4a..283e42e2 100644 --- a/gerrychain/constraints/contiguity.py +++ b/gerrychain/constraints/contiguity.py @@ -1,67 +1,129 @@ from heapq import heappop, heappush from itertools import count -import networkx as nx from typing import Callable, Any, Dict, Set from ..partition import Partition import random from .bounds import SelfConfiguringLowerBound - -def are_reachable(G: nx.Graph, source: Any, avoid: Callable, targets: Any) -> bool: +from ..graph import Graph + +# frm: TODO: Performance: Think about the efficiency of the routines in this module. Almost all +# of these involve traversing the entire graph, and I fear that callers +# might make multiple calls. +# +# Possible solutions are to 1) speed up these routines somehow and 2) cache +# results so that at least we don't do the traversals over and over. + +# frm: TODO: Refactoring: Rethink WTF this module is all about. +# +# It seems like a grab bag for lots of different things - used in different places. +# +# What got me to write this comment was looking at the signature for def contiguous() +# which operates on a partition, but lots of other routines here operate on graphs or +# other things. So, what is going on? +# +# Peter replied to this comment in a pull request: +# +# So anything that is prefixed with an underscore in here should be a helper +# function and not a part of the public API. It looks like, other than +# is_connected_bfs (which should probably be marked "private" with an +# underscore) everything here is acting like an updater. +# + + +def _are_reachable(graph: Graph, start_node: Any, avoid: Callable, targets: Any) -> bool: """ A modified version of NetworkX's function `networkx.algorithms.shortest_paths.weighted._dijkstra_multisource()` - This function checks if the targets are reachable from the source node + This function checks if the targets are reachable from the start_node node while avoiding edges based on the avoid condition function. - :param G: The networkx graph - :type G: nx.Graph - :param source: The starting node - :type source: int + :param graph: Graph + :type graph: Graph + :param start_node: The starting node + :type start_node: int :param avoid: The function that determines if an edge should be avoided. It should take in three parameters: the start node, the end node, and the edges to avoid. It should return True if the edge should be avoided, False otherwise. + # frm: TODO: Documentation: Fix the comment above about the "avoid" function parameter. + # It may have once been accurate, but the original code below + # passed parameters to it of (node_id, neighbor_node_id, edge_data_dict) + # from NetworkX.Graph._succ So, "the edges to avoid" above is wrong. + # This whole issue is moot, however, since the only routine + # that is used as an avoid function ignores the third parameter. + # Or rather it used to avoid the third parameter, but it has + # been updated to only take two parameters, and the code below + # has been modified to use Graph.neighbors() instead of _succ + # because 1) we can't use NX and 2) because we don't need the + # edge data dictionary anyways... + # :type avoid: Callable :param targets: The target nodes that we would like to reach :type targets: Any - :returns: True if all of the targets are reachable from the source node + :returns: True if all of the targets are reachable from the start_node node under the avoid condition, False otherwise. :rtype: bool """ - G_succ = G._succ if G.is_directed() else G._adj - push = heappush pop = heappop - dist = {} # dictionary of final distances + node_distances = {} # dictionary of final distances seen = {} # fringe is heapq with 3-tuples (distance,c,node) # use the count c to avoid comparing nodes (may not be able to) c = count() fringe = [] - seen[source] = 0 - push(fringe, (0, next(c), source)) - - while not all(t in seen for t in targets) and fringe: - (d, _, v) = pop(fringe) - if v in dist: + seen[start_node] = 0 + push(fringe, (0, next(c), start_node)) + + + # frm: Original Code: + # + # while not all(t in seen for t in targets) and fringe: + # (d, _, v) = pop(fringe) + # if v in dist: + # continue # already searched this node. + # dist[v] = d + # for u, e in G_succ[v].items(): + # if avoid(v, u, e): + # continue + # + # vu_dist = dist[v] + 1 + # if u not in seen or vu_dist < seen[u]: + # seen[u] = vu_dist + # push(fringe, (vu_dist, next(c), u)) + # + # return all(t in seen for t in targets) + # + + + + # While we have not yet seen all of our targets and while there is + # still some fringe... + while not all(tgt in seen for tgt in targets) and fringe: + (distance, _, node_id) = pop(fringe) + if node_id in node_distances: continue # already searched this node. - dist[v] = d - for u, e in G_succ[v].items(): - if avoid(v, u, e): + node_distances[node_id] = distance + + for neighbor in graph.neighbors(node_id): + if avoid(node_id, neighbor): continue - vu_dist = dist[v] + 1 - if u not in seen or vu_dist < seen[u]: - seen[u] = vu_dist - push(fringe, (vu_dist, next(c), u)) + neighbor_distance = node_distances[node_id] + 1 + if neighbor not in seen or neighbor_distance < seen[neighbor]: + seen[neighbor] = neighbor_distance + push(fringe, (neighbor_distance, next(c), neighbor)) - return all(t in seen for t in targets) + # frm: TODO: Refactoring: Simplify this code. It computes distances and counts but + # never uses them. These must be relics of code copied + # from somewhere else where it had more uses... + return all(tgt in seen for tgt in targets) def single_flip_contiguous(partition: Partition) -> bool: """ @@ -87,7 +149,7 @@ def single_flip_contiguous(partition: Partition) -> bool: graph = partition.graph assignment = partition.assignment - def partition_edge_avoid(start_node: Any, end_node: Any, edge_attrs: Dict): + def _partition_edge_avoid(start_node: Any, end_node: Any): """ Helper function used in the graph traversal to avoid edges that cross between different assignments. It's crucial for ensuring that the traversal only considers paths within @@ -98,7 +160,7 @@ def partition_edge_avoid(start_node: Any, end_node: Any, edge_attrs: Dict): :param end_node: The end node of the edge. :type end_node: Any :param edge_attrs: The attributes of the edge (not used in this function). Needed - because this function is passed to :func:`are_reachable`, which expects the + because this function is passed to :func:`_are_reachable`, which expects the avoid function to have this signature. :type edge_attrs: Dict @@ -126,8 +188,10 @@ def partition_edge_avoid(start_node: Any, end_node: Any, edge_attrs: Dict): start_neighbor = random.choice(old_neighbors) # Check if all old neighbors in the same assignment are still reachable. - connected = are_reachable( - graph, start_neighbor, partition_edge_avoid, old_neighbors + # The "_partition_edge_avoid" function will prevent searching across + # a part (district) boundary + connected = _are_reachable( + graph, start_neighbor, _partition_edge_avoid, old_neighbors ) if not connected: @@ -138,7 +202,7 @@ def partition_edge_avoid(start_node: Any, end_node: Any, edge_attrs: Dict): return True -def affected_parts(partition: Partition) -> Set[int]: +def _affected_parts(partition: Partition) -> Set[int]: """ Checks which partitions were affected by the change of nodes. @@ -168,7 +232,7 @@ def affected_parts(partition: Partition) -> Set[int]: def contiguous(partition: Partition) -> bool: """ - Check if the parts of a partition are connected using :func:`networkx.is_connected`. + Check if the parts of a partition are connected :param partition: The proposed next :class:`~gerrychain.partition.Partition` :type partition: Partition @@ -176,11 +240,11 @@ def contiguous(partition: Partition) -> bool: :returns: Whether the partition is contiguous :rtype: bool """ + return all( - nx.is_connected(partition.subgraphs[part]) for part in affected_parts(partition) + is_connected_bfs(partition.subgraphs[part]) for part in _affected_parts(partition) ) - def contiguous_bfs(partition: Partition) -> bool: """ Checks that a given partition's parts are connected as graphs using a simple @@ -192,17 +256,36 @@ def contiguous_bfs(partition: Partition) -> bool: :returns: Whether the parts of this partition are connected :rtype: bool """ - parts_to_check = affected_parts(partition) - - # Generates a subgraph for each district and perform a BFS on it - # to check connectedness. - for part in parts_to_check: - adj = nx.to_dict_of_lists(partition.subgraphs[part]) - if _bfs(adj) is False: - return False - - return True - + + # frm: TODO: Refactoring: Figure out why this routine, contiguous_bfs() exists. + # + # It is mentioned in __init__.py so maybe it is used externally in legacy code. + # + # However, I have changed the code so that it just calls contiguous() and all + # of the tests pass, so I am going to assume that my comment below is accurate, + # that is, I am assuming that this function does not need to exist independently + # except for legacy purposes. Stated differently, if someone can verify that + # this routine is NOT needed for legacy purposes, then we can just delete it. + # + # It seems to be exactly the same conceptually as contiguous(). It looks + # at the "affected" parts - those that have changed node + # assignments from parent, and sees if those parts are + # contiguous. + # + # frm: Original Code: + # + # parts_to_check = _affected_parts(partition) + # + # # Generates a subgraph for each district and perform a BFS on it + # # to check connectedness. + # for part in parts_to_check: + # adj = nx.to_dict_of_lists(partition.subgraphs[part]) + # if _bfs(adj) is False: + # return False + # + # return True + + return contiguous(partition) def number_of_contiguous_parts(partition: Partition) -> int: """ @@ -213,7 +296,7 @@ def number_of_contiguous_parts(partition: Partition) -> int: :rtype: int """ parts = partition.assignment.parts - return sum(1 for part in parts if nx.is_connected(partition.subgraphs[part])) + return sum(1 for part in parts if is_connected_bfs(partition.subgraphs[part])) # Create an instance of SelfConfiguringLowerBound using the number_of_contiguous_parts function. @@ -235,11 +318,31 @@ def contiguous_components(partition: Partition) -> Dict[int, list]: subgraphs of that part of the partition :rtype: dict """ - return { - part: [subgraph.subgraph(nodes) for nodes in nx.connected_components(subgraph)] - for part, subgraph in partition.subgraphs.items() - } + # frm: TODO: Documentation: Migration Guide: NX vs RX Issues here: + # + # The call on subgraph() below is perhaps problematic because it will renumber + # node_ids... + # + # The issue is not that the code is incorrect (with RX there is really no other + # option), but rather that any legacy code will be unprepared to deal with the fact + # that the subgraphs returned are (I think) three node translations away from the + # original NX-Graph object's node_ids. + # + # Translations: + # + # 1) From NX to RX when partition was created + # 2) From top-level RX graph to the partition's subgraphs for each part (district) + # 3) From each part's subgraph to the subgraphs of contiguous_components... + # + + connected_components_in_each_partition = {} + for part, subgraph in partition.subgraphs.items(): + # create a subgraph for each set of connected nodes in the part's nodes + list_of_connected_subgraphs = subgraph.subgraphs_for_connected_components() + connected_components_in_each_partition[part] = list_of_connected_subgraphs + + return connected_components_in_each_partition def _bfs(graph: Dict[int, list]) -> bool: """ @@ -254,11 +357,11 @@ def _bfs(graph: Dict[int, list]) -> bool: """ q = [next(iter(graph))] visited = set() - total_vertices = len(graph) + num_nodes = len(graph) # Check if the district has a single vertex. If it does, then simply return # `True`, as it's trivially connected. - if total_vertices <= 1: + if num_nodes <= 1: return True # bfs! @@ -271,4 +374,29 @@ def _bfs(graph: Dict[int, list]) -> bool: visited.add(neighbor) q += [neighbor] - return total_vertices == len(visited) + return num_nodes == len(visited) + +# frm: TODO: Testing: Verify that is_connected_bfs() works - add a test or two... + +# frm: TODO: Refactoring: Move this code into graph.py. It is all about the Graph... + +# frm: TODO: Documentation: This code was obtained from the web - probably could be optimized... +# This code replaced calls on nx.is_connected() +def is_connected_bfs(graph: Graph): + if not graph: + return True + + nodes = list(graph.node_indices) + + start_node = random.choice(nodes) + visited = {start_node} + queue = [start_node] + + while queue: + current_node = queue.pop(0) + for neighbor in graph.neighbors(current_node): + if neighbor not in visited: + visited.add(neighbor) + queue.append(neighbor) + + return len(visited) == len(nodes) diff --git a/gerrychain/graph/graph.py b/gerrychain/graph/graph.py index fdd905a8..c3b3faf0 100644 --- a/gerrychain/graph/graph.py +++ b/gerrychain/graph/graph.py @@ -9,6 +9,8 @@ Note: This module relies on NetworkX, pandas, and geopandas, which should be installed and imported as required. + +TODO: Documentation: Update top-level documentation for graph.py """ import functools @@ -21,11 +23,25 @@ from networkx.readwrite import json_graph import pandas as pd +import rustworkx + from .adjacency import neighbors from .geo import GeometryError, invalid_geometries, reprojected -from typing import List, Iterable, Optional, Set, Tuple, Union +# frm: codereview note: removed type hints that are now baked into Python +from typing import Iterable, Optional, Generator, Union + +import geopandas as gp +from shapely.ops import unary_union +from shapely.prepared import prep + +import numpy +import scipy +# frm: TODO: Refactor: Move json_serialize() closer to its use. +# +# It should not be the first thing someone sees when looking at this code... +# def json_serialize(input_object: Any) -> Optional[int]: """ This function is used to handle one of the common issues that @@ -47,75 +63,595 @@ def json_serialize(input_object: Any) -> Optional[int]: return None - -class Graph(networkx.Graph): +class Graph: """ - Represents a graph to be partitioned, extending the :class:`networkx.Graph`. + frm TODO: Documentation: Clean up this documentation + + frm: this class encapsulates / hides the underlying graph which can either be a + NetworkX graph or a RustworkX graph. The intent is that it provides the same + external interface as a NetworkX graph (for all of the uses that GerryChain cares + about, at least) so that legacy code that operated on NetworkX based Graph objects + can continue to work unchanged. + + When a graph is added to a partition, however, the NX graph will be converted into + an RX graph and the NX graph will become unaccessible to the user. The RX graph + may also be "frozen" the way the NX graph was "frozen" in the legacy code, but we + have not yet gotten that far in the implementation. + + It is not clear whether the code that does the heavy lifting on partitions will + need to use the old NX syntax or whether it will be useful to allow unfettered + access to the RX graph so that RX code can be used in these modules. TBD... + - This class includes additional class methods for constructing graphs from shapefiles, - and for saving and loading graphs in JSON format. """ - def __repr__(self): - return "".format(len(self.nodes), len(self.edges)) + # Note: This class cannot have a constructor - because there is code that assumes + # that it can use the default constructor to create instances of it. + # That code is buried deep in non GerryChain code, so I don't really understand + # what it is doing, but the assignment of nx_graph and rx_graph class attributes/members + # needs to happen in the "from_xxx()" routines. + + # frm: TODO: Documentation: Add documentation for new data members I am adding: + # _nx_graph, _rx_graph, _node_id_to_parent_node_id_map, _is_a_subgraph + # _node_id_to_original_nx_node_id_map + # => used to recreate NX graph from an RX graph and also + # as an aid for testing @classmethod - def from_networkx(cls, graph: networkx.Graph) -> "Graph": + def from_networkx(cls, nx_graph: networkx.Graph) -> "Graph": """ - Create a Graph instance from a networkx.Graph object. + Create a :class:`Graph` from a NetworkX.Graph object - :param graph: The networkx graph to be converted. - :type graph: networkx.Graph + This supports the use case of users creating a graph using NetworkX + which is convenient - both for users of the previous implementation of + a GerryChain object which was a subclass of NetworkX.Graph and for + users more generally who are familiar with NetworkX. - :returns: The converted graph as an instance of this class. - :rtype: Graph + Note that most users will not ever call this function directly, + because they can create a GerryChain Partition object directly + from a NetworkX graph, and the Partition initialization code + will use this function to convert the NetworkX graph to a + GerryChain Graph object. + + :param nx_graph: A NetworkX.Graph object with node and edge data + to be converted into a GerryChain Graph object. + :type nx_graph: networkx.Graph + + :returns: ...text... + :rtype: + """ + graph = cls() + graph._nx_graph = nx_graph + graph._rx_graph = None + graph._is_a_subgraph = False # See comments on RX subgraph issues. + # Maps node_ids in the graph to the "parent" node_ids in the parent graph. + # For top-level graphs, this is just an identity map + graph._node_id_to_parent_node_id_map = {node_id: node_id for node_id in graph.node_indices} + # Maps node_ids in the graph to the "original" node_ids in parent graph. + # For top-level graphs, this is just an identity map + graph._node_id_to_original_nx_node_id_map = {node_id: node_id for node_id in graph.node_indices} + graph.nx_to_rx_node_id_map = None # only set when an NX based graph is converted to be an RX based graph + return graph + + @classmethod + def from_rustworkx(cls, rx_graph: rustworkx.PyGraph) -> "Graph": + """ + + + Create a :class:`Graph` from a RustworkX.PyGraph object + + There are three primary use cases for this routine: + 1) converting an NX-based Graph to be an RX-based + Graph, 2) creating a subgraph of an RX-based Graph, and + 3) creating a Graph whose node_ids do not need to be + mapped to some previous graph's node_ids. + + In a little more detail: + + 1) A typical way to use GerryChain is to create a graph + using NetworkX functionality and to then rely on the + initialization code in the Partition class to create + an RX-based Graph object. That initialization code + constructs a RustworkX PyGraph and then uses this + routine to create an RX-based Graph object, and it then + creates maps from the node_ids of the resulting RX-based + Graph back to the original NetworkX.Graph's node_ids. + + 2) When creating a subgraph of a RustworkX PyGraph + object, the node_ids of the subgraph are (in general) + different from those of the parent graph. So we + create a mapping from the subgraph's node_ids to the + node_ids of the parent. The subgraph() routine + creates a RustworkX PyGraph subgraph, then uses this + routine to create an RX-based Graph using that subgraph, + and it then creates the mapping of subgraph node_ids + to the parent (RX) graph's node_ids. + + 3) In those cases where no node_id mapping is needed + this routine provides a simple way to create an + RX-based GerryChain graph object. + + :param rx_graph: a RustworkX PyGraph object + :type rx_graph: rustworkx.PyGraph + + :returns: a GerryChain Graph object with an embedded RustworkX.PyGraph object + :rtype: "Graph" + """ + + # Ensure that the RX graph has node and edge data dictionaries + # + # While NX graphs always have node and edge data dictionaries, + # the node data for the nodes in RX graphs do not have to be + # a data dictionary - they can be any Python object. Since + # gerrychain code depends on having a data dictionary + # associated with nodes and edges, we need to check the RX + # graph to see if it already has node and edge data and if so, + # whether that node and edge data is a data dictionary. + # + # Note that there is no way to change the type of the data + # associated with an RX node. So if the data for a node + # is not already a dict then we have an unrecoverable error. + # + # However, RX does allow you to update the data for edges, + # so if we find an edge with no data (None), then we can + # create an empty dict for the edge data, and if the edge + # data is some other type, then we can also replace the + # existing edge data with a dict (retaining the original + # data as a value in the new dict) + + for node_id in rx_graph.node_indices(): + data_dict = rx_graph[node_id] + if not isinstance(data_dict, dict): + # Unrecoverable error - see above... + raise Exception("from_rustworkx(): RustworkX graph does not have node_data dictionary") + + for edge_id in rx_graph.edge_indices(): + data_dict = rx_graph.get_edge_data_by_index(edge_id) + if data_dict is None: + # Create an empty dict for edge_data + graph.update_edge_by_index(edge_id, {}) + if not isinstance(data_dict, dict): + # Create a new dict with the existing edge_data as an item + graph.update_edge_by_index(edge_id, {"__original_rx_edge_data": data_dict}) + + graph = cls() + graph._rx_graph = rx_graph + graph._nx_graph = None + graph._is_a_subgraph = False # See comments on RX subgraph issues. + + # frm: TODO: Documentation: from_rustworkx(): Make these comments more coherent + # + # Instead of these very specific comments, just say that at this + # point, we don't know whether the graph is derived from NX, is a + # subgraph, or is something that can stand alone, so the maps are + # all identity maps. It is responsibility of callers to reset the + # maps if that is appropriate... + + # Maps node_ids in the graph to the "parent" node_ids in the parent graph. + # For top-level graphs, this is just an identity map + graph._node_id_to_parent_node_id_map = {node_id: node_id for node_id in graph.node_indices} + + # This routine assumes that the rx_graph was not derived from an "original" NX + # graph, so the RX node_ids are considered to be the "original" node_ids and + # we create an identity map - each node_id maps to itself as the "original" node_id + # + # If this routine is used for an RX-based Graph that was indeed derived from an + # NX graph, then it is the responsibility of the caller to set + # the _node_id_to_original_nx_node_id_map appropriately. + graph._node_id_to_original_nx_node_id_map = {node_id: node_id for node_id in graph.node_indices} + + # only set when an NX based graph is converted to be an RX based graph + graph.nx_to_rx_node_id_map = None + + return graph + + def to_networkx_graph(self) -> networkx.Graph: + """ + Create a NetworkX.Graph object that has the same nodes, edges, + node_data, and edge_data as the GerryChain Graph object. + + The intended purpose of this routine is to allow a user to + run a MarkovChain - which uses an embedded RustworkX graph + and then extract an equivalent version of that graph with all + of its data as a NetworkX.Graph object - in order to use + NetworkX routines to access and manipulate the graph. + + In short, this routine allows users to use NetworkX + functionality on a graph after running a MarkovChain. + + If the GerryChain graph object is NX-based, then this + routine merely returns the embedded NetworkX.Graph object. + + :returns: A NetworkX.Graph object that is equivalent to the + GerryChain Graph object (nodes, edges, node_data, edge_data) + :rtype: networkx.Graph + """ + if self.is_nx_graph(): + return self.get_nx_graph() + + if not self.is_rx_graph(): + raise TypeError( + "Graph passed to 'to_networkx_graph()' must be a rustworkx graph" + ) + + # We have an RX-based Graph, and we want to create a NetworkX Graph object + # that has all of the node data and edge data, and which has the + # node_ids and edge_ids of the original NX graph. + # + # Original node_ids are those that were used in the original NX + # Graph used to create the RX-based Graph object. + # + + # Confirm that this RX based graph was derived from an NX graph... + if self._node_id_to_original_nx_node_id_map == None: + raise Exception("to_networkx_graph(): _node_id_to_original_nx_node_id_map is None") + + rx_graph = self.get_rx_graph() + + # Extract node data + node_data = [] + for node_id in rx_graph.node_indices(): + node_payload = rx_graph[node_id] + # Get the "original" node_id + original_nx_node_id = self.original_nx_node_id_for_internal_node_id(node_id) + node_data.append({"node_name": original_nx_node_id, **node_payload}) + + # Extract edge data + edge_data = [] + for edge_id in rx_graph.edge_indices(): + edge = rx_graph.get_edge_endpoints_by_index(edge_id) + edge_0_node_id = edge[0] + edge_1_node_id = edge[1] + # Get the "original" node_ids + edge_0_original_nx_node_id = self.original_nx_node_id_for_internal_node_id(edge_0_node_id) + edge_1_original_nx_node_id = self.original_nx_node_id_for_internal_node_id(edge_1_node_id) + edge_payload = rx_graph.get_edge_data_by_index(edge_id) + # Add edges and edge data using the original node_ids + # as the names/IDs for the nodes that make up the edge + edge_data.append({"source": edge_0_original_nx_node_id, "target": edge_1_original_nx_node_id, **edge_payload}) + + # Create Pandas DataFrames + + nodes_df = pd.DataFrame(node_data) + edges_df = pd.DataFrame(edge_data) + + # Create a NetworkX Graph object from the edges_df, using + # "source", and "tartet" to define edge node_ids, and adding + # all attribute data (True). + nx_graph = networkx.from_pandas_edgelist(edges_df, 'source', 'target', True, networkx.Graph) + + # Add all of the node_data, using the "node_name" attr as the NX Graph node_id + nodes_df = nodes_df.set_index('node_name') + networkx.set_node_attributes(nx_graph, nodes_df.to_dict(orient='index')) + + return nx_graph + + # frm: TODO: Refactoring: Create a defined type name "node_id" to use instead of "Any" + # + # This is purely cosmetic, but it would provide an opportunity to add a comment that + # talked about NX node_ids vs. RX node_ids and hence why the type for a node_id is + # a vague "Any"... + + def original_nx_node_id_for_internal_node_id(self, internal_node_id: Any) -> Any: + """ + Translate a node_id to its "original" node_id. + + :param internal_node_id: A node_id to be translated + :type internal_node_id: Any + + :returns: A translated node_id + :rtype: Any + """ + return self._node_id_to_original_nx_node_id_map[internal_node_id] + + # frm: TODO: Testing: Create a test for this routine + def original_nx_node_ids_for_set(self, set_of_node_ids: set[Any]) -> Any: + """ + Translate a set of node_ids to their "original" node_ids. + + :param set_of_node_ids: A set of node_ids to be translated + :type set_of_node_ids: set[Any] + + :returns: A set of translated node_ids + :rtype: set[Any] + """ + _node_id_to_original_nx_node_id_map = self._node_id_to_original_nx_node_id_map + new_set = {_node_id_to_original_nx_node_id_map[node_id] for node_id in set_of_node_ids} + return new_set + + # frm: TODO: Testing: Create a test for this routine + def original_nx_node_ids_for_list(self, list_of_node_ids: list[Any]) -> list[Any]: + """ + Translate a list of node_ids to their "original" node_ids. + + :param list_of_node_ids: A list of node_ids to be translated + :type list_of_node_ids: list[Any] + + :returns: A list of translated node_ids + :rtype: list[Any] + """ + # Utility routine to quickly translate a set of node_ids to their original node_ids + _node_id_to_original_nx_node_id_map = self._node_id_to_original_nx_node_id_map + new_list = [_node_id_to_original_nx_node_id_map[node_id] for node_id in list_of_node_ids] + return new_list + + def internal_node_id_for_original_nx_node_id(self, original_nx_node_id: Any) -> Any: + """ + Discover the "internal" node_id in the current GerryChain graph + that corresponds to the "original" node_id in the top-level + graph (presumably an NX-based graph object). + + This was originally created to facilitate testing where it was + convenient to express the test success criteria in terms of + "original" node_ids, but the actual test needed to be made + using the "internal" (RX) node_ids. + + :param original_nx_node_id: The "original" node_id + :type original_nx_node_id: Any + + :returns: The corresponding "internal" node_id + :rtype: Any + """ + # Note: TODO: Performance: This code is inefficient but it is not a priority to fix now... + # + # The code reverses the dict that maps internal node_ids to "original" + # node_ids, which has an entry for every node in the graph - hence large + # for large graphs, which is costly, but worse - it does this every time + # it is called, so if the calling code is looping through a list of nodes + # then this reverse dict computation will happen each time. + # + # The obvious fix is to just create the reverse map once when the "internal" + # graph is created. This would be simple to do and safe, because the + # "internal" graph is frozen. + # + # However, at present (December 2025) this routine is only ever used for + # tests, so I am putting it on the back burner... + + + # reverse the map so we can go from original node_id to internal node_id + orignal_node_id_to_internal_node_id_map = { + v: k for k,v in self._node_id_to_original_nx_node_id_map.items() + } + return orignal_node_id_to_internal_node_id_map[original_nx_node_id] + + def verify_graph_is_valid(self) -> bool: + """ + Verify that the graph is valid. + + This may be overkill, but the idea is that at least in + development mode, it would be prudent to check periodically + to see that the graph data structure has not been corrupted. + + :returns: True if the graph is deemed valid + :rtype: bool """ - g = cls(graph) - return g + + # frm: TODO: Performance: Only check verify_graph_is_valid() in development. + # + # For now, in order to assess performance differences between NX and RX + # I will just return True... + return True + + + # Sanity check - this is where to add additional sanity checks in the future. + + # frm: TODO: Code: Enhance verify_graph_is_valid to do more... + + # frm: TODO: Performance: verify_graph_is_valid() is expensive - called a lot + # + # Come up with a way to run this in "debug mode" - that is, while in development/testing + # but not in production. It actually accounted for 5% of runtime... + + # Checks that there is one and only one graph + if not ( + (self._nx_graph is not None and self._rx_graph is None) + or (self._nx_graph is None and self._rx_graph is not None) + ): + raise Exception("Graph.verify_graph_is_valid(): graph not properly configured") + + # frm: TODO: Performance: is_nx_graph() and is_rx_graph() are expensive. + # + # Not all of the calls on these routines are needed in production - some are just + # sanity checking. Find a way to NOT run this code when in production. + + # frm: TODO: Refactoring: Reorder these following routines in sensible order + + def is_nx_graph(self) -> bool: + """ + Determine if the graph is NX-based + + :rtype: bool + """ + # frm: TODO: Performance: Only check graph_is_valid() in production + # + # Find a clever way to only run this code in development. Commenting it out for now... + # self.verify_graph_is_valid() + return self._nx_graph is not None + + def get_nx_graph(self) -> networkx.Graph: + """ + Return the embedded NX graph object + + :rtype: networkx.Graph + """ + if not self.is_nx_graph(): + raise TypeError( + "Graph passed to 'get_nx_graph()' must be a networkx graph" + ) + return self._nx_graph + + def get_rx_graph(self) -> rustworkx.PyGraph: + """ + Return the embedded RX graph object + + :rtype: rustworkx.PyGraph + """ + if not self.is_rx_graph(): + raise TypeError( + "Graph passed to 'get_rx_graph()' must be a rustworkx graph" + ) + return self._rx_graph + + def is_rx_graph(self) -> bool: + """ + Determine if the graph is RX-based + + :rtype: bool + """ + # frm: TODO: Performance: Only check graph_is_valid() in production + # + # Find a clever way to only run this code in development. Commenting it out for now... + # self.verify_graph_is_valid() + return self._rx_graph is not None + + def convert_from_nx_to_rx(self) -> "Graph": + """ + Convert an NX-based graph object to be an RX-based graph object. + + The primary use case for this routine is support for users + constructing a graph using NetworkX functionality and then + converting that NetworkX graph to RustworkX when creating a + Partition object. + + + :returns: An RX-based graph that is "the same" as the given NX-based graph + :rtype: "Graph" + """ + + # Note that in both cases in the if-stmt below, the nodes are not copied. + # This is arguably dangerous, but in our case I think it is OK. Stated + # differently, the actual node data (the dictionaries) in the original + # graph (self) will be reused in the returned graph - either because we + # are just returning the same graph (if it is already based on rx.PyGraph) + # or if we are converting it from NX. + # + self.verify_graph_is_valid() + if self.is_nx_graph(): + + if (self._is_a_subgraph): + # This routine is intended to be used in exactly one place - in converting + # an NX based Graph object to be RX based when creating a Partition object. + # In the future, it might become useful for other reasons, but until then + # to guard against careless uses, the code will insist that it not be a subgraph. + + # frm: TODO: Documentation: Add a comment about the intended use of this routine to its + # overview comment above. + raise Exception("convert_from_nx_to_rx(): graph to be converted is a subgraph") + + nx_graph = self._nx_graph + rx_graph = rustworkx.networkx_converter(nx_graph, keep_attributes=True) + + # Note that the resulting RX graph will have multigraph set to False which + # ensures that there is never more than one edge between two specific nodes. + # This is perhaps not all that interesting in general, but it is critical + # when getting the edge_id from an edge using RX.edge_indices_from_endpoints() + # routine - because it ensures that only a single edge_id is returned... + + converted_graph = Graph.from_rustworkx(rx_graph) + + # Some graphs have geometry data (from a geodataframe), so preserve it if it exists + if hasattr(self, "geometry"): + converted_graph.geometry = self.geometry + + # Create a mapping from the old NX node_ids to the new RX node_ids (created by + # RX when it converts from NX) + nx_to_rx_node_id_map = { + converted_graph.node_data(node_id)["__networkx_node__"]: node_id + for node_id in converted_graph._rx_graph.node_indices() + } + converted_graph.nx_to_rx_node_id_map = nx_to_rx_node_id_map + + # We also have to update the _node_id_to_original_nx_node_id_map to refer to the node_ids + # in the NX Graph object. + _node_id_to_original_nx_node_id_map = {} + for node_id in converted_graph.node_indices: + original_nx_node_id = converted_graph.node_data(node_id)["__networkx_node__"] + _node_id_to_original_nx_node_id_map[node_id] = original_nx_node_id + converted_graph._node_id_to_original_nx_node_id_map = _node_id_to_original_nx_node_id_map + + return converted_graph + elif self.is_rx_graph(): + return self + else: + raise TypeError( + "Graph passed to 'convert_from_nx_to_rx()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) + + def get_nx_to_rx_node_id_map(self) -> dict[Any, Any]: + """ + Return the dict that maps NX node_ids to RX node_ids + + The primary use case for this routine is to support automatically + converting NX-based graph objects to be RX-based when creating a + Partition object. The issue is that when you convert from NX to RX + the node_ids change and so you need to update the Partition object's + Assignment to use the new RX node_ids. This routine is used + to translate those NX node_ids to the new RX node_ids when + initializing a Partition object. + + :rtype: dict[Any, Any] + """ + # Simple getter method + if not self.is_rx_graph(): + raise TypeError( + "Graph passed to 'get_nx_to_rx_node_id()' is not a rustworkx graph" + ) + + return self.nx_to_rx_node_id_map @classmethod - def from_json(cls, json_file: str) -> "Graph": + def from_json(cls, json_file_name: str) -> "Graph": """ - Load a graph from a JSON file in the NetworkX json_graph format. + Create a :class:`Graph` from a JSON file - :param json_file: Path to JSON file. - :type json_file: str + :param json_file_name: JSON file + # frm: TODO: Documentation: more detail on contents of JSON file needed here + :type json_file_name: str - :returns: The loaded graph as an instance of this class. - :rtype: Graph + :returns: A GerryChain Graph object with data from JSON file + :rtype: "Graph" """ - with open(json_file) as f: + + # Note that this returns an NX-based Graph object. At some point in + # the future, if we embrace an all RX world, it will make sense to + # have it produce an RX-based Graph object. + + with open(json_file_name) as f: data = json.load(f) - g = json_graph.adjacency_graph(data) - graph = cls.from_networkx(g) + + # A bit of Python magic - an adjacency graph is a dict of dict of dicts + # which is structurally equivalent to a NetworkX graph, so you can just + # pretend that is what it is and it all works. + nx_graph = json_graph.adjacency_graph(data) + + graph = cls.from_networkx(nx_graph) graph.issue_warnings() - return graph + return graph - def to_json( - self, json_file: str, *, include_geometries_as_geojson: bool = False - ) -> None: + def to_json(self, json_file_name: str, include_geometries_as_geojson: bool = False) -> None: """ - Save a graph to a JSON file in the NetworkX json_graph format. + Dump a GerryChain Graph object to disk as a JSON file - :param json_file: Path to target JSON file. - :type json_file: str - :param bool include_geometry_as_geojson: Whether to include - any :mod:`shapely` geometry objects encountered in the graph's node - attributes as GeoJSON. The default (``False``) behavior is to remove - all geometry objects because they are not serializable. Including the - GeoJSON will result in a much larger JSON file. - :type include_geometries_as_geojson: bool, optional + :param json_file_name: name of JSON file to be created + :type json_file_name: str - :returns: None + :rtype: None """ - data = json_graph.adjacency_data(self) + # frm TODO: Code: Implement graph.to_json for an RX based graph + if not self.is_nx_graph(): + raise TypeError( + "Graph passed to 'to_json()' is not a networkx graph" + ) + + data = json_graph.adjacency_data(self._nx_graph) if include_geometries_as_geojson: convert_geometries_to_geojson(data) else: remove_geometries(data) - with open(json_file, "w") as f: + with open(json_file_name, "w") as f: json.dump(data, f, default=json_serialize) @classmethod @@ -123,7 +659,7 @@ def from_file( cls, filename: str, adjacency: str = "rook", - cols_to_add: Optional[List[str]] = None, + cols_to_add: Optional[list[str]] = None, reproject: bool = False, ignore_errors: bool = False, ) -> "Graph": @@ -138,7 +674,7 @@ def from_file( :type adjacency: str, optional :param cols_to_add: The names of the columns that you want to add to the graph as node attributes. Default is None. - :type cols_to_add: Optional[List[str]], optional + :type cols_to_add: Optional[list[str]], optional :param reproject: Whether to reproject to a UTM projection before creating the graph. Default is False. :type reproject: bool, optional @@ -161,7 +697,6 @@ def from_file( or install ``geopandas`` separately. """ - import geopandas as gp df = gp.read_file(filename) graph = cls.from_geodataframe( @@ -171,7 +706,42 @@ def from_file( reproject=reproject, ignore_errors=ignore_errors, ) - graph.graph["crs"] = df.crs.to_json() + # frm: TODO: Documentation: Make it clear that this creates an NX-based + # Graph object. + # + # Also add some documentation (here or elsewhere) + # about what CRS data is and what it is used for. + # + # Note that the NetworkX.Graph.graph["crs"] is only + # ever accessed in this file (graph.py), so I am not + # clear what it is used for. It seems to just be set + # and never used except to be written back out to JSON. + # + # The issue (I think) is that we do not preserve graph + # attributes when we convert to RX from NX, so if the + # user wants to write an RX based Graph back out to JSON + # this data (and another other graph level data) would be + # lost. + # + # So - need to figure out what CRS is used for... + # + # Peter commented on this in a PR comment: + # + # CRS stands for "Coordinate Reference System" which can be thought of + # as the projection system used for the polygons contained in the + # geodataframe. While it is not used in any of the graph operations of + # GerryChain, it may be used in things like validators and updaters. Since + # the CRS determines the projection system used by the underlying + # geodataframe, any area or perimeter computations encoded on the graph + # are stored with the understanding that those values may inherit + # distortions from projection used. We keep this around as metadata so + # that, in the event that the original geodataframe source is lost, + # the graph metadata still carries enough information for us to sanity + # check the area and perimeter computations if we get weird numbers. + + + # Store CRS data as an attribute of the NX graph + graph._nx_graph.graph["crs"] = df.crs.to_json() return graph @classmethod @@ -179,13 +749,17 @@ def from_geodataframe( cls, dataframe: pd.DataFrame, adjacency: str = "rook", - cols_to_add: Optional[List[str]] = None, + cols_to_add: Optional[list[str]] = None, reproject: bool = False, ignore_errors: bool = False, crs_override: Optional[Union[str, int]] = None, ) -> "Graph": + + # frm: Changed to operate on a NetworkX.Graph object and then convert to a + # Graph object at the end of the function. + """ - Creates the adjacency :class:`Graph` of geometries described by `dataframe`. + Create the adjacency :class:`Graph` of geometries described by `dataframe`. The areas of the polygons are included as node attributes (with key `area`). The shared perimeter of neighboring polygons are included as edge attributes (with key `shared_perim`). @@ -208,7 +782,7 @@ def from_geodataframe( :type adjacency: str, optional :param cols_to_add: The names of the columns that you want to add to the graph as node attributes. Default is None. - :type cols_to_add: Optional[List[str]], optional + :type cols_to_add: Optional[list[str]], optional :param reproject: Whether to reproject to a UTM projection before creating the graph. Default is ``False``. :type reproject: bool, optional @@ -252,21 +826,62 @@ def from_geodataframe( # Generate dict of dicts of dicts with shared perimeters according # to the requested adjacency rule - adjacencies = neighbors(df, adjacency) - graph = cls(adjacencies) - - graph.geometry = df.geometry - - graph.issue_warnings() + adjacencies = neighbors(df, adjacency) # Note - this is adjacency.neighbors() + + nx_graph = networkx.Graph(adjacencies) + + # frm: TODO: Documentation: Document what geometry is used for. + # + # Need to grok what geometry is used for - it is used in partition.py.plot() + # and maybe that is the only place it is used, but it is also used below + # to set other data, such as add_boundary_perimeters() and areas. The + # reason this is an issue is because I need to know what to carry over to + # the RX version of a Graph when I convert to RX when making a Partition. + # Partition.plot() uses this information, so it needs to be available in + # the RX version of a Graph - which essentially means that I need to grok + # how plot() works and where it gets its information and how existing + # users use it... + # + # There is a test failure due to geometry not being available after conversion to RX. + # + # Here is what Peter said in the PR: + # + # The geometry attribute on df is a special attribute that only appears on + # geodataframes. This is just a list of polygons representing some real-life + # geometries underneath a certain projection system (CRS). These polygons can + # then be fed to matplotilb to make nice plots of things, or they can be used + # to compute things like area and perimeter for use in updaters and validators + # that employ some sort of Reock score (uncommon, but unfortunately necessary in + # some jurisdictions). We probably don't need to store this as an attribute on + # the Graph._nxgraph object (or the Graph._rxgraph) object, however. In fact, it + # might be best to just make a Graph.dataframe attribute to store all of the + # graph data on, and add attributes to _nxgraph and _rxgraph nodes as needed + # + + nx_graph.geometry = df.geometry + + # frm: TODO: Refactoring: Rethink the name of add_boundary_perimeters + # + # It acts on an nx_graph which seems wrong with the given name. + # Maybe it should be: add_boundary_perimeters_to_nx_graph() + # + # Need to check in with Peter to see if this is considered + # part of the external API. + + # frm: TODO: Refactoring: Create an nx_utilities module + # + # It raises the question of whether there should be an nx_utilities + # module for stuff designed to only work on nx_graph objects. + # + # Note that Peter said: "I like this idea" + # # Add "exterior" perimeters to the boundary nodes - add_boundary_perimeters(graph, df.geometry) + add_boundary_perimeters(nx_graph, df.geometry) # Add area data to the nodes areas = df.geometry.area.to_dict() - networkx.set_node_attributes(graph, name="area", values=areas) - - graph.add_data(df, columns=cols_to_add) + networkx.set_node_attributes(nx_graph, name="area", values=areas) if crs_override is not None: df.set_crs(crs_override, inplace=True) @@ -278,36 +893,317 @@ def from_geodataframe( "Otherwise, please set the CRS using the `crs_override` parameter. " "Attempting to proceed without a CRS." ) - graph.graph["crs"] = None + nx_graph.graph["crs"] = None else: - graph.graph["crs"] = df.crs.to_json() + nx_graph.graph["crs"] = df.crs.to_json() + + graph = cls.from_networkx(nx_graph) + + # frm: Moved from earlier in the function so that we would have a Graph + # object (vs. NetworkX.Graph object) + + graph.add_data(df, columns=cols_to_add) + graph.issue_warnings() return graph - def lookup(self, node: Any, field: Any) -> Any: + # Performance Note: + # + # Most of the functions in the Graph class will be called after a + # partition has been created and the underlying graph converted + # to be based on RX. So, by testing first for RX we actually + # save a significant amount of time because we do not need to + # also test for NX (if you test for NX first then you do two tests). + # + + @property + def node_indices(self) -> set[Any]: """ - Lookup a node/field attribute. + Return a set of the node_ids in the graph - :param node: Node to look up. - :type node: Any - :param field: Field to look up. - :type field: Any + :rtype: set[Any] + """ + self.verify_graph_is_valid() + + # frm: TODO: Refactoring: node_indices() does the same thing that graph.nodes does - returning a list of node_ids. + # Do we really want to support two ways of doing the same thing? + # Actually this returns a set rather than a list - not sure that matters though... + # + # My code uses node_indices() to make it clear we are talking about node_ids... + # + # The question is whether to deprecate nodes()... + + if (self.is_rx_graph()): + return set(self._rx_graph.node_indices()) + elif (self.is_nx_graph()): + return set(self._nx_graph.nodes) + else: + raise TypeError( + "Graph passed to 'node_indices()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) - :returns: The value of the attribute `field` at `node`. + @property + def edge_indices(self) -> set[Any]: + """ + Return a set of the edge_ids in the graph + + :rtype: set[Any] + """ + self.verify_graph_is_valid() + + if (self.is_rx_graph()): + # A set of edge_ids for the edges + return set(self._rx_graph.edge_indices()) + elif (self.is_nx_graph()): + # A set of edge_ids (tuples) extracted from the graph's EdgeView + return set(self._nx_graph.edges) + else: + raise TypeError( + "Graph passed to 'edge_indices()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) + + def get_edge_from_edge_id(self, edge_id: Any) -> tuple[Any, Any]: + """ + Return the edge (tuple of node_ids) corresponding to the + given edge_id + + Note that in NX, an edge_id is the same as an edge - it is + just a tuple of node_ids. However, in RX, an edge_id is + an integer, so if you want to get the tuple of node_ids + you need to use the edge_id to get that tuple... + + :param edge_id: The ID of the desired edge + :type edge_id: Any + + :returns: An edge, namely a tuple of node_ids + :rtype: tuple[Any, Any] + """ + + self.verify_graph_is_valid() + + if (self.is_rx_graph()): + # In RX, we need to go get the edge tuple + # frm: TODO: Performance - use get_edge_endpoints_by_index() to get edge + # + # The original RX code (before October 27, 2025): + # return self._rx_graph.edge_list()[edge_id] + endpoints = self._rx_graph.get_edge_endpoints_by_index(edge_id) + return (endpoints[0], endpoints[1]) + elif (self.is_nx_graph()): + # In NX, the edge_id is also the edge tuple + return edge_id + else: + raise TypeError( + "Graph passed to 'get_edge_from_edge_id()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) + + # frm: TODO: Refactoring: Create abstract "edge" and "edge_id" type names + # + # As with node_id, this is cosmetic but it will provide a nice place to + # put a comment about the difference between NX and RX and it will make + # the type annotations make more sense... + + def get_edge_id_from_edge(self, edge: tuple[Any, Any]) -> Any: + """ + Get the edge_id that corresponds to the given edge. + + In RX an edge_id is an integer that designates an edge (an edge is + a tuple of node_ids). In NX, an edge_id IS the tuple of node_ids. + So, in general, to support both NX and RX, if you want to get access + to the edge data for an edge (tuple of node_ids), you need to + ask for the edge_id. + + This functionality is needed, for instance, when + + :param edge: A tuple of node_ids. + :type edge: tuple[Any, Any] + + :returns: The ID associated with the given edge :rtype: Any """ - return self.nodes[node][field] + self.verify_graph_is_valid() + + if (self.is_rx_graph()): + + # frm: TODO: Performance: Perhaps get_edge_id_from_edge() is too expensive... + # + # If this routine becomes a signficant performance issue, then perhaps + # we can change the algorithms that use it so that it is not needed. + # In particular, there are several routines in tree.py that use it + # by traversing chains of nodes (successors and predecessors) which + # requires the code to recreate the edges from the nodes in hand. This + # was not a problem in an NX world - the tuple of nodes was exactly what + # and edge_id was, but in the RX world it is not - necessitating this routine. + # + # BUT... If the code had chains of edges rather than chains of nodes, + # then you could have the edge_ids at hand already and avoid having to + # do this lookup. + # + # However, it may be that the RX edge_indices_from_endpoints() is smart + # enough (for instance if it caches a dict mapping) that the performance + # hit is minimal... Here's to hoping that RX is "smart enough"... ;-) + + # Note that while in general the routine, edge_indices_from_endpoints(), + # can return more than one edge in the case of a Multi-Graph (a graph that + # allows more than one edge between two nodes), we can rely on it only + # returning a single edge because the RX graph object has multigraph set + # to false by RX.networkx_converter() - because the NX graph was undirected... + # + edge_indices = self._rx_graph.edge_indices_from_endpoints(edge[0], edge[1]) + return edge_indices[0] # there will always be one and only one + elif (self.is_nx_graph()): + # In NX, the edge_id is also the edge tuple + return edge + else: + raise TypeError( + "Graph passed to 'get_edge_id_from_edge()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) @property - def node_indices(self): - return set(self.nodes) + def nodes(self) -> list[Any]: + """ + Return a list of all of the node_ids in the graph. + + This routine still exists because there is a lot of legacy + code that uses this syntax to iterate through all of the nodes + in a graph. + + There is another routine, node_indices(), which does essentially + the same thing (it returns a set of node_ids, however, rather than + a list). + + Why have two routines that do the same thing? The answer is that with + move to RX, it seemed appropriate to emphasize the distinction + between objects and the IDs for objects, hence the introduction of + node_indices() and edge_indices() routines. This distinction is + critical for edges, but mostly not important for nodes. In fact + this routine is implemented by just converting node_indices to a list. + So, it is essentially a style issue - when referring to nodes, we + are almost always really referring to node_ids, so why not use a + routine called node_indices()? + + Note that there is a subtle point to be made about node names vs. + node_ids. It was common before the transition to RX to create + nodes with IDs that were essentially names. That is, the ID had + semantic weight. This is not true with RX node_ids. So, any + code that relies on the semantics of a node's ID (treating it + like a name) is suspect in the new RX world. + + :returns: A list of all of the node_ids in the graph + :rtype: list[Any] + """ + + # frm: TODO: Documentation: Warn users in Migration Guide that nodes() has gone away + # + # Since the legacy code implemented a GerryChain Graph as a subclass of NetworkX.Graph + # legacy code could take advantage of NX cleverness - NX returns a NodeView object for + # nx_graph.nodes which supports much more than just a list of node_ids (which is all that + # code below does). + # + # Probably the most common use of nx_graph.nodes was to access node data as in: + # + # nx_graph.nodes[node_id][] + # + # In the new world, to do that you need to do: + # + # graph.node_data(node_id)[] + # + # So, almost the same number of keystrokes, but if a legacy user uses nodes[...] the + # old way, it won't work out well. + # + + # frm: TODO: Refactoring: Think about whether to do away entirely with graph.nodes + # + # All this routine does now is to coerce the set of nodes obtained by node_indices() + # to be a list (which I think is unnecessary). So, why have it at all? Why not just + # tell legacy users via an exception that it no longer exists? + # + # On the other hand, it maybe does no harm to allow legacy users to indulge in + # what appears to be a very common idiom in legacy code... + + self.verify_graph_is_valid() + + if (self.is_rx_graph()): + # A list of integer node_ids + return list(self._rx_graph.node_indices()) + elif (self.is_nx_graph()): + # A list of node_ids - + return list(self._nx_graph.nodes) + else: + raise TypeError( + "Graph passed to 'nodes()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) @property - def edge_indices(self): - return set(self.edges) + def edges(self) -> set[tuple[Any, Any]]: + """ + Return a set of all of the edges in the graph, where each + edge is a tuple of node_ids + + :rtype: set[tuple[Any, Any]]: + """ + # Return a set of edge tuples + + # frm: TODO: Code: ???: Should edges return a list instead of a set? + # + # Peter said he thought users would expect a list - but why? + + self.verify_graph_is_valid() + + if (self.is_rx_graph()): + # A set of tuples for the edges + return set(self._rx_graph.edge_list()) + elif (self.is_nx_graph()): + # A set of tuples extracted from the graph's EdgeView + return set(self._nx_graph.edges) + else: + raise TypeError( + "Graph passed to 'edges()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) + + def add_edge(self, node_id1: Any, node_id2: Any) -> None: + """ + Add an edge to the graph from node_id1 to node_id2 + + :param node_id1: The node_id for one of the nodes in the edge + :type node_id1: Any + :param node_id2: The node_id for one of the nodes in the edge + :type node_id2: Any + + :rtype: None + """ + + # frm: TODO: Code: add_edge(): Check that nodes exist and that they have data dicts. + # + # This checking should probably be limited to development mode, but + # the issue is that an RX node need not have a data value that is + # a dict, but GerryChain code depends on having a data dict. So, + # it makes sense to test and make sure that the nodes exist and + # have a data dict... + + # frm: TODO: Code: add_edge(): Do we need to check to make sure the edge does not already exist? + + self.verify_graph_is_valid() + + if (self.is_rx_graph()): + # empty dict tells RX the edge data will be a dict + self._rx_graph.add_edge(node_id1, node_id2, {}) + elif (self.is_nx_graph()): + self._nx_graph.add_edge(node_id1, node_id2) + else: + raise TypeError( + "Graph passed to 'add_edge()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) def add_data( - self, df: pd.DataFrame, columns: Optional[Iterable[str]] = None + self, df: pd.DataFrame, columns: Optional[Iterable[str]] = None ) -> None: """ Add columns of a DataFrame to a graph as node attributes @@ -315,29 +1211,37 @@ def add_data( :param df: Dataframe containing given columns. :type df: :class:`pandas.DataFrame` - :param columns: List of dataframe column names to add. Default is None. + :param columns: list of dataframe column names to add. Default is None. :type columns: Optional[Iterable[str]], optional :returns: None """ + if not (self.is_nx_graph()): + raise TypeError( + "Graph passed to 'add_data()' is not a networkx graph" + ) + if columns is None: columns = list(df.columns) check_dataframe(df[columns]) + # Create dict: {node_id: {attr_name: attr_value}} column_dictionaries = df.to_dict("index") - networkx.set_node_attributes(self, column_dictionaries) + nx_graph = self._nx_graph + networkx.set_node_attributes(nx_graph, column_dictionaries) - if hasattr(self, "data"): - self.data[columns] = df[columns] # type: ignore + if hasattr(nx_graph, "data"): + nx_graph.data[columns] = df[columns] # type: ignore else: - self.data = df[columns] + nx_graph.data = df[columns] + def join( self, dataframe: pd.DataFrame, - columns: Optional[List[str]] = None, + columns: Optional[list[str]] = None, left_index: Optional[str] = None, right_index: Optional[str] = None, ) -> None: @@ -349,7 +1253,7 @@ def join( :type dataframe: :class:`pandas.DataFrame` :columns: The columns whose data you wish to add to the graph. If not provided, all columns are added. Default is None. - :type columns: Optional[List[str]], optional + :type columns: Optional[list[str]], optional :left_index: The node attribute used to match nodes to rows. If not provided, node IDs are used. Default is None. :type left_index: Optional[str], optional @@ -371,8 +1275,21 @@ def join( column_dictionaries = df.to_dict() + # frm: TODO: Code: Implement graph.join() for RX + # + # This is low priority given that current suggested coding + # strategy of creating the graph using NX and then letting + # GerryChain convert it automatically to RX. In this scenario + # any joins would happen to the NX-based graph only. + + if not self.is_nx_graph(): + raise TypeError( + "Graph passed to join() is not a networkx graph" + ) + nx_graph = self._nx_graph + if left_index is not None: - ids_to_index = networkx.get_node_attributes(self, left_index) + ids_to_index = networkx.get_node_attributes(nx_graph, left_index) else: # When the left_index is node ID, the matching is just # a redundant {node: node} dictionary @@ -385,38 +1302,1021 @@ def join( for node_id, index in ids_to_index.items() } - networkx.set_node_attributes(self, node_attributes) + networkx.set_node_attributes(nx_graph, node_attributes) @property - def islands(self) -> Set: - """ - :returns: The set of degree-0 nodes. - :rtype: Set + def islands(self) -> set[Any]: """ - return set(node for node in self if self.degree[node] == 0) + Return a set of all node_ids that are not connected via an + edge to any other node in the graph - that is, nodes with + degree = 0 + :returns: A set of all node_ids for nodes of degree 0 + :rtype: set[Any] + """ + # Return all nodes of degree 0 (those not connected in an edge to another node) + return set(node_id for node_id in self.node_indices if self.degree(node_id) == 0) + + def is_directed(self) -> bool: + # frm TODO: Code: Delete this code: graph.is_directed() once convinced it is safe to do so... + # + # I added it because code in contiguity.py + # called nx.is_connected() which eventually called is_directed() + # assuming the graph was an nx_graph. + # + # Changing from return False to raising an exception just to make + # sure nobody uses it. + + raise NotImplementedError("graph.is_directed() should not be used") + + def warn_for_islands(self) -> None: """ - :returns: None + Issue a warning if there are any islands in the graph - that is, + if there are any nodes in the graph that are not connected to any + other node (degree = 0) - :raises: UserWarning if the graph has any islands (degree-0 nodes). + :rtype: None """ islands = self.islands if len(self.islands) > 0: warnings.warn( "Found islands (degree-0 nodes). Indices of islands: {}".format(islands) - ) - + ) + def issue_warnings(self) -> None: """ - :returns: None + Issue any warnings concerning the content or structure + of the graph. - :raises: UserWarning if the graph has any red flags (right now, only islands). + :rtype: None """ self.warn_for_islands() + def __len__(self) -> int: + """ + Return the number of nodes in the graph + + :rtype: int + """ + return len(self.node_indices) + + def __getattr__(self, __name: str) -> Any: + """ + + + :param : ...text... + ...more text... + :type : -def add_boundary_perimeters(graph: Graph, geometries: pd.Series) -> None: + :returns: ...text... + :rtype: + """ + # frm: TODO: Code: Get rid of _getattr_ eventually - it is very dangerous... + + # frm: Interesting bug lurking if __name is "nx_graph". This occurs when legacy code + # uses the default constructor, Graph(), and then references a built-in NX + # Graph method, such as my_graph.add_edges(). In this case the built-in NX + # Graph method is not defined, so __getattr__() is called to try to figure out + # what it could be. This triggers the call below to self.is_nx_graph(), which + # references self._nx_graph (which is undefined/None) which triggers another + # call to __getattr__() which is BAD... + # + # I think the solution is to not rely on testing whether nx_graph and rx_graph + # are None - but rather to have explicit is_nx_or_rx_graph data member which + # is set to one of "NX", "RX", "not_set". + # + # For now, I am just going to return None if __name is "_nx_graph" or "_rx_graph". + # + # Peter's comments from PR: + # + # Oh interesting; good catch! The flag approach seems like a good solution to me. + # It's very, very rare to use the default constructor, so I don't imagine that + # people will really run into this. + + # frm: TODO: Code: Fix this hack (in __getattr__) - see comment above... + if (__name == "_nx_graph") or (__name == "_rx_graph"): + return None + + # If attribute doesn't exist on this object, try + # its underlying graph object... + if (self.is_rx_graph()): + return object.__getattribute__(self._rx_graph, __name) + elif (self.is_nx_graph()): + return object.__getattribute__(self._nx_graph, __name) + else: + raise TypeError( + "Graph passed to '__gettattr__()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) + + def __getitem__(self, __name: str) -> Any: + """ + + + :param : ...text... + ...more text... + :type : + + :returns: ...text... + :rtype: + """ + # frm: TODO: Code: Does any of the code actually use __getitem__ ? + # + # It is a clever Python way to use square bracket + # notation to access something (anything) you want. + # + # In this case, it returns the NetworkX AtlasView + # of neighboring nodes - looks like a dictionary + # with a key of the neighbor node_id and a value + # with the neighboring node's data (another dict). + # + # I am guessing that it is only ever used to get + # a list of the neighbor node_ids, in which case + # it is functionally equivalent to self.neighbors(). + # + # *sigh* + # + self.verify_graph_is_valid() + + if (self.is_rx_graph()): + # frm TODO: Code: Decide if __getitem__() should work for RX + raise TypeError("Graph._getitem__() is not defined for a rustworkx graph") + elif (self.is_nx_graph()): + return self._nx_graph[__name] + else: + raise TypeError( + "Graph passed to '__getitem__()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) + + def __iter__(self) -> Iterable[Any]: + """ + + + :param : ...text... + ...more text... + :type : + + :returns: ...text... + :rtype: + """ + yield from self.node_indices + + def subgraph(self, nodes: Iterable[Any]) -> "Graph": + """ + Create a subgraph that contains the given nodes. + + Note that creating a subgraph of an RustworkX (RX) graph + renumbers the nodes, so that a node that had node_id: 4 + in the parent graph might have node_id: 2 in the subgraph. + This is a HUGE difference from the NX world where the + node_ids in a subgraph do not change from those in the + parent graph. + + In order to make sense of the nodes in a subgraph in the + RX world, we need to maintain mappings from the node_ids + in the subgraph to the node_ids of the immediate parent + graph and to the "original" top-level graph that contains + all of the nodes. You will notice the creation of those + maps in the code below. + + :param nodes: The nodes to be included in the subgraph + :type nodes: Iterable[Any] + + :returns: A subgraph containing the given nodes. + :rtype: "Graph" + """ + + """ + frm: RX Documentation: + + Subgraphs are one of the biggest differences between NX and RX, because RX creates new + node_ids for the nodes in the subgraph, starting at 0. So, if you create a subgraph with + a list of nodes: [45, 46, 47] the nodes in the subgraph will be [0, 1, 2]. + + This creates problems for functions that operate on subgraphs and want to return results + involving node_ids to the caller. To solve this, we define a _node_id_to_parent_node_id_map whenever + we create a subgraph that will provide the node_id in the parent for each node in the subgraph. + For NX this is a no-op, and the _node_id_to_parent_node_id_map is just an identity map - each node_id is + mapped to itself. For RX, however, we store the parent_node_id in the node's data before + creating the subgraph, and then in the subgraph, we use the parent's node_id to construct + a map from the subgraph node_id to the parent_node_id. + + This means that any function that wants to return results involving node_ids can safely + just translate node_ids using the _node_id_to_parent_node_id_map, so that the results make sense in + the caller's context. + + A note of caution: if the caller retains the subgraph after using it in a function call, + the caller should almost certainly not use the node_ids in the subgraph for ANYTHING. + It would be safest to reset the value of the subgraph to None after using it as an + argument to a function call. + + Also, for both RX and NX, we set the _node_id_to_parent_node_id_map to be the identity map for top-level + graphs on the off chance that there is a function that takes both top-level graphs and + subgraphs as a parameter. This allows the function to just always do the node translation. + In the case of a top-level graph the translation will be a no-op, but it will be correct. + + Also, we set the _is_a_subgraph = True, so that we can detect whether a parameter passed into + a function is a top-level graph or not. This will allow us to debug the code to determine + if assumptions about a parameter always being a subgraph is accurate. It also helps to + educate future readers of the code that subgraphs are "interesting"... + + """ + + self.verify_graph_is_valid() + + new_subgraph = None + + if (self.is_nx_graph()): + nx_subgraph = self._nx_graph.subgraph(nodes) + new_subgraph = self.from_networkx(nx_subgraph) + # for NX, the node_ids in subgraph are the same as in the parent graph + _node_id_to_parent_node_id_map = {node: node for node in nodes} + _node_id_to_original_nx_node_id_map = {node: node for node in nodes} + elif (self.is_rx_graph()): + if isinstance(nodes, frozenset) or isinstance(nodes, set): + nodes = list(nodes) + + # For RX, the node_ids in the subgraph change, so we need a way to map subgraph node_ids + # into parent graph node_ids. To do so, we add the parent node_id into the node data + # so that in the subgraph we can find it and then create the map. + # + # Note that this works because the node_data dict is shared by the nodes in both the + # parent graph and the subgraph, so we can set the "parent" node_id in the parent before + # creating the subgraph, and that value will be available in the subgraph even though + # the subgraph will have a different node_id for the same node. + # + # This value is removed from the node_data below after creating the subgraph. + # + for node_id in nodes: + self.node_data(node_id)["parent_node_id"] = node_id + + # It is also important for all RX graphs (subgraphs or top-level graphs) to have + # a mapping from RX node_id to the "original" NX node_id. However, we do not need + # to do what we do with the _node_id_to_parent_node_id_map and set the value of + # the "original" node_id now, because this value never changes for a node. It + # should already have been set for each node by the standard RX code that + # converts from NX to RX (which sets the "__networkx_node__" attribute to be + # the NX node_id). We just check to make sure that it is in fact set. + # + for node_id in nodes: + if not ("__networkx_node__" in self.node_data(node_id)): + raise Exception("subgraph: internal error: original_nx_node_id not set") + + rx_subgraph = self._rx_graph.subgraph(nodes) + new_subgraph = self.from_rustworkx(rx_subgraph) + + # frm: Create the map from subgraph node_id to parent graph node_id + _node_id_to_parent_node_id_map = {} + for subgraph_node_id in new_subgraph.node_indices: + _node_id_to_parent_node_id_map[subgraph_node_id] = \ + new_subgraph.node_data(subgraph_node_id)["parent_node_id"] + # value no longer needed, so delete it + new_subgraph.node_data(subgraph_node_id).pop("parent_node_id") + + # frm: Create the map from subgraph node_id to the original graph's node_id + _node_id_to_original_nx_node_id_map = {} + for subgraph_node_id in new_subgraph.node_indices: + _node_id_to_original_nx_node_id_map[subgraph_node_id] = \ + new_subgraph.node_data(subgraph_node_id)["__networkx_node__"] + else: + raise TypeError( + "Graph passed to 'subgraph()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) + + new_subgraph._is_a_subgraph = True + new_subgraph._node_id_to_parent_node_id_map = _node_id_to_parent_node_id_map + new_subgraph._node_id_to_original_nx_node_id_map = _node_id_to_original_nx_node_id_map + + return new_subgraph + + # frm: TODO: Refactoring: Create abstract type name for "Flip" and "Flip_Dict". + # + # This is cosmetic, but it would (IMHO) make the code easier to understand, and it + # would provide a logical place to define WTF a flip is... + + def translate_subgraph_node_ids_for_flips(self, flips: dict[Any, int]) -> dict[Any, int]: + """ + Translate the given flips so that the subgraph node_ids in the flips + have been translated to the appropriate node_ids in the + parent graph. + + The flips parameter is a dict mapping node_ids to parts (districts). + + This routine is used when a computation that creates flips is made + on a subgraph, but those flips want to be translated into the context + of the parent graph at the end of the computation. + + For more details, refer to the larger comment on subgraphs... + + :param flips: A dict containing "flips" which associate a node with + a new part in a partition (a "part" is the same as a district in + common parlance). + :type flips: dict[Any, int] + + :returns: A dict containing "flips" that have been translated to have + node_ids appropriate for the parent graph + :rtype: dict[Any, int] + """ + + # frm: TODO: Documentation: Write an overall comment on subgraphs and node_id maps + + translated_flips = {} + for subgraph_node_id, part in flips.items(): + parent_node_id = self._node_id_to_parent_node_id_map[subgraph_node_id] + translated_flips[parent_node_id] = part + + return translated_flips + + def translate_subgraph_node_ids_for_set_of_nodes(self, set_of_nodes: set[Any]) -> set[Any]: + """ + Translate the given set_of_nodes to have the appropriate + node_ids for the parent graph. + + This routine is used when a computation that creates a set of nodes is made + on a subgraph, but those nodes want to be translated into the context + of the parent graph at the end of the computation. + + For more details, refer to the larger comment on subgraphs... + + :param set_of_nodes: A set of node_ids in a subgraph + :type set_of_nodes: set[Any] + + :returns: A set of node_ids that have been translated to have + the node_ids appropriate for the parent graph + :rtype: set[Any] + """ + # This routine replaces the node_ids of the subgraph with the node_ids + # for the same node in the parent graph. This routine is used to + # when a computation is made on a subgraph but the resulting set of nodes + # being returned want to be the appropriate node_ids for the parent graph. + translated_set_of_nodes = set() + for node_id in set_of_nodes: + translated_set_of_nodes.add(self._node_id_to_parent_node_id_map[node_id]) + return translated_set_of_nodes + + def generic_bfs_edges(self, source, neighbors=None, depth_limit=None) -> Generator[tuple[Any, Any], None, None]: + """ + + + :param : ...text... + ...more text... + :type : + + :returns: ...text... + :rtype: + """ + # frm: Code copied from GitHub: + # + # https://github.com/networkx/networkx/blob/main/networkx/algorithms/traversal/breadth_first_search.py + # + # Code was not modified - it worked as written for both rx.PyGraph and a graph.Graph object + # with an RX graph embedded in it... + + """Iterate over edges in a breadth-first search. + + The breadth-first search begins at `source` and enqueues the + neighbors of newly visited nodes specified by the `neighbors` + function. + + Parameters + ---------- + G : RustworkX.PyGraph object (not a NetworkX graph) + + source : node + Starting node for the breadth-first search; this function + iterates over only those edges in the component reachable from + this node. + + neighbors : function + A function that takes a newly visited node of the graph as input + and returns an *iterator* (not just a list) of nodes that are + neighbors of that node with custom ordering. If not specified, this is + just the ``G.neighbors`` method, but in general it can be any function + that returns an iterator over some or all of the neighbors of a + given node, in any order. + + depth_limit : int, optional(default=len(G)) + Specify the maximum search depth. + + Yields + ------ + edge + Edges in the breadth-first search starting from `source`. + + Examples + -------- + >>> G = nx.path_graph(7) + >>> list(nx.generic_bfs_edges(G, source=0)) + [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6)] + >>> list(nx.generic_bfs_edges(G, source=2)) + [(2, 1), (2, 3), (1, 0), (3, 4), (4, 5), (5, 6)] + >>> list(nx.generic_bfs_edges(G, source=2, depth_limit=2)) + [(2, 1), (2, 3), (1, 0), (3, 4)] + + The `neighbors` param can be used to specify the visitation order of each + node's neighbors generically. In the following example, we modify the default + neighbor to return *odd* nodes first: + + >>> def odd_first(n): + ... return sorted(G.neighbors(n), key=lambda x: x % 2, reverse=True) + + >>> G = nx.star_graph(5) + >>> list(nx.generic_bfs_edges(G, source=0)) # Default neighbor ordering + [(0, 1), (0, 2), (0, 3), (0, 4), (0, 5)] + >>> list(nx.generic_bfs_edges(G, source=0, neighbors=odd_first)) + [(0, 1), (0, 3), (0, 5), (0, 2), (0, 4)] + + Notes + ----- + This implementation is from `PADS`_, which was in the public domain + when it was first accessed in July, 2004. The modifications + to allow depth limits are based on the Wikipedia article + "`Depth-limited-search`_". + + .. _PADS: http://www.ics.uci.edu/~eppstein/PADS/BFS.py + .. _Depth-limited-search: https://en.wikipedia.org/wiki/Depth-limited_search + """ + # frm: These two if-stmts work for both rx.PyGraph and gerrychain.Graph with RX inside + if neighbors is None: + neighbors = self.neighbors + if depth_limit is None: + depth_limit = len(self) + + seen = {source} + n = len(self) + depth = 0 + next_parents_children = [(source, neighbors(source))] + while next_parents_children and depth < depth_limit: + this_parents_children = next_parents_children + next_parents_children = [] + for parent, children in this_parents_children: + for child in children: + # frm: avoid cycles - don't process a child twice... + if child not in seen: + seen.add(child) + # frm: add this node's children to list to be processed later... + next_parents_children.append((child, neighbors(child))) + yield (parent, child) + if len(seen) == n: + return + depth += 1 + + # frm: TODO: Testing: Add tests for all of the new routines I have added... + + def generic_bfs_successors_generator(self, root_node_id: Any) -> Generator[tuple[Any, Any], None, None]: + """ + + + :param : ...text... + ...more text... + :type : + + :returns: ...text... + :rtype: + """ + # frm: Generate in sequence a tuple for the parent (node_id) and + # the children of that node (list of node_ids). + parent = root_node_id + children = [] + for p, c in self.generic_bfs_edges(root_node_id): + # frm: parent-child pairs appear ordered by their parent, so + # we can collect all of the children for a node by just + # iterating through pairs until the parent changes. + if p == parent: + children.append(c) + continue + yield (parent, children) + # new parent, so reset parent and children variables to + # be the new parent (p) and a new children list containing + # this first child (c), and continue looping + children = [c] + parent = p + yield (parent, children) + + def generic_bfs_successors(self, root_node_id: Any) -> dict[Any: Any]: + """ + + + :param : ...text... + ...more text... + :type : + + :returns: ...text... + :rtype: + """ + return dict(self.generic_bfs_successors_generator(root_node_id)) + + def generic_bfs_predecessors(self, root_node_id: Any) -> dict[Any, Any]: + """ + + + :param : ...text... + ...more text... + :type : + + :returns: ...text... + :rtype: + """ + # frm Note: We had do implement our own, because the built-in RX version only worked + # for directed graphs. + predecessors = [] + for s, t in self.generic_bfs_edges(root_node_id): + predecessors.append((t,s)) + return dict(predecessors) + + + def predecessors(self, root_node_id: Any) -> dict[Any: Any]: + """ + + + :param : ...text... + ...more text... + :type : + + :returns: ...text... + :rtype: + """ + + """ + frm: It took me a while to grok what predecessors() and successors() + were all about. In the end, it was simple - they are just the + parents and the children of a tree that "starts" at the given root + node. + + What took me a while to understand is that this effectively + converts an undirected cyclic graph into a DAG. What is clever is + that as soon as it detects a cycle it stops traversing the graph. + The other thing that is clever is that the DAG that is created + either starts at the top or the bottom. For successors(), the + DAG starts at the top, so that the argument to successors() is + the root of the tree. However, in the case of predecessors() + the argument to predecessors() is a leaf node, and the "tree" + can have multiple "roots". + + In both cases, you can ask what the associated parent or + children are of any node in the graph. If you ask for the + successors() you will get a list of the children nodes. + If you ask for the predecessors() you will get the single + parent node. + + I think that the successors() graph is deterministic (except + for the order of the child nodes), meaning that for a given + graph no matter what order you created nodes and added edges, + you will get the same set of children for a given node. + However, for predecessors(), there are many different + DAGs that might be created depending on which edge the + algorithm decides is the single parent. + + All of this is interesting, but I have not yet spent the + time to figure out why it matters in the code. + + TODO: Code: predecessors(): Decide if it makes sense to have different implementations + for NX and RX. The code below has the original definition + from the pre-RX codebase, but the code for RX will work + for NX too - so I think that there is no good reason to + have different code for NX. Maybe no harm, but on the other + hand, it seems like a needless difference and hence more + complexity... + + TODO: Performance: see if the performance of the built-in NX + version is significantly better than the generic one. + """ + + self.verify_graph_is_valid() + + if (self.is_rx_graph()): + return self.generic_bfs_predecessors(root_node_id) + elif (self.is_nx_graph()): + return {a: b for a, b in networkx.bfs_predecessors(self._nx_graph, root_node_id)} + else: + raise TypeError( + "Graph passed to 'predecessors()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) + + def successors(self, root_node_id: Any) -> dict[Any: Any]: + """ + + + :param : ...text... + ...more text... + :type : + + :returns: ...text... + :rtype: + """ + self.verify_graph_is_valid() + + if (self.is_rx_graph()): + return self.generic_bfs_successors(root_node_id) + elif (self.is_nx_graph()): + return {a: b for a, b in networkx.bfs_successors(self._nx_graph, root_node_id)} + else: + raise TypeError( + "Graph passed to 'successors()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) + + def neighbors(self, node_id: Any) -> list[Any]: + """ + Return a list of the node_ids of the nodes that are neighbors of + the given node - that is, all of the nodes that are directly + connected to the given node by an edge. + + :param node_id: The ID of a node + :type node_id: Any + + :returns: A list of neighbor node_ids + :rtype: list[Any] + """ + self.verify_graph_is_valid() + + if (self.is_rx_graph()): + return list(self._rx_graph.neighbors(node_id)) + elif (self.is_nx_graph()): + return list(self._nx_graph.neighbors(node_id)) + else: + raise TypeError( + "Graph passed to 'neighbors()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) + + def degree(self, node_id: Any) -> int: + """ + Return the degree of the given node, that is, the number + of other nodes directly connected to the given node. + + :param node_id: The ID of a node + :type node_id: Any + + :returns: Number of nodes directly connected to the given node + :rtype: int + """ + self.verify_graph_is_valid() + + if (self.is_rx_graph()): + return self._rx_graph.degree(node_id) + elif (self.is_nx_graph()): + return self._nx_graph.degree(node_id) + else: + raise TypeError( + "Graph passed to 'degree()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) + + def node_data(self, node_id: Any) -> dict[Any, Any]: + """ + Return the data dictionary that contains the given node's data. + + As docmented elsewhere, in GerryChain code before the conversion + to RustworkX, users could access node data using the syntax: + + graph.nodes[node_id][attribute_name] + + This was because a GerryChain Graph object in that codebase was a + subclass of NetworkX.Graph, and NetworkX was clever and implemented + dict-like behavior for the syntax graph.nodes[]... + + This Python cleverness was not carried over to the RustworkX + implementation, so in the current GerryChain Graph implementation + users need to access node data using the syntax: + + graph.node_data(node_id)[attribute_name] + + :param node_id: The ID of a node + :type node_id: Any + + :returns: Data dictionary containing the given node's data. + :rtype: dict[Any, Any] + """ + + self.verify_graph_is_valid() + + if (self.is_rx_graph()): + data_dict = self._rx_graph[node_id] + elif (self.is_nx_graph()): + data_dict = self._nx_graph.nodes[node_id] + else: + raise TypeError( + "Graph passed to 'node_data()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) + + if not isinstance(data_dict, dict): + raise TypeError("graph.node_data(): data for node is not a dict") + + return data_dict + + def edge_data(self, edge_id: Any) -> dict[Any, Any]: + """ + Return the data dictionary that contains the data for the given edge. + + Note that in NetworkX an edge_id can be almost anything, for instance, + a string or even a tuple. However, in RustworkX, an edge_id is + an integer. This code handles both kinds of edge_ids - hence the + type, Any. + + :param edge_id: The ID of the edge + :type edge_id: Any + + :returns: The data dictionary for the given edge's data + :rtype: dict[Any, Any] + """ + + self.verify_graph_is_valid() + + if (self.is_rx_graph()): + data_dict = self._rx_graph.get_edge_data_by_index(edge_id) + elif (self.is_nx_graph()): + data_dict = self._nx_graph.edges[edge_id] + else: + raise TypeError( + "Graph passed to 'edge_data()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) + + # Sanity check - RX edges do not need to have a data dict for node data + # + # A GerryChain Graph object should always be constructed with a data dict + # for edge data, but it doesn't hurt to check. + if not isinstance(data_dict, dict): + raise TypeError("graph.edge(): data for edge is not a dict") + + return data_dict + + + # frm: TODO: Documentation: Note: I added the laplacian_matrix routines as methods of the Graph + # class because they are only ever used on Graph objects. It + # bloats the Graph class, but it still seems like the best + # option. + # + # A goal is to encapsulate ALL NX dependencies in this file. + + def laplacian_matrix(self) -> scipy.sparse.csr_array: + """ + + + :param : ...text... + ...more text... + :type : + + :returns: ...text... + :rtype: + """ + # A local "gc" (as in GerryChain) version of the laplacian matrix + + # frm: TODO: Code: laplacian_matrix(): should NX and RX return same type (float vs. int)? + # + # The NX version returns a matrix of integer values while the + # RX version returns a matrix of floating point values. I + # think the reason is that the RX.adjacency_matrix() call + # returns an array of floats. + # + # Since the laplacian matrix is used for further numeric + # processing, I don't think this matters, but I should + # check to be 100% certain. + + if self.is_rx_graph(): + rx_graph = self._rx_graph + # 1. Get the adjacency matrix + adj_matrix = rustworkx.adjacency_matrix(rx_graph) + # 2. Calculate the degree matrix (simplified for this example) + degree_matrix = numpy.diag([rx_graph.degree(node) for node in rx_graph.node_indices()]) + # 3. Calculate the Laplacian matrix + np_laplacian_matrix = degree_matrix - adj_matrix + # 4. Convert the NumPy array to a scipy.sparse array + laplacian_matrix = scipy.sparse.csr_array(np_laplacian_matrix) + elif self.is_nx_graph(): + nx_graph = self._nx_graph + laplacian_matrix = networkx.laplacian_matrix(nx_graph) + else: + raise TypeError( + "Graph passed into laplacian_matrix() is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) + + return laplacian_matrix + + def normalized_laplacian_matrix(self) -> scipy.sparse.dia_array: + """ + + + :param : ...text... + ...more text... + :type : + + :returns: ...text... + :rtype: + """ + + def create_scipy_sparse_array_from_rx_graph(rx_graph: rustworkx.PyGraph) -> scipy.sparse.coo_matrix: + """ + + + :param : ...text... + ...more text... + :type : + + :returns: ...text... + :rtype: + """ + num_nodes = rx_graph.num_nodes() + + rows = [] + cols = [] + data = [] + + for u, v in rx_graph.edge_list(): + rows.append(u) + cols.append(v) + data.append(1) # simple adjacency matrix, so just 1 not weight attribute + + sparse_array = scipy.sparse.coo_matrix((data, (rows,cols)), shape=(num_nodes, num_nodes)) + + return sparse_array + + if self.is_rx_graph(): + rx_graph = self._rx_graph + """ + The following is code copied from the networkx linalg file, laplacianmatrix.py + for normalized_laplacian_matrix. Below this code has been modified to work for + gerrychain (with an RX-based Graph object) + + import numpy as np + import scipy as sp + + if nodelist is None: + nodelist = list(G) + A = nx.to_scipy_sparse_array(G, nodelist=nodelist, weight=weight, format="csr") + n, _ = A.shape + diags = A.sum(axis=1) + D = sp.sparse.dia_array((diags, 0), shape=(n, n)).tocsr() + L = D - A + with np.errstate(divide="ignore"): + diags_sqrt = 1.0 / np.sqrt(diags) + diags_sqrt[np.isinf(diags_sqrt)] = 0 + DH = sp.sparse.dia_array((diags_sqrt, 0), shape=(n, n)).tocsr() + return DH @ (L @ DH) + + """ + + # frm: TODO: Get someone to validate that this in fact does the right thing. + # + # The one test, test_proposal_returns_a_partition[spectral_recom], in test_proposals.py + # that uses normalized_laplacian_matrix() now passes, but it is for a small 6x6 graph + # and hence is not a real world test... + # + + A = create_scipy_sparse_array_from_rx_graph(rx_graph) + n, _ = A.shape # shape() => dimensions of the array (rows, cols), so n = num_rows + diags = A.sum(axis=1) # sum of values in each row => column vector + diags = diags.T # convert to a row vector / 1D array + D = scipy.sparse.dia_array((diags, [0]), shape=(n, n)).tocsr() + L = D - A + with numpy.errstate(divide="ignore"): + diags_sqrt = 1.0 / numpy.sqrt(diags) + diags_sqrt[numpy.isinf(diags_sqrt)] = 0 + DH = scipy.sparse.dia_array((diags_sqrt, 0), shape=(n, n)).tocsr() + normalized_laplacian = DH @ (L @ DH) + return normalized_laplacian + + elif self.is_nx_graph(): + nx_graph = self._nx_graph + laplacian_matrix = networkx.normalized_laplacian_matrix(nx_graph) + else: + raise TypeError( + "Graph passed into normalized_laplacian_matrix() is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) + + return laplacian_matrix + + def subgraphs_for_connected_components(self) -> list["Graph"]: + """ + Create and return a list of subgraphs for each set of nodes + in the given graph that are connected. + + Note that a connected graph is one in which there is a path + from every node in the graph to every other node in the + graph. + + Note also that each of the subgraphs returned is a + maximal subgraph of connected components, meaning that there + is no other larger subgraph of connected components that + includes it as a subset. + + :returns: A list of "maximal" subgraphs each of which + contains nodes that are connected. + :rtype: list["Graph"] + """ + + if self.is_rx_graph(): + rx_graph = self.get_rx_graph() + subgraphs = [ + self.subgraph(nodes) for nodes in rustworkx.connected_components(rx_graph) + ] + elif self.is_nx_graph(): + nx_graph = self.get_nx_graph() + subgraphs = [ + self.subgraph(nodes) for nodes in networkx.connected_components(nx_graph) + ] + else: + raise TypeError( + "Graph passed to 'subgraphs_for_connected_components()' is " + "neither a networkx-based graph nor a rustworkx-based graph" + ) + + return subgraphs + + def num_connected_components(self) -> int: + """ + Return the number of connected components. + + Note: A connected component is a maximal subgraph + where every vertex is reachable from every other vertex in + that same subgraph. In a graph that is not fully connected, + connected components are the separate, distinct "islands" of + connected nodes. Every node in a graph belongs to exactly + one connected component. + + :returns: The number of connected components + :rtype: int + """ + + # frm: TODO: Performance: num_connected_components(): do both NX and RX have builtins for this? + # + # NetworkX and RustworkX both have a routine number_connected_components(). + # I am guessing that it is more efficient to call these than it is + # to construct the connected components and then determine how many + # of them there are. + # + # So - should be a simple issue of trying it and running tests, but + # I will do that another day... + + if self.is_rx_graph(): + rx_graph = self.get_rx_graph() + connected_components = rustworkx.connected_components(rx_graph) + elif self.is_nx_graph(): + nx_graph = self.get_nx_graph() + connected_components = list(networkx.connected_components(nx_graph)) + else: + raise TypeError( + "Graph passed to 'num_connected_components()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) + + num_cc = len(connected_components) + return num_cc + + def is_a_tree(self) -> bool: + """ + Return whether the current graph is a tree - meaning that + it is connected and that it has no cycles. + + :returns: Whether the current graph is a tree + :rtype: bool + """ + + # frm: TODO: Refactor: is_a_tree() is only called in a test (test_tree.py) - delete it? + # + # Talk to Peter to see if there is any reason to keep this function. Does anyone + # use it externally? + # + # On the other hand, perhaps it is OK to keep it even if it is only ever used in a test... + + if self.is_rx_graph(): + rx_graph = self.get_rx_graph() + num_nodes = rx_graph.num_nodes() + num_edges = rx_graph.num_edges() + + # Condition 1: Check if the number of edges is one less than the number of nodes + if num_edges != num_nodes - 1: + return False + + # Condition 2: Check for connectivity (and implicitly, acyclicity if E = V-1) + # A graph with V-1 edges and no cycles must be connected. + # A graph with V-1 edges and connected must be acyclic. + + # We can check connectivity by ensuring there's only one connected component. + connected_components = rustworkx.connected_components(rx_graph) + if len(connected_components) != 1: + return False + + return True + elif self.is_nx_graph(): + nx_graph = self.get_nx_graph() + return networkx.is_tree(nx_graph) + else: + raise TypeError( + "Graph passed to 'is_a_tree()' is neither a " + "networkx-based graph nor a rustworkx-based graph" + ) + + +def add_boundary_perimeters(nx_graph: networkx.Graph, geometries: pd.Series) -> None: """ Add shared perimeter between nodes and the total geometry boundary. @@ -428,23 +2328,34 @@ def add_boundary_perimeters(graph: Graph, geometries: pd.Series) -> None: :returns: The updated graph. :rtype: Graph """ - from shapely.ops import unary_union - from shapely.prepared import prep + + # frm: TODO: add_boundary_perimeters(): Think about whether it is reasonable to require this to work + # on an NetworkX.Graph object. + + # frm: The original code operated on the Graph object which was a subclass of + # NetworkX.Graph. I have changed it to operate on a NetworkX.Graph object + # with the understanding that callers will reach down into a Graph object + # and pass in the inner nx_graph data member. + + if not(isinstance(nx_graph, networkx.Graph)): + raise TypeError( + "Graph passed into add_boundary_perimeters() " + "is not a networkx graph" + ) prepared_boundary = prep(unary_union(geometries).boundary) boundary_nodes = geometries.boundary.apply(prepared_boundary.intersects) - for node in graph: - graph.nodes[node]["boundary_node"] = bool(boundary_nodes[node]) + for node in nx_graph: + nx_graph.nodes[node]["boundary_node"] = bool(boundary_nodes[node]) if boundary_nodes[node]: total_perimeter = geometries[node].boundary.length shared_perimeter = sum( - neighbor_data["shared_perim"] for neighbor_data in graph[node].values() + neighbor_data["shared_perim"] for neighbor_data in nx_graph[node].values() ) boundary_perimeter = total_perimeter - shared_perimeter - graph.nodes[node]["boundary_perim"] = boundary_perimeter - + nx_graph.nodes[node]["boundary_perim"] = boundary_perimeter def check_dataframe(df: pd.DataFrame) -> None: """ @@ -524,6 +2435,15 @@ class FrozenGraph: The class uses `__slots__` for improved memory efficiency. """ + # frm: TODO: Code: Rename the internal data member, "graph", to be something else. + # The reason is that a NetworkX.Graph object already has an internal + # data member named, "graph", which is just a dict for the data + # associated with the Networkx.Graph object. + # + # So to avoid confusion, naming the frozen graph something like + # _frozen_graph would make it easier for a future reader of the + # code to avoid confusion... + __slots__ = ["graph", "size"] def __init__(self, graph: Graph) -> None: @@ -535,30 +2455,84 @@ def __init__(self, graph: Graph) -> None: :returns: None """ - self.graph = networkx.classes.function.freeze(graph) - self.graph.join = frozen - self.graph.add_data = frozen - self.size = len(self.graph) + # frm: Original code follows: + # + # self.graph = networkx.classes.function.freeze(graph) + # + # # frm: frozen is just a function that raises an exception if called... + # self.graph.join = frozen + # self.graph.add_data = frozen + # + # self.size = len(self.graph) + + # frm TODO: Code: Add logic to FrozenGraph so that it is indeed "frozen" (for both NX and RX) + # + # I think this just means redefining those methods that change the graph + # to return an error / exception if called. + + self.graph = graph + self.size = len(self.graph.node_indices) def __len__(self) -> int: + """ + + + :param : ...text... + ...more text... + :type : + + :returns: ...text... + :rtype: + """ return self.size def __getattribute__(self, __name: str) -> Any: + """ + + + :param : ...text... + ...more text... + :type : + + :returns: ...text... + :rtype: + """ try: return object.__getattribute__(self, __name) except AttributeError: - return object.__getattribute__(self.graph, __name) + # delegate getting the attribute to the graph data member + return self.graph.__getattribute__(__name) def __getitem__(self, __name: str) -> Any: + """ + + + :param : ...text... + ...more text... + :type : + + :returns: ...text... + :rtype: + """ return self.graph[__name] def __iter__(self) -> Iterable[Any]: + """ + + + :param : ...text... + ...more text... + :type : + + :returns: ...text... + :rtype: + """ yield from self.node_indices @functools.lru_cache(16384) - def neighbors(self, n: Any) -> Tuple[Any, ...]: - return tuple(self.graph.neighbors(n)) + def neighbors(self, n: Any) -> tuple[Any, ...]: + return self.graph.neighbors(n) @functools.cached_property def node_indices(self) -> Iterable[Any]: @@ -572,9 +2546,5 @@ def edge_indices(self) -> Iterable[Any]: def degree(self, n: Any) -> int: return self.graph.degree(n) - @functools.lru_cache(65536) - def lookup(self, node: Any, field: str) -> Any: - return self.graph.nodes[node][field] - def subgraph(self, nodes: Iterable[Any]) -> "FrozenGraph": return FrozenGraph(self.graph.subgraph(nodes)) diff --git a/gerrychain/grid.py b/gerrychain/grid.py index b635c807..688783ae 100644 --- a/gerrychain/grid.py +++ b/gerrychain/grid.py @@ -12,8 +12,17 @@ - typing: Used for type hints. """ + import math import networkx +# frm TODO: Documentation: Clarify what purpose grid.py serves. +# +# It is a convenience module to help users create toy graphs. It leverages +# NX to create graphs, but it returns new Graph objects. So, legacy user +# code will need to be at least reviewed to make sure that it properly +# copes with new Graph objects. +# + from gerrychain.partition import Partition from gerrychain.graph import Graph from gerrychain.updaters import ( @@ -62,6 +71,20 @@ def __init__( assignment: Optional[Dict] = None, updaters: Optional[Dict[str, Callable]] = None, parent: Optional["Grid"] = None, + # frm: ???: TODO: This code indicates that flips are a dict of tuple: int which would be + # correct for edge flips, but not for node flips. Need to check again + # to see if this is correct. Note that flips is used in the constructor + # so it should fall through to Partition._from_parent()... + # + # OK - I think that this is a bug. Parition._from_parent() assumes + # that flips are a mapping from node to partition not tuple/edge to partition. + # I checked ALL of the code and the constructor for Grid is never passed in + # a flips parameter, so there is no example to check / verify, but it sure + # looks and smells like a bug. + # + # The fix would be to just change Dict[Tuple[int, int], int] to be + # Dict[int, int] + # flips: Optional[Dict[Tuple[int, int], int]] = None, ) -> None: """ @@ -95,14 +118,17 @@ def __init__( :raises Exception: If neither dimensions nor parent is provided. """ + + # Note that Grid graphs have node_ids that are tuples not integers. + if dimensions: self.dimensions = dimensions - graph = Graph.from_networkx(create_grid_graph(dimensions, with_diagonals)) + graph = Graph.from_networkx(_create_grid_nx_graph(dimensions, with_diagonals)) if not assignment: thresholds = tuple(math.floor(n / 2) for n in self.dimensions) assignment = { - node: color_quadrants(node, thresholds) for node in graph.nodes # type: ignore + node_id: color_quadrants(node_id, thresholds) for node_id in graph.node_indices # type: ignore } if not updaters: @@ -139,7 +165,22 @@ def as_list_of_lists(self): return [[self.assignment.mapping[(i, j)] for i in range(m)] for j in range(n)] -def create_grid_graph(dimensions: Tuple[int, int], with_diagonals: bool) -> Graph: +# frm: Documentation: Document what grid.py is intended to be used for +# +# I will need to do some research, but my guess is that there are two use +# cases: +# +# 1) Testing - make it easy to create tests +# 2) User Code - make it easy for users to play around. +# +# For #1, it is OK to have some routines return NX-Graph objects and some to return new Graph +# objects, but that is probably confusing to users, so the todo list items are: +# +# * Decide whether to support returning NX-based Graphs sometimes and new Graphs others, +# * Document whatever we decide +# + +def _create_grid_nx_graph(dimensions: Tuple[int, int], with_diagonals: bool) -> Graph: """ Creates a grid graph with the specified dimensions. Optionally includes diagonal connections between nodes. @@ -157,9 +198,9 @@ def create_grid_graph(dimensions: Tuple[int, int], with_diagonals: bool) -> Grap if len(dimensions) != 2: raise ValueError("Expected two dimensions.") m, n = dimensions - graph = networkx.generators.lattice.grid_2d_graph(m, n) + nx_graph = networkx.generators.lattice.grid_2d_graph(m, n) - networkx.set_edge_attributes(graph, 1, "shared_perim") + networkx.set_edge_attributes(nx_graph, 1, "shared_perim") if with_diagonals: nw_to_se = [ @@ -169,18 +210,31 @@ def create_grid_graph(dimensions: Tuple[int, int], with_diagonals: bool) -> Grap ((i, j + 1), (i + 1, j)) for i in range(m - 1) for j in range(n - 1) ] diagonal_edges = nw_to_se + sw_to_ne - graph.add_edges_from(diagonal_edges) + #frm: TODO: Check that graph is an NX graph before calling graph.add_edges_from(). Eventually + # make this work for RX too... + nx_graph.add_edges_from(diagonal_edges) for edge in diagonal_edges: - graph.edges[edge]["shared_perim"] = 0 + # frm: TODO: When/if grid.py is converted to operate on GerryChain Graph + # objects instead of NX.Graph objects, this use of NX + # EdgeView to get/set edge data will need to change to use + # gerrychain_graph.edge_data() + # + # We will also need to think about edge vs edge_id. In this + # case we want an edge_id, so that means we need to look at + # how diagonal_edges are created - but that is for the future... + nx_graph.edges[edge]["shared_perim"] = 0 - networkx.set_node_attributes(graph, 1, "population") - networkx.set_node_attributes(graph, 1, "area") + # frm: These just set all nodes/edges in the graph to have the given attributes with a value of 1 + # frm: TODO: These won't work for the new graph, and they won't work for RX + networkx.set_node_attributes(nx_graph, 1, "population") + networkx.set_node_attributes(nx_graph, 1, "area") - tag_boundary_nodes(graph, dimensions) + _tag_boundary_nodes(nx_graph, dimensions) - return graph + return nx_graph +# frm ???: Why is this here instead of in graph.py? Who is it intended for? Internal vs. External? def give_constant_attribute(graph: Graph, attribute: Any, value: Any) -> None: """ Sets the specified attribute to the specified value for all nodes in the graph. @@ -194,11 +248,11 @@ def give_constant_attribute(graph: Graph, attribute: Any, value: Any) -> None: :returns: None """ - for node in graph.nodes: - graph.nodes[node][attribute] = value + for node_id in graph.node_indices: + graph.node_data(node_id)[attribute] = value -def tag_boundary_nodes(graph: Graph, dimensions: Tuple[int, int]) -> None: +def _tag_boundary_nodes(nx_graph: networkx.Graph, dimensions: Tuple[int, int]) -> None: """ Adds the boolean attribute ``boundary_node`` to each node in the graph. If the node is on the boundary of the grid, that node also gets the attribute @@ -211,13 +265,30 @@ def tag_boundary_nodes(graph: Graph, dimensions: Tuple[int, int]) -> None: :returns: None """ + # + # frm: Another case of code that is not clear (at least to me). It took me + # a while to figure out that the name/label for a node in a grid graph + # is a tuple and not just a number or string. The tuple indicates its + # position in the grid (x,y) cartesian coordinates, so node[0] below + # means its x-position and node[1] means its y-position. So the if-stmt + # below tests whether a node is all the way on the left or the right or + # all the way on the top or the bottom. If so, it is tagged as a + # boundary node and it gets its boundary_perim value set - still not + # sure what that does/means... + # + # Peter's comment from PR: + # + # I think that being able to identify a boundary edge was needed in some early + # experiments, so it was important to tag them, but I haven't really something + # that cares about this in a while + m, n = dimensions - for node in graph.nodes: + for node in nx_graph.nodes: if node[0] in [0, m - 1] or node[1] in [0, n - 1]: - graph.nodes[node]["boundary_node"] = True - graph.nodes[node]["boundary_perim"] = get_boundary_perim(node, dimensions) + nx_graph.nodes[node]["boundary_node"] = True + nx_graph.nodes[node]["boundary_perim"] = get_boundary_perim(node, dimensions) else: - graph.nodes[node]["boundary_node"] = False + nx_graph.nodes[node]["boundary_node"] = False def get_boundary_perim(node: Tuple[int, int], dimensions: Tuple[int, int]) -> int: diff --git a/gerrychain/metagraph.py b/gerrychain/metagraph.py index 438af206..833af8c2 100644 --- a/gerrychain/metagraph.py +++ b/gerrychain/metagraph.py @@ -22,11 +22,19 @@ def all_cut_edge_flips(partition: Partition) -> Iterator[Dict]: Generate all possible flips of cut edges in a partition without any constraints. + This routine finds all edges on the boundary of + districts - those that are "cut edges" where one node + is in one district and the other node is in another + district. These are all of the places where you + could move the boundary between districts by moving a single + node. + :param partition: The partition object. :type partition: Partition :returns: An iterator that yields dictionaries representing the flipped edges. :rtype: Iterator[Dict] """ + for edge, index in product(partition.cut_edges, (0, 1)): yield {edge[index]: partition.assignment.mapping[edge[1 - index]]} diff --git a/gerrychain/metrics/partisan.py b/gerrychain/metrics/partisan.py index ef8be5cc..af75fb99 100644 --- a/gerrychain/metrics/partisan.py +++ b/gerrychain/metrics/partisan.py @@ -9,12 +9,13 @@ import numpy from typing import Tuple +# frm: TODO: Refactoring: Why are these not just included in the file that defines ElectionResults? def mean_median(election_results) -> float: """ Computes the Mean-Median score for the given ElectionResults. A positive value indicates an advantage for the first party listed - in the Election's parties_to_columns dictionary. + in the Election's party_names_to_node_attribute_names dictionary. :param election_results: An ElectionResults object :type election_results: ElectionResults @@ -32,7 +33,7 @@ def mean_thirdian(election_results) -> float: """ Computes the Mean-Median score for the given ElectionResults. A positive value indicates an advantage for the first party listed - in the Election's parties_to_columns dictionary. + in the Election's party_names_to_node_attribute_names dictionary. The motivation for this score is that the minority party in many states struggles to win even a third of the seats. @@ -56,7 +57,7 @@ def efficiency_gap(election_results) -> float: """ Computes the efficiency gap for the given ElectionResults. A positive value indicates an advantage for the first party listed - in the Election's parties_to_columns dictionary. + in the Election's party_names_to_node_attribute_names dictionary. :param election_results: An ElectionResults object :type election_results: ElectionResults diff --git a/gerrychain/optimization/optimization.py b/gerrychain/optimization/optimization.py index 5b45e3c6..fd7a6a7e 100644 --- a/gerrychain/optimization/optimization.py +++ b/gerrychain/optimization/optimization.py @@ -505,7 +505,7 @@ def tilted_short_bursts( with_progress_bar=with_progress_bar, ) - # TODO: Maybe add a max_time variable so we don't run forever. + # TODO: Refactoring: Maybe add a max_time variable so we don't run forever. def variable_length_short_bursts( self, num_steps: int, diff --git a/gerrychain/partition/assignment.py b/gerrychain/partition/assignment.py index b38ca49d..bdbe8ad7 100644 --- a/gerrychain/partition/assignment.py +++ b/gerrychain/partition/assignment.py @@ -37,6 +37,7 @@ def __init__( :raises ValueError: if the keys of ``parts`` are not unique :raises TypeError: if the values of ``parts`` are not frozensets """ + if validate: number_of_keys = sum(len(keys) for keys in parts.values()) number_of_unique_keys = len(set().union(*parts.values())) @@ -77,6 +78,14 @@ def update_flows(self, flows): """ Update the assignment for some nodes using the given flows. """ + # frm: Update the assignment of nodes to partitions by adding + # all of the new nodes and removing all of the old nodes + # as represented in the flows (dict keyed by district (part) + # of nodes flowing "in" and "out" for that district). + # + # Also, reset the mapping of node to partition (self.mapping) + # to reassign each node to its new partition. + # for part, flow in flows.items(): # Union between frozenset and set returns an object whose type # matches the object on the left, which here is a frozenset @@ -146,9 +155,64 @@ def from_dict(cls, assignment: Dict) -> "Assignment": passed-in dictionary. :rtype: Assignment """ + + # frm: TODO: Refactoring: Clean up from_dict(). + # + # A couple of things: + # * It uses a routine, level_sets(), which is only ever used here, so + # why bother having a separate routine. All it does is convert a dict + # mapping node_ids to parts into a dict mapping parts into sets of + # node_ids. Why not just have that code here inline? + # + # * Also, the constructor for Assignment explicitly allows for the caller + # to pass in a "mapping" of node_id to part, which we have right here. + # Why don't we pass it in and save having to recompute it? + # + parts = {part: frozenset(keys) for part, keys in level_sets(assignment).items()} return cls(parts) + + def new_assignment_convert_old_node_ids_to_new_node_ids(self, node_id_mapping: Dict) -> "Assignment": + """ + Create a new Assignment object from the one passed in, where the node_ids are changed + according to the node_id_mapping from old node_ids to new node_ids. + + This routine was motivated by the fact that node_ids are changed when converting from an + NetworkX based graph to a RustworkX based graph. An Assignment based on the node_ids in + the NetworkX based graph would need to be changed to use the new node_ids - the new + Asignment would be semantically equivalent - just converted to use the new node_ids in + the RX based graph. + + The node_id_mapping is of the form {old_node_id: new_node_id} + """ + + # Dict of the form: {node_id: part_id} + old_assignment_mapping = self.mapping + old_parts = self.parts + + # convert old_node_ids to new_node_ids, keeping part IDs the same + new_assignment_mapping = { + node_id_mapping[old_node_id]: part + for old_node_id, part in old_assignment_mapping.items() + } + # Now upate the parts dict that has a frozenset of all the nodes in each part (district) + new_parts = {} + for cur_node_id, cur_part in new_assignment_mapping.items(): + if not cur_part in new_parts: + new_parts[cur_part] = set() + new_parts[cur_part].add(cur_node_id) + for cur_part, set_of_nodes in new_parts.items(): + new_parts[cur_part] = frozenset(set_of_nodes) + + # pandas.Series(data=part, index=nodes) for part, nodes in self.parts.items() + + new_assignment = Assignment( + new_parts, + new_assignment_mapping + ) + + return new_assignment def get_assignment( @@ -174,13 +238,22 @@ def get_assignment( is not provided. :raises TypeError: If the part_assignment is not a string or dictionary. """ + + # frm: TODO: Refactoring: Think about whether to split this into two functions. AT + # present, it does different things based on whether + # the "part_assignment" parameter is a string, a dict, + # or an assignment. Probably not worth the trouble (possible + # legacy issues), but I just can't get used to the Python habit + # of weak typing... + if isinstance(part_assignment, str): + # Extract an assignment using the named node attribute if graph is None: raise TypeError( "You must provide a graph when using a node attribute for the part_assignment" ) return Assignment.from_dict( - {node: graph.nodes[node][part_assignment] for node in graph} + {node: graph.node_data(node)[part_assignment] for node in graph} ) # Check if assignment is a dict or a mapping type elif callable(getattr(part_assignment, "items", None)): diff --git a/gerrychain/partition/partition.py b/gerrychain/partition/partition.py index 9f484f61..d676130f 100644 --- a/gerrychain/partition/partition.py +++ b/gerrychain/partition/partition.py @@ -1,5 +1,13 @@ import json -import networkx + + +# frm: Only used in _first_time() inside __init__() to allow for creating +# a Partition from a NetworkX Graph object: +# +# elif isinstance(graph, networkx.Graph): +# graph = Graph.from_networkx(graph) +# self.graph = FrozenGraph(graph) +import networkx from gerrychain.graph.graph import FrozenGraph, Graph from ..updaters import compute_edge_flows, flows_from_changes, cut_edges @@ -8,6 +16,18 @@ from ..tree import recursive_tree_part from typing import Any, Callable, Dict, Optional, Tuple +# frm TODO: Documentation: Add documentation about how this all works. For instance, +# what is computationally expensive and how does a FrozenGraph +# help? Why do we need both assignments and parts? +# +# Since a Partition is intimately tied up with how the Markov Chain +# does its magic, it would make sense to talk about that a bit... +# +# For instance, is there any reason to use a Partition object +# except in a Markov Chain? I suppose they are useful for post +# Markov Chain analysis - but if so, then it would be nice to +# know what functionality is tuned for the Markov Chain and what +# functionality / data is tuned for post Markov Chain analysis. class Partition: """ @@ -56,12 +76,24 @@ def __init__( which the functions compute. :param use_default_updaters: If `False`, do not include default updaters. """ + if parent is None: + if graph is None: + raise Exception("Parition.__init__(): graph object is None") + self._first_time(graph, assignment, updaters, use_default_updaters) else: self._from_parent(parent, flips) self._cache = dict() + + #frm: SubgraphView provides cached access to subgraphs for each of the + # partition's districts. It is important that we asign subgraphs AFTER + # we have established what nodes belong to which parts (districts). In + # the case when the parent is None, the assignments are explicitly provided, + # and in the case when there is a parent, the _from_parent() logic processes + # the flips to update the assignments. + self.subgraphs = SubgraphView(self.graph, self.parts) @classmethod @@ -101,7 +133,9 @@ def from_random_assignment( :returns: The partition created with a random assignment :rtype: Partition """ - total_pop = sum(graph.nodes[n][pop_col] for n in graph) + # frm: TODO: BUG: The param, flips, is never used in this routine... + + total_pop = sum(graph.node_data(n)[pop_col] for n in graph) ideal_pop = total_pop / n_parts assignment = method( @@ -120,18 +154,71 @@ def from_random_assignment( ) def _first_time(self, graph, assignment, updaters, use_default_updaters): - if isinstance(graph, Graph): - self.graph = FrozenGraph(graph) - elif isinstance(graph, networkx.Graph): + # Make sure that the embedded graph for the Partition is based on + # a RustworkX graph, and make sure it is also a FrozenGraph. Both + # of these are important for performance. + + # Note that we automatically convert NetworkX based graphs to use RustworkX + # when we create a Partition object. + # + # Creating and manipulating NX Graphs is easy and users + # are familiar with doing so. It makes sense to preserve the use case of + # creating an NX-Graph and then allowing the code to under-the-covers + # convert to RX - both for legacy compatibility, but also because NX provides + # a really nice and easy way to create graphs. + # + # TODO: Documentation: update the documentation + # to describe the use case of creating a graph using NX. That documentation + # should also describe how to post-process results of a MarkovChain run + # but I haven't figured that out yet... + + # If a NX.Graph, create a Graph object based on NX + if isinstance(graph, networkx.Graph): graph = Graph.from_networkx(graph) + + # if a Graph object, make sure it is based on an embedded RustworkX.PyGraph + if isinstance(graph, Graph): + # frm: TODO: Performance: Remove this short-term hack to do performance testing + # + # This "test_performance_using_NX_graph" hack just forces the partition + # to NOT convert the NX graph to be RX based. This allows me to + # compare RX performance to NX performance with the same code - so that + # whatever is different is crystal clear. + test_performance_using_NX_graph = False + if (graph.is_nx_graph()) and test_performance_using_NX_graph: + self.assignment = get_assignment(assignment, graph) + print("=====================================================") + print("Performance-Test: using NetworkX for Partition object") + print("=====================================================") + + elif (graph.is_nx_graph()): + + # Get the assignment that would be appropriate for the NX-based graph + old_nx_assignment = get_assignment(assignment, graph) + + # Convert the NX graph to be an RX graph + graph = graph.convert_from_nx_to_rx() + + # After converting from NX to RX, we need to update the Partition's assignment + # because it used the old NX node_ids (converting to RX changes node_ids) + nx_to_rx_node_id_map = graph.get_nx_to_rx_node_id_map() + new_rx_assignment = old_nx_assignment.new_assignment_convert_old_node_ids_to_new_node_ids( + nx_to_rx_node_id_map + ) + self.assignment = new_rx_assignment + + else: + self.assignment = get_assignment(assignment, graph) + self.graph = FrozenGraph(graph) + elif isinstance(graph, FrozenGraph): self.graph = graph + self.assignment = get_assignment(assignment, graph) + else: raise TypeError(f"Unsupported Graph object with type {type(graph)}") - self.assignment = get_assignment(assignment, graph) - if set(self.assignment) != set(graph): raise KeyError("The graph's node labels do not match the Assignment's keys") @@ -145,11 +232,25 @@ def _first_time(self, graph, assignment, updaters, use_default_updaters): self.updaters.update(updaters) + # Note that the updater functions are executed lazily - that is, only when + # a caller asks for the results, such as partition["perimeter"]. See the code + # for __getitem__(). + # + # So no need to execute the updater functions now... + self.parent = None self.flips = None self.flows = None self.edge_flows = None + # frm ???: This is only called once and it is tagged as an internal + # function (leading underscore). Is there a good reason + # why this is not internal to the __init__() routine + # where it is used? + # + # That is, is there any reason why anyone might ever + # call this except __init__()? + def _from_parent(self, parent: "Partition", flips: Dict) -> None: self.parent = parent self.flips = flips @@ -173,7 +274,7 @@ def __repr__(self): def __len__(self): return len(self.parts) - def flip(self, flips: Dict) -> "Partition": + def flip(self, flips: Dict, use_original_nx_node_ids=False) -> "Partition": """ Returns the new partition obtained by performing the given `flips` on this partition. @@ -182,6 +283,29 @@ def flip(self, flips: Dict) -> "Partition": :returns: the new :class:`Partition` :rtype: Partition """ + + # frm: TODO: Documentation: Change comments above to document new optional parameter, use_original_nx_node_ids. + # + # This is a new issue that arises from the fact that node_ids in RX are different from those + # in the original NX graph. In the pre-RX code, we did not need to distinguish between + # calls to flip() that were internal code used when doing a MarkovChain versus user code + # for instance in tests. However, in the new RX world, the internal code uses RX node_ids + # and the tests want to use "original" NX node_ids. Hence the new parameter. + + # If the caller identified flips in terms of "original" node_ids (typically node_ids associated with + # an NX-based graph before creating a Partition object), then translate those original node_ids + # into the appropriate internal RX-based node_ids. + # + # Note that original node_ids in flips are typically used in tests + # + + if use_original_nx_node_ids: + new_flips = {} + for original_nx_node_id, part in flips.items(): + internal_node_id = self.graph.internal_node_id_for_original_nx_node_id(original_nx_node_id) + new_flips[internal_node_id] = part + flips = new_flips + return self.__class__(parent=self, flips=flips) def crosses_parts(self, edge: Tuple) -> bool: @@ -205,11 +329,49 @@ def __getitem__(self, key: str) -> Any: :returns: The value of the updater. :rtype: Any """ + # frm: Cleverness Alert: Delayed evaluation of updater functions... + # + # The code immediately below executes the appropriate updater function + # if it has not already been executed and then caches the results. + # This makes sense - why compute something if nobody ever wants it, + # but it took me a while to figure out why the constructor did not + # explicitly call the updaters. + # + if key not in self._cache: + # frm: TODO: Testing: Add a test checking what happens if no updater defined + # + # This code checks that the desired updater actually is + # defined in the list of updaters. If not, then this + # would produce a perhaps difficult to debug problem... + if key not in self.updaters: + raise KeyError(f"__getitem__(): updater: {key} not defined in the updaters for the partition") + self._cache[key] = self.updaters[key](self) return self._cache[key] def __getattr__(self, key): + # frm TODO: Refactor: Not sure it makes sense to allow two ways to accomplish the same thing... + # + # The code below allows Partition users to get the results of updaters by just + # doing: partition. which is the same as doing: partition[""] + # It is clever, but perhaps too clever. Why provide two ways to do the same thing? + # + # It is also odd on a more general level - this approach means that the attributes of a + # Partition are the same as the names of the updaters and return the results of running + # the updater functions. I guess this makes sense, but there is no documentation (that I + # am aware of) that makes this clear. + # + # Peter's comment in PR: + # + # This is actually on my list of things that I would prefer removed. When I first + # started working with this codebase, I found the fact that you could just do + # partition.name_of_my_updater really confusing, and, from a Python perspective, + # I think that the more intuitive interface is keyword access like in a dictionary. + # I haven't scoured the codebase for instances of ".attr" yet, but this is one of + # the things that I am 100% okay with getting rid of. Almost all of the people + # that I have seen work with this package use the partition["attr"] paradigm anyway. + # return self[key] def keys(self): @@ -220,6 +382,15 @@ def parts(self): return self.assignment.parts def plot(self, geometries=None, **kwargs): + # + # frm ???: I think that this plots districts on a map that is defined + # by the geometries parameter (presumably polygons or something similar). + # It converts the partition data into data that the plot routine + # knows how to deal with, but essentially it just assigns each node + # to a district. the **kwargs are then passed to the plotting + # engine - presumably to define colors and other graph stuff. + # + """ Plot the partition, using the provided geometries. @@ -236,9 +407,12 @@ def plot(self, geometries=None, **kwargs): import geopandas if geometries is None: - geometries = self.graph.geometry + if hasattr(self.graph, "geometry"): + geometries = self.graph.geometry + else: + raise Exception("Partition.plot: graph has no geometry data") - if set(geometries.index) != set(self.graph.nodes): + if set(geometries.index) != self.graph.node_indices: raise TypeError( "The provided geometries do not match the nodes of the graph." ) @@ -285,13 +459,15 @@ def from_districtr_file( id_column_key = districtr_plan["idColumn"]["key"] districtr_assignment = districtr_plan["assignment"] try: - node_to_id = {node: str(graph.nodes[node][id_column_key]) for node in graph} + node_to_id = {node: str(graph.node_data(node)[id_column_key]) for node in graph} except KeyError: raise TypeError( "The provided graph is missing the {} column, which is " "needed to match the Districtr assignment to the nodes of the graph." ) - assignment = {node: districtr_assignment[node_to_id[node]] for node in graph} + # frm: TODO: Testing: Verify that there is a test for from_districtr_file() + + assignment = {node_id: districtr_assignment[node_to_id[node_id]] for node_id in graph.node_indices} return cls(graph, assignment, updaters) diff --git a/gerrychain/partition/subgraphs.py b/gerrychain/partition/subgraphs.py index b282a510..3062267e 100644 --- a/gerrychain/partition/subgraphs.py +++ b/gerrychain/partition/subgraphs.py @@ -1,7 +1,6 @@ from typing import List, Any, Tuple from ..graph import Graph - class SubgraphView: """ A view for accessing subgraphs of :class:`Graph` objects. diff --git a/gerrychain/proposals/proposals.py b/gerrychain/proposals/proposals.py index 988c7467..005591ea 100644 --- a/gerrychain/proposals/proposals.py +++ b/gerrychain/proposals/proposals.py @@ -111,8 +111,8 @@ def slow_reversible_propose_bi(partition: Partition) -> Partition: :rtype: Partition """ - b_nodes = {x[0] for x in partition["cut_edges"]}.union( - {x[1] for x in partition["cut_edges"]} + b_nodes = {edge[0] for edge in partition["cut_edges"]}.union( + {edge[1] for edge in partition["cut_edges"]} ) flip = random.choice(list(b_nodes)) diff --git a/gerrychain/proposals/spectral_proposals.py b/gerrychain/proposals/spectral_proposals.py index 3c213a94..ee981e56 100644 --- a/gerrychain/proposals/spectral_proposals.py +++ b/gerrychain/proposals/spectral_proposals.py @@ -1,57 +1,95 @@ -import networkx as nx +import networkx as nx # frm: only used to get access to laplacian functions... from numpy import linalg as LA import random from ..graph import Graph from ..partition import Partition from typing import Dict, Optional - +# frm: only ever used in this file - but maybe it is used externally? def spectral_cut( - graph: Graph, part_labels: Dict, weight_type: str, lap_type: str + subgraph: Graph, part_labels: Dict, weight_type: str, lap_type: str ) -> Dict: """ Spectral cut function. - Uses the signs of the elements in the Fiedler vector of a graph to + Uses the signs of the elements in the Fiedler vector of a subgraph to partition into two components. - :param graph: The graph to be partitioned. - :type graph: Graph - :param part_labels: The current partition of the graph. + :param subgraph: The subgraph to be partitioned. + :type subgraph: Graph + :param part_labels: The current partition of the subgraph. :type part_labels: Dict :param weight_type: The type of weight to be used in the Laplacian. :type weight_type: str :param lap_type: The type of Laplacian to be used. :type lap_type: str - :returns: A dictionary assigning nodes of the graph to their new districts. + :returns: A dictionary assigning nodes of the subgraph to their new districts. :rtype: Dict """ - nlist = list(graph.nodes()) - n = len(nlist) + # This routine operates on subgraphs, which is important because the node_ids + # in a subgraph are different from the node_ids of the parent graph, so + # the return value's node_ids need to be translated back into the appropriate + # parent node_ids. + + node_list = list(subgraph.node_indices) + num_nodes = len(node_list) if weight_type == "random": - for edge in graph.edge_indices: - graph.edges[edge]["weight"] = random.random() + # assign a random weight to each edge in the subgraph + for edge_id in subgraph.edge_indices: + subgraph.edge_data(edge_id)["weight"] = random.random() + # Compute the desired laplacian matrix (convert from sparse to dense) if lap_type == "normalized": - LAP = (nx.normalized_laplacian_matrix(graph)).todense() - + laplacian_matrix = (subgraph.normalized_laplacian_matrix()).todense() else: - LAP = (nx.laplacian_matrix(graph)).todense() + laplacian_matrix = (subgraph.laplacian_matrix()).todense() + + # frm TODO: Documentation: Add a better explanation for why eigenvectors are useful + # for determining flips. Perhaps just a URL to an article + # somewhere... + # + # I have added comments to describe the nuts and bolts of what is happening, + # but the overall rationale for this code is missing - and it should be here... + - NLMva, NLMve = LA.eigh(LAP) - NFv = NLMve[:, 1] - xNFv = [NFv.item(x) for x in range(n)] + # LA.eigh(laplacian_matrix) call invokes the eigh() function from + # the Numpy LinAlg module which: + # + # "returns the eigenvalues and eigenvectors of a complex Hermitian + # ... or a real symmetrix matrix." + # + # In our case we have a symmetric matrix, so it returns two + # objects - a 1-D numpy array containing the eigenvalues (which we don't + # care about) and a 2-D numpy square matrix of the eigenvectors. + numpy_eigen_values, numpy_eigen_vectors = LA.eigh(laplacian_matrix) - node_color = [xNFv[x] > 0 for x in range(n)] + # Extract an eigenvector as a numpy array + # frm: ???: Not sure why we want just one of them... + numpy_eigen_vector = numpy_eigen_vectors[:, 1] # frm: ??? I think that this is an eigenvector... - clusters = {nlist[x]: part_labels[node_color[x]] for x in range(n)} + # Convert to an array of normal Python numbers (not numpy based) + eigen_vector_array = [numpy_eigen_vector.item(x) for x in range(num_nodes)] - return clusters + # node_color will be True or False depending on whether the value in the + # eigen_vector_array is positive or negative. In the code below, this + # is equivalent to node_color being 1 or 0 (since Python treats True as 1 + # and False as 0) + node_color = [eigen_vector_array[x] > 0 for x in range(num_nodes)] + # Create flips using the node_color to select which part (district) to assign + # to the node. + flips = {node_list[x]: part_labels[node_color[x]] for x in range(num_nodes)} + # translate subgraph node_ids in flips to parent_graph node_ids + translated_flips = subgraph.translate_subgraph_node_ids_for_flips(flips) + + return translated_flips + + +# frm: only ever used in this file - but maybe it is used externally? def spectral_recom( partition: Partition, weight_type: Optional[str] = None, @@ -88,16 +126,23 @@ def spectral_recom( :rtype: Partition """ - edge = random.choice(tuple(partition["cut_edges"])) + # Select two adjacent parts (districts) at random by first selecting + # a cut_edge at random and then figuring out the parts (districts) + # associated with the edge. + cut_edge = random.choice(tuple(partition["cut_edges"])) parts_to_merge = ( - partition.assignment.mapping[edge[0]], - partition.assignment.mapping[edge[1]], + partition.assignment.mapping[cut_edge[0]], + partition.assignment.mapping[cut_edge[1]], ) - subgraph = partition.graph.subgraph( - partition.parts[parts_to_merge[0]] | partition.parts[parts_to_merge[1]] - ) + subgraph_nodes = partition.parts[parts_to_merge[0]] | partition.parts[parts_to_merge[1]] - flips = spectral_cut(subgraph, parts_to_merge, weight_type, lap_type) + # Cut the set of all nodes from parts_to_merge into two hopefully new parts (districts) + flips = spectral_cut( + partition.graph.subgraph(subgraph_nodes), + parts_to_merge, + weight_type, + lap_type + ) return partition.flip(flips) diff --git a/gerrychain/proposals/tree_proposals.py b/gerrychain/proposals/tree_proposals.py index e66a718b..d6e1568b 100644 --- a/gerrychain/proposals/tree_proposals.py +++ b/gerrychain/proposals/tree_proposals.py @@ -7,14 +7,14 @@ epsilon_tree_bipartition, bipartition_tree, bipartition_tree_random, - _bipartition_tree_random_all, + bipartition_tree_random_with_num_cuts, uniform_spanning_tree, find_balanced_edge_cuts_memoization, ReselectException, ) from typing import Callable, Optional, Dict, Union - +# frm: only used in this file class MetagraphError(Exception): """ Raised when the partition we are trying to split is a low degree @@ -24,6 +24,7 @@ class MetagraphError(Exception): pass +# frm: only used in this file class ValueWarning(UserWarning): """ Raised whe a particular value is technically valid, but may @@ -89,6 +90,7 @@ def recom( :type method: Callable, optional :returns: The new partition resulting from the ReCom algorithm. + print("bipartition_tree: updating restarts and attempts") :rtype: Partition """ @@ -99,9 +101,14 @@ def recom( # Try to add the region aware in if the method accepts the surcharge dictionary if "region_surcharge" in signature(method).parameters: method = partial(method, region_surcharge=region_surcharge) - + while len(bad_district_pairs) < tot_pairs: + # frm: In no particular order, try to merge and then split pairs of districts + # that have a cut_edge - meaning that they are adjacent, until you either + # find one that can be split, or you have tried all possible pairs + # of adjacent districts... try: + # frm: TODO: Refactoring: see if there is some way to avoid a while True loop... while True: edge = random.choice(tuple(partition["cut_edges"])) # Need to sort the tuple so that the order is consistent @@ -115,12 +122,11 @@ def recom( if tuple(parts_to_merge) not in bad_district_pairs: break - subgraph = partition.graph.subgraph( - partition.parts[parts_to_merge[0]] | partition.parts[parts_to_merge[1]] - ) + # frm: Note that the vertical bar operator merges the two sets into one set. + subgraph_nodes = partition.parts[parts_to_merge[0]] | partition.parts[parts_to_merge[1]] flips = epsilon_tree_bipartition( - subgraph.graph, + partition.graph.subgraph(subgraph_nodes), parts_to_merge, pop_col=pop_col, pop_target=pop_target, @@ -132,6 +138,7 @@ def recom( except Exception as e: if isinstance(e, ReselectException): + # frm: Add this pair to list of pairs that did not work... bad_district_pairs.add(tuple(parts_to_merge)) continue else: @@ -176,6 +183,7 @@ def reversible_recom( :param balance_edge_fn: The balance edge function. Default is find_balanced_edge_cuts_memoization. :type balance_edge_fn: Callable, optional + frm: it returns a list of Cuts - a named tuple defined in tree.py :param M: The maximum number of balance edges. Default is 1. :type M: int, optional :param repeat_until_valid: Flag indicating whether to repeat until a valid partition is @@ -189,6 +197,7 @@ def reversible_recom( """ def dist_pair_edges(part, a, b): + # frm: Find all edges that cross from district a into district b return set( e for e in part.graph.edges @@ -212,41 +221,101 @@ def bounded_balance_edge_fn(*args, **kwargs): ) return cuts + """ + frm: Original Code: + bipartition_tree_random_reversible = partial( _bipartition_tree_random_all, repeat_until_valid=repeat_until_valid, spanning_tree_fn=uniform_spanning_tree, balance_edge_fn=bounded_balance_edge_fn, ) + + I deemed this code to be evil, if only because it used an internal tree.py routine + _bipartition_tree_random_all(). This internal routine returns a set of Cut objects + which otherwise never appear outside tree.py, so this just adds complexity. + + The only reason the original code used _bipartition_tree_random_all() instead of just + using bipartition_tree_random() is that it needs to know how many possible new + districts there are. So, I created a new function in tree.py that does EXACTLY + what bipartition_tree_random() does but which also returns the number of possible + new districts. + + """ + bipartition_tree_random_reversible = partial( + bipartition_tree_random_with_num_cuts, + repeat_until_valid=repeat_until_valid, + spanning_tree_fn=uniform_spanning_tree, + balance_edge_fn=bounded_balance_edge_fn, + ) parts = sorted(list(partition.parts.keys())) dist_pairs = [] for out_part in parts: for in_part in parts: dist_pairs.append((out_part, in_part)) + # frm: TODO: Code: ???: Grok why this code considers pairs that are the same part... + # + # For instance, if there are only two parts (districts), then this code will + # produce four pairs: (0,0), (0,1), (1,0), (1,1). The code below tests + # to see if there is any adjacency, but there will never be adjacency between + # the same part (district). Why not just prune out all pairs that have the + # same two values and save an interation of the entire chain? + # + # Stated differently, is there any value in doing an entire chain iteration + # when we randomly select the same part (district) to merge with itself??? + # + # A similar issue comes up if there are no pair_edges (below). We waste + # an entire iteration in that case too - which seems kind of dumb... + # random_pair = random.choice(dist_pairs) pair_edges = dist_pair_edges(partition, *random_pair) if random_pair[0] == random_pair[1] or not pair_edges: return partition # self-loop: no adjacency + # frm: TODO: Code: ???: Grok why it is OK to return the partition unchanged as the next step. + # + # This runs the risk of running an entire chain without ever changing the partition. + # I assume that the logic is that there is deliberate randomness introduced each time, + # so eventually, if it is possible, the chain will get started, but it seems like there + # should be some kind of check to see if it doesn't ever get started, so that the + # user can have a clue about what is going on... + edge = random.choice(list(pair_edges)) parts_to_merge = ( partition.assignment.mapping[edge[0]], partition.assignment.mapping[edge[1]], ) - subgraph = partition.graph.subgraph( - partition.parts[parts_to_merge[0]] | partition.parts[parts_to_merge[1]] - ) - - all_cuts = bipartition_tree_random_reversible( - subgraph, pop_col=pop_col, pop_target=pop_target, epsilon=epsilon + # Remember node_ids from which subgraph was created - we will need them below + subgraph_nodes = partition.parts[parts_to_merge[0]] | partition.parts[parts_to_merge[1]] + + # frm: Note: This code has changed to make sure we don't access subgraph node_ids. + # The former code saved the subgraph and used its nodes to compute + # the remaining_nodes, but this doesn't work with RX, because the + # node_ids for the subgraph are different from those in the parent graph. + # The solution is to just remember the parent node_ids that were used + # to create the subgraph, and to move the subgraph call in as an actual + # parameter, so that after the call there is no way to reference it. + # + # Going forward, this should be a coding style - only invoke Graph.subgraph() + # as an actual parameter so that there is no way to inadvertently access + # the subgraph's node_ids afterwards. + # + + result = bipartition_tree_random_reversible( + partition.graph.subgraph(subgraph_nodes), + pop_col=pop_col, pop_target=pop_target, epsilon=epsilon ) - if not all_cuts: + if not result: return partition # self-loop: no balance edge - nodes = choice(all_cuts).subset - remaining_nodes = set(subgraph.nodes()) - set(nodes) + num_possible_districts, nodes = result + + remaining_nodes = subgraph_nodes - set(nodes) + # Note: Clever way to create a single dictionary from + # two dictionaries - the ** operator unpacks each dictionary + # and then they get merged into a new dictionary. flips = { **{node: parts_to_merge[0] for node in nodes}, **{node: parts_to_merge[1] for node in remaining_nodes}, @@ -255,7 +324,7 @@ def bounded_balance_edge_fn(*args, **kwargs): new_part = partition.flip(flips) seam_length = len(dist_pair_edges(new_part, *random_pair)) - prob = len(all_cuts) / (M * seam_length) + prob = num_possible_districts / (M * seam_length) if prob > 1: raise ReversibilityError( f"Found {len(all_cuts)} balance edges, but " @@ -267,6 +336,24 @@ def bounded_balance_edge_fn(*args, **kwargs): return partition # self-loop +# frm TODO: Refactoring: I do not think that ReCom() is ever called. Note that it +# only defines a constructor and a __call__() which would allow +# you to call the recom() function by creating a ReCom object and then +# "calling" that object - why not just call the recom function? +# +# ...confused... +# +# My guess is that someone started writing this code thinking that +# a class would make sense but then realized that the only use +# was to call the recom() function but never went back to remove +# the class. In short, I think that we should probably remove the +# class and just keep the function... +# +# What Peter said in a PR: +# +# Another bit of legacy code. I am also not sure why this exists. Seems like +# there were plans for this and then it got dropped when someone graduated +# class ReCom: """ ReCom (short for ReCombination) is a class that represents a ReCom proposal diff --git a/gerrychain/tree.py b/gerrychain/tree.py index 06ce8433..bf3beb3e 100644 --- a/gerrychain/tree.py +++ b/gerrychain/tree.py @@ -12,9 +12,9 @@ and methods for assessing and modifying this data. - Functions for finding balanced edge cuts in a populated graph, either through contraction or memoization techniques. -- A suite of functions (`bipartition_tree`, `recursive_tree_part`, `get_seed_chunks`, etc.) +- A suite of functions (`bipartition_tree`, `recursive_tree_part`, `_get_seed_chunks`, etc.) for partitioning graphs into balanced subsets based on population targets and tolerances. -- Utility functions like `get_max_prime_factor_less_than` and `recursive_seed_part_inner` +- Utility functions like `get_max_prime_factor_less_than` and `_recursive_seed_part_inner` to assist in complex partitioning tasks. Dependencies: @@ -25,9 +25,62 @@ Last Updated: 25 April 2024 """ +# frm: This file, tree.py, needed to be modified to operate on new Graph +# objects instead of NetworkX Graph objects because the routines are +# used by the Graph objects inside a Partion, which will soon be based +# on RustworkX. More specifically, these routines are used by Proposals, +# and we will soon switch to having the underlying Graph object used +# in Partitions and Proposals be based on RustworkX. +# +# It may be the case that they are ONLY ever used by Proposals and +# hence could just have been rewritten to operate on RustworkX Graph +# objects, but there seemed to be no harm in having them work either +# way. It was also a good proving ground for testing whether the new +# Graph object could behave like a NetworkX Graph object (in terms of +# attribute access and syntax). + +""" +frm: RX Documentation + +Many of the functions in this file operate on subgraphs which are different from +NX subgraphs because the node_ids change in the subgraph. To deal with this, +in graph.py we have a _node_id_to_parent_node_id_map data member for Graph objects which maps +the node_ids in a subgraph to the corresponding node_id in its parent graph. This +will allow routines operating on subgraphs to return results using the node_ids +of the parent graph. + +Note that for top-level graphs, we still define this _node_id_to_parent_node_id_map, but in +this case it is an identity map that just maps each node_id to itself. This allows +code to always translate correctly, even if operating on a top-level graph. + +As a matter of coding convention, all calls to graph.subgraph() have been placed +in the actual parameter list of function calls. This limits the scope of the +subgraph node_ids to the called function - eliminating the risk of those node_ids +leaking into surrounding code. Stated differently, this eliminates the cognitive +load of trying to remember whether a node_id is a parent or a subgraph node_id. +""" import networkx as nx -from networkx.algorithms import tree +import rustworkx as rx +import numpy as np +from scipy.sparse import csr_array +# frm: TODO: Refactoring: Remove import of networkx and rustworkx once we have moved networkx +# dependencies out of this file - see comments below on +# spanning trees. + +import networkx.algorithms.tree as nxtree +# frm: TODO: Refactoring Remove import of "tree" from networkx.algorithms in this file +# It is only used to get a spanning tree function: +# +# spanning_tree = nxtree.minimum_spanning_tree( +# +# There is an RX function that also computes a spanning tree - hopefully +# it works as we want it to work and hence can be used. +# +# I think it probably makes sense to move this spanning tree function +# into graph.py and to encapsulate the NX vs RX code there. +# +# Note Peter agrees with this... from functools import partial from inspect import signature @@ -48,96 +101,275 @@ ) import warnings +# frm: import the new Graph object which encapsulates NX and RX Graph... +from .graph import Graph -def predecessors(h: nx.Graph, root: Any) -> Dict: - return {a: b for a, b in nx.bfs_predecessors(h, root)} - - -def successors(h: nx.Graph, root: Any) -> Dict: - return {a: b for a, b in nx.bfs_successors(h, root)} +# frm TODO: Documentation: Update function param docmentation to get rid of nx.Graph and use just Graph +# frm TODO: Documentation: Migration Guide: tree.py is no longer a general purpose module - it is GerryChain specific +# +# Before the work to integrate RX, many of the routines ij tree.py +# operated on NetworkX Graph objects, which meant that the module +# was not bound to just GerryChain work - someone could conceivably +# have used it for a graph oriented project that had nothing to do +# with GerryChain or redistricting. +# +# That is no lnoger true, as the parameters to the routines have +# been changed to be GerryChain Graph objects which are not subclasses +# of NetworkX Graph objects. def random_spanning_tree( - graph: nx.Graph, region_surcharge: Optional[Dict] = None -) -> nx.Graph: + graph: Graph, + region_surcharge: Optional[Dict] = None +) -> Graph: """ Builds a spanning tree chosen by Kruskal's method using random weights. - :param graph: The input graph to build the spanning tree from. Should be a Networkx Graph. - :type graph: nx.Graph + :param graph: The input graph to build the spanning tree from. + :type graph: Graph :param region_surcharge: Dictionary of surcharges to add to the random weights used in region-aware variants. :type region_surcharge: Optional[Dict], optional - :returns: The maximal spanning tree represented as a Networkx Graph. - :rtype: nx.Graph + :returns: The maximal spanning tree represented as a GerryChain Graph. + :rtype: Graph """ + # frm: TODO: Performance + # This seems to me to be an expensive way to build a random spanning + # tree. It calls a routine to compute a "minimal" spanning tree that + # computes the total "weight" of the spanning tree and selects the + # minmal total weight. By making the weights random, this will select + # a different spanning tree each time. This works, but it does not + # in any way depend on the optimization. + # + # Why isn't the uniform_spanning_tree() below adequate? It takes + # a random walk at each point to create the spanning tree. This + # would seem to be a much cheaper way to calculate a spanning tree. + # + # What am I missing??? + # + # The region_surcharge allows the caller to tweak the ramdommess + # which might be useful... + + """ + frm: RX Documentation: + + As far as I can tell a spanning tree is only ever used to populate a PopulatedGraph + and so, there is no need to worry about translating the spanning tree's nodes into + the context of the parent. Stated differently, a spanning tree is not used to + compute something about a subgraph but rather to compute something about whatever + graph is currently being dealt with. + + In short, I am assuming that we can ignore the fact that RX subgraphs have different + node_ids for this function and all will be well... + """ + + # frm: TODO: Refactoring: WTF is up with region_surcharge being unset? The region_surcharge + # is only ever accessed in this routine in the for-loop below to + # increase the weight on the edge - setting it to be an empty dict + # just prevents the code below from blowing up. Why not just put + # a test for the surcharge for-loop alone: + # + # if not region_surcharge is None: + # for key, value in region_surcharge.items(): + # ... + # + # Peter's comments from PR: + # + # peterrrock2 last week + # This is one of mine. I added the region surcharge stuff in an afternoon, + # so I probably did this to prevent the more than 3 levels of indentation + # and to make the reasoning easier to track as I was adding the feature. + # + # Collaborator + # Author + # @peterrrock2 peterrrock2 last week + # Also, I imagine that I originally wanted the function modification to look like + # + # def random_spanning_tree( + # graph: Graph, + # region_surcharge: dict = dict() + # ) -> Graph: + # + # but doing this sort of thing is generally a bad idea in python since the + # dict() is instantiated at import time and then all future calls to the + # function reference the same dict when the surcharge is unset. Not a problem + # for this function, but the accepted best-practice is to change the above to + # + # def random_spanning_tree( + # graph: Graph, + # region_surcharge: Optional[Dict] = None + # ) -> Graph: + # if region_surcharge is None: + # region_surcharge = dict() + # + # since this doesn't reuse the reference. + if region_surcharge is None: region_surcharge = dict() - for edge in graph.edges(): + # Add a random weight to each edge in the graph with the goal of + # causing the selection of a different (random) spanning tree based + # on those weights. + # + # If a region_surcharge was passed in, then we want to add additional + # weight to edges that cross regions or that have a node that is + # not in any region. For example, if we want to keep municipalities + # together in the same district, the region_surcharge would contain + # an additional weight associated with the key for municipalities (say + # "mini") and if an edge went from one municipality to another or if + # either of the nodes in the edge were not in a municipality, then + # the edge would be given the additional weight (value) associated + # with the region_surcharge. This would preference/bias the + # spanning_tree algorithm to select other edges... which would have + # the effect of prioritizing keeping regions intact. + + # frm: TODO: Documentation: Verify that the comment above about region_surcharge is accurate + + # Add random weights to the edges in the graph so that the spanning tree + # algorithm will select a different spanning tree each time. + # + for edge_id in graph.edge_indices: + edge = graph.get_edge_from_edge_id(edge_id) weight = random.random() + + # If there are any entries in the region_surcharge dict, then add + # additional weight to the edge for 1) edges that cross region boundaries (one + # node is in one region and the other node is in a different region) and 2) edges + # where one (or both) of the nodes is not in a region for key, value in region_surcharge.items(): # We surcharge edges that cross regions and those that are not in any region if ( - graph.nodes[edge[0]][key] != graph.nodes[edge[1]][key] - or graph.nodes[edge[0]][key] is None - or graph.nodes[edge[1]][key] is None + graph.node_data(edge[0])[key] != graph.node_data(edge[1])[key] + or graph.node_data(edge[0])[key] is None + or graph.node_data(edge[1])[key] is None ): weight += value - graph.edges[edge]["random_weight"] = weight - - spanning_tree = tree.minimum_spanning_tree( - graph, algorithm="kruskal", weight="random_weight" - ) - return spanning_tree + graph.edge_data(edge_id)["random_weight"] = weight + + # frm: TODO: Refactoring: Code: CROCK: (for the moment) + # We need to create a minimum spanning tree but the way to do so + # is different for NX and RX. I am sure that there is a more elegant + # way to do this, and in any event, this dependence on NX vs RX + # should not be in this file, tree.py, but for now, I am just trying + # to get this to work, so I am using CROCKS... + + graph.verify_graph_is_valid() + + # frm: TODO: Refactoring: Remove NX / RX dependency - maybe move to graph.py + + # frm: TODO: Documentation: Think a bit about original_nx_node_ids + # + # Original node_ids refer to the node_ids used when a graph was created. + # This mostly means remembering the NX node_ids when you create an RX + # based Graph object. In the code below, we create an RX based Graph + # object, but we do not do anything to map original node_ids. This is + # probably OK, but it depends on how the spanning tree is used elsewhere. + # + # In short, worth some thought... + + if (graph.is_nx_graph()): + nx_graph = graph.get_nx_graph() + spanning_tree = nxtree.minimum_spanning_tree( + nx_graph, algorithm="kruskal", weight="random_weight" + ) + spanningGraph = Graph.from_networkx(spanning_tree) + elif (graph.is_rx_graph()): + rx_graph = graph.get_rx_graph() + def get_weight(edge_data): + # function to get the weight of an edge from its data + # This function is passed a dict with the data for the edge. + return edge_data["random_weight"] + spanning_tree = rx.minimum_spanning_tree(rx_graph, get_weight) + spanningGraph = Graph.from_rustworkx(spanning_tree) + else: + raise Exception("random_spanning_tree - bad kind of graph object") + return spanningGraph def uniform_spanning_tree( - graph: nx.Graph, choice: Callable = random.choice -) -> nx.Graph: + graph: Graph, + choice: Callable = random.choice +) -> Graph: """ Builds a spanning tree chosen uniformly from the space of all spanning trees of the graph. Uses Wilson's algorithm. - :param graph: Networkx Graph - :type graph: nx.Graph + :param graph: Graph + :type graph: Graph :param choice: :func:`random.choice`. Defaults to :func:`random.choice`. :type choice: Callable, optional :returns: A spanning tree of the graph chosen uniformly at random. - :rtype: nx.Graph + :rtype: Graph """ - root = choice(list(graph.node_indices)) - tree_nodes = set([root]) - next_node = {root: None} - - for node in graph.node_indices: - u = node + + """ + frm: RX Docmentation: + + As with random_spanning_tree, I am assuming that the issue of RX subgraphs having + different node_ids is not an issue for this routine... + """ + # Pick a starting point at random + root_id = choice(list(graph.node_indices)) + tree_nodes = set([root_id]) + next_node_id = {root_id: None} + + # frm: I think that this builds a tree bottom up. It takes + # every node in the graph (in sequence). If the node + # is already in the list of nodes that have been seen + # which means it has a neighbor registered as a next_node, + # then it is skipped. If this node does not yet have + # a neighbor registered, then it is given one, and + # that neighbor becomes the next node looked at. + # + # This essentially takes a node and travels "up" until + # it finds a node that is already in the tree. Multiple + # nodes can end up with the same "next_node" - which + # in tree-speak means that next_node is the parent of + # all of the nodes that end on it. + + for node_id in graph.node_indices: + u = node_id while u not in tree_nodes: - next_node[u] = choice(list(graph.neighbors(u))) - u = next_node[u] + next_node_id[u] = choice(list(graph.neighbors(u))) + u = next_node_id[u] - u = node + u = node_id while u not in tree_nodes: tree_nodes.add(u) - u = next_node[u] + u = next_node_id[u] + + # frm DONE: To support RX, I added an add_edge() method to Graph. + + # frm: TODO: Refactoring: Remove dependency on NX below + + nx_graph = nx.Graph() + G = Graph.from_networkx(nx_graph) - G = nx.Graph() - for node in tree_nodes: - if next_node[node] is not None: - G.add_edge(node, next_node[node]) + for node_id in tree_nodes: + if next_node_id[node_id] is not None: + G.add_edge(node_id, next_node_id[node_id]) return G +# frm TODO: Documentation: PopulatedGraph - state that this only exists in tree.py +# +# I think that this is only ever used inside this module (except) +# for testing. +# +# Decide if this is intended to only ever be used inside tree.py (and for testing), +# and if so: 1) document that fact and 2) see if there is any Pythonic convention +# for a class that is intended to NOT be used externally (like a leading underscore) +# class PopulatedGraph: """ A class representing a graph with population information. :ivar graph: The underlying graph structure. - :type graph: nx.Graph + :type graph: Graph :ivar subsets: A dictionary mapping nodes to their subsets. :type subsets: Dict :ivar population: A dictionary mapping nodes to their populations. @@ -153,14 +385,14 @@ class PopulatedGraph: def __init__( self, - graph: nx.Graph, + graph: Graph, populations: Dict, ideal_pop: Union[float, int], epsilon: float, ) -> None: """ :param graph: The underlying graph structure. - :type graph: nx.Graph + :type graph: Graph :param populations: A dictionary mapping nodes to their populations. :type populations: Dict :param ideal_pop: The ideal population for each district. @@ -170,15 +402,40 @@ def __init__( :type epsilon: float """ self.graph = graph - self.subsets = {node: {node} for node in graph.nodes} + self.subsets = {node_id: {node_id} for node_id in graph.node_indices} self.population = populations.copy() self.tot_pop = sum(self.population.values()) self.ideal_pop = ideal_pop self.epsilon = epsilon - self._degrees = {node: graph.degree(node) for node in graph.nodes} + self._degrees = {node_id: graph.degree(node_id) for node_id in graph.node_indices} + + # frm: TODO: Refactor: _degrees ??? Why separately store the degree of every node? + # + # The _degrees data member above is used to define a method below called "degree()" + # What is odd is that the implementation of this degree() method could just as + # easily have been self.graph.degree(node_id). And in fact, every call on the + # new degree function could be replaced with just .graph.degree(node_id) + # + # So unless there is a big performace gain (or some other reason), I would be + # in favor of deleting the degree() method below and just using + # .graph.degree(node_id) on the assumption that both NX and RX + # have an efficient implementation of degree()... + def __iter__(self): - return iter(self.graph) + # Note: in the pre RustworkX code, this was implemented as: + # + # return iter(self.graph) + # + # But RustworkX does not support __iter__() - it is not iterable. + # + # The way to do this in the new RustworkX based code is to use + # the node_indices() method which is accessed as a property as in: + # + # for node_id in graph.node_indices: + # ...do something with the node_id + # + raise NotImplementedError("Graph is not iterable - use graph.node_indices instead") def degree(self, node) -> int: return self._degrees[node] @@ -188,6 +445,8 @@ def contract_node(self, node, parent) -> None: self.subsets[parent] |= self.subsets[node] self._degrees[parent] -= 1 + # frm: only ever used inside this file + # But maybe this is intended to be used externally... def has_ideal_population(self, node, one_sided_cut: bool = False) -> bool: """ Checks if a node has an ideal population within the graph up to epsilon. @@ -204,6 +463,23 @@ def has_ideal_population(self, node, one_sided_cut: bool = False) -> bool: :returns: True if the node has an ideal population within the graph up to epsilon. :rtype: bool """ + + # frm: TODO: Refactoring: Create a helper function for this + # + # This logic is repeated several times in this file. Consider + # refactoring the code so that the logic lives in exactly + # one place. + # + # When thinking about refactoring, consider whether it makes + # sense to toggle what this routine does by the "one_sided_cut" + # parameter. Why not have two separate routines with + # similar but distinguishing names. I need to be absolutely + # clear about what the two cases are all about, but my current + # hypothesis is that when one_sided_cut == False, we are looking + # for the edge which when cut produces two districts of + # approximately equal size - so a bisect rather than a find all + # meaning... + if one_sided_cut: return ( abs(self.population[node] - self.ideal_pop) @@ -218,7 +494,7 @@ def has_ideal_population(self, node, one_sided_cut: bool = False) -> bool: def __repr__(self) -> str: graph_info = ( - f"Graph(nodes={len(self.graph.nodes)}, edges={len(self.graph.edges)})" + f"Graph(nodes={len(self.graph.node_indices)}, edges={len(self.graph.edges)})" ) return ( f"{self.__class__.__name__}(" @@ -229,6 +505,9 @@ def __repr__(self) -> str: ) +# frm: ???: Is a Cut used anywhere outside this file? + +# Definition of Cut namedtuple # Tuple that is used in the find_balanced_edge_cuts function Cut = namedtuple("Cut", "edge weight subset") Cut.__new__.__defaults__ = (None, None, None) @@ -239,6 +518,26 @@ def __repr__(self) -> str: "The (frozen) subset of nodes on one side of the cut. Defaults to None." ) +# frm: TODO: Documentation: Document what Cut objects are used for +# +# Not sure how this is used, and so I do not know whether it needs +# to translate node_ids to the parent_node_id context. I am assuming not... +# +# Here is an example of how it is used (in test_tree.py): +# +# method=partial( +# bipartition_tree, +# max_attempts=10000, +# balance_edge_fn=find_balanced_edge_cuts_contraction, +# +# and another in the same test file: +# +# populated_tree = PopulatedGraph( +# tree, {node: 1 for node in tree}, len(tree) / 2, 0.5 +# ) +# cuts = find_balanced_edge_cuts_contraction(populated_tree) + + def find_balanced_edge_cuts_contraction( h: PopulatedGraph, one_sided_cut: bool = False, choice: Callable = random.choice @@ -261,27 +560,54 @@ def find_balanced_edge_cuts_contraction( :rtype: List[Cut] """ - root = choice([x for x in h if h.degree(x) > 1]) + root = choice([node_id for node_id in h.graph.node_indices if h.degree(node_id) > 1]) # BFS predecessors for iteratively contracting leaves - pred = predecessors(h.graph, root) + pred = h.graph.predecessors(root) cuts = [] - leaves = deque(x for x in h if h.degree(x) == 1) + + # frm: Work up from leaf nodes to find subtrees with the "correct" + # population. The algorighm starts with real leaf nodes, but + # if a node does not have the "correct" population, then that + # node is merged (contracted) into its parent, effectively + # creating another leaf node which is then added to the end + # of the queue. + # + # In this way, we calculate the total population of subtrees + # by going bottom up, until we find a subtree that has the + # "correct" population for a cut. + + # frm: ??? Note that there is at least one other routine in this file + # that does something similar (perhaps exactly the same). + # Need to figure out why there are more than one way to do this... + + leaves = deque(node_id for node_id in h.graph.node_indices if h.degree(node_id) == 1) while len(leaves) > 0: leaf = leaves.popleft() if h.has_ideal_population(leaf, one_sided_cut=one_sided_cut): + # frm: If the population of the subtree rooted in this node is the correct + # size, then add it to the cut list. Note that if one_sided_cut == False, + # then the cut means the cut bisects the partition (frm: ??? need to verify this). e = (leaf, pred[leaf]) cuts.append( Cut( edge=e, - weight=h.graph.edges[e].get("random_weight", random.random()), + weight=h.graph.edge_data( + h.graph.get_edge_id_from_edge(e) + ).get("random_weight", random.random()), subset=frozenset(h.subsets[leaf].copy()), ) ) - # Contract the leaf: + # Contract the leaf: frm: merge the leaf's population into the parent and add the parent to "leaves" parent = pred[leaf] + # frm: Add child population and subsets to parent, reduce parent's degree by 1 + # This effectively removes the leaf from the tree, adding all of its data + # to the parent. h.contract_node(leaf, parent) if h.degree(parent) == 1 and parent != root: + # frm: Only add the parent to the end of the queue when we are merging + # the last leaf - this makes sure we only add the parent node to + # the queue one time... leaves.append(parent) return cuts @@ -301,6 +627,18 @@ def _calc_pops(succ, root, h): :returns: A dictionary mapping nodes to their subtree populations. :rtype: Dict """ + # frm: This took me a while to sort out what was going on. + # Conceptually it is easy - given a tree anchored in a root node, + # calculate the population in each subtree going bottom-up. + # The stack (deque) provides the mechanism for going bottom-up. + # On the way down, you just put nodes in the stack (append is like + # push() which seems odd to me, but whatever...) then on the way back + # up, you add the totals for each child to your own population and + # presto you have the total population for each subtree... + # + # For this to work, you just need to have a list of nodes with + # their successors associated with them... + # subtree_pops: Dict[Any, Union[int, float]] = {} stack = deque(n for n in succ[root]) while stack: @@ -322,6 +660,7 @@ def _calc_pops(succ, root, h): return subtree_pops +# frm: Only used in one function and only in this module... def _part_nodes(start, succ): """ Partitions the nodes of a graph into two sets. @@ -335,6 +674,39 @@ def _part_nodes(start, succ): :returns: A set of nodes for a particular district (only one side of the cut). :rtype: Set """ + + """ + frm: Compute the nodes in a subtree defined by a Cut. + + This routine computes the set of nodes in a subtree rooted in the + node identified by "start" in the tree defined by "succ". + + As such it is highly dependent on context and is not generally + useful. That is, it is essentially just a way to refactor some + code used in a couple of places so that the logic in the code is + in one place instead of several. + + To be specific, Cuts are always relative to a specific tree for + a partition. This tree is a "spanning tree" that converts the + graph into a DAG. Cuts are then computed by finding subtrees + of that DAG that have the appropriate population (this could + presumably be modified to include other factors). + + When a Cut is created, we want to collect all of the nodes that + are in the subtree, and this is what this routine does. It + merely starts at the root of the subtree (start) and goes down + the subtree, adding each node to a set. + + frm: TODO: Documentation: Rename this to be more descriptive - perhaps ] + something like: _nodes_in_subtree() or + _nodes_for_cut() + + frm: TODO: Documentation: Add the above explanation for what a Cut is and how + we find them by converting the graph to a DAG and + then looking for subtrees to a block header at the + top of this file. It will give the reader some + idea wtf is going on... ;-) + """ nodes = set() queue = deque([start]) while queue: @@ -347,7 +719,7 @@ def _part_nodes(start, succ): queue.append(c) return nodes - +#frm: used externally by tree_proposals.py def find_balanced_edge_cuts_memoization( h: PopulatedGraph, one_sided_cut: bool = False, choice: Callable = random.choice ) -> List[Cut]: @@ -373,12 +745,38 @@ def find_balanced_edge_cuts_memoization( :returns: A list of balanced edge cuts. :rtype: List[Cut] """ + + """ + frm: ???: confused... + + This function seems to be used for two very different purposes, depending on the + value of the parameter, one_sided_cut. When true, the code looks for lots of cuts + that would create a district with the right population - both above and below the + node being considered. Given that it is operating on a tree, one would assume that + there is only one (or perhaps two if one node's population was tiny) cut for the top + of the tree, but there should be many for the bottom of the tree. + + However, if the paramter is set to false (the default), then the code checks to see + whether a cut would produce two districts - on above and one below the tree that + have the right populations. In this case, the code is presumatly looking for the + single node (again there might be two if one node's population was way below epsilon) + that would bisect the graph into two districts with a tolerable population. + + If I am correct, then there is an opportunity to clarify these two uses - perhaps + with wrapper functions. I am also a bit surprised that snippets of code are repeated. + Again - this causes mental load for the reader, and it is an opportunity for bugs to + creep in later (you fix it in one place but not the other). Not sure this "clarification" + is desired, but it is worth considering... + """ + + # frm: ???: Why does a root have to have degree > 1? I would think that any node would do... - root = choice([x for x in h if h.degree(x) > 1]) - pred = predecessors(h.graph, root) - succ = successors(h.graph, root) + root = choice([node_id for node_id in h.graph.node_indices if h.degree(node_id) > 1]) + pred = h.graph.predecessors(root) + succ = h.graph.successors(root) total_pop = h.tot_pop + # Calculate the population of each subtree in the "succ" tree subtree_pops = _calc_pops(succ, root, h) cuts = [] @@ -386,44 +784,70 @@ def find_balanced_edge_cuts_memoization( if one_sided_cut: for node, tree_pop in subtree_pops.items(): if abs(tree_pop - h.ideal_pop) <= h.ideal_pop * h.epsilon: - e = (node, pred[node]) + # frm: If the subtree for this node has a population within epsilon + # of the ideal, then add it to the cuts list. + e = (node, pred[node]) # get the edge from the parent to this node wt = random.random() + # frm: Add the cut - set its weight if it does not already have one + # and remember all of the nodes in the subtree in the frozenset cuts.append( Cut( edge=e, - weight=h.graph.edges[e].get("random_weight", wt), + weight=h.graph.edge_data( + h.graph.get_edge_id_from_edge(e) + ).get("random_weight", wt), subset=frozenset(_part_nodes(node, succ)), ) ) elif abs((total_pop - tree_pop) - h.ideal_pop) <= h.ideal_pop * h.epsilon: + # frm: If the population of everything ABOVE this node in the tree is + # within epsilon of the ideal, then add it to the cut list too. e = (node, pred[node]) wt = random.random() cuts.append( Cut( edge=e, - weight=h.graph.edges[e].get("random_weight", wt), - subset=frozenset(set(h.graph.nodes) - _part_nodes(node, succ)), + weight=h.graph.edge_data( + h.graph.get_edge_id_from_edge(e) + ).get("random_weight", wt), + subset=frozenset(set(h.graph.node_indices) - _part_nodes(node, succ)), ) ) return cuts + # frm: TODO: Refactoring: this code to make its two use cases clearer: + # + # One use case is bisecting the graph (one_sided_cut is False). The + # other use case is to peel off one part (district) with the appropriate + # population. + # + # Not quite clear yet exactly how to do this, but a return stmt in the middle + # of the routine (above) is a clear sign that something is odd. Perhaps + # we keep the existing function signature but immediately split the code + # into calls on two separate routines - one for each use case. + + # We are looking for a way to bisect the graph (one_sided_cut is False) for node, tree_pop in subtree_pops.items(): + if (abs(tree_pop - h.ideal_pop) <= h.ideal_pop * h.epsilon) and ( abs((total_pop - tree_pop) - h.ideal_pop) <= h.ideal_pop * h.epsilon ): e = (node, pred[node]) wt = random.random() + # frm: TODO: Performance: Think if code below can be made faster... cuts.append( Cut( edge=e, - weight=h.graph.edges[e].get("random_weight", wt), - subset=frozenset(set(h.graph.nodes) - _part_nodes(node, succ)), + weight=h.graph.edge_data( + h.graph.get_edge_id_from_edge(e) + ).get("random_weight", wt), + subset=frozenset(set(h.graph.node_indices) - _part_nodes(node, succ)), ) ) return cuts - +# frm: only used in this file and in a test class BipartitionWarning(UserWarning): """ Generally raised when it is proving difficult to find a balanced cut. @@ -431,7 +855,7 @@ class BipartitionWarning(UserWarning): pass - +# frm: only used in this file and in a test class ReselectException(Exception): """ Raised when the tree-splitting algorithm is unable to find a @@ -477,9 +901,22 @@ def _max_weight_choice(cut_edge_list: List[Cut]) -> Cut: if not isinstance(cut_edge_list[0], Cut) or cut_edge_list[0].weight is None: return random.choice(cut_edge_list) + # frm: ???: this strikes me as possibly expensive. Computing the + # max in a list is O(N) so not terrible, but this + # might be called lots of times (need to know more about + # how it is used). Would it make sense to have the + # cut_edge_list sorted before it is frozen? I think it + # is now a set, so it would need to be a list... Not + # urgent, but worth looking into at some point... + # return max(cut_edge_list, key=lambda cut: cut.weight) +# frm: TODO: Documentation: document what _power_set_sorted_by_size_then_sum() does +# +# Figure out what this does. There is no NX/RX issue here, I just +# don't yet know what it does or why... +# Note that this is only ever used once... def _power_set_sorted_by_size_then_sum(d): power_set = [ s for i in range(1, len(d) + 1) for s in itertools.combinations(d.keys(), i) @@ -501,6 +938,8 @@ def _power_set_sorted_by_size_then_sum(d): def _region_preferred_max_weight_choice( populated_graph: PopulatedGraph, region_surcharge: Dict, cut_edge_list: List[Cut] ) -> Cut: + # frm: ???: There is no NX/RX dependency in this routine, but I do + # not yet understand what it does or why... """ This function is used in the case of a region-aware chain. It is similar to the as :meth:`_max_weight_choice` function except @@ -551,9 +990,24 @@ def _region_preferred_max_weight_choice( # Prepare data for efficient access edge_region_info = { cut: { + #frm: This code is a bit dense (at least for me). + # Given a cut_edge_list (whose elements have an + # attribute, "edge",) construct a dict + # that associates with each "cut" the + # values of the region_surcharge values + # for both nodes in the edge. + # + # So, if the region_surcharge dict was + # {"muni": 0.2, "water": 0.8} then for + # each cut, cut_n, there would be a + # dict value that looked like: + # {"muni": ("siteA", "siteA", + # "water": ("water1", "water2") + # } + # key: ( - populated_graph.graph.nodes[cut.edge[0]].get(key), - populated_graph.graph.nodes[cut.edge[1]].get(key), + populated_graph.graph.node_data(cut.edge[0]).get(key), + populated_graph.graph.node_data(cut.edge[1]).get(key), ) for key in region_surcharge } @@ -577,14 +1031,43 @@ def _region_preferred_max_weight_choice( return _max_weight_choice(cut_edge_list) +# frm TODO: Refactoring: def bipartition_tree( +# +# This might get complicated depending on what kinds of functions +# are used as parameters. That is, do the functions used as parameters +# assume they are working with an NX graph? +# +# I think all of the functions used as parameters have been converted +# to work on the new Graph object, but perhaps end users have created +# their own? Should probably add logic to verify that the +# functions are not written to be operating on an NX Graph. Not sure +# how to do that though... +# +# Peter's comments from PR: +# +# Users do sometimes write custom spanning tree and cut edge functions. My +# recommendation would be to make this simple for now. Have a list of "RX_compatible" +# functions and then have the MarkovChain class do some coersion to store an +# appropriate graph and partition object at initialization. We always expect +# the workflow to be something like +# +# Graph -> Partition -> MarkovChain +# +# But we do copy operations in each step, so I wouldn't expect any weird +# side-effects from pushing the determination of what graph type to use +# off onto the MarkovChain class + +# frm: used in this file and in tree_proposals.py +# But maybe this is intended to be used externally... +# def bipartition_tree( - graph: nx.Graph, + subgraph_to_split: Graph, pop_col: str, pop_target: Union[int, float], epsilon: float, node_repeats: int = 1, - spanning_tree: Optional[nx.Graph] = None, + spanning_tree: Optional[Graph] = None, spanning_tree_fn: Callable = random_spanning_tree, region_surcharge: Optional[Dict] = None, balance_edge_fn: Callable = find_balanced_edge_cuts_memoization, @@ -595,6 +1078,9 @@ def bipartition_tree( allow_pair_reselection: bool = False, cut_choice: Callable = _region_preferred_max_weight_choice, ) -> Set: + # frm: TODO: Refactoring: Change the names of ALL function formal parameters to end in "_fn" - to make it clear + # that the paraemter is a function. This will make it easier to do a global search + # to find all function parameters - as well as just being good coding practice... """ This function finds a balanced 2 partition of a graph by drawing a spanning tree and finding an edge to cut that leaves at most an epsilon @@ -605,7 +1091,7 @@ def bipartition_tree( is ``epsilon * pop_target`` away from ``pop_target``. :param graph: The graph to partition. - :type graph: nx.Graph + :type graph: Graph :param pop_col: The node attribute holding the population of each node. :type pop_col: str :param pop_target: The target population for the returned subset of nodes. @@ -618,7 +1104,7 @@ def bipartition_tree( :type node_repeats: int, optional :param spanning_tree: The spanning tree for the algorithm to use (used when the algorithm chooses a new root and for testing). - :type spanning_tree: Optional[nx.Graph], optional + :type spanning_tree: Optional[Graph], optional :param spanning_tree_fn: The random spanning tree algorithm to use if a spanning tree is not provided. Defaults to :func:`random_spanning_tree`. :type spanning_tree_fn: Callable, optional @@ -661,40 +1147,112 @@ def bipartition_tree( given by ``max_attempts``. """ # Try to add the region-aware in if the spanning_tree_fn accepts a surcharge dictionary + # frm ???: REALLY??? You are going to change the semantics of your program based on the + # a function argument's signature? What if someone refactors the code to have + # different names??? *sigh* + # + # A better strategy would be to lock in the function signature for ALL spanning_tree + # functions and then just have the region_surcharge parameter not be used in some of them... + # + # Same with "one_sided_cut" + # + # Oh - and change "one_sided_cut" to be something a little more intuitive. I have to + # reset my mind every time I see it to figure out whether it means to split into + # two districts or just peel off one district... *sigh* Before doing this, check to + # see if "one_sided_cut" is a term of art that might make sense to some set of experts... + # if "region_surcharge" in signature(spanning_tree_fn).parameters: spanning_tree_fn = partial(spanning_tree_fn, region_surcharge=region_surcharge) if "one_sided_cut" in signature(balance_edge_fn).parameters: balance_edge_fn = partial(balance_edge_fn, one_sided_cut=one_sided_cut) - populations = {node: graph.nodes[node][pop_col] for node in graph.node_indices} + # dict of node_id: population for the nodes in the subgraph + populations = { + node_id: subgraph_to_split.node_data(node_id)[pop_col] + for node_id in subgraph_to_split.node_indices + } + + # frm: TODO: Debugging: Remove debugging code + # print(" ") + # print(f"bipartition_tree(): Entering...") + # print(f"bipartition_tree(): balance_edge_fn is: {balance_edge_fn}") + # print(f"bipartition_tree(): spanning_tree_fn is: {spanning_tree_fn}") + # print(f"bipartition_tree(): populations in subgraph are: {populations}") possible_cuts: List[Cut] = [] if spanning_tree is None: - spanning_tree = spanning_tree_fn(graph) + spanning_tree = spanning_tree_fn(subgraph_to_split) + + # print(" ") + # print(f"bipartition_tree(): subgraph edges: {subgraph_to_split.edges}") + # print(f"bipartition_tree(): initial spanning_tree edges: {spanning_tree.edges}") restarts = 0 attempts = 0 while max_attempts is None or attempts < max_attempts: if restarts == node_repeats: - spanning_tree = spanning_tree_fn(graph) + spanning_tree = spanning_tree_fn(subgraph_to_split) + # print(f"bipartition_tree(): new spanning_tree edges: {spanning_tree.edges}") restarts = 0 h = PopulatedGraph(spanning_tree, populations, pop_target, epsilon) + # frm: TODO: Refactoring: Again - we should NOT be changing semantics based + # on the names in signatures... + # Better approach is to have all of the poosible paramters exist + # in ALL of the versions of the cut_choice() functions and to + # have them default to None if not used by one of the functions. + # Then this code could just pass in the values to the + # cut_choice function, and it could make sense of what to do. + # + # This makes it clear what the overall and comprehensive purpose + # of cut_choice functions are. This centralizes the knowlege + # of what a cut_choice() function is supposed to do - or at least + # it prompts the programmer to document that a param in the + # general scheme does not apply in a given instance. + # + # I realize that this is perhaps not "pythonic" - in that it + # forces the programmer to document overall behavior instead + # of just finding a convenient way to sneak in something new. + # However, when code gets complicated, sneaky/clever code + # is just not worth it - better to have each change be a little + # more painful (needing to change the function signature for + # all instances of a generic function to add new functionality + # that is only needed by one new instance). This provides + # a natural place (in comments of the generic function instances) + # to describe what is going on - and it alerts programmers + # that a given generic function has perhaps many different + # instances - but that they all share the same high level + # responsibility. + is_region_cut = ( "region_surcharge" in signature(cut_choice).parameters and "populated_graph" in signature(cut_choice).parameters ) + # frm: Find one or more edges in the spanning tree, that if cut would + # result in a subtree with the appropriate population. + # This returns a list of Cut objects with attributes edge and subset possible_cuts = balance_edge_fn(h, choice=choice) + # frm: TODO: Debugging: Remove debugging code below + # print(f"bipartition_tree(): possible_cuts = {possible_cuts}") + + # frm: RX Subgraph if len(possible_cuts) != 0: + chosen_cut = None if is_region_cut: - return cut_choice(h, region_surcharge, possible_cuts).subset - - return cut_choice(possible_cuts).subset + chosen_cut = cut_choice(h, region_surcharge, possible_cuts) + else: + chosen_cut = cut_choice(possible_cuts) + translated_nodes = subgraph_to_split.translate_subgraph_node_ids_for_set_of_nodes( + chosen_cut.subset + ) + # print(f"bipartition_tree(): translated_nodes = {translated_nodes}") + # frm: Not sure if it is important that the returned set be a frozenset... + return frozenset(translated_nodes) restarts += 1 attempts += 1 @@ -708,7 +1266,7 @@ def bipartition_tree( "a different pair of districts for recombination.", BipartitionWarning, ) - + if allow_pair_reselection: raise ReselectException( f"Failed to find a balanced cut after {max_attempts} attempts.\n" @@ -717,25 +1275,32 @@ def bipartition_tree( raise RuntimeError(f"Could not find a possible cut after {max_attempts} attempts.") - def _bipartition_tree_random_all( - graph: nx.Graph, + # + # Note: Complexity Alert... _bipartition_tree_random_all does NOT translate node_ids to parent + # + # Unlike many/most of the routines in this module, _bipartition_tree_random_all() does + # not translate node_ids into the IDs of the parent, because calls to it are not made + # on subgraphs. That is, it returns possible Cuts using the same node_ids as the parent. + # It is up to the caller to translate node_ids (if appropriate). + # + graph_to_split: Graph, pop_col: str, pop_target: Union[int, float], epsilon: float, node_repeats: int = 1, repeat_until_valid: bool = True, - spanning_tree: Optional[nx.Graph] = None, + spanning_tree: Optional[Graph] = None, spanning_tree_fn: Callable = random_spanning_tree, balance_edge_fn: Callable = find_balanced_edge_cuts_memoization, choice: Callable = random.choice, max_attempts: Optional[int] = 100000, -) -> List[Tuple[Hashable, Hashable]]: +) -> List[Tuple[Hashable, Hashable]]: # frm: TODO: Documentation: Change this to be a set of node_ids (ints) """ Randomly bipartitions a tree into two subgraphs until a valid bipartition is found. :param graph: The input graph. - :type graph: nx.Graph + :type graph: Graph :param pop_col: The name of the column in the graph nodes that contains the population data. :type pop_col: str :param pop_target: The target population for each subgraph. @@ -750,7 +1315,7 @@ def _bipartition_tree_random_all( :type repeat_until_valid: bool, optional :param spanning_tree: The spanning tree to use for bipartitioning. If None, a random spanning tree will be generated. Defaults to None. - :type spanning_tree: Optional[nx.Graph], optional + :type spanning_tree: Optional[Graph], optional :param spanning_tree_fn: The function to generate a spanning tree. Defaults to random_spanning_tree. :type spanning_tree_fn: Callable, optional @@ -770,18 +1335,22 @@ def _bipartition_tree_random_all( attempts. """ - populations = {node: graph.nodes[node][pop_col] for node in graph.node_indices} + # dict of node_id: population for the nodes in the subgraph + populations = { + node_id: graph_to_split.node_data(node_id)[pop_col] + for node_id in graph_to_split.node_indices + } possible_cuts = [] if spanning_tree is None: - spanning_tree = spanning_tree_fn(graph) + spanning_tree = spanning_tree_fn(graph_to_split) restarts = 0 attempts = 0 while max_attempts is None or attempts < max_attempts: if restarts == node_repeats: - spanning_tree = spanning_tree_fn(graph) + spanning_tree = spanning_tree_fn(graph_to_split) restarts = 0 h = PopulatedGraph(spanning_tree, populations, pop_target, epsilon) possible_cuts = balance_edge_fn(h, choice=choice) @@ -794,15 +1363,128 @@ def _bipartition_tree_random_all( raise RuntimeError(f"Could not find a possible cut after {max_attempts} attempts.") +# frm: used in this file and in tree_proposals.py +# But maybe this is intended to be used externally... + +####################### +# frm: Note: This routine is EXACTLY the same as bipartition_tree_random() except +# that it returns in addition to the nodes for a new district, the +# number of possible new districts. This additional information +# is needed by reversible_recom(), but I did not want to change the +# function signature of bipartition_tree_random() in case it is used +# as part of the public API by someone. +# +# It is bad form to have two functions that are the same excpet for +# a tweak - an invitation for future bugs when you fix something in +# one place and not the other, so maybe this is something we should +# revisit when we decide a general code cleanup is in order... +# +def bipartition_tree_random_with_num_cuts( + graph: Graph, + pop_col: str, + pop_target: Union[int, float], + epsilon: float, + node_repeats: int = 1, + repeat_until_valid: bool = True, + spanning_tree: Optional[Graph] = None, + spanning_tree_fn: Callable = random_spanning_tree, + balance_edge_fn: Callable = find_balanced_edge_cuts_memoization, + one_sided_cut: bool = False, + choice: Callable = random.choice, + max_attempts: Optional[int] = 100000, +) -> Union[Set[Any], None]: + """ + This is like :func:`bipartition_tree` except it chooses a random balanced + cut, rather than the first cut it finds. + + This function finds a balanced 2 partition of a graph by drawing a + spanning tree and finding an edge to cut that leaves at most an epsilon + imbalance between the populations of the parts. If a root fails, new roots + are tried until node_repeats in which case a new tree is drawn. + + Builds up a connected subgraph with a connected complement whose population + is ``epsilon * pop_target`` away from ``pop_target``. + + :param graph: The graph to partition. + :type graph: Graph + :param pop_col: The node attribute holding the population of each node. + :type pop_col: str + :param pop_target: The target population for the returned subset of nodes. + :type pop_target: Union[int, float] + :param epsilon: The allowable deviation from ``pop_target`` (as a percentage of + ``pop_target``) for the subgraph's population. + :type epsilon: float + :param node_repeats: A parameter for the algorithm: how many different choices + of root to use before drawing a new spanning tree. Defaults to 1. + :type node_repeats: int + :param repeat_until_valid: Determines whether to keep drawing spanning trees + until a tree with a balanced cut is found. If `True`, a set of nodes will + always be returned; if `False`, `None` will be returned if a valid spanning + tree is not found on the first try. Defaults to True. + :type repeat_until_valid: bool, optional + :param spanning_tree: The spanning tree for the algorithm to use (used when the + algorithm chooses a new root and for testing). Defaults to None. + :type spanning_tree: Optional[Graph], optional + :param spanning_tree_fn: The random spanning tree algorithm to use if a spanning + tree is not provided. Defaults to :func:`random_spanning_tree`. + :type spanning_tree_fn: Callable, optional + :param balance_edge_fn: The algorithm used to find balanced cut edges. Defaults to + :func:`find_balanced_edge_cuts_memoization`. + :type balance_edge_fn: Callable, optional + :param one_sided_cut: Passed to the ``balance_edge_fn``. Determines whether or not we are + cutting off a single district when partitioning the tree. When + set to False, we check if the node we are cutting and the remaining graph + are both within epsilon of the ideal population. When set to True, we only + check if the node we are cutting is within epsilon of the ideal population. + Defaults to False. + :type one_sided_cut: bool, optional + :param choice: The random choice function. Can be substituted for testing. Defaults + to :func:`random.choice`. + :type choice: Callable, optional + :param max_attempts: The max number of attempts that should be made to bipartition. + Defaults to None. + :type max_attempts: Optional[int], optional + + :returns: A subset of nodes of ``graph`` (whose induced subgraph is connected) or None if a + valid spanning tree is not found. + :rtype: Union[Set[Any], None] + """ + + # frm: TODO: Refactoring: Again - semantics should not depend on signatures... + if "one_sided_cut" in signature(balance_edge_fn).parameters: + balance_edge_fn = partial(balance_edge_fn, one_sided_cut=True) + possible_cuts = _bipartition_tree_random_all( + graph_to_split=graph, + pop_col=pop_col, + pop_target=pop_target, + epsilon=epsilon, + node_repeats=node_repeats, + repeat_until_valid=repeat_until_valid, + spanning_tree=spanning_tree, + spanning_tree_fn=spanning_tree_fn, + balance_edge_fn=balance_edge_fn, + choice=choice, + max_attempts=max_attempts, + ) + if possible_cuts: + chosen_cut = choice(possible_cuts) + num_cuts = len(possible_cuts) + parent_nodes = graph.translate_subgraph_node_ids_for_set_of_nodes(chosen_cut.subset) + return num_cuts, frozenset(parent_nodes) # frm: Not sure if important that it be frozenset + else: + return None + +####################### +# frm TODO: Testing: Check to make sure there is a test for this... def bipartition_tree_random( - graph: nx.Graph, + subgraph_to_split: Graph, pop_col: str, pop_target: Union[int, float], epsilon: float, node_repeats: int = 1, repeat_until_valid: bool = True, - spanning_tree: Optional[nx.Graph] = None, + spanning_tree: Optional[Graph] = None, spanning_tree_fn: Callable = random_spanning_tree, balance_edge_fn: Callable = find_balanced_edge_cuts_memoization, one_sided_cut: bool = False, @@ -822,7 +1504,7 @@ def bipartition_tree_random( is ``epsilon * pop_target`` away from ``pop_target``. :param graph: The graph to partition. - :type graph: nx.Graph + :type graph: Graph :param pop_col: The node attribute holding the population of each node. :type pop_col: str :param pop_target: The target population for the returned subset of nodes. @@ -840,7 +1522,7 @@ def bipartition_tree_random( :type repeat_until_valid: bool, optional :param spanning_tree: The spanning tree for the algorithm to use (used when the algorithm chooses a new root and for testing). Defaults to None. - :type spanning_tree: Optional[nx.Graph], optional + :type spanning_tree: Optional[Graph], optional :param spanning_tree_fn: The random spanning tree algorithm to use if a spanning tree is not provided. Defaults to :func:`random_spanning_tree`. :type spanning_tree_fn: Callable, optional @@ -865,11 +1547,25 @@ def bipartition_tree_random( valid spanning tree is not found. :rtype: Union[Set[Any], None] """ + + # frm: TODO: Refactoring: Again - semantics should not depend on signatures... + # + # This is odd - there are two balance_edge_functions defined in tree.py but + # both of them have a formal parameter with the name "one_sided_cut", so this + # code is not picking one of them. Perhaps there was an earlier version of + # the code where it allowed functions that did not support "one_sided_cut". + # In any event, it looks like this if-stmt is a no-op as far as the current + # codebase is concerned... + # + # Even odder - there is a formal parameter, one_sided_cut, which is never + # used... + + if "one_sided_cut" in signature(balance_edge_fn).parameters: balance_edge_fn = partial(balance_edge_fn, one_sided_cut=True) possible_cuts = _bipartition_tree_random_all( - graph=graph, + graph_to_split=subgraph_to_split, pop_col=pop_col, pop_target=pop_target, epsilon=epsilon, @@ -882,11 +1578,16 @@ def bipartition_tree_random( max_attempts=max_attempts, ) if possible_cuts: - return choice(possible_cuts).subset + chosen_cut = choice(possible_cuts) + translated_nodes = subgraph_to_split.translate_subgraph_node_ids_for_set_of_nodes(chosen_cut.subset) + return frozenset(translated_nodes) # frm: Not sure if important that it be frozenset +# frm: used in this file and in tree_proposals.py +# But maybe this is intended to be used externally... +# frm: Note that this routine is only used in recom() def epsilon_tree_bipartition( - graph: nx.Graph, + subgraph_to_split: Graph, parts: Sequence, pop_target: Union[float, int], pop_col: str, @@ -899,7 +1600,7 @@ def epsilon_tree_bipartition( two parts of population ``pop_target`` (within ``epsilon``). :param graph: The graph to partition into two :math:`\varepsilon`-balanced parts. - :type graph: nx.Graph + :type graph: Graph :param parts: Iterable of part (district) labels (like ``[0,1,2]`` or ``range(4)``). :type parts: Sequence :param pop_target: Target population for each part of the partition. @@ -926,14 +1627,14 @@ def epsilon_tree_bipartition( ) flips = {} - remaining_nodes = graph.node_indices + remaining_nodes = subgraph_to_split.node_indices lb_pop = pop_target * (1 - epsilon) ub_pop = pop_target * (1 + epsilon) check_pop = lambda x: lb_pop <= x <= ub_pop nodes = method( - graph.subgraph(remaining_nodes), + subgraph_to_split.subgraph(remaining_nodes), pop_col=pop_col, pop_target=pop_target, epsilon=epsilon, @@ -944,10 +1645,15 @@ def epsilon_tree_bipartition( if nodes is None: raise BalanceError() + # Calculate the total population for the two districts based on the + # results of the "method()" partitioning. part_pop = 0 for node in nodes: + # frm: ???: The code above has already confirmed that len(parts) is 2 + # so why use negative index values - why not just use + # parts[0] and parts[1]? flips[node] = parts[-2] - part_pop += graph.nodes[node][pop_col] + part_pop += subgraph_to_split.node_data(node)[pop_col] if not check_pop(part_pop): raise PopulationBalanceError() @@ -958,18 +1664,22 @@ def epsilon_tree_bipartition( part_pop = 0 for node in remaining_nodes: flips[node] = parts[-1] - part_pop += graph.nodes[node][pop_col] + part_pop += subgraph_to_split.node_data(node)[pop_col] if not check_pop(part_pop): raise PopulationBalanceError() - return flips + # translate subgraph node_ids back into node_ids in parent graph + translated_flips = subgraph_to_split.translate_subgraph_node_ids_for_flips(flips) + return translated_flips -# TODO: Move these recursive partition functions to their own module. They are not +# frm: TODO: Refactoring: Move these recursive partition functions to their own module. They are not # central to the operation of the recom function despite being tree methods. +# frm: defined here but only used in partition.py +# But maybe this is intended to be used externally... def recursive_tree_part( - graph: nx.Graph, + graph: Graph, parts: Sequence, pop_target: Union[float, int], pop_col: str, @@ -983,7 +1693,7 @@ def recursive_tree_part( generate initial seed plans or to implement ReCom-like "merge walk" proposals. :param graph: The graph to partition into ``len(parts)`` :math:`\varepsilon`-balanced parts. - :type graph: nx.Graph + :type graph: Graph :param parts: Iterable of part (district) labels (like ``[0,1,2]`` or ``range(4)``). :type parts: Sequence :param pop_target: Target population for each part of the partition. @@ -1018,13 +1728,23 @@ def recursive_tree_part( ub_pop = pop_target * (1 + epsilon) check_pop = lambda x: lb_pop <= x <= ub_pop + # frm: Notes to self: The code in the for-loop creates n-2 districts (where n is + # the number of partitions desired) by calling the "method" + # function, whose job it is to produce a connected set of + # nodes that has the desired population target. + # + # Note that it sets one_sided_cut=True which tells the + # "method" function that it is NOT bisecting the graph + # but is rather supposed to just find one connected + # set of nodes of the correct population size. + for part in parts[:-2]: min_pop = max(pop_target * (1 - epsilon), pop_target * (1 - epsilon) - debt) max_pop = min(pop_target * (1 + epsilon), pop_target * (1 + epsilon) - debt) new_pop_target = (min_pop + max_pop) / 2 try: - nodes = method( + node_ids = method( graph.subgraph(remaining_nodes), pop_col=pop_col, pop_target=new_pop_target, @@ -1035,23 +1755,27 @@ def recursive_tree_part( except Exception: raise - if nodes is None: + if node_ids is None: raise BalanceError() part_pop = 0 - for node in nodes: + for node in node_ids: flips[node] = part - part_pop += graph.nodes[node][pop_col] + part_pop += graph.node_data(node)[pop_col] if not check_pop(part_pop): raise PopulationBalanceError() debt += part_pop - pop_target - remaining_nodes -= nodes + remaining_nodes -= node_ids # After making n-2 districts, we need to make sure that the last # two districts are both balanced. - nodes = method( + + # frm: For the last call to "method", set one_sided_cut=False to + # request that "method" create two equal sized districts + # with the given population goal by bisecting the graph. + node_ids = method( graph.subgraph(remaining_nodes), pop_col=pop_col, pop_target=pop_target, @@ -1060,33 +1784,40 @@ def recursive_tree_part( one_sided_cut=False, ) - if nodes is None: + if node_ids is None: raise BalanceError() part_pop = 0 - for node in nodes: - flips[node] = parts[-2] - part_pop += graph.nodes[node][pop_col] + for node_id in node_ids: + flips[node_id] = parts[-2] + # frm: this code fragment: graph.node_data(node_id)[pop_col] is used + # many times and is a candidate for being wrapped with + # a function that has a meaningful name, such as perhaps: + # get_population_for_node(node_id, pop_col). + # This is an example of code-bloat from the perspective of + # code gurus, but it really helps a new code reviewer understand + # WTF is going on... + part_pop += graph.node_data(node_id)[pop_col] if not check_pop(part_pop): raise PopulationBalanceError() - remaining_nodes -= nodes + remaining_nodes -= node_ids # All of the remaining nodes go in the last part part_pop = 0 for node in remaining_nodes: flips[node] = parts[-1] - part_pop += graph.nodes[node][pop_col] + part_pop += graph.node_data(node)[pop_col] if not check_pop(part_pop): raise PopulationBalanceError() return flips - -def get_seed_chunks( - graph: nx.Graph, +# frm: only used in this file, so I changed the name to have a leading underscore +def _get_seed_chunks( + graph: Graph, num_chunks: int, num_dists: int, pop_target: Union[int, float], @@ -1100,7 +1831,7 @@ def get_seed_chunks( balanced within new_epsilon <= ``epsilon`` of a balanced target population. :param graph: The graph - :type graph: nx.Graph + :type graph: Graph :param num_chunks: The number of chunks to partition the graph into :type num_chunks: int :param num_dists: The number of districts @@ -1122,22 +1853,48 @@ def get_seed_chunks( :returns: New assignments for the nodes of ``graph``. :rtype: List[List[int]] """ + + # frm: TODO: Refactoring: Change the name of num_chunks_left to instead be num_districts_per_chunk. + # frm: ???: It is not clear to me when num_chunks will not evenly divide num_dists. In + # the only place where _get_seed_chunks() is called, it is inside an if-stmt + # branch that validates that num_chunks evenly divides num_dists... + # num_chunks_left = num_dists // num_chunks + + # frm: TODO: Refactoring: Change the name of parts below to be something / anything else. Normally + # parts refers to districts, but here is is just a way to keep track of + # sets of nodes for chunks. Yes - they eventually become districts when + # this code gets to the base cases, but I found it confusing at this + # level... + # parts = range(num_chunks) + # frm: ???: I think that new_epsilon is the epsilon to use for each district, in which + # case the epsilon passed in would be for the HERE... new_epsilon = epsilon / (num_chunks_left * num_chunks) if num_chunks_left == 1: new_epsilon = epsilon chunk_pop = 0 for node in graph.node_indices: - chunk_pop += graph.nodes[node][pop_col] + chunk_pop += graph.node_data(node)[pop_col] + # frm: TODO: Refactoring: See if there is a better way to structure this instead of a while True loop... while True: epsilon = abs(epsilon) flips = {} - remaining_nodes = set(graph.nodes) - + remaining_nodes = graph.node_indices + + # frm: ??? What is the distinction between num_chunks and num_districts? + # I think that a chunk is typically a multiple of districts, so + # if we want 15 districts we might only ask for 5 chunks. Stated + # differently a chunk will always have at least enough nodes + # for a given number of districts. As the chunk size gets + # smaller, the number of nodes more closely matches what + # is needed for a set number of districts. + + # frm: Note: This just scales epsilon by the number of districts for each chunk + # so we can get chunks with the appropriate population sizes... min_pop = pop_target * (1 - new_epsilon) * num_chunks_left max_pop = pop_target * (1 + new_epsilon) * num_chunks_left @@ -1146,6 +1903,26 @@ def get_seed_chunks( diff = min(max_pop - chunk_pop_target, chunk_pop_target - min_pop) new_new_epsilon = diff / chunk_pop_target + # frm: Note: This code is clever... It loops through all of the + # parts (districts) except the last, and on each + # iteration, it finds nodes for the given part. + # Each time through the loop it assigns the + # unassigned nodes to the last part, but + # most of this gets overwritten by the next + # iteration, so that at the end the only nodes + # still assigned to the last part are the ones + # that had not been previously assigned. + # + # It works, but is a little too clever for me. + # + # I would just have assigned all nodes to + # the last part before entering the loop + # with a comment saying that by end of loop + # the nodes not assigned in the loop will + # default to the last part. + # + + # Assign all nodes to one of the parts for i in range(len(parts[:-1])): part = parts[i] @@ -1168,13 +1945,21 @@ def get_seed_chunks( for node in remaining_nodes: flips[node] = parts[-1] + # frm: ???: Look at remaining_nodes to see if we are done part_pop = 0 + # frm: ???: Compute population total for remaining nodes. for node in remaining_nodes: - part_pop += graph.nodes[node][pop_col] + part_pop += graph.node_data(node)[pop_col] + # frm: ???: Compute what the population total would be for each district in chunk part_pop_as_dist = part_pop / num_chunks_left fake_epsilon = epsilon + # frm: ???: If the chunk is for more than one district, divide epsilon by two if num_chunks_left != 1: fake_epsilon = epsilon / 2 + # frm: ???: Calculate max and min populations on a district level + # This will just be based on epsilon if we only want one district from chunk, but + # it will be based on half of epsilon if we want more than one district from chunk. + # This is odd - why wouldn't we use an epsilon min_pop_as_dist = pop_target * (1 - fake_epsilon) max_pop_as_dist = pop_target * (1 + fake_epsilon) @@ -1193,10 +1978,11 @@ def get_seed_chunks( return list(chunks.values()) - +# frm: only used in this file +# But maybe this is intended to be used externally... def get_max_prime_factor_less_than(n: int, ceil: int) -> Optional[int]: """ - Helper function for recursive_seed_part_inner. Returns the largest prime factor of ``n`` + Helper function for _recursive_seed_part_inner. Returns the largest prime factor of ``n`` less than ``ceil``, or None if all are greater than ceil. :param n: The number to find the largest prime factor for. @@ -1229,9 +2015,8 @@ def get_max_prime_factor_less_than(n: int, ceil: int) -> Optional[int]: return largest_factor - -def recursive_seed_part_inner( - graph: nx.Graph, +def _recursive_seed_part_inner( + graph: Graph, num_dists: int, pop_target: Union[float, int], pop_col: str, @@ -1245,6 +2030,16 @@ def recursive_seed_part_inner( Inner function for recursive_seed_part. Returns a partition with ``num_dists`` districts balanced within ``epsilon`` of ``pop_target``. + + frm: TODO: Documentation: Correct the above statement that this function returns a partition. + In fact, it returns a list of sets of nodes, which is conceptually + equivalent to a partition, but is not a Partition object. Each + set of nodes constitutes a district, but the district does not + have an ID, and there is nothing that associates these nodes + with a specific graph - that is implicit, depending on the graph + object passed in, so the caller is responsible for knowing that + the returned list of sets belongs to the graph passed in... + Splits graph into num_chunks chunks, and then recursively splits each chunk into ``num_dists``/num_chunks chunks. The number num_chunks of chunks is chosen based on ``n`` and ``ceil`` as follows: @@ -1259,8 +2054,15 @@ def recursive_seed_part_inner( this function bites off a single district from the graph and recursively partitions the remaining graph into ``num_dists - 1`` districts. + frm: ???: OK, but why is the logic above for num_chunks the correct number? Is there + a mathematical reason for it? I assume so, but that explanation is missing... + + I presume that the reason is that something in the code that finds a + district scales exponentially, so it makes sense to divide and conquer. + Even so, why this particular strategy for divide and conquer? + :param graph: The underlying graph structure. - :type graph: nx.Graph + :type graph: Graph :param num_dists: number of districts to partition the graph into :type num_dists: int :param pop_target: Target population for each part of the partition @@ -1292,6 +2094,18 @@ def recursive_seed_part_inner( :rtype: List of sets, each set is a district """ + """ + frm: This code is quite nice once you grok it. + + The goal is to find the given number of districts - but to do it in an + efficient way - meaning with smaller graphs. So conceptually, you want + to + HERE + + There are two base cases when the number of districts still to be found are + either 1 or + + """ # Chooses num_chunks if n is None: if ceil is None: @@ -1301,17 +2115,28 @@ def recursive_seed_part_inner( else: raise ValueError("ceil must be None or at least 2") elif n > 1: + # frm: Note: This is not guaranteed to evenly divide num_dists num_chunks = n else: raise ValueError("n must be None or a positive integer") # base case if num_dists == 1: - return [set(graph.nodes)] + # Just return an assignment with all of the nodes in the graph + # Translate the node_ids into parent_node_ids + translated_set_of_nodes = graph.translate_subgraph_node_ids_for_set_of_nodes( + graph.node_indices + ) + translated_assignment = [] + translated_assignment.append(translated_set_of_nodes) + return translated_assignment + + # frm: In the case when there are exactly 2 districts, split the graph by setting + # one_sided_cut to be False. if num_dists == 2: nodes = method( - graph, + graph.subgraph(graph.node_indices), # needs to be a subgraph pop_col=pop_col, pop_target=pop_target, epsilon=epsilon, @@ -1319,11 +2144,38 @@ def recursive_seed_part_inner( one_sided_cut=False, ) - return [set(nodes), set(graph.nodes) - set(nodes)] + # frm: Note to Self: the name "one_sided_cut" seems unnecessarily opaque. What it really + # means is whether to split the graph into two equal districts or + # whether to just find one district from a larger graph. When we + # clean up this code, consider changing the name of this parameter + # to something like: find_two_equal_sized_districts... + # + # Consider creating a wrapper function which has the better + # name that delegates to a private method to do the work. + + nodes_for_one_district = set(nodes) + nodes_for_the_other_district = set(graph.node_indices) - nodes_for_one_district + + # Translate the subgraph node_ids into parent_node_ids + translated_set_1 = graph.translate_subgraph_node_ids_for_set_of_nodes( + nodes_for_one_district + ) + translated_set_2 = graph.translate_subgraph_node_ids_for_set_of_nodes( + nodes_for_the_other_district + ) + + return [translated_set_1, translated_set_2] # bite off a district and recurse into the remaining subgraph + # frm: Note: In the case when num_chunks does not evenly divide num_dists, + # just find one district, remove those nodes from + # the unassigned nodes and try again with num_dists + # set to be one less. Stated differently, reduce + # number of desired districts until you get to + # one that is evenly divided by num_chunks and then + # do chunk stuff... elif num_chunks is None or num_dists % num_chunks != 0: - remaining_nodes = set(graph.nodes) + remaining_nodes = graph.node_indices nodes = method( graph.subgraph(remaining_nodes), pop_col=pop_col, @@ -1333,7 +2185,9 @@ def recursive_seed_part_inner( one_sided_cut=True, ) remaining_nodes -= nodes - assignment = [nodes] + recursive_seed_part_inner( + # frm: Create a list with the set of nodes returned by method() and then recurse + # to get the rest of the sets of nodes for remaining districts. + assignment = [nodes] + _recursive_seed_part_inner( graph.subgraph(remaining_nodes), num_dists - 1, pop_target, @@ -1345,9 +2199,10 @@ def recursive_seed_part_inner( ) # split graph into num_chunks chunks, and recurse into each chunk + # frm: TODO: Documentation: Add documentation for why a subgraph in call below elif num_dists % num_chunks == 0: - chunks = get_seed_chunks( - graph, + chunks = _get_seed_chunks( + graph.subgraph(graph.node_indices), # needs to be a subgraph num_chunks, num_dists, pop_target, @@ -1358,9 +2213,9 @@ def recursive_seed_part_inner( assignment = [] for chunk in chunks: - chunk_assignment = recursive_seed_part_inner( + chunk_assignment = _recursive_seed_part_inner( graph.subgraph(chunk), - num_dists // num_chunks, + num_dists // num_chunks, # new target number of districts pop_target, pop_col, epsilon, @@ -1369,12 +2224,30 @@ def recursive_seed_part_inner( ceil=ceil, ) assignment += chunk_assignment + else: + # frm: From the logic above, this should never happen, but if it did + # because of a future edit (bug), at least this will catch it + # early before really bizarre things happen... + raise Exception("_recursive_seed_part_inner(): Should never happen...") + + # The assignment object that has been created needs to have its + # node_ids translated into parent_node_ids + + translated_assignment = [] + for set_of_nodes in assignment: + translated_set_of_nodes = graph.translate_subgraph_node_ids_for_set_of_nodes( + set_of_nodes + ) + translated_assignment.append(translated_set_of_nodes) + + return translated_assignment - return assignment +# frm TODO: Refactoring: This routine is never called - not in this file and not in any other GerryChain file. +# Is it intended to be used by end-users? And if so, for what purpose? def recursive_seed_part( - graph: nx.Graph, + graph: Graph, parts: Sequence, pop_target: Union[float, int], pop_col: str, @@ -1386,10 +2259,10 @@ def recursive_seed_part( ) -> Dict: """ Returns a partition with ``num_dists`` districts balanced within ``epsilon`` of - ``pop_target`` by recursively splitting graph using recursive_seed_part_inner. + ``pop_target`` by recursively splitting graph using _recursive_seed_part_inner. :param graph: The graph - :type graph: nx.Graph + :type graph: Graph :param parts: Iterable of part labels (like ``[0,1,2]`` or ``range(4)`` :type parts: Sequence :param pop_target: Target population for each part of the partition @@ -1420,9 +2293,24 @@ def recursive_seed_part( :returns: New assignments for the nodes of ``graph``. :rtype: dict """ + + # frm: Note: It is not strictly necessary to use a subgraph in the call below on + # _recursive_seed_part_inner(), because the top-level graph has + # a _node_id_to_parent_node_id_map that just maps node_ids to themselves. However, + # it seemed a good practice to ALWAYS call routines that are intended + # to deal with subgraphs, to use a subgraph even when not strictly + # necessary. Just one more cognitive load to not have to worry about. + # + # This probably means that the identity _node_id_to_parent_node_id_map for top-level + # graphs will never be used, I still think that it makes sense to retain + # it - again, for consistency: Every graph knows how to translate to + # parent_node_ids even if it is a top-level graph. + # + # In short - an agrument based on invariants being a good thing... + # flips = {} - assignment = recursive_seed_part_inner( - graph, + assignment = _recursive_seed_part_inner( + graph.subgraph(graph.node_indices), len(parts), pop_target, pop_col, @@ -1444,3 +2332,6 @@ class BalanceError(Exception): class PopulationBalanceError(Exception): """Raised when the population of a district is outside the acceptable epsilon range.""" + + + \ No newline at end of file diff --git a/gerrychain/updaters/compactness.py b/gerrychain/updaters/compactness.py index 7b42e201..5683acb7 100644 --- a/gerrychain/updaters/compactness.py +++ b/gerrychain/updaters/compactness.py @@ -16,13 +16,20 @@ def boundary_nodes(partition, alias: str = "boundary_nodes") -> Set: :returns: The set of nodes in the partition that are on the boundary. :rtype: Set """ + + # Note that the "alias" parameter is used as the attribute name + # on the partition - using this "alias" you can retrieve the + # the data stored by an updater that uses this routine... + if partition.parent: return partition.parent[alias] - return { - node - for node in partition.graph.nodes - if partition.graph.nodes[node]["boundary_node"] - } + else: + result = { + node_id + for node_id in partition.graph.node_indices + if partition.graph.node_data(node_id)["boundary_node"] + } + return result def initialize_exterior_boundaries_as_a_set(partition) -> Dict[int, Set]: @@ -37,6 +44,7 @@ def initialize_exterior_boundaries_as_a_set(partition) -> Dict[int, Set]: part_boundaries = collections.defaultdict(set) for node in partition["boundary_nodes"]: part_boundaries[partition.assignment.mapping[node]].add(node) + return part_boundaries @@ -63,6 +71,16 @@ def exterior_boundaries_as_a_set( partition. :rtype: Set """ + # Compute the new set of boundary nodes for the partition. + # + # The term, (inflow & graph_boundary), computes new nodes that are boundary nodes. + # + # the term, (previous | (inflow & graph_boundary)), adds those new boundary nodes to the + # set of previous boundary nodes. + # + # Then all you need to do is subtract all of the nodes in the outflow to remove any of those + # that happen to be boundary nodes... + graph_boundary = partition["boundary_nodes"] return (previous | (inflow & graph_boundary)) - outflow @@ -80,7 +98,7 @@ def initialize_exterior_boundaries(partition) -> Dict[int, float]: boundaries = collections.defaultdict(lambda: 0) for node in graph_boundary: part = partition.assignment.mapping[node] - boundaries[part] += partition.graph.nodes[node]["boundary_perim"] + boundaries[part] += partition.graph.node_data(node)["boundary_perim"] return boundaries @@ -107,11 +125,11 @@ def exterior_boundaries(partition, previous: Set, inflow: Set, outflow: Set) -> """ graph_boundary = partition["boundary_nodes"] added_perimeter = sum( - partition.graph.nodes[node]["boundary_perim"] + partition.graph.node_data(node)["boundary_perim"] for node in inflow & graph_boundary ) removed_perimeter = sum( - partition.graph.nodes[node]["boundary_perim"] + partition.graph.node_data(node)["boundary_perim"] for node in outflow & graph_boundary ) return previous + added_perimeter - removed_perimeter @@ -126,13 +144,33 @@ def initialize_interior_boundaries(partition): perimeter the given part shares with other parts. :rtype: Dict[int, float] """ - return { - part: sum( - partition.graph.edges[edge]["shared_perim"] + + # RustworkX Note: + # + # The old NX code did not distinguish between edges and edge_ids - they were one + # and the same. However, in RX an edge is a tuple and an edge_id is an integer. + # The edges stored in partition["cut_edges_by_part"] are edges (tuples), so + # we need to get the edge_id for each edge in order to access the data for the edge. + + # Get edge_ids for each edge (tuple) + edge_ids_for_part = { + part: [ + partition.graph.get_edge_id_from_edge(edge) for edge in partition["cut_edges_by_part"][part] + ] + for part in partition.parts + } + + # Compute length of the shared perimeter of each part + shared_perimeters_for_part = { + part: sum( + partition.graph.edge_data(edge_id)["shared_perim"] + for edge_id in edge_ids_for_part[part] ) for part in partition.parts } + + return shared_perimeters_for_part @on_edge_flow(initialize_interior_boundaries, alias="interior_boundaries") @@ -159,11 +197,16 @@ def interior_boundaries( boundary of that part. :rtype: Dict """ + added_perimeter = sum( - partition.graph.edges[edge]["shared_perim"] for edge in new_edges + partition.graph.edge_data( + partition.graph.get_edge_id_from_edge(edge) + )["shared_perim"] for edge in new_edges ) removed_perimeter = sum( - partition.graph.edges[edge]["shared_perim"] for edge in old_edges + partition.graph.edge_data( + partition.graph.get_edge_id_from_edge(edge) + )["shared_perim"] for edge in old_edges ) return previous + added_perimeter - removed_perimeter @@ -177,6 +220,7 @@ def flips(partition) -> Dict: given partition. :rtype: Dict """ + # frm: ???: Does anyone ever use this? It seems kind of useless... return partition.flips @@ -184,7 +228,7 @@ def perimeter_of_part(partition, part: int) -> float: """ Totals up the perimeter of the part in the partition. - .. Warning:: + .. Warning:: frm: TODO: Refactoring: Add code to enforce this warning... Requires that 'boundary_perim' be a node attribute, 'shared_perim' be an edge attribute, 'cut_edges' be an updater, and 'exterior_boundaries' be an updater. diff --git a/gerrychain/updaters/county_splits.py b/gerrychain/updaters/county_splits.py index fad28f4c..91350ec5 100644 --- a/gerrychain/updaters/county_splits.py +++ b/gerrychain/updaters/county_splits.py @@ -79,21 +79,27 @@ def compute_county_splits( # Create the initial county data containers. if not partition.parent: + county_dict = dict() - for node in partition.graph.node_indices: - county = partition.graph.lookup(node, county_field) + for node_id in partition.graph.node_indices: + + # First figure get current status of the county's information + county = partition.graph.node_data(node_id)[county_field] if county in county_dict: split, nodes, seen = county_dict[county] else: split, nodes, seen = CountySplit.NOT_SPLIT, [], set() - nodes.append(node) - seen.update(set([partition.assignment.mapping[node]])) + # Now update "nodes" and "seen" with this node_id and the part (district) from partition's assignment. + nodes.append(node_id) + seen.update(set([partition.assignment.mapping[node_id]])) + # lastly, if we have "seen" more than one part (district), then the county is split across parts. if len(seen) > 1: split = CountySplit.OLD_SPLIT + # update the county_dict with new information county_dict[county] = CountyInfo(split, nodes, seen) return county_dict @@ -102,7 +108,7 @@ def compute_county_splits( parent = partition.parent for county, county_info in parent[partition_field].items(): - seen = set(partition.assignment.mapping[node] for node in county_info.nodes) + seen = set(partition.assignment.mapping[node_id] for node_id in county_info.nodes) split = CountySplit.NOT_SPLIT @@ -145,17 +151,17 @@ def _get_splits(partition): def total_reg_splits(partition, reg_attr): """Returns the total number of times that reg_attr is split in the partition.""" all_region_names = set( - partition.graph.nodes[node][reg_attr] for node in partition.graph.nodes + partition.graph.node_data(node_id)[reg_attr] for node_id in partition.graph.node_indices ) split = {name: 0 for name in all_region_names} # Require that the cut_edges updater is attached to the partition for node1, node2 in partition["cut_edges"]: if ( partition.assignment[node1] != partition.assignment[node2] - and partition.graph.nodes[node1][reg_attr] - == partition.graph.nodes[node2][reg_attr] + and partition.graph.node_data(node1)[reg_attr] + == partition.graph.node_data(node2)[reg_attr] ): - split[partition.graph.nodes[node1][reg_attr]] += 1 - split[partition.graph.nodes[node2][reg_attr]] += 1 + split[partition.graph.node_data(node1)[reg_attr]] += 1 + split[partition.graph.node_data(node2)[reg_attr]] += 1 return sum(1 for value in split.values() if value > 0) diff --git a/gerrychain/updaters/cut_edges.py b/gerrychain/updaters/cut_edges.py index 7fac766e..e7852df9 100644 --- a/gerrychain/updaters/cut_edges.py +++ b/gerrychain/updaters/cut_edges.py @@ -3,29 +3,29 @@ from .flows import on_edge_flow, neighbor_flips -def put_edges_into_parts(edges: List, assignment: Dict) -> Dict: + +def _put_edges_into_parts(cut_edges: List, assignment: Dict) -> Dict: """ - :param edges: A list of edges in a graph which are to be separated + :param cut_edges: A list of cut_edges in a graph which are to be separated into their respective parts within the partition according to the given assignment. - :type edges: List + :type cut_edges: List :param assignment: A dictionary mapping nodes to their respective parts within the partition. :type assignment: Dict - :returns: A dictionary mapping each part of a partition to the set of edges + :returns: A dictionary mapping each part of a partition to the set of cut_edges in that part. :rtype: Dict """ by_part = collections.defaultdict(set) - for edge in edges: + for edge in cut_edges: # add edge to the sets corresponding to the parts it touches by_part[assignment.mapping[edge[0]]].add(edge) by_part[assignment.mapping[edge[1]]].add(edge) return by_part - -def new_cuts(partition) -> Set[Tuple]: +def _new_cuts(partition) -> Set[Tuple]: """ :param partition: A partition of a Graph :type partition: :class:`~gerrychain.partition.Partition` @@ -40,7 +40,7 @@ def new_cuts(partition) -> Set[Tuple]: } -def obsolete_cuts(partition) -> Set[Tuple]: +def _obsolete_cuts(partition) -> Set[Tuple]: """ :param partition: A partition of a Graph :type partition: :class:`~gerrychain.partition.Partition` @@ -55,28 +55,48 @@ def obsolete_cuts(partition) -> Set[Tuple]: and not partition.crosses_parts((node, neighbor)) } - def initialize_cut_edges(partition): """ :param partition: A partition of a Graph :type partition: :class:`~gerrychain.partition.Partition` + frm: TODO: Documentation This description should be updated. Cut_edges are edges that touch + two different parts (districts). They are the internal boundaries + between parts (districts). This routine finds all of the cut_edges + in the graph and then creates a dict that stores all of the cut_edges + for each part (district). This dict becomes the value of + partition["cut_edges"]. + + Peter agreed: + Ah, you are correct. It maps parts to cut edges, not just any edges in the partition + + + :returns: A dictionary mapping each part of a partition to the set of edges in that part. :rtype: Dict """ - edges = { + # Compute the set of edges that are "cut_edges" - that is, edges that go from + # one part (district) to another. + cut_edges = { tuple(sorted(edge)) + # frm: edges vs edge_ids: edges are wanted here (tuples) for edge in partition.graph.edges if partition.crosses_parts(edge) } - return put_edges_into_parts(edges, partition.assignment) + return _put_edges_into_parts(cut_edges, partition.assignment) @on_edge_flow(initialize_cut_edges, alias="cut_edges_by_part") def cut_edges_by_part( partition, previous: Set[Tuple], new_edges: Set[Tuple], old_edges: Set[Tuple] ) -> Set[Tuple]: + # + # frm TODO: Documentation: Update / expand the documentation for this routine. + # + # This only operates on cut-edges and not on all of the + # edges in a partition. A "cut-edge" is an edge that spans two districts. + # """ Updater function that responds to the flow of edges between different partitions. @@ -115,6 +135,6 @@ def cut_edges(partition): # Edges that weren't cut, but now are cut # We sort the tuples to make sure we don't accidentally end # up with both (4,5) and (5,4) (for example) in it - new, obsolete = new_cuts(partition), obsolete_cuts(partition) + new, obsolete = _new_cuts(partition), _obsolete_cuts(partition) return (parent["cut_edges"] | new) - obsolete diff --git a/gerrychain/updaters/election.py b/gerrychain/updaters/election.py index 2415de42..ab9efc66 100644 --- a/gerrychain/updaters/election.py +++ b/gerrychain/updaters/election.py @@ -48,12 +48,12 @@ class Election: :type name: str :ivar parties: A list of the names of the parties in the election. :type parties: List[str] - :ivar columns: A list of the columns in the graph's node data that hold + :ivar node_attribute_names: A list of the node_attribute_names in the graph's node data that hold the vote totals for each party. - :type columns: List[str] - :ivar parties_to_columns: A dictionary mapping party names to the columns + :type node_attribute_names: List[str] + :ivar party_names_to_node_attribute_names: A dictionary mapping party names to the node_attribute_names in the graph's node data that hold the vote totals for that party. - :type parties_to_columns: Dict[str, str] + :type party_names_to_node_attribute_names: Dict[str, str] :ivar tallies: A dictionary mapping party names to :class:`DataTally` objects that manage the vote totals for that party. :type tallies: Dict[str, DataTally] @@ -68,54 +68,110 @@ class Election: def __init__( self, name: str, - parties_to_columns: Union[Dict, List], + party_names_to_node_attribute_names: Union[Dict, List], alias: Optional[str] = None, ) -> None: """ :param name: The name of the election. (e.g. "2008 Presidential") :type name: str - :param parties_to_columns: A dictionary matching party names to their - data columns, either as actual columns (list-like, indexed by nodes) + :param party_names_to_node_attribute_names: A mapping from the name of a + party to the name of an attribute of a node that contains the + vote totals for that party. This parameter can be either a list or + a dict. If a list, then the name of the party and the name of the + node attribute are the same, for instance: ["Dem", "Rep"] would + indicate that the "Dem" party vote totals are stored in the "Dem" + node attribute. If a list, then there are two possibilities. + + A dictionary matching party names to their + data node_attribute_names, either as actual node_attribute_names (list-like, indexed by nodes) or as string keys for the node attributes that hold the party's vote totals. Or, a list of strings which will serve as both the party names and the node attribute keys. - :type parties_to_columns: Union[Dict, List] + :type party_names_to_node_attribute_names: Union[Dict, List] :param alias: Alias that the election is registered under in the Partition's dictionary of updaters. :type alias: Optional[str], optional """ + self.name = name if alias is None: alias = name self.alias = alias - if isinstance(parties_to_columns, dict): - self.parties = list(parties_to_columns.keys()) - self.columns = list(parties_to_columns.values()) - self.parties_to_columns = parties_to_columns - elif isinstance(parties_to_columns, list): - self.parties = parties_to_columns - self.columns = parties_to_columns - self.parties_to_columns = dict(zip(self.parties, self.columns)) + # Canonicalize "parties", "node_attribute_names", and "party_names_to_node_attribute_names": + # + # "parties" are the names of the parties for purposes of reporting + # "node_attribute_names" are the names of the node attributes storing vote counts + # "party_names_to_node_attribute_names" is a mapping from one to the other + # + if isinstance(party_names_to_node_attribute_names, dict): + self.parties = list(party_names_to_node_attribute_names.keys()) + self.node_attribute_names = list(party_names_to_node_attribute_names.values()) + self.party_names_to_node_attribute_names = party_names_to_node_attribute_names + elif isinstance(party_names_to_node_attribute_names, list): + # name of the party and the attribute name containing value is the same + self.parties = party_names_to_node_attribute_names + self.node_attribute_names = party_names_to_node_attribute_names + self.party_names_to_node_attribute_names = dict(zip(self.parties, self.node_attribute_names)) else: - raise TypeError("Election expects parties_to_columns to be a dict or list") + raise TypeError("Election expects party_names_to_node_attribute_names to be a dict or list") + + # frm: TODO: Documentation: Migration: Using node_ids to vote tally maps... + # + # A DataTally used to support a first parameter that was either a string + # or a dict. + # + # The idea was that in most cases, the values to be tallied would be present + # as the values of attributes associated with nodes, so it made sense to just + # provide the name of the attribute (a string) to identify what to tally. + # + # However, the code also supported providing an explicit mapping from node_id + # to the value to be tallied (a dict). This was useful for testing because + # it allowed for tallying values without having to implement an updater that + # would be based on a node's attribute. It provided a way to map values that + # were not part of the graph to vote totals. + # + # The problem was that when we started using RX for the embedded graph for + # partitions, the node_ids were no longer the same as the ones the user + # specified when creating the (NX) graph. This complicated the logic of + # having an explicit mapping from node_id to a value to be tallied - to + # make this work the code would have needed to translate the node_ids into + # the internal RX node_ids. + # + # The decision was made (Fred and Peter) that this extra complexity was not + # worth the trouble, so we now disallow passing in an explicit mapping (dict). + # + + for party in self.parties: + if isinstance(self.party_names_to_node_attribute_names[party], dict): + raise Exception("Election: Using a map from node_id to vote totals is no longer permitted") self.tallies = { - party: DataTally(self.parties_to_columns[party], party) + party: DataTally(self.party_names_to_node_attribute_names[party], party) for party in self.parties } self.updater = ElectionUpdater(self) + def _initialize_self(self, partition): + + # Create DataTally objects for each party in the election. + self.tallies = { + # For each party, create a DataTally using the string for the node + # attribute where that party's vote totals can be found. + party: DataTally(self.party_names_to_node_attribute_names[party], party) + for party in self.parties + } + def __str__(self): - return "Election '{}' with vote totals for parties {} from columns {}.".format( - self.name, str(self.parties), str(self.columns) + return "Election '{}' with vote totals for parties {} from node_attribute_names {}.".format( + self.name, str(self.parties), str(self.node_attribute_names) ) def __repr__(self): - return "Election(parties={}, columns={}, alias={})".format( - str(self.parties), str(self.columns), str(self.alias) + return "Election(parties={}, node_attribute_names={}, alias={})".format( + str(self.parties), str(self.node_attribute_names), str(self.alias) ) def __call__(self, *args, **kwargs): @@ -167,6 +223,10 @@ def get_previous_values(self, partition) -> Dict[str, Dict[int, float]]: return previous_totals_for_party +# frm: TODO: Refactoring: This routine, get_percents(), is only ever used inside ElectionResults. +# +# Why is it not defined as an internal function inside ElectionResults? +# def get_percents(counts: Dict, totals: Dict) -> Dict: """ :param counts: A dictionary mapping each part in a partition to the diff --git a/gerrychain/updaters/flows.py b/gerrychain/updaters/flows.py index bf00096b..0dd3f9f1 100644 --- a/gerrychain/updaters/flows.py +++ b/gerrychain/updaters/flows.py @@ -2,6 +2,9 @@ import functools from typing import Dict, Set, Tuple, Callable +# frm: TODO: Documentation: This file needs documentation / comments!!! +# +# Peter agrees... @functools.lru_cache(maxsize=2) def neighbor_flips(partition) -> Set[Tuple]: @@ -36,6 +39,13 @@ def flows_from_changes(old_partition, new_partition) -> Dict: `{'in': , 'out': }`. :rtype: Dict """ + + # frm: TODO: Code: ???: Grok why there is a test for: source != target + # + # It would seem to me that it would be a logic bug if there + # was a "flip" that did not in fact change the partition mapping... + # + flows = collections.defaultdict(create_flow) for node, target in new_partition.flips.items(): source = old_partition.assignment.mapping[node] @@ -129,18 +139,40 @@ def compute_edge_flows(partition) -> Dict: new_source = assignment.mapping[node] new_target = assignment.mapping[neighbor] - cut = new_source != new_target - was_cut = old_source != old_target + # frm: Clarification to myself... + # A "cut edge" is one where the nodes in the edge are assigned to different + # districts. So, how does a flip change whether an edge is a cut edge? There + # are three possibilities: 1) the edge goes from not being a cut edge to being + # a cut edge, 2) the edge goes from being a cut edge to not being a cut edge, + # and 3) the edge was a cut edge before and is still a cut edge after the flip, + # but the partition assignments to one or the other nodes in the edge changes. + # + # That is what the if-stmt below is doing - determining which of the three + # cases each flip falls into. It updates the flows accordingly... + # + cut = new_source != new_target # after flip, the edge is a cut edge + was_cut = old_source != old_target # before flip, the edge was a cut edge if not cut and was_cut: + # was a cut edge before, but now is not, so flows out of both edge_flows[old_target]["out"].add(edge) edge_flows[old_source]["out"].add(edge) elif cut and not was_cut: + # was not a cut edge before, but now is, so flows into both edge_flows[new_target]["in"].add(edge) edge_flows[new_source]["in"].add(edge) elif cut and was_cut: # If an edge was cut and still is cut, we need to make sure the # edge is listed under the correct parts. + # frm: Clarification to myself... Python set subtraction will delete + # from the set on the left any members of the set on the right, + # so no_longer_incident_parts will determine if either old_target, + # or old_source has changed - that is, whether the assignment of + # the one of the old mappings has changed - if so, the edge has + # gone "out" of that partition. If you do the subtraction the + # other way, you find whether the new mappings have changed + # and you can then update the "in" flows + # no_longer_incident_parts = {old_target, old_source} - { new_target, new_source, @@ -151,6 +183,7 @@ def compute_edge_flows(partition) -> Dict: newly_incident_parts = {new_target, new_source} - {old_target, old_source} for part in newly_incident_parts: edge_flows[part]["in"].add(edge) + return edge_flows diff --git a/gerrychain/updaters/locality_split_scores.py b/gerrychain/updaters/locality_split_scores.py index 28720b2f..1a211ed7 100644 --- a/gerrychain/updaters/locality_split_scores.py +++ b/gerrychain/updaters/locality_split_scores.py @@ -1,9 +1,19 @@ # Imports from collections import defaultdict, Counter +# frm TODO: Refactoring: Remove dependence on NetworkX. +# The only use is: +# pieces += nx.number_connected_components(subgraph) import networkx as nx import math from typing import List +# frm: TODO: Performance: Do performance testing and improve performance of these routines. +# +# Peter made the comment in a PR that we should make this code more efficient: +# +# A note on this file: A ton of the code in here is inefficient. This was +# made 6 years ago and hasn't really been touched since then other than +# when I was doing an overhaul on many of the doc strings class LocalitySplits: """ @@ -134,8 +144,26 @@ def __init__( def __call__(self, partition): + # frm: TODO: Refactoring: LocalitySplits: Figure out how this is intended to be used... + # + # Not quite sure why it is better to have a "__call()__" method instead of a + # get_scores(self) method, but whatever... + # + # This routine indeed just computes the requested scores (specified in the constructor). + # It stashed those scores as a data member in the class and returns them to the caller as well. + # + # This all seems kind of misguided to me - and there is no instance of this being used in + # the gerrychain code except in a test, so I am not sure how it is intended to be used. + # + # Probably need to look at some user code that Peter sent me to see if anyone actually uses + # this and if so, how... + # + if self.localities == []: - self.localitydict = dict(partition.graph.nodes(data=self.col_id)) + self.localitydict = {} + for node_id in partition.graph.node_indices: + self.localitydict[node_id] = partition.graph.node_data(node_id)[self.col_id] + self.localities = set(list(self.localitydict.values())) locality_splits = { @@ -154,23 +182,48 @@ def __call__(self, partition): allowed_pieces = {} totpop = 0 - for node in partition.graph.nodes: - totpop += partition.graph.nodes[node][self.pop_col] + for node_id in partition.graph.node_indices: + # frm: TODO: Refactoring: Once you have a partition, you cannot change the total population + # in the Partition, so why don't we cache the total population as + # a data member in Partition? + # + # Peter agreed that this would be a good thing to do + + totpop += partition.graph.node_data(node_id)[self.pop_col] + + # frm: TODO: Refactoring: Ditto with num_districts - isn't this a constant once you create a Partition? + # + # Peter agreed that this would be a good thing to do. num_districts = len(partition.assignment.parts.keys()) + # Compute the total population for each locality and then the number of "allowed pieces" for loc in self.localities: - sg = partition.graph.subgraph( - n - for n, v in partition.graph.nodes(data=True) - if v[self.col_id] == loc - ) - - pop = 0 - for n in sg.nodes(): - pop += sg.nodes[n][self.pop_col] + # frm: TODO: Refactoring: The code below just calculates the total population for a set of nodes. + # This sounds like a good candidate for a utility function. See if this + # logic is repeated elsewhere... + + # Compute the population associated with each location + the_graph = partition.graph + locality_population = {} # dict mapping locality name to population in that locality + for node_id in the_graph.node_indices: + locality_name = the_graph.node_data(node_id)[self.col_id] + locality_pop = the_graph.node_data(node_id)[self.pop_col] + if locality_name not in locality_population: + locality_population[locality_name] = locality_pop + else: + locality_population[locality_name] += locality_pop + + # frm: TODO: Refactoring: Peter commented (in PR) that this is another thing that + # could be cached so we didn't recompute it over and over... + ideal_population_per_district = totpop / num_districts + + # Compute the number of "allowed pieces" for each locality + allowed_pieces = {} + for locality_name in locality_population.keys(): + pop_for_locality = locality_population[locality_name] + allowed_pieces[locality_name] = math.ceil(pop_for_locality / ideal_population_per_district) - allowed_pieces[loc] = math.ceil(pop / (totpop / num_districts)) self.allowed_pieces = allowed_pieces for s in self.scores: @@ -227,8 +280,8 @@ def num_pieces(self, partition) -> int: """ locality_intersections = {} - for n in partition.graph.nodes(): - locality = partition.graph.nodes[n][self.col_id] + for n in partition.graph.node_indices: + locality = partition.graph.node_data(n)[self.col_id] if locality not in locality_intersections: locality_intersections[locality] = set( [partition.assignment.mapping[n]] @@ -243,11 +296,11 @@ def num_pieces(self, partition) -> int: [ x for x in partition.parts[d] - if partition.graph.nodes[x][self.col_id] == locality + if partition.graph.node_data(x)[self.col_id] == locality ] ) - pieces += nx.number_connected_components(subgraph) + pieces += subgraph.num_connected_components() return pieces def naked_boundary(self, partition) -> int: @@ -380,7 +433,7 @@ def symmetric_entropy(self, partition) -> float: # IN PROGRESS vtds = district_dict[district] locality_pop = {k: 0 for k in self.localities} for vtd in vtds: - locality_pop[self.localitydict[vtd]] += partition.graph.nodes[vtd][ + locality_pop[self.localitydict[vtd]] += partition.graph.node_data(vtd)[ self.pop_col ] district_dict[district] = locality_pop diff --git a/gerrychain/updaters/spanning_trees.py b/gerrychain/updaters/spanning_trees.py index 307daf40..9150151f 100644 --- a/gerrychain/updaters/spanning_trees.py +++ b/gerrychain/updaters/spanning_trees.py @@ -4,7 +4,6 @@ import math import numpy -import networkx from typing import Dict @@ -25,7 +24,7 @@ def _num_spanning_trees_in_district(partition, district: int) -> int: :rtype: int """ graph = partition.subgraphs[district] - laplacian = networkx.laplacian_matrix(graph) + laplacian = partition.graph.laplacian_matrix() L = numpy.delete(numpy.delete(laplacian.todense(), 0, 0), 1, 1) return math.exp(numpy.linalg.slogdet(L)[1]) diff --git a/gerrychain/updaters/tally.py b/gerrychain/updaters/tally.py index 97305b38..91c05407 100644 --- a/gerrychain/updaters/tally.py +++ b/gerrychain/updaters/tally.py @@ -19,6 +19,16 @@ class DataTally: :type alias: str """ + # frm: TODO: Code: Check to see if DataTally used for data that is NOT attribute of a node + # + # The comment above indicates that you can use a DataTally for data that is not stored + # as an attribute of a node. Check to see if it is ever actually used that way. If so, + # then update the documentation above to state the use cases for adding up data that is + # NOT stored as a node attribute... + # + # It appears that some tests use the ability to specify tallies that do not involve a + # node attribute, but it is not clear if any "real" code does that... + __slots__ = ["data", "alias", "_call"] def __init__(self, data: Union[Dict, pandas.Series, str], alias: str) -> None: @@ -35,23 +45,40 @@ def __init__(self, data: Union[Dict, pandas.Series, str], alias: str) -> None: self.alias = alias def initialize_tally(partition): + + # If the "data" passed in was a string, then interpret that string + # as the name of a node attribute in the graph, and construct + # a dict of the form: {node_id: node_attribution_value} + # + # If not, then assume that the "data" passed in is already of the + # form: {node_id: data_value} + if isinstance(self.data, str): - nodes = partition.graph.nodes + + # if the "data" passed in was a string, then replace its value with + # a dict of {node_id: attribute_value of the node} + graph = partition.graph + node_ids = partition.graph.node_indices attribute = self.data - self.data = {node: nodes[node][attribute] for node in nodes} + self.data = {node_id: graph.node_data(node_id)[attribute] for node_id in node_ids} tally = collections.defaultdict(int) - for node, part in partition.assignment.items(): - add = self.data[node] + for node_id, part in partition.assignment.items(): + add = self.data[node_id] + # Note: math.isnan() will raise an exception if the value passed in is not + # numeric, so there is no need to do another check to ensure that the value + # is numeric - that test is implicit in math.isnan() + # if math.isnan(add): warnings.warn( - "ignoring nan encountered at node '{}' for attribute '{}'".format( - node, self.alias + "ignoring nan encountered at node_id '{}' for attribute '{}'".format( + node_id, self.alias ) ) else: tally[part] += add + return dict(tally) @on_flow(initialize_tally, alias=alias) @@ -167,7 +194,7 @@ def _update_tally(self, partition): return new_tally def _get_tally_from_node(self, partition, node): - return sum(partition.graph.lookup(node, field) for field in self.fields) + return sum(partition.graph.node_data(node)[field] for field in self.fields) def compute_out_flow(graph, fields: Union[str, List[str]], flow: Dict) -> int: @@ -185,7 +212,7 @@ def compute_out_flow(graph, fields: Union[str, List[str]], flow: Dict) -> int: :returns: The sum of the "field" attribute of nodes in the "out" set of the flow. :rtype: int """ - return sum(graph.lookup(node, field) for node in flow["out"] for field in fields) + return sum(graph.node_data(node)[field] for node in flow["out"] for field in fields) def compute_in_flow(graph, fields: Union[str, List[str]], flow: Dict) -> int: @@ -203,4 +230,4 @@ def compute_in_flow(graph, fields: Union[str, List[str]], flow: Dict) -> int: :returns: The sum of the "field" attribute of nodes in the "in" set of the flow. :rtype: int """ - return sum(graph.lookup(node, field) for node in flow["in"] for field in fields) + return sum(graph.node_data(node)[field] for node in flow["in"] for field in fields) diff --git a/tests/README.txt b/tests/README.txt new file mode 100644 index 00000000..b2a0024e --- /dev/null +++ b/tests/README.txt @@ -0,0 +1,6 @@ +This folder contains tests (and subfolders that also contain tests). + +As a convention (at least for now - October 2025), tests that Fred +Mueller adds will be named test_frm_... Eventually the names should +be changed to have the "_frm" deleted, but for now it will help +identify big changes in testing diff --git a/tests/_foo/do_laplacian.py b/tests/_foo/do_laplacian.py new file mode 100644 index 00000000..110a15d6 --- /dev/null +++ b/tests/_foo/do_laplacian.py @@ -0,0 +1,48 @@ + +import networkx as nx +import rustworkx as rx +import numpy as np +from graph import Graph +import tree as gc_tree + +# Create an RX graph (replace with your graph data) +rx_graph = rx.PyGraph() +rx_graph.add_nodes_from([0, 1, 2, 3]) +rx_graph.add_edges_from([(0, 1, "data"), (0, 2, "data"), (1, 2, "data"), (2, 3, "data")]) + +# 1. Get the adjacency matrix +adj_matrix = rx.adjacency_matrix(rx_graph) + +# 2. Calculate the degree matrix (simplified for this example) +degree_matrix = np.diag([rx_graph.degree(node) for node in rx_graph.node_indices()]) + +# 3. Calculate the Laplacian matrix +rx_laplacian_matrix = degree_matrix - adj_matrix + +# frm: TODO: Debugging: Remove Debugging Code + +# print("RX Adjacency Matrix:") +# print(adj_matrix) + +# print("\nRX Degree Matrix:") +# print(degree_matrix) + +# print("\nRX Laplacian Matrix:") +# print(rx_laplacian_matrix) + +# print("type of RX laplacian_matrix is: ", type(rx_laplacian_matrix)) + +# Create an NX graph (replace with your graph data) +nx_graph = nx.Graph([(0, 1), (0, 2), (1, 2), (2, 3)]) +nx_laplacian_matrix = nx.laplacian_matrix(nx_graph) + +# print("\nNX Laplacian Matrix:") +# print(nx_laplacian_matrix) + +# print("type of NX laplacian_matrix is: ", type(nx_laplacian_matrix)) + +gc_nx_graph = Graph.from_nx_graph(nx_graph) +gc_rx_graph = Graph.from_rx_graph(rx_graph) + +# print("\ngc_laplacian(nx_graph) is: ", gctree.gc_laplacian_matrix(gc_nx_graph)) +# print("\ngc_laplacian(rx_graph) is: ", gctree.gc_laplacian_matrix(gc_rx_graph)) diff --git a/tests/_perf_tests/perf_test.py b/tests/_perf_tests/perf_test.py new file mode 100644 index 00000000..40a2a8b4 --- /dev/null +++ b/tests/_perf_tests/perf_test.py @@ -0,0 +1,63 @@ +# Code copied from the GerryChain User Guide / Tutorial: + +import matplotlib.pyplot as plt +from gerrychain import (Partition, Graph, MarkovChain, + updaters, constraints, accept) +from gerrychain.proposals import recom +from gerrychain.constraints import contiguous +from functools import partial +import pandas + +import cProfile + +# Set the random seed so that the results are reproducible! +import random + +def main(): + + random.seed(2024) + graph = Graph.from_json("./gerrymandria.json") + + my_updaters = { + "population": updaters.Tally("TOTPOP"), + "cut_edges": updaters.cut_edges + } + + initial_partition = Partition( + graph, + assignment="district", + updaters=my_updaters + ) + + # This should be 8 since each district has 1 person in it. + # Note that the key "population" corresponds to the population updater + # that we defined above and not with the population column in the json file. + ideal_population = sum(initial_partition["population"].values()) / len(initial_partition) + + proposal = partial( + recom, + pop_col="TOTPOP", + pop_target=ideal_population, + epsilon=0.01, + node_repeats=2 + ) + + recom_chain = MarkovChain( + proposal=proposal, + constraints=[contiguous], + accept=accept.always_accept, + initial_state=initial_partition, + total_steps=40 + ) + + assignments = list(recom_chain) + + assignment_list = [] + + for i, item in enumerate(recom_chain): + print(f"Finished step {i+1}/{len(recom_chain)}", end="\r") + assignment_list.append(item.assignment) + +if __name__ == "__main__": + cProfile.run('main()', sort='tottime') + diff --git a/tests/_perf_tests/perf_test2.py b/tests/_perf_tests/perf_test2.py new file mode 100644 index 00000000..7d2e78ae --- /dev/null +++ b/tests/_perf_tests/perf_test2.py @@ -0,0 +1,87 @@ + +import matplotlib.pyplot as plt +from gerrychain import (GeographicPartition, Partition, Graph, MarkovChain, + proposals, updaters, constraints, accept, Election) +from gerrychain.proposals import recom +from functools import partial +import pandas +import gerrychain + +import sys +import cProfile + +def main(): + + graph = Graph.from_json("./PA_VTDs.json") + + elections = [ + Election("SEN10", {"Democratic": "SEN10D", "Republican": "SEN10R"}), + Election("SEN12", {"Democratic": "USS12D", "Republican": "USS12R"}), + Election("SEN16", {"Democratic": "T16SEND", "Republican": "T16SENR"}), + Election("PRES12", {"Democratic": "PRES12D", "Republican": "PRES12R"}), + Election("PRES16", {"Democratic": "T16PRESD", "Republican": "T16PRESR"}) + ] + + # Population updater, for computing how close to equality the district + # populations are. "TOTPOP" is the population column from our shapefile. + my_updaters = {"population": updaters.Tally("TOT_POP", alias="population")} + + # Election updaters, for computing election results using the vote totals + # from our shapefile. + election_updaters = {election.name: election for election in elections} + my_updaters.update(election_updaters) + + initial_partition = GeographicPartition( + graph, + assignment="2011_PLA_1", # This identifies the district plan in 2011 + updaters=my_updaters + ) + + # The ReCom proposal needs to know the ideal population for the districts so that + # we can improve speed by bailing early on unbalanced partitions. + + ideal_population = sum(initial_partition["population"].values()) / len(initial_partition) + + # We use functools.partial to bind the extra parameters (pop_col, pop_target, epsilon, node_repeats) + # of the recom proposal. + proposal = partial( + recom, + pop_col="TOT_POP", + pop_target=ideal_population, + epsilon=0.02, + node_repeats=2 + ) + + def cut_edges_length(p): + return len(p["cut_edges"]) + + compactness_bound = constraints.UpperBound( + cut_edges_length, + 2*len(initial_partition["cut_edges"]) + ) + + pop_constraint = constraints.within_percent_of_ideal_population(initial_partition, 0.02) + + print("About to call MarkovChain", file=sys.stderr) + + chain = MarkovChain( + proposal=proposal, + constraints=[ + pop_constraint, + compactness_bound + ], + accept=accept.always_accept, + initial_state=initial_partition, + total_steps=1000 + ) + + print("Done with calling MarkovChain", file=sys.stderr) + + print("About to get all assignments from the chain", file=sys.stderr) + assignments = list(chain) + print("Done getting all assignments from the chain", file=sys.stderr) + + +if __name__ == "__main__": + cProfile.run('main()', sort='tottime') + diff --git a/tests/conftest.py b/tests/conftest.py index 501906ab..e5797be9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,8 +15,8 @@ def three_by_three_grid(): 3 4 5 6 7 8 """ - graph = Graph() - graph.add_edges_from( + nx_graph = nx.Graph() + nx_graph.add_edges_from( [ (0, 1), (0, 3), @@ -32,8 +32,7 @@ def three_by_three_grid(): (7, 8), ] ) - return graph - + return Graph.from_networkx(nx_graph) @pytest.fixture def four_by_five_grid_for_opt(): @@ -47,8 +46,8 @@ def four_by_five_grid_for_opt(): # 5 6 7 8 9 # 0 1 2 3 4 - graph = Graph() - graph.add_nodes_from( + nx_graph = nx.Graph() + nx_graph.add_nodes_from( [ (0, {"population": 10, "opt_value": 1, "MVAP": 2}), (1, {"population": 10, "opt_value": 1, "MVAP": 2}), @@ -73,7 +72,7 @@ def four_by_five_grid_for_opt(): ] ) - graph.add_edges_from( + nx_graph.add_edges_from( [ (0, 1), (0, 5), @@ -109,26 +108,35 @@ def four_by_five_grid_for_opt(): ] ) - return graph + return Graph.from_networkx(nx_graph) @pytest.fixture def graph_with_random_data_factory(three_by_three_grid): + def factory(columns): graph = three_by_three_grid attach_random_data(graph, columns) return graph + # A closure - will add random data (int) to all nodes for each named "column" return factory +# frm: TODO: Refactoring: This routine is only ever used immediately above in def factory(columns). +# Is it part of the external API? If not, then it should be moved inside +# the graph_with_random_data_factory() routine def attach_random_data(graph, columns): for node in graph.nodes: for col in columns: - graph.nodes[node][col] = random.randint(1, 1000) + graph.node_data(node)[col] = random.randint(1, 1000) @pytest.fixture +# frm: ???: Why not just always use three_by_three_grid? At least that gives +# the reader an idea of how many nodes there are? What is the +# value of just having a generic "graph" test fixture??? +# def graph(three_by_three_grid): return three_by_three_grid diff --git a/tests/constraints/test_contiguity.py b/tests/constraints/test_contiguity.py index b94f1b5d..87ade1d7 100644 --- a/tests/constraints/test_contiguity.py +++ b/tests/constraints/test_contiguity.py @@ -3,6 +3,7 @@ def test_contiguous_components(graph): + partition = Partition(graph, {0: 1, 1: 1, 2: 1, 3: 2, 4: 2, 5: 2, 6: 1, 7: 1, 8: 1}) components = contiguous_components(partition) @@ -10,9 +11,14 @@ def test_contiguous_components(graph): assert len(components[1]) == 2 assert len(components[2]) == 1 - assert set(frozenset(g.nodes) for g in components[1]) == { + # Confirm that the appropriate connected subgraphs were found. Note that we need + # to compare against the original node_ids, since RX node_ids change every time + # you create a subgraph. + + assert set(frozenset(g.original_nx_node_ids_for_set(g.nodes)) for g in components[1]) == { frozenset([0, 1, 2]), frozenset([6, 7, 8]), } - - assert set(components[2][0].nodes) == {3, 4, 5} + assert set(frozenset(g.original_nx_node_ids_for_set(g.nodes)) for g in components[2]) == { + frozenset([3, 4, 5]), + } diff --git a/tests/constraints/test_validity.py b/tests/constraints/test_validity.py index 7bc3e01d..c0817e6c 100644 --- a/tests/constraints/test_validity.py +++ b/tests/constraints/test_validity.py @@ -5,7 +5,7 @@ import pytest from gerrychain.constraints import (SelfConfiguringLowerBound, Validator, - contiguous, contiguous_bfs, + contiguous, districts_within_tolerance, no_vanishing_districts, single_flip_contiguous, @@ -50,15 +50,16 @@ def discontiguous_partition(discontiguous_partition_with_flips): def test_contiguous_with_contiguity_no_flips_is_true(contiguous_partition): assert contiguous(contiguous_partition) assert single_flip_contiguous(contiguous_partition) - assert contiguous_bfs(contiguous_partition) + assert contiguous(contiguous_partition) def test_contiguous_with_contiguity_flips_is_true(contiguous_partition_with_flips): contiguous_partition, test_flips = contiguous_partition_with_flips + # frm: TODO: Testing: Figure out whether test_flips are in original node_ids or internal RX node_ids contiguous_partition2 = contiguous_partition.flip(test_flips) assert contiguous(contiguous_partition2) assert single_flip_contiguous(contiguous_partition2) - assert contiguous_bfs(contiguous_partition2) + assert contiguous(contiguous_partition2) def test_discontiguous_with_contiguous_no_flips_is_false(discontiguous_partition): @@ -71,14 +72,15 @@ def test_discontiguous_with_single_flip_contiguous_no_flips_is_false( assert not single_flip_contiguous(discontiguous_partition) -def test_discontiguous_with_contiguous_bfs_no_flips_is_false(discontiguous_partition): - assert not contiguous_bfs(discontiguous_partition) +def test_discontiguous_with_contiguous_no_flips_is_false(discontiguous_partition): + assert not contiguous(discontiguous_partition) def test_discontiguous_with_contiguous_flips_is_false( discontiguous_partition_with_flips ): part, test_flips = discontiguous_partition_with_flips + # frm: TODO: Testing: Figure out whether test_flips are in original node_ids or internal RX node_ids discontiguous_partition2 = part.flip(test_flips) assert not contiguous(discontiguous_partition2) @@ -91,16 +93,18 @@ def test_discontiguous_with_single_flip_contiguous_flips_is_false( discontiguous_partition_with_flips ): part, test_flips = discontiguous_partition_with_flips + # frm: TODO: Testing: Figure out whether test_flips are in original node_ids or internal RX node_ids discontiguous_partition2 = part.flip(test_flips) assert not single_flip_contiguous(discontiguous_partition2) -def test_discontiguous_with_contiguous_bfs_flips_is_false( +def test_discontiguous_with_contiguous_flips_is_false( discontiguous_partition_with_flips ): part, test_flips = discontiguous_partition_with_flips + # frm: TODO: Testing: Figure out whether test_flips are in original node_ids or internal RX node_ids discontiguous_partition2 = part.flip(test_flips) - assert not contiguous_bfs(discontiguous_partition2) + assert not contiguous(discontiguous_partition2) def test_districts_within_tolerance_returns_false_if_districts_are_not_within_tolerance(): diff --git a/tests/frm_tests/README.txt b/tests/frm_tests/README.txt new file mode 100644 index 00000000..037dbec1 --- /dev/null +++ b/tests/frm_tests/README.txt @@ -0,0 +1,6 @@ +This directory contains tests added by Fred Mueller +for the work he is doing / did to convert GerryChain +from using NetworkX to using RustworkX. + +Eventually if his code becomes the new thing, these +tests should be rolled into the normal tests directory. diff --git a/tests/frm_tests/__init__.py b/tests/frm_tests/__init__.py new file mode 100644 index 00000000..b0fefc57 --- /dev/null +++ b/tests/frm_tests/__init__.py @@ -0,0 +1,2 @@ + +print("__init__.py invoked") diff --git a/tests/frm_tests/frm_regression_test.README.txt b/tests/frm_tests/frm_regression_test.README.txt new file mode 100644 index 00000000..a1051154 --- /dev/null +++ b/tests/frm_tests/frm_regression_test.README.txt @@ -0,0 +1,12 @@ +I created a regression test based on the User Guide code so that +I could make changes and quickly test whether they affected +user code. + +The 3 files that I added are: + + * frm_regression_test.py + * Code copied from the User Guide + * gerrymandria.json + * JSON for the graph used in the regression test + * frm_regression_test.README.txt + * This file diff --git a/tests/frm_tests/gerrymandria.json b/tests/frm_tests/gerrymandria.json new file mode 100644 index 00000000..a6ca2fae --- /dev/null +++ b/tests/frm_tests/gerrymandria.json @@ -0,0 +1,1641 @@ +{ + "directed": false, + "multigraph": false, + "graph": [], + "nodes": [ + { + "TOTPOP": 1, + "x": 0, + "y": 0, + "county": "1", + "district": "1", + "precinct": 0, + "muni": "1", + "boundary_node": true, + "boundary_perim": 1, + "water_dist": "2", + "id": 0 + }, + { + "TOTPOP": 1, + "x": 0, + "y": 1, + "county": "1", + "district": "1", + "precinct": 1, + "muni": "1", + "boundary_node": true, + "boundary_perim": 1, + "water_dist": "2", + "id": 1 + }, + { + "TOTPOP": 1, + "x": 0, + "y": 2, + "county": "1", + "district": "1", + "precinct": 2, + "muni": "5", + "boundary_node": true, + "boundary_perim": 1, + "water_dist": "2", + "id": 2 + }, + { + "TOTPOP": 1, + "x": 0, + "y": 3, + "county": "1", + "district": "1", + "precinct": 3, + "muni": "5", + "boundary_node": true, + "boundary_perim": 1, + "water_dist": "2", + "id": 3 + }, + { + "TOTPOP": 1, + "x": 0, + "y": 4, + "county": "3", + "district": "1", + "precinct": 4, + "muni": "9", + "boundary_node": true, + "boundary_perim": 1, + "water_dist": "2", + "id": 4 + }, + { + "TOTPOP": 1, + "x": 0, + "y": 5, + "county": "3", + "district": "1", + "precinct": 5, + "muni": "9", + "boundary_node": true, + "boundary_perim": 1, + "water_dist": "2", + "id": 5 + }, + { + "TOTPOP": 1, + "x": 0, + "y": 6, + "county": "3", + "district": "1", + "precinct": 6, + "muni": "13", + "boundary_node": true, + "boundary_perim": 1, + "water_dist": "2", + "id": 6 + }, + { + "TOTPOP": 1, + "x": 0, + "y": 7, + "county": "3", + "district": "1", + "precinct": 7, + "muni": "13", + "boundary_node": true, + "boundary_perim": 1, + "water_dist": "2", + "id": 7 + }, + { + "TOTPOP": 1, + "x": 1, + "y": 0, + "county": "1", + "district": "2", + "precinct": 8, + "muni": "1", + "boundary_node": true, + "boundary_perim": 1, + "water_dist": "2", + "id": 8 + }, + { + "TOTPOP": 1, + "x": 1, + "y": 1, + "county": "1", + "district": "2", + "precinct": 9, + "muni": "1", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "2", + "id": 9 + }, + { + "TOTPOP": 1, + "x": 1, + "y": 2, + "county": "1", + "district": "2", + "precinct": 10, + "muni": "5", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "2", + "id": 10 + }, + { + "TOTPOP": 1, + "x": 1, + "y": 3, + "county": "1", + "district": "2", + "precinct": 11, + "muni": "5", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "2", + "id": 11 + }, + { + "TOTPOP": 1, + "x": 1, + "y": 4, + "county": "3", + "district": "2", + "precinct": 12, + "muni": "9", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "2", + "id": 12 + }, + { + "TOTPOP": 1, + "x": 1, + "y": 5, + "county": "3", + "district": "2", + "precinct": 13, + "muni": "9", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "2", + "id": 13 + }, + { + "TOTPOP": 1, + "x": 1, + "y": 6, + "county": "3", + "district": "2", + "precinct": 14, + "muni": "13", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "4", + "id": 14 + }, + { + "TOTPOP": 1, + "x": 1, + "y": 7, + "county": "3", + "district": "2", + "precinct": 15, + "muni": "13", + "boundary_node": true, + "boundary_perim": 1, + "water_dist": "4", + "id": 15 + }, + { + "TOTPOP": 1, + "x": 2, + "y": 0, + "county": "1", + "district": "3", + "precinct": 16, + "muni": "2", + "boundary_node": true, + "boundary_perim": 1, + "water_dist": "4", + "id": 16 + }, + { + "TOTPOP": 1, + "x": 2, + "y": 1, + "county": "1", + "district": "3", + "precinct": 17, + "muni": "2", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "4", + "id": 17 + }, + { + "TOTPOP": 1, + "x": 2, + "y": 2, + "county": "1", + "district": "3", + "precinct": 18, + "muni": "6", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "2", + "id": 18 + }, + { + "TOTPOP": 1, + "x": 2, + "y": 3, + "county": "1", + "district": "3", + "precinct": 19, + "muni": "6", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "2", + "id": 19 + }, + { + "TOTPOP": 1, + "x": 2, + "y": 4, + "county": "3", + "district": "3", + "precinct": 20, + "muni": "10", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "4", + "id": 20 + }, + { + "TOTPOP": 1, + "x": 2, + "y": 5, + "county": "3", + "district": "3", + "precinct": 21, + "muni": "10", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "4", + "id": 21 + }, + { + "TOTPOP": 1, + "x": 2, + "y": 6, + "county": "3", + "district": "3", + "precinct": 22, + "muni": "14", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "4", + "id": 22 + }, + { + "TOTPOP": 1, + "x": 2, + "y": 7, + "county": "3", + "district": "3", + "precinct": 23, + "muni": "14", + "boundary_node": true, + "boundary_perim": 1, + "water_dist": "4", + "id": 23 + }, + { + "TOTPOP": 1, + "x": 3, + "y": 0, + "county": "1", + "district": "4", + "precinct": 24, + "muni": "2", + "boundary_node": true, + "boundary_perim": 1, + "water_dist": "4", + "id": 24 + }, + { + "TOTPOP": 1, + "x": 3, + "y": 1, + "county": "1", + "district": "4", + "precinct": 25, + "muni": "2", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "4", + "id": 25 + }, + { + "TOTPOP": 1, + "x": 3, + "y": 2, + "county": "1", + "district": "4", + "precinct": 26, + "muni": "6", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "4", + "id": 26 + }, + { + "TOTPOP": 1, + "x": 3, + "y": 3, + "county": "1", + "district": "4", + "precinct": 27, + "muni": "6", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "4", + "id": 27 + }, + { + "TOTPOP": 1, + "x": 3, + "y": 4, + "county": "3", + "district": "4", + "precinct": 28, + "muni": "10", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "4", + "id": 28 + }, + { + "TOTPOP": 1, + "x": 3, + "y": 5, + "county": "3", + "district": "4", + "precinct": 29, + "muni": "10", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "1", + "id": 29 + }, + { + "TOTPOP": 1, + "x": 3, + "y": 6, + "county": "3", + "district": "4", + "precinct": 30, + "muni": "14", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "1", + "id": 30 + }, + { + "TOTPOP": 1, + "x": 3, + "y": 7, + "county": "3", + "district": "4", + "precinct": 31, + "muni": "14", + "boundary_node": true, + "boundary_perim": 1, + "water_dist": "1", + "id": 31 + }, + { + "TOTPOP": 1, + "x": 4, + "y": 0, + "county": "2", + "district": "5", + "precinct": 32, + "muni": "3", + "boundary_node": true, + "boundary_perim": 1, + "water_dist": "3", + "id": 32 + }, + { + "TOTPOP": 1, + "x": 4, + "y": 1, + "county": "2", + "district": "5", + "precinct": 33, + "muni": "3", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "3", + "id": 33 + }, + { + "TOTPOP": 1, + "x": 4, + "y": 2, + "county": "2", + "district": "5", + "precinct": 34, + "muni": "7", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "3", + "id": 34 + }, + { + "TOTPOP": 1, + "x": 4, + "y": 3, + "county": "2", + "district": "5", + "precinct": 35, + "muni": "7", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "4", + "id": 35 + }, + { + "TOTPOP": 1, + "x": 4, + "y": 4, + "county": "4", + "district": "5", + "precinct": 36, + "muni": "11", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "4", + "id": 36 + }, + { + "TOTPOP": 1, + "x": 4, + "y": 5, + "county": "4", + "district": "5", + "precinct": 37, + "muni": "11", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "1", + "id": 37 + }, + { + "TOTPOP": 1, + "x": 4, + "y": 6, + "county": "4", + "district": "5", + "precinct": 38, + "muni": "15", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "1", + "id": 38 + }, + { + "TOTPOP": 1, + "x": 4, + "y": 7, + "county": "4", + "district": "5", + "precinct": 39, + "muni": "15", + "boundary_node": true, + "boundary_perim": 1, + "water_dist": "1", + "id": 39 + }, + { + "TOTPOP": 1, + "x": 5, + "y": 0, + "county": "2", + "district": "6", + "precinct": 40, + "muni": "3", + "boundary_node": true, + "boundary_perim": 1, + "water_dist": "3", + "id": 40 + }, + { + "TOTPOP": 1, + "x": 5, + "y": 1, + "county": "2", + "district": "6", + "precinct": 41, + "muni": "3", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "3", + "id": 41 + }, + { + "TOTPOP": 1, + "x": 5, + "y": 2, + "county": "2", + "district": "6", + "precinct": 42, + "muni": "7", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "3", + "id": 42 + }, + { + "TOTPOP": 1, + "x": 5, + "y": 3, + "county": "2", + "district": "6", + "precinct": 43, + "muni": "7", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "4", + "id": 43 + }, + { + "TOTPOP": 1, + "x": 5, + "y": 4, + "county": "4", + "district": "6", + "precinct": 44, + "muni": "11", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "1", + "id": 44 + }, + { + "TOTPOP": 1, + "x": 5, + "y": 5, + "county": "4", + "district": "6", + "precinct": 45, + "muni": "11", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "1", + "id": 45 + }, + { + "TOTPOP": 1, + "x": 5, + "y": 6, + "county": "4", + "district": "6", + "precinct": 46, + "muni": "15", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "1", + "id": 46 + }, + { + "TOTPOP": 1, + "x": 5, + "y": 7, + "county": "4", + "district": "6", + "precinct": 47, + "muni": "15", + "boundary_node": true, + "boundary_perim": 1, + "water_dist": "1", + "id": 47 + }, + { + "TOTPOP": 1, + "x": 6, + "y": 0, + "county": "2", + "district": "7", + "precinct": 48, + "muni": "4", + "boundary_node": true, + "boundary_perim": 1, + "water_dist": "3", + "id": 48 + }, + { + "TOTPOP": 1, + "x": 6, + "y": 1, + "county": "2", + "district": "7", + "precinct": 49, + "muni": "4", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "3", + "id": 49 + }, + { + "TOTPOP": 1, + "x": 6, + "y": 2, + "county": "2", + "district": "7", + "precinct": 50, + "muni": "8", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "3", + "id": 50 + }, + { + "TOTPOP": 1, + "x": 6, + "y": 3, + "county": "2", + "district": "7", + "precinct": 51, + "muni": "8", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "3", + "id": 51 + }, + { + "TOTPOP": 1, + "x": 6, + "y": 4, + "county": "4", + "district": "7", + "precinct": 52, + "muni": "12", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "3", + "id": 52 + }, + { + "TOTPOP": 1, + "x": 6, + "y": 5, + "county": "4", + "district": "7", + "precinct": 53, + "muni": "12", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "1", + "id": 53 + }, + { + "TOTPOP": 1, + "x": 6, + "y": 6, + "county": "4", + "district": "7", + "precinct": 54, + "muni": "16", + "boundary_node": false, + "boundary_perim": 0, + "water_dist": "1", + "id": 54 + }, + { + "TOTPOP": 1, + "x": 6, + "y": 7, + "county": "4", + "district": "7", + "precinct": 55, + "muni": "16", + "boundary_node": true, + "boundary_perim": 1, + "water_dist": "1", + "id": 55 + }, + { + "TOTPOP": 1, + "x": 7, + "y": 0, + "county": "2", + "district": "8", + "precinct": 56, + "muni": "4", + "boundary_node": true, + "boundary_perim": 1, + "water_dist": "3", + "id": 56 + }, + { + "TOTPOP": 1, + "x": 7, + "y": 1, + "county": "2", + "district": "8", + "precinct": 57, + "muni": "4", + "boundary_node": true, + "boundary_perim": 1, + "water_dist": "3", + "id": 57 + }, + { + "TOTPOP": 1, + "x": 7, + "y": 2, + "county": "2", + "district": "8", + "precinct": 58, + "muni": "8", + "boundary_node": true, + "boundary_perim": 1, + "water_dist": "3", + "id": 58 + }, + { + "TOTPOP": 1, + "x": 7, + "y": 3, + "county": "2", + "district": "8", + "precinct": 59, + "muni": "8", + "boundary_node": true, + "boundary_perim": 1, + "water_dist": "3", + "id": 59 + }, + { + "TOTPOP": 1, + "x": 7, + "y": 4, + "county": "4", + "district": "8", + "precinct": 60, + "muni": "12", + "boundary_node": true, + "boundary_perim": 1, + "water_dist": "3", + "id": 60 + }, + { + "TOTPOP": 1, + "x": 7, + "y": 5, + "county": "4", + "district": "8", + "precinct": 61, + "muni": "12", + "boundary_node": true, + "boundary_perim": 1, + "water_dist": "1", + "id": 61 + }, + { + "TOTPOP": 1, + "x": 7, + "y": 6, + "county": "4", + "district": "8", + "precinct": 62, + "muni": "16", + "boundary_node": true, + "boundary_perim": 1, + "water_dist": "1", + "id": 62 + }, + { + "TOTPOP": 1, + "x": 7, + "y": 7, + "county": "4", + "district": "8", + "precinct": 63, + "muni": "16", + "boundary_node": true, + "boundary_perim": 1, + "water_dist": "1", + "id": 63 + } + ], + "adjacency": [ + [ + { + "id": 8 + }, + { + "id": 1 + } + ], + [ + { + "id": 0 + }, + { + "id": 9 + }, + { + "id": 2 + } + ], + [ + { + "id": 1 + }, + { + "id": 10 + }, + { + "id": 3 + } + ], + [ + { + "id": 2 + }, + { + "id": 11 + }, + { + "id": 4 + } + ], + [ + { + "id": 3 + }, + { + "id": 12 + }, + { + "id": 5 + } + ], + [ + { + "id": 4 + }, + { + "id": 13 + }, + { + "id": 6 + } + ], + [ + { + "id": 5 + }, + { + "id": 14 + }, + { + "id": 7 + } + ], + [ + { + "id": 6 + }, + { + "id": 15 + } + ], + [ + { + "id": 0 + }, + { + "id": 16 + }, + { + "id": 9 + } + ], + [ + { + "id": 1 + }, + { + "id": 8 + }, + { + "id": 17 + }, + { + "id": 10 + } + ], + [ + { + "id": 2 + }, + { + "id": 9 + }, + { + "id": 18 + }, + { + "id": 11 + } + ], + [ + { + "id": 3 + }, + { + "id": 10 + }, + { + "id": 19 + }, + { + "id": 12 + } + ], + [ + { + "id": 4 + }, + { + "id": 11 + }, + { + "id": 20 + }, + { + "id": 13 + } + ], + [ + { + "id": 5 + }, + { + "id": 12 + }, + { + "id": 21 + }, + { + "id": 14 + } + ], + [ + { + "id": 6 + }, + { + "id": 13 + }, + { + "id": 22 + }, + { + "id": 15 + } + ], + [ + { + "id": 7 + }, + { + "id": 14 + }, + { + "id": 23 + } + ], + [ + { + "id": 8 + }, + { + "id": 24 + }, + { + "id": 17 + } + ], + [ + { + "id": 9 + }, + { + "id": 16 + }, + { + "id": 25 + }, + { + "id": 18 + } + ], + [ + { + "id": 10 + }, + { + "id": 17 + }, + { + "id": 26 + }, + { + "id": 19 + } + ], + [ + { + "id": 11 + }, + { + "id": 18 + }, + { + "id": 27 + }, + { + "id": 20 + } + ], + [ + { + "id": 12 + }, + { + "id": 19 + }, + { + "id": 28 + }, + { + "id": 21 + } + ], + [ + { + "id": 13 + }, + { + "id": 20 + }, + { + "id": 29 + }, + { + "id": 22 + } + ], + [ + { + "id": 14 + }, + { + "id": 21 + }, + { + "id": 30 + }, + { + "id": 23 + } + ], + [ + { + "id": 15 + }, + { + "id": 22 + }, + { + "id": 31 + } + ], + [ + { + "id": 16 + }, + { + "id": 32 + }, + { + "id": 25 + } + ], + [ + { + "id": 17 + }, + { + "id": 24 + }, + { + "id": 33 + }, + { + "id": 26 + } + ], + [ + { + "id": 18 + }, + { + "id": 25 + }, + { + "id": 34 + }, + { + "id": 27 + } + ], + [ + { + "id": 19 + }, + { + "id": 26 + }, + { + "id": 35 + }, + { + "id": 28 + } + ], + [ + { + "id": 20 + }, + { + "id": 27 + }, + { + "id": 36 + }, + { + "id": 29 + } + ], + [ + { + "id": 21 + }, + { + "id": 28 + }, + { + "id": 37 + }, + { + "id": 30 + } + ], + [ + { + "id": 22 + }, + { + "id": 29 + }, + { + "id": 38 + }, + { + "id": 31 + } + ], + [ + { + "id": 23 + }, + { + "id": 30 + }, + { + "id": 39 + } + ], + [ + { + "id": 24 + }, + { + "id": 40 + }, + { + "id": 33 + } + ], + [ + { + "id": 25 + }, + { + "id": 32 + }, + { + "id": 41 + }, + { + "id": 34 + } + ], + [ + { + "id": 26 + }, + { + "id": 33 + }, + { + "id": 42 + }, + { + "id": 35 + } + ], + [ + { + "id": 27 + }, + { + "id": 34 + }, + { + "id": 43 + }, + { + "id": 36 + } + ], + [ + { + "id": 28 + }, + { + "id": 35 + }, + { + "id": 44 + }, + { + "id": 37 + } + ], + [ + { + "id": 29 + }, + { + "id": 36 + }, + { + "id": 45 + }, + { + "id": 38 + } + ], + [ + { + "id": 30 + }, + { + "id": 37 + }, + { + "id": 46 + }, + { + "id": 39 + } + ], + [ + { + "id": 31 + }, + { + "id": 38 + }, + { + "id": 47 + } + ], + [ + { + "id": 32 + }, + { + "id": 48 + }, + { + "id": 41 + } + ], + [ + { + "id": 33 + }, + { + "id": 40 + }, + { + "id": 49 + }, + { + "id": 42 + } + ], + [ + { + "id": 34 + }, + { + "id": 41 + }, + { + "id": 50 + }, + { + "id": 43 + } + ], + [ + { + "id": 35 + }, + { + "id": 42 + }, + { + "id": 51 + }, + { + "id": 44 + } + ], + [ + { + "id": 36 + }, + { + "id": 43 + }, + { + "id": 52 + }, + { + "id": 45 + } + ], + [ + { + "id": 37 + }, + { + "id": 44 + }, + { + "id": 53 + }, + { + "id": 46 + } + ], + [ + { + "id": 38 + }, + { + "id": 45 + }, + { + "id": 54 + }, + { + "id": 47 + } + ], + [ + { + "id": 39 + }, + { + "id": 46 + }, + { + "id": 55 + } + ], + [ + { + "id": 40 + }, + { + "id": 56 + }, + { + "id": 49 + } + ], + [ + { + "id": 41 + }, + { + "id": 48 + }, + { + "id": 57 + }, + { + "id": 50 + } + ], + [ + { + "id": 42 + }, + { + "id": 49 + }, + { + "id": 58 + }, + { + "id": 51 + } + ], + [ + { + "id": 43 + }, + { + "id": 50 + }, + { + "id": 59 + }, + { + "id": 52 + } + ], + [ + { + "id": 44 + }, + { + "id": 51 + }, + { + "id": 60 + }, + { + "id": 53 + } + ], + [ + { + "id": 45 + }, + { + "id": 52 + }, + { + "id": 61 + }, + { + "id": 54 + } + ], + [ + { + "id": 46 + }, + { + "id": 53 + }, + { + "id": 62 + }, + { + "id": 55 + } + ], + [ + { + "id": 47 + }, + { + "id": 54 + }, + { + "id": 63 + } + ], + [ + { + "id": 48 + }, + { + "id": 57 + } + ], + [ + { + "id": 49 + }, + { + "id": 56 + }, + { + "id": 58 + } + ], + [ + { + "id": 50 + }, + { + "id": 57 + }, + { + "id": 59 + } + ], + [ + { + "id": 51 + }, + { + "id": 58 + }, + { + "id": 60 + } + ], + [ + { + "id": 52 + }, + { + "id": 59 + }, + { + "id": 61 + } + ], + [ + { + "id": 53 + }, + { + "id": 60 + }, + { + "id": 62 + } + ], + [ + { + "id": 54 + }, + { + "id": 61 + }, + { + "id": 63 + } + ], + [ + { + "id": 55 + }, + { + "id": 62 + } + ] + ] +} \ No newline at end of file diff --git a/tests/frm_tests/test_frm_make_graph.py b/tests/frm_tests/test_frm_make_graph.py new file mode 100644 index 00000000..cee8d27f --- /dev/null +++ b/tests/frm_tests/test_frm_make_graph.py @@ -0,0 +1,291 @@ +################################################################ +# +# frm: This file was copied from test_make_graph.py (to make +# use of their fixtures. It should eventually evolve into +# a reasonable test of additional functions added by me +# to gerrychain.graph +# +################################################################ + +import pathlib +from tempfile import TemporaryDirectory +from unittest.mock import patch + +import geopandas as gp +import pandas +import pytest +from shapely.geometry import Polygon +from pyproj import CRS + +from gerrychain.graph import Graph +from gerrychain.graph.geo import GeometryError + +import networkx + + +@pytest.fixture +def geodataframe(): + a = Polygon([(0, 0), (0, 1), (1, 1), (1, 0)]) + b = Polygon([(0, 1), (0, 2), (1, 2), (1, 1)]) + c = Polygon([(1, 0), (1, 1), (2, 1), (2, 0)]) + d = Polygon([(1, 1), (1, 2), (2, 2), (2, 1)]) + df = gp.GeoDataFrame({"ID": ["a", "b", "c", "d"], "geometry": [a, b, c, d]}) + df.crs = "+proj=longlat +ellps=WGS84 +datum=WGS84 +no_defs" + return df + + +@pytest.fixture +def gdf_with_data(geodataframe): + geodataframe["data"] = list(range(len(geodataframe))) + geodataframe["data2"] = list(range(len(geodataframe))) + return geodataframe + + +@pytest.fixture +def geodataframe_with_boundary(): + """ + abe + ade + ace + """ + a = Polygon([(0, 0), (0, 1), (0, 2), (0, 3), (1, 3), (1, 2), (1, 1), (1, 0)]) + b = Polygon([(1, 2), (1, 3), (2, 3), (2, 2)]) + c = Polygon([(1, 0), (1, 1), (2, 1), (2, 0)]) + d = Polygon([(1, 1), (1, 2), (2, 2), (2, 1)]) + e = Polygon([(2, 0), (2, 1), (2, 2), (2, 3), (3, 3), (3, 2), (3, 1), (3, 0)]) + df = gp.GeoDataFrame({"ID": ["a", "b", "c", "d", "e"], "geometry": [a, b, c, d, e]}) + df.crs = "+proj=longlat +ellps=WGS84 +datum=WGS84 +no_defs" + return df + + +@pytest.fixture +def shapefile(gdf_with_data): + with TemporaryDirectory() as d: + filepath = pathlib.Path(d) / "temp.shp" + filename = str(filepath.absolute()) + gdf_with_data.to_file(filename) + yield filename + + +@pytest.fixture +def target_file(): + with TemporaryDirectory() as d: + filepath = pathlib.Path(d) / "temp.shp" + filename = str(filepath.absolute()) + yield filename + + +def test_add_data_to_graph_can_handle_column_names_that_start_with_numbers(): + nx_graph = networkx.Graph([("01", "02"), ("02", "03"), ("03", "01")]) + df = pandas.DataFrame({"16SenDVote": [20, 30, 50], "node": ["01", "02", "03"]}) + df = df.set_index("node") + + graph = Graph.from_networkx(nx_graph) + graph.add_data(df, ["16SenDVote"]) + + assert nx_graph.nodes["01"]["16SenDVote"] == 20 + assert nx_graph.nodes["02"]["16SenDVote"] == 30 + assert nx_graph.nodes["03"]["16SenDVote"] == 50 + + assert graph.node_data("01")["16SenDVote"] == 20 + assert graph.node_data("02")["16SenDVote"] == 30 + assert graph.node_data("03")["16SenDVote"] == 50 + + +def test_join_can_handle_right_index(): + nx_graph = networkx.Graph([("01", "02"), ("02", "03"), ("03", "01")]) + df = pandas.DataFrame({"16SenDVote": [20, 30, 50], "node": ["01", "02", "03"]}) + + graph = Graph.from_networkx(nx_graph) + + graph.join(df, ["16SenDVote"], right_index="node") + + assert graph.node_data("01")["16SenDVote"] == 20 + assert graph.node_data("02")["16SenDVote"] == 30 + assert graph.node_data("03")["16SenDVote"] == 50 + + +def test_make_graph_from_dataframe_creates_graph(geodataframe): + graph = Graph.from_geodataframe(geodataframe) + assert isinstance(graph, Graph) + + +def test_make_graph_from_dataframe_preserves_df_index(geodataframe): + df = geodataframe.set_index("ID") + graph = Graph.from_geodataframe(df) + assert set(graph.nodes) == {"a", "b", "c", "d"} + + +def test_make_graph_from_dataframe_gives_correct_graph(geodataframe): + df = geodataframe.set_index("ID") + graph = Graph.from_geodataframe(df) + + assert edge_set_equal( + set(graph.edges), {("a", "b"), ("a", "c"), ("b", "d"), ("c", "d")} + ) + + +def test_make_graph_works_with_queen_adjacency(geodataframe): + df = geodataframe.set_index("ID") + graph = Graph.from_geodataframe(df, adjacency="queen") + + assert edge_set_equal( + set(graph.edges), + {("a", "b"), ("a", "c"), ("b", "d"), ("c", "d"), ("a", "d"), ("b", "c")}, + ) + + +def test_can_pass_queen_or_rook_strings_to_control_adjacency(geodataframe): + df = geodataframe.set_index("ID") + graph = Graph.from_geodataframe(df, adjacency="queen") + + assert edge_set_equal( + set(graph.edges), + {("a", "b"), ("a", "c"), ("b", "d"), ("c", "d"), ("a", "d"), ("b", "c")}, + ) + + +def test_can_insist_on_not_reprojecting(geodataframe): + df = geodataframe.set_index("ID") + graph = Graph.from_geodataframe(df, reproject=False) + + for node in ("a", "b", "c", "d"): + assert graph.node_data(node)["area"] == 1 + + for edge in graph.edges: + assert graph.edge_data(edge)["shared_perim"] == 1 + + +def test_does_not_reproject_by_default(geodataframe): + df = geodataframe.set_index("ID") + graph = Graph.from_geodataframe(df) + + for node in ("a", "b", "c", "d"): + assert graph.node_data(node)["area"] == 1.0 + + for edge in graph.edges: + assert graph.edge_data(edge)["shared_perim"] == 1.0 + + +def test_reproject(geodataframe): + # I don't know what the areas and perimeters are in UTM for these made-up polygons, + # but I'm pretty sure they're not 1. + df = geodataframe.set_index("ID") + graph = Graph.from_geodataframe(df, reproject=True) + + for node in ("a", "b", "c", "d"): + assert graph.node_data(node)["area"] != 1 + + for edge in graph.edges: + assert graph.edge_data(edge)["shared_perim"] != 1 + + +def test_identifies_boundary_nodes(geodataframe_with_boundary): + df = geodataframe_with_boundary.set_index("ID") + graph = Graph.from_geodataframe(df) + + for node in ("a", "b", "c", "e"): + assert graph.node_data(node)["boundary_node"] + assert not graph.node_data("d")["boundary_node"] + + +def test_computes_boundary_perims(geodataframe_with_boundary): + df = geodataframe_with_boundary.set_index("ID") + graph = Graph.from_geodataframe(df, reproject=False) + + expected = {"a": 5, "e": 5, "b": 1, "c": 1} + + for node, value in expected.items(): + assert graph.node_data(node)["boundary_perim"] == value + + +def edge_set_equal(set1, set2): + return {(y, x) for x, y in set1} | set1 == {(y, x) for x, y in set2} | set2 + + +def test_from_file_adds_all_data_by_default(shapefile): + graph = Graph.from_file(shapefile) + + # data dictionaries for all of the nodes + all_node_data = [graph.node_data(node_id) for node_id in graph.node_indices] + + assert all("data" in node_data for node_data in all_node_data) + assert all("data2" in node_data for node_data in all_node_data) + + +def test_from_file_and_then_to_json_does_not_error(shapefile, target_file): + graph = Graph.from_file(shapefile) + + # Even the geometry column is copied to the graph + + # data dictionaries for all of the nodes + all_node_data = [graph.node_data(node_id) for node_id in graph.node_indices] + + assert all("geometry" in node_data for node_data in all_node_data) + + graph.to_json(target_file) + + +def test_from_file_and_then_to_json_with_geometries(shapefile, target_file): + graph = Graph.from_file(shapefile) + + # data dictionaries for all of the nodes + all_node_data = [graph.node_data(node_id) for node_id in graph.node_indices] + + # Even the geometry column is copied to the graph + assert all("geometry" in node_data for node_data in all_node_data) + + # frm: ??? Does anything check that the file is actually written? + graph.to_json(target_file, include_geometries_as_geojson=True) + + +def test_graph_warns_for_islands(): + nx_graph = networkx.Graph() + nx_graph.add_node(0) + graph = Graph.from_networkx(nx_graph) + + with pytest.warns(Warning): + graph.warn_for_islands() + + +def test_graph_raises_if_crs_is_missing_when_reprojecting(geodataframe): + geodataframe.crs = None + + with pytest.raises(ValueError): + Graph.from_geodataframe(geodataframe, reproject=True) + + +def test_raises_geometry_error_if_invalid_geometry(shapefile): + with patch("gerrychain.graph.geo.explain_validity") as explain: + explain.return_value = "Invalid geometry" + with pytest.raises(GeometryError): + Graph.from_file(shapefile, ignore_errors=False) + + +def test_can_ignore_errors_while_making_graph(shapefile): + with patch("gerrychain.graph.geo.explain_validity") as explain: + explain.return_value = "Invalid geometry" + assert Graph.from_file(shapefile, ignore_errors=True) + + +def test_data_and_geometry(gdf_with_data): + df = gdf_with_data + graph = Graph.from_geodataframe(df, cols_to_add=["data","data2"]) + assert graph.geometry is df.geometry + #graph.add_data(df[["data"]]) + assert (graph.data["data"] == df["data"]).all() + #graph.add_data(df[["data2"]]) + assert list(graph.data.columns) == ["data", "data2"] + + +def test_make_graph_from_dataframe_has_crs(gdf_with_data): + graph = Graph.from_geodataframe(gdf_with_data) + assert CRS.from_json(graph.graph["crs"]).equals(gdf_with_data.crs) + +def test_make_graph_from_shapefile_has_crs(shapefile): + graph = Graph.from_file(shapefile) + df = gp.read_file(shapefile) + assert CRS.from_json(graph.graph["crs"]).equals(df.crs) + + + diff --git a/tests/frm_tests/test_frm_nx_rx_graph.py b/tests/frm_tests/test_frm_nx_rx_graph.py new file mode 100644 index 00000000..84f0c32c --- /dev/null +++ b/tests/frm_tests/test_frm_nx_rx_graph.py @@ -0,0 +1,232 @@ +####################################################### +# Overview of test_frm_nx_rx_graph.py +####################################################### +""" + +A collection of tests to verify that the new GerryChain +Graph object works the same with NetworkX and RustworkX. + + +""" + +import matplotlib.pyplot as plt +from gerrychain import (Partition, Graph, MarkovChain, + updaters, constraints, accept) +from gerrychain.proposals import recom +from gerrychain.constraints import contiguous +from functools import partial +import pandas + +import os +import rustworkx as rx +import networkx as nx + +import pytest + + +# Set the random seed so that the results are reproducible! +import random +random.seed(2024) + +############################################################ +# Create Graph Objects - both direct NX.Graph and RX.PyGraph +# objects and two GerryChain Graph objects that embed the +# NX and RX graphs. +############################################################ + +@pytest.fixture(scope="module") +def json_file_path(): + # Get path to the JSON containing graph data + test_file_path = os.path.abspath(__file__) + cur_directory = os.path.dirname(test_file_path) + path_for_json_file = os.path.join(cur_directory, "gerrymandria.json") + # print("json file is: ", json_file_path) + return path_for_json_file + +@pytest.fixture(scope="module") +def gerrychain_nx_graph(json_file_path): + # Create an NX based Graph object from the JSON + graph = Graph.from_json(json_file_path) + print("gerrychain_nx_graph: len(graph): ", len(graph)) + return(graph) + +@pytest.fixture(scope="module") +def nx_graph(gerrychain_nx_graph): + # Fetch the NX graph object from inside the Graph object + return gerrychain_nx_graph.get_nx_graph() + +@pytest.fixture(scope="module") +def rx_graph(nx_graph): + # Create an RX graph object from NX, preserving node data + return rx.networkx_converter(nx_graph, keep_attributes=True) + +@pytest.fixture(scope="module") +def gerrychain_rx_graph(rx_graph): + # Create a Graph object with an RX graph inside + return Graph.from_rustworkx(rx_graph) + +################## +# Start of Tests +################## + +def test_sanity(): + # frm: if you call pytest with -rP, then it will show stdout for tests + print("test_sanity(): called") + assert True + +def test_nx_rx_sets_of_nodes_agree(nx_graph, rx_graph): + nx_set_of_nodes = set(nx_graph.nodes()) + rx_set_of_nodes = set(rx_graph.node_indices()) + assert nx_set_of_nodes == rx_set_of_nodes + +def test_nx_rx_node_data_agree(gerrychain_nx_graph, gerrychain_rx_graph): + nx_data_dict = gerrychain_nx_graph.node_data(1) + rx_data_dict = gerrychain_rx_graph.node_data(1) + assert nx_data_dict == rx_data_dict + +def test_nx_rx_node_indices_agree(gerrychain_nx_graph, gerrychain_rx_graph): + nx_node_indices = gerrychain_nx_graph.node_indices + rx_node_indices = gerrychain_rx_graph.node_indices + assert nx_node_indices == rx_node_indices + +def test_nx_rx_edges_agree(gerrychain_nx_graph, gerrychain_rx_graph): + # TODO: Testing: Rethink this test. At the moment it relies on the edge_list() + # call which does not exist on a GerryChain Graph object + # being handled by RX through clever __getattr__ stuff. + # I think we should add an edge_list() method to GerryChain Graph + nx_edges = set(gerrychain_nx_graph.edges) + rx_edges = set(gerrychain_rx_graph.edge_list()) + assert nx_edges == rx_edges + +def test_nx_rx_node_neighbors_agree(gerrychain_nx_graph, gerrychain_rx_graph): + for i in gerrychain_nx_graph: + # Need to convert to set, because ordering of neighbor nodes differs in the lists + nx_neighbors = set(gerrychain_nx_graph.neighbors(i)) + rx_neighbors = set(gerrychain_rx_graph.neighbors(i)) + assert nx_neighbors == rx_neighbors + +def test_nx_rx_subgraphs_agree(gerrychain_nx_graph, gerrychain_rx_graph): + subgraph_nodes = [0,1,2,3,4,5] # TODO: Testing: make this a fixture dependent on JSON graph + nx_subgraph = gerrychain_nx_graph.subgraph(subgraph_nodes) + rx_subgraph = gerrychain_rx_graph.subgraph(subgraph_nodes) + for node_id in nx_subgraph: + nx_node_data = nx_subgraph.node_data(node_id) + rx_node_data = rx_subgraph.node_data(node_id) + assert nx_node_data == rx_node_data + # frm: TODO: Testing: This does not test that the rx_subgraph has the exact same number of + # nodes as the nx_subgraph, and it does not test edge data... + +def test_nx_rx_degrees_agree(gerrychain_nx_graph, gerrychain_rx_graph): + # Verify that the degree of each node agrees between NX and RX versions + nx_degrees = { + node_id: gerrychain_nx_graph.degree(node_id) for node_id in gerrychain_nx_graph.node_indices + } + rx_degrees = { + node_id: gerrychain_rx_graph.degree(node_id) for node_id in gerrychain_rx_graph.node_indices + } + for node_id in gerrychain_nx_graph.node_indices: + assert nx_degrees[node_id] == rx_degrees[node_id] + + +""" +frm: TODO: Testing: + + * Functions: + * predecessors() + * successors() + * is_connected() + * laplacian_matrix() + * normalized_laplacian_matrix() + * neighbors() + I think this has been done for both NX and RX + * networkx.generators.lattice.grid_2d_graph() + * nx.to_dict_of_lists() + * nx.tree.minimum_spanning_tree() + * nx.number_connected_components() + * nx.set_edge_attributes() + * nx.set_node_attributes() + + * Syntax: + * graph.edges + NX - note that edges and edges() do exactly the same thing. They return + an EdgeView of a list of edges with edge_id being a tuple indicating + the start and end node_ids for the edge. + Need to find out how edges and edges() is used in the code to know + what the right thing to do is for RX - that is, what aspect of an + EdgeView is used in the code? Is a set of tuples OK? + * graph.nodes + NX returns a NodeView with the node_ids for the nodes + RX does not have a "nodes" attribute, but it does have a nodes() + method which does something different. It returns a list (indexed + by node_id) of the data associated with nodes. + So, I need to see how Graph.nodes is used in the code to see what the + right way is to support it in RX. + * graph.nodes[node_id] + returns data dictionary for the node + * graph.nodes[node_id][attr_id] + returns the value for the given attribute for that node's data + * graph.add_edge() + Done differently in NX and RX + * graph.degree + * graph.subgraph + * for edge in graph.edge_indices: + graph.edges[edge]["weight"] = random.random() + In RX, assigning the weight to an edge is done differently... + Note that edge_indices currently works exactly the same for both + NX and RX - returning a set of tuples (for edges). However, + assigning a value to the "weight" attribute of an edge is done + differently... + * islands() +""" + + + + + + +### my_updaters = { +### "population": updaters.Tally("TOTPOP"), +### "cut_edges": updaters.cut_edges +### } +### +### initial_partition = Partition( +### nx_graph, +### assignment="district", +### updaters=my_updaters +### ) +### +### # This should be 8 since each district has 1 person in it. +### # Note that the key "population" corresponds to the population updater +### # that we defined above and not with the population column in the json file. +### ideal_population = sum(initial_partition["population"].values()) / len(initial_partition) +### +### proposal = partial( +### recom, +### pop_col="TOTPOP", +### pop_target=ideal_population, +### epsilon=0.01, +### node_repeats=2 +### ) +### +### print("Got proposal") +### +### recom_chain = MarkovChain( +### proposal=proposal, +### constraints=[contiguous], +### accept=accept.always_accept, +### initial_state=initial_partition, +### total_steps=40 +### ) +### +### print("Set up Markov Chain") +### +### assignment_list = [] +### +### for i, item in enumerate(recom_chain): +### print(f"Finished step {i+1}/{len(recom_chain)}") +### assignment_list.append(item.assignment) +### +### print("Enumerated the chain: number of entries in list is: ", len(assignment_list)) +### +### def test_success(): +### len(assignment_list) == 40 diff --git a/tests/frm_tests/test_frm_regression.py b/tests/frm_tests/test_frm_regression.py new file mode 100644 index 00000000..29c4a897 --- /dev/null +++ b/tests/frm_tests/test_frm_regression.py @@ -0,0 +1,77 @@ +############################################################### +# +# frm: Overview of test_frm_regression.py +# +# This code was copied from the GerryChain User Guide / Tutorial as a way +# to have a functional test that exercised the overall logic of GerryChain. +# +# It is NOT comprehensive, but it does get all the way to executing +# a chain. +# +# It is a quick and dirty way to make sure I haven't really screwed things up ;-) +# + +import matplotlib.pyplot as plt +from gerrychain import (Partition, Graph, MarkovChain, + updaters, constraints, accept) +from gerrychain.proposals import recom +from gerrychain.constraints import contiguous +from functools import partial +import pandas + +import os + + +# Set the random seed so that the results are reproducible! +import random +random.seed(2024) + + +test_file_path = os.path.abspath(__file__) +cur_directory = os.path.dirname(test_file_path) +json_file_path = os.path.join(cur_directory, "gerrymandria.json") + +graph = Graph.from_json(json_file_path) + +my_updaters = { + "population": updaters.Tally("TOTPOP"), + "cut_edges": updaters.cut_edges +} + +initial_partition = Partition( + graph, + assignment="district", + updaters=my_updaters +) + +# This should be 8 since each district has 1 person in it. +# Note that the key "population" corresponds to the population updater +# that we defined above and not with the population column in the json file. +ideal_population = sum(initial_partition["population"].values()) / len(initial_partition) + +proposal = partial( + recom, + pop_col="TOTPOP", + pop_target=ideal_population, + epsilon=0.01, + node_repeats=2 +) + +recom_chain = MarkovChain( + proposal=proposal, + constraints=[contiguous], + accept=accept.always_accept, + initial_state=initial_partition, + total_steps=40 +) + +assignment_list = [] + +for i, item in enumerate(recom_chain): + print(f"Finished step {i+1}/{len(recom_chain)}") + assignment_list.append(item.assignment) + +print("Enumerated the chain: number of entries in list is: ", len(assignment_list)) + +def test_success(): + len(assignment_list) == 40 diff --git a/tests/frm_tests/test_to_networkx_graph.py b/tests/frm_tests/test_to_networkx_graph.py new file mode 100644 index 00000000..54f85ab3 --- /dev/null +++ b/tests/frm_tests/test_to_networkx_graph.py @@ -0,0 +1,195 @@ +# +# This tests whether the routine, to_networkx_graph(), works +# properly. +# +# This routine extracts a new NetworkX.Graph object from an +# Graph object that is based on RustworkX. When we create +# a Partition object from an NetworkX Graph we convert the +# graph to RustworkX for performance. However, users might +# want to have access to a NetworkX Graph for a variety of +# reasons: mostly because they built their initial graph as +# a NetworkX Graph and they used node_ids that made sense to +# them at the time and would like to access the graph at +# the end of a MarkovChain run using those same "original" +# IDs. +# +# The extracted NetworkX Graph should have the "original" +# node_ids, and it should have all of the node and edge +# data that was in the RustworkX Graph object. +# + + +import networkx as nx + +from gerrychain.graph import Graph +from gerrychain.partition import Partition + +def test_to_networkx_graph_works(): + + """ + Create an NX graph (grid) that looks like this: + + 'A' 'B' 'C' + 'D' 'E' 'F' + 'G' 'H' 'I' + """ + + nx_graph = nx.Graph() + nx_graph.add_edges_from( + [ + ('A', 'B'), + ('A', 'D'), + ('B', 'C'), + ('B', 'E'), + ('C', 'F'), + ('D', 'E'), + ('D', 'G'), + ('E', 'F'), + ('E', 'H'), + ('F', 'I'), + ('G', 'H'), + ('H', 'I'), + ] + ) + + # Add some node and edge data to the nx_graph + + graph_node_ids = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I'] + for node_id in graph_node_ids: + nx_graph.nodes[node_id]["nx-node-data"] = node_id + + nx_graph.edges[('A','B')]["nx-edge-data"] = ('A','B') + nx_graph.edges[('A','D')]["nx-edge-data"] = ('A','D') + nx_graph.edges[('B','C')]["nx-edge-data"] = ('B','C') + nx_graph.edges[('B','E')]["nx-edge-data"] = ('B','E') + nx_graph.edges[('C','F')]["nx-edge-data"] = ('C','F') + nx_graph.edges[('D','E')]["nx-edge-data"] = ('D','E') + nx_graph.edges[('D','G')]["nx-edge-data"] = ('D','G') + nx_graph.edges[('E','F')]["nx-edge-data"] = ('E','F') + nx_graph.edges[('E','H')]["nx-edge-data"] = ('E','H') + nx_graph.edges[('F','I')]["nx-edge-data"] = ('F','I') + nx_graph.edges[('G','H')]["nx-edge-data"] = ('G','H') + nx_graph.edges[('H','I')]["nx-edge-data"] = ('H','I') + + graph = Graph.from_networkx(nx_graph) + + """ + Create a partition assigning each "row" of + nodes to a part (district), so the assignment + looks like: + + 0 0 0 + 1 1 1 + 2 2 2 + """ + + initial_assignment = { + 'A': 0, + 'B': 0, + 'C': 0, + 'D': 1, + 'E': 1, + 'F': 1, + 'G': 2, + 'H': 2, + 'I': 2, + } + + # Create a partition + partition = Partition(graph, initial_assignment) + + # The partition's graph object has been converted to be based on RX + new_graph = partition.graph + + # Add some additional data + for node_id in new_graph.node_indices: + new_graph.node_data(node_id)["internal-node-data"] = new_graph.original_nx_node_id_for_internal_node_id(node_id) + for edge_id in new_graph.edge_indices: + new_graph.edge_data(edge_id)["internal-edge-data"] = "internal-edge-data" + + # Now create a second partition by flipping the + # nodes in the first row to be in part (district) 1 + + """ + The new partition's mapping of nodes to parts should look like this: + + 1 1 1 + 1 1 1 + 2 2 2 + """ + + flips = {'A': 1, 'B': 1, 'C': 1} + # Create a new partition based on these flips - using "original" node_ids + new_partition = partition.flip(flips, use_original_nx_node_ids=True) + + + # Get the NX graph after doing the flips. + extracted_nx_graph = new_partition.graph.to_networkx_graph() + + # Get the assignments for both the initial partition and the new_partition + + internal_assignment_0 = partition.assignment + internal_assignment_1 = new_partition.assignment + + # convert the internal assignments into "original" node_ids + original_assignment_0 = {} + for node_id, part in internal_assignment_0.items(): + original_nx_node_id = partition.graph.original_nx_node_id_for_internal_node_id(node_id) + original_assignment_0[original_nx_node_id] = part + original_assignment_1 = {} + for node_id, part in internal_assignment_1.items(): + original_nx_node_id = partition.graph.original_nx_node_id_for_internal_node_id(node_id) + original_assignment_1[original_nx_node_id] = part + + # Check that all is well... + + # Check that the initial assignment is the same as the internal RX-based assignment + for node_id, part in initial_assignment.items(): + assert (part == original_assignment_0[node_id]) + + # Check that the flips did what they were supposed to do + for node_id in ['A', 'B', 'C', 'D', 'E', 'F']: + assert(original_assignment_1[node_id] == 1) + for node_id in ['G', 'H', 'I']: + assert(original_assignment_1[node_id] == 2) + + # Check that the node and edge data is present + + # Check node data + for node_id in extracted_nx_graph.nodes: + # Data assigned to the NX-Graph should still be there... + assert( + extracted_nx_graph.nodes[node_id]["nx-node-data"] + == + nx_graph.nodes[node_id]["nx-node-data"] + ) + # Data assigned to the partition's RX-Graph should still be there... + assert( + extracted_nx_graph.nodes[node_id]["internal-node-data"] + == + node_id + ) + # Node_id agrees with __networkx_node__ (created by RX conversion) + assert( + node_id + == + extracted_nx_graph.nodes[node_id]["__networkx_node__"] + ) + + + # Check edge data + for edge in extracted_nx_graph.edges: + assert( + extracted_nx_graph.edges[edge]["nx-edge-data"] + == + nx_graph.edges[edge]["nx-edge-data"] + ) + # Data assigned to the partition's RX-Graph should still be there... + assert( + extracted_nx_graph.edges[edge]["internal-edge-data"] + == + "internal-edge-data" + ) + # compare the extracted_nx_graph's nodes and edges to see if they make sense + # Compare node_data and edge_data + diff --git a/tests/partition/test_partition.py b/tests/partition/test_partition.py index d9e42dbc..ea7e380f 100644 --- a/tests/partition/test_partition.py +++ b/tests/partition/test_partition.py @@ -12,6 +12,13 @@ def test_Partition_can_be_flipped(example_partition): + # frm: TODO: Testing: Verify that this flip is in internal RX-based graph node_ids and not "original" NX node_ids + # + # My guess is that this flip is intended to be in original node_ids but that the test works + # anyways because the assertion uses the same numbers. It should probably be changed to use + # original node_ids and to translate the node_id and part in the assert into internal node_ids + # just to make it crystal clear to anyone following later what is going on... + flip = {1: 2} new_partition = example_partition.flip(flip) assert new_partition.assignment[1] == 2 @@ -45,6 +52,9 @@ def test_Partition_knows_cut_edges_K3(example_partition): def test_propose_random_flip_proposes_a_partition(example_partition): partition = example_partition + + # frm: TODO: Testing: Verify that propose_random_flip() to make sure it is doing the right thing + # wrt RX-based node_ids vs. original node_ids. proposal = propose_random_flip(partition) assert isinstance(proposal, partition.__class__) @@ -54,10 +64,10 @@ def example_geographic_partition(): graph = Graph.from_networkx(networkx.complete_graph(3)) assignment = {0: 1, 1: 1, 2: 2} for node in graph.nodes: - graph.nodes[node]["boundary_node"] = False - graph.nodes[node]["area"] = 1 + graph.node_data(node)["boundary_node"] = False + graph.node_data(node)["area"] = 1 for edge in graph.edges: - graph.edges[edge]["shared_perim"] = 1 + graph.edge_data(edge)["shared_perim"] = 1 return GeographicPartition(graph, assignment, None, None, None) @@ -69,15 +79,32 @@ def test_geographic_partition_can_be_instantiated(example_geographic_partition): def test_Partition_parts_is_a_dictionary_of_parts_to_nodes(example_partition): partition = example_partition flip = {1: 2} - new_partition = partition.flip(flip) + new_partition = partition.flip(flip, use_original_nx_node_ids=True) assert all(isinstance(nodes, frozenset) for nodes in new_partition.parts.values()) assert all(isinstance(nodes, frozenset) for nodes in partition.parts.values()) def test_Partition_has_subgraphs(example_partition): + # Test that subgraphs work as intended. + # The partition has two parts (districts) with IDs: 1, 2 + # Part #1 has nodes 0, 1, so the subgraph for part #1 should have these nodes + # Part #2 has node 2, so the subgraph for part #1 should have this node + + # Note that the original node_ids are based on the original NX-based graph + # The node_ids in the partition's graph have been changed by the conversion + # from NX to RX, so we need to be careful about when to use "original" node_ids + # and when to use "internal" RX-based node_ids + partition = example_partition - assert set(partition.subgraphs[1].nodes) == {0, 1} - assert set(partition.subgraphs[2].nodes) == {2} + + subgraph_for_part_1 = partition.subgraphs[1] + internal_node_id_0 = subgraph_for_part_1.internal_node_id_for_original_nx_node_id(0) + internal_node_id_1 = subgraph_for_part_1.internal_node_id_for_original_nx_node_id(1) + assert set(partition.subgraphs[1].nodes) == {internal_node_id_0, internal_node_id_1} + + subgraph_for_part_2 = partition.subgraphs[2] + internal_node_id = subgraph_for_part_2.internal_node_id_for_original_nx_node_id(2) + assert set(partition.subgraphs[2].nodes) == {internal_node_id} assert len(list(partition.subgraphs)) == 2 @@ -92,10 +119,20 @@ def test_partition_implements_getattr_for_updater_access(example_partition): def test_can_be_created_from_a_districtr_file(graph, districtr_plan_file): for node in graph: - graph.nodes[node]["area_num_1"] = node + graph.node_data(node)["area_num_1"] = node + + # frm: TODO: Testing: NX vs. RX node_id issues here... partition = Partition.from_districtr_file(graph, districtr_plan_file) - assert partition.assignment.to_dict() == { + + # Convert internal node_ids of the partition's graph to "original" node_ids + internal_node_assignment = partition.assignment.to_dict() + original_node_assignment = {} + for internal_node_id, part in internal_node_assignment.items(): + original_nx_node_id = partition.graph.original_nx_node_id_for_internal_node_id(internal_node_id) + original_node_assignment[original_nx_node_id] = part + + assert original_node_assignment == { 0: 1, 1: 1, 2: 1, diff --git a/tests/partition/test_plotting.py b/tests/partition/test_plotting.py index b9916319..1f5bb393 100644 --- a/tests/partition/test_plotting.py +++ b/tests/partition/test_plotting.py @@ -3,13 +3,15 @@ import geopandas as gp import pytest from shapely.geometry import Polygon +import networkx from gerrychain import Graph, Partition @pytest.fixture def partition(): - graph = Graph([(0, 1), (1, 3), (2, 3), (0, 2)]) + nx_graph = networkx.Graph([(0, 1), (1, 3), (2, 3), (0, 2)]) + graph = Graph.from_networkx(nx_graph) return Partition(graph, {0: 1, 1: 1, 2: 2, 3: 2}) @@ -66,5 +68,35 @@ def test_uses_graph_geometries_by_default(self, geodataframe): graph = Graph.from_geodataframe(geodataframe) partition = Partition(graph=graph, assignment={node: 0 for node in graph}) + + # frm: TODO: Testing: how to handle geometry? + # + # Originally, the following statement blew up because we do not copy + # geometry data from NX to RX when we convert to RX. + # + # I said at the time: + # Need to grok what the right way to deal with geometry + # data is (is it only an issue for from_geodataframe() or + # are there other ways a geometry value might be set?) + # + # Peter comments (from PR): + # + # The geometry data should only exist on the attached geodataframe. + # In fact, if there is no "geometry" column in the dataframe, this call + # should fail. + # + # Fixing the plotting functions is a low-priority. I need to set up + # snapshot tests for these anyway, so if you find working with + # matplotlib a PITA (because it is), then don't worry about the + # plotting functions for now. + # + # Worst-case scenario, I can just add some temporary verbage to + # readthedocs telling people to use + # + # my_partition.df.plot() + + # Which will just use all of the plotting stuff that Pandas has set up internally. + partition.plot() assert mock_plot.call_count == 1 + \ No newline at end of file diff --git a/tests/test_chain.py b/tests/test_chain.py index 3910df2f..1ffd7554 100644 --- a/tests/test_chain.py +++ b/tests/test_chain.py @@ -4,12 +4,12 @@ class MockState: - def flip(self, changes): + def flip(self, changes, use_original_nx_node_ids): return MockState() def mock_proposal(state): - return state.flip({1: 2}) + return state.flip({1: 2}, use_original_nx_node_ids=True) def mock_accept(state): diff --git a/tests/test_frm_graph.py b/tests/test_frm_graph.py new file mode 100644 index 00000000..19b729aa --- /dev/null +++ b/tests/test_frm_graph.py @@ -0,0 +1,650 @@ +import pytest +import networkx as nx +import rustworkx as rx +from gerrychain import Graph + +############################################### +# This file contains tests routines in graph.py +############################################### + +@pytest.fixture +def four_by_five_grid_nx(): + + # Create an NX Graph object with attributes + # + # This graph has the following properties + # which are important for the tests below: + # + # * The "nx_node_id" attribute serves as an + # effective "original" node_id so that we + # can track a node even when its internal + # node_id changes. + # + # * The graph has two "connected" components: + # the first two rows and the last two + # rows. This is used in the connected + # components tests + + # nx_node_id + # + # 0 1 2 3 4 + # 5 6 7 8 9 + # 10 11 12 13 14 + # 15 16 17 18 19 + + # MVAP: + # + # 2 2 2 2 2 + # 2 2 2 2 2 + # 2 2 2 2 2 + # 2 2 2 2 2 + + nx_graph = nx.Graph() + nx_graph.add_nodes_from( + [ + (0, {"population": 10, "nx_node_id": 0, "MVAP": 2}), + (1, {"population": 10, "nx_node_id": 1, "MVAP": 2}), + (2, {"population": 10, "nx_node_id": 2, "MVAP": 2}), + (3, {"population": 10, "nx_node_id": 3, "MVAP": 2}), + (4, {"population": 10, "nx_node_id": 4, "MVAP": 2}), + (5, {"population": 10, "nx_node_id": 5, "MVAP": 2}), + (6, {"population": 10, "nx_node_id": 6, "MVAP": 2}), + (7, {"population": 10, "nx_node_id": 7, "MVAP": 2}), + (8, {"population": 10, "nx_node_id": 8, "MVAP": 2}), + (9, {"population": 10, "nx_node_id": 9, "MVAP": 2}), + (10, {"population": 10, "nx_node_id": 10, "MVAP": 2}), + (11, {"population": 10, "nx_node_id": 11, "MVAP": 2}), + (12, {"population": 10, "nx_node_id": 12, "MVAP": 2}), + (13, {"population": 10, "nx_node_id": 13, "MVAP": 2}), + (14, {"population": 10, "nx_node_id": 14, "MVAP": 2}), + (15, {"population": 10, "nx_node_id": 15, "MVAP": 2}), + (16, {"population": 10, "nx_node_id": 16, "MVAP": 2}), + (17, {"population": 10, "nx_node_id": 17, "MVAP": 2}), + (18, {"population": 10, "nx_node_id": 18, "MVAP": 2}), + (19, {"population": 10, "nx_node_id": 19, "MVAP": 2}), + ] + ) + + nx_graph.add_edges_from( + [ + (0, 1), + (0, 5), + (1, 2), + (1, 6), + (2, 3), + (2, 7), + (3, 4), + (3, 8), + (4, 9), + (5, 6), + # (5, 10), + (6, 7), + # (6, 11), + (7, 8), + # (7, 12), + (8, 9), + # (8, 13), + # (9, 14), + (10, 11), + (10, 15), + (11, 12), + (11, 16), + (12, 13), + (12, 17), + (13, 14), + (13, 18), + (14, 19), + (15, 16), + (16, 17), + (17, 18), + (18, 19), + ] + ) + + return nx_graph + +@pytest.fixture +def four_by_five_grid_rx(four_by_five_grid_nx): + # Create an RX Graph object with attributes + rx_graph = rx.networkx_converter(four_by_five_grid_nx, keep_attributes=True) + return rx_graph + +def top_level_graph_is_properly_configured(graph): + # This routine tests that top-level graphs (not a subgraph) + # are properly configured + assert graph._is_a_subgraph == False, \ + "Top-level graph _is_a_subgraph is True" + assert hasattr(graph, "_node_id_to_parent_node_id_map"), \ + "Graph._node_id_to_parent_node_id_map is not set" + assert hasattr(graph, "_node_id_to_original_nx_node_id_map"), \ + "Graph._node_id_to_original_nx_node_id_map is not set" + +def test_from_networkx(four_by_five_grid_nx): + graph = Graph.from_networkx(four_by_five_grid_nx) + assert len(graph.node_indices) == 20, \ + f"Expected 20 nodes but got {len(graph.node_indices)}" + assert len(graph.edge_indices) == 26, \ + f"Expected 26 edges but got {len(graph.edge_indices)}" + assert graph.node_data(1)["population"] == 10, \ + f"Expected population of 10 but got {graph.node_data(1)['population']}" + top_level_graph_is_properly_configured(graph) + +def test_from_rustworkx(four_by_five_grid_nx): + rx_graph = rx.networkx_converter(four_by_five_grid_nx, keep_attributes=True) + graph = Graph.from_rustworkx(rx_graph) + assert len(graph.node_indices) == 20, \ + f"Expected 20 nodes but got {len(graph.node_indices)}" + assert graph.node_data(1)["population"] == 10, \ + f"Expected population of 10 but got {graph.node_data(1)['population']}" + top_level_graph_is_properly_configured(graph) + +@pytest.fixture +def four_by_five_graph_nx(four_by_five_grid_nx): + # Create an NX Graph object with attributes + graph = Graph.from_networkx(four_by_five_grid_nx) + return graph + +@pytest.fixture +def four_by_five_graph_rx(four_by_five_grid_nx): + # Create an NX Graph object with attributes + # + # Instead of using from_rustworkx(), we use + # convert_from_nx_to_rx() because tests below + # depend on the node_id maps that are created + # by convert_from_nx_to_rx() + # + graph = Graph.from_networkx(four_by_five_grid_nx) + converted_graph = graph.convert_from_nx_to_rx() + return converted_graph + +def test_convert_from_nx_to_rx(four_by_five_graph_nx): + graph = four_by_five_graph_nx # more readable + converted_graph = graph.convert_from_nx_to_rx() + + # Same number of nodes + assert len(graph.node_indices) == 20, \ + f"Expected 20 nodes but got {len(graph.node_indices)}" + assert len(converted_graph.node_indices) == 20, \ + f"Expected 20 nodes but got {len(graph.node_indices)}" + + # Same number of edges + assert len(graph.edge_indices) == 26, \ + f"Expected 26 edges but got {len(graph.edge_indices)}" + assert len(converted_graph.edge_indices) == 26, \ + f"Expected 26 edges but got {len(graph.edge_indices)}" + + # Node data is the same + # frm: TODO: Refactoring: Do this the clever Python way and test ALL at the same time + for node_id in graph.node_indices: + assert graph.node_data(node_id)["population"] == 10, \ + f"Expected population of 10 but got {graph.node_data(node_id)['population']}" + assert graph.node_data(node_id)["nx_node_id"] == node_id, \ + f"Expected nx_node_id of {node_id} but got {graph.node_data(node_id)['nx_node_id']}" + assert graph.node_data(node_id)["MVAP"] == 2, \ + f"Expected MVAP of 2 but got {graph.node_data(node_id)['MVAP']}" + for node_id in converted_graph.node_indices: + assert graph.node_data(node_id)["population"] == 10, \ + f"Expected population of 10 but got {graph.node_data(node_id)['population']}" + # frm: TODO: Code: Need to use node_id map to get appropriate node_ids for RX graph + # assert graph.node_data(node_id)["nx_node_id"] == node_id, \ + # f"Expected nx_node_id of {node_id} but got {graph.node_data(node_id)['nx_node_id']}" + assert graph.node_data(node_id)["MVAP"] == 2, \ + f"Expected MVAP of 2 but got {graph.node_data(node_id)['MVAP']}" + + # Confirm that the node_id map to the "original" NX node_ids is correct + for node_id in converted_graph.nodes: + # get the "original" NX node_id + nx_node_id = converted_graph._node_id_to_original_nx_node_id_map[node_id] + # confirm that the converted node has "nx_node_id" set to the NX node_id. This + # is an artifact of the way the NX graph was constructed. + assert converted_graph.node_data(node_id)["nx_node_id"] == nx_node_id + +def test_get_edge_from_edge_id(four_by_five_graph_nx, four_by_five_graph_rx): + + # Test that get_edge_from_edge_id works for both NX and RX based Graph objects + + # NX edges and edge_ids are the same, so this first test is trivial + # + nx_edge_id = (0, 1) + nx_edge = four_by_five_graph_nx.get_edge_from_edge_id(nx_edge_id) + assert nx_edge == (0, 1) + + # RX edge_ids are assigned arbitrarily, so without using the nx_to_rx_node_id_map + # we can't know which edge got what edge_id, so this test just verifies that + # there is an edge tuple associated with edge_id, 0 + # + rx_edge_id = 0 # arbitrary id - but there is always an edge with id == 0 + rx_edge = four_by_five_graph_rx.get_edge_from_edge_id(rx_edge_id) + assert isinstance(rx_edge[0], int), "RX edge does not exist (0)" + assert isinstance(rx_edge[1], int), "RX edge does not exist (1)" + +def test_get_edge_id_from_edge(four_by_five_graph_nx, four_by_five_graph_rx): + + # Test that get_edge_id_from_edge works for both NX and RX based Graph objects + + # NX edges and edge_ids are the same, so this first test is trivial + # + nx_edge = (0, 1) + nx_edge_id = four_by_five_graph_nx.get_edge_id_from_edge(nx_edge) + assert nx_edge_id == (0, 1) + + # Test that get_edge_id_from_edge returns an integer value and that + # when that value is used to retrieve an edge tuple, we get the + # tuple value that is expected + # + rx_edge = (0, 1) + rx_edge_id = four_by_five_graph_rx.get_edge_id_from_edge(rx_edge) + assert isinstance(rx_edge_id, int), "Edge ID not found for edge" + found_rx_edge = four_by_five_graph_rx.get_edge_from_edge_id(rx_edge_id) + assert found_rx_edge == rx_edge, "Edge ID does not yield correct edge value" + +def test_add_edge(): + # At present (October 2025), there is nothing to test. The + # code just delegates to NetworkX or RustworkX to create + # the edge. + # + # However, it is conceivable that in the future, when users + # stop using NX altogether, there might be a reason for a + # test, so this is just a placeholder for that future test... + # + assert True + +def test_subgraph(four_by_five_graph_rx): + + """ + Subgraphs are one of the most dangerous areas of the code. + In NX, subgraphs preserve node_ids - that is, the node_id + in the subgraph is the same as the node_id in the parent. + However, in RX, that is not the case - RX always creates + new node_ids starting at 0 and increasing by one + sequentially, so in general a node in an RX subgraph + will have a different node_id than it has in the parent + graph. + + To deal with this, the code creates a map from the + node_id in a subgraph to the node_id in the parent + graph, _node_id_to_parent_node_id_map. This test verifies + that this map is properly created. + + In addition, all RX based graphs that came from an NX + based graph record the "original" NX node_ids in + another node_id map, _node_id_to_original_nx_node_id_map + + When we create a subgraph, this map needs to be + established for the subgraph. This test verifies + that this map is properly created. + + Note that this test is only configured to work on + RX based Graph objects because the only uses of subgraph + in the gerrychain codebase is on RX based Graph objects. + """ + + # Create a subgraph for an arbitrary set of nodes: + subgraph_node_ids = [2, 4, 5, 8, 11, 13] + parent_graph_rx = four_by_five_graph_rx # make the code below clearer + subgraph_rx = parent_graph_rx.subgraph(subgraph_node_ids) + + assert \ + len(subgraph_node_ids) == len(subgraph_rx), \ + "Number of nodes do not agree" + + # verify that _node_id_to_parent_node_id_map is correct + for subgraph_node_id, parent_node_id in subgraph_rx._node_id_to_parent_node_id_map.items(): + # check that each node in subgraph has the same data (is the same node) + # as the node in the parent that it is mapped to + # + subgraph_stored_node_id = subgraph_rx.node_data(subgraph_node_id)["nx_node_id"] + subgraph_stored_node_id = subgraph_rx.node_data(subgraph_node_id)["nx_node_id"] + parent_stored_node_id = parent_graph_rx.node_data(parent_node_id)["nx_node_id"] + assert \ + parent_stored_node_id == subgraph_stored_node_id, \ + f"_node_id_to_parent_node_id_map is incorrect" + + # verify that _node_id_to_original_nx_node_id_map is correct + for subgraph_node_id, original_node_id in subgraph_rx._node_id_to_original_nx_node_id_map.items(): + subgraph_stored_node_id = subgraph_rx.node_data(subgraph_node_id)["nx_node_id"] + assert \ + subgraph_stored_node_id == original_node_id, \ + f"_node_id_to_original_nx_node_id_map is incorrect" + +def test_num_connected_components(four_by_five_graph_nx, four_by_five_graph_rx): + num_components_nx = four_by_five_graph_nx.num_connected_components() + num_components_rx = four_by_five_graph_rx.num_connected_components() + assert \ + num_components_nx == 2, \ + f"num_components: expected 2 but got {num_components_nx}" + assert \ + num_components_rx == 2, \ + f"num_components: expected 2 but got {num_components_rx}" + +def test_subgraphs_for_connected_components(four_by_five_graph_nx, four_by_five_graph_rx): + + subgraphs_nx = four_by_five_graph_nx.subgraphs_for_connected_components() + subgraphs_rx = four_by_five_graph_rx.subgraphs_for_connected_components() + + assert len(subgraphs_nx) == 2 + assert len(subgraphs_rx) == 2 + + assert len(subgraphs_nx[0]) == 10 + assert len(subgraphs_nx[1]) == 10 + assert len(subgraphs_rx[0]) == 10 + assert len(subgraphs_rx[1]) == 10 + + # Check that each subgraph (NX-based Graph) has correct nodes in it + node_ids_nx_0 = subgraphs_nx[0].node_indices + node_ids_nx_1 = subgraphs_nx[1].node_indices + assert node_ids_nx_0 == {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + assert node_ids_nx_1 == {10, 11, 12, 13, 14, 15, 16, 17, 18, 19} + + # Check that each subgraph (RX-based Graph) has correct nodes in it + node_ids_rx_0 = subgraphs_rx[0].node_indices + node_ids_rx_1 = subgraphs_rx[1].node_indices + original_nx_node_ids_rx_0 = subgraphs_rx[0].original_nx_node_ids_for_set(node_ids_rx_0) + original_nx_node_ids_rx_1 = subgraphs_rx[1].original_nx_node_ids_for_set(node_ids_rx_1) + assert original_nx_node_ids_rx_0 == {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + assert original_nx_node_ids_rx_1 == {10, 11, 12, 13, 14, 15, 16, 17, 18, 19} + +def test_to_networkx_graph(): + # There is already a test for this in another file + assert True + +def test_add_data(): + # This is already tested in test_make_graph.py + assert True + +######################################################## +# Long utility routine to determine if there is a cycle +# in a graph (with integer node_ids). +######################################################## + +def graph_has_cycle(set_of_edges): + + # + # Given a set of edges that define a graph, determine + # if the graph has cycles. + # + # This will allow us to test that predecessors and + # successors are in fact trees with no cycles. + # + # The approach is to do a depth-first-search that + # remembers each node it has visited, and which + # signals that a cycle exists if it revisits a node + # it has already visited + # + # Note that this code assumes that the set of nodes + # is a sequential list starting at zero with no gaps + # in the sequence. This allows us to use a simplified + # adjacency matrix which is adequate for testing + # purposes. + # + # The adjacency matrix is just a 2D square matrix + # that has a 1 value for element (i,j) iff there + # is an edge from node i to node j. Note that + # because we deal with undirected graphs the matrix + # is symetrical - edges go both ways... + # + + def add_edge(adj_matrix, s, t): + # Add an edge to an adjacency matrix + adj_matrix[s][t] = 1 + adj_matrix[t][s] = 1 # Since it's an undirected graph + + def delete_edge(adj_matrix, s, t): + # Delete an edge from an adjacency matrix + adj_matrix[s][t] = 0 + adj_matrix[t][s] = 0 # Since it's an undirected graph + + def create_empty_adjacency_matrix(num_nodes): + # create 2D array, num_nodes x num_nodes + adj_matrix = [[0] * num_nodes for _ in range(num_nodes)] + return adj_matrix + + def create_adjacency_matrix_from_set_of_edges(set_of_edges): + + # determine num_nodes + # + set_of_nodes = set() + for edge in set_of_edges: + for node in edge: + set_of_nodes.add(node) + num_nodes = len(set_of_nodes) + list_of_nodes = list(set_of_nodes) + + # We need node_ids that start at zero and go + # up sequentially with no gaps, so create a + # map for new node_ids + new_node_id_map = {} + for index, node_id in enumerate(list_of_nodes): + new_node_id_map[node_id] = index + + # Now create a new set of edges with the new node_ids + new_set_of_edges = set() + for edge in set_of_edges: + new_edge = ( + new_node_id_map[edge[0]], new_node_id_map[edge[1]] + ) + new_set_of_edges.add(new_edge) + + # debugging: + + # create an empty adjacency matrix + # + adj_matrix = create_empty_adjacency_matrix(num_nodes) + + # add the edges to the adjacency matrix + # + for edge in new_set_of_edges: + add_edge(adj_matrix, edge[0], edge[1]) + + return adj_matrix + + def inner_has_cycle(adj_matrix, visited, s, visit_list): + # This routine does a depth first search looking + # for cycles - if it encounters a node that it has + # already seen then it returns True. + + # Record having visited this node + # + visited[s] = True + visit_list.append(s) + + # Recursively visit all adjacent vertices looking for cycles + # If we have already visited a node, then there is a cycle... + # + for i in range(len(adj_matrix)): + # Recurse on every adjacent / child node... + if adj_matrix[s][i] == 1: + if visited[i]: + return True + else: + # remove this edge from adjacency matrix so we + # don't follow link back to node, i. + # + delete_edge(adj_matrix, s, i) + if inner_has_cycle(adj_matrix, visited, i, visit_list): + return True + return False + + adj_matrix = create_adjacency_matrix_from_set_of_edges(set_of_edges) + visited = [False] * len(adj_matrix) + visit_list = [] + root_node = 0 # arbitrary, but every graph has a node 0 + cycle_found = \ + inner_has_cycle(adj_matrix, visited, root_node, visit_list) + return cycle_found + +def test_graph_has_cycle(): + # Test to make sure the utility routine, graph_has_cycle, works + + # First try with no cycle + # Define the edges of the graph + set_of_edges = { + (11, 2), + (11, 0), + # (2, 0), # no cycle without this edge + (2, 3), + (2, 4) + } + the_graph_has_a_cycle = graph_has_cycle(set_of_edges) + assert the_graph_has_a_cycle == False + + # Now try with a cycle + # Define the edges of the graph + set_of_edges = { + (11, 2), + (11, 0), + (2, 0), # this edge creates a cycle + (2, 3), + (2, 4) + } + the_graph_has_a_cycle = graph_has_cycle(set_of_edges) + assert the_graph_has_a_cycle == True + +def test_generic_bfs_edges(four_by_five_graph_nx, four_by_five_graph_rx): + # + # The routine, generic_bfs_edges() returns an ordered list of + # edges from a breadth-first traversal of a graph, starting + # at the given node. + # + # For our graphs, there are two connected components (the first + # two rows and the last two rows) and each component is a + # grid: + # + # 0 - 1 - 2 - 3 - 4 + # | | | | | + # 5 - 6 - 7 - 8 - 9 + # + # 10 - 11 - 12 - 13 - 14 + # | | | | | + # 15 - 16 - 17 - 18 - 19 + # + # So, a BFS starting at 0 should produce something like: + # + # [ (0,5), (0,1), (1,6), (1,2), (2,7), (2,3), (3,8), (3,4), (4,9) ] + # + # However, the specific order that is returned depends on the + # internals of the algorithm. + # + + # + bfs_edges_nx_0 = set(four_by_five_graph_nx.generic_bfs_edges(0)) + expected_set_of_edges = { \ + (0,5), (0,1), (1,6), (1,2), (2,7), (2,3), (3,8), (3,4), (4,9) \ + } + # debugging: + assert bfs_edges_nx_0 == expected_set_of_edges + + # Check that generic_bfs_edges() does not produce a cycle + the_graph_has_a_cycle = graph_has_cycle(bfs_edges_nx_0) + assert the_graph_has_a_cycle == False + bfs_edges_nx_12 = set(four_by_five_graph_nx.generic_bfs_edges(12)) + the_graph_has_a_cycle = graph_has_cycle(bfs_edges_nx_12) + assert the_graph_has_a_cycle == False + + """ + TODO: Testing: + * Think about whether this test is actually appropriate. The + issue is that the expected_set_of_edges is the right set + for this particular graph, but I am not sure that this is + a good enough test. Think about other situations... + + * Think about whether to verify that the BFS returned + has no cycles. It doesn't in this particular case, + but perhaps we should have more cases that stress the test... + """ +def test_generic_bfs_successors_generator(): + # TODO: Testing: Write a test for this routine + # + # Note that the code for this routine is very straight-forward, so + # writing a test is not high-priority. The only reason I did not + # just go ahead and write one is because it was not immediately + # clear to me how to write the test - more work than doing a + # thorough code review... + # + assert True + +def test_generic_bfs_successors(): + # TODO: Testing: Write a test for this routine + # + # Code is trivial, but because this routine is important it + # deserves a test - just not clear off top of my head how + # to write the test... + # + assert True + +def test_generic_bfs_predecessors(): + # TODO: Testing: Write a test for this routine + # + # Code is trivial, but because this routine is important it + # deserves a test - just not clear off top of my head how + # to write the test... + # + assert True + +def test_predecessors(): + # TODO: Testing: Write a test for this routine + # + # Code is trivial, but because this routine is important it + # deserves a test - just not clear off top of my head how + # to write the test... + # + assert True + +def test_successors(): + # TODO: Testing: Write a test for this routine + # + # Code is trivial, but because this routine is important it + # deserves a test - just not clear off top of my head how + # to write the test... + # + assert True + +def test_laplacian_matrix(): + # TODO: Testing: Write a test for this routine + # + # Not clear off the top of my head how + # to write the test... + # + assert True + +def test_normalized_laplacian_matrix(): + # TODO: Testing: Write a test for this routine + # + # This routine has not yet been implemented (as + # of October 2025), but when it is implemented + # we should add a test for it... + # + assert True + + +""" +============================================================= + +TODO: Code: ??? + + * Aliasing concerns: + + It occurs to me that the RX node_data is aliased with the NX node_data. + That is, the data dictionaries in the NX Graph are just retained + when the NX Graph is converted to be an RX Graph - so if you change + the data in the RX Graph, the NX Graph from which we created the RX + graph will also be changed. + + I believe that this is also true for subgraphs for both NX and RX, + meaning that the node_data in the subgraph is the exact same + data dictionary in the parent graph and the subgraph. + + I am not sure if this is a problem or not, but it is something + to be tested / thought about... + + * NX allows node_ids to be almost anything - they can be integers, + strings, even tuples. I think that they just need to be hashable. + + I don't know if we need to test that non-integer NX node_ids + don't cause a problem. There are tests elsewhere that have + NX node_ids that are tuples, and that test passes, so I think + we are OK, but there are no tests specifically targeting this + issue that I know of. + +============================================================= +""" diff --git a/tests/test_laplacian.py b/tests/test_laplacian.py new file mode 100644 index 00000000..927da85d --- /dev/null +++ b/tests/test_laplacian.py @@ -0,0 +1,98 @@ + +import pytest +import networkx as nx +import rustworkx as rx +import numpy as np +from gerrychain.graph import Graph +import gerrychain.tree as gctree + +""" +This tests whether we compute the same laplacian matrix for NX and RX +based Graph objects. + +The NX version is computed (as was true in the old code) by a built-in +NetworkX routine. The RX version is computed by code added when we +supported RX as the embedded graph object. + +The NX version produces ints from the code below, while the RX +version produces floats. I don't think this matters as the laplacian +matrix is used to do numerical calculations, so that code should +happily use ints or floats, but it means that for this test I need +to convert the NX version's result to have floating point values. +""" + +# frm: TODO: Testing: Add additional tests for laplacian matrix calculations, in +# particular, add a test for normalized_laplacian_matrix() +# once that routine has been implemented. + + +def are_sparse_matrices_equal(sparse_matrix1, sparse_matrix2, rtol=1e-05, atol=1e-08): + """ + Checks if two scipy.sparse.csr_matrix objects are equal, considering + potential floating-point inaccuracies in the data. + + Args: + sparse_matrix1 (scipy.sparse.csr_matrix): The first sparse matrix. + sparse_matrix2 (scipy.sparse.csr_matrix): The second sparse matrix. + rtol (float): The relative tolerance parameter for np.allclose. + atol (float): The absolute tolerance parameter for np.allclose. + + Returns: + bool: True if the sparse matrices are equal, False otherwise. + """ + # Check if shapes are equal + if sparse_matrix1.shape != sparse_matrix2.shape: + return False + + # Check if the number of non-zero elements is equal + if sparse_matrix1.nnz != sparse_matrix2.nnz: + return False + + # Check for equality of structural components (indices and indptr) + # These should be exact matches + if not (np.array_equal(sparse_matrix1.indices, sparse_matrix2.indices) and + np.array_equal(sparse_matrix1.indptr, sparse_matrix2.indptr)): + return False + + # Check for approximate equality of data (values) + # Use np.allclose to handle floating-point comparisons + if not np.allclose(sparse_matrix1.data, sparse_matrix2.data, rtol=rtol, atol=atol): + return False + + return True + +# Create equivalent NX and RX graphs from scratch + +@pytest.fixture +def nx_graph(): + this_nx_graph = nx.Graph([(0, 1), (0, 2), (1, 2), (2, 3)]) + return this_nx_graph + +@pytest.fixture +def rx_graph(): + this_rx_graph = rx.PyGraph() + # argument to add_node_from() is the data to be associated with each node. + # To be compatible with GerryChain, nodes need to have data values that are dictionaries + # so we just have an empty dict for each node's data + this_rx_graph.add_nodes_from([{}, {}, {}, {}]) + this_rx_graph.add_edges_from([(0, 1, {}), (0, 2, {}), (1, 2, {}), (2, 3, {})]) + return this_rx_graph + + +def test_nx_rx_laplacian_matrix_equality(nx_graph, rx_graph): + + # Create Graph objects from the NX and RX graphs + gc_nx_graph = Graph.from_networkx(nx_graph) + gc_rx_graph = Graph.from_rustworkx(rx_graph) + + # Compute the laplacian_matrix for both the NX and RX based Graph objects + gc_nx_laplacian_matrix = gc_nx_graph.laplacian_matrix() + gc_rx_laplacian_matrix = gc_rx_graph.laplacian_matrix() + + # Convert values in the NX version to be floating point + float_gc_nx_laplacian_matrix = gc_nx_laplacian_matrix.astype(float) + + # test equality + matrices_are_equal = are_sparse_matrices_equal(float_gc_nx_laplacian_matrix, gc_rx_laplacian_matrix) + assert(matrices_are_equal) + diff --git a/tests/test_make_graph.py b/tests/test_make_graph.py index e7c1a5c8..c0efc9e2 100644 --- a/tests/test_make_graph.py +++ b/tests/test_make_graph.py @@ -8,9 +8,14 @@ from shapely.geometry import Polygon from pyproj import CRS +import networkx + from gerrychain.graph import Graph from gerrychain.graph.geo import GeometryError +# frm: added following import +# from gerrychain.graph import node_data + @pytest.fixture def geodataframe(): @@ -65,27 +70,44 @@ def target_file(): def test_add_data_to_graph_can_handle_column_names_that_start_with_numbers(): - graph = Graph([("01", "02"), ("02", "03"), ("03", "01")]) + + # frm: Test has been modified to work with new Graph object that has an NetworkX.Graph + # object embedded inside it. I am not sure if this test actually tests + # anything useful anymore... + + nx_graph = networkx.Graph([("01", "02"), ("02", "03"), ("03", "01")]) df = pandas.DataFrame({"16SenDVote": [20, 30, 50], "node": ["01", "02", "03"]}) df = df.set_index("node") + # frm: Note that the new Graph only supports the add_data() routine if + # the underlying graph object is an NX Graph + + graph = Graph.from_networkx(nx_graph) + graph.add_data(df, ["16SenDVote"]) - assert graph.nodes["01"]["16SenDVote"] == 20 - assert graph.nodes["02"]["16SenDVote"] == 30 - assert graph.nodes["03"]["16SenDVote"] == 50 + # Test that the embedded nx_graph object has the added data + assert nx_graph.nodes["01"]["16SenDVote"] == 20 + assert nx_graph.nodes["02"]["16SenDVote"] == 30 + assert nx_graph.nodes["03"]["16SenDVote"] == 50 + + # Test that the graph object has the added data + assert graph.node_data("01")["16SenDVote"] == 20 + assert graph.node_data("02")["16SenDVote"] == 30 + assert graph.node_data("03")["16SenDVote"] == 50 def test_join_can_handle_right_index(): - graph = Graph([("01", "02"), ("02", "03"), ("03", "01")]) + nx_graph = networkx.Graph([("01", "02"), ("02", "03"), ("03", "01")]) df = pandas.DataFrame({"16SenDVote": [20, 30, 50], "node": ["01", "02", "03"]}) - graph.join(df, ["16SenDVote"], right_index="node") + graph = Graph.from_networkx(nx_graph) - assert graph.nodes["01"]["16SenDVote"] == 20 - assert graph.nodes["02"]["16SenDVote"] == 30 - assert graph.nodes["03"]["16SenDVote"] == 50 + graph.join(df, ["16SenDVote"], right_index="node") + assert graph.node_data("01")["16SenDVote"] == 20 + assert graph.node_data("02")["16SenDVote"] == 30 + assert graph.node_data("03")["16SenDVote"] == 50 def test_make_graph_from_dataframe_creates_graph(geodataframe): graph = Graph.from_geodataframe(geodataframe) @@ -132,10 +154,10 @@ def test_can_insist_on_not_reprojecting(geodataframe): graph = Graph.from_geodataframe(df, reproject=False) for node in ("a", "b", "c", "d"): - assert graph.nodes[node]["area"] == 1 + assert graph.node_data(node)["area"] == 1 for edge in graph.edges: - assert graph.edges[edge]["shared_perim"] == 1 + assert graph.edge_data(edge)["shared_perim"] == 1 def test_does_not_reproject_by_default(geodataframe): @@ -143,10 +165,10 @@ def test_does_not_reproject_by_default(geodataframe): graph = Graph.from_geodataframe(df) for node in ("a", "b", "c", "d"): - assert graph.nodes[node]["area"] == 1.0 + assert graph.node_data(node)["area"] == 1.0 for edge in graph.edges: - assert graph.edges[edge]["shared_perim"] == 1.0 + assert graph.edge_data(edge)["shared_perim"] == 1.0 def test_reproject(geodataframe): @@ -156,10 +178,10 @@ def test_reproject(geodataframe): graph = Graph.from_geodataframe(df, reproject=True) for node in ("a", "b", "c", "d"): - assert graph.nodes[node]["area"] != 1 + assert graph.node_data(node)["area"] != 1 for edge in graph.edges: - assert graph.edges[edge]["shared_perim"] != 1 + assert graph.edge_data(edge)["shared_perim"] != 1 def test_identifies_boundary_nodes(geodataframe_with_boundary): @@ -167,8 +189,8 @@ def test_identifies_boundary_nodes(geodataframe_with_boundary): graph = Graph.from_geodataframe(df) for node in ("a", "b", "c", "e"): - assert graph.nodes[node]["boundary_node"] - assert not graph.nodes["d"]["boundary_node"] + assert graph.node_data(node)["boundary_node"] + assert not graph.node_data("d")["boundary_node"] def test_computes_boundary_perims(geodataframe_with_boundary): @@ -178,41 +200,57 @@ def test_computes_boundary_perims(geodataframe_with_boundary): expected = {"a": 5, "e": 5, "b": 1, "c": 1} for node, value in expected.items(): - assert graph.nodes[node]["boundary_perim"] == value + assert graph.node_data(node)["boundary_perim"] == value def edge_set_equal(set1, set2): - return {(y, x) for x, y in set1} | set1 == {(y, x) for x, y in set2} | set2 - + """ + Returns true if the two sets have the same edges. + + The complication is that for an edge, (1,2) is the same as (2,1), so to compare them you + need to canonicalize the edges somehow. This code just takes set1 and set2 and creates + a new set for each that has both edge pairs for each edge, and it then compares those new sets. + """ + canonical_set1 = {(y, x) for x, y in set1} | set1 + canonical_set2 = {(y, x) for x, y in set2} | set2 + return canonical_set1 == canonical_set2 def test_from_file_adds_all_data_by_default(shapefile): graph = Graph.from_file(shapefile) - assert all("data" in node_data for node_data in graph.nodes.values()) - assert all("data2" in node_data for node_data in graph.nodes.values()) + nx_graph = graph.get_nx_graph() + + assert all("data" in node_data for node_data in nx_graph.nodes.values()) + assert all("data2" in node_data for node_data in nx_graph.nodes.values()) def test_from_file_and_then_to_json_does_not_error(shapefile, target_file): graph = Graph.from_file(shapefile) + nx_graph = graph.get_nx_graph() + # Even the geometry column is copied to the graph - assert all("geometry" in node_data for node_data in graph.nodes.values()) + assert all("geometry" in node_data for node_data in nx_graph.nodes.values()) graph.to_json(target_file) def test_from_file_and_then_to_json_with_geometries(shapefile, target_file): graph = Graph.from_file(shapefile) + + nx_graph = graph.get_nx_graph() # Even the geometry column is copied to the graph - assert all("geometry" in node_data for node_data in graph.nodes.values()) + assert all("geometry" in node_data for node_data in nx_graph.nodes.values()) graph.to_json(target_file, include_geometries_as_geojson=True) def test_graph_warns_for_islands(): - graph = Graph() - graph.add_node(0) + nx_graph = networkx.Graph() + nx_graph.add_node(0) + + graph = Graph.from_networkx(nx_graph) with pytest.warns(Warning): graph.warn_for_islands() @@ -255,4 +293,4 @@ def test_make_graph_from_dataframe_has_crs(gdf_with_data): def test_make_graph_from_shapefile_has_crs(shapefile): graph = Graph.from_file(shapefile) df = gp.read_file(shapefile) - assert CRS.from_json(graph.graph["crs"]).equals(df.crs) \ No newline at end of file + assert CRS.from_json(graph.graph["crs"]).equals(df.crs) diff --git a/tests/test_metagraph.py b/tests/test_metagraph.py index 03aa2d59..a4bad50d 100644 --- a/tests/test_metagraph.py +++ b/tests/test_metagraph.py @@ -12,12 +12,30 @@ def partition(graph): def test_all_cut_edge_flips(partition): + + # frm: TODO: Testing: Maybe change all_cut_edge_flips to return a dict + # + # At present, it returns an iterator, which makes the code below + # more complicated than it needs to be. If it just returned + # a dict, then the code would be: + # + # result = set( + # node, part for all_cut_edge_flips(partition).items() + # ) + # result = set( (node, part) for flip in all_cut_edge_flips(partition) for node, part in flip.items() ) - assert result == {(6, 1), (7, 1), (8, 1), (4, 2), (5, 2), (3, 2)} + + # Convert from internal node_ids to "original" node_ids + new_result = set() + for internal_node_id, part in result: + original_nx_node_id = partition.graph.original_nx_node_id_for_internal_node_id(internal_node_id) + new_result.add((original_nx_node_id, part)) + + assert new_result == {(6, 1), (7, 1), (8, 1), (4, 2), (5, 2), (3, 2)} class TestAllValidStatesOneFlipAway: @@ -35,6 +53,7 @@ def test_accepts_list_of_constraints(self, partition): def test_all_valid_flips(partition): + # frm: TODO: Testing: NX vs. RX node_id issues... def disallow_six_to_one(partition): for node, part in partition.flips.items(): if node == 6 and part == 1: @@ -43,9 +62,21 @@ def disallow_six_to_one(partition): constraints = [disallow_six_to_one] + # frm: TODO: Testing: If I created a utility routine to convert + # a list of flips to original node_ids, + # then I could use that here and then + # convert the resulting list to a set... + result = set( (node, part) for flip in all_valid_flips(partition, constraints) for node, part in flip.items() ) - assert result == {(7, 1), (8, 1), (4, 2), (5, 2), (3, 2)} + + # Convert from internal node_ids to "original" node_ids + new_result = set() + for internal_node_id, part in result: + original_nx_node_id = partition.graph.original_nx_node_id_for_internal_node_id(internal_node_id) + new_result.add((original_nx_node_id, part)) + + assert new_result == {(7, 1), (8, 1), (4, 2), (5, 2), (3, 2)} diff --git a/tests/test_region_aware.py b/tests/test_region_aware.py index bcbe1e0f..0fb1c305 100644 --- a/tests/test_region_aware.py +++ b/tests/test_region_aware.py @@ -161,9 +161,11 @@ def straddled_regions(partition, reg_attr, all_reg_names): """Returns the total number of district that straddle two regions in the partition.""" split = {name: 0 for name in all_reg_names} + # frm: TODO: Testing: Grok what this tests - not clear to me at this time... + for node1, node2 in set(partition.graph.edges() - partition["cut_edges"]): - split[partition.graph.nodes[node1][reg_attr]] += 1 - split[partition.graph.nodes[node2][reg_attr]] += 1 + split[partition.graph.node_data(node1)[reg_attr]] += 1 + split[partition.graph.node_data(node2)[reg_attr]] += 1 return sum(1 for value in split.values() if value > 0) diff --git a/tests/test_reproducibility.py b/tests/test_reproducibility.py index 85d2122c..49ccafec 100644 --- a/tests/test_reproducibility.py +++ b/tests/test_reproducibility.py @@ -58,7 +58,6 @@ def test_repeatable(three_by_three_grid): {3: 1}, ] flips = [partition.flips for partition in chain] - print(flips) assert flips == expected_flips diff --git a/tests/test_tally.py b/tests/test_tally.py index 220bc149..24757cc3 100644 --- a/tests/test_tally.py +++ b/tests/test_tally.py @@ -8,12 +8,15 @@ import random from gerrychain.updaters.tally import DataTally, Tally random.seed(2018) +import networkx def random_assignment(graph, num_districts): return {node: random.choice(range(num_districts)) for node in graph.nodes} def test_data_tally_works_as_an_updater(three_by_three_grid): + # Simple test that a DataTally creates an attribute on a partition. + # Another test (below) checks that the computed "tally" is correct. assignment = random_assignment(three_by_three_grid, 4) data = {node: random.randint(1, 100) for node in three_by_three_grid.nodes} parts = tuple(set(assignment.values())) @@ -27,17 +30,30 @@ def test_data_tally_works_as_an_updater(three_by_three_grid): def test_data_tally_gives_expected_value(three_by_three_grid): + # Put all but one of the nodes in part #1, and put the one "first_node" + # into part #2. + first_node = next(iter(three_by_three_grid.nodes)) assignment = {node: 1 for node in three_by_three_grid.nodes} assignment[first_node] = 2 + # All nodes get a value of 1 for the data to be tallied data = {node: 1 for node in three_by_three_grid} updaters = {"tally": DataTally(data, alias="tally")} partition = Partition(three_by_three_grid, assignment, updaters) + # Note that in general a flip using node_ids generated before creating + # a partition should be translated into "internal" RX-Graph based + # node_ids. In this case it is not needed, because it doesn't matter + # whether we are using the "original" or the "internal" node_id for + # first_node because it still refers to the same node and nothing else + # is going on. + + # Create a new partition, adding the "first_node" to part #1 flip = {first_node: 1} new_partition = partition.flip(flip) + # The "tally" should increase by one because of the flipped node's data assert new_partition["tally"][1] == partition["tally"][1] + 1 @@ -49,7 +65,7 @@ def test_data_tally_mimics_old_tally_usage(graph_with_random_data_factory): assignment = {i: 1 if i in range(4) else 2 for i in range(9)} partition = Partition(graph, assignment, updaters) - expected_total_in_district_one = sum(graph.nodes[i]["total"] for i in range(4)) + expected_total_in_district_one = sum(graph.node_data(i)["total"] for i in range(4)) assert partition["total"][1] == expected_total_in_district_one @@ -68,7 +84,7 @@ def get_expected_tally(partition): expected = defaultdict(int) for node in partition.graph.nodes: part = partition.assignment[node] - expected[part] += partition.graph.nodes[node]["population"] + expected[part] += partition.graph.node_data(node)["population"] return expected for state in chain: @@ -77,9 +93,10 @@ def get_expected_tally(partition): def test_works_when_no_flips_occur(): - graph = Graph([(0, 1), (1, 2), (2, 3), (3, 0)]) + nx_graph = networkx.Graph([(0, 1), (1, 2), (2, 3), (3, 0)]) + graph = Graph.from_networkx(nx_graph) for node in graph: - graph.nodes[node]["pop"] = node + 1 + graph.node_data(node)["pop"] = node + 1 partition = Partition(graph, {0: 0, 1: 0, 2: 1, 3: 1}, {"pop": Tally("pop")}) chain = MarkovChain(lambda p: p.flip({}), [], always_accept, partition, 10) diff --git a/tests/test_tree.py b/tests/test_tree.py index 1805b8ca..fad9741a 100644 --- a/tests/test_tree.py +++ b/tests/test_tree.py @@ -1,6 +1,7 @@ import functools import networkx +import rustworkx import pytest from gerrychain import MarkovChain @@ -26,89 +27,178 @@ random.seed(2018) +# +# This code is complicated by the need to test both NX-based +# and RX-based Graph objects. +# +# The pattern is to define the test logic in a routine that +# will be run with both NX-based and RX-based Graph objects +# and to then have the actual test case call that logic. +# This buries the asserts down a level, which means that +# figuring out what went wrong if a test fails will be +# slightly more challenging, but it keeps the logic for +# testing both NX-based and RX-based Graph objects clean. +# + +# frm: TODO: Documentation: test_tree.py: explain nx_to_rx_node_id_map @pytest.fixture -def graph_with_pop(three_by_three_grid): +def graph_with_pop_nx(three_by_three_grid): + # NX-based Graph object for node in three_by_three_grid: - three_by_three_grid.nodes[node]["pop"] = 1 - return Graph.from_networkx(three_by_three_grid) + three_by_three_grid.node_data(node)["pop"] = 1 + return three_by_three_grid +@pytest.fixture +def graph_with_pop_rx(graph_with_pop_nx): + # RX-based Graph object (same data as NX-based version) + graph_rx = graph_with_pop_nx.convert_from_nx_to_rx() + return graph_rx @pytest.fixture -def partition_with_pop(graph_with_pop): +def partition_with_pop(graph_with_pop_nx): + # No need for an RX-based Graph here because creating the + # Partition object converts the graph to be RX-based if + # it is not already RX-based + # return Partition( - graph_with_pop, + graph_with_pop_nx, {0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 1, 6: 1, 7: 1, 8: 1}, updaters={"pop": Tally("pop"), "cut_edges": cut_edges}, ) - @pytest.fixture -def twelve_by_twelve_with_pop(): +def twelve_by_twelve_with_pop_nx(): + # NX-based Graph object + xy_grid = networkx.grid_graph([12, 12]) + + # Relabel nodes with integers rather than tuples. Node + # in cartesian coordinate (x,y) will be relabeled with + # the integer = x*12 + y , which just numbers nodes + # sequentially from 0 by row... + # nodes = {node: node[1] + 12 * node[0] for node in xy_grid} grid = networkx.relabel_nodes(xy_grid, nodes) + for node in grid: grid.nodes[node]["pop"] = 1 return Graph.from_networkx(grid) +@pytest.fixture +def twelve_by_twelve_with_pop_rx(twelve_by_twelve_with_pop_nx): + # RX-based Graph object (same data as NX-based version) + graph_rx = twelve_by_twelve_with_pop_nx.convert_from_nx_to_rx() + return graph_rx + +# --------------------------------------------------------------------- -def test_bipartition_tree_returns_a_subset_of_nodes(graph_with_pop): - ideal_pop = sum(graph_with_pop.nodes[node]["pop"] for node in graph_with_pop) / 2 - result = bipartition_tree(graph_with_pop, "pop", ideal_pop, 0.25, 10) +def do_test_bipartition_tree_random_returns_a_subset_of_nodes(graph): + ideal_pop = sum(graph.node_data(node)["pop"] for node in graph) / 2 + result = bipartition_tree_random(graph, "pop", ideal_pop, 0.25, 10) assert isinstance(result, frozenset) - assert all(node in graph_with_pop.nodes for node in result) + assert all(node in graph.nodes for node in result) + +def test_bipartition_tree_random_returns_a_subset_of_nodes(graph_with_pop_nx, graph_with_pop_rx): + # Test both NX-based and RX-based Graph objects + do_test_bipartition_tree_random_returns_a_subset_of_nodes(graph_with_pop_nx) + do_test_bipartition_tree_random_returns_a_subset_of_nodes(graph_with_pop_rx) +# --------------------------------------------------------------------- -def test_bipartition_tree_returns_within_epsilon_of_target_pop(graph_with_pop): - ideal_pop = sum(graph_with_pop.nodes[node]["pop"] for node in graph_with_pop) / 2 +def do_test_bipartition_tree_random_returns_within_epsilon_of_target_pop(graph): + ideal_pop = sum(graph.node_data(node)["pop"] for node in graph) / 2 epsilon = 0.25 - result = bipartition_tree(graph_with_pop, "pop", ideal_pop, epsilon, 10) + result = bipartition_tree_random(graph, "pop", ideal_pop, epsilon, 10) - part_pop = sum(graph_with_pop.nodes[node]["pop"] for node in result) + part_pop = sum(graph.node_data(node)["pop"] for node in result) assert abs(part_pop - ideal_pop) / ideal_pop < epsilon - -def test_recursive_tree_part_returns_within_epsilon_of_target_pop( - twelve_by_twelve_with_pop, +def test_bipartition_tree_random_returns_within_epsilon_of_target_pop( + graph_with_pop_nx, + graph_with_pop_rx ): + # Test both NX-based and RX-based Graph objects + do_test_bipartition_tree_random_returns_within_epsilon_of_target_pop(graph_with_pop_nx) + do_test_bipartition_tree_random_returns_within_epsilon_of_target_pop(graph_with_pop_rx) + +# --------------------------------------------------------------------- + +def do_test_bipartition_tree_returns_a_subset_of_nodes(graph): + ideal_pop = sum(graph.node_data(node)["pop"] for node in graph) / 2 + result = bipartition_tree(graph, "pop", ideal_pop, 0.25, 10) + assert isinstance(result, frozenset) + assert all(node in graph.nodes for node in result) + +def test_bipartition_tree_returns_a_subset_of_nodes(graph_with_pop_nx, graph_with_pop_rx): + # Test both NX-based and RX-based Graph objects + do_test_bipartition_tree_returns_a_subset_of_nodes(graph_with_pop_nx) + do_test_bipartition_tree_returns_a_subset_of_nodes(graph_with_pop_rx) + +# --------------------------------------------------------------------- + +def do_test_bipartition_tree_returns_within_epsilon_of_target_pop(graph): + ideal_pop = sum(graph.node_data(node)["pop"] for node in graph) / 2 + epsilon = 0.25 + result = bipartition_tree(graph, "pop", ideal_pop, epsilon, 10) + + part_pop = sum(graph.node_data(node)["pop"] for node in result) + assert abs(part_pop - ideal_pop) / ideal_pop < epsilon + +def test_bipartition_tree_returns_within_epsilon_of_target_pop(graph_with_pop_nx, graph_with_pop_rx): + # Test both NX-based and RX-based Graph objects + do_test_bipartition_tree_returns_within_epsilon_of_target_pop(graph_with_pop_nx) + do_test_bipartition_tree_returns_within_epsilon_of_target_pop(graph_with_pop_rx) + +# --------------------------------------------------------------------- + +def do_test_recursive_tree_part_returns_within_epsilon_of_target_pop(twelve_by_twelve_with_pop_graph): n_districts = 7 # 144/7 ≈ 20.5 nodes/subgraph (1 person/node) ideal_pop = ( sum( - twelve_by_twelve_with_pop.nodes[node]["pop"] - for node in twelve_by_twelve_with_pop + twelve_by_twelve_with_pop_graph.node_data(node)["pop"] + for node in twelve_by_twelve_with_pop_graph ) ) / n_districts epsilon = 0.05 result = recursive_tree_part( - twelve_by_twelve_with_pop, + twelve_by_twelve_with_pop_graph, range(n_districts), ideal_pop, "pop", epsilon, ) partition = Partition( - twelve_by_twelve_with_pop, result, updaters={"pop": Tally("pop")} + twelve_by_twelve_with_pop_graph, result, updaters={"pop": Tally("pop")} ) - return all( + assert all( abs(part_pop - ideal_pop) / ideal_pop < epsilon for part_pop in partition["pop"].values() ) +def test_recursive_tree_part_returns_within_epsilon_of_target_pop( + twelve_by_twelve_with_pop_nx, + twelve_by_twelve_with_pop_rx +): + # Test both NX-based and RX-based Graph objects + do_test_recursive_tree_part_returns_within_epsilon_of_target_pop(twelve_by_twelve_with_pop_nx) + do_test_recursive_tree_part_returns_within_epsilon_of_target_pop(twelve_by_twelve_with_pop_rx) + +# --------------------------------------------------------------------- -def test_recursive_tree_part_returns_within_epsilon_of_target_pop_using_contraction( - twelve_by_twelve_with_pop, +def do_test_recursive_tree_part_returns_within_epsilon_of_target_pop_using_contraction( + twelve_by_twelve_with_pop_graph, ): n_districts = 7 # 144/7 ≈ 20.5 nodes/subgraph (1 person/node) ideal_pop = ( sum( - twelve_by_twelve_with_pop.nodes[node]["pop"] - for node in twelve_by_twelve_with_pop + twelve_by_twelve_with_pop_graph.node_data(node)["pop"] + for node in twelve_by_twelve_with_pop_graph ) ) / n_districts epsilon = 0.05 result = recursive_tree_part( - twelve_by_twelve_with_pop, + twelve_by_twelve_with_pop_graph, range(n_districts), ideal_pop, "pop", @@ -120,27 +210,40 @@ def test_recursive_tree_part_returns_within_epsilon_of_target_pop_using_contract ), ) partition = Partition( - twelve_by_twelve_with_pop, result, updaters={"pop": Tally("pop")} + twelve_by_twelve_with_pop_graph, result, updaters={"pop": Tally("pop")} ) - return all( + assert all( abs(part_pop - ideal_pop) / ideal_pop < epsilon for part_pop in partition["pop"].values() ) +def test_recursive_tree_part_returns_within_epsilon_of_target_pop_using_contraction( + twelve_by_twelve_with_pop_nx, + twelve_by_twelve_with_pop_rx +): + # Test both NX-based and RX-based Graph objects + do_test_recursive_tree_part_returns_within_epsilon_of_target_pop_using_contraction( + twelve_by_twelve_with_pop_nx + ) + do_test_recursive_tree_part_returns_within_epsilon_of_target_pop_using_contraction( + twelve_by_twelve_with_pop_rx + ) + +# --------------------------------------------------------------------- -def test_recursive_seed_part_returns_within_epsilon_of_target_pop( - twelve_by_twelve_with_pop, +def do_test_recursive_seed_part_returns_within_epsilon_of_target_pop( + twelve_by_twelve_with_pop_graph, ): n_districts = 7 # 144/7 ≈ 20.5 nodes/subgraph (1 person/node) ideal_pop = ( sum( - twelve_by_twelve_with_pop.nodes[node]["pop"] - for node in twelve_by_twelve_with_pop + twelve_by_twelve_with_pop_graph.node_data(node)["pop"] + for node in twelve_by_twelve_with_pop_graph ) ) / n_districts epsilon = 0.1 result = recursive_seed_part( - twelve_by_twelve_with_pop, + twelve_by_twelve_with_pop_graph, range(n_districts), ideal_pop, "pop", @@ -149,27 +252,36 @@ def test_recursive_seed_part_returns_within_epsilon_of_target_pop( ceil=None, ) partition = Partition( - twelve_by_twelve_with_pop, result, updaters={"pop": Tally("pop")} + twelve_by_twelve_with_pop_graph, result, updaters={"pop": Tally("pop")} ) - return all( + assert all( abs(part_pop - ideal_pop) / ideal_pop < epsilon for part_pop in partition["pop"].values() ) +def test_recursive_seed_part_returns_within_epsilon_of_target_pop( + twelve_by_twelve_with_pop_nx, + twelve_by_twelve_with_pop_rx +): + # Test both NX-based and RX-based Graph objects + do_test_recursive_seed_part_returns_within_epsilon_of_target_pop(twelve_by_twelve_with_pop_nx) + do_test_recursive_seed_part_returns_within_epsilon_of_target_pop(twelve_by_twelve_with_pop_rx) + +# --------------------------------------------------------------------- -def test_recursive_seed_part_returns_within_epsilon_of_target_pop_using_contraction( - twelve_by_twelve_with_pop, +def do_test_recursive_seed_part_returns_within_epsilon_of_target_pop_using_contraction( + twelve_by_twelve_with_pop_graph, ): n_districts = 7 # 144/7 ≈ 20.5 nodes/subgraph (1 person/node) ideal_pop = ( sum( - twelve_by_twelve_with_pop.nodes[node]["pop"] - for node in twelve_by_twelve_with_pop + twelve_by_twelve_with_pop_graph.node_data(node)["pop"] + for node in twelve_by_twelve_with_pop_graph ) ) / n_districts epsilon = 0.1 result = recursive_seed_part( - twelve_by_twelve_with_pop, + twelve_by_twelve_with_pop_graph, range(n_districts), ideal_pop, "pop", @@ -183,15 +295,28 @@ def test_recursive_seed_part_returns_within_epsilon_of_target_pop_using_contract ), ) partition = Partition( - twelve_by_twelve_with_pop, result, updaters={"pop": Tally("pop")} + twelve_by_twelve_with_pop_graph, result, updaters={"pop": Tally("pop")} ) - return all( + assert all( abs(part_pop - ideal_pop) / ideal_pop < epsilon for part_pop in partition["pop"].values() ) +def test_recursive_seed_part_returns_within_epsilon_of_target_pop_using_contraction( + twelve_by_twelve_with_pop_nx, + twelve_by_twelve_with_pop_rx +): + # Test both NX-based and RX-based Graph objects + do_test_recursive_seed_part_returns_within_epsilon_of_target_pop_using_contraction( + twelve_by_twelve_with_pop_nx + ) + do_test_recursive_seed_part_returns_within_epsilon_of_target_pop_using_contraction( + twelve_by_twelve_with_pop_rx + ) + +# --------------------------------------------------------------------- -def test_recursive_seed_part_uses_method(twelve_by_twelve_with_pop): +def do_test_recursive_seed_part_uses_method(twelve_by_twelve_with_pop_graph): calls = 0 def dummy_method(graph, pop_col, pop_target, epsilon, node_repeats, one_sided_cut): @@ -210,13 +335,13 @@ def dummy_method(graph, pop_col, pop_target, epsilon, node_repeats, one_sided_cu n_districts = 7 # 144/7 ≈ 20.5 nodes/subgraph (1 person/node) ideal_pop = ( sum( - twelve_by_twelve_with_pop.nodes[node]["pop"] - for node in twelve_by_twelve_with_pop + twelve_by_twelve_with_pop_graph.node_data(node)["pop"] + for node in twelve_by_twelve_with_pop_graph ) ) / n_districts epsilon = 0.1 result = recursive_seed_part( - twelve_by_twelve_with_pop, + twelve_by_twelve_with_pop_graph, range(n_districts), ideal_pop, "pop", @@ -231,20 +356,26 @@ def dummy_method(graph, pop_col, pop_target, epsilon, node_repeats, one_sided_cu # implementation detail) assert calls >= n_districts - 1 +def test_recursive_seed_part_uses_method(twelve_by_twelve_with_pop_nx, twelve_by_twelve_with_pop_rx): + # Test both NX-based and RX-based Graph objects + do_test_recursive_seed_part_uses_method(twelve_by_twelve_with_pop_nx) + do_test_recursive_seed_part_uses_method(twelve_by_twelve_with_pop_rx) -def test_recursive_seed_part_with_n_unspecified_within_epsilon( - twelve_by_twelve_with_pop, +# --------------------------------------------------------------------- + +def do_test_recursive_seed_part_with_n_unspecified_within_epsilon( + twelve_by_twelve_with_pop_graph, ): n_districts = 6 # This should set n=3 ideal_pop = ( sum( - twelve_by_twelve_with_pop.nodes[node]["pop"] - for node in twelve_by_twelve_with_pop + twelve_by_twelve_with_pop_graph.node_data(node)["pop"] + for node in twelve_by_twelve_with_pop_graph ) ) / n_districts epsilon = 0.05 result = recursive_seed_part( - twelve_by_twelve_with_pop, + twelve_by_twelve_with_pop_graph, range(n_districts), ideal_pop, "pop", @@ -252,45 +383,128 @@ def test_recursive_seed_part_with_n_unspecified_within_epsilon( ceil=None, ) partition = Partition( - twelve_by_twelve_with_pop, result, updaters={"pop": Tally("pop")} + twelve_by_twelve_with_pop_graph, result, updaters={"pop": Tally("pop")} ) - return all( + assert all( abs(part_pop - ideal_pop) / ideal_pop < epsilon for part_pop in partition["pop"].values() ) +def test_recursive_seed_part_with_n_unspecified_within_epsilon( + twelve_by_twelve_with_pop_nx, + twelve_by_twelve_with_pop_rx +): + # Test both NX-based and RX-based Graph objects + do_test_recursive_seed_part_with_n_unspecified_within_epsilon(twelve_by_twelve_with_pop_nx) + do_test_recursive_seed_part_with_n_unspecified_within_epsilon(twelve_by_twelve_with_pop_rx) -def test_random_spanning_tree_returns_tree_with_pop_attribute(graph_with_pop): - tree = random_spanning_tree(graph_with_pop) - assert networkx.is_tree(tree) +# --------------------------------------------------------------------- +def do_test_random_spanning_tree_returns_tree_with_pop_attribute(graph): + tree = random_spanning_tree(graph) + assert tree.is_a_tree() -def test_uniform_spanning_tree_returns_tree_with_pop_attribute(graph_with_pop): - tree = uniform_spanning_tree(graph_with_pop) - assert networkx.is_tree(tree) +def test_random_spanning_tree_returns_tree_with_pop_attribute(graph_with_pop_nx, graph_with_pop_rx): + # Test both NX-based and RX-based Graph objects + do_test_random_spanning_tree_returns_tree_with_pop_attribute(graph_with_pop_nx) + do_test_random_spanning_tree_returns_tree_with_pop_attribute(graph_with_pop_rx) +# --------------------------------------------------------------------- -def test_bipartition_tree_returns_a_tree(graph_with_pop): - ideal_pop = sum(graph_with_pop.nodes[node]["pop"] for node in graph_with_pop) / 2 - tree = Graph.from_networkx( - networkx.Graph([(0, 1), (1, 2), (1, 4), (3, 4), (4, 5), (3, 6), (6, 7), (6, 8)]) - ) - for node in tree: - tree.nodes[node]["pop"] = graph_with_pop.nodes[node]["pop"] +def do_test_uniform_spanning_tree_returns_tree_with_pop_attribute(graph): + tree = uniform_spanning_tree(graph) + assert tree.is_a_tree() - result = bipartition_tree( - graph_with_pop, "pop", ideal_pop, 0.25, 10, tree, lambda x: 4 - ) +def test_uniform_spanning_tree_returns_tree_with_pop_attribute(graph_with_pop_nx, graph_with_pop_rx): + # Test both NX-based and RX-based Graph objects + do_test_uniform_spanning_tree_returns_tree_with_pop_attribute(graph_with_pop_nx) + do_test_uniform_spanning_tree_returns_tree_with_pop_attribute(graph_with_pop_rx) + +# --------------------------------------------------------------------- - assert networkx.is_tree(tree.subgraph(result)) - assert networkx.is_tree( - tree.subgraph({node for node in tree if node not in result}) +def do_test_bipartition_tree_returns_a_tree(graph, spanning_tree): + ideal_pop = sum(graph.node_data(node)["pop"] for node in graph) / 2 + + result = bipartition_tree( + graph, "pop", ideal_pop, 0.25, 10, spanning_tree, lambda x: 4 ) + assert spanning_tree.subgraph(result).is_a_tree() + assert spanning_tree.subgraph({node for node in spanning_tree if node not in result}).is_a_tree() + +def create_graphs_from_nx_edges(num_nodes, list_of_edges_nx, nx_to_rx_node_id_map): + + # NX is easy - just use the list of NX edges + graph_nx = Graph.from_networkx(networkx.Graph(list_of_edges_nx)) + + # RX requires more work. + # + # First we create the RX graph and add nodes. + # + # frm: TODO: Testing: Update test so that the number of nodes is not hard-coded... + # + # Then we to create the appropriate RX edges - the ones that + # correspond to the NX edges but using the RX node_ids for the edges. + # + # First we have to translate the node_ids used in the + # list of edges to be the ones used in the RX graph using the + # nx_to_rx_node_id_map. Then we need to create a rustworkx.PyGraph and then + # from that create a "new" Graph object. + + # Create the RX graph + rx_graph = rustworkx.PyGraph() + for i in range(num_nodes): + rx_graph.add_node({}) # empty data dict for node_data + # Verify that the nodes created have node_ids 0-(num_nodes-1) + assert(set(rx_graph.node_indices()) == set(range(num_nodes))) + # Set the attribute identifying the "original" NX node_id + # This is normally set by the code that converts an NX graph to RX + # but we are cobbling together stuff for a test and so have to + # just do it here... + rx_to_nx_node_id_map = {v: k for k,v in nx_to_rx_node_id_map.items()} + for node_id in rx_graph.node_indices(): + rx_graph[node_id]["__networkx_node__"] = rx_to_nx_node_id_map[node_id] + + # translate the NX edges into the appropriate node_ids for the derived RX graph + list_of_edges_rx = [ + ( + nx_to_rx_node_id_map[edge[0]], + nx_to_rx_node_id_map[edge[1]], + {} # empty data dict for edge_data + ) + for edge in list_of_edges_nx + ] + + # Add the RX edges + rx_graph.add_edges_from(list_of_edges_rx) + graph_rx = Graph.from_rustworkx(rx_graph) + + return graph_nx, graph_rx + +def test_bipartition_tree_returns_a_tree(graph_with_pop_nx, graph_with_pop_rx): + # Test both NX-based and RX-based Graph objects + + spanning_tree_edges_nx = [(0, 1), (1, 2), (1, 4), (3, 4), (4, 5), (3, 6), (6, 7), (6, 8)] + + spanning_tree_nx, spanning_tree_rx = \ + create_graphs_from_nx_edges( + 9, + spanning_tree_edges_nx, + graph_with_pop_rx.nx_to_rx_node_id_map + ) + + # Give the nodes a population + for node in spanning_tree_nx: + spanning_tree_nx.node_data(node)["pop"] = 1 + for node in spanning_tree_rx: + spanning_tree_rx.node_data(node)["pop"] = 1 + + do_test_bipartition_tree_returns_a_tree(graph_with_pop_nx, spanning_tree_nx) + do_test_bipartition_tree_returns_a_tree(graph_with_pop_rx, spanning_tree_rx) def test_recom_works_as_a_proposal(partition_with_pop): graph = partition_with_pop.graph - ideal_pop = sum(graph.nodes[node]["pop"] for node in graph) / 2 + ideal_pop = sum(graph.node_data(node)["pop"] for node in graph) / 2 proposal = functools.partial( recom, pop_col="pop", pop_target=ideal_pop, epsilon=0.25, node_repeats=5 ) @@ -301,23 +515,72 @@ def test_recom_works_as_a_proposal(partition_with_pop): for state in chain: assert contiguous(state) - def test_reversible_recom_works_as_a_proposal(partition_with_pop): random.seed(2018) graph = partition_with_pop.graph - ideal_pop = sum(graph.nodes[node]["pop"] for node in graph) / 2 + ideal_pop = sum(graph.node_data(node)["pop"] for node in graph) / 2 proposal = functools.partial( reversible_recom, pop_col="pop", pop_target=ideal_pop, epsilon=0.10, M=1 ) constraints = [within_percent_of_ideal_population(partition_with_pop, 0.25, "pop")] + # frm: ???: I am not sure how epsilon of 0.10 interacts with the constraint. + # + # The issue is that there are 9 nodes each with a population of 1, so the ideal population + # is 4.5. But no matter how you split the graph, you end up with an integer population, say, + # 4 or 5 - so you will never get within 0.10 of 4.5. + # + # I am not quite sure what is being tested here... + # + # within_percent_of_ideal_population() returns a Bounds object which contains the lower and + # upper bounds for a given value - in this case 0.25 percent of the ideal population. + # + # The more I did into this the more I shake my head. The value of "epsilon" passed into the + # reversible_recom() seems to only ever be used when creating a PopulatedGraph which in turn + # only ever uses it when doing a specific balanced edge cut algorithm. That is, the value of + # epsilon is very rarely used, and yet it is passed in as one of the important paramters to + # reversible_recom(). It looks like the original coders thought that it would be a great thing + # to have in the PopulatedGraph object, but then they didn't actually use it. *sigh* + # + # Then this test defines a constraint for population defining it to be OK if the population + # is within 25% of ideal - which is at odds with the value of epsilon above of 10%, but since + # the value of epsilon (of 10%) is never used, whatever... + # + + # frm: TODO: Testing: Grok this test - what is it trying to accomplish? + # + # The proposal uses reversible_recom() with the default value for the "repeat_until_valid" + # parameter which is False. This means that the call to try to combine and then split two + # parts (districts) only gets one shot at it before it fails. In this case, that means that + # it fails EVERY time - because the initial spanning tree that is returned is not balanced + # enough to satisfy the population constraint. If you let it run, then it succeeds after + # a couple of attempts (I think 10), but it never succeeds on the first try, and there is no + # randomness possible since we only have two parts (districts) that we can merge. + # + # So this test runs through 100 chain iterations doing NOTHING - returning the same partition + # each iteration, and in fact returning the same partition at the end that it started with. + # + # This raises all sorts of issues: + # + # * Makes no sense for this test + # * Questions the logic in reversible_recom() to not detect an infinite loop + # * Questions the logic that does not inform the user somehow that the chain is ineffective + # * Raises the issue of documentation of the code - it took me quite a long time to + # figure out WTF was going on... + # + + chain = MarkovChain(proposal, constraints, lambda x: True, partition_with_pop, 100) for state in chain: assert contiguous(state) +# frm: TODO: Testing: Add more tests using MarkovChain... def test_find_balanced_cuts_contraction(): + + # frm: TODO: Testing: Add test for RX-based Graph object + tree = Graph.from_networkx( networkx.Graph([(0, 1), (1, 2), (1, 4), (3, 4), (4, 5), (3, 6), (6, 7), (6, 8)]) ) @@ -339,50 +602,112 @@ def test_find_balanced_cuts_contraction(): def test_no_balanced_cuts_contraction_when_one_side_okay(): - tree = Graph.from_networkx(networkx.Graph([(0, 1), (1, 2), (2, 3), (3, 4)])) + list_of_nodes_nx = ([(0, 1), (1, 2), (2, 3), (3, 4)]) + + # For this test we are not dealing with an RX-based Graph object + # that is derived fromn an NX-based Graph object, so the + # nx_to_rx_node_id_map can just be the identity map... + # + nx_to_rx_node_id_map = { node: node for node in range(5) } + + tree_nx, tree_rx = \ + create_graphs_from_nx_edges( + 5, + list_of_nodes_nx, + nx_to_rx_node_id_map + ) + + # OK to use the same populations for NX and RX graphs populations = {0: 4, 1: 4, 2: 3, 3: 3, 4: 3} - populated_tree = PopulatedGraph( - graph=tree, populations=populations, ideal_pop=10, epsilon=0.1 + populated_tree_nx = PopulatedGraph( + graph=tree_nx, populations=populations, ideal_pop=10, epsilon=0.1 + ) + populated_tree_rx = PopulatedGraph( + graph=tree_rx, populations=populations, ideal_pop=10, epsilon=0.1 ) - cuts = find_balanced_edge_cuts_contraction(populated_tree, one_sided_cut=False) - assert cuts == [] + cuts_nx = find_balanced_edge_cuts_contraction(populated_tree_nx, one_sided_cut=False) + assert cuts_nx == [] + + cuts_rx = find_balanced_edge_cuts_contraction(populated_tree_rx, one_sided_cut=False) + assert cuts_rx == [] def test_find_balanced_cuts_memo(): - tree = Graph.from_networkx( - networkx.Graph([(0, 1), (1, 2), (1, 4), (3, 4), (4, 5), (3, 6), (6, 7), (6, 8)]) - ) - # 0 - 1 - 2 - # || - # 3= 4 - 5 - # || - # 6- 7 - # | - # 8 + list_of_nodes_nx = [(0, 1), (1, 2), (1, 4), (3, 4), (4, 5), (3, 6), (6, 7), (6, 8)] - populated_tree = PopulatedGraph( - tree, {node: 1 for node in tree}, len(tree) / 2, 0.5 + # For this test we are not dealing with an RX-based Graph object + # that is derived fromn an NX-based Graph object, so the + # nx_to_rx_node_id_map can just be the identity map... + # + nx_to_rx_node_id_map = { node: node for node in range(9) } + + tree_nx, tree_rx = \ + create_graphs_from_nx_edges( + 9, + list_of_nodes_nx, + nx_to_rx_node_id_map + ) + + # 0 - 1 - 2 + # | + # 4 - 3 + # | | + # 5 6 - 7 + # | + # 8 + + populated_tree_nx = PopulatedGraph( + tree_nx, {node: 1 for node in tree_nx}, len(tree_nx) / 2, 0.5 ) - cuts = find_balanced_edge_cuts_memoization(populated_tree) - edges = set(tuple(sorted(cut.edge)) for cut in cuts) - assert edges == {(1, 4), (3, 4), (3, 6)} + populated_tree_rx = PopulatedGraph( + tree_rx, {node: 1 for node in tree_rx}, len(tree_rx) / 2, 0.5 + ) + + cuts_nx = find_balanced_edge_cuts_memoization(populated_tree_nx) + edges_nx = set(tuple(sorted(cut.edge)) for cut in cuts_nx) + assert edges_nx == {(1, 4), (3, 4), (3, 6)} + + cuts_rx = find_balanced_edge_cuts_memoization(populated_tree_rx) + edges_rx = set(tuple(sorted(cut.edge)) for cut in cuts_rx) + assert edges_rx == {(1, 4), (3, 4), (3, 6)} def test_no_balanced_cuts_memo_when_one_side_okay(): - tree = Graph.from_networkx(networkx.Graph([(0, 1), (1, 2), (2, 3), (3, 4)])) + list_of_nodes_nx = ([(0, 1), (1, 2), (2, 3), (3, 4)]) + + # For this test we are not dealing with an RX-based Graph object + # that is derived fromn an NX-based Graph object, so the + # nx_to_rx_node_id_map can just be the identity map... + # + nx_to_rx_node_id_map = { node: node for node in range(5) } + + tree_nx, tree_rx = \ + create_graphs_from_nx_edges( + 5, + list_of_nodes_nx, + nx_to_rx_node_id_map + ) + + # OK to use the same populations with both NX and RX Graphs populations = {0: 4, 1: 4, 2: 3, 3: 3, 4: 3} - populated_tree = PopulatedGraph( - graph=tree, populations=populations, ideal_pop=10, epsilon=0.1 + populated_tree_nx = PopulatedGraph( + graph=tree_nx, populations=populations, ideal_pop=10, epsilon=0.1 + ) + populated_tree_rx = PopulatedGraph( + graph=tree_rx, populations=populations, ideal_pop=10, epsilon=0.1 ) - cuts = find_balanced_edge_cuts_memoization(populated_tree) - assert cuts == [] + cuts_nx = find_balanced_edge_cuts_memoization(populated_tree_nx) + assert cuts_nx == [] + + cuts_rx = find_balanced_edge_cuts_memoization(populated_tree_rx) + assert cuts_rx == [] def test_prime_bound(): @@ -394,17 +719,4 @@ def test_prime_bound(): ) -def test_bipartition_tree_random_returns_a_subset_of_nodes(graph_with_pop): - ideal_pop = sum(graph_with_pop.nodes[node]["pop"] for node in graph_with_pop) / 2 - result = bipartition_tree_random(graph_with_pop, "pop", ideal_pop, 0.25, 10) - assert isinstance(result, frozenset) - assert all(node in graph_with_pop.nodes for node in result) - -def test_bipartition_tree_random_returns_within_epsilon_of_target_pop(graph_with_pop): - ideal_pop = sum(graph_with_pop.nodes[node]["pop"] for node in graph_with_pop) / 2 - epsilon = 0.25 - result = bipartition_tree_random(graph_with_pop, "pop", ideal_pop, epsilon, 10) - - part_pop = sum(graph_with_pop.nodes[node]["pop"] for node in result) - assert abs(part_pop - ideal_pop) / ideal_pop < epsilon diff --git a/tests/updaters/dbg.py b/tests/updaters/dbg.py new file mode 100644 index 00000000..41c19b30 --- /dev/null +++ b/tests/updaters/dbg.py @@ -0,0 +1,79 @@ +import math + +import networkx +import pytest + +from gerrychain import MarkovChain +from gerrychain.constraints import Validator, no_vanishing_districts +from gerrychain.graph import Graph +from gerrychain.partition import Partition +from gerrychain.proposals import propose_random_flip +import random +from gerrychain.updaters import (Election, Tally, boundary_nodes, cut_edges, + cut_edges_by_part, exterior_boundaries, + exterior_boundaries_as_a_set, + interior_boundaries, perimeter) +from gerrychain.updaters.election import ElectionResults +random.seed(2018) + + +def create_three_by_three_grid(): + """Returns a graph that looks like this: + 0 1 2 + 3 4 5 + 6 7 8 + """ + nx_graph = networkx.Graph() + nx_graph.add_edges_from( + [ + (0, 1), + (0, 3), + (1, 2), + (1, 4), + (2, 5), + (3, 4), + (3, 6), + (4, 5), + (4, 7), + (5, 8), + (6, 7), + (7, 8), + ] + ) + return Graph.from_networkx(nx_graph) + + + + + +def random_assignment(graph, num_districts): + assignment = {node: random.choice(range(num_districts)) for node in graph.nodes} + # Make sure that there are cut edges: + while len(set(assignment.values())) == 1: + assignment = {node: random.choice(range(num_districts)) for node in graph.nodes} + return assignment + + + +def test_vote_proportion_returns_nan_if_total_votes_is_zero(three_by_three_grid): + election = Election("Mock Election", ["D", "R"], alias="election") + graph = three_by_three_grid + + for node in graph.nodes: + for col in election.columns: + graph.node_data(node)[col] = 0 + + updaters = {"election": election} + assignment = random_assignment(graph, 3) + + partition = Partition(graph, assignment, updaters) + + assert all( + math.isnan(value) + for party_percents in partition["election"].percents_for_party.values() + for value in party_percents.values() + ) + + +three_by_three_grid = create_three_by_three_grid() +test_vote_proportion_returns_nan_if_total_votes_is_zero(three_by_three_grid) diff --git a/tests/updaters/test_cut_edges.py b/tests/updaters/test_cut_edges.py index e6582f41..a308482d 100644 --- a/tests/updaters/test_cut_edges.py +++ b/tests/updaters/test_cut_edges.py @@ -27,6 +27,13 @@ def invalid_cut_edges(partition): ] return invalid +def translate_flips_to_internal_node_ids(partition, flips): + # Translate flips into the internal_node_ids for the partition + internal_flips = {} + for original_nx_node_id, part in flips.items(): + internal_node_id = partition.graph.internal_node_id_for_original_nx_node_id(original_nx_node_id) + internal_flips[internal_node_id] = part + return internal_flips def test_cut_edges_doesnt_duplicate_edges_with_different_order_of_nodes( three_by_three_grid, @@ -39,10 +46,13 @@ def test_cut_edges_doesnt_duplicate_edges_with_different_order_of_nodes( # 222 222 flip = {4: 2, 2: 1, 5: 1} - new_partition = Partition(parent=partition, flips=flip) + internal_flips = translate_flips_to_internal_node_ids(partition, flip) + + new_partition = Partition(parent=partition, flips=internal_flips) result = new_partition["cut_edges"] + # Verify that the same edge is not in the result twice (just in different node_id order) for edge in result: assert (edge[1], edge[0]) not in result @@ -56,13 +66,16 @@ def test_cut_edges_can_handle_multiple_flips(three_by_three_grid): # 222 222 flip = {4: 2, 2: 1, 5: 1} - new_partition = Partition(parent=partition, flips=flip) + internal_flips = translate_flips_to_internal_node_ids(partition, flip) + + new_partition = Partition(parent=partition, flips=internal_flips) result = new_partition["cut_edges"] naive_cut_edges = { - tuple(sorted(edge)) for edge in graph.edges if new_partition.crosses_parts(edge) + tuple(sorted(edge)) for edge in partition.graph.edges if new_partition.crosses_parts(edge) } + assert result == naive_cut_edges @@ -78,7 +91,9 @@ def test_cut_edges_by_part_doesnt_duplicate_edges_with_opposite_order_of_nodes( # 222 222 flip = {4: 2, 2: 1, 5: 1} - new_partition = Partition(parent=partition, flips=flip) + internal_flips = translate_flips_to_internal_node_ids(partition, flip) + + new_partition = Partition(parent=partition, flips=internal_flips) result = new_partition["cut_edges_by_part"] @@ -97,11 +112,13 @@ def test_cut_edges_by_part_gives_same_total_edges_as_naive_method(three_by_three # 222 222 flip = {4: 2, 2: 1, 5: 1} - new_partition = Partition(parent=partition, flips=flip) + internal_flips = translate_flips_to_internal_node_ids(partition, flip) + + new_partition = Partition(parent=partition, flips=internal_flips) result = new_partition["cut_edges_by_part"] naive_cut_edges = { - tuple(sorted(edge)) for edge in graph.edges if new_partition.crosses_parts(edge) + tuple(sorted(edge)) for edge in partition.graph.edges if new_partition.crosses_parts(edge) } assert naive_cut_edges == { @@ -115,11 +132,15 @@ def test_implementation_of_cut_edges_matches_naive_method(three_by_three_grid): partition = Partition(graph, assignment, {"cut_edges": cut_edges}) flip = {4: 2} - new_partition = Partition(parent=partition, flips=flip) + + internal_flips = translate_flips_to_internal_node_ids(partition, flip) + + new_partition = Partition(parent=partition, flips=internal_flips) + result = cut_edges(new_partition) naive_cut_edges = { - edge for edge in graph.edges if new_partition.crosses_parts(edge) + edge for edge in partition.graph.edges if new_partition.crosses_parts(edge) } assert edge_set_equal(result, naive_cut_edges) diff --git a/tests/updaters/test_perimeters.py b/tests/updaters/test_perimeters.py index 05c1f156..d98d1f5a 100644 --- a/tests/updaters/test_perimeters.py +++ b/tests/updaters/test_perimeters.py @@ -8,8 +8,13 @@ def setup(): + + # Note that the node_ids for the NX graph for a grid are tuples with the (x,y) position of the node + grid = Grid((4, 4), with_diagonals=False) - flipped_grid = grid.flip({(2, 1): 3}) + + flipped_grid = grid.flip({(2, 1): 3}, use_original_nx_node_ids=True) + return grid, flipped_grid @@ -30,29 +35,47 @@ def test_interior_perimeter_handles_flips_with_a_simple_grid(): def test_cut_edges_by_part_handles_flips_with_a_simple_grid(): + + # frm: TODO: Testing: Add a graphic here + # + # That will allow the person reading this code to make sense + # of what it does... + # grid, flipped_grid = setup() result = flipped_grid["cut_edges_by_part"] - assert result[0] == { + # Translate internal edges so that they can be compared to the literals below + new_result = {} + for part, set_of_edges in result.items(): + new_set_of_edges = set() + for edge in set_of_edges: + new_edge = ( + flipped_grid.graph.original_nx_node_id_for_internal_node_id(edge[0]), + flipped_grid.graph.original_nx_node_id_for_internal_node_id(edge[1]), + ) + new_set_of_edges.add(new_edge) + new_result[part] = new_set_of_edges + + assert new_result[0] == { ((1, 0), (2, 0)), ((1, 1), (2, 1)), ((0, 1), (0, 2)), ((1, 1), (1, 2)), } - assert result[1] == { + assert new_result[1] == { ((1, 0), (2, 0)), ((2, 0), (2, 1)), ((2, 1), (3, 1)), ((3, 1), (3, 2)), } - assert result[2] == { + assert new_result[2] == { ((0, 1), (0, 2)), ((1, 1), (1, 2)), ((1, 2), (2, 2)), ((1, 3), (2, 3)), } - assert result[3] == { + assert new_result[3] == { ((1, 1), (2, 1)), ((2, 0), (2, 1)), ((2, 1), (3, 1)), @@ -99,9 +122,9 @@ def test_perimeter_match_naive_perimeter_at_every_step(): def get_exterior_boundaries(partition): graph_boundary = partition["boundary_nodes"] exterior = defaultdict(lambda: 0) - for node in graph_boundary: - part = partition.assignment[node] - exterior[part] += partition.graph.nodes[node]["boundary_perim"] + for node_id in graph_boundary: + part = partition.assignment[node_id] + exterior[part] += partition.graph.node_data(node_id)["boundary_perim"] return exterior def get_interior_boundaries(partition): @@ -111,9 +134,9 @@ def get_interior_boundaries(partition): interior = defaultdict(int) for edge in cut_edges: for node in edge: - interior[partition.assignment[node]] += partition.graph.edges[edge][ - "shared_perim" - ] + interior[partition.assignment[node]] += partition.graph.edge_data( + partition.graph.get_edge_id_from_edge(edge) + )["shared_perim"] return interior def expected_perimeter(partition): diff --git a/tests/updaters/test_split_scores.py b/tests/updaters/test_split_scores.py index c26a32de..27d37504 100644 --- a/tests/updaters/test_split_scores.py +++ b/tests/updaters/test_split_scores.py @@ -4,6 +4,7 @@ from gerrychain.updaters.locality_split_scores import LocalitySplits from gerrychain.updaters.cut_edges import cut_edges from gerrychain import Graph +import networkx @pytest.fixture def three_by_three_grid(): @@ -12,8 +13,8 @@ def three_by_three_grid(): 3 4 5 6 7 8 """ - graph = Graph() - graph.add_edges_from( + nx_graph = networkx.Graph() + nx_graph.add_edges_from( [ (0, 1), (0, 3), @@ -29,20 +30,21 @@ def three_by_three_grid(): (7, 8), ] ) + graph = Graph.from_networkx(nx_graph) return graph @pytest.fixture def graph_with_counties(three_by_three_grid): for node in [0, 1, 2]: - three_by_three_grid.nodes[node]["county"] = "a" - three_by_three_grid.nodes[node]["pop"] = 1 + three_by_three_grid.node_data(node)["county"] = "a" + three_by_three_grid.node_data(node)["pop"] = 1 for node in [3, 4, 5]: - three_by_three_grid.nodes[node]["county"] = "b" - three_by_three_grid.nodes[node]["pop"] = 1 + three_by_three_grid.node_data(node)["county"] = "b" + three_by_three_grid.node_data(node)["pop"] = 1 for node in [6, 7, 8]: - three_by_three_grid.nodes[node]["county"] = "c" - three_by_three_grid.nodes[node]["pop"] = 1 + three_by_three_grid.node_data(node)["county"] = "c" + three_by_three_grid.node_data(node)["pop"] = 1 return three_by_three_grid @@ -69,11 +71,6 @@ def split_partition(graph_with_counties): ) return partition - - - - - class TestSplittingScores: diff --git a/tests/updaters/test_splits.py b/tests/updaters/test_splits.py index 1b6c26fa..62f6e395 100644 --- a/tests/updaters/test_splits.py +++ b/tests/updaters/test_splits.py @@ -9,11 +9,11 @@ @pytest.fixture def graph_with_counties(three_by_three_grid): for node in [0, 1, 2]: - three_by_three_grid.nodes[node]["county"] = "a" + three_by_three_grid.node_data(node)["county"] = "a" for node in [3, 4, 5]: - three_by_three_grid.nodes[node]["county"] = "b" + three_by_three_grid.node_data(node)["county"] = "b" for node in [6, 7, 8]: - three_by_three_grid.nodes[node]["county"] = "c" + three_by_three_grid.node_data(node)["county"] = "c" return three_by_three_grid @@ -43,12 +43,14 @@ def test_describes_splits_for_all_counties(self, partition): assert set(result.keys()) == {"a", "b", "c"} - after_a_flip = partition.flip({3: 1}) + after_a_flip = partition.flip({3: 1}, use_original_nx_node_ids=True) second_result = after_a_flip["splits"] assert set(second_result.keys()) == {"a", "b", "c"} def test_no_splits(self, graph_with_counties): + + # frm: TODO: Testing: Why does this not just use "split_partition"? Isn't it the same? partition = Partition(graph_with_counties, assignment="county") result = compute_county_splits(partition, "county", None) @@ -57,7 +59,9 @@ def test_no_splits(self, graph_with_counties): assert splits_info.split == CountySplit.NOT_SPLIT def test_new_split(self, partition): - after_a_flip = partition.flip({3: 1}) + # Do a flip, using the node_ids of the original assignment (rather than the + # node_ids used internally in the RX-based graph) + after_a_flip = partition.flip({3: 1}, use_original_nx_node_ids=True) result = after_a_flip["splits"] # County b is now split, but a and c are not @@ -74,7 +78,9 @@ def test_initial_split(self, split_partition): assert result["c"].split == CountySplit.NOT_SPLIT def test_old_split(self, split_partition): - after_a_flip = split_partition.flip({4: 1}) + # Do a flip, using the node_ids of the original assignment (rather than the + # node_ids used internally in the RX-based graph) + after_a_flip = split_partition.flip({4: 1}, use_original_nx_node_ids=True) result = after_a_flip["splits"] # County b becomes more split @@ -87,11 +93,11 @@ def test_old_split(self, split_partition): "previous partition, which is not the intuitive behavior." ) def test_initial_split_that_disappears_and_comes_back(self, split_partition): - no_splits = split_partition.flip({3: 2}) + no_splits = split_partition.flip({3: 2}, use_original_nx_node_ids=True) result = no_splits["splits"] assert all(info.split == CountySplit.NOT_SPLIT for info in result.values()) - split_comes_back = no_splits.flip({3: 1}) + split_comes_back = no_splits.flip({3: 1}, use_original_nx_node_ids=True) new_result = split_comes_back["splits"] assert new_result["a"].split == CountySplit.NOT_SPLIT assert new_result["b"].split == CountySplit.OLD_SPLIT diff --git a/tests/updaters/test_updaters.py b/tests/updaters/test_updaters.py index 37a4b97e..2855cb4d 100644 --- a/tests/updaters/test_updaters.py +++ b/tests/updaters/test_updaters.py @@ -33,11 +33,10 @@ def random_assignment(graph, num_districts): def partition_with_election(graph_with_d_and_r_cols): graph = graph_with_d_and_r_cols assignment = random_assignment(graph, 3) - parties_to_columns = { - "D": {node: graph.nodes[node]["D"] for node in graph.nodes}, - "R": {node: graph.nodes[node]["R"] for node in graph.nodes}, - } - election = Election("Mock Election", parties_to_columns) + + party_names_to_node_attribute_names = ["D", "R"] + + election = Election("Mock Election", party_names_to_node_attribute_names) updaters = {"Mock Election": election, "cut_edges": cut_edges} return Partition(graph, assignment, updaters) @@ -54,24 +53,36 @@ def chain_with_election(partition_with_election): def test_Partition_can_update_stats(): - graph = networkx.complete_graph(3) + nx_graph = networkx.complete_graph(3) assignment = {0: 1, 1: 1, 2: 2} - graph.nodes[0]["stat"] = 1 - graph.nodes[1]["stat"] = 2 - graph.nodes[2]["stat"] = 3 + nx_graph.nodes[0]["stat"] = 1 + nx_graph.nodes[1]["stat"] = 2 + nx_graph.nodes[2]["stat"] = 7 + + graph = Graph.from_networkx(nx_graph) updaters = {"total_stat": Tally("stat", alias="total_stat")} - partition = Partition(Graph.from_networkx(graph), assignment, updaters) - assert partition["total_stat"][2] == 3 + # This test is complicated by the fact that "original" node_ids are typically based + # on the node_ids for NX-based graphs, so in this test's case, those would be: 0, 1, 2 . + # However, when we create a Partition, we convert to an RX-based graph object and + # as a result the internal node_ids for the RX-based graph change. So, when we ask + # for graph data from a partition we need to be careful to use its internal node_ids. + + # Verify that the "total_stat" for the part (district) 2 is 7 + partition = Partition(graph, assignment, updaters) + assert partition["total_stat"][2] == 7 + + # Flip node with original node_id of 1 to be in part (district) 2 flip = {1: 2} - new_partition = partition.flip(flip) - assert new_partition["total_stat"][2] == 5 + new_partition = partition.flip(flip, use_original_nx_node_ids=True) + + assert new_partition["total_stat"][2] == 9 -def test_tally_multiple_columns(graph_with_d_and_r_cols): +def test_tally_multiple_node_attribute_names(graph_with_d_and_r_cols): graph = graph_with_d_and_r_cols updaters = {"total": Tally(["D", "R"], alias="total")} @@ -79,7 +90,7 @@ def test_tally_multiple_columns(graph_with_d_and_r_cols): partition = Partition(graph, assignment, updaters) expected_total_in_district_one = sum( - graph.nodes[i]["D"] + graph.nodes[i]["R"] for i in range(4) + graph.node_data(i)["D"] + graph.node_data(i)["R"] for i in range(4) ) assert partition["total"][1] == expected_total_in_district_one @@ -103,12 +114,13 @@ def test_vote_proportion_updater_returns_percentage_or_nan(partition_with_electi def test_vote_proportion_returns_nan_if_total_votes_is_zero(three_by_three_grid): + election = Election("Mock Election", ["D", "R"], alias="election") graph = three_by_three_grid for node in graph.nodes: - for col in election.columns: - graph.nodes[node][col] = 0 + for col in election.node_attribute_names: + graph.node_data(node)[col] = 0 updaters = {"election": election} assignment = random_assignment(graph, 3) @@ -179,12 +191,41 @@ def test_election_result_has_a_cute_str_method(): assert str(results) == expected +def _convert_dict_of_set_of_rx_node_ids_to_set_of_nx_node_ids(dict_of_set_of_rx_nodes, nx_to_rx_node_id_map): + + # frm: TODO: Testing: This way to convert node_ids is clumsy and inconvenient. Think of something better... + + # When we create a partition from an NX based Graph we convert it to be an + # RX based Graph which changes the node_ids of the graph. If one wants + # to convert sets of RX based graph node_ids back to the node_ids in the + # original NX Graph, then we can do so by taking advantage of the + # nx_to_rx_node_id_map that is generated and saved when we converted the + # NX based graph to be based on RX + # + # This routine converts the data that some updaters create - namely a mapping from + # partitions to a set of node_ids. + + converted_set = {} + if nx_to_rx_node_id_map is not None: # means graph was converted from NX + # reverse the map + rx_to_nx_node_id_map = {value: key for key, value in nx_to_rx_node_id_map.items()} + converted_set = {} + for part, set_of_rx_nodes in dict_of_set_of_rx_nodes.items(): + converted_set_of_rx_nodes = {rx_to_nx_node_id_map[rx_node_id] for rx_node_id in set_of_rx_nodes} + converted_set[part] = converted_set_of_rx_nodes + # converted_set = { + # part: {rx_to_nx_node_id_map[rx_node_id]} + # for part, set_of_rx_node_ids in dict_of_set_of_rx_nodes.items() + # for rx_node_id in set_of_rx_node_ids + # } + return converted_set + def test_exterior_boundaries_as_a_set(three_by_three_grid): graph = three_by_three_grid for i in [0, 1, 2, 3, 5, 6, 7, 8]: - graph.nodes[i]["boundary_node"] = True - graph.nodes[4]["boundary_node"] = False + graph.node_data(i)["boundary_node"] = True + graph.node_data(4)["boundary_node"] = False assignment = {0: 1, 1: 1, 2: 2, 3: 1, 4: 1, 5: 2, 6: 2, 7: 2, 8: 2} updaters = { @@ -194,27 +235,61 @@ def test_exterior_boundaries_as_a_set(three_by_three_grid): partition = Partition(graph, assignment, updaters) result = partition["exterior_boundaries_as_a_set"] - assert result[1] == {0, 1, 3} and result[2] == {2, 5, 6, 7, 8} - # 112 111 - # 112 -> 121 - # 222 222 - flips = {4: 2, 2: 1, 5: 1} + # frm: TODO: Testing: Come up with a nice way to convert the result which uses + # RX based node_ids back to the original NX based node_ids... + + # If the original graph that the partition was based on was an NX graph + # then we need to convert the RX node_ids in the partition's graph + # back to what they were in the NX graph. + nx_to_rx_node_id_map = partition.graph.get_nx_to_rx_node_id_map() + if nx_to_rx_node_id_map is not None: + converted_result = _convert_dict_of_set_of_rx_node_ids_to_set_of_nx_node_ids(result, nx_to_rx_node_id_map) + result = converted_result - new_partition = Partition(parent=partition, flips=flips) + assert result[1] == {0, 1, 3} and result[2] == {2, 5, 6, 7, 8} + + # Flip nodes and then recompute partition + # boundaries to make sure the updater works properly. + # The new partition map will look like this: + # + # 112 111 + # 112 -> 121 + # 222 222 + # + # In terms of the original NX graph's node_ids, we would + # do the following flips: 4->2, 2->1, and 5->1 + # + # However, the node_ids in the partition's graph have changed due to + # conversion to RX, so we need to translate the flips into RX node_ids + + nx_flips = {4: 2, 2: 1, 5: 1} + rx_to_nx_node_id_map = {v: k for k,v in nx_to_rx_node_id_map.items()} + rx_flips = {rx_to_nx_node_id_map[nx_node_id]: part for nx_node_id, part in nx_flips.items()} + + new_partition = Partition(parent=partition, flips=rx_flips) result = new_partition["exterior_boundaries_as_a_set"] + # If the original graph that the partition was based on was an NX graph + # then we need to convert the RX node_ids in the partition's graph + # back to what they were in the NX graph. + nx_to_rx_node_id_map = new_partition.graph.get_nx_to_rx_node_id_map() + if nx_to_rx_node_id_map is not None: + converted_result = _convert_dict_of_set_of_rx_node_ids_to_set_of_nx_node_ids(result, nx_to_rx_node_id_map) + result = converted_result + assert result[1] == {0, 1, 2, 3, 5} and result[2] == {6, 7, 8} def test_exterior_boundaries(three_by_three_grid): + graph = three_by_three_grid for i in [0, 1, 2, 3, 5, 6, 7, 8]: - graph.nodes[i]["boundary_node"] = True - graph.nodes[i]["boundary_perim"] = 2 - graph.nodes[4]["boundary_node"] = False + graph.node_data(i)["boundary_node"] = True + graph.node_data(i)["boundary_perim"] = 2 + graph.node_data(4)["boundary_node"] = False assignment = {0: 1, 1: 1, 2: 2, 3: 1, 4: 1, 5: 2, 6: 2, 7: 2, 8: 2} updaters = { @@ -229,9 +304,15 @@ def test_exterior_boundaries(three_by_three_grid): # 112 111 # 112 -> 121 # 222 222 - flips = {4: 2, 2: 1, 5: 1} + flips = {4: 2, 2: 1, 5: 1} + + # Convert the flips into internal node_ids + internal_flips = {} + for node_id, part in flips.items(): + internal_node_id = partition.graph.internal_node_id_for_original_nx_node_id(node_id) + internal_flips[internal_node_id] = part - new_partition = Partition(parent=partition, flips=flips) + new_partition = Partition(parent=partition, flips=internal_flips) result = new_partition["exterior_boundaries"] @@ -241,12 +322,13 @@ def test_exterior_boundaries(three_by_three_grid): def test_perimeter(three_by_three_grid): graph = three_by_three_grid for i in [0, 1, 2, 3, 5, 6, 7, 8]: - graph.nodes[i]["boundary_node"] = True - graph.nodes[i]["boundary_perim"] = 1 - graph.nodes[4]["boundary_node"] = False + graph.node_data(i)["boundary_node"] = True + # frm: TODO: Testing: Update test - boundary_perim should be 2 for corner nodes... + graph.node_data(i)["boundary_perim"] = 1 + graph.node_data(4)["boundary_node"] = False for edge in graph.edges: - graph.edges[edge]["shared_perim"] = 1 + graph.edge_data(edge)["shared_perim"] = 1 assignment = {0: 1, 1: 1, 2: 2, 3: 1, 4: 1, 5: 2, 6: 2, 7: 2, 8: 2} updaters = { @@ -275,6 +357,7 @@ def reject_half_of_all_flips(partition): def test_elections_match_the_naive_computation(partition_with_election): + chain = MarkovChain( propose_random_flip, Validator([no_vanishing_districts, reject_half_of_all_flips]), @@ -292,8 +375,8 @@ def test_elections_match_the_naive_computation(partition_with_election): assert expected_party_totals == election_view.totals_for_party -def expected_tally(partition, column): +def expected_tally(partition, node_attribute_name): return { - part: sum(partition.graph.nodes[node][column] for node in nodes) + part: sum(partition.graph.node_data(node)[node_attribute_name] for node in nodes) for part, nodes in partition.parts.items() }