Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 115 additions & 31 deletions debacl/level_set_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from __future__ import print_function as _print_function
from __future__ import absolute_import as _absolute_import

from sys import version

import logging as _logging
import copy as _copy
import pickle as _pickle
Expand Down Expand Up @@ -271,7 +273,12 @@ def plot(self, form='mass', horizontal_spacing='uniform', color_nodes=[],
split_coords = {}

## Find the root connected components and corresponding plot intervals
ix_root = _np.array([k for k, v in self.nodes.iteritems()
if version[0] == '2':
nodes_items = self.nodes.iteritems()
if version[0] == '3':
nodes_items = self.nodes.items()

ix_root = _np.array([k for k, v in nodes_items
if v.parent is None])
n_root = len(ix_root)
census = _np.array([len(self.nodes[x].members) for x in ix_root],
Expand Down Expand Up @@ -322,7 +329,12 @@ def plot(self, form='mass', horizontal_spacing='uniform', color_nodes=[],
ax.yaxis.grid(color='gray')
ax.set_yticks(primary_ticks)
ax.set_yticklabels(primary_labels)


if version[0] == '2':
nodes_values = self.nodes.itervalues()
if version[0] == '3':
nodes_values = self.nodes.values()

## Form-specific details
if form == 'branch-mass':
yrange = max(primary_ticks)
Expand All @@ -331,23 +343,29 @@ def plot(self, form='mass', horizontal_spacing='uniform', color_nodes=[],

elif form == 'density':
ax.set_ylabel("density level")
ymin = min([v.start_level for v in self.nodes.itervalues()])
ymax = max([v.end_level for v in self.nodes.itervalues()])

ymin = min([v.start_level for v in nodes_values])
ymax = max([v.end_level for v in nodes_values])
yrange = ymax - ymin
ax.set_ylim(ymin - gap * yrange, ymax + 0.05 * yrange)

elif form == 'mass':
ax.set_ylabel("mass level")
ymin = min([v.start_mass for v in self.nodes.itervalues()])
ymax = max([v.end_mass for v in self.nodes.itervalues()])
ymin = min([v.start_mass for v in nodes_values])
ymax = max([v.end_mass for v in nodes_values])
yrange = ymax - ymin
ax.set_ylim(ymin - gap * yrange, ymax + 0.05 * yrange)

else:
raise ValueError('Plot form not understood')

## Color the line segments.
node_colors = {k: [0.0, 0.0, 0.0, 1.0] for k, v in self.nodes.items()}
if version[0] == '2':
nodes_items = self.nodes.iteritems()
if version[0] == '3':
nodes_items = self.nodes.items()

node_colors = {k: [0.0, 0.0, 0.0, 1.0] for k, v in nodes_items}
palette = _plt.get_cmap(colormap)
colorset = palette(_np.linspace(0, 1, len(color_nodes)))

Expand Down Expand Up @@ -604,16 +622,32 @@ def _merge_by_size(self, threshold):
tree.prune_threshold = threshold

## remove small root branches
small_roots = [k for k, v in tree.nodes.iteritems()
if version[0] == '2':
tree_nodes_items = tree.nodes.iteritems()
if version[0] == '3':
tree_nodes_items = tree.nodes.items()

small_roots = [k for k, v in tree_nodes_items
if v.parent is None and len(v.members) <= threshold]

for root in small_roots:
root_tree = tree._make_subtree(root)
for ix in root_tree.nodes.iterkeys():

if version[0] == '2':
root_tree_nodes_keys = root_tree.nodes.iterkeys()
if version[0] == '3':
root_tree_nodes_keys = root_tree.nodes.keys()

for ix in root_tree_nodes_keys:
del tree.nodes[ix]

## main pruning
parents = [k for k, v in tree.nodes.iteritems()
if version[0] == '2':
tree_nodes_items = tree.nodes.iteritems()
if version[0] == '3':
tree_nodes_items = tree.nodes.items()

parents = [k for k, v in tree_nodes_items
if len(v.children) >= 1]
parents = _np.sort(parents)[::-1]

Expand All @@ -622,9 +656,14 @@ def _merge_by_size(self, threshold):

# get size of each child
kid_size = {k: len(tree.nodes[k].members) for k in parent.children}

# print(_np.array([(key,val) for (key,val) in kid_size.items()]))

# count children larger than 'threshold'
n_bigkid = sum(_np.array(kid_size.values()) >= threshold)
# print(_np.array(kid_size.values()), threshold)
# n_bigkid = sum(_np.array(kid_size.values()) >= threshold)
# n_bigkid = 0
n_bigkid = sum([len(tree.nodes[k].members) for k in parent.children])
# n_bigkid = sum(_np.array([(key,val) for (key,val) in kid_size.items()]) >= threshold)

if n_bigkid == 0:
# update parent's end level and end mass
Expand All @@ -641,7 +680,12 @@ def _merge_by_size(self, threshold):
elif n_bigkid == 1:
pass
# identify the big kid
ix_bigkid = [k for k, v in kid_size.iteritems()
if version[0] == '2':
kid_size_items = kid_size.iteritems()
if version[0] == '3':
kid_size_items = kid_size.items()

ix_bigkid = [k for k, v in kid_size_items
if v >= threshold][0]
bigkid = tree.nodes[ix_bigkid]

Expand Down Expand Up @@ -786,8 +830,13 @@ def _upper_set_cluster(self, threshold, form='mass'):
form='density')

else:
if version[0] == '2':
nodes_items = self.nodes.iteritems()
if version[0] == '3':
nodes_items = self.nodes.items()

upper_level_set = _np.where(_np.array(self.density) > threshold)[0]
active_nodes = [k for k, v in self.nodes.iteritems()
active_nodes = [k for k, v in nodes_items
if (v.start_level <= threshold and
v.end_level > threshold)]

Expand Down Expand Up @@ -835,7 +884,7 @@ def _first_K_level_cluster(self, k):
"""

cut = self._find_K_cut(k)
nodes = [e for e, v in self.nodes.iteritems()
nodes = [e for e, v in self.nodesiteritems()
if v.start_level <= cut and v.end_level > cut]

points = []
Expand Down Expand Up @@ -896,25 +945,39 @@ def _find_K_cut(self, k):
"""

## Find the lowest level to cut at that has k or more clusters
starts = [v.start_level for v in self.nodes.itervalues()]
ends = [v.end_level for v in self.nodes.itervalues()]
if version[0] == '2':
nodes_values = self.nodes.itervalues()
nodes_items = self.nodes.iteritems()
if version[0] == '3':
nodes_values = self.nodes.values()
nodes_values = self.nodes.items()

starts = [v.start_level for v in nodes_values]
ends = [v.end_level for v in nodes_values]
crits = _np.unique(starts + ends)
nclust = {}

for c in crits:
nclust[c] = len([e for e, v in self.nodes.iteritems()
nclust[c] = len([e for e, v in nodes_items
if v.start_level <= c and v.end_level > c])

width = _np.max(nclust.values())


if version[0] == '2':
nclust_items = nclust.iteritems()
nclust_values = nclust.itervalues()
if version[0] == '3':
nclust_items = nclust.items()
nclust_values = nclust.values()

if k in nclust.values():
cut = _np.min([e for e, v in nclust.iteritems() if v == k])
cut = _np.min([e for e, v in nclust_items if v == k])
else:
if width < k:
cut = _np.min([e for e, v in nclust.iteritems() if v == width])
cut = _np.min([e for e, v in nclust_items if v == width])
else:
ktemp = _np.min([v for v in nclust.itervalues() if v > k])
cut = _np.min([e for e, v in nclust.iteritems() if v == ktemp])
ktemp = _np.min([v for v in nclust_values if v > k])
cut = _np.min([e for e, v in nclust_items if v == ktemp])

return cut

Expand Down Expand Up @@ -1039,8 +1102,13 @@ def _construct_branch_map(self, ix, interval, form, horizontal_spacing,

segmap += branch_segmap
splitmap += branch_splitmap
splits = dict(splits.items() + branch_splits.items())
segments = dict(segments.items() + branch_segs.items())
# d = dict1.copy()
# d.update(dict2)
#splits = dict(splits.items() + branch_splits.items())
splits.update(branch_splits.items())

segments.update(branch_segs.items())
# segments = dict(segments.items() + branch_segs.items())

## find the middle of the children's x-position and make vertical
# segment ix
Expand Down Expand Up @@ -1183,8 +1251,13 @@ def _construct_mass_map(self, ix, start_pile, interval,

segmap += branch_segmap
splitmap += branch_splitmap
splits = dict(splits.items() + branch_splits.items())
segments = dict(segments.items() + branch_segs.items())
# d = dict1.copy()
# d.update(dict2)
# splits = dict(splits.items() + branch_splits.items())
splits.update(branch_splits.items())

segments.update(branch_segs.items())
# segments = dict(segments.items() + branch_segs.items())

## find the middle of the children's x-position and make vertical
## segment ix
Expand Down Expand Up @@ -1379,15 +1452,26 @@ def construct_tree_from_graph(adjacency_list, density, prune_threshold=None,
previous_level = level

## compute the mass after the current bg set is removed
if version[0] == '2':
T_subGraph_values = T._subgraphs.itervalues()
if version[0] == '3':
T_subGraph_values = T._subgraphs.values()

old_vcount = sum([x.number_of_nodes()
for x in T._subgraphs.itervalues()])
for x in T_subGraph_values])

current_mass = 1. - ((old_vcount - len(bg)) / n)

# loop through active components, i.e. subgraphs
deactivate_keys = [] # subgraphs to deactivate at the iter end
activate_subgraphs = {} # new subgraphs to add at the end of the iter

for (k, H) in T._subgraphs.iteritems():

if version[0] == '2':
T_subGraph_items = T._subgraphs.iteritems()
if version[0] == '3':
T_subGraph_items = T._subgraphs.items()

for (k, H) in T_subGraph_items:

## remove nodes at the current level
H.remove_nodes_from(bg)
Expand Down