Skip to content

Commit f5a244f

Browse files
committed
update lineax nb
1 parent fffc3c2 commit f5a244f

File tree

1 file changed

+146
-45
lines changed

1 file changed

+146
-45
lines changed

jax_linear.ipynb

Lines changed: 146 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,17 @@
1010
},
1111
{
1212
"cell_type": "code",
13-
"execution_count": 2,
13+
"execution_count": 1,
1414
"id": "ee58a952-6a96-494f-8067-0ea6036c2572",
15-
"metadata": {},
15+
"metadata": {
16+
"execution": {
17+
"iopub.execute_input": "2024-08-07T20:10:24.908831Z",
18+
"iopub.status.busy": "2024-08-07T20:10:24.908521Z",
19+
"iopub.status.idle": "2024-08-07T20:10:25.960003Z",
20+
"shell.execute_reply": "2024-08-07T20:10:25.959395Z",
21+
"shell.execute_reply.started": "2024-08-07T20:10:24.908808Z"
22+
}
23+
},
1624
"outputs": [],
1725
"source": [
1826
"import numpy as np\n",
@@ -24,11 +32,26 @@
2432
"from statsmodels.api import OLS\n"
2533
]
2634
},
35+
{
36+
"cell_type": "markdown",
37+
"id": "b600e803-0dc3-466f-a46c-8b1fa329d97d",
38+
"metadata": {},
39+
"source": [
40+
"$n>p$ dgp, OLS solution not unique"
41+
]
42+
},
2743
{
2844
"cell_type": "code",
29-
"execution_count": 3,
45+
"execution_count": 2,
3046
"id": "8b778f4c-7ea7-4823-97d1-98a3d118e9fd",
3147
"metadata": {
48+
"execution": {
49+
"iopub.execute_input": "2024-08-07T20:10:27.002244Z",
50+
"iopub.status.busy": "2024-08-07T20:10:27.001851Z",
51+
"iopub.status.idle": "2024-08-07T20:10:30.681932Z",
52+
"shell.execute_reply": "2024-08-07T20:10:30.681240Z",
53+
"shell.execute_reply.started": "2024-08-07T20:10:27.002225Z"
54+
},
3255
"tags": []
3356
},
3457
"outputs": [],
@@ -52,18 +75,34 @@
5275
"y, X = sparse_dgp()\n"
5376
]
5477
},
78+
{
79+
"cell_type": "markdown",
80+
"id": "bc4600f1-21c9-4d56-a838-238a449e6622",
81+
"metadata": {},
82+
"source": [
83+
"### statsmodels"
84+
]
85+
},
5586
{
5687
"cell_type": "code",
57-
"execution_count": 4,
88+
"execution_count": 3,
5889
"id": "f55c4df2",
59-
"metadata": {},
90+
"metadata": {
91+
"execution": {
92+
"iopub.execute_input": "2024-08-07T20:10:50.994546Z",
93+
"iopub.status.busy": "2024-08-07T20:10:50.994120Z",
94+
"iopub.status.idle": "2024-08-07T20:12:57.233971Z",
95+
"shell.execute_reply": "2024-08-07T20:12:57.233090Z",
96+
"shell.execute_reply.started": "2024-08-07T20:10:50.994526Z"
97+
}
98+
},
6099
"outputs": [
61100
{
62101
"name": "stdout",
63102
"output_type": "stream",
64103
"text": [
65-
"CPU times: user 2h 23min 9s, sys: 58min 50s, total: 3h 22min\n",
66-
"Wall time: 14min 32s\n"
104+
"CPU times: user 1h 24min 14s, sys: 34min 40s, total: 1h 58min 55s\n",
105+
"Wall time: 2min 6s\n"
67106
]
68107
}
69108
],
@@ -74,50 +113,75 @@
74113
},
75114
{
76115
"cell_type": "code",
77-
"execution_count": 6,
116+
"execution_count": 4,
78117
"id": "ab2705b3",
79-
"metadata": {},
118+
"metadata": {
119+
"execution": {
120+
"iopub.execute_input": "2024-08-07T20:12:57.235703Z",
121+
"iopub.status.busy": "2024-08-07T20:12:57.235326Z",
122+
"iopub.status.idle": "2024-08-07T20:12:57.240783Z",
123+
"shell.execute_reply": "2024-08-07T20:12:57.240344Z",
124+
"shell.execute_reply.started": "2024-08-07T20:12:57.235682Z"
125+
}
126+
},
80127
"outputs": [
81128
{
82129
"data": {
83130
"text/plain": [
84-
"31.803339628159765"
131+
"32.06474491647644"
85132
]
86133
},
87-
"execution_count": 6,
134+
"execution_count": 4,
88135
"metadata": {},
89136
"output_type": "execute_result"
90137
}
91138
],
92139
"source": [
93-
"\n",
94140
"np.linalg.norm(smols.params)"
95141
]
96142
},
97143
{
98144
"cell_type": "markdown",
145+
"id": "4e49c36b-af05-4e09-896f-6895f6207d66",
99146
"metadata": {},
100147
"source": [
101-
"Very fast least squares solver (including for minimum norm interpolation problems). \n"
148+
"Statsmodels is very slow with such problems."
149+
]
150+
},
151+
{
152+
"cell_type": "markdown",
153+
"id": "7d19a866-7360-4fa2-8eb8-3f2f6a538e58",
154+
"metadata": {},
155+
"source": [
156+
"### scikit"
102157
]
103158
},
104159
{
105160
"cell_type": "code",
106161
"execution_count": 7,
107-
"metadata": {},
162+
"id": "82c630be",
163+
"metadata": {
164+
"execution": {
165+
"iopub.execute_input": "2024-08-07T20:13:16.255033Z",
166+
"iopub.status.busy": "2024-08-07T20:13:16.254911Z",
167+
"iopub.status.idle": "2024-08-07T20:14:07.937802Z",
168+
"shell.execute_reply": "2024-08-07T20:14:07.937238Z",
169+
"shell.execute_reply.started": "2024-08-07T20:13:16.255020Z"
170+
}
171+
},
108172
"outputs": [
109173
{
110174
"name": "stdout",
111175
"output_type": "stream",
112176
"text": [
113-
"CPU times: user 1h 24min 3s, sys: 1.51 s, total: 1h 24min 5s\n",
114-
"Wall time: 6min 14s\n"
177+
"CPU times: user 35min 26s, sys: 13min 45s, total: 49min 11s\n",
178+
"Wall time: 51.5 s\n"
115179
]
116180
},
117181
{
118182
"data": {
119183
"text/plain": [
120-
"Array(0.0001564, dtype=float32)"
184+
"1.794120407794253e-12"
121185
]
122186
},
123187
"execution_count": 7,
@@ -127,26 +191,29 @@
127191
],
128192
"source": [
129193
"%%time\n",
130-
"sol = lx.linear_solve( # solve # Ax = b\n",
131-
" operator = lx.MatrixLinearOperator(jnp.array(X)), # A\n",
132-
" vector = jnp.array(y), # b\n",
133-
" solver=lx.AutoLinearSolver(well_posed=None), # auto solver with no well-posedness check\n",
134-
" )\n",
135-
"\n",
136-
"betahat = sol.value\n",
137-
"# does it interpolate\n",
138-
"(y - X @ betahat).max()\n"
194+
"m = LinearRegression()\n",
195+
"m.fit(X, y)\n",
196+
"(y - m.predict(X)).max()\n"
139197
]
140198
},
141199
{
142200
"cell_type": "code",
143201
"execution_count": 8,
144-
"metadata": {},
202+
"id": "47998a93",
203+
"metadata": {
204+
"execution": {
205+
"iopub.execute_input": "2024-08-07T20:14:07.939157Z",
206+
"iopub.status.busy": "2024-08-07T20:14:07.938685Z",
207+
"iopub.status.idle": "2024-08-07T20:14:07.942731Z",
208+
"shell.execute_reply": "2024-08-07T20:14:07.942369Z",
209+
"shell.execute_reply.started": "2024-08-07T20:14:07.939133Z"
210+
}
211+
},
145212
"outputs": [
146213
{
147214
"data": {
148215
"text/plain": [
149-
"31.80334"
216+
"32.063915612235505"
150217
]
151218
},
152219
"execution_count": 8,
@@ -155,66 +222,100 @@
155222
}
156223
],
157224
"source": [
158-
"np.linalg.norm(betahat)\n"
225+
"np.linalg.norm(m.coef_)\n"
159226
]
160227
},
161228
{
162-
"cell_type": "code",
163-
"execution_count": 9,
229+
"cell_type": "markdown",
230+
"id": "2d8a87e2-ea14-4cb8-b9fa-7d261c741251",
164231
"metadata": {},
232+
"source": [
233+
"### lineax\n",
234+
"\n",
235+
"Very fast least squares solver (including for minimum norm interpolation problems). \n"
236+
]
237+
},
238+
{
239+
"cell_type": "code",
240+
"execution_count": 5,
241+
"id": "3207d070-779f-4107-9763-d0cda1a311e2",
242+
"metadata": {
243+
"execution": {
244+
"iopub.execute_input": "2024-08-07T20:12:57.241449Z",
245+
"iopub.status.busy": "2024-08-07T20:12:57.241317Z",
246+
"iopub.status.idle": "2024-08-07T20:13:16.249450Z",
247+
"shell.execute_reply": "2024-08-07T20:13:16.248802Z",
248+
"shell.execute_reply.started": "2024-08-07T20:12:57.241436Z"
249+
}
250+
},
165251
"outputs": [
166252
{
167253
"name": "stdout",
168254
"output_type": "stream",
169255
"text": [
170-
"CPU times: user 3h 1min 55s, sys: 0 ns, total: 3h 1min 55s\n",
171-
"Wall time: 13min 44s\n"
256+
"CPU times: user 10min 31s, sys: 3min 35s, total: 14min 6s\n",
257+
"Wall time: 18.9 s\n"
172258
]
173259
},
174260
{
175261
"data": {
176262
"text/plain": [
177-
"1.538325022920617e-12"
263+
"Array(0.00014114, dtype=float32)"
178264
]
179265
},
180-
"execution_count": 9,
266+
"execution_count": 5,
181267
"metadata": {},
182268
"output_type": "execute_result"
183269
}
184270
],
185271
"source": [
186272
"%%time\n",
187-
"m = LinearRegression()\n",
188-
"m.fit(X, y)\n",
189-
"(y - m.predict(X)).max()\n"
273+
"sol = lx.linear_solve( # solve # Ax = b\n",
274+
" operator = lx.MatrixLinearOperator(jnp.array(X)), # A\n",
275+
" vector = jnp.array(y), # b\n",
276+
" solver=lx.AutoLinearSolver(well_posed=None), \n",
277+
" )\n",
278+
"\n",
279+
"betahat = sol.value\n",
280+
"# does it interpolate\n",
281+
"(y - X @ betahat).max()"
190282
]
191283
},
192284
{
193285
"cell_type": "code",
194-
"execution_count": 10,
195-
"metadata": {},
286+
"execution_count": 6,
287+
"id": "24e09278-3a2c-4bec-b0db-9547278d51cc",
288+
"metadata": {
289+
"execution": {
290+
"iopub.execute_input": "2024-08-07T20:13:16.250977Z",
291+
"iopub.status.busy": "2024-08-07T20:13:16.250825Z",
292+
"iopub.status.idle": "2024-08-07T20:13:16.254344Z",
293+
"shell.execute_reply": "2024-08-07T20:13:16.253966Z",
294+
"shell.execute_reply.started": "2024-08-07T20:13:16.250962Z"
295+
}
296+
},
196297
"outputs": [
197298
{
198299
"data": {
199300
"text/plain": [
200-
"31.8032580188364"
301+
"32.064747"
201302
]
202303
},
203-
"execution_count": 10,
304+
"execution_count": 6,
204305
"metadata": {},
205306
"output_type": "execute_result"
206307
}
207308
],
208309
"source": [
209-
"np.linalg.norm(m.coef_)\n"
310+
"np.linalg.norm(betahat)\n"
210311
]
211312
}
212313
],
213314
"metadata": {
214315
"kernelspec": {
215-
"display_name": "metrics",
316+
"display_name": "Python 3.10 (recommended)",
216317
"language": "python",
217-
"name": "python3"
318+
"name": "python310"
218319
},
219320
"language_info": {
220321
"codemirror_mode": {

0 commit comments

Comments
 (0)