Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Conformal Coherent Quantile Regression #3

Merged
merged 7 commits into from
Mar 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,3 @@ updates:
prefix: "ci"
prefix-development: "ci"
include: "scope"
- package-ecosystem: pip
directory: /
schedule:
interval: monthly
commit-message:
prefix: "build"
prefix-development: "build"
include: "scope"
versioning-strategy: lockfile-only
allow:
- dependency-type: "all"
2 changes: 1 addition & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: "3.10"

Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
python-version: ["3.10", "3.11"]

name: Python ${{ matrix.python-version }}

Expand All @@ -36,12 +36,12 @@ jobs:
PYTHON_VERSION=${{ matrix.python-version }} devcontainer up --workspace-folder .

- name: Lint package
run: devcontainer exec --workspace-folder . poe lint
run: devcontainer exec --remote-env CI=true --workspace-folder . poe lint

- name: Test package
run: devcontainer exec --workspace-folder . poe test
run: devcontainer exec --remote-env CI=true --workspace-folder . poe test

- name: Upload coverage
uses: codecov/codecov-action@v3
uses: codecov/codecov-action@v4
with:
files: reports/coverage.xml
117 changes: 115 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,124 @@

# 👖 Conformal Tights

A scikit-learn [meta-estimator](https://scikit-learn.org/stable/glossary.html#term-meta-estimator) for computing tight [conformal predictions](https://en.wikipedia.org/wiki/Conformal_prediction).
A [scikit-learn meta-estimator](https://scikit-learn.org/stable/glossary.html#term-meta-estimator) that adds [conformal prediction](https://en.wikipedia.org/wiki/Conformal_prediction) of coherent [quantiles](https://en.wikipedia.org/wiki/Quantile) and [intervals](https://en.wikipedia.org/wiki/Prediction_interval) to any [scikit-learn regressor](https://scikit-learn.org/stable/glossary.html#term-regressor). Features:

1. 🍬 *Meta-estimator*: add prediction of quantiles and intervals to any scikit-learn regressor
2. 🌡️ *Conformally calibrated:* accurate quantiles, and intervals with reliable [coverage](https://en.wikipedia.org/wiki/Coverage_probability)
3. 🚦 *Coherent quantiles:* quantiles increase monotonically instead of [crossing](https://github.com/dmlc/xgboost/issues/9848) [each other](https://github.com/microsoft/LightGBM/issues/3447)
4. 👖 *Tight quantiles:* selects the lowest [dispersion](https://en.wikipedia.org/wiki/Statistical_dispersion) that provides the desired coverage
5. 🎁 *Data efficient:* requires only a small number of calibration examples to fit
6. 🐼 *Pandas support:* optionally predict on DataFrames and receive DataFrame output

## Using

To add and install this package as a dependency of your project, run `poetry add conformal-tights`.
### Installing

First, install this package with:

```sh
pip install conformal-tights
```

### Predicting quantiles

Conformal Tights exposes a meta-estimator called `ConformalCoherentQuantileRegressor` that you can use to wrap any scikit-learn regressor, after which you can use `predict_quantiles` predict conformally calibrated quantiles. Example usage:

```python
from conformal_tights import ConformalCoherentQuantileRegressor
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from xgboost import XGBRegressor

# Fetch dataset and split in train and test
X, y = fetch_openml("ames_housing", version=1, return_X_y=True, as_frame=True, parser="auto")
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.15, random_state=42)

# Create a regressor, wrap it, and fit on the train set
my_regressor = XGBRegressor(objective="reg:absoluteerror")
conformal_predictor = ConformalCoherentQuantileRegressor(estimator=my_regressor)
conformal_predictor.fit(X_train, y_train)

# Predict with the wrapped regressor
ŷ_test = conformal_predictor.predict(X_test)

# Predict quantiles with the conformal wrapper
ŷ_test_quantiles = conformal_predictor.predict_quantiles(X_test, quantiles=(0.025, 0.05, 0.1, 0.9, 0.95, 0.975))
```

When the input data is a pandas DataFrame, the output is also a pandas DataFrame. For example, printing the head of `ŷ_test_quantiles` yields:

| house_id | 0.025 | 0.05 | 0.1 | 0.9 | 0.95 | 0.975 |
|-----------:|--------:|-------:|-------:|-------:|-------:|--------:|
| 1357 | 121557 | 130272 | 139913 | 189399 | 211177 | 237309 |
| 2367 | 86005 | 92617 | 98591 | 130236 | 145686 | 164766 |
| 2822 | 116523 | 121711 | 134993 | 175583 | 194964 | 216891 |
| 2126 | 105712 | 113784 | 122145 | 164330 | 183352 | 206224 |
| 1544 | 85920 | 92311 | 99130 | 133228 | 148895 | 167969 |

Let's visualize the predicted quantiles on the test set:

<img src="https://github.com/radix-ai/conformal-tights/assets/4543654/b02b3797-de6a-4e0d-b457-ed8e50e3f42c" width="512">

<details>
<summary>Expand to see the code that generated the graph above</summary>

```python
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
%config InlineBackend.figure_format = "retina"
plt.rcParams["font.size"] = 8
idx = (-ŷ_test.sample(50, random_state=42)).sort_values().index
y_ticks = list(range(1, len(idx) + 1))
plt.figure(figsize=(4, 5))
for j in range(3):
end = ŷ_test_quantiles.shape[1] - 1 - j
coverage = round(100 * (ŷ_test_quantiles.columns[end] - ŷ_test_quantiles.columns[j]))
plt.barh(
y_ticks,
ŷ_test_quantiles.loc[idx].iloc[:, end] - ŷ_test_quantiles.loc[idx].iloc[:, j],
left=ŷ_test_quantiles.loc[idx].iloc[:, j],
label=f"{coverage}% Prediction interval",
color=["#b3d9ff", "#86bfff", "#4da6ff"][j],
)
plt.plot(y_test.loc[idx], y_ticks, "s", markersize=3, markerfacecolor="none", markeredgecolor="#e74c3c", label="Actual value")
plt.plot(ŷ_test.loc[idx], y_ticks, "s", color="blue", markersize=0.6, label="Predicted value")
plt.xlabel("House price")
plt.ylabel("Test house index")
plt.yticks(y_ticks, y_ticks)
plt.tick_params(axis="y", labelsize=6)
plt.grid(axis="x", color="lightsteelblue", linestyle=":", linewidth=0.5)
plt.gca().xaxis.set_major_formatter(ticker.StrMethodFormatter("${x:,.0f}"))
plt.gca().spines["top"].set_visible(False)
plt.gca().spines["right"].set_visible(False)
plt.legend()
plt.tight_layout()
plt.show()
```
</details>

### Predicting intervals

In addition to quantile prediction, you can use `predict_interval` to predict conformally calibrated prediction intervals. Compared to quantiles, these focus on reliable coverage over quantile accuracy. Example usage:

```python
# Predict an interval for each example with the conformal wrapper
ŷ_test_interval = conformal_predictor.predict_interval(X_test, coverage=0.95)

# Measure the coverage of the prediction intervals on the test set
coverage = ((ŷ_test_interval.iloc[:, 0] <= y_test) & (y_test <= ŷ_test_interval.iloc[:, 1])).mean()
print(coverage) # 96.6%
```

When the input data is a pandas DataFrame, the output is also a pandas DataFrame. For example, printing the head of `ŷ_test_interval` yields:

| house_id | 0.025 | 0.975 |
|-----------:|--------:|--------:|
| 1357 | 108489 | 238396 |
| 2367 | 76043 | 165189 |
| 2822 | 101319 | 220247 |
| 2126 | 94238 | 207501 |
| 1544 | 75976 | 168741 |

## Contributing

Expand Down
Loading
Loading