Skip to content

Commit dfb9781

Browse files
committed
Add regression metric reporter and update requirements
1 parent 9b7a0d7 commit dfb9781

File tree

8 files changed

+57
-256
lines changed

8 files changed

+57
-256
lines changed

datasist/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,3 @@
44
from . import timeseries
55
from . import visualizations
66
from . import model
7-
from . import nlp

datasist/model.py

+50-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
'''
55
import platform
66

7-
from sklearn.metrics import roc_curve, confusion_matrix, precision_score, accuracy_score, recall_score, f1_score, make_scorer
7+
from sklearn.metrics import roc_curve, confusion_matrix, precision_score, accuracy_score, recall_score, f1_score, make_scorer, mean_absolute_error, mean_squared_error, r2_score, mean_squared_log_error
88
from sklearn.model_selection import KFold, cross_val_score
99
import numpy as np
1010
import pandas as pd
@@ -232,6 +232,55 @@ def get_classification_report(y_train=None, prediction=None, show_roc_plot=True,
232232
plt.savefig("roc_plot.png")
233233

234234

235+
def get_regression_report(y_true=None, prediction=None, show_r2_plot=True, save_plot=False):
236+
'''
237+
Generates performance report for a regression problem.
238+
239+
Parameters:
240+
------------------
241+
y_true: Array, series, list.
242+
243+
The truth/ground value from the train data set.
244+
245+
prediction: Array, series, list.
246+
247+
The predicted value by a trained model.
248+
249+
show_r2_plot: Bool, default True.
250+
251+
Show the r-squared curve.
252+
253+
save_plot: Bool, default True.
254+
255+
Save the plot to the current working directory.
256+
257+
'''
258+
mae = mean_absolute_error(y_true, prediction)
259+
mse = mean_squared_error(y_true, prediction)
260+
msle = precision_score(y_true, prediction)
261+
r2 = r2_score(y_true, prediction)
262+
263+
print("Mean Absolute Error: ", round(mae, 5))
264+
print("Mean Squared Error: ", round(mse, 5))
265+
print("Mean Squared Log Error: ", round(msle, 5))
266+
print("R-squared Error: ", round(r2, 5))
267+
print("*" * 100)
268+
269+
if show_r2_plot:
270+
plt.scatter(y_true,prediction)
271+
plt.xlabel('Truth values')
272+
plt.ylabel('Predicted values')
273+
plt.plot(np.unique(y_true), np.poly1d(np.polyfit(y_true, y_true, 1))(np.unique(y_true)))
274+
plt.text(0.7, 0.2, 'R-squared = %0.2f' % r2)
275+
plt.show()
276+
277+
if save_plot:
278+
plt.savefig("r2_plot.png")
279+
280+
281+
282+
283+
235284
def compare_model(models_list=None, x_train=None, y_train=None, scoring_metric=None, scoring_cv=3, silenced=True, plot=True):
236285
"""
237286
Train multiple user-defined model and display report based on defined metric. Enables user to pick the best base model for a problem.

datasist/tests/test_model.py

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from datasist import model
88

9+
910
def test_compare_model_classification():
1011
x_train, y_train = make_classification(
1112
n_samples=50,
@@ -24,6 +25,7 @@ def test_compare_model_classification():
2425
assert type(model_scores) is list
2526
assert hasattr(fitted_model[0], "predict")
2627

28+
2729
def test_compare_model_regression():
2830
x_train, y_train = make_classification(
2931
n_samples=50,

datasist/tests/test_nlp.py

-11
This file was deleted.

docs/index.html

+2-6
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ <h1 class="title">Module <code>datasist</code></h1>
2929
from . import timeseries
3030
from . import visualizations
3131
from . import model
32-
from . import nlp</code></pre>
32+
</code></pre>
3333
</details>
3434
</section>
3535
<section>
@@ -43,10 +43,7 @@ <h2 class="section-title" id="header-submodules">Sub-modules</h2>
4343
<dd>
4444
<section class="desc"><p>This module contains all functions relating to modeling in using sklearn library.</p></section>
4545
</dd>
46-
<dt><code class="name"><a title="datasist.nlp" href="nlp.html">datasist.nlp</a></code></dt>
47-
<dd>
48-
<section class="desc"><p>This module contains all functions relating to nlp</p></section>
49-
</dd>
46+
<dd></dd>
5047
<dt><code class="name"><a title="datasist.structdata" href="structdata.html">datasist.structdata</a></code></dt>
5148
<dd>
5249
<section class="desc"><p>This module contains all functions relating to the cleaning and exploration of structured data sets; mostly in pandas format</p></section>
@@ -80,7 +77,6 @@ <h1><img src="datasist.png" alt="logo"></h1>
8077
<ul>
8178
<li><code><a title="datasist.feature_engineering" href="feature_engineering.html">datasist.feature_engineering</a></code></li>
8279
<li><code><a title="datasist.model" href="model.html">datasist.model</a></code></li>
83-
<li><code><a title="datasist.nlp" href="nlp.html">datasist.nlp</a></code></li>
8480
<li><code><a title="datasist.structdata" href="structdata.html">datasist.structdata</a></code></li>
8581
<li><code><a title="datasist.timeseries" href="timeseries.html">datasist.timeseries</a></code></li>
8682
<li><code><a title="datasist.visualizations" href="visualizations.html">datasist.visualizations</a></code></li>

docs/nlp.html

-230
This file was deleted.

requirements.txt

+1-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,4 @@ pandas
44
matplotlib
55
seaborn
66
scikit-learn
7-
numpy
8-
spacy
9-
en
7+
numpy

0 commit comments

Comments
 (0)