Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 64 additions & 24 deletions gramex/apps/mlhandler/template.html
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
'DecisionTreeRegressor',
'RandomForestRegressor',
'MLPRegressor'] %}
{% set GRAINSIGHT_MODELS = ['TopCause'] %}
{% set tcol = handler.get_opt('target_col', False) %}
{% set CLASSIFICTION_METRICS = {
'Accuracy': 'accuracy',
Expand Down Expand Up @@ -121,7 +122,7 @@ <h3 class="text-center">Train the Model</h3>
<div class="row pb-3 pt-3">
<div class="col">
<label for="cats">Categorical Columns:</label>
<select id="cats" class="selectpicker form-control" multiple name="cats">
<select id="cats" class="selectpicker form-control tcdisable" multiple name="cats">
{% for col in columns %}
{% set selected = "selected" if col in handler.get_opt('cats', []) else "" %}
<option value="{{ col }}" {{ selected }}>{{ col }}</option>
Expand All @@ -130,7 +131,7 @@ <h3 class="text-center">Train the Model</h3>
</div>
<div class="col">
<label for="nums">Numerical Columns:</label>
<select id="nums" class="selectpicker form-control" multiple name="nums">
<select id="nums" class="selectpicker form-control tcdisable" multiple name="nums">
{% for col in columns %}
{% set selected = "selected" if col in handler.get_opt('nums', []) else "" %}
<option value="{{ col }}" {{ selected }}>{{ col }}</option>
Expand All @@ -141,12 +142,12 @@ <h3 class="text-center">Train the Model</h3>
<div class="row pb-3 pt-3">
<div class="col">
<label for="transform">Transform:</label>
<input class="form-control" id="transform" name="data.transform" type="text"
<input class="form-control tcdisable" id="transform" name="data.transform" type="text"
value="{{ handler.get_opt('transform', '') }}">
</div>
<div class="col">
<label for="metric">Choose a Metric:</label>
<select id="metric" class="form-control selectpicker" name="metric">
<select id="metric" class="form-control selectpicker tcdisable" name="metric">
{% if handler.get_opt('class') in CLASSIFICTION_MODELS %}
{% for i, (mname, metric) in enumerate(CLASSIFICTION_METRICS.items()) %}
{% set selected = "selected" if metric == "accuracy" else "" %}
Expand Down Expand Up @@ -181,6 +182,12 @@ <h3 class="text-center">Train the Model</h3>
<option value="{{ model }}" {{ selected }}>{{ model }}</option>
{% end %}
</optgroup>
<optgroup label="GrainSight">
{% for model in GRAINSIGHT_MODELS %}
{% set selected = "selected" if model == handler.get_opt('class') else "" %}
<option value="{{ model }}" {{ selected }}>{{ model }}</option>
{% end %}
</optgroup>
</select>
</div>
</div>
Expand All @@ -190,6 +197,7 @@ <h3 class="text-center">Train the Model</h3>
</div>
</form>
<div class="text-center divider">Results</div>
<div id="tcresult" class="container"></div>
<div class="container" id="resultcnt">
<div class="row">
<div class="col">
Expand Down Expand Up @@ -234,7 +242,7 @@ <h3 class="text-center">Train the Model</h3>
<h3 class="text-center">Make Predictions</h3>
<div class="row py-2">
<div class="col">
<button type="submit" form="predictform" class="btn btn-primary">Predict</button>
<button type="submit" form="predictform" class="btn btn-primary tcdisable">Predict</button>
</div>
<div class="col">
<h4 id="predResult"></h4>
Expand All @@ -246,7 +254,7 @@ <h4 id="predResult"></h4>
<% COLS.forEach(function(col) { %>
<div class="form-group row">
<label for="<%= col.name %>" class="col-md-6"><%= col.name %></label>
<input class="form-control col-md-6" type="<%= col.type %>" name="<%= col.name %>" value="<%=row[col.name]%>">
<input class="form-control col-md-6 tcdisable" type="<%= col.type %>" name="<%= col.name %>" value="<%=row[col.name]%>">
</div>
<% }) %>
</div>
Expand Down Expand Up @@ -276,34 +284,64 @@ <h4 id="predResult"></h4>
if (s > 90) { color = '#00f700' }
return color
}
const get_score = function() {
let url = g1.url.parse(window.location)
$.ajax({
url: url + '?_action=score&_metric=' + encodeURIComponent($('#metric').val()),
method: 'POST',
success: function(resp) {
let score = Number.parseFloat(JSON.parse(resp).score * 100).toPrecision(4)
score = Number.parseFloat(score)
$('.donut-segment').attr('stroke-dasharray', `${score} ${100 - score}`)
$('tspan').html(`${score}%`)
$('.donut-segment').attr('stroke', get_score_color(score))
$('#resultcnt').show()
let inputcols = fh_meta.meta.columns.filter((col) => !$('#exclude').val().concat($('#targetcol').val()).includes(col.name))
$('#predicttabtemplate').template({COLS: inputcols.map(e => ({name: e.name, type: e.type})), row: fh_meta.formdata[0]})
}
})
const get_score = function(resp) {
if ($('#modelchoice').val() != 'TopCause') {
$('#tcresult').hide()
let url = g1.url.parse(window.location)
$.ajax({
url: url + '?_action=score&_metric=' + encodeURIComponent($('#metric').val()),
method: 'POST',
success: function(resp) {
let score = (resp.score * 100).toPrecision(4)
score = Number.parseFloat(score)
$('.donut-segment').attr('stroke-dasharray', `${score} ${100 - score}`)
$('tspan').html(`${score}%`)
$('.donut-segment').attr('stroke', get_score_color(score))
$('#resultcnt').show()
let inputcols = fh_meta.meta.columns.filter((col) => !$('#exclude').val().concat($('#targetcol').val()).includes(col.name))
$('#predicttabtemplate')
.on('template', modifyModelChoice)
.template({COLS: inputcols.map(e => ({name: e.name, type: e.type})), row: fh_meta.formdata[0]})
}
})
} else {
$('#tcresult').show()
$('#tcresult').formhandler({data: resp.result_})
}
}
const post_train = function (target_col) {
$('#resultcnt').hide()
let url = g1.url.parse(window.location)
url.hash = ''
url = url + '?_action=retrain&target_col=' + encodeURIComponent(target_col)
if ($('#modelchoice').val() != 'TopCause') {
url = url + '&_metric=' + encodeURIComponent($('#metric').val())
}
$.ajax({
url: url + '?_action=retrain&target_col=' + encodeURIComponent(target_col) + '&_metric=' + encodeURIComponent($('#metric').val()),
url: url,
method: 'POST',
success: get_score
})
}
const modifyModelChoice = function() {
$('.tcdisable').attr('disabled', $('#modelchoice').val() == 'TopCause')
$('.selectpicker').selectpicker('refresh')
if ($('#modelchoice').val() == 'TopCause') {
let url = g1.url.parse(window.location)
url.hash = ''
$.getJSON(url + '?_params').done(function(result) {
$('#tcresult').show()
$('#tcresult').formhandler({data: result.attrs.result_})
})
} else {
$('#tcresult').hide()
}
}
$(document).ready(function() {
// If TopCause is selected, enable only some inputs
$('#modelchoice').change(modifyModelChoice)
$('#modelchoice').trigger('change')

$('#downloadbtn').hide()
$('#resultcnt').hide()
let url = g1.url.parse(window.location)
Expand All @@ -312,7 +350,9 @@ <h4 id="predResult"></h4>
$('.formhandler').on('load', function(obj) {
fh_meta = obj
let inputcols = obj.meta.columns.filter((col) => !$('#exclude').val().concat($('#targetcol').val()).includes(col.name))
$('#predicttabtemplate').template({COLS: inputcols.map(e => ({name: e.name, type: e.type})), row: obj.formdata[0]})
$('#predicttabtemplate')
.on('template', modifyModelChoice)
.template({COLS: inputcols.map(e => ({name: e.name, type: e.type})), row: obj.formdata[0]})
}).formhandler({
pageSize: 5,
export: false
Expand Down
23 changes: 16 additions & 7 deletions gramex/handlers/mlhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from gramex.handlers import FormHandler
from gramex.http import NOT_FOUND, BAD_REQUEST
from gramex.install import _mkdir, safe_rmtree
from gramex.ml import TopCause
from gramex import cache
import joblib
import pandas as pd
Expand Down Expand Up @@ -224,12 +225,19 @@ def _filterrows(cls, data, **kwargs):

@classmethod
def _assemble_pipeline(cls, data, force=False, mclass='', params=None):
if params is None:
params = {}
# If the model exists, return it
if op.exists(cls.model_path) and not force:
return joblib.load(cls.model_path)

model_kwargs = cls.config_store.load('model', {})
if not mclass:
mclass = model_kwargs.get('class', False)

# No pipeline for TopCause
# If preprocessing is not enabled, return the root model
if not cls.get_opt('pipeline', True):
if mclass == 'TopCause' or not cls.get_opt('pipeline', True):
return search_modelclass(mclass)(**params)

# Else assemble the preprocessing pipeline
Expand All @@ -252,8 +260,6 @@ def _assemble_pipeline(cls, data, force=False, mclass='', params=None):
[('ohe', OneHotEncoder(sparse=False), categoricals),
('scaler', StandardScaler(), numericals)]
)
model_kwargs = cls.config_store.load('model', {})
mclass = model_kwargs.get('class', False)
if mclass:
model = search_modelclass(mclass)(**model_kwargs.get('params', {}))
cls.set_opt('params', model.get_params())
Expand Down Expand Up @@ -318,8 +324,9 @@ def get(self, *path_args, **path_kwargs):
}
try:
model = cache.open(self.model_path, joblib.load)
estimator = model[-1] if hasattr(model, '__getitem__') else model
attrs = {
k: v for k, v in vars(model[-1]).items() if re.search(r'[^_]+_$', k)
k: v for k, v in vars(estimator).items() if re.search(r'[^_]+_$', k)
}
except FileNotFoundError:
attrs = {}
Expand Down Expand Up @@ -367,17 +374,19 @@ def _train(self, data=None):
data = self._filtercols(data)
data = self._filterrows(data)
self.model = self._assemble_pipeline(data, force=True)
if not isinstance(self.model[-1], TransformerMixin):
estimator = self.model[-1] if hasattr(self.model, '__getitem__') else self.model
if not isinstance(estimator, (TransformerMixin, TopCause)):
target = data[target_col]
train = data[[c for c in data if c != target_col]]
_fit(self.model, train, target, self.model_path)
result = {'score': self.model.score(train, target)}
else:
_fit(self.model, data, path=self.model_path)
target = data[target_col] if isinstance(estimator, TopCause) else None
_fit(self.model, data, target, self.model_path)
# Note: Fitted sklearn estimators store their parameters
# in attributes whose names end in an underscore. E.g. in the case of PCA,
# attributes are named `explained_variance_`. The `_train` action returns them.
result = {k: v for k, v in vars(self.model[-1]).items() if re.search(r'[^_]+_$', k)}
result = {k: v for k, v in vars(estimator).items() if re.search(r'[^_]+_$', k)}
return result

def _retrain(self):
Expand Down
2 changes: 1 addition & 1 deletion gramex/topcause.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,6 @@ def fit(self, X, y, sample_weight=None): # noqa - capital X is a sklearn conv

results = pd.DataFrame(results).T
results.loc[results['p'] > self.max_p, ('value', 'gain')] = np.nan
self.result_ = results.sort_values('gain', ascending=False)
self.result_ = results.sort_values('gain', ascending=False).reset_index()

return self