Skip to content

Commit

Permalink
improve modularity
Browse files Browse the repository at this point in the history
  • Loading branch information
oaoni committed Aug 19, 2022
1 parent 1f90a76 commit 61463a5
Show file tree
Hide file tree
Showing 8 changed files with 433 additions and 54 deletions.
1 change: 1 addition & 0 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .models import ActiveExplore
from .apps import ExploreML
from .data import loadActiveH5
1 change: 1 addition & 0 deletions apps/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .explore_predict import ExploreML
47 changes: 47 additions & 0 deletions apps/explore_predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from exploreML.models import ActiveExplore
from exploreML.models import PredictMap
from bokeh.layouts import column, row

class ExploreML:


def __init__(self,grid_shape=(1,1),**kwargs):

self.grid_shape = grid_shape
self.layouts = []
#maybe map_p_high/low


def addActiveExplore(self,data_dict,sampling_dict,**kwargs):

active = ActiveExplore(data_dict, sampling_dict, **kwargs)

self.layouts += [active.layout]
self.toolbar = active.toolbar
self.clust_dict = active.clust_dict # Index clusters for heatmap
self.clust_methods = active.clust_methods
self.upper_source = active.upper_source
self.upper_dict = active.upper_dict
self.row_name = active.row_name
self.col_name = active.col_name
self.sample_sliders = active.sample_sliders
self.plot_size = active.plot_size
self.radio_button_group = active.radio_button_group
self.toggle = active.toggle
self.data_toggle = active.data_toggle
self.active_dim = active.active_dim


def addPredictMap(self,pred_df,map_dict,**kwargs):

predict = PredictMap(pred_df, map_dict, toolbar=self.toolbar,
clust_dict=self.clust_dict,clust_methods=self.clust_methods,
upper_source=self.upper_source,upper_dict=self.upper_dict,
row_name=self.row_name,col_name=self.col_name,
radio_button_group=self.radio_button_group,toggle=self.toggle,
sample_sliders=self.sample_sliders,plot_size=self.plot_size,
data_toggle=self.data_toggle,active_dim=self.active_dim,**kwargs)
self.layouts += [predict.layout]

def Layout(self):
self.layout = row(*self.layouts)
1 change: 1 addition & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .active_explore import ActiveExplore
from .predict_map import PredictMap
91 changes: 41 additions & 50 deletions models/active_explore.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
import pandas as pd
import seaborn as sns
import math
import os
from bokeh.io import show, save
from bokeh.plotting import Figure, output_file, show, output_notebook
from bokeh.layouts import column, row
from bokeh.models import ColorBar, LinearColorMapper, BasicTicker, CustomJS, ColumnDataSource,\
Toggle, Slider, RadioButtonGroup, Select, Legend, ColorPicker, Panel, Tabs, RangeSlider,HoverTool
from bokeh.models.widgets import Div
from bokeh.palettes import all_palettes
from scipy.cluster.hierarchy import linkage, dendrogram
from exploreML.models.custom_tools import ResetTool
import itertools

with open('exploreML/models/active_explore_js/slider_callback.js','r') as f:
slider_callback_js = f.read()
Expand Down Expand Up @@ -45,15 +45,15 @@ def __init__(self, data_dict, sampling_dict, is_sym=False, num_line_plots=2,
row_name='dim1', col_name='dim2', val_name='entry_value',hide_heatmap_labels=False,
heatmap_colors = 'default', n_colors=10, inds_colors=[1,2,4,5,6,7,8],
color_palette='Category10',plot_size=700, line_width=500, line_height=300,
name='active explorer', url='active_explorer.html', plot_location='below'):
name='active explorer', url='active_explorer.html', plot_location='below',
file_output=True):

self.name = name
self.numLinePlots = num_line_plots
self.is_sym = is_sym

# Output file
output_file(filename=url, title=name)
# output_notebook()
self.row_name = row_name
self.col_name = col_name
self.plot_size = plot_size

# Load data
M = data_dict['M']
Expand Down Expand Up @@ -81,6 +81,7 @@ def __init__(self, data_dict, sampling_dict, is_sym=False, num_line_plots=2,
sampler_color = dict(zip(sampling_names, sampler_palette[:n_samplers]))

active_dim = sampling_dfs[0].shape[0]
self.active_dim = active_dim
if self.is_sym:
sampling_methods = [self._makeSymAL(sampler, row_coord, col_coord, active_x, batch_col)\
for sampler in sampling_dfs]
Expand All @@ -94,6 +95,7 @@ def __init__(self, data_dict, sampling_dict, is_sym=False, num_line_plots=2,
# Precompute the indices for the different clustering methods
clust_dict = {method:self._makeClustIndex(M, method) for method in clust_methods}
self.clust_dict = clust_dict
self.clust_methods = clust_methods

# Active learning data
sampling_data = {sample:self._addIndexCols(df, df_index, df_cols, meta_vars)\
Expand All @@ -108,6 +110,7 @@ def __init__(self, data_dict, sampling_dict, is_sym=False, num_line_plots=2,
# Column source data of mask for each clustering method
upper_dict = {method:self._makeGImatrix(self._makeMaskUpper(M.iloc[clust_dict[method][0],clust_dict[method][1]]),meta_vars)\
.dropna(axis=0).to_dict(orient='list') for method in clust_methods}
self.upper_dict = upper_dict

# Collecting the quantitative columns from the sampling data
quant_options = sampling_methods[0].describe().T.query('std > 0').index.to_list()
Expand All @@ -121,6 +124,7 @@ def __init__(self, data_dict, sampling_dict, is_sym=False, num_line_plots=2,
for sampler,data in sampling_data.items()}

upper_source = ColumnDataSource(data=upper_dict[init_clust])
self.upper_source = upper_source

# Initialize sources
# GI column source data
Expand Down Expand Up @@ -235,19 +239,14 @@ def __init__(self, data_dict, sampling_dict, is_sym=False, num_line_plots=2,
with open('exploreML/models/active_explore_js/radio_call.js','r') as f:
radio_call_js = f.read()

radio_call = CustomJS(args=dict(methods=clust_methods,clust_dict=clust_dict,plot=p,
up_source=upper_source, up_dict=upper_dict),code=radio_call_js)

sample_sliders = {sampler:Slider(start=0, end=active_dim, value=0, step=1,\
# TODO: Make step a variable
sample_sliders = {sampler:Slider(start=0, end=active_dim, value=0, step=120,\
title=title,max_width=300)\
for sampler,title in zip(sampling_names,sampling_titles)}

radio_button_group = RadioButtonGroup(labels=[x.capitalize() for x in clust_methods], active=0,max_width=300)

slider_js = {sampler:sample_sliders[sampler].js_on_change('value', self._slider_callback(sampler, active_sources, sampling_sources, symMult))\
for sampler in sampling_names}

radio_button_group.js_on_click(radio_call)
self.sample_sliders = sample_sliders

toggle = Toggle(label="Lower Triangle (Toggle)", button_type="primary",max_width=300)
toggle.js_on_click(CustomJS(args=dict(plot=upperFig),code="""
Expand All @@ -256,6 +255,7 @@ def __init__(self, data_dict, sampling_dict, is_sym=False, num_line_plots=2,
plot.visible = cb_obj.active;
"""))
self.toggle = toggle

#default, primary, success, warning, danger, light
#Toggle to show and hide training examples
Expand All @@ -270,6 +270,7 @@ def __init__(self, data_dict, sampling_dict, is_sym=False, num_line_plots=2,
data_toggle_js = f.read()

data_toggle = Toggle(label="Show All (Toggle)", button_type="primary",max_width=300)
self.data_toggle = data_toggle
data_toggle.js_on_click(CustomJS(args=dict(sliders=sample_sliders, active_dim=\
active_dim),
code=data_toggle_js))
Expand All @@ -282,19 +283,32 @@ def __init__(self, data_dict, sampling_dict, is_sym=False, num_line_plots=2,
train_picker = ColorPicker(width=50,color='white')
train_picker.js_link('color', trainFig.glyph, 'fill_color')

sliders = column(list(sample_sliders.values()))
sliders = column(list(sample_sliders.values()),name='slider_col')

linePlots = [column(*line) for line in list(zip(line_selects,line_tabs))]

heatmap_layout = row(column(radio_button_group,toggle,sliders,data_toggle,
row(train_toggle,train_picker),range_slider,select_colorbar,width=340),
p)
radio_call = CustomJS(args=dict(methods=clust_methods,clust_dict=clust_dict,plot=p,
up_source=upper_source, up_dict=upper_dict),code=radio_call_js)

radio_button_group = RadioButtonGroup(labels=[x.capitalize() for x in clust_methods], active=0, max_width=300)
radio_button_group.js_on_click(radio_call)
self.radio_button_group = radio_button_group

toolbar = column(radio_button_group,toggle,sliders,data_toggle,
row(train_toggle,train_picker),
range_slider,select_colorbar,
width=310,name='toolbar', margin=(0,10,0,10))
heatmap_layout = row(toolbar,p)

layout = self._make_layout(heatmap_layout,linePlots,plot_location)

self.toolbar = toolbar
self.layout = layout

layout = self._make_layout(heatmap_layout,linePlots, plot_location)
if file_output:

# show(layout)
# self.layout = layout
save(layout)
output_file(filename=url, title=name)
save(layout)

def _make_layout(self, heatmap_layout, linePlots, plot_location):

Expand All @@ -313,7 +327,8 @@ def _make_layout(self, heatmap_layout, linePlots, plot_location):
layout = column(line_layout, heatmap_layout)

elif plot_location == 'below': # Below
layout = column(heatmap_layout,line_layout)
# layout = column(heatmap_layout,Div(height=80),line_layout)
layout = column(heatmap_layout,Div(height=20),line_layout)

return layout

Expand All @@ -331,32 +346,8 @@ def _line_callback(self, Fig, Fig2, Figs1, Figs2, samplerCol_meta):
plot2=Figs2,
col_meta=samplerCol_meta,
yaxis=Fig.yaxis[0],
yaxis2=Fig2.yaxis[0]), code="""
console.log('select: value=' + this.value, this.toString())
var select = cb_obj.value;
const keys = Object.keys(plot);
var keysLength = keys.length;
for (var i = 0; i < keysLength; i++) {
plot[keys[i]].glyph.y.field = select;
plot[keys[i]].data_source.change.emit();
plot2[keys[i]].glyph.y.field = select;
plot2[keys[i]].data_source.change.emit();
}
var mx = col_meta[select].max + col_meta[select].max * 0.05;
var mn = col_meta[select].min - col_meta[select].max * 0.05;
fig.y_range.end = mx;
fig.y_range.start = mn;
yaxis.axis_label = select;
fig2.y_range.end = mx;
fig2.y_range.start = mn;
yaxis2.axis_label = select;
""")
yaxis2=Fig2.yaxis[0]),
code=line_callback_js)

return callback

Expand Down
11 changes: 7 additions & 4 deletions models/active_explore_js/line_callback.js
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@ for (var i = 0; i < keysLength; i++) {
plot2[keys[i]].data_source.change.emit();
}

fig.y_range.end = col_meta[select].max;
fig.y_range.start = col_meta[select].min;
var mx = col_meta[select].max + col_meta[select].max * 0.05;
var mn = col_meta[select].min - col_meta[select].max * 0.05;

fig.y_range.end = mx;
fig.y_range.start = mn;
yaxis.axis_label = select;

fig2.y_range.end = col_meta[select].max;
fig2.y_range.start = col_meta[select].min;
fig2.y_range.end = mx;
fig2.y_range.start = mn;
yaxis2.axis_label = select;
43 changes: 43 additions & 0 deletions models/active_explore_js/radio_call_plots.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
console.log('radio_button_group: active=' + this.active, this.toString())

//Store radioVal for clustering methods
let radioVal = cb_obj.active;

let method = methods[radioVal];
let x_range = clust_dict[method][3];
let y_range = [].concat(clust_dict[method][2]).reverse();

plot1.x_range.factors = x_range;
plot1.y_range.factors = y_range;

plot2.x_range.factors = x_range;
plot2.y_range.factors = y_range;

let cdsLists = up_dict[method];

//Reorder mask
//Assign mask source to variable
var data = up_source.data;
var keys = Object.keys(data);
var keysLength = keys.length;

// Reorder mask
var mask_map = new Map();
var new_map = new Map();
for (var i = 0; i < keysLength; i++) {

//Assign mask source to map variable
mask_map.set(keys[i], data[keys[i]]);

// Reassign CDS
new_map.set(keys[i], cdsLists[keys[i]]);

// Clear old mask CDS
mask_map.get(keys[i]).splice(0, mask_map.get(keys[i]).length);

// Add new mask cds
mask_map.set(keys[i], mask_map.get(keys[i]).push(...new_map.get(keys[i])));

}

up_source.change.emit();
Loading

0 comments on commit 61463a5

Please sign in to comment.