Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 48 additions & 151 deletions gramex/handlers/mlhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,117 +5,64 @@

import gramex
from gramex import ml_api as ml
from gramex.transforms import build_transform
from gramex.config import app_log, CustomJSONEncoder, locate
from gramex import data as gdata
from gramex.handlers import FormHandler
from gramex.http import NOT_FOUND, BAD_REQUEST
from gramex.install import safe_rmtree
from gramex import cache

import numpy as np
import pandas as pd
import joblib
from sklearn.base import TransformerMixin
from sklearn.pipeline import Pipeline
from slugify import slugify
from tornado.gen import coroutine
from tornado.web import HTTPError
from sklearn.metrics import get_scorer

# TODO: Redesign the template for usecases
# MLHandler2 - API is more streamlined.

op = os.path

ACTIONS = ['predict', 'score', 'append', 'train', 'retrain']
DEFAULT_TEMPLATE = op.join(op.dirname(__file__), '..', 'apps', 'mlhandler', 'template.html')


def get_model(mclass: str, model_params: dict, **kwargs) -> ml.AbstractModel:
if not mclass:
def get_model(
data_config: dict = None,
model_config: dict = None,
store: str = None, **kwargs) -> ml.AbstractModel:
if data_config is None:
data_config = {}
if model_config is None:
model_config = {}
params = store.load('params', {}) # To repopulate after recreating the class
klass = model_config.pop('class', store.load('class'))
store.dump('class', klass)
store.dump('params', params)
try:
klass, wrapper = ml.search_modelclass(klass)
except ValueError:
app_log.warning('No model specification found.')
return
if mclass.endswith('.pkl'):
model = cache.open(mclass, joblib.load)
if isinstance(model, Pipeline):
_, wrapper = ml.search_modelclass(model[-1].__class__.__name__)
else:
_, wrapper = ml.search_modelclass(model.__class__.__name__)
else:
mclass, wrapper = ml.search_modelclass(mclass)
try:
model = mclass(**model_params)
except TypeError:
model = mclass
return locate(wrapper)(model, params=model_params, **kwargs)
model = locate(wrapper)(klass, store, data_config, **model_config)
return model


class MLHandler(FormHandler):

@classmethod
def setup(cls, data=None, model={}, config_dir='', template=DEFAULT_TEMPLATE, **kwargs):
def setup(cls, data={}, model={}, config_dir='', template=DEFAULT_TEMPLATE, **kwargs):
if not config_dir:
config_dir = op.join(gramex.config.variables['GRAMEXDATA'], 'apps', 'mlhandler',
slugify(cls.name))
cls.store = ml.ModelStore(config_dir)
cls.store = ml.ModelStore(config_dir, model)
cls.template = template
super(MLHandler, cls).setup(**kwargs)
index_col = None
try:
if 'transform' in data:
cls.store.dump('built_transform', data['transform'])
data['transform'] = build_transform(
{'function': data['transform']},
vars={'data': None, 'handler': None},
filename='MLHandler:data', iter=False)
cls._built_transform = staticmethod(data['transform'])
else:
cls._built_transform = staticmethod(lambda x: x)
index_col = data.get('index_col')
cls.store.dump('index_col', index_col)
data = gdata.filter(**data)
cls.store.store_data(data)
except TypeError:
app_log.warning('MLHandler could not find training data.')
data = None
cls._built_transform = staticmethod(lambda x: x)

# store the model kwargs from gramex.yaml into the store
for key in ml.TRANSFORMS:
cls.store.dump(key, model.get(key, cls.store.load(key)))
# Remove target_col if it appears anywhere in cats or nums
target_col = cls.store.load('target_col')
nums = list(set(cls.store.load('nums')) - {target_col})
cats = list(set(cls.store.load('cats')) - {target_col})
cls.store.dump('cats', cats)
cls.store.dump('nums', nums)
cls.data_config = data
cls.model_config = model
cls.model = get_model(data, model, cls.store, **kwargs)

mclass = model.get('class', cls.store.load('class', ''))
model_params = model.get('params', {})
cls.store.dump('class', mclass)
cls.store.dump('params', model_params)
if op.exists(cls.store.model_path): # If the pkl exists, load it
if op.isdir(cls.store.model_path):
mclass, wrapper = ml.search_modelclass(mclass)
cls.model = locate(wrapper).from_disk(mclass, cls.store.model_path)
else:
cls.model = get_model(cls.store.model_path, {})
elif data is not None:
data = cls._filtercols(data)
data = cls._filterrows(data)
cls.model = get_model(mclass, model_params, data=data, cats=cats,
nums=nums, target_col=target_col)
# train the model
if issubclass(cls.model.__class__, TransformerMixin):
target = None
train = data
else:
target = data[target_col]
train = data.drop([target_col], axis=1)
# Fit the model, if model and data exist
if cls.model:
gramex.service.threadpool.submit(
cls.model.fit, train, target,
model_path=cls.store.model_path, name=cls.name,
**cls.store.model_kwargs()
cls.model._init_fit, name=cls.name,
)

def _parse_multipart_form_data(self):
Expand Down Expand Up @@ -146,70 +93,29 @@ def _parse_data(self, _cache=True, append=False):
except ValueError:
app_log.warning('Could not read data from request, reading cached data.')
data = self.store.load_data()
data = self._built_transform(data)

if _cache:
self.store.store_data(data, append)
return data

@classmethod
def _filtercols(cls, data, **kwargs):
include = kwargs.get('include', cls.store.load('include', []))
if include:
include += [cls.store.load('target_col')]
data = data[include]
else:
exclude = kwargs.get('exclude', cls.store.load('exclude', []))
to_exclude = [c for c in exclude if c in data]
if to_exclude:
data = data.drop(to_exclude, axis=1)
return data

@classmethod
def _filterrows(cls, data, **kwargs):
for method in 'dropna drop_duplicates'.split():
action = kwargs.get(method, cls.store.load(method, True))
if action:
subset = action if isinstance(action, list) else None
data = getattr(data, method)(subset=subset)
return data

def _transform(self, data, **kwargs):
orgdata = self.store.load_data()
for col in np.intersect1d(data.columns, orgdata.columns):
data[col] = data[col].astype(orgdata[col].dtype)
data = self._filtercols(data, **kwargs)
data = self._filterrows(data, **kwargs)
return data

def _predict(self, data=None, score_col=''):
def _predict(self, data=None):
self._check_model_path()
metric = self.get_argument('_metric', False)
if metric:
scorer = get_scorer(metric)
if data is None:
data = self._parse_data(False)
data = self._transform(data, drop_duplicates=False)
try:
target = data.pop(score_col)
if metric:
return scorer(self.model, data, target)
return self.model.score(data, target)
except KeyError:
# Set data in the same order as the transformer requests
try:
tcol = self.store.load('target_col', '_prediction')
data = self.model.predict(data, target_col=tcol)
except Exception as exc:
app_log.exception(exc)
return data
tcol = self.store.load('target_col', '_prediction')
data = self.model.predict(data, target_col=tcol)
except Exception as exc:
app_log.exception(exc)
return data

def _check_model_path(self):
try:
klass, wrapper = ml.search_modelclass(self.store.load('class'))
self.model = locate(wrapper).from_disk(self.store.model_path, klass=klass)
self.model = locate(wrapper).from_disk(self.store, klass=klass)
except FileNotFoundError:
raise HTTPError(NOT_FOUND, f'No model found at {self.store.model_path}')
except ValueError:
raise HTTPError(NOT_FOUND, 'No model definition found.')

@coroutine
def prepare(self):
Expand All @@ -230,8 +136,10 @@ def get(self, *path_args, **path_kwargs):
'params': self.store.load('model')
}
try:
attrs = get_model(self.store.model_path, {}).get_attributes()
except (AttributeError, ImportError, FileNotFoundError):
self._check_model_path()
attrs = self.model.get_attributes()
except (AttributeError, ImportError, FileNotFoundError, HTTPError):
app_log.warning('No reasonable model found: either saved or defined in the spec.')
attrs = {}
params['attrs'] = attrs
self.write(json.dumps(params, indent=2, cls=CustomJSONEncoder))
Expand All @@ -258,7 +166,9 @@ def get(self, *path_args, **path_kwargs):
app_log.debug(err.msg)
data = []
if len(data) > 0:
data = data.drop([self.store.load('target_col')], axis=1, errors='ignore')
data = data.drop(
[self.store.load('target_col')], axis=1, errors='ignore'
)
prediction = yield gramex.service.threadpool.submit(
self._predict, data)
self.write(json.dumps(prediction, indent=2, cls=CustomJSONEncoder))
Expand All @@ -272,25 +182,11 @@ def _append(self):

def _train(self, data=None):
target_col = self.get_argument('target_col', self.store.load('target_col'))
index_col = self.get_argument('index_col', self.store.load('index_col'))
self.store.dump('target_col', target_col)
data = self._parse_data(False) if data is None else data
data = self._filtercols(data)
data = self._filterrows(data)
self.model = get_model(
self.store.load('class'), self.store.load('params'),
data=data, target_col=target_col,
nums=self.store.load('nums'), cats=self.store.load('cats')
)
if not isinstance(self.model, ml.SklearnTransformer):
target = data[target_col]
train = data[[c for c in data if c not in (target_col, index_col)]]
self.model.fit(train, target, self.store.model_path)
result = {'score': self.model.score(train, target)}
else:
self.model.fit(data, None, self.store.model_path)
result = self.model.get_attributes()
return result
self.model = get_model(store=self.store)
self.model.fit(data, self.store.model_path, self.name)
return {'score': self.model.score(data, target_col)}

def _retrain(self):
return self._train(self.store.load_data())
Expand All @@ -300,7 +196,8 @@ def _score(self):
data = self._parse_data(False)
target_col = self.get_argument('target_col', self.store.load('target_col'))
self.store.dump('target_col', target_col)
return {'score': self._predict(data, target_col)}
metric = self.get_argument('_metric', '')
return {'score': self.model.score(data, target_col, metric=metric)}

@coroutine
def post(self, *path_args, **path_kwargs):
Expand All @@ -323,7 +220,7 @@ def put(self, *path_args, **path_kwargs):
val = self.args.pop(opt)
self.store.dump(opt, val)
# The rest is params
params = self.store.load('params')
params = self.store.load('params', {})
for key, val in ml.coerce_model_params(mclass, self.args).items():
params[key] = val
self.store.dump('params', params)
Expand Down
Loading