diff --git a/debacl/level_set_tree.py b/debacl/level_set_tree.py index 99cbcba..782d252 100644 --- a/debacl/level_set_tree.py +++ b/debacl/level_set_tree.py @@ -8,6 +8,8 @@ from __future__ import print_function as _print_function from __future__ import absolute_import as _absolute_import +from sys import version + import logging as _logging import copy as _copy import pickle as _pickle @@ -271,7 +273,12 @@ def plot(self, form='mass', horizontal_spacing='uniform', color_nodes=[], split_coords = {} ## Find the root connected components and corresponding plot intervals - ix_root = _np.array([k for k, v in self.nodes.iteritems() + if version[0] == '2': + nodes_items = self.nodes.iteritems() + if version[0] == '3': + nodes_items = self.nodes.items() + + ix_root = _np.array([k for k, v in nodes_items if v.parent is None]) n_root = len(ix_root) census = _np.array([len(self.nodes[x].members) for x in ix_root], @@ -322,7 +329,12 @@ def plot(self, form='mass', horizontal_spacing='uniform', color_nodes=[], ax.yaxis.grid(color='gray') ax.set_yticks(primary_ticks) ax.set_yticklabels(primary_labels) - + + if version[0] == '2': + nodes_values = self.nodes.itervalues() + if version[0] == '3': + nodes_values = self.nodes.values() + ## Form-specific details if form == 'branch-mass': yrange = max(primary_ticks) @@ -331,15 +343,16 @@ def plot(self, form='mass', horizontal_spacing='uniform', color_nodes=[], elif form == 'density': ax.set_ylabel("density level") - ymin = min([v.start_level for v in self.nodes.itervalues()]) - ymax = max([v.end_level for v in self.nodes.itervalues()]) + + ymin = min([v.start_level for v in nodes_values]) + ymax = max([v.end_level for v in nodes_values]) yrange = ymax - ymin ax.set_ylim(ymin - gap * yrange, ymax + 0.05 * yrange) elif form == 'mass': ax.set_ylabel("mass level") - ymin = min([v.start_mass for v in self.nodes.itervalues()]) - ymax = max([v.end_mass for v in self.nodes.itervalues()]) + ymin = min([v.start_mass for v in nodes_values]) + ymax = max([v.end_mass for v in nodes_values]) yrange = ymax - ymin ax.set_ylim(ymin - gap * yrange, ymax + 0.05 * yrange) @@ -347,7 +360,12 @@ def plot(self, form='mass', horizontal_spacing='uniform', color_nodes=[], raise ValueError('Plot form not understood') ## Color the line segments. - node_colors = {k: [0.0, 0.0, 0.0, 1.0] for k, v in self.nodes.items()} + if version[0] == '2': + nodes_items = self.nodes.iteritems() + if version[0] == '3': + nodes_items = self.nodes.items() + + node_colors = {k: [0.0, 0.0, 0.0, 1.0] for k, v in nodes_items} palette = _plt.get_cmap(colormap) colorset = palette(_np.linspace(0, 1, len(color_nodes))) @@ -604,16 +622,32 @@ def _merge_by_size(self, threshold): tree.prune_threshold = threshold ## remove small root branches - small_roots = [k for k, v in tree.nodes.iteritems() + if version[0] == '2': + tree_nodes_items = tree.nodes.iteritems() + if version[0] == '3': + tree_nodes_items = tree.nodes.items() + + small_roots = [k for k, v in tree_nodes_items if v.parent is None and len(v.members) <= threshold] - + for root in small_roots: root_tree = tree._make_subtree(root) - for ix in root_tree.nodes.iterkeys(): + + if version[0] == '2': + root_tree_nodes_keys = root_tree.nodes.iterkeys() + if version[0] == '3': + root_tree_nodes_keys = root_tree.nodes.keys() + + for ix in root_tree_nodes_keys: del tree.nodes[ix] ## main pruning - parents = [k for k, v in tree.nodes.iteritems() + if version[0] == '2': + tree_nodes_items = tree.nodes.iteritems() + if version[0] == '3': + tree_nodes_items = tree.nodes.items() + + parents = [k for k, v in tree_nodes_items if len(v.children) >= 1] parents = _np.sort(parents)[::-1] @@ -622,9 +656,14 @@ def _merge_by_size(self, threshold): # get size of each child kid_size = {k: len(tree.nodes[k].members) for k in parent.children} - + # print(_np.array([(key,val) for (key,val) in kid_size.items()])) + # count children larger than 'threshold' - n_bigkid = sum(_np.array(kid_size.values()) >= threshold) + # print(_np.array(kid_size.values()), threshold) + # n_bigkid = sum(_np.array(kid_size.values()) >= threshold) + # n_bigkid = 0 + n_bigkid = sum([len(tree.nodes[k].members) for k in parent.children]) + # n_bigkid = sum(_np.array([(key,val) for (key,val) in kid_size.items()]) >= threshold) if n_bigkid == 0: # update parent's end level and end mass @@ -641,7 +680,12 @@ def _merge_by_size(self, threshold): elif n_bigkid == 1: pass # identify the big kid - ix_bigkid = [k for k, v in kid_size.iteritems() + if version[0] == '2': + kid_size_items = kid_size.iteritems() + if version[0] == '3': + kid_size_items = kid_size.items() + + ix_bigkid = [k for k, v in kid_size_items if v >= threshold][0] bigkid = tree.nodes[ix_bigkid] @@ -786,8 +830,13 @@ def _upper_set_cluster(self, threshold, form='mass'): form='density') else: + if version[0] == '2': + nodes_items = self.nodes.iteritems() + if version[0] == '3': + nodes_items = self.nodes.items() + upper_level_set = _np.where(_np.array(self.density) > threshold)[0] - active_nodes = [k for k, v in self.nodes.iteritems() + active_nodes = [k for k, v in nodes_items if (v.start_level <= threshold and v.end_level > threshold)] @@ -835,7 +884,7 @@ def _first_K_level_cluster(self, k): """ cut = self._find_K_cut(k) - nodes = [e for e, v in self.nodes.iteritems() + nodes = [e for e, v in self.nodesiteritems() if v.start_level <= cut and v.end_level > cut] points = [] @@ -896,25 +945,39 @@ def _find_K_cut(self, k): """ ## Find the lowest level to cut at that has k or more clusters - starts = [v.start_level for v in self.nodes.itervalues()] - ends = [v.end_level for v in self.nodes.itervalues()] + if version[0] == '2': + nodes_values = self.nodes.itervalues() + nodes_items = self.nodes.iteritems() + if version[0] == '3': + nodes_values = self.nodes.values() + nodes_values = self.nodes.items() + + starts = [v.start_level for v in nodes_values] + ends = [v.end_level for v in nodes_values] crits = _np.unique(starts + ends) nclust = {} for c in crits: - nclust[c] = len([e for e, v in self.nodes.iteritems() + nclust[c] = len([e for e, v in nodes_items if v.start_level <= c and v.end_level > c]) width = _np.max(nclust.values()) - + + if version[0] == '2': + nclust_items = nclust.iteritems() + nclust_values = nclust.itervalues() + if version[0] == '3': + nclust_items = nclust.items() + nclust_values = nclust.values() + if k in nclust.values(): - cut = _np.min([e for e, v in nclust.iteritems() if v == k]) + cut = _np.min([e for e, v in nclust_items if v == k]) else: if width < k: - cut = _np.min([e for e, v in nclust.iteritems() if v == width]) + cut = _np.min([e for e, v in nclust_items if v == width]) else: - ktemp = _np.min([v for v in nclust.itervalues() if v > k]) - cut = _np.min([e for e, v in nclust.iteritems() if v == ktemp]) + ktemp = _np.min([v for v in nclust_values if v > k]) + cut = _np.min([e for e, v in nclust_items if v == ktemp]) return cut @@ -1039,8 +1102,13 @@ def _construct_branch_map(self, ix, interval, form, horizontal_spacing, segmap += branch_segmap splitmap += branch_splitmap - splits = dict(splits.items() + branch_splits.items()) - segments = dict(segments.items() + branch_segs.items()) + # d = dict1.copy() + # d.update(dict2) + #splits = dict(splits.items() + branch_splits.items()) + splits.update(branch_splits.items()) + + segments.update(branch_segs.items()) + # segments = dict(segments.items() + branch_segs.items()) ## find the middle of the children's x-position and make vertical # segment ix @@ -1183,8 +1251,13 @@ def _construct_mass_map(self, ix, start_pile, interval, segmap += branch_segmap splitmap += branch_splitmap - splits = dict(splits.items() + branch_splits.items()) - segments = dict(segments.items() + branch_segs.items()) + # d = dict1.copy() + # d.update(dict2) + # splits = dict(splits.items() + branch_splits.items()) + splits.update(branch_splits.items()) + + segments.update(branch_segs.items()) + # segments = dict(segments.items() + branch_segs.items()) ## find the middle of the children's x-position and make vertical ## segment ix @@ -1379,15 +1452,26 @@ def construct_tree_from_graph(adjacency_list, density, prune_threshold=None, previous_level = level ## compute the mass after the current bg set is removed + if version[0] == '2': + T_subGraph_values = T._subgraphs.itervalues() + if version[0] == '3': + T_subGraph_values = T._subgraphs.values() + old_vcount = sum([x.number_of_nodes() - for x in T._subgraphs.itervalues()]) + for x in T_subGraph_values]) + current_mass = 1. - ((old_vcount - len(bg)) / n) # loop through active components, i.e. subgraphs deactivate_keys = [] # subgraphs to deactivate at the iter end activate_subgraphs = {} # new subgraphs to add at the end of the iter - - for (k, H) in T._subgraphs.iteritems(): + + if version[0] == '2': + T_subGraph_items = T._subgraphs.iteritems() + if version[0] == '3': + T_subGraph_items = T._subgraphs.items() + + for (k, H) in T_subGraph_items: ## remove nodes at the current level H.remove_nodes_from(bg)