diff --git a/gerrychain/accept.py b/gerrychain/accept.py index 5463f13c..fee5393c 100644 --- a/gerrychain/accept.py +++ b/gerrychain/accept.py @@ -21,7 +21,7 @@ def cut_edge_accept(partition: Partition) -> bool: Always accepts the flip if the number of cut_edges increases. Otherwise, uses the Metropolis criterion to decide. - frm: TODO: Add documentation on what the "Metropolis criterion" is... + frm: TODO: Documentation: Add documentation on what the "Metropolis criterion" is... :param partition: The current partition to accept a flip from. :type partition: Partition diff --git a/gerrychain/constraints/contiguity.py b/gerrychain/constraints/contiguity.py index 1c92e6f8..283e42e2 100644 --- a/gerrychain/constraints/contiguity.py +++ b/gerrychain/constraints/contiguity.py @@ -8,14 +8,14 @@ from ..graph import Graph -# frm: TODO: Think about the efficiency of the routines in this module. Almost all +# frm: TODO: Performance: Think about the efficiency of the routines in this module. Almost all # of these involve traversing the entire graph, and I fear that callers # might make multiple calls. # # Possible solutions are to 1) speed up these routines somehow and 2) cache # results so that at least we don't do the traversals over and over. -# frm: TODO: Rethink WTF this module is all about. +# frm: TODO: Refactoring: Rethink WTF this module is all about. # # It seems like a grab bag for lots of different things - used in different places. # @@ -23,6 +23,13 @@ # which operates on a partition, but lots of other routines here operate on graphs or # other things. So, what is going on? # +# Peter replied to this comment in a pull request: +# +# So anything that is prefixed with an underscore in here should be a helper +# function and not a part of the public API. It looks like, other than +# is_connected_bfs (which should probably be marked "private" with an +# underscore) everything here is acting like an updater. +# def _are_reachable(graph: Graph, start_node: Any, avoid: Callable, targets: Any) -> bool: @@ -41,7 +48,7 @@ def _are_reachable(graph: Graph, start_node: Any, avoid: Callable, targets: Any) It should take in three parameters: the start node, the end node, and the edges to avoid. It should return True if the edge should be avoided, False otherwise. - # frm: TODO: Fix the comment above about the "avoid" function parameter. + # frm: TODO: Documentation: Fix the comment above about the "avoid" function parameter. # It may have once been accurate, but the original code below # passed parameters to it of (node_id, neighbor_node_id, edge_data_dict) # from NetworkX.Graph._succ So, "the edges to avoid" above is wrong. @@ -112,7 +119,7 @@ def _are_reachable(graph: Graph, start_node: Any, avoid: Callable, targets: Any) seen[neighbor] = neighbor_distance push(fringe, (neighbor_distance, next(c), neighbor)) - # frm: TODO: Simplify this code. It computes distances and counts but + # frm: TODO: Refactoring: Simplify this code. It computes distances and counts but # never uses them. These must be relics of code copied # from somewhere else where it had more uses... @@ -233,11 +240,6 @@ def contiguous(partition: Partition) -> bool: :returns: Whether the partition is contiguous :rtype: bool """ - # frm: Original code: - # - # return all( - # nx.is_connected(partition.subgraphs[part]) for part in _affected_parts(partition) - # ) return all( is_connected_bfs(partition.subgraphs[part]) for part in _affected_parts(partition) @@ -255,16 +257,21 @@ def contiguous_bfs(partition: Partition) -> bool: :rtype: bool """ - # frm: TODO: Try to figure out why this routine exists. It seems to be - # exactly the same conceptually as contiguous(). It looks - # at the "affected" parts - those that have changed node - # assignments from parent, and sees if those parts are - # contiguous. + # frm: TODO: Refactoring: Figure out why this routine, contiguous_bfs() exists. # - # For now, I have just replaced the existing code which depended - # on NX with a call on contiguous(partition). + # It is mentioned in __init__.py so maybe it is used externally in legacy code. + # + # However, I have changed the code so that it just calls contiguous() and all + # of the tests pass, so I am going to assume that my comment below is accurate, + # that is, I am assuming that this function does not need to exist independently + # except for legacy purposes. Stated differently, if someone can verify that + # this routine is NOT needed for legacy purposes, then we can just delete it. + # + # It seems to be exactly the same conceptually as contiguous(). It looks + # at the "affected" parts - those that have changed node + # assignments from parent, and sees if those parts are + # contiguous. # - # frm: Original Code: # # parts_to_check = _affected_parts(partition) @@ -288,10 +295,6 @@ def number_of_contiguous_parts(partition: Partition) -> int: :returns: Number of contiguous parts in the partition. :rtype: int """ - # frm: Original Code: - # parts = partition.assignment.parts - # return sum(1 for part in parts if nx.is_connected(partition.subgraphs[part])) - # parts = partition.assignment.parts return sum(1 for part in parts if is_connected_bfs(partition.subgraphs[part])) @@ -316,7 +319,7 @@ def contiguous_components(partition: Partition) -> Dict[int, list]: :rtype: dict """ - # frm: TODO: NX vs RX Issues here: + # frm: TODO: Documentation: Migration Guide: NX vs RX Issues here: # # The call on subgraph() below is perhaps problematic because it will renumber # node_ids... @@ -333,12 +336,6 @@ def contiguous_components(partition: Partition) -> Dict[int, list]: # 3) From each part's subgraph to the subgraphs of contiguous_components... # - # frm: Original Code: - # return { - # part: [subgraph.subgraph(nodes) for nodes in nx.connected_components(subgraph)] - # for part, subgraph in partition.subgraphs.items() - # } - # connected_components_in_each_partition = {} for part, subgraph in partition.subgraphs.items(): # create a subgraph for each set of connected nodes in the part's nodes @@ -379,11 +376,11 @@ def _bfs(graph: Dict[int, list]) -> bool: return num_nodes == len(visited) -# frm: TODO: Verify that is_connected_bfs() works - add a test or two... +# frm: TODO: Testing: Verify that is_connected_bfs() works - add a test or two... -# frm: TODO: Move this code into graph.py. It is all about the Graph... +# frm: TODO: Refactoring: Move this code into graph.py. It is all about the Graph... -# frm: Code obtained from the web - probably could be optimized... +# frm: TODO: Documentation: This code was obtained from the web - probably could be optimized... # This code replaced calls on nx.is_connected() def is_connected_bfs(graph: Graph): if not graph: diff --git a/gerrychain/graph/graph.py b/gerrychain/graph/graph.py index 3dc58313..c3b3faf0 100644 --- a/gerrychain/graph/graph.py +++ b/gerrychain/graph/graph.py @@ -9,6 +9,8 @@ Note: This module relies on NetworkX, pandas, and geopandas, which should be installed and imported as required. + +TODO: Documentation: Update top-level documentation for graph.py """ import functools @@ -21,12 +23,13 @@ from networkx.readwrite import json_graph import pandas as pd -# frm: added to support RustworkX graphs (in the future) import rustworkx from .adjacency import neighbors from .geo import GeometryError, invalid_geometries, reprojected -from typing import List, Iterable, Optional, Set, Tuple, Union + +# frm: codereview note: removed type hints that are now baked into Python +from typing import Iterable, Optional, Generator, Union import geopandas as gp from shapely.ops import unary_union @@ -35,94 +38,10 @@ import numpy import scipy -######################################################### -# frm Overview of changes (May 2025): -# frm: TODO: Revise (or perhaps just delete) this comment / discussion... -""" -This comment is temporary - it describes the work done to encapsulate dependency on -NetworkX so that this file is the only file that has any NetworkX dependencies. -That work is not completely done - there are bits and bobs of NetworkX -dependencies outside this file, but they are at least commented. In short, -this comment attempts to make clear what I am trying to do. - -The idea is to replace the old Graph object (that was a subclass of NetworkX.Graph) with -a new Graph object that is not a subclass of anything. This new Graph class would look -and act like the old NetworkX based Graph object. Under the covers it would have -either an NX Graph or an RX PyGraph. - -There is a legitimate question - why bother to retain the option to use a NetworkX Graph -as the underlying Graph object, if the user cannot know what the underlying graph object -is? There are two answers: - 1) It seemed possible to me that users took advantage in their own code that the - Graph object was in fact a NetworkX Graph object. If that is so, then we can - make life easier for them by providing them an easy way to gain access to the - internal NetworkX Graph object so they can continue to use that code. - 2) It was a convenient way to evolve the code - I could make changes, but still - have some old NX code that I could use as short-term hacks. It allowed me to - use a regression test to make sure that it all still ran - some of it running - with the new Graph object and some of it hacked to operate on the underlying - NX Graph data member. - -In the future, if #1 is not an issue, we can just expunge NetworkX completely. - -I noticed that the FrozenGraph class had already implemented the behavior of the Graph -class but without being a subclass of the NetworkX Graph object. So, my new Graph class -was based in large part on the FrozenGraph code. It helped me grok property decorators -and __getattr__ and __getattribute__ - interesting Pythonic stuff! - -It is not the case that ALL of the behavior of the NX based Graph class is replicated -in the new Graph class - I have not implemented NodeView and EdgeView functionality -and maybe I will not have to. - -Note that one of the biggest differences between NetworkX and RustworkX is in how nodes -and edges are identified. In NetworkX there is not really a difference between the "index" -of a node or an edge and its "name" or "ID" or "value". In NetworkX the way you index into nodes -and edges is by using the node's name/Id or the edge's tuple - in effect the index and the -name/ID/value are the same. However, in RustworkX, the index is always an integer, and furthermore -the set of indexes for both nodes and edges stats at zero with consecutive integer values. This -is one of the things that allows RustworkX to be faster than NetworkX. Converting to using -RustworkX, therefore, required that the code distinguish between a node/edge's index and its value/ID/name. -This is most visible in the use of node_data() and edge_data() functions and -in the changes made to the use of subgraphs (which unfortunately have different index values for nodes than -the parent graph in RX). - -A note on subgraphs: Creating subgraphs is a fundamental operation for GerryChain. When using NX, -a subgraph's node (and also edge) indexes were unchanged from the parent's, so it was safe to do -calculations on a subgraph and pass back node information (like flips). However, when using RX, the -node and edge indexes change, so in order to pass back information in the parent's index systems, -the subgraph's nodes and edges need to be translated back into those of the parent's index system. -In order to do this, every graph contains two new bits of information 1) whether it is a subgraph and -2) a mapping from the subgraph index values to those of its parent. For top-level (non subgraphs), this -mapping is just an identity mapping - this is just a convenience so that routines can always use the -map without having to worry about whether it is a subgraph or not. - -The current state of affairs (early May 2025) is that the code in tree.py has mostly -been converted to use the new Graph object instead of nx.Graph, and that the regression -test works (which only tests some of the functionality, but it does run a chain...) - -I have left the original code for the old Graph object in the file so that I could test -that the original and the new code behave the same way - see tests/frm_tests/test_frm_old_vs_new_graph.py -These frm_tests are not yet configured to run as pytest tests, but they soon will be. -I will add additional tests here over time. - -Most of the NetworkX dependencies that remain are on NX algorithms (like is_connected() and -laplacian_matrix()). These need to be replaced with functions that work on RustworkX Graphs. -I have not yet determined whether they all need to work on both NX and RX graphs - if they -only ever need to work on graphs inside Paritions, then they only need to work for RX, but -it may be convenient to have them work both ways - needs some thought, and it might be easier -to just provide compatibility to cover any edge case that I can't think of... - -After getting rid of all NX dependencies outside this file, it will be time to switch to -RX which will involve: - - 1) Creating RX versions of NX functionality - such as laplacian_matrix(). There are - lots of comments in the code saying: # frm TODO: RX version NYI... - - 2) Adding code so that when we "freeze" a graph, we also convert it to RX. - -""" -######################################################### - +# frm: TODO: Refactor: Move json_serialize() closer to its use. +# +# It should not be the first thing someone sees when looking at this code... +# def json_serialize(input_object: Any) -> Optional[int]: """ This function is used to handle one of the common issues that @@ -146,7 +65,7 @@ def json_serialize(input_object: Any) -> Optional[int]: class Graph: """ - frm TODO: Clean up this documentation + frm TODO: Documentation: Clean up this documentation frm: this class encapsulates / hides the underlying graph which can either be a NetworkX graph or a RustworkX graph. The intent is that it provides the same @@ -166,19 +85,13 @@ class Graph: """ - # frm: TODO: Update the comment below - making sure it is 100% accurate (and useful) - # frm: This class cannot have a constructor - because there is code that assumes - # that it can use the default constructor to create instances of it. - # That code is buried deep in non GerryChain code, so I don't really understand - # what it is doing, but the assignment of nx_graph and rx_graph class attributes/members - # needs to happen in the "from_xxx()" routines. - # - # def __init__(self, nx_graph: networkx.Graph, rx_graph: rustworkx.PyGraph) -> None: - # # frm TODO: check that exactly one param is not None - need one and only one graph... - # self._nx_graph = nx_graph - # self._rx_graph = rx_graph + # Note: This class cannot have a constructor - because there is code that assumes + # that it can use the default constructor to create instances of it. + # That code is buried deep in non GerryChain code, so I don't really understand + # what it is doing, but the assignment of nx_graph and rx_graph class attributes/members + # needs to happen in the "from_xxx()" routines. - # frm: TODO: Add documentation for new data members I am adding: + # frm: TODO: Documentation: Add documentation for new data members I am adding: # _nx_graph, _rx_graph, _node_id_to_parent_node_id_map, _is_a_subgraph # _node_id_to_original_nx_node_id_map # => used to recreate NX graph from an RX graph and also @@ -186,68 +99,184 @@ class Graph: @classmethod def from_networkx(cls, nx_graph: networkx.Graph) -> "Graph": + """ + Create a :class:`Graph` from a NetworkX.Graph object + + This supports the use case of users creating a graph using NetworkX + which is convenient - both for users of the previous implementation of + a GerryChain object which was a subclass of NetworkX.Graph and for + users more generally who are familiar with NetworkX. + + Note that most users will not ever call this function directly, + because they can create a GerryChain Partition object directly + from a NetworkX graph, and the Partition initialization code + will use this function to convert the NetworkX graph to a + GerryChain Graph object. + + :param nx_graph: A NetworkX.Graph object with node and edge data + to be converted into a GerryChain Graph object. + :type nx_graph: networkx.Graph + + :returns: ...text... + :rtype: + """ graph = cls() graph._nx_graph = nx_graph graph._rx_graph = None graph._is_a_subgraph = False # See comments on RX subgraph issues. # Maps node_ids in the graph to the "parent" node_ids in the parent graph. # For top-level graphs, this is just an identity map - graph._node_id_to_parent_node_id_map = {node: node for node in graph.node_indices} + graph._node_id_to_parent_node_id_map = {node_id: node_id for node_id in graph.node_indices} # Maps node_ids in the graph to the "original" node_ids in parent graph. # For top-level graphs, this is just an identity map - graph._node_id_to_original_nx_node_id_map = {node: node for node in graph.node_indices} + graph._node_id_to_original_nx_node_id_map = {node_id: node_id for node_id in graph.node_indices} graph.nx_to_rx_node_id_map = None # only set when an NX based graph is converted to be an RX based graph return graph @classmethod def from_rustworkx(cls, rx_graph: rustworkx.PyGraph) -> "Graph": - # This routine is intended to be used to create top level graphs that - # are 1) not subgraphs and 2) not based on NetworkX Graphs. Stated - # differently, subgraphs and RX graphs derived from NetworkX graphs - # need to create translation maps for node_ids (either to the parent - # of a subgraph or the "original" node in a NetworkX Graph), and this - # routine does neither of those things. - - # frm: TODO: Think about node data dictionaries - do I need to check for them? + """ + + + Create a :class:`Graph` from a RustworkX.PyGraph object + + There are three primary use cases for this routine: + 1) converting an NX-based Graph to be an RX-based + Graph, 2) creating a subgraph of an RX-based Graph, and + 3) creating a Graph whose node_ids do not need to be + mapped to some previous graph's node_ids. + + In a little more detail: + + 1) A typical way to use GerryChain is to create a graph + using NetworkX functionality and to then rely on the + initialization code in the Partition class to create + an RX-based Graph object. That initialization code + constructs a RustworkX PyGraph and then uses this + routine to create an RX-based Graph object, and it then + creates maps from the node_ids of the resulting RX-based + Graph back to the original NetworkX.Graph's node_ids. + + 2) When creating a subgraph of a RustworkX PyGraph + object, the node_ids of the subgraph are (in general) + different from those of the parent graph. So we + create a mapping from the subgraph's node_ids to the + node_ids of the parent. The subgraph() routine + creates a RustworkX PyGraph subgraph, then uses this + routine to create an RX-based Graph using that subgraph, + and it then creates the mapping of subgraph node_ids + to the parent (RX) graph's node_ids. + + 3) In those cases where no node_id mapping is needed + this routine provides a simple way to create an + RX-based GerryChain graph object. + + :param rx_graph: a RustworkX PyGraph object + :type rx_graph: rustworkx.PyGraph + + :returns: a GerryChain Graph object with an embedded RustworkX.PyGraph object + :rtype: "Graph" + """ + + # Ensure that the RX graph has node and edge data dictionaries # - # While NX graphs always have a node data dictionary, RX graphs do not have to - # have any data, and the data does not need to be a data dictionary. Since - # gerrychain code depends on having a data dictionary associated with nodes, - # it probably makes sense to make sure that the rustworkx graph provided - # as a parameter does in fact have a data dictionary for every node in the - # graph... + # While NX graphs always have node and edge data dictionaries, + # the node data for the nodes in RX graphs do not have to be + # a data dictionary - they can be any Python object. Since + # gerrychain code depends on having a data dictionary + # associated with nodes and edges, we need to check the RX + # graph to see if it already has node and edge data and if so, + # whether that node and edge data is a data dictionary. # - # The same applies for edge data... + # Note that there is no way to change the type of the data + # associated with an RX node. So if the data for a node + # is not already a dict then we have an unrecoverable error. # + # However, RX does allow you to update the data for edges, + # so if we find an edge with no data (None), then we can + # create an empty dict for the edge data, and if the edge + # data is some other type, then we can also replace the + # existing edge data with a dict (retaining the original + # data as a value in the new dict) + + for node_id in rx_graph.node_indices(): + data_dict = rx_graph[node_id] + if not isinstance(data_dict, dict): + # Unrecoverable error - see above... + raise Exception("from_rustworkx(): RustworkX graph does not have node_data dictionary") + + for edge_id in rx_graph.edge_indices(): + data_dict = rx_graph.get_edge_data_by_index(edge_id) + if data_dict is None: + # Create an empty dict for edge_data + graph.update_edge_by_index(edge_id, {}) + if not isinstance(data_dict, dict): + # Create a new dict with the existing edge_data as an item + graph.update_edge_by_index(edge_id, {"__original_rx_edge_data": data_dict}) graph = cls() graph._rx_graph = rx_graph graph._nx_graph = None graph._is_a_subgraph = False # See comments on RX subgraph issues. + + # frm: TODO: Documentation: from_rustworkx(): Make these comments more coherent + # + # Instead of these very specific comments, just say that at this + # point, we don't know whether the graph is derived from NX, is a + # subgraph, or is something that can stand alone, so the maps are + # all identity maps. It is responsibility of callers to reset the + # maps if that is appropriate... + # Maps node_ids in the graph to the "parent" node_ids in the parent graph. # For top-level graphs, this is just an identity map - graph._node_id_to_parent_node_id_map = {node: node for node in graph.node_indices} - # Maps node_ids in the graph to the "original" NX node_ids. At this point, we - # don't know if this RX based graph was indeed based on an NX graph, so we just - # set this to the identity map - which makes sense if this was not derived from - # and NX-based Graph. However, in the case when the graph is indeed derived from - # an NX-based Graph, it is the responsibility of the caller to set + graph._node_id_to_parent_node_id_map = {node_id: node_id for node_id in graph.node_indices} + + # This routine assumes that the rx_graph was not derived from an "original" NX + # graph, so the RX node_ids are considered to be the "original" node_ids and + # we create an identity map - each node_id maps to itself as the "original" node_id + # + # If this routine is used for an RX-based Graph that was indeed derived from an + # NX graph, then it is the responsibility of the caller to set # the _node_id_to_original_nx_node_id_map appropriately. - graph._node_id_to_original_nx_node_id_map = {node: node for node in graph.node_indices} - graph.nx_to_rx_node_id_map = None # only set when an NX based graph is converted to be an RX based graph + graph._node_id_to_original_nx_node_id_map = {node_id: node_id for node_id in graph.node_indices} + + # only set when an NX based graph is converted to be an RX based graph + graph.nx_to_rx_node_id_map = None + return graph - def to_networkx_graph(self): + def to_networkx_graph(self) -> networkx.Graph: + """ + Create a NetworkX.Graph object that has the same nodes, edges, + node_data, and edge_data as the GerryChain Graph object. + + The intended purpose of this routine is to allow a user to + run a MarkovChain - which uses an embedded RustworkX graph + and then extract an equivalent version of that graph with all + of its data as a NetworkX.Graph object - in order to use + NetworkX routines to access and manipulate the graph. + + In short, this routine allows users to use NetworkX + functionality on a graph after running a MarkovChain. + + If the GerryChain graph object is NX-based, then this + routine merely returns the embedded NetworkX.Graph object. + + :returns: A NetworkX.Graph object that is equivalent to the + GerryChain Graph object (nodes, edges, node_data, edge_data) + :rtype: networkx.Graph + """ if self.is_nx_graph(): return self.get_nx_graph() if not self.is_rx_graph(): - # frm: TODO: Raise specific exception - type error? - raise Exception("to_networkx: bad Graph type") + raise TypeError( + "Graph passed to 'to_networkx_graph()' must be a rustworkx graph" + ) - # OK - we have an RX-based Graph, so create a NetworkX Graph object + # We have an RX-based Graph, and we want to create a NetworkX Graph object # that has all of the node data and edge data, and which has the - # original node_ids and edge_ids. + # node_ids and edge_ids of the original NX graph. # # Original node_ids are those that were used in the original NX # Graph used to create the RX-based Graph object. @@ -255,7 +284,7 @@ def to_networkx_graph(self): # Confirm that this RX based graph was derived from an NX graph... if self._node_id_to_original_nx_node_id_map == None: - raise Exception("to_networkx_graph: _node_id_to_original_nx_node_id_map is None") + raise Exception("to_networkx_graph(): _node_id_to_original_nx_node_id_map is None") rx_graph = self.get_rx_graph() @@ -293,38 +322,91 @@ def to_networkx_graph(self): # Add all of the node_data, using the "node_name" attr as the NX Graph node_id nodes_df = nodes_df.set_index('node_name') - # frm: TODO: WTF is going on with the line immediately below? - node_data_dict = nodes_df.to_dict(orient='index') networkx.set_node_attributes(nx_graph, nodes_df.to_dict(orient='index')) return nx_graph - # frm: TODO: Create a test for this routine - def original_nx_node_ids_for_set(self, set_of_nodes): - # Utility routine to quickly translate a set of node_ids to their original node_ids + # frm: TODO: Refactoring: Create a defined type name "node_id" to use instead of "Any" + # + # This is purely cosmetic, but it would provide an opportunity to add a comment that + # talked about NX node_ids vs. RX node_ids and hence why the type for a node_id is + # a vague "Any"... + + def original_nx_node_id_for_internal_node_id(self, internal_node_id: Any) -> Any: + """ + Translate a node_id to its "original" node_id. + + :param internal_node_id: A node_id to be translated + :type internal_node_id: Any + + :returns: A translated node_id + :rtype: Any + """ + return self._node_id_to_original_nx_node_id_map[internal_node_id] + + # frm: TODO: Testing: Create a test for this routine + def original_nx_node_ids_for_set(self, set_of_node_ids: set[Any]) -> Any: + """ + Translate a set of node_ids to their "original" node_ids. + + :param set_of_node_ids: A set of node_ids to be translated + :type set_of_node_ids: set[Any] + + :returns: A set of translated node_ids + :rtype: set[Any] + """ _node_id_to_original_nx_node_id_map = self._node_id_to_original_nx_node_id_map - new_set = {_node_id_to_original_nx_node_id_map[node_id] for node_id in set_of_nodes} + new_set = {_node_id_to_original_nx_node_id_map[node_id] for node_id in set_of_node_ids} return new_set - # frm: TODO: Create a test for this routine - def original_nx_node_ids_for_list(self, list_of_nodes): + # frm: TODO: Testing: Create a test for this routine + def original_nx_node_ids_for_list(self, list_of_node_ids: list[Any]) -> list[Any]: + """ + Translate a list of node_ids to their "original" node_ids. + + :param list_of_node_ids: A list of node_ids to be translated + :type list_of_node_ids: list[Any] + + :returns: A list of translated node_ids + :rtype: list[Any] + """ # Utility routine to quickly translate a set of node_ids to their original node_ids _node_id_to_original_nx_node_id_map = self._node_id_to_original_nx_node_id_map - new_list = [_node_id_to_original_nx_node_id_map[node_id] for node_id in list_of_nodes] + new_list = [_node_id_to_original_nx_node_id_map[node_id] for node_id in list_of_node_ids] return new_list - def original_nx_node_id_for_internal_node_id(self, internal_node_id): - return self._node_id_to_original_nx_node_id_map[internal_node_id] + def internal_node_id_for_original_nx_node_id(self, original_nx_node_id: Any) -> Any: + """ + Discover the "internal" node_id in the current GerryChain graph + that corresponds to the "original" node_id in the top-level + graph (presumably an NX-based graph object). + + This was originally created to facilitate testing where it was + convenient to express the test success criteria in terms of + "original" node_ids, but the actual test needed to be made + using the "internal" (RX) node_ids. + + :param original_nx_node_id: The "original" node_id + :type original_nx_node_id: Any - # frm: TODO: Create a test for this routine - def internal_node_id_for_original_nx_node_id(self, original_nx_node_id): - # frm: TODO: Think about a better way to map original_nx_node_ids to internal node_ids + :returns: The corresponding "internal" node_id + :rtype: Any + """ + # Note: TODO: Performance: This code is inefficient but it is not a priority to fix now... + # + # The code reverses the dict that maps internal node_ids to "original" + # node_ids, which has an entry for every node in the graph - hence large + # for large graphs, which is costly, but worse - it does this every time + # it is called, so if the calling code is looping through a list of nodes + # then this reverse dict computation will happen each time. # - # The problem is that when this routine is called, it may often be called repeatedly - # for a list of nodes, and we create the reverse dict every time this is called which - # is needlessly expensive. We could just cache this reverse map, but that is often - # dangerous because we have two sources of truth and if someone needs to update one - # they may forget to update the other... + # The obvious fix is to just create the reverse map once when the "internal" + # graph is created. This would be simple to do and safe, because the + # "internal" graph is frozen. + # + # However, at present (December 2025) this routine is only ever used for + # tests, so I am putting it on the back burner... + # reverse the map so we can go from original node_id to internal node_id orignal_node_id_to_internal_node_id_map = { @@ -332,7 +414,17 @@ def internal_node_id_for_original_nx_node_id(self, original_nx_node_id): } return orignal_node_id_to_internal_node_id_map[original_nx_node_id] - def verify_graph_is_valid(self): + def verify_graph_is_valid(self) -> bool: + """ + Verify that the graph is valid. + + This may be overkill, but the idea is that at least in + development mode, it would be prudent to check periodically + to see that the graph data structure has not been corrupted. + + :returns: True if the graph is deemed valid + :rtype: bool + """ # frm: TODO: Performance: Only check verify_graph_is_valid() in development. # @@ -343,7 +435,7 @@ def verify_graph_is_valid(self): # Sanity check - this is where to add additional sanity checks in the future. - # frm: TODO: Enhance this routine to do more... + # frm: TODO: Code: Enhance verify_graph_is_valid to do more... # frm: TODO: Performance: verify_graph_is_valid() is expensive - called a lot # @@ -355,31 +447,57 @@ def verify_graph_is_valid(self): (self._nx_graph is not None and self._rx_graph is None) or (self._nx_graph is None and self._rx_graph is not None) ): - raise Exception("Graph.verify_graph_is_valid - graph not properly configured") + raise Exception("Graph.verify_graph_is_valid(): graph not properly configured") # frm: TODO: Performance: is_nx_graph() and is_rx_graph() are expensive. # - # Not all of the calls on these routines is needed in production - some are just + # Not all of the calls on these routines are needed in production - some are just # sanity checking. Find a way to NOT run this code when in production. - def is_nx_graph(self): + # frm: TODO: Refactoring: Reorder these following routines in sensible order + + def is_nx_graph(self) -> bool: + """ + Determine if the graph is NX-based + + :rtype: bool + """ # frm: TODO: Performance: Only check graph_is_valid() in production # # Find a clever way to only run this code in development. Commenting it out for now... # self.verify_graph_is_valid() return self._nx_graph is not None - def get_nx_graph(self): + def get_nx_graph(self) -> networkx.Graph: + """ + Return the embedded NX graph object + + :rtype: networkx.Graph + """ if not self.is_nx_graph(): - raise Exception("get_nx_graph - graph is not an NX version of Graph") + raise TypeError( + "Graph passed to 'get_nx_graph()' must be a networkx graph" + ) return self._nx_graph - def get_rx_graph(self): + def get_rx_graph(self) -> rustworkx.PyGraph: + """ + Return the embedded RX graph object + + :rtype: rustworkx.PyGraph + """ if not self.is_rx_graph(): - raise Exception("get_rx_graph - graph is not an RX version of Graph") + raise TypeError( + "Graph passed to 'get_rx_graph()' must be a rustworkx graph" + ) return self._rx_graph - def is_rx_graph(self): + def is_rx_graph(self) -> bool: + """ + Determine if the graph is RX-based + + :rtype: bool + """ # frm: TODO: Performance: Only check graph_is_valid() in production # # Find a clever way to only run this code in development. Commenting it out for now... @@ -387,9 +505,19 @@ def is_rx_graph(self): return self._rx_graph is not None def convert_from_nx_to_rx(self) -> "Graph": - # Return a Graph object which has a RustworkX Graph object as its - # embedded graph object. - # + """ + Convert an NX-based graph object to be an RX-based graph object. + + The primary use case for this routine is support for users + constructing a graph using NetworkX functionality and then + converting that NetworkX graph to RustworkX when creating a + Partition object. + + + :returns: An RX-based graph that is "the same" as the given NX-based graph + :rtype: "Graph" + """ + # Note that in both cases in the if-stmt below, the nodes are not copied. # This is arguably dangerous, but in our case I think it is OK. Stated # differently, the actual node data (the dictionaries) in the original @@ -406,9 +534,9 @@ def convert_from_nx_to_rx(self) -> "Graph": # In the future, it might become useful for other reasons, but until then # to guard against careless uses, the code will insist that it not be a subgraph. - # frm: TODO: Add a comment about the intended use of this routine to its + # frm: TODO: Documentation: Add a comment about the intended use of this routine to its # overview comment above. - raise Exception("convert_from_nx_to_rx: graph to be converted is a subgraph") + raise Exception("convert_from_nx_to_rx(): graph to be converted is a subgraph") nx_graph = self._nx_graph rx_graph = rustworkx.networkx_converter(nx_graph, keep_attributes=True) @@ -436,7 +564,7 @@ def convert_from_nx_to_rx(self) -> "Graph": # We also have to update the _node_id_to_original_nx_node_id_map to refer to the node_ids # in the NX Graph object. _node_id_to_original_nx_node_id_map = {} - for node_id in converted_graph.nodes: + for node_id in converted_graph.node_indices: original_nx_node_id = converted_graph.node_data(node_id)["__networkx_node__"] _node_id_to_original_nx_node_id_map[node_id] = original_nx_node_id converted_graph._node_id_to_original_nx_node_id_map = _node_id_to_original_nx_node_id_map @@ -445,38 +573,76 @@ def convert_from_nx_to_rx(self) -> "Graph": elif self.is_rx_graph(): return self else: - raise Exception("convert_from_nx_to_rx: Bad kind of Graph object") + raise TypeError( + "Graph passed to 'convert_from_nx_to_rx()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) + + def get_nx_to_rx_node_id_map(self) -> dict[Any, Any]: + """ + Return the dict that maps NX node_ids to RX node_ids + + The primary use case for this routine is to support automatically + converting NX-based graph objects to be RX-based when creating a + Partition object. The issue is that when you convert from NX to RX + the node_ids change and so you need to update the Partition object's + Assignment to use the new RX node_ids. This routine is used + to translate those NX node_ids to the new RX node_ids when + initializing a Partition object. - def get_nx_to_rx_node_id_map(self): + :rtype: dict[Any, Any] + """ # Simple getter method if not self.is_rx_graph(): - raise Exception("get_nx_to_rx_node_id_map: Graph is not an RX based Graph") + raise TypeError( + "Graph passed to 'get_nx_to_rx_node_id()' is not a rustworkx graph" + ) return self.nx_to_rx_node_id_map @classmethod - def from_json(cls, json_file: str) -> "Graph": - # frm TODO: Do we want to be able to go from JSON directly to RX? - # - # Peter said that this is not a priority - that we only need RX after - # creating a partition, but maybe in the future if we decide to - # encourage an all RX world... - # + def from_json(cls, json_file_name: str) -> "Graph": + """ + Create a :class:`Graph` from a JSON file + + :param json_file_name: JSON file + # frm: TODO: Documentation: more detail on contents of JSON file needed here + :type json_file_name: str - with open(json_file) as f: + :returns: A GerryChain Graph object with data from JSON file + :rtype: "Graph" + """ + + # Note that this returns an NX-based Graph object. At some point in + # the future, if we embrace an all RX world, it will make sense to + # have it produce an RX-based Graph object. + + with open(json_file_name) as f: data = json.load(f) - # frm: A bit of Python magic - an adjacency graph is a dict of dict of dicts - # which is structurally equivalent to a NetworkX graph, so you can just - # pretend that is what it is and it all works. + + # A bit of Python magic - an adjacency graph is a dict of dict of dicts + # which is structurally equivalent to a NetworkX graph, so you can just + # pretend that is what it is and it all works. nx_graph = json_graph.adjacency_graph(data) + graph = cls.from_networkx(nx_graph) graph.issue_warnings() return graph - def to_json(self, json_file: str, include_geometries_as_geojson: bool = False) -> None: - # frm TODO: Implement this for an RX based graph + def to_json(self, json_file_name: str, include_geometries_as_geojson: bool = False) -> None: + """ + Dump a GerryChain Graph object to disk as a JSON file + + :param json_file_name: name of JSON file to be created + :type json_file_name: str + + :rtype: None + """ + # frm TODO: Code: Implement graph.to_json for an RX based graph if not self.is_nx_graph(): - raise Exception("At present, can only create JSON for NetworkX graph") + raise TypeError( + "Graph passed to 'to_json()' is not a networkx graph" + ) data = json_graph.adjacency_data(self._nx_graph) @@ -485,7 +651,7 @@ def to_json(self, json_file: str, include_geometries_as_geojson: bool = False) - else: remove_geometries(data) - with open(json_file, "w") as f: + with open(json_file_name, "w") as f: json.dump(data, f, default=json_serialize) @classmethod @@ -493,7 +659,7 @@ def from_file( cls, filename: str, adjacency: str = "rook", - cols_to_add: Optional[List[str]] = None, + cols_to_add: Optional[list[str]] = None, reproject: bool = False, ignore_errors: bool = False, ) -> "Graph": @@ -508,7 +674,7 @@ def from_file( :type adjacency: str, optional :param cols_to_add: The names of the columns that you want to add to the graph as node attributes. Default is None. - :type cols_to_add: Optional[List[str]], optional + :type cols_to_add: Optional[list[str]], optional :param reproject: Whether to reproject to a UTM projection before creating the graph. Default is False. :type reproject: bool, optional @@ -540,10 +706,12 @@ def from_file( reproject=reproject, ignore_errors=ignore_errors, ) - # frm: TODO: Need to make sure this works for RX also - # To do so, need to find out how CRS data is used - # and whether it is used externally or only internally... + # frm: TODO: Documentation: Make it clear that this creates an NX-based + # Graph object. # + # Also add some documentation (here or elsewhere) + # about what CRS data is and what it is used for. + # # Note that the NetworkX.Graph.graph["crs"] is only # ever accessed in this file (graph.py), so I am not # clear what it is used for. It seems to just be set @@ -581,7 +749,7 @@ def from_geodataframe( cls, dataframe: pd.DataFrame, adjacency: str = "rook", - cols_to_add: Optional[List[str]] = None, + cols_to_add: Optional[list[str]] = None, reproject: bool = False, ignore_errors: bool = False, crs_override: Optional[Union[str, int]] = None, @@ -591,7 +759,7 @@ def from_geodataframe( # Graph object at the end of the function. """ - Creates the adjacency :class:`Graph` of geometries described by `dataframe`. + Create the adjacency :class:`Graph` of geometries described by `dataframe`. The areas of the polygons are included as node attributes (with key `area`). The shared perimeter of neighboring polygons are included as edge attributes (with key `shared_perim`). @@ -614,7 +782,7 @@ def from_geodataframe( :type adjacency: str, optional :param cols_to_add: The names of the columns that you want to add to the graph as node attributes. Default is None. - :type cols_to_add: Optional[List[str]], optional + :type cols_to_add: Optional[list[str]], optional :param reproject: Whether to reproject to a UTM projection before creating the graph. Default is ``False``. :type reproject: bool, optional @@ -659,12 +827,12 @@ def from_geodataframe( # Generate dict of dicts of dicts with shared perimeters according # to the requested adjacency rule adjacencies = neighbors(df, adjacency) # Note - this is adjacency.neighbors() - # frm: TODO: Make it explicit that neighbors() above is adjacency.neighbors() - # frm: Original Code: graph = cls(adjacencies) nx_graph = networkx.Graph(adjacencies) - # frm: TODO: Need to grok what geometry is used for - it is used in partition.py.plot() + # frm: TODO: Documentation: Document what geometry is used for. + # + # Need to grok what geometry is used for - it is used in partition.py.plot() # and maybe that is the only place it is used, but it is also used below # to set other data, such as add_boundary_perimeters() and areas. The # reason this is an issue is because I need to know what to carry over to @@ -692,12 +860,18 @@ def from_geodataframe( nx_graph.geometry = df.geometry - # frm: TODO: Rethink the name of add_boundary_perimeters - it acts on an nx_graph - # which seems wrong with the given name. Maybe it should be: - # add_boundary_perimeters_to_nx_graph() + # frm: TODO: Refactoring: Rethink the name of add_boundary_perimeters + # + # It acts on an nx_graph which seems wrong with the given name. + # Maybe it should be: add_boundary_perimeters_to_nx_graph() + # + # Need to check in with Peter to see if this is considered + # part of the external API. + + # frm: TODO: Refactoring: Create an nx_utilities module # - # It raises the question of whether there should be an nx_utilities - # module for stuff designed to only work on nx_graph objects. + # It raises the question of whether there should be an nx_utilities + # module for stuff designed to only work on nx_graph objects. # # Note that Peter said: "I like this idea" # @@ -733,29 +907,6 @@ def from_geodataframe( return graph - def lookup(self, node: Any, field: Any): - # Not quite sure why this routine existed in the original graph.py - # code, since most of the other code does not use it, and instead - # does graph.nodes[node_id][key] - back when a Graph was a subclass - # of NetworkX.Graph. - # - # It is left because a couple of other files use it (versioneer.py, - # county_splits.py, and tally.py) and because perhaps an end user also - # uses it. Leaving it does not significant harm - it is just code bloat... - - # frm: TODO: Remove this routine: def lookup() => in FrozenGraph too - # - # As per Peter's PR comment: - # - # Yeah, I will get rid of this in the future. This is very old code - # that someone probably wrote to make their life easier in the early - # stages of the package, but it's not really useful. I am going to be - # changing all of the old setup and versioning systems over to use UV - # anyway, and county_splits.py and tally.py are easy changes - # - - return self.node_data(node, field) - # Performance Note: # # Most of the functions in the Graph class will be called after a @@ -766,21 +917,39 @@ def lookup(self, node: Any, field: Any): # @property - def node_indices(self): + def node_indices(self) -> set[Any]: + """ + Return a set of the node_ids in the graph + + :rtype: set[Any] + """ self.verify_graph_is_valid() - # frm: TODO: This does the same thing that graph.nodes does - returning a list of node_ids. + # frm: TODO: Refactoring: node_indices() does the same thing that graph.nodes does - returning a list of node_ids. # Do we really want to support two ways of doing the same thing? + # Actually this returns a set rather than a list - not sure that matters though... + # + # My code uses node_indices() to make it clear we are talking about node_ids... + # + # The question is whether to deprecate nodes()... if (self.is_rx_graph()): return set(self._rx_graph.node_indices()) elif (self.is_nx_graph()): return set(self._nx_graph.nodes) else: - raise Exception("Graph.node_indices - bad kind of graph object") + raise TypeError( + "Graph passed to 'node_indices()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) @property - def edge_indices(self): + def edge_indices(self) -> set[Any]: + """ + Return a set of the edge_ids in the graph + + :rtype: set[Any] + """ self.verify_graph_is_valid() if (self.is_rx_graph()): @@ -790,17 +959,28 @@ def edge_indices(self): # A set of edge_ids (tuples) extracted from the graph's EdgeView return set(self._nx_graph.edges) else: - raise Exception("Graph.edges - bad kind of graph object") + raise TypeError( + "Graph passed to 'edge_indices()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) - def get_edge_from_edge_id(self, edge_id): + def get_edge_from_edge_id(self, edge_id: Any) -> tuple[Any, Any]: """ - In NX, an edge_id is a tuple of node_ids, but in RX an edge_id - is an integer. To get the tuple of node_ids in RX, you need to - make a call using the edge_id. + Return the edge (tuple of node_ids) corresponding to the + given edge_id - Stated differently, in NX an edge and an edge ID are the same, but - not in RX... + Note that in NX, an edge_id is the same as an edge - it is + just a tuple of node_ids. However, in RX, an edge_id is + an integer, so if you want to get the tuple of node_ids + you need to use the edge_id to get that tuple... + + :param edge_id: The ID of the desired edge + :type edge_id: Any + + :returns: An edge, namely a tuple of node_ids + :rtype: tuple[Any, Any] """ + self.verify_graph_is_valid() if (self.is_rx_graph()): @@ -815,17 +995,57 @@ def get_edge_from_edge_id(self, edge_id): # In NX, the edge_id is also the edge tuple return edge_id else: - raise Exception("Graph.get_edge_from_edge_id - bad kind of graph object") + raise TypeError( + "Graph passed to 'get_edge_from_edge_id()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) - def get_edge_id_from_edge(self, edge): + # frm: TODO: Refactoring: Create abstract "edge" and "edge_id" type names + # + # As with node_id, this is cosmetic but it will provide a nice place to + # put a comment about the difference between NX and RX and it will make + # the type annotations make more sense... + + def get_edge_id_from_edge(self, edge: tuple[Any, Any]) -> Any: """ - Another case where we need to deal with the fact that in - NX an edge ID is a tuple of node_ids, where in RX an edge ID - is an integer assocaited with an edge. + Get the edge_id that corresponds to the given edge. + + In RX an edge_id is an integer that designates an edge (an edge is + a tuple of node_ids). In NX, an edge_id IS the tuple of node_ids. + So, in general, to support both NX and RX, if you want to get access + to the edge data for an edge (tuple of node_ids), you need to + ask for the edge_id. + + This functionality is needed, for instance, when + + :param edge: A tuple of node_ids. + :type edge: tuple[Any, Any] + + :returns: The ID associated with the given edge + :rtype: Any """ self.verify_graph_is_valid() if (self.is_rx_graph()): + + # frm: TODO: Performance: Perhaps get_edge_id_from_edge() is too expensive... + # + # If this routine becomes a signficant performance issue, then perhaps + # we can change the algorithms that use it so that it is not needed. + # In particular, there are several routines in tree.py that use it + # by traversing chains of nodes (successors and predecessors) which + # requires the code to recreate the edges from the nodes in hand. This + # was not a problem in an NX world - the tuple of nodes was exactly what + # and edge_id was, but in the RX world it is not - necessitating this routine. + # + # BUT... If the code had chains of edges rather than chains of nodes, + # then you could have the edge_ids at hand already and avoid having to + # do this lookup. + # + # However, it may be that the RX edge_indices_from_endpoints() is smart + # enough (for instance if it caches a dict mapping) that the performance + # hit is minimal... Here's to hoping that RX is "smart enough"... ;-) + # Note that while in general the routine, edge_indices_from_endpoints(), # can return more than one edge in the case of a Multi-Graph (a graph that # allows more than one edge between two nodes), we can rely on it only @@ -838,10 +1058,73 @@ def get_edge_id_from_edge(self, edge): # In NX, the edge_id is also the edge tuple return edge else: - raise Exception("Graph.get_edge_id_from_edge - bad kind of graph object") + raise TypeError( + "Graph passed to 'get_edge_id_from_edge()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) @property - def nodes(self): + def nodes(self) -> list[Any]: + """ + Return a list of all of the node_ids in the graph. + + This routine still exists because there is a lot of legacy + code that uses this syntax to iterate through all of the nodes + in a graph. + + There is another routine, node_indices(), which does essentially + the same thing (it returns a set of node_ids, however, rather than + a list). + + Why have two routines that do the same thing? The answer is that with + move to RX, it seemed appropriate to emphasize the distinction + between objects and the IDs for objects, hence the introduction of + node_indices() and edge_indices() routines. This distinction is + critical for edges, but mostly not important for nodes. In fact + this routine is implemented by just converting node_indices to a list. + So, it is essentially a style issue - when referring to nodes, we + are almost always really referring to node_ids, so why not use a + routine called node_indices()? + + Note that there is a subtle point to be made about node names vs. + node_ids. It was common before the transition to RX to create + nodes with IDs that were essentially names. That is, the ID had + semantic weight. This is not true with RX node_ids. So, any + code that relies on the semantics of a node's ID (treating it + like a name) is suspect in the new RX world. + + :returns: A list of all of the node_ids in the graph + :rtype: list[Any] + """ + + # frm: TODO: Documentation: Warn users in Migration Guide that nodes() has gone away + # + # Since the legacy code implemented a GerryChain Graph as a subclass of NetworkX.Graph + # legacy code could take advantage of NX cleverness - NX returns a NodeView object for + # nx_graph.nodes which supports much more than just a list of node_ids (which is all that + # code below does). + # + # Probably the most common use of nx_graph.nodes was to access node data as in: + # + # nx_graph.nodes[node_id][] + # + # In the new world, to do that you need to do: + # + # graph.node_data(node_id)[] + # + # So, almost the same number of keystrokes, but if a legacy user uses nodes[...] the + # old way, it won't work out well. + # + + # frm: TODO: Refactoring: Think about whether to do away entirely with graph.nodes + # + # All this routine does now is to coerce the set of nodes obtained by node_indices() + # to be a list (which I think is unnecessary). So, why have it at all? Why not just + # tell legacy users via an exception that it no longer exists? + # + # On the other hand, it maybe does no harm to allow legacy users to indulge in + # what appears to be a very common idiom in legacy code... + self.verify_graph_is_valid() if (self.is_rx_graph()): @@ -851,69 +1134,26 @@ def nodes(self): # A list of node_ids - return list(self._nx_graph.nodes) else: - raise Exception("Graph.sdges - bad kind of graph object") + raise TypeError( + "Graph passed to 'nodes()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) @property - def edges(self): - # frm: TODO: Confirm that this will work - returning different kinds of values - + def edges(self) -> set[tuple[Any, Any]]: """ - Edges are one of the areas where NX and RX differ. - - Conceptually an edge is just a tuple identifying the two nodes comprising the edge. - To be a little more specific, we will consider an edge to be a tuple of node_ids. - - But what is an edge_id? In NX, the edge_id is just the tuple of node_ids. I do - not know if NX is smart enough in an undirected graph to know that (3,4) is the same - as (4,3), but I assume that it is. In RX, however, the edge_id is just an integer. - Stated differently, in NX there is no difference between an "edge" and an "edge_id", - but in RX there is. - - So, the new Graph object is going to distinguish between edges and edge_ids. - Graph.edges will return a set of tuples in both cases, and Graph.edge_indices will - return a set of edge_ids in both cases. This is a little funky as the return type - for Graph.edge_indices will be structurally different for NX and RX version of Graph - objects, but hey - this is Python, so why not? Sorry for the snide attack... + Return a set of all of the edges in the graph, where each + edge is a tuple of node_ids - Another issue (that should probably be documented elsewhere instead of here) is that - in NX, Graph.edges returns an EdgeView object which allows for access to several - different bits of information about edges. If you iterate over Graph.edges you - get a sequence of tuples for the edges, but if you use square bracket notation, - as in: Graph.edges[(n1, n2)] you get access to the data dictionary for the edge. - - Here are some examples: - - for e in nx_graph.edges: - print("This edge goes between the following nodes: ", e) - - The above will print out all of the edge_id tuples: - - This edge goes between nodes: (46, 47) - This edge goes between nodes: (47, 55) - This edge goes between nodes: (48, 56) - This edge goes between nodes: (48, 49) - ... - - However, if you want to get the data dictionary associated with the edge that goes - between nodes 46, and 47, then you can do: - - print("node: (46,47) has data: ", nx_graph.edges[(46,47)]) - - node: (46,47) has data: {'weight': 5.5, 'total_population': 123445} - - RX does not support the EdgeView object, so we will use the same approach as for nodes. - To get access to an edge's data dictionary, one will need to use the new function, - edge_data(edge_id) - where edge_id will be either a tuple or an integer depending - on what flavor of Graph is being operated on. + :rtype: set[tuple[Any, Any]]: """ + # Return a set of edge tuples - self.verify_graph_is_valid() - - # frm: TODO: Think about whether edges() should return a set or a list. + # frm: TODO: Code: ???: Should edges return a list instead of a set? # - # nodes() returns a list, so my first take is that edges() should do so - # as well (or maybe nodes() should return a set). Seems odd for them to return - # different types... + # Peter said he thought users would expect a list - but why? + + self.verify_graph_is_valid() if (self.is_rx_graph()): # A set of tuples for the edges @@ -922,47 +1162,45 @@ def edges(self): # A set of tuples extracted from the graph's EdgeView return set(self._nx_graph.edges) else: - raise Exception("Graph.edges - bad kind of graph object") - - def add_edge(self, node_id1, node_id2): - # frm: TODO: Think about whether this routine makes sense for RX based Graph objects - # - # The current paradigm is that the user will use NX to create and modify graphs, but - # then transition (automatically) to using RX once a Partition object is created. - # This paradigm would not allow modifying an RX graph, so maybe (until we transition - # to allowing folks to start with RX based Graph objects) we disallow this for RX - # based graphs - # + raise TypeError( + "Graph passed to 'edges()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) - self.verify_graph_is_valid() + def add_edge(self, node_id1: Any, node_id2: Any) -> None: + """ + Add an edge to the graph from node_id1 to node_id2 - if (self.is_rx_graph()): - # empty dict tells RX the edge data will be a dict - self._rx_graph.add_edge(node_id1, node_id2, {}) - elif (self.is_nx_graph()): - self._nx_graph.add_edge(node_id1, node_id2) - else: - raise Exception("Graph.add_edge - bad kind of graph object") + :param node_id1: The node_id for one of the nodes in the edge + :type node_id1: Any + :param node_id2: The node_id for one of the nodes in the edge + :type node_id2: Any - def get_edge_tuple(self, edge_id): + :rtype: None + """ - # frm: TODO: Delete this routine after verifying that nobody uses it. + # frm: TODO: Code: add_edge(): Check that nodes exist and that they have data dicts. # - # It appears that the gerrychain code does not use this. Need to - # verify that it was added by me (and hence is not part of any - # external legacy code), but then it should be deleted as it is - # semantically the same as get_edge_from_edge_id() + # This checking should probably be limited to development mode, but + # the issue is that an RX node need not have a data value that is + # a dict, but GerryChain code depends on having a data dict. So, + # it makes sense to test and make sure that the nodes exist and + # have a data dict... + + # frm: TODO: Code: add_edge(): Do we need to check to make sure the edge does not already exist? self.verify_graph_is_valid() if (self.is_rx_graph()): - # frm: TODO: Performance: There is probably a more efficnient way to do this - return self._rx_graph.edge_list()[edge_id] + # empty dict tells RX the edge data will be a dict + self._rx_graph.add_edge(node_id1, node_id2, {}) elif (self.is_nx_graph()): - # In NX, the edge_id is already a tuple with the two node_ids - return edge_id + self._nx_graph.add_edge(node_id1, node_id2) else: - raise Exception("Graph.get_edge_tuple - bad kind of graph object") + raise TypeError( + "Graph passed to 'add_edge()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) def add_data( self, df: pd.DataFrame, columns: Optional[Iterable[str]] = None @@ -973,14 +1211,16 @@ def add_data( :param df: Dataframe containing given columns. :type df: :class:`pandas.DataFrame` - :param columns: List of dataframe column names to add. Default is None. + :param columns: list of dataframe column names to add. Default is None. :type columns: Optional[Iterable[str]], optional :returns: None """ if not (self.is_nx_graph()): - raise Exception("Graph.add_data only valid for NetworkX based graphs") + raise TypeError( + "Graph passed to 'add_data()' is not a networkx graph" + ) if columns is None: columns = list(df.columns) @@ -1001,7 +1241,7 @@ def add_data( def join( self, dataframe: pd.DataFrame, - columns: Optional[List[str]] = None, + columns: Optional[list[str]] = None, left_index: Optional[str] = None, right_index: Optional[str] = None, ) -> None: @@ -1013,7 +1253,7 @@ def join( :type dataframe: :class:`pandas.DataFrame` :columns: The columns whose data you wish to add to the graph. If not provided, all columns are added. Default is None. - :type columns: Optional[List[str]], optional + :type columns: Optional[list[str]], optional :left_index: The node attribute used to match nodes to rows. If not provided, node IDs are used. Default is None. :type left_index: Optional[str], optional @@ -1035,8 +1275,17 @@ def join( column_dictionaries = df.to_dict() + # frm: TODO: Code: Implement graph.join() for RX + # + # This is low priority given that current suggested coding + # strategy of creating the graph using NX and then letting + # GerryChain convert it automatically to RX. In this scenario + # any joins would happen to the NX-based graph only. + if not self.is_nx_graph(): - raise Exception("Graph.join only valid for NetworkX based Graph objects") + raise TypeError( + "Graph passed to join() is not a networkx graph" + ) nx_graph = self._nx_graph if left_index is not None: @@ -1044,7 +1293,6 @@ def join( else: # When the left_index is node ID, the matching is just # a redundant {node: node} dictionary - # frm: TODO: don't think self.nodes works for RX... ids_to_index = dict(zip(self.nodes, self.nodes)) node_attributes = { @@ -1057,17 +1305,39 @@ def join( networkx.set_node_attributes(nx_graph, node_attributes) @property - def islands(self): + def islands(self) -> set[Any]: + """ + Return a set of all node_ids that are not connected via an + edge to any other node in the graph - that is, nodes with + degree = 0 + + :returns: A set of all node_ids for nodes of degree 0 + :rtype: set[Any] + """ # Return all nodes of degree 0 (those not connected in an edge to another node) - return set(node for node in self.node_indices if self.degree(node) == 0) + return set(node_id for node_id in self.node_indices if self.degree(node_id) == 0) + + def is_directed(self) -> bool: + # frm TODO: Code: Delete this code: graph.is_directed() once convinced it is safe to do so... + # + # I added it because code in contiguity.py + # called nx.is_connected() which eventually called is_directed() + # assuming the graph was an nx_graph. + # + # Changing from return False to raising an exception just to make + # sure nobody uses it. + + raise NotImplementedError("graph.is_directed() should not be used") - def is_directed(self): - # frm TODO: Get rid of this hack. I added it because code in contiguity.py - # called nx.is_connected() which eventually called is_directed() - # assuming the graph was an nx_graph. - return False def warn_for_islands(self) -> None: + """ + Issue a warning if there are any islands in the graph - that is, + if there are any nodes in the graph that are not connected to any + other node (degree = 0) + + :rtype: None + """ islands = self.islands if len(self.islands) > 0: warnings.warn( @@ -1075,19 +1345,34 @@ def warn_for_islands(self) -> None: ) def issue_warnings(self) -> None: - self.warn_for_islands() + """ + Issue any warnings concerning the content or structure + of the graph. - # frm TODO: Implement a FrozenGraph that supports RX... - # self.graph.join = frozen - # self.graph.add_data = frozen - # self.size = len(self.graph) + :rtype: None + """ + self.warn_for_islands() def __len__(self) -> int: - # Relies on self.node_indices to work on both NX and RX + """ + Return the number of nodes in the graph + + :rtype: int + """ return len(self.node_indices) def __getattr__(self, __name: str) -> Any: - # frm: TODO: Get rid of this eventually - it is very dangerous... + """ + + + :param : ...text... + ...more text... + :type : + + :returns: ...text... + :rtype: + """ + # frm: TODO: Code: Get rid of _getattr_ eventually - it is very dangerous... # frm: Interesting bug lurking if __name is "nx_graph". This occurs when legacy code # uses the default constructor, Graph(), and then references a built-in NX @@ -1109,7 +1394,7 @@ def __getattr__(self, __name: str) -> Any: # It's very, very rare to use the default constructor, so I don't imagine that # people will really run into this. - # frm: TODO: Fix this hack - see comment above... + # frm: TODO: Code: Fix this hack (in __getattr__) - see comment above... if (__name == "_nx_graph") or (__name == "_rx_graph"): return None @@ -1120,10 +1405,24 @@ def __getattr__(self, __name: str) -> Any: elif (self.is_nx_graph()): return object.__getattribute__(self._nx_graph, __name) else: - raise Exception("Graph.__getattribute__ - bad kind of graph object") + raise TypeError( + "Graph passed to '__gettattr__()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) def __getitem__(self, __name: str) -> Any: - # frm: ???: TODO: Does any of the code actually use this? + """ + + + :param : ...text... + ...more text... + :type : + + :returns: ...text... + :rtype: + """ + # frm: TODO: Code: Does any of the code actually use __getitem__ ? + # # It is a clever Python way to use square bracket # notation to access something (anything) you want. # @@ -1141,23 +1440,54 @@ def __getitem__(self, __name: str) -> Any: self.verify_graph_is_valid() if (self.is_rx_graph()): - # frm TODO: - raise Exception("Graph.__getitem__() NYI for RX") + # frm TODO: Code: Decide if __getitem__() should work for RX + raise TypeError("Graph._getitem__() is not defined for a rustworkx graph") elif (self.is_nx_graph()): return self._nx_graph[__name] else: - raise Exception("Graph.__getitem__() - bad kind of graph object") + raise TypeError( + "Graph passed to '__getitem__()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) def __iter__(self) -> Iterable[Any]: - # frm: TODO: Verify that this does the right thing... - # It seems to do the right thing - iterating over node_ids which - # works so long as NX uses integers for node_ids. - # frm: TODO: Perhaps I should test for non-integer node_ids in NX graphs and issue a warning... - # In any event, this deserves thought: what to do for NX graphs that do not use - # integers for node_ids? + """ + + + :param : ...text... + ...more text... + :type : + + :returns: ...text... + :rtype: + """ yield from self.node_indices def subgraph(self, nodes: Iterable[Any]) -> "Graph": + """ + Create a subgraph that contains the given nodes. + + Note that creating a subgraph of an RustworkX (RX) graph + renumbers the nodes, so that a node that had node_id: 4 + in the parent graph might have node_id: 2 in the subgraph. + This is a HUGE difference from the NX world where the + node_ids in a subgraph do not change from those in the + parent graph. + + In order to make sense of the nodes in a subgraph in the + RX world, we need to maintain mappings from the node_ids + in the subgraph to the node_ids of the immediate parent + graph and to the "original" top-level graph that contains + all of the nodes. You will notice the creation of those + maps in the code below. + + :param nodes: The nodes to be included in the subgraph + :type nodes: Iterable[Any] + + :returns: A subgraph containing the given nodes. + :rtype: "Graph" + """ + """ frm: RX Documentation: @@ -1205,38 +1535,34 @@ def subgraph(self, nodes: Iterable[Any]) -> "Graph": _node_id_to_parent_node_id_map = {node: node for node in nodes} _node_id_to_original_nx_node_id_map = {node: node for node in nodes} elif (self.is_rx_graph()): - # frm TODO: Need to check logic below - not sure this works exactly correctly for RX... if isinstance(nodes, frozenset) or isinstance(nodes, set): nodes = list(nodes) + # For RX, the node_ids in the subgraph change, so we need a way to map subgraph node_ids # into parent graph node_ids. To do so, we add the parent node_id into the node data # so that in the subgraph we can find it and then create the map. - # frm: TODO: Be careful - node data is shared by subgraphs, so a subgraph of this - # subgraph will still have this field set - meaning that the field's - # value is not dependable over time - perhaps I should null it out - # after using it here... + # + # Note that this works because the node_data dict is shared by the nodes in both the + # parent graph and the subgraph, so we can set the "parent" node_id in the parent before + # creating the subgraph, and that value will be available in the subgraph even though + # the subgraph will have a different node_id for the same node. + # + # This value is removed from the node_data below after creating the subgraph. + # for node_id in nodes: self.node_data(node_id)["parent_node_id"] = node_id - # frm: TODO: Since data is shared by nodes in subgraphs, perhaps we could just set - # the "original_nx_node_id" in the beginning and rely on it forever... - # - # This actually gives me the heebie-jeebies - we have aliased data that we are overwriting. - # Either it is aliased data and we count on that and don't bother to update it, or else - # we ensure that it is NOT aliased data and then we can feel free to overwrite it and/or - # set it anew... - # - # In short, verify that data in subgraphs is shared (in RX) and then think to make sure - # that we do NOT need to set this value, because it is already set. - # - # A more general issue is whether it is OK for data in subgraphs to be shared. I think - # so because we do not store any temporary or context dependent information in node-data - # but it would be nice to validate that (and ideally to test it). At the very least - # if subgraph data (and in fact NX and RX node data) is shared, then that should be documented - # up the wazoo... + # It is also important for all RX graphs (subgraphs or top-level graphs) to have + # a mapping from RX node_id to the "original" NX node_id. However, we do not need + # to do what we do with the _node_id_to_parent_node_id_map and set the value of + # the "original" node_id now, because this value never changes for a node. It + # should already have been set for each node by the standard RX code that + # converts from NX to RX (which sets the "__networkx_node__" attribute to be + # the NX node_id). We just check to make sure that it is in fact set. # for node_id in nodes: - self.node_data(node_id)["original_nx_node_id"] = self._node_id_to_original_nx_node_id_map[node_id] + if not ("__networkx_node__" in self.node_data(node_id)): + raise Exception("subgraph: internal error: original_nx_node_id not set") rx_subgraph = self._rx_graph.subgraph(nodes) new_subgraph = self.from_rustworkx(rx_subgraph) @@ -1244,14 +1570,21 @@ def subgraph(self, nodes: Iterable[Any]) -> "Graph": # frm: Create the map from subgraph node_id to parent graph node_id _node_id_to_parent_node_id_map = {} for subgraph_node_id in new_subgraph.node_indices: - _node_id_to_parent_node_id_map[subgraph_node_id] = new_subgraph.node_data(subgraph_node_id)["parent_node_id"] + _node_id_to_parent_node_id_map[subgraph_node_id] = \ + new_subgraph.node_data(subgraph_node_id)["parent_node_id"] + # value no longer needed, so delete it + new_subgraph.node_data(subgraph_node_id).pop("parent_node_id") + # frm: Create the map from subgraph node_id to the original graph's node_id _node_id_to_original_nx_node_id_map = {} for subgraph_node_id in new_subgraph.node_indices: - _node_id_to_original_nx_node_id_map[subgraph_node_id] = new_subgraph.node_data(subgraph_node_id)["original_nx_node_id"] - + _node_id_to_original_nx_node_id_map[subgraph_node_id] = \ + new_subgraph.node_data(subgraph_node_id)["__networkx_node__"] else: - raise Exception("Graph.subgraph - bad kind of graph object") + raise TypeError( + "Graph passed to 'subgraph()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) new_subgraph._is_a_subgraph = True new_subgraph._node_id_to_parent_node_id_map = _node_id_to_parent_node_id_map @@ -1259,30 +1592,82 @@ def subgraph(self, nodes: Iterable[Any]) -> "Graph": return new_subgraph - def translate_subgraph_node_ids_for_flips(self, flips): - # flips is a dictionary mapping node_ids to parts (districts). - # This routine replaces the node_ids of the subgraph with the node_ids - # for the same node in the parent graph. This routine is used to - # when a computation is made on a subgraph but the resulting flips - # being returned want to be the appropriate node_ids for the parent graph. - translated_flips = {} - for subgraph_node_id, part in flips.items(): - parent_node_id = self._node_id_to_parent_node_id_map[subgraph_node_id] - translated_flips[parent_node_id] = part + # frm: TODO: Refactoring: Create abstract type name for "Flip" and "Flip_Dict". + # + # This is cosmetic, but it would (IMHO) make the code easier to understand, and it + # would provide a logical place to define WTF a flip is... - return translated_flips + def translate_subgraph_node_ids_for_flips(self, flips: dict[Any, int]) -> dict[Any, int]: + """ + Translate the given flips so that the subgraph node_ids in the flips + have been translated to the appropriate node_ids in the + parent graph. - def translate_subgraph_node_ids_for_set_of_nodes(self, set_of_nodes): - # This routine replaces the node_ids of the subgraph with the node_ids - # for the same node in the parent graph. This routine is used to - # when a computation is made on a subgraph but the resulting set of nodes + The flips parameter is a dict mapping node_ids to parts (districts). + + This routine is used when a computation that creates flips is made + on a subgraph, but those flips want to be translated into the context + of the parent graph at the end of the computation. + + For more details, refer to the larger comment on subgraphs... + + :param flips: A dict containing "flips" which associate a node with + a new part in a partition (a "part" is the same as a district in + common parlance). + :type flips: dict[Any, int] + + :returns: A dict containing "flips" that have been translated to have + node_ids appropriate for the parent graph + :rtype: dict[Any, int] + """ + + # frm: TODO: Documentation: Write an overall comment on subgraphs and node_id maps + + translated_flips = {} + for subgraph_node_id, part in flips.items(): + parent_node_id = self._node_id_to_parent_node_id_map[subgraph_node_id] + translated_flips[parent_node_id] = part + + return translated_flips + + def translate_subgraph_node_ids_for_set_of_nodes(self, set_of_nodes: set[Any]) -> set[Any]: + """ + Translate the given set_of_nodes to have the appropriate + node_ids for the parent graph. + + This routine is used when a computation that creates a set of nodes is made + on a subgraph, but those nodes want to be translated into the context + of the parent graph at the end of the computation. + + For more details, refer to the larger comment on subgraphs... + + :param set_of_nodes: A set of node_ids in a subgraph + :type set_of_nodes: set[Any] + + :returns: A set of node_ids that have been translated to have + the node_ids appropriate for the parent graph + :rtype: set[Any] + """ + # This routine replaces the node_ids of the subgraph with the node_ids + # for the same node in the parent graph. This routine is used to + # when a computation is made on a subgraph but the resulting set of nodes # being returned want to be the appropriate node_ids for the parent graph. translated_set_of_nodes = set() for node_id in set_of_nodes: translated_set_of_nodes.add(self._node_id_to_parent_node_id_map[node_id]) return translated_set_of_nodes - def generic_bfs_edges(self, source, neighbors=None, depth_limit=None): + def generic_bfs_edges(self, source, neighbors=None, depth_limit=None) -> Generator[tuple[Any, Any], None, None]: + """ + + + :param : ...text... + ...more text... + :type : + + :returns: ...text... + :rtype: + """ # frm: Code copied from GitHub: # # https://github.com/networkx/networkx/blob/main/networkx/algorithms/traversal/breadth_first_search.py @@ -1374,14 +1759,24 @@ def generic_bfs_edges(self, source, neighbors=None, depth_limit=None): seen.add(child) # frm: add this node's children to list to be processed later... next_parents_children.append((child, neighbors(child))) - yield parent, child + yield (parent, child) if len(seen) == n: return depth += 1 - # frm: TODO: Add tests for all of the new routines I have added... + # frm: TODO: Testing: Add tests for all of the new routines I have added... - def generic_bfs_successors_generator(self, root_node_id): + def generic_bfs_successors_generator(self, root_node_id: Any) -> Generator[tuple[Any, Any], None, None]: + """ + + + :param : ...text... + ...more text... + :type : + + :returns: ...text... + :rtype: + """ # frm: Generate in sequence a tuple for the parent (node_id) and # the children of that node (list of node_ids). parent = root_node_id @@ -1401,10 +1796,30 @@ def generic_bfs_successors_generator(self, root_node_id): parent = p yield (parent, children) - def generic_bfs_successors(self, root_node_id): + def generic_bfs_successors(self, root_node_id: Any) -> dict[Any: Any]: + """ + + + :param : ...text... + ...more text... + :type : + + :returns: ...text... + :rtype: + """ return dict(self.generic_bfs_successors_generator(root_node_id)) - def generic_bfs_predecessors(self, root_node_id): + def generic_bfs_predecessors(self, root_node_id: Any) -> dict[Any, Any]: + """ + + + :param : ...text... + ...more text... + :type : + + :returns: ...text... + :rtype: + """ # frm Note: We had do implement our own, because the built-in RX version only worked # for directed graphs. predecessors = [] @@ -1413,7 +1828,17 @@ def generic_bfs_predecessors(self, root_node_id): return dict(predecessors) - def predecessors(self, root_node_id): + def predecessors(self, root_node_id: Any) -> dict[Any: Any]: + """ + + + :param : ...text... + ...more text... + :type : + + :returns: ...text... + :rtype: + """ """ frm: It took me a while to grok what predecessors() and successors() @@ -1448,13 +1873,16 @@ def predecessors(self, root_node_id): All of this is interesting, but I have not yet spent the time to figure out why it matters in the code. - TODO: Decide if it makes sense to have different implementations + TODO: Code: predecessors(): Decide if it makes sense to have different implementations for NX and RX. The code below has the original definition from the pre-RX codebase, but the code for RX will work for NX too - so I think that there is no good reason to have different code for NX. Maybe no harm, but on the other hand, it seems like a needless difference and hence more complexity... + + TODO: Performance: see if the performance of the built-in NX + version is significantly better than the generic one. """ self.verify_graph_is_valid() @@ -1464,9 +1892,22 @@ def predecessors(self, root_node_id): elif (self.is_nx_graph()): return {a: b for a, b in networkx.bfs_predecessors(self._nx_graph, root_node_id)} else: - raise Exception("Graph.predecessors - bad kind of graph object") + raise TypeError( + "Graph passed to 'predecessors()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) - def successors(self, root_node_id): + def successors(self, root_node_id: Any) -> dict[Any: Any]: + """ + + + :param : ...text... + ...more text... + :type : + + :returns: ...text... + :rtype: + """ self.verify_graph_is_valid() if (self.is_rx_graph()): @@ -1474,43 +1915,83 @@ def successors(self, root_node_id): elif (self.is_nx_graph()): return {a: b for a, b in networkx.bfs_successors(self._nx_graph, root_node_id)} else: - raise Exception("Graph.successors - bad kind of graph object") + raise TypeError( + "Graph passed to 'successors()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) + + def neighbors(self, node_id: Any) -> list[Any]: + """ + Return a list of the node_ids of the nodes that are neighbors of + the given node - that is, all of the nodes that are directly + connected to the given node by an edge. - def neighbors(self, node): + :param node_id: The ID of a node + :type node_id: Any + + :returns: A list of neighbor node_ids + :rtype: list[Any] + """ self.verify_graph_is_valid() - # NX neighbors() returns a which iterates over the node_ids of neighbor nodes - # RX neighbors() returns a NodeIndices object with the list of node_ids of neighbor nodes - # However, the code outside graph.py only ever iterates over all neighbors so returning a list works... if (self.is_rx_graph()): - # frm: TODO: Performance: Do not convert to a list - # - # The RX path (this path) is much more expensive than the NX path, and - # I am assuming that it is the conversion to a list in the original code - # that is expensive, so I am going to just not convert it to a list - # since the return type of rx.neighbors() is a NodeIndices object which - # is already essentially a Python list - in that it implements the Python sequence protocol - # - # Original code: - # return list(self._rx_graph.neighbors(node)) - return self._rx_graph.neighbors(node) + return list(self._rx_graph.neighbors(node_id)) elif (self.is_nx_graph()): - return list(self._nx_graph.neighbors(node)) + return list(self._nx_graph.neighbors(node_id)) else: - raise Exception("Graph.neighbors - bad kind of graph object") + raise TypeError( + "Graph passed to 'neighbors()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) + + def degree(self, node_id: Any) -> int: + """ + Return the degree of the given node, that is, the number + of other nodes directly connected to the given node. - def degree(self, node: Any) -> int: + :param node_id: The ID of a node + :type node_id: Any + + :returns: Number of nodes directly connected to the given node + :rtype: int + """ self.verify_graph_is_valid() if (self.is_rx_graph()): - return self._rx_graph.degree(node) + return self._rx_graph.degree(node_id) elif (self.is_nx_graph()): - return self._nx_graph.degree(node) + return self._nx_graph.degree(node_id) else: - raise Exception("Graph.degree - bad kind of graph object") + raise TypeError( + "Graph passed to 'degree()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) + + def node_data(self, node_id: Any) -> dict[Any, Any]: + """ + Return the data dictionary that contains the given node's data. + + As docmented elsewhere, in GerryChain code before the conversion + to RustworkX, users could access node data using the syntax: + + graph.nodes[node_id][attribute_name] + + This was because a GerryChain Graph object in that codebase was a + subclass of NetworkX.Graph, and NetworkX was clever and implemented + dict-like behavior for the syntax graph.nodes[]... + + This Python cleverness was not carried over to the RustworkX + implementation, so in the current GerryChain Graph implementation + users need to access node data using the syntax: + + graph.node_data(node_id)[attribute_name] - def node_data(self, node_id): - # This routine returns the data dictionary for the given node's data + :param node_id: The ID of a node + :type node_id: Any + + :returns: Data dictionary containing the given node's data. + :rtype: dict[Any, Any] + """ self.verify_graph_is_valid() @@ -1519,53 +2000,77 @@ def node_data(self, node_id): elif (self.is_nx_graph()): data_dict = self._nx_graph.nodes[node_id] else: - raise Exception("Graph.node_data - bad kind of graph object") + raise TypeError( + "Graph passed to 'node_data()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) if not isinstance(data_dict, dict): - raise Exception("node data is not a dictionary"); + raise TypeError("graph.node_data(): data for node is not a dict") return data_dict - def edge_data(self, edge_id): - # This routine returns the data dictionary for the given edge's data - + def edge_data(self, edge_id: Any) -> dict[Any, Any]: """ - CLEVERNESS ALERT! - - The type of the edge_id parameter will be a tuple in the case of an - embedded NX graph but will be an integer in the case of an RX embedded - graph. - + Return the data dictionary that contains the data for the given edge. + + Note that in NetworkX an edge_id can be almost anything, for instance, + a string or even a tuple. However, in RustworkX, an edge_id is + an integer. This code handles both kinds of edge_ids - hence the + type, Any. + + :param edge_id: The ID of the edge + :type edge_id: Any + + :returns: The data dictionary for the given edge's data + :rtype: dict[Any, Any] """ self.verify_graph_is_valid() if (self.is_rx_graph()): - # frm: TODO: Performance: use get_edge_data_by_index() - # - # Original code (pre October 2025) that indexed into all edges => slow - # data_dict = self._rx_graph.edges()[edge_id] data_dict = self._rx_graph.get_edge_data_by_index(edge_id) elif (self.is_nx_graph()): data_dict = self._nx_graph.edges[edge_id] else: - raise Exception("Graph.edge_data - bad kind of graph object") + raise TypeError( + "Graph passed to 'edge_data()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) + # Sanity check - RX edges do not need to have a data dict for node data + # + # A GerryChain Graph object should always be constructed with a data dict + # for edge data, but it doesn't hurt to check. if not isinstance(data_dict, dict): - raise Exception("edge data is not a dictionary"); + raise TypeError("graph.edge(): data for edge is not a dict") return data_dict - # frm: Note: I added the laplacian_matrix routines as methods of the Graph + # frm: TODO: Documentation: Note: I added the laplacian_matrix routines as methods of the Graph # class because they are only ever used on Graph objects. It # bloats the Graph class, but it still seems like the best # option. + # + # A goal is to encapsulate ALL NX dependencies in this file. + + def laplacian_matrix(self) -> scipy.sparse.csr_array: + """ + + + :param : ...text... + ...more text... + :type : - def laplacian_matrix(self): + :returns: ...text... + :rtype: + """ # A local "gc" (as in GerryChain) version of the laplacian matrix - # frm: TODO: The NX version returns a matrix of integer values while the + # frm: TODO: Code: laplacian_matrix(): should NX and RX return same type (float vs. int)? + # + # The NX version returns a matrix of integer values while the # RX version returns a matrix of floating point values. I # think the reason is that the RX.adjacency_matrix() call # returns an array of floats. @@ -1588,13 +2093,36 @@ def laplacian_matrix(self): nx_graph = self._nx_graph laplacian_matrix = networkx.laplacian_matrix(nx_graph) else: - raise Exception("laplacian_matrix: badly configured graph parameter") + raise TypeError( + "Graph passed into laplacian_matrix() is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) return laplacian_matrix - def normalized_laplacian_matrix(self): + def normalized_laplacian_matrix(self) -> scipy.sparse.dia_array: + """ + + + :param : ...text... + ...more text... + :type : + + :returns: ...text... + :rtype: + """ + + def create_scipy_sparse_array_from_rx_graph(rx_graph: rustworkx.PyGraph) -> scipy.sparse.coo_matrix: + """ + + + :param : ...text... + ...more text... + :type : - def create_scipy_sparse_array_from_rx_graph(rx_graph): + :returns: ...text... + :rtype: + """ num_nodes = rx_graph.num_nodes() rows = [] @@ -1641,46 +2169,49 @@ def create_scipy_sparse_array_from_rx_graph(rx_graph): # that uses normalized_laplacian_matrix() now passes, but it is for a small 6x6 graph # and hence is not a real world test... # - # I have left debugging print statements (commented out) in case someone in the future - # (probably me) wants to verify that the right things are happening... - # Original NetworkX Code: - # A = networkx.to_scipy_sparse_array(G, nodelist=nodelist, weight=weight, format="csr") - # A = create_scipy_sparse_array_from_rx_graph(rx_graph) n, _ = A.shape # shape() => dimensions of the array (rows, cols), so n = num_rows - # print("") - # print(f"normalized_laplacian_matrix: num_rows = {n}") - # print(f"normalized_laplacian_matrix: sparse_array is: {A}") diags = A.sum(axis=1) # sum of values in each row => column vector diags = diags.T # convert to a row vector / 1D array - # print(f"normalized_laplacian_matrix: diags is: \n{diags}") D = scipy.sparse.dia_array((diags, [0]), shape=(n, n)).tocsr() - # print(f"normalized_laplacian_matrix: D is: \n{D}") L = D - A - # print(f"normalized_laplacian_matrix: L is: \n{L}") with numpy.errstate(divide="ignore"): diags_sqrt = 1.0 / numpy.sqrt(diags) diags_sqrt[numpy.isinf(diags_sqrt)] = 0 - # print(f"normalized_laplacian_matrix: diags_sqrt is: \n{diags_sqrt}") DH = scipy.sparse.dia_array((diags_sqrt, 0), shape=(n, n)).tocsr() - # print(f"normalized_laplacian_matrix: DH is: \n{DH}") normalized_laplacian = DH @ (L @ DH) - # print(f"normalized_laplacian_matrix: normalized_laplacian is: \n{normalized_laplacian}") return normalized_laplacian elif self.is_nx_graph(): nx_graph = self._nx_graph laplacian_matrix = networkx.normalized_laplacian_matrix(nx_graph) else: - raise Exception("normalized_laplacian_matrix: badly configured graph parameter") + raise TypeError( + "Graph passed into normalized_laplacian_matrix() is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) return laplacian_matrix - def subgraphs_for_connected_components(self): - # Create a list of subgraphs - one for each subset of connected nodes in the graph - # - # This mirrors the nx.connected_components() routine in NetworkX + def subgraphs_for_connected_components(self) -> list["Graph"]: + """ + Create and return a list of subgraphs for each set of nodes + in the given graph that are connected. + + Note that a connected graph is one in which there is a path + from every node in the graph to every other node in the + graph. + + Note also that each of the subgraphs returned is a + maximal subgraph of connected components, meaning that there + is no other larger subgraph of connected components that + includes it as a subset. + + :returns: A list of "maximal" subgraphs each of which + contains nodes that are connected. + :rtype: list["Graph"] + """ if self.is_rx_graph(): rx_graph = self.get_rx_graph() @@ -1693,11 +2224,38 @@ def subgraphs_for_connected_components(self): self.subgraph(nodes) for nodes in networkx.connected_components(nx_graph) ] else: - raise Exception("subgraphs_for_connected_components: Bad kind of Graph") + raise TypeError( + "Graph passed to 'subgraphs_for_connected_components()' is " + "neither a networkx-based graph nor a rustworkx-based graph" + ) return subgraphs - def num_connected_components(self): + def num_connected_components(self) -> int: + """ + Return the number of connected components. + + Note: A connected component is a maximal subgraph + where every vertex is reachable from every other vertex in + that same subgraph. In a graph that is not fully connected, + connected components are the separate, distinct "islands" of + connected nodes. Every node in a graph belongs to exactly + one connected component. + + :returns: The number of connected components + :rtype: int + """ + + # frm: TODO: Performance: num_connected_components(): do both NX and RX have builtins for this? + # + # NetworkX and RustworkX both have a routine number_connected_components(). + # I am guessing that it is more efficient to call these than it is + # to construct the connected components and then determine how many + # of them there are. + # + # So - should be a simple issue of trying it and running tests, but + # I will do that another day... + if self.is_rx_graph(): rx_graph = self.get_rx_graph() connected_components = rustworkx.connected_components(rx_graph) @@ -1705,12 +2263,30 @@ def num_connected_components(self): nx_graph = self.get_nx_graph() connected_components = list(networkx.connected_components(nx_graph)) else: - raise Exception("num_connected_components: Bad kind of Graph") + raise TypeError( + "Graph passed to 'num_connected_components()' is neither " + "a networkx-based graph nor a rustworkx-based graph" + ) num_cc = len(connected_components) return num_cc - def is_a_tree(self): + def is_a_tree(self) -> bool: + """ + Return whether the current graph is a tree - meaning that + it is connected and that it has no cycles. + + :returns: Whether the current graph is a tree + :rtype: bool + """ + + # frm: TODO: Refactor: is_a_tree() is only called in a test (test_tree.py) - delete it? + # + # Talk to Peter to see if there is any reason to keep this function. Does anyone + # use it externally? + # + # On the other hand, perhaps it is OK to keep it even if it is only ever used in a test... + if self.is_rx_graph(): rx_graph = self.get_rx_graph() num_nodes = rx_graph.num_nodes() @@ -1734,418 +2310,11 @@ def is_a_tree(self): nx_graph = self.get_nx_graph() return networkx.is_tree(nx_graph) else: - raise Exception("is_a_tree: Bad kind of Graph") - -###################################################### - -class OriginalGraph(networkx.Graph): - """ - frm: This is the original code for gerrychain.Graph before any RustworkX changes. - - It continues to exist so that I can write tests to verify that from the outside - the new Graph object behaves the same as the original Graph object. - - See the test in tests/frm_tests/test_frm_old_vs_new_graph.py - """ - - # frm: Original Graph code... - def __repr__(self): - return "".format(len(self.nodes), len(self.edges)) - - # frm: Original Graph code... - @classmethod - def from_networkx(cls, graph: networkx.Graph) -> "Graph": - """ - Create a Graph instance from a networkx.Graph object. - - :param graph: The networkx graph to be converted. - :type graph: networkx.Graph - - :returns: The converted graph as an instance of this class. - :rtype: Graph - """ - g = cls(graph) - return g - - # frm: Original Graph code... - @classmethod - def from_json(cls, json_file: str) -> "Graph": - """ - Load a graph from a JSON file in the NetworkX json_graph format. - - :param json_file: Path to JSON file. - :type json_file: str - - :returns: The loaded graph as an instance of this class. - :rtype: Graph - """ - with open(json_file) as f: - data = json.load(f) - g = json_graph.adjacency_graph(data) - graph = cls.from_networkx(g) - graph.issue_warnings() - return graph - - # frm: Original Graph code... - def to_json( - self, json_file: str, *, include_geometries_as_geojson: bool = False - ) -> None: - """ - Save a graph to a JSON file in the NetworkX json_graph format. - - :param json_file: Path to target JSON file. - :type json_file: str - :param bool include_geometry_as_geojson: Whether to include - any :mod:`shapely` geometry objects encountered in the graph's node - attributes as GeoJSON. The default (``False``) behavior is to remove - all geometry objects because they are not serializable. Including the - GeoJSON will result in a much larger JSON file. - :type include_geometries_as_geojson: bool, optional - - :returns: None - """ - data = json_graph.adjacency_data(self) - - if include_geometries_as_geojson: - convert_geometries_to_geojson(data) - else: - remove_geometries(data) - - with open(json_file, "w") as f: - json.dump(data, f, default=json_serialize) - - # frm: Original Graph code... - @classmethod - def from_file( - cls, - filename: str, - adjacency: str = "rook", - cols_to_add: Optional[List[str]] = None, - reproject: bool = False, - ignore_errors: bool = False, - ) -> "Graph": - """ - Create a :class:`Graph` from a shapefile (or GeoPackage, or GeoJSON, or - any other library that :mod:`geopandas` can read. See :meth:`from_geodataframe` - for more details. - - :param filename: Path to the shapefile / GeoPackage / GeoJSON / etc. - :type filename: str - :param adjacency: The adjacency type to use ("rook" or "queen"). Default is "rook" - :type adjacency: str, optional - :param cols_to_add: The names of the columns that you want to - add to the graph as node attributes. Default is None. - :type cols_to_add: Optional[List[str]], optional - :param reproject: Whether to reproject to a UTM projection before - creating the graph. Default is False. - :type reproject: bool, optional - :param ignore_errors: Whether to ignore all invalid geometries and try to continue - creating the graph. Default is False. - :type ignore_errors: bool, optional - - :returns: The Graph object of the geometries from `filename`. - :rtype: Graph - - .. Warning:: - - This method requires the optional ``geopandas`` dependency. - So please install ``gerrychain`` with the ``geo`` extra - via the command: - - .. code-block:: console - - pip install gerrychain[geo] - - or install ``geopandas`` separately. - """ - - df = gp.read_file(filename) - graph = cls.from_geodataframe( - df, - adjacency=adjacency, - cols_to_add=cols_to_add, - reproject=reproject, - ignore_errors=ignore_errors, - ) - - graph.graph["crs"] = df.crs.to_json() - return graph - - # frm: Original Graph code... - @classmethod - def from_geodataframe( - cls, - dataframe: pd.DataFrame, - adjacency: str = "rook", - cols_to_add: Optional[List[str]] = None, - reproject: bool = False, - ignore_errors: bool = False, - crs_override: Optional[Union[str, int]] = None, - ) -> "Graph": - """ - Creates the adjacency :class:`Graph` of geometries described by `dataframe`. - The areas of the polygons are included as node attributes (with key `area`). - The shared perimeter of neighboring polygons are included as edge attributes - (with key `shared_perim`). - Nodes corresponding to polygons on the boundary of the union of all the geometries - (e.g., the state, if your dataframe describes VTDs) have a `boundary_node` attribute - (set to `True`) and a `boundary_perim` attribute with the length of this "exterior" - boundary. - - By default, areas and lengths are computed in a UTM projection suitable for the - geometries. This prevents the bizarro area and perimeter values that show up when - you accidentally do computations in Longitude-Latitude coordinates. If the user - specifies `reproject=False`, then the areas and lengths will be computed in the - GeoDataFrame's current coordinate reference system. This option is for users who - have a preferred CRS they would like to use. - - :param dataframe: The GeoDateFrame to convert - :type dataframe: :class:`geopandas.GeoDataFrame` - :param adjacency: The adjacency type to use ("rook" or "queen"). - Default is "rook". - :type adjacency: str, optional - :param cols_to_add: The names of the columns that you want to - add to the graph as node attributes. Default is None. - :type cols_to_add: Optional[List[str]], optional - :param reproject: Whether to reproject to a UTM projection before - creating the graph. Default is ``False``. - :type reproject: bool, optional - :param ignore_errors: Whether to ignore all invalid geometries and - attept to create the graph anyway. Default is ``False``. - :type ignore_errors: bool, optional - :param crs_override: Value to override the CRS of the GeoDataFrame. - Default is None. - :type crs_override: Optional[Union[str,int]], optional - - :returns: The adjacency graph of the geometries from `dataframe`. - :rtype: Graph - """ - # Validate geometries before reprojection - if not ignore_errors: - invalid = invalid_geometries(dataframe) - if len(invalid) > 0: - raise GeometryError( - "Invalid geometries at rows {} before " - "reprojection. Consider repairing the affected geometries with " - "`.buffer(0)`, or pass `ignore_errors=True` to attempt to create " - "the graph anyways.".format(invalid) - ) - - # Project the dataframe to an appropriate UTM projection unless - # explicitly told not to. - if reproject: - df = reprojected(dataframe) - if ignore_errors: - invalid_reproj = invalid_geometries(df) - print(invalid_reproj) - if len(invalid_reproj) > 0: - raise GeometryError( - "Invalid geometries at rows {} after " - "reprojection. Consider reloading the GeoDataFrame with " - "`reproject=False` or repairing the affected geometries " - "with `.buffer(0)`.".format(invalid_reproj) - ) - else: - df = dataframe - - # Generate dict of dicts of dicts with shared perimeters according - # to the requested adjacency rule - adjacencies = neighbors(df, adjacency) # Note - this is adjacency.neighbors() - graph = cls(adjacencies) - - graph.geometry = df.geometry - - graph.issue_warnings() - - # Add "exterior" perimeters to the boundary nodes - add_boundary_perimeters(graph, df.geometry) - - # Add area data to the nodes - areas = df.geometry.area.to_dict() - networkx.set_node_attributes(graph, name="area", values=areas) - - graph.add_data(df, columns=cols_to_add) - - if crs_override is not None: - df.set_crs(crs_override, inplace=True) - - if df.crs is None: - warnings.warn( - "GeoDataFrame has no CRS. Did you forget to set it? " - "If you're sure this is correct, you can ignore this warning. " - "Otherwise, please set the CRS using the `crs_override` parameter. " - "Attempting to proceed without a CRS." - ) - graph.graph["crs"] = None - else: - graph.graph["crs"] = df.crs.to_json() - - return graph - - # frm: Original Graph code... - def lookup(self, node: Any, field: Any) -> Any: - """ - Lookup a node/field attribute. - - :param node: Node to look up. - :type node: Any - :param field: Field to look up. - :type field: Any - - :returns: The value of the attribute `field` at `node`. - :rtype: Any - """ - return self.nodes[node][field] - - # frm: Original Graph code... - @property - def node_indices(self): - return set(self.nodes) - - # frm: Original Graph code... - @property - def edge_indices(self): - return set(self.edges) - - # frm: Original Graph code... - def add_data( - self, df: pd.DataFrame, columns: Optional[Iterable[str]] = None - ) -> None: - """ - Add columns of a DataFrame to a graph as node attributes - by matching the DataFrame's index to node ids. - - :param df: Dataframe containing given columns. - :type df: :class:`pandas.DataFrame` - :param columns: List of dataframe column names to add. Default is None. - :type columns: Optional[Iterable[str]], optional - - :returns: None - """ - - if columns is None: - columns = list(df.columns) - - check_dataframe(df[columns]) - - column_dictionaries = df.to_dict("index") - networkx.set_node_attributes(self, column_dictionaries) - - if hasattr(self, "data"): - self.data[columns] = df[columns] # type: ignore - else: - self.data = df[columns] - - # frm: Original Graph code... - def join( - self, - dataframe: pd.DataFrame, - columns: Optional[List[str]] = None, - left_index: Optional[str] = None, - right_index: Optional[str] = None, - ) -> None: - """ - Add data from a dataframe to the graph, matching nodes to rows when - the node's `left_index` attribute equals the row's `right_index` value. - - This is the same as a "join" in SQL: - insert into - select <> from - where . - - :param dataframe: DataFrame. - :type dataframe: :class:`pandas.DataFrame` - :columns: The columns whose data you wish to add to the graph. - If not provided, all columns are added. Default is None. - :type columns: Optional[List[str]], optional - :left_index: The node attribute used to match nodes to rows. - If not provided, node IDs are used. Default is None. - :type left_index: Optional[str], optional - :right_index: The DataFrame column name to use to match rows - to nodes. If not provided, the DataFrame's index is used. Default is None. - :type right_index: Optional[str], optional - - :returns: None - """ - # frm: TODO: Implement this for RX. Note, however, that this is probably - # low priority since this routine is for building a graph - # which for now (summer 2025) will continue to be done using - # NetworkX. That is, this code will not be used after - # freezing the graph when we create a Parition... - if (not self.is_nx_graph()): - raise Exception("join(): Not supported for RX based Graph objects") - - if right_index is not None: - df = dataframe.set_index(right_index) - else: - df = dataframe - - if columns is not None: - df = df[columns] - - check_dataframe(df) - - # Transform the dataframe into a dict of dicts, where - # each column in the df is associated with a dict of - # : values. - column_dictionaries = df.to_dict() - - # Determine what data in the graph to sync up with the - # data in the dataframe. ids_to_index maps node_ids to - # values that select which row in dataframe should be - # associated with that node_id. - if left_index is not None: - # frm: TODO: Figure out how to make this work for RX... - ids_to_index = networkx.get_node_attributes(self, left_index) - else: - # When the left_index is node ID, the matching is just - # a redundant {node: node} dictionary - ids_to_index = dict(zip(self.nodes, self.nodes)) - - # For each column in the dataframe, extract the appropriate entry for - # the given node_id (using index) and wrap it all up in a dict. The - # result is a dict of (name, value) pairs of data from the dataframe - # for each node_id. - node_attributes = { - node_id: { - column: values[index] for column, values in column_dictionaries.items() - } - for node_id, index in ids_to_index.items() - } - - # frm: TODO: Figure out how to make this work for RX... - networkx.set_node_attributes(self, node_attributes) - - # frm: Original Graph code... - @property - def islands(self) -> Set: - """ - :returns: The set of degree-0 nodes. - :rtype: Set - """ - return set(node for node in self if self.degree[node] == 0) - - # frm: Original Graph code... - def warn_for_islands(self) -> None: - """ - :returns: None - - :raises: UserWarning if the graph has any islands (degree-0 nodes). - """ - islands = self.islands - if len(self.islands) > 0: - warnings.warn( - "Found islands (degree-0 nodes). Indices of islands: {}".format(islands) + raise TypeError( + "Graph passed to 'is_a_tree()' is neither a " + "networkx-based graph nor a rustworkx-based graph" ) - # frm: Original Graph code... - def issue_warnings(self) -> None: - """ - :returns: None - - :raises: UserWarning if the graph has any red flags (right now, only islands). - """ - self.warn_for_islands() def add_boundary_perimeters(nx_graph: networkx.Graph, geometries: pd.Series) -> None: """ @@ -2160,7 +2329,7 @@ def add_boundary_perimeters(nx_graph: networkx.Graph, geometries: pd.Series) -> :rtype: Graph """ - # frm: TODO: Think about whether it is reasonable to require this to work + # frm: TODO: add_boundary_perimeters(): Think about whether it is reasonable to require this to work # on an NetworkX.Graph object. # frm: The original code operated on the Graph object which was a subclass of @@ -2169,7 +2338,10 @@ def add_boundary_perimeters(nx_graph: networkx.Graph, geometries: pd.Series) -> # and pass in the inner nx_graph data member. if not(isinstance(nx_graph, networkx.Graph)): - raise Exception("add_boundary_permiters: Graph is not a NetworkX.Graph object") + raise TypeError( + "Graph passed into add_boundary_perimeters() " + "is not a networkx graph" + ) prepared_boundary = prep(unary_union(geometries).boundary) @@ -2263,7 +2435,7 @@ class FrozenGraph: The class uses `__slots__` for improved memory efficiency. """ - # frm: TODO: Rename the internal data member, "graph", to be something else. + # frm: TODO: Code: Rename the internal data member, "graph", to be something else. # The reason is that a NetworkX.Graph object already has an internal # data member named, "graph", which is just a dict for the data # associated with the Networkx.Graph object. @@ -2294,16 +2466,38 @@ def __init__(self, graph: Graph) -> None: # # self.size = len(self.graph) - # frm TODO: Add logic to have this work for RX. + # frm TODO: Code: Add logic to FrozenGraph so that it is indeed "frozen" (for both NX and RX) + # + # I think this just means redefining those methods that change the graph + # to return an error / exception if called. self.graph = graph - # frm: TODO: Not sure this works for RX - self.size = len(self.graph.nodes) + self.size = len(self.graph.node_indices) def __len__(self) -> int: + """ + + + :param : ...text... + ...more text... + :type : + + :returns: ...text... + :rtype: + """ return self.size def __getattribute__(self, __name: str) -> Any: + """ + + + :param : ...text... + ...more text... + :type : + + :returns: ...text... + :rtype: + """ try: return object.__getattribute__(self, __name) except AttributeError: @@ -2311,15 +2505,33 @@ def __getattribute__(self, __name: str) -> Any: return self.graph.__getattribute__(__name) def __getitem__(self, __name: str) -> Any: + """ + + + :param : ...text... + ...more text... + :type : + + :returns: ...text... + :rtype: + """ return self.graph[__name] def __iter__(self) -> Iterable[Any]: + """ + + + :param : ...text... + ...more text... + :type : + + :returns: ...text... + :rtype: + """ yield from self.node_indices @functools.lru_cache(16384) - def neighbors(self, n: Any) -> Tuple[Any, ...]: - # frm: Original Code: - # return tuple(self.graph.neighbors(n)) + def neighbors(self, n: Any) -> tuple[Any, ...]: return self.graph.neighbors(n) @functools.cached_property @@ -2334,10 +2546,5 @@ def edge_indices(self) -> Iterable[Any]: def degree(self, n: Any) -> int: return self.graph.degree(n) - @functools.lru_cache(65536) - def lookup(self, node: Any, field: str) -> Any: - # frm: Original Code: return self.graph.nodes[node][field] - return self.node_data(node)[field] - def subgraph(self, nodes: Iterable[Any]) -> "FrozenGraph": return FrozenGraph(self.graph.subgraph(nodes)) diff --git a/gerrychain/grid.py b/gerrychain/grid.py index 52057f6a..688783ae 100644 --- a/gerrychain/grid.py +++ b/gerrychain/grid.py @@ -15,7 +15,7 @@ import math import networkx -# frm TODO: Clarify what purpose grid.py serves. +# frm TODO: Documentation: Clarify what purpose grid.py serves. # # It is a convenience module to help users create toy graphs. It leverages # NX to create graphs, but it returns new Graph objects. So, legacy user @@ -128,7 +128,7 @@ def __init__( if not assignment: thresholds = tuple(math.floor(n / 2) for n in self.dimensions) assignment = { - node: color_quadrants(node, thresholds) for node in graph.nodes # type: ignore + node_id: color_quadrants(node_id, thresholds) for node_id in graph.node_indices # type: ignore } if not updaters: @@ -165,17 +165,19 @@ def as_list_of_lists(self): return [[self.assignment.mapping[(i, j)] for i in range(m)] for j in range(n)] -# frm TODO: Is this intended to be callable / useful for external users? -# For now, I am going to leave this as operating on NetworkX graphs, since -# it appears to only be used internally in this Class. However, I may discover -# that it has been used externally with the intention of returning a Graph object. -# If so, then I will need to return a Graph object (from_networkx(nx_graphg)) and change -# the call inside this class to expect a Graph object instead of a NetworkX.Graph object. - -# frm: TODO: Decide if I should change this to return a Graph object or not... - -# frm: Original Code - function signature: -# def create_grid_graph(dimensions: Tuple[int, int], with_diagonals: bool) -> Graph: +# frm: Documentation: Document what grid.py is intended to be used for +# +# I will need to do some research, but my guess is that there are two use +# cases: +# +# 1) Testing - make it easy to create tests +# 2) User Code - make it easy for users to play around. +# +# For #1, it is OK to have some routines return NX-Graph objects and some to return new Graph +# objects, but that is probably confusing to users, so the todo list items are: +# +# * Decide whether to support returning NX-based Graphs sometimes and new Graphs others, +# * Document whatever we decide # def _create_grid_nx_graph(dimensions: Tuple[int, int], with_diagonals: bool) -> Graph: @@ -246,9 +248,8 @@ def give_constant_attribute(graph: Graph, attribute: Any, value: Any) -> None: :returns: None """ - for node in graph.nodes: - # frm original code: graph.nodes[node][attribute] = value - graph.node_data(node)[attribute] = value + for node_id in graph.node_indices: + graph.node_data(node_id)[attribute] = value def _tag_boundary_nodes(nx_graph: networkx.Graph, dimensions: Tuple[int, int]) -> None: diff --git a/gerrychain/metagraph.py b/gerrychain/metagraph.py index 9a1fe928..833af8c2 100644 --- a/gerrychain/metagraph.py +++ b/gerrychain/metagraph.py @@ -22,25 +22,19 @@ def all_cut_edge_flips(partition: Partition) -> Iterator[Dict]: Generate all possible flips of cut edges in a partition without any constraints. + This routine finds all edges on the boundary of + districts - those that are "cut edges" where one node + is in one district and the other node is in another + district. These are all of the places where you + could move the boundary between districts by moving a single + node. + :param partition: The partition object. :type partition: Partition :returns: An iterator that yields dictionaries representing the flipped edges. :rtype: Iterator[Dict] """ - # frm: TODO: Add some documentation so a future readef of this code - # will not be as confused as I was... - - # frm: TODO: Why is this an iterator instead of just a dict? - - # frm: For my own edification... It took me a while to understand why - # this routine made sense at a high level. It finds all edges - # on the boundary of districts - those that are "cut edges" - # where one node is in one district and the other node is in - # another district. These are all of the places where you - # could move the boundary between districts by moving a single - # node. Stated differently, these are all of the places where - # you can make a single flip without creating a disconnected - # graph. + for edge, index in product(partition.cut_edges, (0, 1)): yield {edge[index]: partition.assignment.mapping[edge[1 - index]]} diff --git a/gerrychain/metrics/partisan.py b/gerrychain/metrics/partisan.py index 6fc9e9ff..af75fb99 100644 --- a/gerrychain/metrics/partisan.py +++ b/gerrychain/metrics/partisan.py @@ -9,7 +9,7 @@ import numpy from typing import Tuple -# frm: TODO: Why are these not just included in the file that defines ElectionResults? +# frm: TODO: Refactoring: Why are these not just included in the file that defines ElectionResults? def mean_median(election_results) -> float: """ diff --git a/gerrychain/optimization/optimization.py b/gerrychain/optimization/optimization.py index 5b45e3c6..fd7a6a7e 100644 --- a/gerrychain/optimization/optimization.py +++ b/gerrychain/optimization/optimization.py @@ -505,7 +505,7 @@ def tilted_short_bursts( with_progress_bar=with_progress_bar, ) - # TODO: Maybe add a max_time variable so we don't run forever. + # TODO: Refactoring: Maybe add a max_time variable so we don't run forever. def variable_length_short_bursts( self, num_steps: int, diff --git a/gerrychain/partition/assignment.py b/gerrychain/partition/assignment.py index 3844399a..bdbe8ad7 100644 --- a/gerrychain/partition/assignment.py +++ b/gerrychain/partition/assignment.py @@ -156,7 +156,7 @@ def from_dict(cls, assignment: Dict) -> "Assignment": :rtype: Assignment """ - # frm: TODO: Clean up from_dict(). + # frm: TODO: Refactoring: Clean up from_dict(). # # A couple of things: # * It uses a routine, level_sets(), which is only ever used here, so @@ -239,7 +239,7 @@ def get_assignment( :raises TypeError: If the part_assignment is not a string or dictionary. """ - # frm: TODO: Think about whether to split this into two functions. AT + # frm: TODO: Refactoring: Think about whether to split this into two functions. AT # present, it does different things based on whether # the "part_assignment" parameter is a string, a dict, # or an assignment. Probably not worth the trouble (possible @@ -253,7 +253,6 @@ def get_assignment( "You must provide a graph when using a node attribute for the part_assignment" ) return Assignment.from_dict( - # frm: original code: {node: graph.nodes[node][part_assignment] for node in graph} {node: graph.node_data(node)[part_assignment] for node in graph} ) # Check if assignment is a dict or a mapping type diff --git a/gerrychain/partition/partition.py b/gerrychain/partition/partition.py index 0d48412b..d676130f 100644 --- a/gerrychain/partition/partition.py +++ b/gerrychain/partition/partition.py @@ -16,7 +16,7 @@ from ..tree import recursive_tree_part from typing import Any, Callable, Dict, Optional, Tuple -# frm TODO: Add documentation about how this all works. For instance, +# frm TODO: Documentation: Add documentation about how this all works. For instance, # what is computationally expensive and how does a FrozenGraph # help? Why do we need both assignments and parts? # @@ -133,9 +133,8 @@ def from_random_assignment( :returns: The partition created with a random assignment :rtype: Partition """ - # frm: ???: BUG: TODO: The param, flips, is never used in this routine... + # frm: TODO: BUG: The param, flips, is never used in this routine... - # frm: original code: total_pop = sum(graph.nodes[n][pop_col] for n in graph) total_pop = sum(graph.node_data(n)[pop_col] for n in graph) ideal_pop = total_pop / n_parts @@ -168,7 +167,7 @@ def _first_time(self, graph, assignment, updaters, use_default_updaters): # convert to RX - both for legacy compatibility, but also because NX provides # a really nice and easy way to create graphs. # - # TODO: update the documentation + # TODO: Documentation: update the documentation # to describe the use case of creating a graph using NX. That documentation # should also describe how to post-process results of a MarkovChain run # but I haven't figured that out yet... @@ -179,15 +178,21 @@ def _first_time(self, graph, assignment, updaters, use_default_updaters): # if a Graph object, make sure it is based on an embedded RustworkX.PyGraph if isinstance(graph, Graph): - # frm: TODO: Remove this short-term hack to do performance testing + # frm: TODO: Performance: Remove this short-term hack to do performance testing + # + # This "test_performance_using_NX_graph" hack just forces the partition + # to NOT convert the NX graph to be RX based. This allows me to + # compare RX performance to NX performance with the same code - so that + # whatever is different is crystal clear. test_performance_using_NX_graph = False if (graph.is_nx_graph()) and test_performance_using_NX_graph: self.assignment = get_assignment(assignment, graph) + print("=====================================================") print("Performance-Test: using NetworkX for Partition object") + print("=====================================================") elif (graph.is_nx_graph()): - print("Partition: converting NX to RX") # Get the assignment that would be appropriate for the NX-based graph old_nx_assignment = get_assignment(assignment, graph) @@ -208,7 +213,6 @@ def _first_time(self, graph, assignment, updaters, use_default_updaters): self.graph = FrozenGraph(graph) elif isinstance(graph, FrozenGraph): - # frm: TODO: Verify that the embedded graph is RX self.graph = graph self.assignment = get_assignment(assignment, graph) @@ -280,7 +284,7 @@ def flip(self, flips: Dict, use_original_nx_node_ids=False) -> "Partition": :rtype: Partition """ - # frm: TODO: Change comments above to document new optional parameter, use_original_nx_node_ids. + # frm: TODO: Documentation: Change comments above to document new optional parameter, use_original_nx_node_ids. # # This is a new issue that arises from the fact that node_ids in RX are different from those # in the original NX graph. In the pre-RX code, we did not need to distinguish between @@ -335,14 +339,19 @@ def __getitem__(self, key: str) -> Any: # if key not in self._cache: - # frm: TODO: Add code to check that the desired updater actually is - # defined in the list of updaters. If not, then this - # would produce a perhaps difficult to debug problem... + # frm: TODO: Testing: Add a test checking what happens if no updater defined + # + # This code checks that the desired updater actually is + # defined in the list of updaters. If not, then this + # would produce a perhaps difficult to debug problem... + if key not in self.updaters: + raise KeyError(f"__getitem__(): updater: {key} not defined in the updaters for the partition") + self._cache[key] = self.updaters[key](self) return self._cache[key] def __getattr__(self, key): - # frm TODO: Not sure it makes sense to allow two ways to accomplish the same thing... + # frm TODO: Refactor: Not sure it makes sense to allow two ways to accomplish the same thing... # # The code below allows Partition users to get the results of updaters by just # doing: partition. which is the same as doing: partition[""] @@ -403,7 +412,7 @@ def plot(self, geometries=None, **kwargs): else: raise Exception("Partition.plot: graph has no geometry data") - if set(geometries.index) != set(self.graph.nodes): + if set(geometries.index) != self.graph.node_indices: raise TypeError( "The provided geometries do not match the nodes of the graph." ) @@ -450,7 +459,6 @@ def from_districtr_file( id_column_key = districtr_plan["idColumn"]["key"] districtr_assignment = districtr_plan["assignment"] try: - # frm: original code: node_to_id = {node: str(graph.nodes[node][id_column_key]) for node in graph} node_to_id = {node: str(graph.node_data(node)[id_column_key]) for node in graph} except KeyError: raise TypeError( @@ -458,7 +466,8 @@ def from_districtr_file( "needed to match the Districtr assignment to the nodes of the graph." ) - # frm: TODO: NX vs. RX issues: does "node in graph" work for both NX and RX? - assignment = {node: districtr_assignment[node_to_id[node]] for node in graph} + # frm: TODO: Testing: Verify that there is a test for from_districtr_file() + + assignment = {node_id: districtr_assignment[node_to_id[node_id]] for node_id in graph.node_indices} return cls(graph, assignment, updaters) diff --git a/gerrychain/proposals/spectral_proposals.py b/gerrychain/proposals/spectral_proposals.py index ffc1c5e3..ee981e56 100644 --- a/gerrychain/proposals/spectral_proposals.py +++ b/gerrychain/proposals/spectral_proposals.py @@ -33,7 +33,7 @@ def spectral_cut( # the return value's node_ids need to be translated back into the appropriate # parent node_ids. - node_list = list(subgraph.nodes) + node_list = list(subgraph.node_indices) num_nodes = len(node_list) if weight_type == "random": @@ -47,7 +47,7 @@ def spectral_cut( else: laplacian_matrix = (subgraph.laplacian_matrix()).todense() - # frm TODO: Add a better explanation for why eigenvectors are useful + # frm TODO: Documentation: Add a better explanation for why eigenvectors are useful # for determining flips. Perhaps just a URL to an article # somewhere... # diff --git a/gerrychain/proposals/tree_proposals.py b/gerrychain/proposals/tree_proposals.py index 02a59043..d6e1568b 100644 --- a/gerrychain/proposals/tree_proposals.py +++ b/gerrychain/proposals/tree_proposals.py @@ -108,7 +108,7 @@ def recom( # find one that can be split, or you have tried all possible pairs # of adjacent districts... try: - # frm: TODO: see if there is some way to avoid a while True loop... + # frm: TODO: Refactoring: see if there is some way to avoid a while True loop... while True: edge = random.choice(tuple(partition["cut_edges"])) # Need to sort the tuple so that the order is consistent @@ -254,7 +254,7 @@ def bounded_balance_edge_fn(*args, **kwargs): for out_part in parts: for in_part in parts: dist_pairs.append((out_part, in_part)) - # frm: TODO: Grok why this code considers pairs that are the same part... + # frm: TODO: Code: ???: Grok why this code considers pairs that are the same part... # # For instance, if there are only two parts (districts), then this code will # produce four pairs: (0,0), (0,1), (1,0), (1,1). The code below tests @@ -274,7 +274,7 @@ def bounded_balance_edge_fn(*args, **kwargs): if random_pair[0] == random_pair[1] or not pair_edges: return partition # self-loop: no adjacency - # frm: TODO: Grok why it is OK to return the partition unchanged as the next step. + # frm: TODO: Code: ???: Grok why it is OK to return the partition unchanged as the next step. # # This runs the risk of running an entire chain without ever changing the partition. # I assume that the logic is that there is deliberate randomness introduced each time, @@ -303,11 +303,6 @@ def bounded_balance_edge_fn(*args, **kwargs): # the subgraph's node_ids afterwards. # - # frm: Original Code: - # num_possible_districts, nodes = bipartition_tree_random_reversible( - # partition.graph.subgraph(subgraph_nodes), - # pop_col=pop_col, pop_target=pop_target, epsilon=epsilon - # ) result = bipartition_tree_random_reversible( partition.graph.subgraph(subgraph_nodes), pop_col=pop_col, pop_target=pop_target, epsilon=epsilon @@ -318,7 +313,9 @@ def bounded_balance_edge_fn(*args, **kwargs): num_possible_districts, nodes = result remaining_nodes = subgraph_nodes - set(nodes) - # Note: the ** operator below merges the two dicts into a single dict. + # Note: Clever way to create a single dictionary from + # two dictionaries - the ** operator unpacks each dictionary + # and then they get merged into a new dictionary. flips = { **{node: parts_to_merge[0] for node in nodes}, **{node: parts_to_merge[1] for node in remaining_nodes}, @@ -339,7 +336,7 @@ def bounded_balance_edge_fn(*args, **kwargs): return partition # self-loop -# frm TODO: I do not think that ReCom() is ever called. Note that it +# frm TODO: Refactoring: I do not think that ReCom() is ever called. Note that it # only defines a constructor and a __call__() which would allow # you to call the recom() function by creating a ReCom object and then # "calling" that object - why not just call the recom function? diff --git a/gerrychain/tree.py b/gerrychain/tree.py index 3d55837e..bf3beb3e 100644 --- a/gerrychain/tree.py +++ b/gerrychain/tree.py @@ -64,12 +64,12 @@ import rustworkx as rx import numpy as np from scipy.sparse import csr_array -# frm TODO: Remove import of networkx and rustworkx once we have moved networkx +# frm: TODO: Refactoring: Remove import of networkx and rustworkx once we have moved networkx # dependencies out of this file - see comments below on # spanning trees. import networkx.algorithms.tree as nxtree -# frm TODO: Remove import of "tree" from networkx.algorithms in this file +# frm: TODO: Refactoring Remove import of "tree" from networkx.algorithms in this file # It is only used to get a spanning tree function: # # spanning_tree = nxtree.minimum_spanning_tree( @@ -104,23 +104,35 @@ # frm: import the new Graph object which encapsulates NX and RX Graph... from .graph import Graph -# frm TODO: Update function param docmentation to get rid of nx.Graph and use just Graph +# frm TODO: Documentation: Update function param docmentation to get rid of nx.Graph and use just Graph + +# frm TODO: Documentation: Migration Guide: tree.py is no longer a general purpose module - it is GerryChain specific +# +# Before the work to integrate RX, many of the routines ij tree.py +# operated on NetworkX Graph objects, which meant that the module +# was not bound to just GerryChain work - someone could conceivably +# have used it for a graph oriented project that had nothing to do +# with GerryChain or redistricting. +# +# That is no lnoger true, as the parameters to the routines have +# been changed to be GerryChain Graph objects which are not subclasses +# of NetworkX Graph objects. def random_spanning_tree( - graph: Graph, # frm: Original code: graph: x.Graph, + graph: Graph, region_surcharge: Optional[Dict] = None -) -> Graph: # frm: Original code: ) -> nx.Graph: +) -> Graph: """ Builds a spanning tree chosen by Kruskal's method using random weights. - :param graph: The input graph to build the spanning tree from. Should be a Networkx Graph. - :type graph: nx.Graph + :param graph: The input graph to build the spanning tree from. + :type graph: Graph :param region_surcharge: Dictionary of surcharges to add to the random weights used in region-aware variants. :type region_surcharge: Optional[Dict], optional - :returns: The maximal spanning tree represented as a Networkx Graph. - :rtype: nx.Graph + :returns: The maximal spanning tree represented as a GerryChain Graph. + :rtype: Graph """ # frm: TODO: Performance # This seems to me to be an expensive way to build a random spanning @@ -135,6 +147,9 @@ def random_spanning_tree( # would seem to be a much cheaper way to calculate a spanning tree. # # What am I missing??? + # + # The region_surcharge allows the caller to tweak the ramdommess + # which might be useful... """ frm: RX Documentation: @@ -149,7 +164,7 @@ def random_spanning_tree( node_ids for this function and all will be well... """ - # frm: TODO: WTF is up with region_surcharge being unset? The region_surcharge + # frm: TODO: Refactoring: WTF is up with region_surcharge being unset? The region_surcharge # is only ever accessed in this routine in the for-loop below to # increase the weight on the edge - setting it to be an empty dict # just prevents the code below from blowing up. Why not just put @@ -193,9 +208,6 @@ def random_spanning_tree( if region_surcharge is None: region_surcharge = dict() - # frm: Original Code: for edge in graph.edges(): - # Changed because in RX edge_ids are integers while edges are tuples - # Add a random weight to each edge in the graph with the goal of # causing the selection of a different (random) spanning tree based # on those weights. @@ -212,7 +224,7 @@ def random_spanning_tree( # spanning_tree algorithm to select other edges... which would have # the effect of prioritizing keeping regions intact. - # frm: TODO: Verify that the comment above about region_surcharge is accurate + # frm: TODO: Documentation: Verify that the comment above about region_surcharge is accurate # Add random weights to the edges in the graph so that the spanning tree # algorithm will select a different spanning tree each time. @@ -228,25 +240,15 @@ def random_spanning_tree( for key, value in region_surcharge.items(): # We surcharge edges that cross regions and those that are not in any region if ( - # frm: original code: graph.nodes[edge[0]][key] != graph.nodes[edge[1]][key] - # frm: original code: or graph.nodes[edge[0]][key] is None - # frm: original code: or graph.nodes[edge[1]][key] is None graph.node_data(edge[0])[key] != graph.node_data(edge[1])[key] or graph.node_data(edge[0])[key] is None or graph.node_data(edge[1])[key] is None ): weight += value - # frm: Original Code: graph.edges[edge]["random_weight"] = weight graph.edge_data(edge_id)["random_weight"] = weight - # frm: TODO: Think about (and at least document) the fact that edge_data (and node_data) - # is shared by all partitions. So, as we process a chain of partitions, we are - # accessing the same underlying graph, and if we muck with edge_data and node_data - # then we are changing that data for all partitions. Stated differently, - # edge_data and node_data should be considered temporary and not persistent... - - # frm: TODO: CROCK: (for the moment) + # frm: TODO: Refactoring: Code: CROCK: (for the moment) # We need to create a minimum spanning tree but the way to do so # is different for NX and RX. I am sure that there is a more elegant # way to do this, and in any event, this dependence on NX vs RX @@ -255,9 +257,9 @@ def random_spanning_tree( graph.verify_graph_is_valid() - # frm: TODO: Remove NX / RX dependency - maybe move to graph.py + # frm: TODO: Refactoring: Remove NX / RX dependency - maybe move to graph.py - # frm: TODO: Think a bit about original_nx_node_ids + # frm: TODO: Documentation: Think a bit about original_nx_node_ids # # Original node_ids refer to the node_ids used when a graph was created. # This mostly means remembering the NX node_ids when you create an RX @@ -287,7 +289,6 @@ def get_weight(edge_data): return spanningGraph def uniform_spanning_tree( - # frm: Original code: graph: nx.Graph, choice: Callable = random.choice graph: Graph, choice: Callable = random.choice ) -> Graph: @@ -295,13 +296,13 @@ def uniform_spanning_tree( Builds a spanning tree chosen uniformly from the space of all spanning trees of the graph. Uses Wilson's algorithm. - :param graph: Networkx Graph - :type graph: nx.Graph + :param graph: Graph + :type graph: Graph :param choice: :func:`random.choice`. Defaults to :func:`random.choice`. :type choice: Callable, optional :returns: A spanning tree of the graph chosen uniformly at random. - :rtype: nx.Graph + :rtype: Graph """ """ @@ -342,9 +343,8 @@ def uniform_spanning_tree( # frm DONE: To support RX, I added an add_edge() method to Graph. - # frm: TODO: Remove dependency on NX below + # frm: TODO: Refactoring: Remove dependency on NX below - # frm: Original code: G = nx.Graph() nx_graph = nx.Graph() G = Graph.from_networkx(nx_graph) @@ -355,17 +355,21 @@ def uniform_spanning_tree( return G -# frm TODO +# frm TODO: Documentation: PopulatedGraph - state that this only exists in tree.py +# +# I think that this is only ever used inside this module (except) +# for testing. # -# I think that this is only ever used inside this module (except) -# for testing. +# Decide if this is intended to only ever be used inside tree.py (and for testing), +# and if so: 1) document that fact and 2) see if there is any Pythonic convention +# for a class that is intended to NOT be used externally (like a leading underscore) # class PopulatedGraph: """ A class representing a graph with population information. :ivar graph: The underlying graph structure. - :type graph: nx.Graph + :type graph: Graph :ivar subsets: A dictionary mapping nodes to their subsets. :type subsets: Dict :ivar population: A dictionary mapping nodes to their populations. @@ -381,14 +385,14 @@ class PopulatedGraph: def __init__( self, - graph: Graph, # frm: Original code: graph: nx.Graph, + graph: Graph, populations: Dict, ideal_pop: Union[float, int], epsilon: float, ) -> None: """ :param graph: The underlying graph structure. - :type graph: nx.Graph + :type graph: Graph :param populations: A dictionary mapping nodes to their populations. :type populations: Dict :param ideal_pop: The ideal population for each district. @@ -398,38 +402,45 @@ def __init__( :type epsilon: float """ self.graph = graph - self.subsets = {node: {node} for node in graph.nodes} + self.subsets = {node_id: {node_id} for node_id in graph.node_indices} self.population = populations.copy() self.tot_pop = sum(self.population.values()) self.ideal_pop = ideal_pop self.epsilon = epsilon - self._degrees = {node: graph.degree(node) for node in graph.nodes} + self._degrees = {node_id: graph.degree(node_id) for node_id in graph.node_indices} + + # frm: TODO: Refactor: _degrees ??? Why separately store the degree of every node? + # + # The _degrees data member above is used to define a method below called "degree()" + # What is odd is that the implementation of this degree() method could just as + # easily have been self.graph.degree(node_id). And in fact, every call on the + # new degree function could be replaced with just .graph.degree(node_id) + # + # So unless there is a big performace gain (or some other reason), I would be + # in favor of deleting the degree() method below and just using + # .graph.degree(node_id) on the assumption that both NX and RX + # have an efficient implementation of degree()... + - # frm TODO: Verify that this does the right thing for the new Graph object def __iter__(self): - return iter(self.graph) + # Note: in the pre RustworkX code, this was implemented as: + # + # return iter(self.graph) + # + # But RustworkX does not support __iter__() - it is not iterable. + # + # The way to do this in the new RustworkX based code is to use + # the node_indices() method which is accessed as a property as in: + # + # for node_id in graph.node_indices: + # ...do something with the node_id + # + raise NotImplementedError("Graph is not iterable - use graph.node_indices instead") def degree(self, node) -> int: return self._degrees[node] - # frm: only ever used inside this file - # But maybe this is intended to be used externally... - # - # In PR: Peter said: - # - # We do use this external to the class when calling find_balance_edge_cuts_contraction - - - # frm: ???: What the fat does this do? Start with what a population is. It - # appears to be indexed by node. Also, what is a subset? GRRRR... def contract_node(self, node, parent) -> None: - # frm: ???: TODO: This routine is only used once, so why have a separate - # routine - why not just include this code inline where - # the function is now called? It would be simpler to read - # inline than having to go find this definition. - # - # Perhaps it is of use externally, but that seems doubtful... - self.population[parent] += self.population[node] self.subsets[parent] |= self.subsets[node] self._degrees[parent] -= 1 @@ -453,19 +464,21 @@ def has_ideal_population(self, node, one_sided_cut: bool = False) -> bool: :rtype: bool """ - # frm: ???: TODO: this logic is repeated several times in this file. Consider - # refactoring the code so that the logic lives in exactly - # one place. + # frm: TODO: Refactoring: Create a helper function for this + # + # This logic is repeated several times in this file. Consider + # refactoring the code so that the logic lives in exactly + # one place. # - # When thinking about refactoring, consider whether it makes - # sense to toggle what this routine does by the "one_sided_cut" - # parameter. Why not have two separate routines with - # similar but distinguishing names. I need to be absolutely - # clear about what the two cases are all about, but my current - # hypothesis is that when one_sided_cut == False, we are looking - # for the edge which when cut produces two districts of - # approximately equal size - so a bisect rather than a find all - # meaning... + # When thinking about refactoring, consider whether it makes + # sense to toggle what this routine does by the "one_sided_cut" + # parameter. Why not have two separate routines with + # similar but distinguishing names. I need to be absolutely + # clear about what the two cases are all about, but my current + # hypothesis is that when one_sided_cut == False, we are looking + # for the edge which when cut produces two districts of + # approximately equal size - so a bisect rather than a find all + # meaning... if one_sided_cut: return ( @@ -481,7 +494,7 @@ def has_ideal_population(self, node, one_sided_cut: bool = False) -> bool: def __repr__(self) -> str: graph_info = ( - f"Graph(nodes={len(self.graph.nodes)}, edges={len(self.graph.edges)})" + f"Graph(nodes={len(self.graph.node_indices)}, edges={len(self.graph.edges)})" ) return ( f"{self.__class__.__name__}(" @@ -505,7 +518,9 @@ def __repr__(self) -> str: "The (frozen) subset of nodes on one side of the cut. Defaults to None." ) -# frm: TODO: Not sure how this is used, and so I do not know whether it needs +# frm: TODO: Documentation: Document what Cut objects are used for +# +# Not sure how this is used, and so I do not know whether it needs # to translate node_ids to the parent_node_id context. I am assuming not... # # Here is an example of how it is used (in test_tree.py): @@ -545,9 +560,8 @@ def find_balanced_edge_cuts_contraction( :rtype: List[Cut] """ - root = choice([x for x in h if h.degree(x) > 1]) + root = choice([node_id for node_id in h.graph.node_indices if h.degree(node_id) > 1]) # BFS predecessors for iteratively contracting leaves - # frm: Original code: pred = predecessors(h.graph, root) pred = h.graph.predecessors(root) cuts = [] @@ -567,7 +581,7 @@ def find_balanced_edge_cuts_contraction( # that does something similar (perhaps exactly the same). # Need to figure out why there are more than one way to do this... - leaves = deque(x for x in h if h.degree(x) == 1) + leaves = deque(node_id for node_id in h.graph.node_indices if h.degree(node_id) == 1) while len(leaves) > 0: leaf = leaves.popleft() if h.has_ideal_population(leaf, one_sided_cut=one_sided_cut): @@ -578,7 +592,6 @@ def find_balanced_edge_cuts_contraction( cuts.append( Cut( edge=e, - # frm: Original Code: weight=h.graph.edges[e].get("random_weight", random.random()), weight=h.graph.edge_data( h.graph.get_edge_id_from_edge(e) ).get("random_weight", random.random()), @@ -684,11 +697,11 @@ def _part_nodes(start, succ): merely starts at the root of the subtree (start) and goes down the subtree, adding each node to a set. - frm: ???: TODO: Rename this to be more descriptive - perhaps ] + frm: TODO: Documentation: Rename this to be more descriptive - perhaps ] something like: _nodes_in_subtree() or _nodes_for_cut() - frm: TODO: Add the above explanation for what a Cut is and how + frm: TODO: Documentation: Add the above explanation for what a Cut is and how we find them by converting the graph to a DAG and then looking for subtrees to a block header at the top of this file. It will give the reader some @@ -758,9 +771,7 @@ def find_balanced_edge_cuts_memoization( # frm: ???: Why does a root have to have degree > 1? I would think that any node would do... - root = choice([x for x in h if h.degree(x) > 1]) - # frm: Original code: pred = predecessors(h.graph, root) - # frm: Original code: succ = successors(h.graph, root) + root = choice([node_id for node_id in h.graph.node_indices if h.degree(node_id) > 1]) pred = h.graph.predecessors(root) succ = h.graph.successors(root) total_pop = h.tot_pop @@ -782,8 +793,6 @@ def find_balanced_edge_cuts_memoization( cuts.append( Cut( edge=e, - # frm: Original Code: weight=h.graph.edges[e].get("random_weight", wt), - # frm: edges vs. edge_ids: edge_ids are wanted here (integers) weight=h.graph.edge_data( h.graph.get_edge_id_from_edge(e) ).get("random_weight", wt), @@ -798,18 +807,16 @@ def find_balanced_edge_cuts_memoization( cuts.append( Cut( edge=e, - # frm: Original Code: weight=h.graph.edges[e].get("random_weight", wt), - # frm: edges vs. edge_ids: edge_ids are wanted here (integers) weight=h.graph.edge_data( h.graph.get_edge_id_from_edge(e) ).get("random_weight", wt), - subset=frozenset(set(h.graph.nodes) - _part_nodes(node, succ)), + subset=frozenset(set(h.graph.node_indices) - _part_nodes(node, succ)), ) ) return cuts - # frm: TODO: Refactor this code to make its two use cases clearer: + # frm: TODO: Refactoring: this code to make its two use cases clearer: # # One use case is bisecting the graph (one_sided_cut is False). The # other use case is to peel off one part (district) with the appropriate @@ -832,12 +839,10 @@ def find_balanced_edge_cuts_memoization( cuts.append( Cut( edge=e, - # frm: Original Code: weight=h.graph.edges[e].get("random_weight", wt), - # frm: edges vs. edge_ids: edge_ids are wanted here (integers) weight=h.graph.edge_data( h.graph.get_edge_id_from_edge(e) ).get("random_weight", wt), - subset=frozenset(set(h.graph.nodes) - _part_nodes(node, succ)), + subset=frozenset(set(h.graph.node_indices) - _part_nodes(node, succ)), ) ) return cuts @@ -907,9 +912,11 @@ def _max_weight_choice(cut_edge_list: List[Cut]) -> Cut: return max(cut_edge_list, key=lambda cut: cut.weight) -# frm: ???: Only ever used once... -# frm: ???: TODO: Figure out what this does. There is no NX/RX issue here, I just +# frm: TODO: Documentation: document what _power_set_sorted_by_size_then_sum() does +# +# Figure out what this does. There is no NX/RX issue here, I just # don't yet know what it does or why... +# Note that this is only ever used once... def _power_set_sorted_by_size_then_sum(d): power_set = [ s for i in range(1, len(d) + 1) for s in itertools.combinations(d.keys(), i) @@ -999,8 +1006,6 @@ def _region_preferred_max_weight_choice( # } # key: ( - # frm: original code: populated_graph.graph.nodes[cut.edge[0]].get(key), - # frm: original code: populated_graph.graph.nodes[cut.edge[1]].get(key), populated_graph.graph.node_data(cut.edge[0]).get(key), populated_graph.graph.node_data(cut.edge[1]).get(key), ) @@ -1026,7 +1031,7 @@ def _region_preferred_max_weight_choice( return _max_weight_choice(cut_edge_list) -# frm TODO: def bipartition_tree( +# frm TODO: Refactoring: def bipartition_tree( # # This might get complicated depending on what kinds of functions # are used as parameters. That is, do the functions used as parameters @@ -1057,12 +1062,12 @@ def _region_preferred_max_weight_choice( # def bipartition_tree( - subgraph_to_split: Graph, # frm: Original code: graph: nx.Graph, + subgraph_to_split: Graph, pop_col: str, pop_target: Union[int, float], epsilon: float, node_repeats: int = 1, - spanning_tree: Optional[Graph] = None, # frm: Original code: spanning_tree: Optional[nx.Graph] = None, + spanning_tree: Optional[Graph] = None, spanning_tree_fn: Callable = random_spanning_tree, region_surcharge: Optional[Dict] = None, balance_edge_fn: Callable = find_balanced_edge_cuts_memoization, @@ -1073,7 +1078,7 @@ def bipartition_tree( allow_pair_reselection: bool = False, cut_choice: Callable = _region_preferred_max_weight_choice, ) -> Set: - # frm: TODO: Change the names of ALL function formal parameters to end in "_fn" - to make it clear + # frm: TODO: Refactoring: Change the names of ALL function formal parameters to end in "_fn" - to make it clear # that the paraemter is a function. This will make it easier to do a global search # to find all function parameters - as well as just being good coding practice... """ @@ -1086,7 +1091,7 @@ def bipartition_tree( is ``epsilon * pop_target`` away from ``pop_target``. :param graph: The graph to partition. - :type graph: nx.Graph + :type graph: Graph :param pop_col: The node attribute holding the population of each node. :type pop_col: str :param pop_target: The target population for the returned subset of nodes. @@ -1099,7 +1104,7 @@ def bipartition_tree( :type node_repeats: int, optional :param spanning_tree: The spanning tree for the algorithm to use (used when the algorithm chooses a new root and for testing). - :type spanning_tree: Optional[nx.Graph], optional + :type spanning_tree: Optional[Graph], optional :param spanning_tree_fn: The random spanning tree algorithm to use if a spanning tree is not provided. Defaults to :func:`random_spanning_tree`. :type spanning_tree_fn: Callable, optional @@ -1162,26 +1167,38 @@ def bipartition_tree( if "one_sided_cut" in signature(balance_edge_fn).parameters: balance_edge_fn = partial(balance_edge_fn, one_sided_cut=one_sided_cut) - # frm: original code: populations = {node: graph.nodes[node][pop_col] for node in graph.node_indices} - populations = {node_id: subgraph_to_split.node_data(node_id)[pop_col] for node_id in subgraph_to_split.node_indices} + # dict of node_id: population for the nodes in the subgraph + populations = { + node_id: subgraph_to_split.node_data(node_id)[pop_col] + for node_id in subgraph_to_split.node_indices + } + + # frm: TODO: Debugging: Remove debugging code + # print(" ") + # print(f"bipartition_tree(): Entering...") + # print(f"bipartition_tree(): balance_edge_fn is: {balance_edge_fn}") + # print(f"bipartition_tree(): spanning_tree_fn is: {spanning_tree_fn}") + # print(f"bipartition_tree(): populations in subgraph are: {populations}") possible_cuts: List[Cut] = [] if spanning_tree is None: - # frm TODO: Make sure spanning_tree_fn operates on new Graph object spanning_tree = spanning_tree_fn(subgraph_to_split) + # print(" ") + # print(f"bipartition_tree(): subgraph edges: {subgraph_to_split.edges}") + # print(f"bipartition_tree(): initial spanning_tree edges: {spanning_tree.edges}") + restarts = 0 attempts = 0 while max_attempts is None or attempts < max_attempts: if restarts == node_repeats: - # frm TODO: Make sure spanning_tree_fn operates on new Graph object - # frm: ???: Not sure what this if-stmt is for... spanning_tree = spanning_tree_fn(subgraph_to_split) + # print(f"bipartition_tree(): new spanning_tree edges: {spanning_tree.edges}") restarts = 0 h = PopulatedGraph(spanning_tree, populations, pop_target, epsilon) - # frm: ???: TODO: Again - we should NOT be changing semantics based + # frm: TODO: Refactoring: Again - we should NOT be changing semantics based # on the names in signatures... # Better approach is to have all of the poosible paramters exist # in ALL of the versions of the cut_choice() functions and to @@ -1220,6 +1237,9 @@ def bipartition_tree( # This returns a list of Cut objects with attributes edge and subset possible_cuts = balance_edge_fn(h, choice=choice) + # frm: TODO: Debugging: Remove debugging code below + # print(f"bipartition_tree(): possible_cuts = {possible_cuts}") + # frm: RX Subgraph if len(possible_cuts) != 0: chosen_cut = None @@ -1230,6 +1250,7 @@ def bipartition_tree( translated_nodes = subgraph_to_split.translate_subgraph_node_ids_for_set_of_nodes( chosen_cut.subset ) + # print(f"bipartition_tree(): translated_nodes = {translated_nodes}") # frm: Not sure if it is important that the returned set be a frozenset... return frozenset(translated_nodes) @@ -1254,48 +1275,32 @@ def bipartition_tree( raise RuntimeError(f"Could not find a possible cut after {max_attempts} attempts.") - -# frm TODO: Note: Re: _bipartition_tree_random_all() -# -# There were a couple of interesting issues surrounding this routine in the original code -# related to subgraphs. The question was whether or not to translate HERE - -# frm: WTF: TODO: This function has a leading underscore indicating that it is a private -# function, but in fact it is used in tree_proposals.py... It also returns -# Cuts which I had hoped would be an internal data structure, but... -# frm: RX-TODO This is called in tree_proposals.py with a subgraph, so it needs to -# return translated Cut objects. However, it is also called internally in -# this code. I need to make sure that I do not translate the node_ids to the -# parent_node_ids twice. At present, they are converted in this file by the -# caller, but that won't work in tree_proposals.py, because there it is called -# with a subgraph, so it would be too late to try to do it in the caller. -# -# Two options: 1) Have this routine do the translation and then comment the -# crap out of the call in this file to make sure we do NOT translate them again, or -# 2) figure out a way to get this OUT of tree_proposals.py where it seems it should -# not be in the first place... -# def _bipartition_tree_random_all( - # frm: Note: Changed the name from "graph" to "subgraph_to_split" to remind any future readers - # of the code that the graph passed in is not the partition's graph, and - # that any node_ids passed back should be translated into parent_node_ids. - subgraph_to_split: Graph, # frm: Original code: graph: nx.Graph, + # + # Note: Complexity Alert... _bipartition_tree_random_all does NOT translate node_ids to parent + # + # Unlike many/most of the routines in this module, _bipartition_tree_random_all() does + # not translate node_ids into the IDs of the parent, because calls to it are not made + # on subgraphs. That is, it returns possible Cuts using the same node_ids as the parent. + # It is up to the caller to translate node_ids (if appropriate). + # + graph_to_split: Graph, pop_col: str, pop_target: Union[int, float], epsilon: float, node_repeats: int = 1, repeat_until_valid: bool = True, - spanning_tree: Optional[Graph] = None, # frm: Original code: spanning_tree: Optional[nx.Graph] = None, + spanning_tree: Optional[Graph] = None, spanning_tree_fn: Callable = random_spanning_tree, balance_edge_fn: Callable = find_balanced_edge_cuts_memoization, choice: Callable = random.choice, max_attempts: Optional[int] = 100000, -) -> List[Tuple[Hashable, Hashable]]: # frm: TODO: Change this to be a set of node_ids (ints) +) -> List[Tuple[Hashable, Hashable]]: # frm: TODO: Documentation: Change this to be a set of node_ids (ints) """ Randomly bipartitions a tree into two subgraphs until a valid bipartition is found. :param graph: The input graph. - :type graph: nx.Graph + :type graph: Graph :param pop_col: The name of the column in the graph nodes that contains the population data. :type pop_col: str :param pop_target: The target population for each subgraph. @@ -1310,7 +1315,7 @@ def _bipartition_tree_random_all( :type repeat_until_valid: bool, optional :param spanning_tree: The spanning tree to use for bipartitioning. If None, a random spanning tree will be generated. Defaults to None. - :type spanning_tree: Optional[nx.Graph], optional + :type spanning_tree: Optional[Graph], optional :param spanning_tree_fn: The function to generate a spanning tree. Defaults to random_spanning_tree. :type spanning_tree_fn: Callable, optional @@ -1330,31 +1335,27 @@ def _bipartition_tree_random_all( attempts. """ - # frm: original code: populations = {node: graph.nodes[node][pop_col] for node in graph.node_indices} + # dict of node_id: population for the nodes in the subgraph populations = { - node_id: subgraph_to_split.node_data(node_id)[pop_col] - for node_id in subgraph_to_split.node_indices + node_id: graph_to_split.node_data(node_id)[pop_col] + for node_id in graph_to_split.node_indices } possible_cuts = [] if spanning_tree is None: - # frm TODO: Make sure spanning_tree_fn works on new Graph object - spanning_tree = spanning_tree_fn(subgraph_to_split) + spanning_tree = spanning_tree_fn(graph_to_split) restarts = 0 attempts = 0 while max_attempts is None or attempts < max_attempts: if restarts == node_repeats: - # frm TODO: Make sure spanning_tree_fn works on new Graph object - spanning_tree = spanning_tree_fn(subgraph_to_split) + spanning_tree = spanning_tree_fn(graph_to_split) restarts = 0 h = PopulatedGraph(spanning_tree, populations, pop_target, epsilon) possible_cuts = balance_edge_fn(h, choice=choice) - # frm: RX-TODO: Translate cuts into node_id context of the parent. if not (repeat_until_valid and len(possible_cuts) == 0): - # frm: TODO: Remove deubgging code: return possible_cuts restarts += 1 @@ -1379,13 +1380,13 @@ def _bipartition_tree_random_all( # revisit when we decide a general code cleanup is in order... # def bipartition_tree_random_with_num_cuts( - graph: Graph, # frm: Original code: graph: nx.Graph, + graph: Graph, pop_col: str, pop_target: Union[int, float], epsilon: float, node_repeats: int = 1, repeat_until_valid: bool = True, - spanning_tree: Optional[Graph] = None, # frm: Original code: spanning_tree: Optional[nx.Graph] = None, + spanning_tree: Optional[Graph] = None, spanning_tree_fn: Callable = random_spanning_tree, balance_edge_fn: Callable = find_balanced_edge_cuts_memoization, one_sided_cut: bool = False, @@ -1405,7 +1406,7 @@ def bipartition_tree_random_with_num_cuts( is ``epsilon * pop_target`` away from ``pop_target``. :param graph: The graph to partition. - :type graph: nx.Graph + :type graph: Graph :param pop_col: The node attribute holding the population of each node. :type pop_col: str :param pop_target: The target population for the returned subset of nodes. @@ -1423,7 +1424,7 @@ def bipartition_tree_random_with_num_cuts( :type repeat_until_valid: bool, optional :param spanning_tree: The spanning tree for the algorithm to use (used when the algorithm chooses a new root and for testing). Defaults to None. - :type spanning_tree: Optional[nx.Graph], optional + :type spanning_tree: Optional[Graph], optional :param spanning_tree_fn: The random spanning tree algorithm to use if a spanning tree is not provided. Defaults to :func:`random_spanning_tree`. :type spanning_tree_fn: Callable, optional @@ -1449,12 +1450,12 @@ def bipartition_tree_random_with_num_cuts( :rtype: Union[Set[Any], None] """ - # frm: ???: TODO: Again - semantics should not depend on signatures... + # frm: TODO: Refactoring: Again - semantics should not depend on signatures... if "one_sided_cut" in signature(balance_edge_fn).parameters: balance_edge_fn = partial(balance_edge_fn, one_sided_cut=True) possible_cuts = _bipartition_tree_random_all( - subgraph_to_split=graph, + graph_to_split=graph, pop_col=pop_col, pop_target=pop_target, epsilon=epsilon, @@ -1472,19 +1473,18 @@ def bipartition_tree_random_with_num_cuts( parent_nodes = graph.translate_subgraph_node_ids_for_set_of_nodes(chosen_cut.subset) return num_cuts, frozenset(parent_nodes) # frm: Not sure if important that it be frozenset else: - # frm: TODO: Grok when this returns None and why and what the calling code does and why... return None ####################### -# frm TODO: RX version NYI... +# frm TODO: Testing: Check to make sure there is a test for this... def bipartition_tree_random( - subgraph_to_split: Graph, # frm: Original code: graph: nx.Graph, + subgraph_to_split: Graph, pop_col: str, pop_target: Union[int, float], epsilon: float, node_repeats: int = 1, repeat_until_valid: bool = True, - spanning_tree: Optional[Graph] = None, # frm: Original code: spanning_tree: Optional[nx.Graph] = None, + spanning_tree: Optional[Graph] = None, spanning_tree_fn: Callable = random_spanning_tree, balance_edge_fn: Callable = find_balanced_edge_cuts_memoization, one_sided_cut: bool = False, @@ -1504,7 +1504,7 @@ def bipartition_tree_random( is ``epsilon * pop_target`` away from ``pop_target``. :param graph: The graph to partition. - :type graph: nx.Graph + :type graph: Graph :param pop_col: The node attribute holding the population of each node. :type pop_col: str :param pop_target: The target population for the returned subset of nodes. @@ -1522,7 +1522,7 @@ def bipartition_tree_random( :type repeat_until_valid: bool, optional :param spanning_tree: The spanning tree for the algorithm to use (used when the algorithm chooses a new root and for testing). Defaults to None. - :type spanning_tree: Optional[nx.Graph], optional + :type spanning_tree: Optional[Graph], optional :param spanning_tree_fn: The random spanning tree algorithm to use if a spanning tree is not provided. Defaults to :func:`random_spanning_tree`. :type spanning_tree_fn: Callable, optional @@ -1548,7 +1548,7 @@ def bipartition_tree_random( :rtype: Union[Set[Any], None] """ - # frm: ???: TODO: Again - semantics should not depend on signatures... + # frm: TODO: Refactoring: Again - semantics should not depend on signatures... # # This is odd - there are two balance_edge_functions defined in tree.py but # both of them have a formal parameter with the name "one_sided_cut", so this @@ -1565,7 +1565,7 @@ def bipartition_tree_random( balance_edge_fn = partial(balance_edge_fn, one_sided_cut=True) possible_cuts = _bipartition_tree_random_all( - subgraph_to_split=subgraph_to_split, + graph_to_split=subgraph_to_split, pop_col=pop_col, pop_target=pop_target, epsilon=epsilon, @@ -1584,11 +1584,10 @@ def bipartition_tree_random( # frm: used in this file and in tree_proposals.py # But maybe this is intended to be used externally... -# frm TODO: RX version NYI... # frm: Note that this routine is only used in recom() def epsilon_tree_bipartition( - subgraph_to_split: Graph, # frm: Original code: graph: nx.Graph, + subgraph_to_split: Graph, parts: Sequence, pop_target: Union[float, int], pop_col: str, @@ -1601,7 +1600,7 @@ def epsilon_tree_bipartition( two parts of population ``pop_target`` (within ``epsilon``). :param graph: The graph to partition into two :math:`\varepsilon`-balanced parts. - :type graph: nx.Graph + :type graph: Graph :param parts: Iterable of part (district) labels (like ``[0,1,2]`` or ``range(4)``). :type parts: Sequence :param pop_target: Target population for each part of the partition. @@ -1654,7 +1653,6 @@ def epsilon_tree_bipartition( # so why use negative index values - why not just use # parts[0] and parts[1]? flips[node] = parts[-2] - # frm: original code: part_pop += graph.nodes[node][pop_col] part_pop += subgraph_to_split.node_data(node)[pop_col] if not check_pop(part_pop): @@ -1666,7 +1664,6 @@ def epsilon_tree_bipartition( part_pop = 0 for node in remaining_nodes: flips[node] = parts[-1] - # frm: original code: part_pop += graph.nodes[node][pop_col] part_pop += subgraph_to_split.node_data(node)[pop_col] if not check_pop(part_pop): @@ -1677,16 +1674,12 @@ def epsilon_tree_bipartition( return translated_flips - # frm: TODO: I think I need to translate flips elsewhere - need to check... - - -# TODO: Move these recursive partition functions to their own module. They are not +# frm: TODO: Refactoring: Move these recursive partition functions to their own module. They are not # central to the operation of the recom function despite being tree methods. # frm: defined here but only used in partition.py # But maybe this is intended to be used externally... -# frm TODO: RX version NYI... def recursive_tree_part( - graph: Graph, # frm: Original code: graph: nx.Graph, + graph: Graph, parts: Sequence, pop_target: Union[float, int], pop_col: str, @@ -1700,7 +1693,7 @@ def recursive_tree_part( generate initial seed plans or to implement ReCom-like "merge walk" proposals. :param graph: The graph to partition into ``len(parts)`` :math:`\varepsilon`-balanced parts. - :type graph: nx.Graph + :type graph: Graph :param parts: Iterable of part (district) labels (like ``[0,1,2]`` or ``range(4)``). :type parts: Sequence :param pop_target: Target population for each part of the partition. @@ -1751,7 +1744,7 @@ def recursive_tree_part( new_pop_target = (min_pop + max_pop) / 2 try: - nodes = method( + node_ids = method( graph.subgraph(remaining_nodes), pop_col=pop_col, pop_target=new_pop_target, @@ -1762,20 +1755,19 @@ def recursive_tree_part( except Exception: raise - if nodes is None: + if node_ids is None: raise BalanceError() part_pop = 0 - for node in nodes: + for node in node_ids: flips[node] = part - # frm: original code: part_pop += graph.nodes[node][pop_col] part_pop += graph.node_data(node)[pop_col] if not check_pop(part_pop): raise PopulationBalanceError() debt += part_pop - pop_target - remaining_nodes -= nodes + remaining_nodes -= node_ids # After making n-2 districts, we need to make sure that the last # two districts are both balanced. @@ -1783,7 +1775,7 @@ def recursive_tree_part( # frm: For the last call to "method", set one_sided_cut=False to # request that "method" create two equal sized districts # with the given population goal by bisecting the graph. - nodes = method( + node_ids = method( graph.subgraph(remaining_nodes), pop_col=pop_col, pop_target=pop_target, @@ -1792,32 +1784,30 @@ def recursive_tree_part( one_sided_cut=False, ) - if nodes is None: + if node_ids is None: raise BalanceError() part_pop = 0 - for node in nodes: - flips[node] = parts[-2] - # frm: this code fragment: graph.nodes[node][pop_col] is used + for node_id in node_ids: + flips[node_id] = parts[-2] + # frm: this code fragment: graph.node_data(node_id)[pop_col] is used # many times and is a candidate for being wrapped with # a function that has a meaningful name, such as perhaps: - # get_population_for_node(node, pop_col). + # get_population_for_node(node_id, pop_col). # This is an example of code-bloat from the perspective of # code gurus, but it really helps a new code reviewer understand # WTF is going on... - # frm: original code: part_pop += graph.nodes[node][pop_col] - part_pop += graph.node_data(node)[pop_col] + part_pop += graph.node_data(node_id)[pop_col] if not check_pop(part_pop): raise PopulationBalanceError() - remaining_nodes -= nodes + remaining_nodes -= node_ids # All of the remaining nodes go in the last part part_pop = 0 for node in remaining_nodes: flips[node] = parts[-1] - # frm: original code: part_pop += graph.nodes[node][pop_col] part_pop += graph.node_data(node)[pop_col] if not check_pop(part_pop): @@ -1826,9 +1816,8 @@ def recursive_tree_part( return flips # frm: only used in this file, so I changed the name to have a leading underscore -# frm TODO: RX version NYI... def _get_seed_chunks( - graph: Graph, # frm: Original code: graph: nx.Graph, + graph: Graph, num_chunks: int, num_dists: int, pop_target: Union[int, float], @@ -1842,7 +1831,7 @@ def _get_seed_chunks( balanced within new_epsilon <= ``epsilon`` of a balanced target population. :param graph: The graph - :type graph: nx.Graph + :type graph: Graph :param num_chunks: The number of chunks to partition the graph into :type num_chunks: int :param num_dists: The number of districts @@ -1865,14 +1854,14 @@ def _get_seed_chunks( :rtype: List[List[int]] """ - # frm: ??? TODO: Change the name of num_chunks_left to instead be num_districts_per_chunk. + # frm: TODO: Refactoring: Change the name of num_chunks_left to instead be num_districts_per_chunk. # frm: ???: It is not clear to me when num_chunks will not evenly divide num_dists. In # the only place where _get_seed_chunks() is called, it is inside an if-stmt # branch that validates that num_chunks evenly divides num_dists... # num_chunks_left = num_dists // num_chunks - # frm: ???: TODO: Change the name of parts below to be something / anything else. Normally + # frm: TODO: Refactoring: Change the name of parts below to be something / anything else. Normally # parts refers to districts, but here is is just a way to keep track of # sets of nodes for chunks. Yes - they eventually become districts when # this code gets to the base cases, but I found it confusing at this @@ -1887,15 +1876,14 @@ def _get_seed_chunks( chunk_pop = 0 for node in graph.node_indices: - # frm: original code: chunk_pop += graph.nodes[node][pop_col] chunk_pop += graph.node_data(node)[pop_col] - # frm: TODO: See if there is a better way to structure this instead of a while True loop... + # frm: TODO: Refactoring: See if there is a better way to structure this instead of a while True loop... while True: epsilon = abs(epsilon) flips = {} - remaining_nodes = set(graph.nodes) + remaining_nodes = graph.node_indices # frm: ??? What is the distinction between num_chunks and num_districts? # I think that a chunk is typically a multiple of districts, so @@ -1961,7 +1949,6 @@ def _get_seed_chunks( part_pop = 0 # frm: ???: Compute population total for remaining nodes. for node in remaining_nodes: - # frm: original code: part_pop += graph.nodes[node][pop_col] part_pop += graph.node_data(node)[pop_col] # frm: ???: Compute what the population total would be for each district in chunk part_pop_as_dist = part_pop / num_chunks_left @@ -2029,7 +2016,7 @@ def get_max_prime_factor_less_than(n: int, ceil: int) -> Optional[int]: return largest_factor def _recursive_seed_part_inner( - graph: Graph, # frm: Original code: graph: nx.Graph, + graph: Graph, num_dists: int, pop_target: Union[float, int], pop_col: str, @@ -2044,7 +2031,7 @@ def _recursive_seed_part_inner( Returns a partition with ``num_dists`` districts balanced within ``epsilon`` of ``pop_target``. - frm: ???: TODO: Correct the above statement that this function returns a partition. + frm: TODO: Documentation: Correct the above statement that this function returns a partition. In fact, it returns a list of sets of nodes, which is conceptually equivalent to a partition, but is not a Partition object. Each set of nodes constitutes a district, but the district does not @@ -2075,7 +2062,7 @@ def _recursive_seed_part_inner( Even so, why this particular strategy for divide and conquer? :param graph: The underlying graph structure. - :type graph: nx.Graph + :type graph: Graph :param num_dists: number of districts to partition the graph into :type num_dists: int :param pop_target: Target population for each part of the partition @@ -2188,7 +2175,7 @@ def _recursive_seed_part_inner( # one that is evenly divided by num_chunks and then # do chunk stuff... elif num_chunks is None or num_dists % num_chunks != 0: - remaining_nodes = set(graph.nodes) + remaining_nodes = graph.node_indices nodes = method( graph.subgraph(remaining_nodes), pop_col=pop_col, @@ -2212,7 +2199,7 @@ def _recursive_seed_part_inner( ) # split graph into num_chunks chunks, and recurse into each chunk - # frm: TODO: Add documentation for why a subgraph in call below + # frm: TODO: Documentation: Add documentation for why a subgraph in call below elif num_dists % num_chunks == 0: chunks = _get_seed_chunks( graph.subgraph(graph.node_indices), # needs to be a subgraph @@ -2257,10 +2244,10 @@ def _recursive_seed_part_inner( -# frm TODO: ???: This routine is never called - not in this file and not in any other GerryChain file. +# frm TODO: Refactoring: This routine is never called - not in this file and not in any other GerryChain file. # Is it intended to be used by end-users? And if so, for what purpose? def recursive_seed_part( - graph: Graph, # frm: Original code: graph: nx.Graph, + graph: Graph, parts: Sequence, pop_target: Union[float, int], pop_col: str, @@ -2275,7 +2262,7 @@ def recursive_seed_part( ``pop_target`` by recursively splitting graph using _recursive_seed_part_inner. :param graph: The graph - :type graph: nx.Graph + :type graph: Graph :param parts: Iterable of part labels (like ``[0,1,2]`` or ``range(4)`` :type parts: Sequence :param pop_target: Target population for each part of the partition diff --git a/gerrychain/updaters/compactness.py b/gerrychain/updaters/compactness.py index 418fe27f..5683acb7 100644 --- a/gerrychain/updaters/compactness.py +++ b/gerrychain/updaters/compactness.py @@ -17,22 +17,17 @@ def boundary_nodes(partition, alias: str = "boundary_nodes") -> Set: :rtype: Set """ - # frm: TODO: Figure out what is going on with the "alias" parameter. - # It is used to get the value from the parent if there is - # a parent, but it is NOT used when computing the result - # for the first partition. Seems like a logic bug... - # - # I think it is used as the attribute name on the partition for the - # data stored by an updater that uses this routine... + # Note that the "alias" parameter is used as the attribute name + # on the partition - using this "alias" you can retrieve the + # the data stored by an updater that uses this routine... if partition.parent: return partition.parent[alias] else: result = { - node - for node in partition.graph.nodes - # frm: original code: if partition.graph.nodes[node]["boundary_node"] - if partition.graph.node_data(node)["boundary_node"] + node_id + for node_id in partition.graph.node_indices + if partition.graph.node_data(node_id)["boundary_node"] } return result @@ -103,7 +98,6 @@ def initialize_exterior_boundaries(partition) -> Dict[int, float]: boundaries = collections.defaultdict(lambda: 0) for node in graph_boundary: part = partition.assignment.mapping[node] - # frm: original code: boundaries[part] += partition.graph.nodes[node]["boundary_perim"] boundaries[part] += partition.graph.node_data(node)["boundary_perim"] return boundaries @@ -131,12 +125,10 @@ def exterior_boundaries(partition, previous: Set, inflow: Set, outflow: Set) -> """ graph_boundary = partition["boundary_nodes"] added_perimeter = sum( - # frm: original code: partition.graph.nodes[node]["boundary_perim"] partition.graph.node_data(node)["boundary_perim"] for node in inflow & graph_boundary ) removed_perimeter = sum( - # frm: original code: partition.graph.nodes[node]["boundary_perim"] partition.graph.node_data(node)["boundary_perim"] for node in outflow & graph_boundary ) @@ -153,21 +145,13 @@ def initialize_interior_boundaries(partition): :rtype: Dict[int, float] """ - # frm: RustworkX Note: + # RustworkX Note: # - # The old NX code did not distinguish between edges and edge_ids - they were one - # and the same. However, in RX an edge is a tuple and an edge_id is an integer. - # The edges stored in partition["cut_edges_by_part"] are edges (tuples), so - # we need to get the edge_id for each edge in order to access the data for the edge. + # The old NX code did not distinguish between edges and edge_ids - they were one + # and the same. However, in RX an edge is a tuple and an edge_id is an integer. + # The edges stored in partition["cut_edges_by_part"] are edges (tuples), so + # we need to get the edge_id for each edge in order to access the data for the edge. - # frm: Original Code: - # return { - # part: sum( - # partition.graph.edges[edge]["shared_perim"] - # for edge in partition["cut_edges_by_part"][part] - # ) - # } - # Get edge_ids for each edge (tuple) edge_ids_for_part = { part: [ @@ -177,11 +161,6 @@ def initialize_interior_boundaries(partition): for part in partition.parts } - edge_to_edge_id_map = [ - (edge, partition.graph.get_edge_id_from_edge(edge)) - for edge in partition.graph.edges - ] - # Compute length of the shared perimeter of each part shared_perimeters_for_part = { part: sum( @@ -220,15 +199,11 @@ def interior_boundaries( """ added_perimeter = sum( - # frm: Original Code: partition.graph.edges[edge]["shared_perim"] for edge in new_edges - # frm: edges vs edge_ids: edge_ids are wanted here (integers) partition.graph.edge_data( partition.graph.get_edge_id_from_edge(edge) )["shared_perim"] for edge in new_edges ) removed_perimeter = sum( - # frm: Original Code: partition.graph.edges[edge]["shared_perim"] for edge in old_edges - # frm: edges vs edge_ids: edge_ids are wanted here (integers) partition.graph.edge_data( partition.graph.get_edge_id_from_edge(edge) )["shared_perim"] for edge in old_edges @@ -253,7 +228,7 @@ def perimeter_of_part(partition, part: int) -> float: """ Totals up the perimeter of the part in the partition. - .. Warning:: frm: TODO: Add code to enforce this warning... + .. Warning:: frm: TODO: Refactoring: Add code to enforce this warning... Requires that 'boundary_perim' be a node attribute, 'shared_perim' be an edge attribute, 'cut_edges' be an updater, and 'exterior_boundaries' be an updater. diff --git a/gerrychain/updaters/county_splits.py b/gerrychain/updaters/county_splits.py index e76ae78e..91350ec5 100644 --- a/gerrychain/updaters/county_splits.py +++ b/gerrychain/updaters/county_splits.py @@ -82,18 +82,18 @@ def compute_county_splits( county_dict = dict() - for node in partition.graph.node_indices: + for node_id in partition.graph.node_indices: # First figure get current status of the county's information - county = partition.graph.lookup(node, county_field) + county = partition.graph.node_data(node_id)[county_field] if county in county_dict: split, nodes, seen = county_dict[county] else: split, nodes, seen = CountySplit.NOT_SPLIT, [], set() - # Now update "nodes" and "seen" with this node and the part (district) from partition's assignment. - nodes.append(node) - seen.update(set([partition.assignment.mapping[node]])) + # Now update "nodes" and "seen" with this node_id and the part (district) from partition's assignment. + nodes.append(node_id) + seen.update(set([partition.assignment.mapping[node_id]])) # lastly, if we have "seen" more than one part (district), then the county is split across parts. if len(seen) > 1: @@ -108,7 +108,7 @@ def compute_county_splits( parent = partition.parent for county, county_info in parent[partition_field].items(): - seen = set(partition.assignment.mapping[node] for node in county_info.nodes) + seen = set(partition.assignment.mapping[node_id] for node_id in county_info.nodes) split = CountySplit.NOT_SPLIT @@ -151,21 +151,16 @@ def _get_splits(partition): def total_reg_splits(partition, reg_attr): """Returns the total number of times that reg_attr is split in the partition.""" all_region_names = set( - # frm: original code: partition.graph.nodes[node][reg_attr] for node in partition.graph.nodes - partition.graph.node_data(node)[reg_attr] for node in partition.graph.nodes + partition.graph.node_data(node_id)[reg_attr] for node_id in partition.graph.node_indices ) split = {name: 0 for name in all_region_names} # Require that the cut_edges updater is attached to the partition for node1, node2 in partition["cut_edges"]: if ( partition.assignment[node1] != partition.assignment[node2] - # frm: original code: and partition.graph.nodes[node1][reg_attr] - # frm: original code: == partition.graph.nodes[node2][reg_attr] and partition.graph.node_data(node1)[reg_attr] == partition.graph.node_data(node2)[reg_attr] ): - # frm: original code: split[partition.graph.nodes[node1][reg_attr]] += 1 - # frm: original code: split[partition.graph.nodes[node2][reg_attr]] += 1 split[partition.graph.node_data(node1)[reg_attr]] += 1 split[partition.graph.node_data(node2)[reg_attr]] += 1 diff --git a/gerrychain/updaters/cut_edges.py b/gerrychain/updaters/cut_edges.py index 8cc85c4e..e7852df9 100644 --- a/gerrychain/updaters/cut_edges.py +++ b/gerrychain/updaters/cut_edges.py @@ -60,7 +60,7 @@ def initialize_cut_edges(partition): :param partition: A partition of a Graph :type partition: :class:`~gerrychain.partition.Partition` - frm: TODO: This description should be updated. Cut_edges are edges that touch + frm: TODO: Documentation This description should be updated. Cut_edges are edges that touch two different parts (districts). They are the internal boundaries between parts (districts). This routine finds all of the cut_edges in the graph and then creates a dict that stores all of the cut_edges @@ -92,7 +92,7 @@ def cut_edges_by_part( partition, previous: Set[Tuple], new_edges: Set[Tuple], old_edges: Set[Tuple] ) -> Set[Tuple]: # - # frm TODO: Update / expand the documentation for this routine. + # frm TODO: Documentation: Update / expand the documentation for this routine. # # This only operates on cut-edges and not on all of the # edges in a partition. A "cut-edge" is an edge that spans two districts. diff --git a/gerrychain/updaters/election.py b/gerrychain/updaters/election.py index b29780c3..ab9efc66 100644 --- a/gerrychain/updaters/election.py +++ b/gerrychain/updaters/election.py @@ -117,28 +117,35 @@ def __init__( else: raise TypeError("Election expects party_names_to_node_attribute_names to be a dict or list") - # A DataTally can be created with a first parameter that is either a string - # or a dict. If a string, then the DataTally will interpret that string as - # the name of the node's attribute value that stores the data to be summed. - # However, if the first parameter is a node, then the DataTally will just - # access the value in the dict for a given node to get the data to be - # summed. The string approach makes it easy/convenient to sum data that - # is already stored as attribute values, while the dict approach makes - # it possible to sum computed values that are not stored as node attributes. - # - # However, after converting to using RustworkX for Graph objects in - # partitions, it no longer makes sense to use an explicit dict associating - # node_ids with data values for an election, because the node_ids given - # would need to be "original" node_ids derived from the NX-based graph - # that existed before creating the first partition, and those node_ids - # are useless once we convert the NX-based graph to RustworkX. - # - # So we disallow using a dict as a parameter to the DataTally below + # frm: TODO: Documentation: Migration: Using node_ids to vote tally maps... + # + # A DataTally used to support a first parameter that was either a string + # or a dict. # + # The idea was that in most cases, the values to be tallied would be present + # as the values of attributes associated with nodes, so it made sense to just + # provide the name of the attribute (a string) to identify what to tally. + # + # However, the code also supported providing an explicit mapping from node_id + # to the value to be tallied (a dict). This was useful for testing because + # it allowed for tallying values without having to implement an updater that + # would be based on a node's attribute. It provided a way to map values that + # were not part of the graph to vote totals. + # + # The problem was that when we started using RX for the embedded graph for + # partitions, the node_ids were no longer the same as the ones the user + # specified when creating the (NX) graph. This complicated the logic of + # having an explicit mapping from node_id to a value to be tallied - to + # make this work the code would have needed to translate the node_ids into + # the internal RX node_ids. + # + # The decision was made (Fred and Peter) that this extra complexity was not + # worth the trouble, so we now disallow passing in an explicit mapping (dict). + # for party in self.parties: if isinstance(self.party_names_to_node_attribute_names[party], dict): - raise Exception("Election: Using explicit node_id to vote totals maps is no longer permitted") + raise Exception("Election: Using a map from node_id to vote totals is no longer permitted") self.tallies = { party: DataTally(self.party_names_to_node_attribute_names[party], party) @@ -216,7 +223,7 @@ def get_previous_values(self, partition) -> Dict[str, Dict[int, float]]: return previous_totals_for_party -# frm: TODO: This routine, get_percents(), is only ever used inside ElectionResults. +# frm: TODO: Refactoring: This routine, get_percents(), is only ever used inside ElectionResults. # # Why is it not defined as an internal function inside ElectionResults? # diff --git a/gerrychain/updaters/flows.py b/gerrychain/updaters/flows.py index fe48fb08..0dd3f9f1 100644 --- a/gerrychain/updaters/flows.py +++ b/gerrychain/updaters/flows.py @@ -2,7 +2,7 @@ import functools from typing import Dict, Set, Tuple, Callable -# frm: TODO: This file needs documentation / comments!!! +# frm: TODO: Documentation: This file needs documentation / comments!!! # # Peter agrees... @@ -40,7 +40,7 @@ def flows_from_changes(old_partition, new_partition) -> Dict: :rtype: Dict """ - # frm: TODO: Grok why there is a test for: source != target + # frm: TODO: Code: ???: Grok why there is a test for: source != target # # It would seem to me that it would be a logic bug if there # was a "flip" that did not in fact change the partition mapping... diff --git a/gerrychain/updaters/locality_split_scores.py b/gerrychain/updaters/locality_split_scores.py index dc56f746..1a211ed7 100644 --- a/gerrychain/updaters/locality_split_scores.py +++ b/gerrychain/updaters/locality_split_scores.py @@ -1,13 +1,13 @@ # Imports from collections import defaultdict, Counter -# frm TODO: Remove dependence on NetworkX. +# frm TODO: Refactoring: Remove dependence on NetworkX. # The only use is: # pieces += nx.number_connected_components(subgraph) import networkx as nx import math from typing import List -# frm: TODO: Do performance testing and improve performance of these routines. +# frm: TODO: Performance: Do performance testing and improve performance of these routines. # # Peter made the comment in a PR that we should make this code more efficient: # @@ -144,7 +144,7 @@ def __init__( def __call__(self, partition): - # frm: TODO: LocalitySplits: Figure out how this is intended to be used... + # frm: TODO: Refactoring: LocalitySplits: Figure out how this is intended to be used... # # Not quite sure why it is better to have a "__call()__" method instead of a # get_scores(self) method, but whatever... @@ -160,9 +160,6 @@ def __call__(self, partition): # if self.localities == []: - # frm: Original code: - # self.localitydict = dict(partition.graph.nodes(data=self.col_id)) - # self.localitydict = {} for node_id in partition.graph.node_indices: self.localitydict[node_id] = partition.graph.node_data(node_id)[self.col_id] @@ -185,17 +182,16 @@ def __call__(self, partition): allowed_pieces = {} totpop = 0 - for node in partition.graph.node_indices: - # frm: TODO: Once you have a partition, you cannot change the total population + for node_id in partition.graph.node_indices: + # frm: TODO: Refactoring: Once you have a partition, you cannot change the total population # in the Partition, so why don't we cache the total population as # a data member in Partition? # # Peter agreed that this would be a good thing to do - # frm: original code: totpop += partition.graph.nodes[node][self.pop_col] - totpop += partition.graph.node_data(node)[self.pop_col] + totpop += partition.graph.node_data(node_id)[self.pop_col] - # frm: TODO: Ditto with num_districts - isn't this a constant once you create a Partition? + # frm: TODO: Refactoring: Ditto with num_districts - isn't this a constant once you create a Partition? # # Peter agreed that this would be a good thing to do. @@ -203,33 +199,9 @@ def __call__(self, partition): # Compute the total population for each locality and then the number of "allowed pieces" for loc in self.localities: - # frm: TODO: The code below just calculates the total population for a set of nodes. + # frm: TODO: Refactoring: The code below just calculates the total population for a set of nodes. # This sounds like a good candidate for a utility function. See if this # logic is repeated elsewhere... - - # frm: I changed the original code for a couple of reasons: - # - # * There were NX depedencies in the original code. - # partition.graph.nodes(data=True) - # * Creating a subgraph just to get a subset of nodes seemed unnecessary - # and probably expensive. - # * I found the code dense and it took me too long to figure out what it did. - - # frm: Original Code: - # - # sg = partition.graph.subgraph( - # for n, v in partition.graph.nodes(data=True) - # if v[self.col_id] == loc - # ) - # - # pop = 0 - # for n in sg.nodes(): - # # frm: TODO: I think this needs to change to work for RX... - # pop += sg.nodes[n][self.pop_col] - # - # allowed_pieces[loc] = math.ceil(pop / (totpop / num_districts)) - - # frm: new version of this code that is less clever... # Compute the population associated with each location the_graph = partition.graph @@ -242,7 +214,7 @@ def __call__(self, partition): else: locality_population[locality_name] += locality_pop - # frm: TODO: Peter commented (in PR) that this is another thing that + # frm: TODO: Refactoring: Peter commented (in PR) that this is another thing that # could be cached so we didn't recompute it over and over... ideal_population_per_district = totpop / num_districts @@ -309,7 +281,6 @@ def num_pieces(self, partition) -> int: locality_intersections = {} for n in partition.graph.node_indices: - # frm: original code: locality = partition.graph.nodes[n][self.col_id] locality = partition.graph.node_data(n)[self.col_id] if locality not in locality_intersections: locality_intersections[locality] = set( @@ -325,14 +296,10 @@ def num_pieces(self, partition) -> int: [ x for x in partition.parts[d] - # frm: original code: if partition.graph.nodes[x][self.col_id] == locality if partition.graph.node_data(x)[self.col_id] == locality ] ) - # frm: Original Code: - # - # pieces += nx.number_connected_components(subgraph) pieces += subgraph.num_connected_components() return pieces @@ -466,9 +433,6 @@ def symmetric_entropy(self, partition) -> float: # IN PROGRESS vtds = district_dict[district] locality_pop = {k: 0 for k in self.localities} for vtd in vtds: - # frm: original code: locality_pop[self.localitydict[vtd]] += partition.graph.nodes[vtd][ - # frm: original code: self.pop_col - # frm: original code: ] locality_pop[self.localitydict[vtd]] += partition.graph.node_data(vtd)[ self.pop_col ] diff --git a/gerrychain/updaters/spanning_trees.py b/gerrychain/updaters/spanning_trees.py index 2f6cddce..9150151f 100644 --- a/gerrychain/updaters/spanning_trees.py +++ b/gerrychain/updaters/spanning_trees.py @@ -24,7 +24,6 @@ def _num_spanning_trees_in_district(partition, district: int) -> int: :rtype: int """ graph = partition.subgraphs[district] - # frm: Original Code: laplacian = networkx.laplacian_matrix(graph) laplacian = partition.graph.laplacian_matrix() L = numpy.delete(numpy.delete(laplacian.todense(), 0, 0), 1, 1) return math.exp(numpy.linalg.slogdet(L)[1]) diff --git a/gerrychain/updaters/tally.py b/gerrychain/updaters/tally.py index 082bd675..91c05407 100644 --- a/gerrychain/updaters/tally.py +++ b/gerrychain/updaters/tally.py @@ -19,6 +19,16 @@ class DataTally: :type alias: str """ + # frm: TODO: Code: Check to see if DataTally used for data that is NOT attribute of a node + # + # The comment above indicates that you can use a DataTally for data that is not stored + # as an attribute of a node. Check to see if it is ever actually used that way. If so, + # then update the documentation above to state the use cases for adding up data that is + # NOT stored as a node attribute... + # + # It appears that some tests use the ability to specify tallies that do not involve a + # node attribute, but it is not clear if any "real" code does that... + __slots__ = ["data", "alias", "_call"] def __init__(self, data: Union[Dict, pandas.Series, str], alias: str) -> None: @@ -43,30 +53,23 @@ def initialize_tally(partition): # If not, then assume that the "data" passed in is already of the # form: {node_id: data_value} - # frm: TODO: Verify that if the "data" passed in is not a string that it - # is of the form: {node_id, data_value} - - if isinstance(self.data, str): - # frm: Original Code: - # nodes = partition.graph.nodes - # attribute = self.data - # self.data = {node: nodes[node][attribute] for node in nodes} + # if the "data" passed in was a string, then replace its value with + # a dict of {node_id: attribute_value of the node} graph = partition.graph node_ids = partition.graph.node_indices attribute = self.data self.data = {node_id: graph.node_data(node_id)[attribute] for node_id in node_ids} - # frm: TODO: Should probably check that the data for each node is numerical, since we - # are going to sum it below... - tally = collections.defaultdict(int) for node_id, part in partition.assignment.items(): add = self.data[node_id] - # frm: TODO: Should I also test that the "add" variable is a number or something - # that can be added? + # Note: math.isnan() will raise an exception if the value passed in is not + # numeric, so there is no need to do another check to ensure that the value + # is numeric - that test is implicit in math.isnan() + # if math.isnan(add): warnings.warn( "ignoring nan encountered at node_id '{}' for attribute '{}'".format( @@ -191,7 +194,7 @@ def _update_tally(self, partition): return new_tally def _get_tally_from_node(self, partition, node): - return sum(partition.graph.lookup(node, field) for field in self.fields) + return sum(partition.graph.node_data(node)[field] for field in self.fields) def compute_out_flow(graph, fields: Union[str, List[str]], flow: Dict) -> int: @@ -209,7 +212,7 @@ def compute_out_flow(graph, fields: Union[str, List[str]], flow: Dict) -> int: :returns: The sum of the "field" attribute of nodes in the "out" set of the flow. :rtype: int """ - return sum(graph.lookup(node, field) for node in flow["out"] for field in fields) + return sum(graph.node_data(node)[field] for node in flow["out"] for field in fields) def compute_in_flow(graph, fields: Union[str, List[str]], flow: Dict) -> int: @@ -227,4 +230,4 @@ def compute_in_flow(graph, fields: Union[str, List[str]], flow: Dict) -> int: :returns: The sum of the "field" attribute of nodes in the "in" set of the flow. :rtype: int """ - return sum(graph.lookup(node, field) for node in flow["in"] for field in fields) + return sum(graph.node_data(node)[field] for node in flow["in"] for field in fields) diff --git a/tests/README.txt b/tests/README.txt new file mode 100644 index 00000000..b2a0024e --- /dev/null +++ b/tests/README.txt @@ -0,0 +1,6 @@ +This folder contains tests (and subfolders that also contain tests). + +As a convention (at least for now - October 2025), tests that Fred +Mueller adds will be named test_frm_... Eventually the names should +be changed to have the "_frm" deleted, but for now it will help +identify big changes in testing diff --git a/tests/_foo/do_laplacian.py b/tests/_foo/do_laplacian.py index 40fb6601..110a15d6 100644 --- a/tests/_foo/do_laplacian.py +++ b/tests/_foo/do_laplacian.py @@ -19,28 +19,30 @@ # 3. Calculate the Laplacian matrix rx_laplacian_matrix = degree_matrix - adj_matrix -print("RX Adjacency Matrix:") -print(adj_matrix) +# frm: TODO: Debugging: Remove Debugging Code -print("\nRX Degree Matrix:") -print(degree_matrix) +# print("RX Adjacency Matrix:") +# print(adj_matrix) -print("\nRX Laplacian Matrix:") -print(rx_laplacian_matrix) +# print("\nRX Degree Matrix:") +# print(degree_matrix) -print("type of RX laplacian_matrix is: ", type(rx_laplacian_matrix)) +# print("\nRX Laplacian Matrix:") +# print(rx_laplacian_matrix) + +# print("type of RX laplacian_matrix is: ", type(rx_laplacian_matrix)) # Create an NX graph (replace with your graph data) nx_graph = nx.Graph([(0, 1), (0, 2), (1, 2), (2, 3)]) nx_laplacian_matrix = nx.laplacian_matrix(nx_graph) -print("\nNX Laplacian Matrix:") -print(nx_laplacian_matrix) +# print("\nNX Laplacian Matrix:") +# print(nx_laplacian_matrix) -print("type of NX laplacian_matrix is: ", type(nx_laplacian_matrix)) +# print("type of NX laplacian_matrix is: ", type(nx_laplacian_matrix)) gc_nx_graph = Graph.from_nx_graph(nx_graph) gc_rx_graph = Graph.from_rx_graph(rx_graph) -print("\ngc_laplacian(nx_graph) is: ", gctree.gc_laplacian_matrix(gc_nx_graph)) -print("\ngc_laplacian(rx_graph) is: ", gctree.gc_laplacian_matrix(gc_rx_graph)) +# print("\ngc_laplacian(nx_graph) is: ", gctree.gc_laplacian_matrix(gc_nx_graph)) +# print("\ngc_laplacian(rx_graph) is: ", gctree.gc_laplacian_matrix(gc_rx_graph)) diff --git a/tests/_perf_tests/perf_test.py b/tests/_perf_tests/perf_test.py new file mode 100644 index 00000000..40a2a8b4 --- /dev/null +++ b/tests/_perf_tests/perf_test.py @@ -0,0 +1,63 @@ +# Code copied from the GerryChain User Guide / Tutorial: + +import matplotlib.pyplot as plt +from gerrychain import (Partition, Graph, MarkovChain, + updaters, constraints, accept) +from gerrychain.proposals import recom +from gerrychain.constraints import contiguous +from functools import partial +import pandas + +import cProfile + +# Set the random seed so that the results are reproducible! +import random + +def main(): + + random.seed(2024) + graph = Graph.from_json("./gerrymandria.json") + + my_updaters = { + "population": updaters.Tally("TOTPOP"), + "cut_edges": updaters.cut_edges + } + + initial_partition = Partition( + graph, + assignment="district", + updaters=my_updaters + ) + + # This should be 8 since each district has 1 person in it. + # Note that the key "population" corresponds to the population updater + # that we defined above and not with the population column in the json file. + ideal_population = sum(initial_partition["population"].values()) / len(initial_partition) + + proposal = partial( + recom, + pop_col="TOTPOP", + pop_target=ideal_population, + epsilon=0.01, + node_repeats=2 + ) + + recom_chain = MarkovChain( + proposal=proposal, + constraints=[contiguous], + accept=accept.always_accept, + initial_state=initial_partition, + total_steps=40 + ) + + assignments = list(recom_chain) + + assignment_list = [] + + for i, item in enumerate(recom_chain): + print(f"Finished step {i+1}/{len(recom_chain)}", end="\r") + assignment_list.append(item.assignment) + +if __name__ == "__main__": + cProfile.run('main()', sort='tottime') + diff --git a/tests/_perf_tests/perf_test2.py b/tests/_perf_tests/perf_test2.py new file mode 100644 index 00000000..7d2e78ae --- /dev/null +++ b/tests/_perf_tests/perf_test2.py @@ -0,0 +1,87 @@ + +import matplotlib.pyplot as plt +from gerrychain import (GeographicPartition, Partition, Graph, MarkovChain, + proposals, updaters, constraints, accept, Election) +from gerrychain.proposals import recom +from functools import partial +import pandas +import gerrychain + +import sys +import cProfile + +def main(): + + graph = Graph.from_json("./PA_VTDs.json") + + elections = [ + Election("SEN10", {"Democratic": "SEN10D", "Republican": "SEN10R"}), + Election("SEN12", {"Democratic": "USS12D", "Republican": "USS12R"}), + Election("SEN16", {"Democratic": "T16SEND", "Republican": "T16SENR"}), + Election("PRES12", {"Democratic": "PRES12D", "Republican": "PRES12R"}), + Election("PRES16", {"Democratic": "T16PRESD", "Republican": "T16PRESR"}) + ] + + # Population updater, for computing how close to equality the district + # populations are. "TOTPOP" is the population column from our shapefile. + my_updaters = {"population": updaters.Tally("TOT_POP", alias="population")} + + # Election updaters, for computing election results using the vote totals + # from our shapefile. + election_updaters = {election.name: election for election in elections} + my_updaters.update(election_updaters) + + initial_partition = GeographicPartition( + graph, + assignment="2011_PLA_1", # This identifies the district plan in 2011 + updaters=my_updaters + ) + + # The ReCom proposal needs to know the ideal population for the districts so that + # we can improve speed by bailing early on unbalanced partitions. + + ideal_population = sum(initial_partition["population"].values()) / len(initial_partition) + + # We use functools.partial to bind the extra parameters (pop_col, pop_target, epsilon, node_repeats) + # of the recom proposal. + proposal = partial( + recom, + pop_col="TOT_POP", + pop_target=ideal_population, + epsilon=0.02, + node_repeats=2 + ) + + def cut_edges_length(p): + return len(p["cut_edges"]) + + compactness_bound = constraints.UpperBound( + cut_edges_length, + 2*len(initial_partition["cut_edges"]) + ) + + pop_constraint = constraints.within_percent_of_ideal_population(initial_partition, 0.02) + + print("About to call MarkovChain", file=sys.stderr) + + chain = MarkovChain( + proposal=proposal, + constraints=[ + pop_constraint, + compactness_bound + ], + accept=accept.always_accept, + initial_state=initial_partition, + total_steps=1000 + ) + + print("Done with calling MarkovChain", file=sys.stderr) + + print("About to get all assignments from the chain", file=sys.stderr) + assignments = list(chain) + print("Done getting all assignments from the chain", file=sys.stderr) + + +if __name__ == "__main__": + cProfile.run('main()', sort='tottime') + diff --git a/tests/conftest.py b/tests/conftest.py index e23125f8..e5797be9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -123,13 +123,12 @@ def factory(columns): return factory -# frm: TODO: This routine is only ever used immediately above in def factory(columns). +# frm: TODO: Refactoring: This routine is only ever used immediately above in def factory(columns). # Is it part of the external API? If not, then it should be moved inside # the graph_with_random_data_factory() routine def attach_random_data(graph, columns): for node in graph.nodes: for col in columns: - # frm: Original code: graph.nodes[node][col] = random.randint(1, 1000) graph.node_data(node)[col] = random.randint(1, 1000) diff --git a/tests/constraints/test_contiguity.py b/tests/constraints/test_contiguity.py index b5f7acde..87ade1d7 100644 --- a/tests/constraints/test_contiguity.py +++ b/tests/constraints/test_contiguity.py @@ -11,14 +11,6 @@ def test_contiguous_components(graph): assert len(components[1]) == 2 assert len(components[2]) == 1 - # frm: Original Code: - # - # assert set(frozenset(g.nodes) for g in components[1]) == { - # frozenset([0, 1, 2]), - # frozenset([6, 7, 8]), - # } - # assert set(components[2][0].nodes) == {3, 4, 5} - # Confirm that the appropriate connected subgraphs were found. Note that we need # to compare against the original node_ids, since RX node_ids change every time # you create a subgraph. diff --git a/tests/constraints/test_validity.py b/tests/constraints/test_validity.py index 1c915683..c0817e6c 100644 --- a/tests/constraints/test_validity.py +++ b/tests/constraints/test_validity.py @@ -5,7 +5,7 @@ import pytest from gerrychain.constraints import (SelfConfiguringLowerBound, Validator, - contiguous, contiguous_bfs, + contiguous, districts_within_tolerance, no_vanishing_districts, single_flip_contiguous, @@ -50,16 +50,16 @@ def discontiguous_partition(discontiguous_partition_with_flips): def test_contiguous_with_contiguity_no_flips_is_true(contiguous_partition): assert contiguous(contiguous_partition) assert single_flip_contiguous(contiguous_partition) - assert contiguous_bfs(contiguous_partition) + assert contiguous(contiguous_partition) def test_contiguous_with_contiguity_flips_is_true(contiguous_partition_with_flips): contiguous_partition, test_flips = contiguous_partition_with_flips - # frm: TODO: Figure out whether test_flips are in original node_ids or internal RX node_ids + # frm: TODO: Testing: Figure out whether test_flips are in original node_ids or internal RX node_ids contiguous_partition2 = contiguous_partition.flip(test_flips) assert contiguous(contiguous_partition2) assert single_flip_contiguous(contiguous_partition2) - assert contiguous_bfs(contiguous_partition2) + assert contiguous(contiguous_partition2) def test_discontiguous_with_contiguous_no_flips_is_false(discontiguous_partition): @@ -72,15 +72,15 @@ def test_discontiguous_with_single_flip_contiguous_no_flips_is_false( assert not single_flip_contiguous(discontiguous_partition) -def test_discontiguous_with_contiguous_bfs_no_flips_is_false(discontiguous_partition): - assert not contiguous_bfs(discontiguous_partition) +def test_discontiguous_with_contiguous_no_flips_is_false(discontiguous_partition): + assert not contiguous(discontiguous_partition) def test_discontiguous_with_contiguous_flips_is_false( discontiguous_partition_with_flips ): part, test_flips = discontiguous_partition_with_flips - # frm: TODO: Figure out whether test_flips are in original node_ids or internal RX node_ids + # frm: TODO: Testing: Figure out whether test_flips are in original node_ids or internal RX node_ids discontiguous_partition2 = part.flip(test_flips) assert not contiguous(discontiguous_partition2) @@ -93,18 +93,18 @@ def test_discontiguous_with_single_flip_contiguous_flips_is_false( discontiguous_partition_with_flips ): part, test_flips = discontiguous_partition_with_flips - # frm: TODO: Figure out whether test_flips are in original node_ids or internal RX node_ids + # frm: TODO: Testing: Figure out whether test_flips are in original node_ids or internal RX node_ids discontiguous_partition2 = part.flip(test_flips) assert not single_flip_contiguous(discontiguous_partition2) -def test_discontiguous_with_contiguous_bfs_flips_is_false( +def test_discontiguous_with_contiguous_flips_is_false( discontiguous_partition_with_flips ): part, test_flips = discontiguous_partition_with_flips - # frm: TODO: Figure out whether test_flips are in original node_ids or internal RX node_ids + # frm: TODO: Testing: Figure out whether test_flips are in original node_ids or internal RX node_ids discontiguous_partition2 = part.flip(test_flips) - assert not contiguous_bfs(discontiguous_partition2) + assert not contiguous(discontiguous_partition2) def test_districts_within_tolerance_returns_false_if_districts_are_not_within_tolerance(): diff --git a/tests/frm_tests/test_frm_make_graph.py b/tests/frm_tests/test_frm_make_graph.py index c4f23bd9..cee8d27f 100644 --- a/tests/frm_tests/test_frm_make_graph.py +++ b/tests/frm_tests/test_frm_make_graph.py @@ -206,13 +206,6 @@ def edge_set_equal(set1, set2): def test_from_file_adds_all_data_by_default(shapefile): graph = Graph.from_file(shapefile) - # frm: Original Code: - # Get all of the data dictionaries for each node and verify that each - # of them contains data with the key "data" and "data2" - # - # assert all("data" in node_data for node_data in graph.nodes.values()) - # assert all("data2" in node_data for node_data in graph.nodes.values()) - # data dictionaries for all of the nodes all_node_data = [graph.node_data(node_id) for node_id in graph.node_indices] diff --git a/tests/frm_tests/test_frm_nx_rx_graph.py b/tests/frm_tests/test_frm_nx_rx_graph.py index 3cdec14e..84f0c32c 100644 --- a/tests/frm_tests/test_frm_nx_rx_graph.py +++ b/tests/frm_tests/test_frm_nx_rx_graph.py @@ -90,7 +90,7 @@ def test_nx_rx_node_indices_agree(gerrychain_nx_graph, gerrychain_rx_graph): assert nx_node_indices == rx_node_indices def test_nx_rx_edges_agree(gerrychain_nx_graph, gerrychain_rx_graph): - # TODO: Rethink this test. At the moment it relies on the edge_list() + # TODO: Testing: Rethink this test. At the moment it relies on the edge_list() # call which does not exist on a GerryChain Graph object # being handled by RX through clever __getattr__ stuff. # I think we should add an edge_list() method to GerryChain Graph @@ -106,14 +106,14 @@ def test_nx_rx_node_neighbors_agree(gerrychain_nx_graph, gerrychain_rx_graph): assert nx_neighbors == rx_neighbors def test_nx_rx_subgraphs_agree(gerrychain_nx_graph, gerrychain_rx_graph): - subgraph_nodes = [0,1,2,3,4,5] # TODO: make this a fixture dependent on JSON graph + subgraph_nodes = [0,1,2,3,4,5] # TODO: Testing: make this a fixture dependent on JSON graph nx_subgraph = gerrychain_nx_graph.subgraph(subgraph_nodes) rx_subgraph = gerrychain_rx_graph.subgraph(subgraph_nodes) for node_id in nx_subgraph: nx_node_data = nx_subgraph.node_data(node_id) rx_node_data = rx_subgraph.node_data(node_id) assert nx_node_data == rx_node_data - # frm: TODO: This does not test that the rx_subgraph has the exact same number of + # frm: TODO: Testing: This does not test that the rx_subgraph has the exact same number of # nodes as the nx_subgraph, and it does not test edge data... def test_nx_rx_degrees_agree(gerrychain_nx_graph, gerrychain_rx_graph): @@ -129,7 +129,7 @@ def test_nx_rx_degrees_agree(gerrychain_nx_graph, gerrychain_rx_graph): """ -frm: TODO: +frm: TODO: Testing: * Functions: * predecessors() diff --git a/tests/frm_tests/test_frm_old_vs_new_graph.py b/tests/frm_tests/test_frm_old_vs_new_graph.py deleted file mode 100644 index 8563980f..00000000 --- a/tests/frm_tests/test_frm_old_vs_new_graph.py +++ /dev/null @@ -1,127 +0,0 @@ -# -# This tests compatibility between the old/original version of -# the Graph object and the new version that encapsulates the -# graph as a data member - either nx_graph or rx_graph. -# - -import matplotlib.pyplot as plt -from gerrychain import (Partition, Graph, MarkovChain, - updaters, constraints, accept) -from gerrychain.graph import OriginalGraph -from gerrychain.proposals import recom -from gerrychain.constraints import contiguous -from functools import partial -import pandas - -import os -import rustworkx as rx -import networkx as nx - - -# Set the random seed so that the results are reproducible! -import random -random.seed(2024) - - -test_file_path = os.path.abspath(__file__) -cur_directory = os.path.dirname(test_file_path) -json_file_path = os.path.join(cur_directory, "gerrymandria.json") -print("json file is: ", json_file_path) - -new_graph = Graph.from_json(json_file_path) -old_graph = OriginalGraph.from_json(json_file_path) - - -print("Created old and new Graph objects from JSON") - -# frm: DEBUGGING: -# print("created new_graph") -# print("type of new_graph.nodes is: ", type(new_graph.nodes)) -new_graph_nodes = new_graph.nodes -old_graph_nodes = list(old_graph.nodes) -# print("new_graph nodes: ", list(new_graph.nodes)) -# print("new_graph edges: ", list(new_graph.edges)) -# print("") # newline -# print("created old_graph") -# print("type of old_graph.nodes is: ", type(old_graph.nodes)) -# print("old_graph nodes: ", list(old_graph.nodes)) -# print("old_graph edges: ", list(old_graph.edges)) - -print("testing that graph.nodes have same length") -assert(len(new_graph.nodes) == len(old_graph.nodes)), "lengths disagree" - -new_graph_edges = new_graph.edges -old_graph_edges = set(old_graph.edges) -print("testing that graph.edges have same length") -assert(len(new_graph_edges) == len(old_graph_edges)), "lengths disagree" - -node_subset = set([1,2,3,4,5]) -new_graph_subset = new_graph.subgraph(node_subset) -print("type of new_graph.subset is: ", type(new_graph_subset)) -print(new_graph_subset.edges) -old_graph_subset = old_graph.subgraph(node_subset) -print("type of old_graph.subset is: ", type(old_graph_subset)) -print(old_graph_subset.edges) - -# print("created frm_graph") -# print("FrmGraph nodes: ", list(frm_graph.nodes)) -# print("FrmGraph edges: ", list(frm_graph.edges)) - -print("About to test Graph.predecessors(root)") -pred = new_graph.predecessors(1) -print(list(pred)) - -# frm: TODO: Flesh out this test... - - -# -# The code below is from the regression test - maybe -# it will be useful in the future, maybe not... -# - -### my_updaters = { -### "population": updaters.Tally("TOTPOP"), -### "cut_edges": updaters.cut_edges -### } -### -### initial_partition = Partition( -### new_graph, -### assignment="district", -### updaters=my_updaters -### ) -### -### # This should be 8 since each district has 1 person in it. -### # Note that the key "population" corresponds to the population updater -### # that we defined above and not with the population column in the json file. -### ideal_population = sum(initial_partition["population"].values()) / len(initial_partition) -### -### proposal = partial( -### recom, -### pop_col="TOTPOP", -### pop_target=ideal_population, -### epsilon=0.01, -### node_repeats=2 -### ) -### -### print("Got proposal") -### -### recom_chain = MarkovChain( -### proposal=proposal, -### constraints=[contiguous], -### accept=accept.always_accept, -### initial_state=initial_partition, -### total_steps=40 -### ) -### -### print("Set up Markov Chain") -### -### assignment_list = [] -### -### for i, item in enumerate(recom_chain): -### print(f"Finished step {i+1}/{len(recom_chain)}") -### assignment_list.append(item.assignment) -### -### print("Enumerated the chain: number of entries in list is: ", len(assignment_list)) -### -### def test_success(): -### len(assignment_list) == 40 diff --git a/tests/frm_tests/test_frm_regression.py b/tests/frm_tests/test_frm_regression.py index cac1134d..29c4a897 100644 --- a/tests/frm_tests/test_frm_regression.py +++ b/tests/frm_tests/test_frm_regression.py @@ -30,17 +30,9 @@ test_file_path = os.path.abspath(__file__) cur_directory = os.path.dirname(test_file_path) json_file_path = os.path.join(cur_directory, "gerrymandria.json") -print("json file is: ", json_file_path) graph = Graph.from_json(json_file_path) -print("Created Graph from JSON") - -# frm: DEBUGGING: -# print("created graph") -# print("nodes: ", list(graph.nodes)) -# print("edges: ", list(graph.edges)) - my_updaters = { "population": updaters.Tally("TOTPOP"), "cut_edges": updaters.cut_edges @@ -65,8 +57,6 @@ node_repeats=2 ) -print("Got proposal") - recom_chain = MarkovChain( proposal=proposal, constraints=[contiguous], @@ -75,8 +65,6 @@ total_steps=40 ) -print("Set up Markov Chain") - assignment_list = [] for i, item in enumerate(recom_chain): diff --git a/tests/optimization/test_single_metric.py b/tests/optimization/test_single_metric.py index c6cdc968..e9868492 100644 --- a/tests/optimization/test_single_metric.py +++ b/tests/optimization/test_single_metric.py @@ -501,7 +501,6 @@ def opt_fn(partition): ): max_scores_sb[i] = optimizer.best_score - # frm: TODO: stmt below fails with 1.0 != 2 assert np.max(max_scores_sb) == 2 @@ -551,7 +550,6 @@ def opt_fn(partition): ): max_scores_anneal[i] = optimizer.best_score - # frm: TODO: stmt below fails. assert np.max(max_scores_anneal) == 2 diff --git a/tests/partition/test_partition.py b/tests/partition/test_partition.py index 77381254..ea7e380f 100644 --- a/tests/partition/test_partition.py +++ b/tests/partition/test_partition.py @@ -12,7 +12,7 @@ def test_Partition_can_be_flipped(example_partition): - # frm: TODO: Verify that this flip is in internal RX-based graph node_ids and not "original" NX node_ids + # frm: TODO: Testing: Verify that this flip is in internal RX-based graph node_ids and not "original" NX node_ids # # My guess is that this flip is intended to be in original node_ids but that the test works # anyways because the assertion uses the same numbers. It should probably be changed to use @@ -53,7 +53,7 @@ def test_Partition_knows_cut_edges_K3(example_partition): def test_propose_random_flip_proposes_a_partition(example_partition): partition = example_partition - # frm: TODO: Verify that propose_random_flip() to make sure it is doing the right thing + # frm: TODO: Testing: Verify that propose_random_flip() to make sure it is doing the right thing # wrt RX-based node_ids vs. original node_ids. proposal = propose_random_flip(partition) assert isinstance(proposal, partition.__class__) @@ -121,7 +121,7 @@ def test_can_be_created_from_a_districtr_file(graph, districtr_plan_file): for node in graph: graph.node_data(node)["area_num_1"] = node - # frm: TODO: NX vs. RX node_id issues here... + # frm: TODO: Testing: NX vs. RX node_id issues here... partition = Partition.from_districtr_file(graph, districtr_plan_file) diff --git a/tests/partition/test_plotting.py b/tests/partition/test_plotting.py index 20374d94..1f5bb393 100644 --- a/tests/partition/test_plotting.py +++ b/tests/partition/test_plotting.py @@ -69,8 +69,12 @@ def test_uses_graph_geometries_by_default(self, geodataframe): graph = Graph.from_geodataframe(geodataframe) partition = Partition(graph=graph, assignment={node: 0 for node in graph}) - # frm: TODO: the following statement blows up because we do not copy + # frm: TODO: Testing: how to handle geometry? + # + # Originally, the following statement blew up because we do not copy # geometry data from NX to RX when we convert to RX. + # + # I said at the time: # Need to grok what the right way to deal with geometry # data is (is it only an issue for from_geodataframe() or # are there other ways a geometry value might be set?) diff --git a/tests/test_frm_graph.py b/tests/test_frm_graph.py new file mode 100644 index 00000000..19b729aa --- /dev/null +++ b/tests/test_frm_graph.py @@ -0,0 +1,650 @@ +import pytest +import networkx as nx +import rustworkx as rx +from gerrychain import Graph + +############################################### +# This file contains tests routines in graph.py +############################################### + +@pytest.fixture +def four_by_five_grid_nx(): + + # Create an NX Graph object with attributes + # + # This graph has the following properties + # which are important for the tests below: + # + # * The "nx_node_id" attribute serves as an + # effective "original" node_id so that we + # can track a node even when its internal + # node_id changes. + # + # * The graph has two "connected" components: + # the first two rows and the last two + # rows. This is used in the connected + # components tests + + # nx_node_id + # + # 0 1 2 3 4 + # 5 6 7 8 9 + # 10 11 12 13 14 + # 15 16 17 18 19 + + # MVAP: + # + # 2 2 2 2 2 + # 2 2 2 2 2 + # 2 2 2 2 2 + # 2 2 2 2 2 + + nx_graph = nx.Graph() + nx_graph.add_nodes_from( + [ + (0, {"population": 10, "nx_node_id": 0, "MVAP": 2}), + (1, {"population": 10, "nx_node_id": 1, "MVAP": 2}), + (2, {"population": 10, "nx_node_id": 2, "MVAP": 2}), + (3, {"population": 10, "nx_node_id": 3, "MVAP": 2}), + (4, {"population": 10, "nx_node_id": 4, "MVAP": 2}), + (5, {"population": 10, "nx_node_id": 5, "MVAP": 2}), + (6, {"population": 10, "nx_node_id": 6, "MVAP": 2}), + (7, {"population": 10, "nx_node_id": 7, "MVAP": 2}), + (8, {"population": 10, "nx_node_id": 8, "MVAP": 2}), + (9, {"population": 10, "nx_node_id": 9, "MVAP": 2}), + (10, {"population": 10, "nx_node_id": 10, "MVAP": 2}), + (11, {"population": 10, "nx_node_id": 11, "MVAP": 2}), + (12, {"population": 10, "nx_node_id": 12, "MVAP": 2}), + (13, {"population": 10, "nx_node_id": 13, "MVAP": 2}), + (14, {"population": 10, "nx_node_id": 14, "MVAP": 2}), + (15, {"population": 10, "nx_node_id": 15, "MVAP": 2}), + (16, {"population": 10, "nx_node_id": 16, "MVAP": 2}), + (17, {"population": 10, "nx_node_id": 17, "MVAP": 2}), + (18, {"population": 10, "nx_node_id": 18, "MVAP": 2}), + (19, {"population": 10, "nx_node_id": 19, "MVAP": 2}), + ] + ) + + nx_graph.add_edges_from( + [ + (0, 1), + (0, 5), + (1, 2), + (1, 6), + (2, 3), + (2, 7), + (3, 4), + (3, 8), + (4, 9), + (5, 6), + # (5, 10), + (6, 7), + # (6, 11), + (7, 8), + # (7, 12), + (8, 9), + # (8, 13), + # (9, 14), + (10, 11), + (10, 15), + (11, 12), + (11, 16), + (12, 13), + (12, 17), + (13, 14), + (13, 18), + (14, 19), + (15, 16), + (16, 17), + (17, 18), + (18, 19), + ] + ) + + return nx_graph + +@pytest.fixture +def four_by_five_grid_rx(four_by_five_grid_nx): + # Create an RX Graph object with attributes + rx_graph = rx.networkx_converter(four_by_five_grid_nx, keep_attributes=True) + return rx_graph + +def top_level_graph_is_properly_configured(graph): + # This routine tests that top-level graphs (not a subgraph) + # are properly configured + assert graph._is_a_subgraph == False, \ + "Top-level graph _is_a_subgraph is True" + assert hasattr(graph, "_node_id_to_parent_node_id_map"), \ + "Graph._node_id_to_parent_node_id_map is not set" + assert hasattr(graph, "_node_id_to_original_nx_node_id_map"), \ + "Graph._node_id_to_original_nx_node_id_map is not set" + +def test_from_networkx(four_by_five_grid_nx): + graph = Graph.from_networkx(four_by_five_grid_nx) + assert len(graph.node_indices) == 20, \ + f"Expected 20 nodes but got {len(graph.node_indices)}" + assert len(graph.edge_indices) == 26, \ + f"Expected 26 edges but got {len(graph.edge_indices)}" + assert graph.node_data(1)["population"] == 10, \ + f"Expected population of 10 but got {graph.node_data(1)['population']}" + top_level_graph_is_properly_configured(graph) + +def test_from_rustworkx(four_by_five_grid_nx): + rx_graph = rx.networkx_converter(four_by_five_grid_nx, keep_attributes=True) + graph = Graph.from_rustworkx(rx_graph) + assert len(graph.node_indices) == 20, \ + f"Expected 20 nodes but got {len(graph.node_indices)}" + assert graph.node_data(1)["population"] == 10, \ + f"Expected population of 10 but got {graph.node_data(1)['population']}" + top_level_graph_is_properly_configured(graph) + +@pytest.fixture +def four_by_five_graph_nx(four_by_five_grid_nx): + # Create an NX Graph object with attributes + graph = Graph.from_networkx(four_by_five_grid_nx) + return graph + +@pytest.fixture +def four_by_five_graph_rx(four_by_five_grid_nx): + # Create an NX Graph object with attributes + # + # Instead of using from_rustworkx(), we use + # convert_from_nx_to_rx() because tests below + # depend on the node_id maps that are created + # by convert_from_nx_to_rx() + # + graph = Graph.from_networkx(four_by_five_grid_nx) + converted_graph = graph.convert_from_nx_to_rx() + return converted_graph + +def test_convert_from_nx_to_rx(four_by_five_graph_nx): + graph = four_by_five_graph_nx # more readable + converted_graph = graph.convert_from_nx_to_rx() + + # Same number of nodes + assert len(graph.node_indices) == 20, \ + f"Expected 20 nodes but got {len(graph.node_indices)}" + assert len(converted_graph.node_indices) == 20, \ + f"Expected 20 nodes but got {len(graph.node_indices)}" + + # Same number of edges + assert len(graph.edge_indices) == 26, \ + f"Expected 26 edges but got {len(graph.edge_indices)}" + assert len(converted_graph.edge_indices) == 26, \ + f"Expected 26 edges but got {len(graph.edge_indices)}" + + # Node data is the same + # frm: TODO: Refactoring: Do this the clever Python way and test ALL at the same time + for node_id in graph.node_indices: + assert graph.node_data(node_id)["population"] == 10, \ + f"Expected population of 10 but got {graph.node_data(node_id)['population']}" + assert graph.node_data(node_id)["nx_node_id"] == node_id, \ + f"Expected nx_node_id of {node_id} but got {graph.node_data(node_id)['nx_node_id']}" + assert graph.node_data(node_id)["MVAP"] == 2, \ + f"Expected MVAP of 2 but got {graph.node_data(node_id)['MVAP']}" + for node_id in converted_graph.node_indices: + assert graph.node_data(node_id)["population"] == 10, \ + f"Expected population of 10 but got {graph.node_data(node_id)['population']}" + # frm: TODO: Code: Need to use node_id map to get appropriate node_ids for RX graph + # assert graph.node_data(node_id)["nx_node_id"] == node_id, \ + # f"Expected nx_node_id of {node_id} but got {graph.node_data(node_id)['nx_node_id']}" + assert graph.node_data(node_id)["MVAP"] == 2, \ + f"Expected MVAP of 2 but got {graph.node_data(node_id)['MVAP']}" + + # Confirm that the node_id map to the "original" NX node_ids is correct + for node_id in converted_graph.nodes: + # get the "original" NX node_id + nx_node_id = converted_graph._node_id_to_original_nx_node_id_map[node_id] + # confirm that the converted node has "nx_node_id" set to the NX node_id. This + # is an artifact of the way the NX graph was constructed. + assert converted_graph.node_data(node_id)["nx_node_id"] == nx_node_id + +def test_get_edge_from_edge_id(four_by_five_graph_nx, four_by_five_graph_rx): + + # Test that get_edge_from_edge_id works for both NX and RX based Graph objects + + # NX edges and edge_ids are the same, so this first test is trivial + # + nx_edge_id = (0, 1) + nx_edge = four_by_five_graph_nx.get_edge_from_edge_id(nx_edge_id) + assert nx_edge == (0, 1) + + # RX edge_ids are assigned arbitrarily, so without using the nx_to_rx_node_id_map + # we can't know which edge got what edge_id, so this test just verifies that + # there is an edge tuple associated with edge_id, 0 + # + rx_edge_id = 0 # arbitrary id - but there is always an edge with id == 0 + rx_edge = four_by_five_graph_rx.get_edge_from_edge_id(rx_edge_id) + assert isinstance(rx_edge[0], int), "RX edge does not exist (0)" + assert isinstance(rx_edge[1], int), "RX edge does not exist (1)" + +def test_get_edge_id_from_edge(four_by_five_graph_nx, four_by_five_graph_rx): + + # Test that get_edge_id_from_edge works for both NX and RX based Graph objects + + # NX edges and edge_ids are the same, so this first test is trivial + # + nx_edge = (0, 1) + nx_edge_id = four_by_five_graph_nx.get_edge_id_from_edge(nx_edge) + assert nx_edge_id == (0, 1) + + # Test that get_edge_id_from_edge returns an integer value and that + # when that value is used to retrieve an edge tuple, we get the + # tuple value that is expected + # + rx_edge = (0, 1) + rx_edge_id = four_by_five_graph_rx.get_edge_id_from_edge(rx_edge) + assert isinstance(rx_edge_id, int), "Edge ID not found for edge" + found_rx_edge = four_by_five_graph_rx.get_edge_from_edge_id(rx_edge_id) + assert found_rx_edge == rx_edge, "Edge ID does not yield correct edge value" + +def test_add_edge(): + # At present (October 2025), there is nothing to test. The + # code just delegates to NetworkX or RustworkX to create + # the edge. + # + # However, it is conceivable that in the future, when users + # stop using NX altogether, there might be a reason for a + # test, so this is just a placeholder for that future test... + # + assert True + +def test_subgraph(four_by_five_graph_rx): + + """ + Subgraphs are one of the most dangerous areas of the code. + In NX, subgraphs preserve node_ids - that is, the node_id + in the subgraph is the same as the node_id in the parent. + However, in RX, that is not the case - RX always creates + new node_ids starting at 0 and increasing by one + sequentially, so in general a node in an RX subgraph + will have a different node_id than it has in the parent + graph. + + To deal with this, the code creates a map from the + node_id in a subgraph to the node_id in the parent + graph, _node_id_to_parent_node_id_map. This test verifies + that this map is properly created. + + In addition, all RX based graphs that came from an NX + based graph record the "original" NX node_ids in + another node_id map, _node_id_to_original_nx_node_id_map + + When we create a subgraph, this map needs to be + established for the subgraph. This test verifies + that this map is properly created. + + Note that this test is only configured to work on + RX based Graph objects because the only uses of subgraph + in the gerrychain codebase is on RX based Graph objects. + """ + + # Create a subgraph for an arbitrary set of nodes: + subgraph_node_ids = [2, 4, 5, 8, 11, 13] + parent_graph_rx = four_by_five_graph_rx # make the code below clearer + subgraph_rx = parent_graph_rx.subgraph(subgraph_node_ids) + + assert \ + len(subgraph_node_ids) == len(subgraph_rx), \ + "Number of nodes do not agree" + + # verify that _node_id_to_parent_node_id_map is correct + for subgraph_node_id, parent_node_id in subgraph_rx._node_id_to_parent_node_id_map.items(): + # check that each node in subgraph has the same data (is the same node) + # as the node in the parent that it is mapped to + # + subgraph_stored_node_id = subgraph_rx.node_data(subgraph_node_id)["nx_node_id"] + subgraph_stored_node_id = subgraph_rx.node_data(subgraph_node_id)["nx_node_id"] + parent_stored_node_id = parent_graph_rx.node_data(parent_node_id)["nx_node_id"] + assert \ + parent_stored_node_id == subgraph_stored_node_id, \ + f"_node_id_to_parent_node_id_map is incorrect" + + # verify that _node_id_to_original_nx_node_id_map is correct + for subgraph_node_id, original_node_id in subgraph_rx._node_id_to_original_nx_node_id_map.items(): + subgraph_stored_node_id = subgraph_rx.node_data(subgraph_node_id)["nx_node_id"] + assert \ + subgraph_stored_node_id == original_node_id, \ + f"_node_id_to_original_nx_node_id_map is incorrect" + +def test_num_connected_components(four_by_five_graph_nx, four_by_five_graph_rx): + num_components_nx = four_by_five_graph_nx.num_connected_components() + num_components_rx = four_by_five_graph_rx.num_connected_components() + assert \ + num_components_nx == 2, \ + f"num_components: expected 2 but got {num_components_nx}" + assert \ + num_components_rx == 2, \ + f"num_components: expected 2 but got {num_components_rx}" + +def test_subgraphs_for_connected_components(four_by_five_graph_nx, four_by_five_graph_rx): + + subgraphs_nx = four_by_five_graph_nx.subgraphs_for_connected_components() + subgraphs_rx = four_by_five_graph_rx.subgraphs_for_connected_components() + + assert len(subgraphs_nx) == 2 + assert len(subgraphs_rx) == 2 + + assert len(subgraphs_nx[0]) == 10 + assert len(subgraphs_nx[1]) == 10 + assert len(subgraphs_rx[0]) == 10 + assert len(subgraphs_rx[1]) == 10 + + # Check that each subgraph (NX-based Graph) has correct nodes in it + node_ids_nx_0 = subgraphs_nx[0].node_indices + node_ids_nx_1 = subgraphs_nx[1].node_indices + assert node_ids_nx_0 == {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + assert node_ids_nx_1 == {10, 11, 12, 13, 14, 15, 16, 17, 18, 19} + + # Check that each subgraph (RX-based Graph) has correct nodes in it + node_ids_rx_0 = subgraphs_rx[0].node_indices + node_ids_rx_1 = subgraphs_rx[1].node_indices + original_nx_node_ids_rx_0 = subgraphs_rx[0].original_nx_node_ids_for_set(node_ids_rx_0) + original_nx_node_ids_rx_1 = subgraphs_rx[1].original_nx_node_ids_for_set(node_ids_rx_1) + assert original_nx_node_ids_rx_0 == {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + assert original_nx_node_ids_rx_1 == {10, 11, 12, 13, 14, 15, 16, 17, 18, 19} + +def test_to_networkx_graph(): + # There is already a test for this in another file + assert True + +def test_add_data(): + # This is already tested in test_make_graph.py + assert True + +######################################################## +# Long utility routine to determine if there is a cycle +# in a graph (with integer node_ids). +######################################################## + +def graph_has_cycle(set_of_edges): + + # + # Given a set of edges that define a graph, determine + # if the graph has cycles. + # + # This will allow us to test that predecessors and + # successors are in fact trees with no cycles. + # + # The approach is to do a depth-first-search that + # remembers each node it has visited, and which + # signals that a cycle exists if it revisits a node + # it has already visited + # + # Note that this code assumes that the set of nodes + # is a sequential list starting at zero with no gaps + # in the sequence. This allows us to use a simplified + # adjacency matrix which is adequate for testing + # purposes. + # + # The adjacency matrix is just a 2D square matrix + # that has a 1 value for element (i,j) iff there + # is an edge from node i to node j. Note that + # because we deal with undirected graphs the matrix + # is symetrical - edges go both ways... + # + + def add_edge(adj_matrix, s, t): + # Add an edge to an adjacency matrix + adj_matrix[s][t] = 1 + adj_matrix[t][s] = 1 # Since it's an undirected graph + + def delete_edge(adj_matrix, s, t): + # Delete an edge from an adjacency matrix + adj_matrix[s][t] = 0 + adj_matrix[t][s] = 0 # Since it's an undirected graph + + def create_empty_adjacency_matrix(num_nodes): + # create 2D array, num_nodes x num_nodes + adj_matrix = [[0] * num_nodes for _ in range(num_nodes)] + return adj_matrix + + def create_adjacency_matrix_from_set_of_edges(set_of_edges): + + # determine num_nodes + # + set_of_nodes = set() + for edge in set_of_edges: + for node in edge: + set_of_nodes.add(node) + num_nodes = len(set_of_nodes) + list_of_nodes = list(set_of_nodes) + + # We need node_ids that start at zero and go + # up sequentially with no gaps, so create a + # map for new node_ids + new_node_id_map = {} + for index, node_id in enumerate(list_of_nodes): + new_node_id_map[node_id] = index + + # Now create a new set of edges with the new node_ids + new_set_of_edges = set() + for edge in set_of_edges: + new_edge = ( + new_node_id_map[edge[0]], new_node_id_map[edge[1]] + ) + new_set_of_edges.add(new_edge) + + # debugging: + + # create an empty adjacency matrix + # + adj_matrix = create_empty_adjacency_matrix(num_nodes) + + # add the edges to the adjacency matrix + # + for edge in new_set_of_edges: + add_edge(adj_matrix, edge[0], edge[1]) + + return adj_matrix + + def inner_has_cycle(adj_matrix, visited, s, visit_list): + # This routine does a depth first search looking + # for cycles - if it encounters a node that it has + # already seen then it returns True. + + # Record having visited this node + # + visited[s] = True + visit_list.append(s) + + # Recursively visit all adjacent vertices looking for cycles + # If we have already visited a node, then there is a cycle... + # + for i in range(len(adj_matrix)): + # Recurse on every adjacent / child node... + if adj_matrix[s][i] == 1: + if visited[i]: + return True + else: + # remove this edge from adjacency matrix so we + # don't follow link back to node, i. + # + delete_edge(adj_matrix, s, i) + if inner_has_cycle(adj_matrix, visited, i, visit_list): + return True + return False + + adj_matrix = create_adjacency_matrix_from_set_of_edges(set_of_edges) + visited = [False] * len(adj_matrix) + visit_list = [] + root_node = 0 # arbitrary, but every graph has a node 0 + cycle_found = \ + inner_has_cycle(adj_matrix, visited, root_node, visit_list) + return cycle_found + +def test_graph_has_cycle(): + # Test to make sure the utility routine, graph_has_cycle, works + + # First try with no cycle + # Define the edges of the graph + set_of_edges = { + (11, 2), + (11, 0), + # (2, 0), # no cycle without this edge + (2, 3), + (2, 4) + } + the_graph_has_a_cycle = graph_has_cycle(set_of_edges) + assert the_graph_has_a_cycle == False + + # Now try with a cycle + # Define the edges of the graph + set_of_edges = { + (11, 2), + (11, 0), + (2, 0), # this edge creates a cycle + (2, 3), + (2, 4) + } + the_graph_has_a_cycle = graph_has_cycle(set_of_edges) + assert the_graph_has_a_cycle == True + +def test_generic_bfs_edges(four_by_five_graph_nx, four_by_five_graph_rx): + # + # The routine, generic_bfs_edges() returns an ordered list of + # edges from a breadth-first traversal of a graph, starting + # at the given node. + # + # For our graphs, there are two connected components (the first + # two rows and the last two rows) and each component is a + # grid: + # + # 0 - 1 - 2 - 3 - 4 + # | | | | | + # 5 - 6 - 7 - 8 - 9 + # + # 10 - 11 - 12 - 13 - 14 + # | | | | | + # 15 - 16 - 17 - 18 - 19 + # + # So, a BFS starting at 0 should produce something like: + # + # [ (0,5), (0,1), (1,6), (1,2), (2,7), (2,3), (3,8), (3,4), (4,9) ] + # + # However, the specific order that is returned depends on the + # internals of the algorithm. + # + + # + bfs_edges_nx_0 = set(four_by_five_graph_nx.generic_bfs_edges(0)) + expected_set_of_edges = { \ + (0,5), (0,1), (1,6), (1,2), (2,7), (2,3), (3,8), (3,4), (4,9) \ + } + # debugging: + assert bfs_edges_nx_0 == expected_set_of_edges + + # Check that generic_bfs_edges() does not produce a cycle + the_graph_has_a_cycle = graph_has_cycle(bfs_edges_nx_0) + assert the_graph_has_a_cycle == False + bfs_edges_nx_12 = set(four_by_five_graph_nx.generic_bfs_edges(12)) + the_graph_has_a_cycle = graph_has_cycle(bfs_edges_nx_12) + assert the_graph_has_a_cycle == False + + """ + TODO: Testing: + * Think about whether this test is actually appropriate. The + issue is that the expected_set_of_edges is the right set + for this particular graph, but I am not sure that this is + a good enough test. Think about other situations... + + * Think about whether to verify that the BFS returned + has no cycles. It doesn't in this particular case, + but perhaps we should have more cases that stress the test... + """ +def test_generic_bfs_successors_generator(): + # TODO: Testing: Write a test for this routine + # + # Note that the code for this routine is very straight-forward, so + # writing a test is not high-priority. The only reason I did not + # just go ahead and write one is because it was not immediately + # clear to me how to write the test - more work than doing a + # thorough code review... + # + assert True + +def test_generic_bfs_successors(): + # TODO: Testing: Write a test for this routine + # + # Code is trivial, but because this routine is important it + # deserves a test - just not clear off top of my head how + # to write the test... + # + assert True + +def test_generic_bfs_predecessors(): + # TODO: Testing: Write a test for this routine + # + # Code is trivial, but because this routine is important it + # deserves a test - just not clear off top of my head how + # to write the test... + # + assert True + +def test_predecessors(): + # TODO: Testing: Write a test for this routine + # + # Code is trivial, but because this routine is important it + # deserves a test - just not clear off top of my head how + # to write the test... + # + assert True + +def test_successors(): + # TODO: Testing: Write a test for this routine + # + # Code is trivial, but because this routine is important it + # deserves a test - just not clear off top of my head how + # to write the test... + # + assert True + +def test_laplacian_matrix(): + # TODO: Testing: Write a test for this routine + # + # Not clear off the top of my head how + # to write the test... + # + assert True + +def test_normalized_laplacian_matrix(): + # TODO: Testing: Write a test for this routine + # + # This routine has not yet been implemented (as + # of October 2025), but when it is implemented + # we should add a test for it... + # + assert True + + +""" +============================================================= + +TODO: Code: ??? + + * Aliasing concerns: + + It occurs to me that the RX node_data is aliased with the NX node_data. + That is, the data dictionaries in the NX Graph are just retained + when the NX Graph is converted to be an RX Graph - so if you change + the data in the RX Graph, the NX Graph from which we created the RX + graph will also be changed. + + I believe that this is also true for subgraphs for both NX and RX, + meaning that the node_data in the subgraph is the exact same + data dictionary in the parent graph and the subgraph. + + I am not sure if this is a problem or not, but it is something + to be tested / thought about... + + * NX allows node_ids to be almost anything - they can be integers, + strings, even tuples. I think that they just need to be hashable. + + I don't know if we need to test that non-integer NX node_ids + don't cause a problem. There are tests elsewhere that have + NX node_ids that are tuples, and that test passes, so I think + we are OK, but there are no tests specifically targeting this + issue that I know of. + +============================================================= +""" diff --git a/tests/test_laplacian.py b/tests/test_laplacian.py index 25ab2572..927da85d 100644 --- a/tests/test_laplacian.py +++ b/tests/test_laplacian.py @@ -21,7 +21,7 @@ to convert the NX version's result to have floating point values. """ -# frm: TODO: Add additional tests for laplacian matrix calculations, in +# frm: TODO: Testing: Add additional tests for laplacian matrix calculations, in # particular, add a test for normalized_laplacian_matrix() # once that routine has been implemented. @@ -71,8 +71,11 @@ def nx_graph(): @pytest.fixture def rx_graph(): this_rx_graph = rx.PyGraph() - this_rx_graph.add_nodes_from([0, 1, 2, 3]) - this_rx_graph.add_edges_from([(0, 1, "data"), (0, 2, "data"), (1, 2, "data"), (2, 3, "data")]) + # argument to add_node_from() is the data to be associated with each node. + # To be compatible with GerryChain, nodes need to have data values that are dictionaries + # so we just have an empty dict for each node's data + this_rx_graph.add_nodes_from([{}, {}, {}, {}]) + this_rx_graph.add_edges_from([(0, 1, {}), (0, 2, {}), (1, 2, {}), (2, 3, {})]) return this_rx_graph diff --git a/tests/test_make_graph.py b/tests/test_make_graph.py index 31bfb98b..c0efc9e2 100644 --- a/tests/test_make_graph.py +++ b/tests/test_make_graph.py @@ -204,12 +204,13 @@ def test_computes_boundary_perims(geodataframe_with_boundary): def edge_set_equal(set1, set2): - # Returns true if the two sets have the same edges. The complication is that - # for an edge, (1,2) is the same as (2,1), so to compare them you need to - # canonicalize the edges somehow. This code just takes set1 and set2 and - # creates a new set for each that has both edge pairs for each edge, and - # it then compares those new sets. - # + """ + Returns true if the two sets have the same edges. + + The complication is that for an edge, (1,2) is the same as (2,1), so to compare them you + need to canonicalize the edges somehow. This code just takes set1 and set2 and creates + a new set for each that has both edge pairs for each edge, and it then compares those new sets. + """ canonical_set1 = {(y, x) for x, y in set1} | set1 canonical_set2 = {(y, x) for x, y in set2} | set2 return canonical_set1 == canonical_set2 diff --git a/tests/test_metagraph.py b/tests/test_metagraph.py index 32c58617..a4bad50d 100644 --- a/tests/test_metagraph.py +++ b/tests/test_metagraph.py @@ -13,7 +13,7 @@ def partition(graph): def test_all_cut_edge_flips(partition): - # frm: TODO: Maybe change all_cut_edge_flips to return a dict + # frm: TODO: Testing: Maybe change all_cut_edge_flips to return a dict # # At present, it returns an iterator, which makes the code below # more complicated than it needs to be. If it just returned @@ -53,7 +53,7 @@ def test_accepts_list_of_constraints(self, partition): def test_all_valid_flips(partition): - # frm: TODO: NX vs. RX node_id issues... + # frm: TODO: Testing: NX vs. RX node_id issues... def disallow_six_to_one(partition): for node, part in partition.flips.items(): if node == 6 and part == 1: @@ -62,7 +62,7 @@ def disallow_six_to_one(partition): constraints = [disallow_six_to_one] - # frm: TODO: If I created a utility routine to convert + # frm: TODO: Testing: If I created a utility routine to convert # a list of flips to original node_ids, # then I could use that here and then # convert the resulting list to a set... diff --git a/tests/test_region_aware.py b/tests/test_region_aware.py index 50b22ac2..0fb1c305 100644 --- a/tests/test_region_aware.py +++ b/tests/test_region_aware.py @@ -161,13 +161,8 @@ def straddled_regions(partition, reg_attr, all_reg_names): """Returns the total number of district that straddle two regions in the partition.""" split = {name: 0 for name in all_reg_names} - # frm: TODO: Grok what this tests - not clear to me at this time... + # frm: TODO: Testing: Grok what this tests - not clear to me at this time... - # frm: Original Code: - # for node1, node2 in set(partition.graph.edges() - partition["cut_edges"]): - # split[partition.graph.nodes[node1][reg_attr]] += 1 - # split[partition.graph.nodes[node2][reg_attr]] += 1 - # for node1, node2 in set(partition.graph.edges() - partition["cut_edges"]): split[partition.graph.node_data(node1)[reg_attr]] += 1 split[partition.graph.node_data(node2)[reg_attr]] += 1 @@ -244,10 +239,6 @@ def test_region_aware_muni_warning(): with pytest.warns(UserWarning) as record: # Random seed 2 should succeed, but drawing the # tree is hard, so we should get a warning - # frm: TODO: stmt below fails - saying too many attempts: - # - # raise RuntimeError(f"Could not find a possible cut after {max_attempts} attempts.") - # RuntimeError: Could not find a possible cut after 10000 attempts. run_chain_dual( seed=2, steps=1000, diff --git a/tests/test_reproducibility.py b/tests/test_reproducibility.py index 85d2122c..49ccafec 100644 --- a/tests/test_reproducibility.py +++ b/tests/test_reproducibility.py @@ -58,7 +58,6 @@ def test_repeatable(three_by_three_grid): {3: 1}, ] flips = [partition.flips for partition in chain] - print(flips) assert flips == expected_flips diff --git a/tests/test_tally.py b/tests/test_tally.py index 46000848..24757cc3 100644 --- a/tests/test_tally.py +++ b/tests/test_tally.py @@ -15,6 +15,8 @@ def random_assignment(graph, num_districts): def test_data_tally_works_as_an_updater(three_by_three_grid): + # Simple test that a DataTally creates an attribute on a partition. + # Another test (below) checks that the computed "tally" is correct. assignment = random_assignment(three_by_three_grid, 4) data = {node: random.randint(1, 100) for node in three_by_three_grid.nodes} parts = tuple(set(assignment.values())) diff --git a/tests/test_tree.py b/tests/test_tree.py index 7635f52f..fad9741a 100644 --- a/tests/test_tree.py +++ b/tests/test_tree.py @@ -40,6 +40,8 @@ # testing both NX-based and RX-based Graph objects clean. # +# frm: TODO: Documentation: test_tree.py: explain nx_to_rx_node_id_map + @pytest.fixture def graph_with_pop_nx(three_by_three_grid): # NX-based Graph object @@ -169,12 +171,6 @@ def do_test_recursive_tree_part_returns_within_epsilon_of_target_pop(twelve_by_t partition = Partition( twelve_by_twelve_with_pop_graph, result, updaters={"pop": Tally("pop")} ) - # frm: Original Code: - # - # return all( - # abs(part_pop - ideal_pop) / ideal_pop < epsilon - # for part_pop in partition["pop"].values() - # ) assert all( abs(part_pop - ideal_pop) / ideal_pop < epsilon for part_pop in partition["pop"].values() @@ -216,13 +212,6 @@ def do_test_recursive_tree_part_returns_within_epsilon_of_target_pop_using_contr partition = Partition( twelve_by_twelve_with_pop_graph, result, updaters={"pop": Tally("pop")} ) - # frm: Original Code: - # - # return all( - # abs(part_pop - ideal_pop) / ideal_pop < epsilon - # for part_pop in partition["pop"].values() - # ) - # assert all( abs(part_pop - ideal_pop) / ideal_pop < epsilon for part_pop in partition["pop"].values() @@ -265,12 +254,6 @@ def do_test_recursive_seed_part_returns_within_epsilon_of_target_pop( partition = Partition( twelve_by_twelve_with_pop_graph, result, updaters={"pop": Tally("pop")} ) - # frm: Original Code: - # - # return all( - # abs(part_pop - ideal_pop) / ideal_pop < epsilon - # for part_pop in partition["pop"].values() - # ) assert all( abs(part_pop - ideal_pop) / ideal_pop < epsilon for part_pop in partition["pop"].values() @@ -314,12 +297,6 @@ def do_test_recursive_seed_part_returns_within_epsilon_of_target_pop_using_contr partition = Partition( twelve_by_twelve_with_pop_graph, result, updaters={"pop": Tally("pop")} ) - # frm: Original Code: - # - # return all( - # abs(part_pop - ideal_pop) / ideal_pop < epsilon - # for part_pop in partition["pop"].values() - # ) assert all( abs(part_pop - ideal_pop) / ideal_pop < epsilon for part_pop in partition["pop"].values() @@ -408,12 +385,6 @@ def do_test_recursive_seed_part_with_n_unspecified_within_epsilon( partition = Partition( twelve_by_twelve_with_pop_graph, result, updaters={"pop": Tally("pop")} ) - # frm: Original Code: - # - # return all( - # abs(part_pop - ideal_pop) / ideal_pop < epsilon - # for part_pop in partition["pop"].values() - # ) assert all( abs(part_pop - ideal_pop) / ideal_pop < epsilon for part_pop in partition["pop"].values() @@ -431,7 +402,6 @@ def test_recursive_seed_part_with_n_unspecified_within_epsilon( def do_test_random_spanning_tree_returns_tree_with_pop_attribute(graph): tree = random_spanning_tree(graph) - # frm: Original code: assert networkx.is_tree(tree) assert tree.is_a_tree() def test_random_spanning_tree_returns_tree_with_pop_attribute(graph_with_pop_nx, graph_with_pop_rx): @@ -443,7 +413,6 @@ def test_random_spanning_tree_returns_tree_with_pop_attribute(graph_with_pop_nx, def do_test_uniform_spanning_tree_returns_tree_with_pop_attribute(graph): tree = uniform_spanning_tree(graph) - # frm: Original code: assert networkx.is_tree(tree) assert tree.is_a_tree() def test_uniform_spanning_tree_returns_tree_with_pop_attribute(graph_with_pop_nx, graph_with_pop_rx): @@ -460,47 +429,51 @@ def do_test_bipartition_tree_returns_a_tree(graph, spanning_tree): graph, "pop", ideal_pop, 0.25, 10, spanning_tree, lambda x: 4 ) - # frm: Original code: - # assert networkx.is_tree(spanning_tree.subgraph(result)) - # assert networkx.is_tree( - # spanning_tree.subgraph({node for node in tree if node not in result}) - # ) assert spanning_tree.subgraph(result).is_a_tree() assert spanning_tree.subgraph({node for node in spanning_tree if node not in result}).is_a_tree() -def create_graphs_from_nx_edges(list_of_edges_nx, nx_to_rx_node_id_map): +def create_graphs_from_nx_edges(num_nodes, list_of_edges_nx, nx_to_rx_node_id_map): - # NX is easy - just use the lsit of NX edges + # NX is easy - just use the list of NX edges graph_nx = Graph.from_networkx(networkx.Graph(list_of_edges_nx)) - print(f"create_graphs_from_nx_edges: list_of_edges_nx is: {list_of_edges_nx}") - print(f"create_graphs_from_nx_edges: nx_to_rx_node_id_map: {nx_to_rx_node_id_map}") - - # RX requires more work. First we have to translate the node_ids used in the + # RX requires more work. + # + # First we create the RX graph and add nodes. + # + # frm: TODO: Testing: Update test so that the number of nodes is not hard-coded... + # + # Then we to create the appropriate RX edges - the ones that + # correspond to the NX edges but using the RX node_ids for the edges. + # + # First we have to translate the node_ids used in the # list of edges to be the ones used in the RX graph using the # nx_to_rx_node_id_map. Then we need to create a rustworkx.PyGraph and then # from that create a "new" Graph object. + # Create the RX graph rx_graph = rustworkx.PyGraph() + for i in range(num_nodes): + rx_graph.add_node({}) # empty data dict for node_data + # Verify that the nodes created have node_ids 0-(num_nodes-1) + assert(set(rx_graph.node_indices()) == set(range(num_nodes))) + # Set the attribute identifying the "original" NX node_id + # This is normally set by the code that converts an NX graph to RX + # but we are cobbling together stuff for a test and so have to + # just do it here... + rx_to_nx_node_id_map = {v: k for k,v in nx_to_rx_node_id_map.items()} + for node_id in rx_graph.node_indices(): + rx_graph[node_id]["__networkx_node__"] = rx_to_nx_node_id_map[node_id] # translate the NX edges into the appropriate node_ids for the derived RX graph - empty_edge_data_dict = {} # needed so we can attach edge data to each edge list_of_edges_rx = [ ( nx_to_rx_node_id_map[edge[0]], nx_to_rx_node_id_map[edge[1]], - empty_edge_data_dict + {} # empty data dict for edge_data ) for edge in list_of_edges_nx ] - rx_graph = rustworkx.PyGraph(); - - # Create the RX nodes - for i in range(9): - empty_node_data_dict = {} # needed so we can attach node data to each node - rx_graph.add_node(empty_node_data_dict) - # Verify that the nodes created have node_ids 0-8 - assert(set(rx_graph.node_indices()) == set(range(9))) # Add the RX edges rx_graph.add_edges_from(list_of_edges_rx) @@ -515,6 +488,7 @@ def test_bipartition_tree_returns_a_tree(graph_with_pop_nx, graph_with_pop_rx): spanning_tree_nx, spanning_tree_rx = \ create_graphs_from_nx_edges( + 9, spanning_tree_edges_nx, graph_with_pop_rx.nx_to_rx_node_id_map ) @@ -573,7 +547,7 @@ def test_reversible_recom_works_as_a_proposal(partition_with_pop): # the value of epsilon (of 10%) is never used, whatever... # - # frm: TODO: Grok this test - what is it trying to accomplish? + # frm: TODO: Testing: Grok this test - what is it trying to accomplish? # # The proposal uses reversible_recom() with the default value for the "repeat_until_valid" # parameter which is False. This means that the call to try to combine and then split two @@ -601,11 +575,11 @@ def test_reversible_recom_works_as_a_proposal(partition_with_pop): for state in chain: assert contiguous(state) -# frm: TODO: Add more tests using MarkovChain... +# frm: TODO: Testing: Add more tests using MarkovChain... def test_find_balanced_cuts_contraction(): - # frm: TODO: Add test for RX-based Graph object + # frm: TODO: Testing: Add test for RX-based Graph object tree = Graph.from_networkx( networkx.Graph([(0, 1), (1, 2), (1, 4), (3, 4), (4, 5), (3, 6), (6, 7), (6, 8)]) @@ -639,6 +613,7 @@ def test_no_balanced_cuts_contraction_when_one_side_okay(): tree_nx, tree_rx = \ create_graphs_from_nx_edges( + 5, list_of_nodes_nx, nx_to_rx_node_id_map ) @@ -672,6 +647,7 @@ def test_find_balanced_cuts_memo(): tree_nx, tree_rx = \ create_graphs_from_nx_edges( + 9, list_of_nodes_nx, nx_to_rx_node_id_map ) @@ -712,6 +688,7 @@ def test_no_balanced_cuts_memo_when_one_side_okay(): tree_nx, tree_rx = \ create_graphs_from_nx_edges( + 5, list_of_nodes_nx, nx_to_rx_node_id_map ) diff --git a/tests/updaters/test_perimeters.py b/tests/updaters/test_perimeters.py index f161246c..d98d1f5a 100644 --- a/tests/updaters/test_perimeters.py +++ b/tests/updaters/test_perimeters.py @@ -36,7 +36,7 @@ def test_interior_perimeter_handles_flips_with_a_simple_grid(): def test_cut_edges_by_part_handles_flips_with_a_simple_grid(): - # frm: TODO: Add a graphic here + # frm: TODO: Testing: Add a graphic here # # That will allow the person reading this code to make sense # of what it does... diff --git a/tests/updaters/test_split_scores.py b/tests/updaters/test_split_scores.py index bb0b15e8..27d37504 100644 --- a/tests/updaters/test_split_scores.py +++ b/tests/updaters/test_split_scores.py @@ -6,14 +6,6 @@ from gerrychain import Graph import networkx -# frm: TODO: This test fails due to NX dependencies in locality_split_scores.py -# -# There are lots of comments in that file about what needs to be fixed, but -# it is a low priority becauxe the code in locality_split_scores.py is not used -# in the gerrychain codebase - it is presumeably used by other users of GC, so -# this needs to be fixed sometime - but later... -# - @pytest.fixture def three_by_three_grid(): """Returns a graph that looks like this: @@ -79,8 +71,6 @@ def split_partition(graph_with_counties): ) return partition -# frm: TODO: NX vs. RX node_id issues here. - class TestSplittingScores: diff --git a/tests/updaters/test_splits.py b/tests/updaters/test_splits.py index 81e8b7b2..62f6e395 100644 --- a/tests/updaters/test_splits.py +++ b/tests/updaters/test_splits.py @@ -50,7 +50,7 @@ def test_describes_splits_for_all_counties(self, partition): def test_no_splits(self, graph_with_counties): - # frm: TODO: Why does this not just use "split_partition"? Isn't it the same? + # frm: TODO: Testing: Why does this not just use "split_partition"? Isn't it the same? partition = Partition(graph_with_counties, assignment="county") result = compute_county_splits(partition, "county", None) diff --git a/tests/updaters/test_updaters.py b/tests/updaters/test_updaters.py index 107ec40a..2855cb4d 100644 --- a/tests/updaters/test_updaters.py +++ b/tests/updaters/test_updaters.py @@ -34,19 +34,6 @@ def partition_with_election(graph_with_d_and_r_cols): graph = graph_with_d_and_r_cols assignment = random_assignment(graph, 3) - # Original Code - this was deprecated when we converted to using RustworkX - # based graphs in partitions. The problem was that the node_ids used below - # are NX-based "original" node_ids that are not valid for the derived RustworkX - # based graph. - # - # party_names_to_node_attribute_names = { - # "D": {node: graph.node_data(node)["D"] for node in graph.nodes}, - # "R": {node: graph.node_data(node)["R"] for node in graph.nodes}, - # } - # - # Instead we do the functionally equivalent operation which identifies where - # to find vote totals for each party using the attribute name for the party. - # party_names_to_node_attribute_names = ["D", "R"] election = Election("Mock Election", party_names_to_node_attribute_names) @@ -206,7 +193,7 @@ def test_election_result_has_a_cute_str_method(): def _convert_dict_of_set_of_rx_node_ids_to_set_of_nx_node_ids(dict_of_set_of_rx_nodes, nx_to_rx_node_id_map): - # frm: TODO: This way to convert node_ids is clumsy and inconvenient. Think of something better... + # frm: TODO: Testing: This way to convert node_ids is clumsy and inconvenient. Think of something better... # When we create a partition from an NX based Graph we convert it to be an # RX based Graph which changes the node_ids of the graph. If one wants @@ -249,7 +236,7 @@ def test_exterior_boundaries_as_a_set(three_by_three_grid): result = partition["exterior_boundaries_as_a_set"] - # frm: TOdO: Come up with a nice way to convert the result which uses + # frm: TODO: Testing: Come up with a nice way to convert the result which uses # RX based node_ids back to the original NX based node_ids... # If the original graph that the partition was based on was an NX graph @@ -336,7 +323,7 @@ def test_perimeter(three_by_three_grid): graph = three_by_three_grid for i in [0, 1, 2, 3, 5, 6, 7, 8]: graph.node_data(i)["boundary_node"] = True - # frm: TODO: Update test - boundary_perim should be 2 for corner nodes... + # frm: TODO: Testing: Update test - boundary_perim should be 2 for corner nodes... graph.node_data(i)["boundary_perim"] = 1 graph.node_data(4)["boundary_node"] = False