|
10 | 10 | }, |
11 | 11 | { |
12 | 12 | "cell_type": "code", |
13 | | - "execution_count": 2, |
| 13 | + "execution_count": 1, |
14 | 14 | "id": "ee58a952-6a96-494f-8067-0ea6036c2572", |
15 | | - "metadata": {}, |
| 15 | + "metadata": { |
| 16 | + "execution": { |
| 17 | + "iopub.execute_input": "2024-08-07T20:10:24.908831Z", |
| 18 | + "iopub.status.busy": "2024-08-07T20:10:24.908521Z", |
| 19 | + "iopub.status.idle": "2024-08-07T20:10:25.960003Z", |
| 20 | + "shell.execute_reply": "2024-08-07T20:10:25.959395Z", |
| 21 | + "shell.execute_reply.started": "2024-08-07T20:10:24.908808Z" |
| 22 | + } |
| 23 | + }, |
16 | 24 | "outputs": [], |
17 | 25 | "source": [ |
18 | 26 | "import numpy as np\n", |
|
24 | 32 | "from statsmodels.api import OLS\n" |
25 | 33 | ] |
26 | 34 | }, |
| 35 | + { |
| 36 | + "cell_type": "markdown", |
| 37 | + "id": "b600e803-0dc3-466f-a46c-8b1fa329d97d", |
| 38 | + "metadata": {}, |
| 39 | + "source": [ |
| 40 | + "$n>p$ dgp, OLS solution not unique" |
| 41 | + ] |
| 42 | + }, |
27 | 43 | { |
28 | 44 | "cell_type": "code", |
29 | | - "execution_count": 3, |
| 45 | + "execution_count": 2, |
30 | 46 | "id": "8b778f4c-7ea7-4823-97d1-98a3d118e9fd", |
31 | 47 | "metadata": { |
| 48 | + "execution": { |
| 49 | + "iopub.execute_input": "2024-08-07T20:10:27.002244Z", |
| 50 | + "iopub.status.busy": "2024-08-07T20:10:27.001851Z", |
| 51 | + "iopub.status.idle": "2024-08-07T20:10:30.681932Z", |
| 52 | + "shell.execute_reply": "2024-08-07T20:10:30.681240Z", |
| 53 | + "shell.execute_reply.started": "2024-08-07T20:10:27.002225Z" |
| 54 | + }, |
32 | 55 | "tags": [] |
33 | 56 | }, |
34 | 57 | "outputs": [], |
|
52 | 75 | "y, X = sparse_dgp()\n" |
53 | 76 | ] |
54 | 77 | }, |
| 78 | + { |
| 79 | + "cell_type": "markdown", |
| 80 | + "id": "bc4600f1-21c9-4d56-a838-238a449e6622", |
| 81 | + "metadata": {}, |
| 82 | + "source": [ |
| 83 | + "### statsmodels" |
| 84 | + ] |
| 85 | + }, |
55 | 86 | { |
56 | 87 | "cell_type": "code", |
57 | | - "execution_count": 4, |
| 88 | + "execution_count": 3, |
58 | 89 | "id": "f55c4df2", |
59 | | - "metadata": {}, |
| 90 | + "metadata": { |
| 91 | + "execution": { |
| 92 | + "iopub.execute_input": "2024-08-07T20:10:50.994546Z", |
| 93 | + "iopub.status.busy": "2024-08-07T20:10:50.994120Z", |
| 94 | + "iopub.status.idle": "2024-08-07T20:12:57.233971Z", |
| 95 | + "shell.execute_reply": "2024-08-07T20:12:57.233090Z", |
| 96 | + "shell.execute_reply.started": "2024-08-07T20:10:50.994526Z" |
| 97 | + } |
| 98 | + }, |
60 | 99 | "outputs": [ |
61 | 100 | { |
62 | 101 | "name": "stdout", |
63 | 102 | "output_type": "stream", |
64 | 103 | "text": [ |
65 | | - "CPU times: user 2h 23min 9s, sys: 58min 50s, total: 3h 22min\n", |
66 | | - "Wall time: 14min 32s\n" |
| 104 | + "CPU times: user 1h 24min 14s, sys: 34min 40s, total: 1h 58min 55s\n", |
| 105 | + "Wall time: 2min 6s\n" |
67 | 106 | ] |
68 | 107 | } |
69 | 108 | ], |
|
74 | 113 | }, |
75 | 114 | { |
76 | 115 | "cell_type": "code", |
77 | | - "execution_count": 6, |
| 116 | + "execution_count": 4, |
78 | 117 | "id": "ab2705b3", |
79 | | - "metadata": {}, |
| 118 | + "metadata": { |
| 119 | + "execution": { |
| 120 | + "iopub.execute_input": "2024-08-07T20:12:57.235703Z", |
| 121 | + "iopub.status.busy": "2024-08-07T20:12:57.235326Z", |
| 122 | + "iopub.status.idle": "2024-08-07T20:12:57.240783Z", |
| 123 | + "shell.execute_reply": "2024-08-07T20:12:57.240344Z", |
| 124 | + "shell.execute_reply.started": "2024-08-07T20:12:57.235682Z" |
| 125 | + } |
| 126 | + }, |
80 | 127 | "outputs": [ |
81 | 128 | { |
82 | 129 | "data": { |
83 | 130 | "text/plain": [ |
84 | | - "31.803339628159765" |
| 131 | + "32.06474491647644" |
85 | 132 | ] |
86 | 133 | }, |
87 | | - "execution_count": 6, |
| 134 | + "execution_count": 4, |
88 | 135 | "metadata": {}, |
89 | 136 | "output_type": "execute_result" |
90 | 137 | } |
91 | 138 | ], |
92 | 139 | "source": [ |
93 | | - "\n", |
94 | 140 | "np.linalg.norm(smols.params)" |
95 | 141 | ] |
96 | 142 | }, |
97 | 143 | { |
98 | 144 | "cell_type": "markdown", |
| 145 | + "id": "4e49c36b-af05-4e09-896f-6895f6207d66", |
99 | 146 | "metadata": {}, |
100 | 147 | "source": [ |
101 | | - "Very fast least squares solver (including for minimum norm interpolation problems). \n" |
| 148 | + "Statsmodels is very slow with such problems." |
| 149 | + ] |
| 150 | + }, |
| 151 | + { |
| 152 | + "cell_type": "markdown", |
| 153 | + "id": "7d19a866-7360-4fa2-8eb8-3f2f6a538e58", |
| 154 | + "metadata": {}, |
| 155 | + "source": [ |
| 156 | + "### scikit" |
102 | 157 | ] |
103 | 158 | }, |
104 | 159 | { |
105 | 160 | "cell_type": "code", |
106 | 161 | "execution_count": 7, |
107 | | - "metadata": {}, |
| 162 | + "id": "82c630be", |
| 163 | + "metadata": { |
| 164 | + "execution": { |
| 165 | + "iopub.execute_input": "2024-08-07T20:13:16.255033Z", |
| 166 | + "iopub.status.busy": "2024-08-07T20:13:16.254911Z", |
| 167 | + "iopub.status.idle": "2024-08-07T20:14:07.937802Z", |
| 168 | + "shell.execute_reply": "2024-08-07T20:14:07.937238Z", |
| 169 | + "shell.execute_reply.started": "2024-08-07T20:13:16.255020Z" |
| 170 | + } |
| 171 | + }, |
108 | 172 | "outputs": [ |
109 | 173 | { |
110 | 174 | "name": "stdout", |
111 | 175 | "output_type": "stream", |
112 | 176 | "text": [ |
113 | | - "CPU times: user 1h 24min 3s, sys: 1.51 s, total: 1h 24min 5s\n", |
114 | | - "Wall time: 6min 14s\n" |
| 177 | + "CPU times: user 35min 26s, sys: 13min 45s, total: 49min 11s\n", |
| 178 | + "Wall time: 51.5 s\n" |
115 | 179 | ] |
116 | 180 | }, |
117 | 181 | { |
118 | 182 | "data": { |
119 | 183 | "text/plain": [ |
120 | | - "Array(0.0001564, dtype=float32)" |
| 184 | + "1.794120407794253e-12" |
121 | 185 | ] |
122 | 186 | }, |
123 | 187 | "execution_count": 7, |
|
127 | 191 | ], |
128 | 192 | "source": [ |
129 | 193 | "%%time\n", |
130 | | - "sol = lx.linear_solve( # solve # Ax = b\n", |
131 | | - " operator = lx.MatrixLinearOperator(jnp.array(X)), # A\n", |
132 | | - " vector = jnp.array(y), # b\n", |
133 | | - " solver=lx.AutoLinearSolver(well_posed=None), # auto solver with no well-posedness check\n", |
134 | | - " )\n", |
135 | | - "\n", |
136 | | - "betahat = sol.value\n", |
137 | | - "# does it interpolate\n", |
138 | | - "(y - X @ betahat).max()\n" |
| 194 | + "m = LinearRegression()\n", |
| 195 | + "m.fit(X, y)\n", |
| 196 | + "(y - m.predict(X)).max()\n" |
139 | 197 | ] |
140 | 198 | }, |
141 | 199 | { |
142 | 200 | "cell_type": "code", |
143 | 201 | "execution_count": 8, |
144 | | - "metadata": {}, |
| 202 | + "id": "47998a93", |
| 203 | + "metadata": { |
| 204 | + "execution": { |
| 205 | + "iopub.execute_input": "2024-08-07T20:14:07.939157Z", |
| 206 | + "iopub.status.busy": "2024-08-07T20:14:07.938685Z", |
| 207 | + "iopub.status.idle": "2024-08-07T20:14:07.942731Z", |
| 208 | + "shell.execute_reply": "2024-08-07T20:14:07.942369Z", |
| 209 | + "shell.execute_reply.started": "2024-08-07T20:14:07.939133Z" |
| 210 | + } |
| 211 | + }, |
145 | 212 | "outputs": [ |
146 | 213 | { |
147 | 214 | "data": { |
148 | 215 | "text/plain": [ |
149 | | - "31.80334" |
| 216 | + "32.063915612235505" |
150 | 217 | ] |
151 | 218 | }, |
152 | 219 | "execution_count": 8, |
|
155 | 222 | } |
156 | 223 | ], |
157 | 224 | "source": [ |
158 | | - "np.linalg.norm(betahat)\n" |
| 225 | + "np.linalg.norm(m.coef_)\n" |
159 | 226 | ] |
160 | 227 | }, |
161 | 228 | { |
162 | | - "cell_type": "code", |
163 | | - "execution_count": 9, |
| 229 | + "cell_type": "markdown", |
| 230 | + "id": "2d8a87e2-ea14-4cb8-b9fa-7d261c741251", |
164 | 231 | "metadata": {}, |
| 232 | + "source": [ |
| 233 | + "### lineax\n", |
| 234 | + "\n", |
| 235 | + "Very fast least squares solver (including for minimum norm interpolation problems). \n" |
| 236 | + ] |
| 237 | + }, |
| 238 | + { |
| 239 | + "cell_type": "code", |
| 240 | + "execution_count": 5, |
| 241 | + "id": "3207d070-779f-4107-9763-d0cda1a311e2", |
| 242 | + "metadata": { |
| 243 | + "execution": { |
| 244 | + "iopub.execute_input": "2024-08-07T20:12:57.241449Z", |
| 245 | + "iopub.status.busy": "2024-08-07T20:12:57.241317Z", |
| 246 | + "iopub.status.idle": "2024-08-07T20:13:16.249450Z", |
| 247 | + "shell.execute_reply": "2024-08-07T20:13:16.248802Z", |
| 248 | + "shell.execute_reply.started": "2024-08-07T20:12:57.241436Z" |
| 249 | + } |
| 250 | + }, |
165 | 251 | "outputs": [ |
166 | 252 | { |
167 | 253 | "name": "stdout", |
168 | 254 | "output_type": "stream", |
169 | 255 | "text": [ |
170 | | - "CPU times: user 3h 1min 55s, sys: 0 ns, total: 3h 1min 55s\n", |
171 | | - "Wall time: 13min 44s\n" |
| 256 | + "CPU times: user 10min 31s, sys: 3min 35s, total: 14min 6s\n", |
| 257 | + "Wall time: 18.9 s\n" |
172 | 258 | ] |
173 | 259 | }, |
174 | 260 | { |
175 | 261 | "data": { |
176 | 262 | "text/plain": [ |
177 | | - "1.538325022920617e-12" |
| 263 | + "Array(0.00014114, dtype=float32)" |
178 | 264 | ] |
179 | 265 | }, |
180 | | - "execution_count": 9, |
| 266 | + "execution_count": 5, |
181 | 267 | "metadata": {}, |
182 | 268 | "output_type": "execute_result" |
183 | 269 | } |
184 | 270 | ], |
185 | 271 | "source": [ |
186 | 272 | "%%time\n", |
187 | | - "m = LinearRegression()\n", |
188 | | - "m.fit(X, y)\n", |
189 | | - "(y - m.predict(X)).max()\n" |
| 273 | + "sol = lx.linear_solve( # solve # Ax = b\n", |
| 274 | + " operator = lx.MatrixLinearOperator(jnp.array(X)), # A\n", |
| 275 | + " vector = jnp.array(y), # b\n", |
| 276 | + " solver=lx.AutoLinearSolver(well_posed=None), \n", |
| 277 | + " )\n", |
| 278 | + "\n", |
| 279 | + "betahat = sol.value\n", |
| 280 | + "# does it interpolate\n", |
| 281 | + "(y - X @ betahat).max()" |
190 | 282 | ] |
191 | 283 | }, |
192 | 284 | { |
193 | 285 | "cell_type": "code", |
194 | | - "execution_count": 10, |
195 | | - "metadata": {}, |
| 286 | + "execution_count": 6, |
| 287 | + "id": "24e09278-3a2c-4bec-b0db-9547278d51cc", |
| 288 | + "metadata": { |
| 289 | + "execution": { |
| 290 | + "iopub.execute_input": "2024-08-07T20:13:16.250977Z", |
| 291 | + "iopub.status.busy": "2024-08-07T20:13:16.250825Z", |
| 292 | + "iopub.status.idle": "2024-08-07T20:13:16.254344Z", |
| 293 | + "shell.execute_reply": "2024-08-07T20:13:16.253966Z", |
| 294 | + "shell.execute_reply.started": "2024-08-07T20:13:16.250962Z" |
| 295 | + } |
| 296 | + }, |
196 | 297 | "outputs": [ |
197 | 298 | { |
198 | 299 | "data": { |
199 | 300 | "text/plain": [ |
200 | | - "31.8032580188364" |
| 301 | + "32.064747" |
201 | 302 | ] |
202 | 303 | }, |
203 | | - "execution_count": 10, |
| 304 | + "execution_count": 6, |
204 | 305 | "metadata": {}, |
205 | 306 | "output_type": "execute_result" |
206 | 307 | } |
207 | 308 | ], |
208 | 309 | "source": [ |
209 | | - "np.linalg.norm(m.coef_)\n" |
| 310 | + "np.linalg.norm(betahat)\n" |
210 | 311 | ] |
211 | 312 | } |
212 | 313 | ], |
213 | 314 | "metadata": { |
214 | 315 | "kernelspec": { |
215 | | - "display_name": "metrics", |
| 316 | + "display_name": "Python 3.10 (recommended)", |
216 | 317 | "language": "python", |
217 | | - "name": "python3" |
| 318 | + "name": "python310" |
218 | 319 | }, |
219 | 320 | "language_info": { |
220 | 321 | "codemirror_mode": { |
|
0 commit comments