Skip to content
Open
18 changes: 14 additions & 4 deletions gramex/handlers/mlhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,14 @@ def setup(cls, data=None, model={}, config_dir='', template=DEFAULT_TEMPLATE, **
if op.exists(cls.store.model_path): # If the pkl exists, load it
if op.isdir(cls.store.model_path):
mclass, wrapper = ml.search_modelclass(mclass)
cls.model = locate(wrapper).from_disk(mclass, cls.store.model_path)
cls.model = locate(wrapper).from_disk(cls.store.model_path, mclass)
else:
cls.model = get_model(cls.store.model_path, {})
try:
cls.model = get_model(cls.store.model_path, {})
except Exception as err:
app_log.warning(err)
mclass, wrapper = ml.search_modelclass(mclass)
cls.model = locate(wrapper).from_disk(cls.store.model_path, mclass)
elif data is not None:
data = cls._filtercols(data)
data = cls._filterrows(data)
Expand Down Expand Up @@ -172,7 +177,12 @@ def _filterrows(cls, data, **kwargs):
action = kwargs.get(method, cls.store.load(method, True))
if action:
subset = action if isinstance(action, list) else None
data = getattr(data, method)(subset=subset)
try:
data = getattr(data, method)(subset=subset)
except TypeError as exc:
# The label column for an NER dataset is a nested list.
# Can't do drop_duplicates on that.
app_log.warning(exc)
return data

def _transform(self, data, **kwargs):
Expand Down Expand Up @@ -200,7 +210,7 @@ def _predict(self, data=None, score_col=''):
# Set data in the same order as the transformer requests
try:
tcol = self.store.load('target_col', '_prediction')
data = self.model.predict(data, target_col=tcol)
data = self.model.predict(data, target_col=tcol, **self.args)
except Exception as exc:
app_log.exception(exc)
return data
Expand Down
9 changes: 7 additions & 2 deletions gramex/ml_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@
"sklearn.decomposition",
"gramex.ml",
],
"gramex.sm_api.StatsModel": [
"gramex.timeseries.StatsModel": [
"statsmodels.tsa.api",
"statsmodels.tsa.statespace.sarimax",
],
"gramex.timeseries.Prophet": ["prophet"],
"gramex.ml_api.HFTransformer": ["gramex.transformers"],
}

Expand Down Expand Up @@ -368,7 +369,11 @@ def predict(
"""
p = self._predict(X, **kwargs)
if target_col:
X[target_col] = p
try:
X[target_col] = p
except ValueError:
# This happens for NER: predictions of a single sample can be multiple entities.
X[target_col] = [p]
return X
return p

Expand Down
56 changes: 53 additions & 3 deletions gramex/sm_api.py → gramex/timeseries.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pandas as pd
import numpy as np
import joblib
from typing import Union
from gramex.config import app_log
from gramex import cache
from statsmodels import api as sm
Expand All @@ -10,7 +11,6 @@


class StatsModel(AbstractModel):

@classmethod
def from_disk(cls, path, **kwargs):
model = cache.open(path, joblib.load)
Expand Down Expand Up @@ -51,8 +51,14 @@ def _get_stl(self, endog):
return pd.Series(result, index=endog.index)

def fit(
self, X, y=None, model_path=None, name=None, index_col=None, target_col=None,
**kwargs
self,
X,
y=None,
model_path=None,
name=None,
index_col=None,
target_col=None,
**kwargs,
):
"""Only a dataframe is accepted. Index and target columns are both expected to be in it."""
params = self.params.copy()
Expand Down Expand Up @@ -106,3 +112,47 @@ def get_attributes(self):
if not result:
return {}
return result.summary().as_html()


class Prophet(StatsModel):
def fit(
self,
X: Union[pd.DataFrame, np.ndarray],
y: Union[pd.Series, np.ndarray],
model_path: str = "",
name: str = "",
**kwargs,
):
X["y"] = y
self.model = self.mclass.fit(X)
from prophet.serialize import model_to_json

with open(model_path, "w") as fout:
fout.write(model_to_json(self.model))
score = self.score(X[["ds"]], y)
return score

@classmethod
def from_disk(cls, path, *args, **kwargs):
from prophet.serialize import model_from_json

with open(path, "r") as fin:
model = model_from_json(fin.read())
return cls(model, params={})

def score(self, X, y_true, **kwargs):
return mean_absolute_error(y_true, self.mclass.predict(X)["yhat"])

def predict(
self,
X: Union[pd.DataFrame, np.ndarray] = None,
n_periods=None,
include_history=False,
**kwargs,
):
if n_periods is not None:
future = self.mclass.make_future_dataframe(
periods=int(n_periods), include_history=include_history
)
return self.mclass.predict(future)
return self.mclass.predict(X)
Loading