Skip to content

Commit a5da8f5

Browse files
committed
add: AIDS dataset quantile regression
1 parent e157a1b commit a5da8f5

File tree

3 files changed

+1633
-0
lines changed

3 files changed

+1633
-0
lines changed

03-nonlinear/aids.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# %%
2+
import numpy as np
3+
import pandas as pd
4+
import seaborn as sns
5+
from sksurv.datasets import load_aids
6+
from sklearn.linear_model import QuantileRegressor
7+
import scienceplots
8+
import matplotlib.pyplot as plt
9+
10+
# %%
11+
X, y_frame = load_aids()
12+
X = X.astype(float)
13+
14+
# %%
15+
y_data = pd.DataFrame(y_frame)
16+
_, y = y_data.values.T
17+
y = y.astype(float)
18+
19+
# %%
20+
quantiles = np.linspace(0.1, 0.9, 20)
21+
coeffs = {}
22+
23+
for q in quantiles:
24+
model = QuantileRegressor(quantile=q, alpha=0.01)
25+
model.fit(X, y)
26+
coeffs[q] = model.coef_
27+
28+
# %%
29+
df_coeffs = pd.DataFrame(coeffs, index=X.columns)
30+
features_important = df_coeffs.abs().sum(axis=1).sort_values(ascending=False).head(4)
31+
df_coeffs = df_coeffs.loc[features_important.index].reset_index(names="feature")
32+
df_coeffs = df_coeffs.melt("feature", var_name="quantile", value_name="coefficient")
33+
34+
# %%
35+
plt.style.use("default")
36+
plt.style.use("no-latex")
37+
plot = sns.relplot(
38+
data=df_coeffs,
39+
x="quantile",
40+
y="coefficient",
41+
col="feature",
42+
col_wrap=2,
43+
kind="line",
44+
height=2,
45+
aspect=4 / 3,
46+
)
47+
48+
# %%
49+
plot.savefig("aids.svg")

0 commit comments

Comments
 (0)