Skip to content

Commit

Permalink
feat: lazify classification metrics and clean tests
Browse files Browse the repository at this point in the history
  • Loading branch information
IndexSeek committed Dec 20, 2024
1 parent b292a49 commit ff559bd
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 43 deletions.
31 changes: 20 additions & 11 deletions ibis_ml/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def accuracy_score(y_true: dt.Integer, y_pred: dt.Integer) -> float:
Examples
--------
>>> from ibis_ml.metrics import accuracy_score
>>> import ibis
>>> from ibis_ml.metrics import accuracy_score
>>> ibis.options.interactive = True
>>> t = ibis.memtable(
... {
Expand All @@ -29,9 +29,11 @@ def accuracy_score(y_true: dt.Integer, y_pred: dt.Integer) -> float:
... }
... )
>>> accuracy_score(t.actual, t.prediction)
0.5833333333333334
┌──────────┐
│ 0.583333 │
└──────────┘
"""
return (y_true == y_pred).mean().to_pyarrow().as_py()
return (y_true == y_pred).mean() # .to_pyarrow().as_py()


def precision_score(y_true: dt.Integer, y_pred: dt.Integer) -> float:
Expand All @@ -51,8 +53,8 @@ def precision_score(y_true: dt.Integer, y_pred: dt.Integer) -> float:
Examples
--------
>>> from ibis_ml.metrics import precision_score
>>> import ibis
>>> from ibis_ml.metrics import precision_score
>>> ibis.options.interactive = True
>>> t = ibis.memtable(
... {
Expand All @@ -62,11 +64,13 @@ def precision_score(y_true: dt.Integer, y_pred: dt.Integer) -> float:
... }
... )
>>> precision_score(t.actual, t.prediction)
0.6666666666666666
┌──────────┐
│ 0.666667 │
└──────────┘
"""
true_positive = (y_true & y_pred).sum()
predicted_positive = y_pred.sum()
return (true_positive / predicted_positive).to_pyarrow().as_py()
return true_positive / predicted_positive


def recall_score(y_true: dt.Integer, y_pred: dt.Integer) -> float:
Expand All @@ -83,10 +87,11 @@ def recall_score(y_true: dt.Integer, y_pred: dt.Integer) -> float:
-------
float
The recall score, representing the fraction of true positive predictions.
Examples
--------
>>> from ibis_ml.metrics import recall_score
>>> import ibis
>>> from ibis_ml.metrics import recall_score
>>> ibis.options.interactive = True
>>> t = ibis.memtable(
... {
Expand All @@ -96,11 +101,13 @@ def recall_score(y_true: dt.Integer, y_pred: dt.Integer) -> float:
... }
... )
>>> recall_score(t.actual, t.prediction)
0.5714285714285714
┌──────────┐
│ 0.571429 │
└──────────┘
"""
true_positive = (y_true & y_pred).sum()
actual_positive = y_true.sum()
return (true_positive / actual_positive).to_pyarrow().as_py()
return true_positive / actual_positive


def f1_score(y_true: dt.Integer, y_pred: dt.Integer) -> float:
Expand All @@ -120,8 +127,8 @@ def f1_score(y_true: dt.Integer, y_pred: dt.Integer) -> float:
Examples
--------
>>> from ibis_ml.metrics import f1_score
>>> import ibis
>>> from ibis_ml.metrics import f1_score
>>> ibis.options.interactive = True
>>> t = ibis.memtable(
... {
Expand All @@ -131,7 +138,9 @@ def f1_score(y_true: dt.Integer, y_pred: dt.Integer) -> float:
... }
... )
>>> f1_score(t.actual, t.prediction)
0.6153846153846154
┌──────────┐
│ 0.615385 │
└──────────┘
"""
precision = precision_score(y_true, y_pred)
recall = recall_score(y_true, y_pred)
Expand Down
48 changes: 16 additions & 32 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import ibis
import pytest
from sklearn.metrics import accuracy_score as sk_accuracy_score
from sklearn.metrics import f1_score as sk_f1_score
from sklearn.metrics import precision_score as sk_precision_score
from sklearn.metrics import recall_score as sk_recall_score
import sklearn.metrics

from ibis_ml.metrics import accuracy_score, f1_score, precision_score, recall_score
import ibis_ml.metrics


@pytest.fixture
Expand All @@ -19,33 +16,20 @@ def results_table():
)


def test_accuracy_score(results_table):
@pytest.mark.parametrize(
"metric_name",
[
pytest.param("accuracy_score", id="accuracy_score"),
pytest.param("precision_score", id="precision_score"),
pytest.param("recall_score", id="recall_score"),
pytest.param("f1_score", id="f1_score"),
],
)
def test_classification_metrics(results_table, metric_name):
ibis_ml_func = getattr(ibis_ml.metrics, metric_name)
sklearn_func = getattr(sklearn.metrics, metric_name)
t = results_table
df = t.to_pandas()
result = accuracy_score(t.actual, t.prediction)
expected = sk_accuracy_score(df["actual"], df["prediction"])
assert result == pytest.approx(expected, abs=1e-4)


def test_precision_score(results_table):
t = results_table
df = t.to_pandas()
result = precision_score(t.actual, t.prediction)
expected = sk_precision_score(df["actual"], df["prediction"])
assert result == pytest.approx(expected, abs=1e-4)


def test_recall_score(results_table):
t = results_table
df = t.to_pandas()
result = recall_score(t.actual, t.prediction)
expected = sk_recall_score(df["actual"], df["prediction"])
assert result == pytest.approx(expected, abs=1e-4)


def test_f1_score(results_table):
t = results_table
df = t.to_pandas()
result = f1_score(t.actual, t.prediction)
expected = sk_f1_score(df["actual"], df["prediction"])
result = ibis_ml_func(t.actual, t.prediction).to_pyarrow().as_py()
expected = sklearn_func(df["actual"], df["prediction"])
assert result == pytest.approx(expected, abs=1e-4)

0 comments on commit ff559bd

Please sign in to comment.