@@ -294,30 +294,128 @@ get_value = jax.jit(get_value, static_argnums=(2,))
294
294
We use successive approximation for VFI.
295
295
296
296
``` {code-cell} ipython3
297
- :load: _static/lecture_specific/successive_approx.py
297
+ def successive_approx_jax(T, # Operator (callable)
298
+ x_0, # Initial condition
299
+ tol=1e-6, # Error tolerance
300
+ max_iter=10_000): # Max iteration bound
301
+ def body_fun(k_x_err):
302
+ k, x, error = k_x_err
303
+ x_new = T(x)
304
+ error = jnp.max(jnp.abs(x_new - x))
305
+ return k + 1, x_new, error
306
+
307
+ def cond_fun(k_x_err):
308
+ k, x, error = k_x_err
309
+ return jnp.logical_and(error > tol, k < max_iter)
310
+
311
+ k, x, error = jax.lax.while_loop(cond_fun, body_fun, (1, x_0, tol + 1))
312
+ return x
313
+
314
+ successive_approx_jax = jax.jit(successive_approx_jax, static_argnums=(0,))
315
+ ```
316
+
317
+ For OPI we'll add a compiled routine that computes $T_σ^m v$.
318
+
319
+ ``` {code-cell} ipython3
320
+ def iterate_policy_operator(σ, v, m, params, sizes, arrays):
321
+
322
+ def update(i, v):
323
+ v = T_σ(v, σ, params, sizes, arrays)
324
+ return v
325
+
326
+ v = jax.lax.fori_loop(0, m, update, v)
327
+ return v
328
+
329
+ iterate_policy_operator = jax.jit(iterate_policy_operator,
330
+ static_argnums=(4,))
298
331
```
299
332
300
333
Finally, we introduce the solvers that implement VFI, HPI and OPI.
301
334
302
335
``` {code-cell} ipython3
303
- :load: _static/lecture_specific/vfi.py
336
+ def value_function_iteration(model, tol=1e-5):
337
+ """
338
+ Implements value function iteration.
339
+ """
340
+ params, sizes, arrays = model
341
+ vz = jnp.zeros(sizes)
342
+ _T = lambda v: T(v, params, sizes, arrays)
343
+ v_star = successive_approx_jax(_T, vz, tol=tol)
344
+ return get_greedy(v_star, params, sizes, arrays)
304
345
```
305
346
347
+ For OPI we will use a compiled JAX ` lax.while_loop ` operation to speed execution.
348
+
349
+
306
350
``` {code-cell} ipython3
307
- :load: _static/lecture_specific/hpi.py
351
+ def opi_loop(params, sizes, arrays, m, tol, max_iter):
352
+ """
353
+ Implements optimistic policy iteration (see dp.quantecon.org) with
354
+ step size m.
355
+
356
+ """
357
+ v_init = jnp.zeros(sizes)
358
+
359
+ def condition_function(inputs):
360
+ i, v, error = inputs
361
+ return jnp.logical_and(error > tol, i < max_iter)
362
+
363
+ def update(inputs):
364
+ i, v, error = inputs
365
+ last_v = v
366
+ σ = get_greedy(v, params, sizes, arrays)
367
+ v = iterate_policy_operator(σ, v, m, params, sizes, arrays)
368
+ error = jnp.max(jnp.abs(v - last_v))
369
+ i += 1
370
+ return i, v, error
371
+
372
+ num_iter, v, error = jax.lax.while_loop(condition_function,
373
+ update,
374
+ (0, v_init, tol + 1))
375
+
376
+ return get_greedy(v, params, sizes, arrays)
377
+
378
+ opi_loop = jax.jit(opi_loop, static_argnums=(1,))
308
379
```
309
380
381
+ Here's a friendly interface to OPI
382
+
310
383
``` {code-cell} ipython3
311
- :load: _static/lecture_specific/opi.py
384
+ def optimistic_policy_iteration(model, m=10, tol=1e-5, max_iter=10_000):
385
+ params, sizes, arrays = model
386
+ σ_star = opi_loop(params, sizes, arrays, m, tol, max_iter)
387
+ return σ_star
312
388
```
313
389
390
+ Here's HPI
391
+
392
+
393
+ ``` {code-cell} ipython3
394
+ def howard_policy_iteration(model, maxiter=250):
395
+ """
396
+ Implements Howard policy iteration (see dp.quantecon.org)
397
+ """
398
+ params, sizes, arrays = model
399
+ σ = jnp.zeros(sizes, dtype=int)
400
+ i, error = 0, 1.0
401
+ while error > 0 and i < maxiter:
402
+ v_σ = get_value(σ, params, sizes, arrays)
403
+ σ_new = get_greedy(v_σ, params, sizes, arrays)
404
+ error = jnp.max(jnp.abs(σ_new - σ))
405
+ σ = σ_new
406
+ i = i + 1
407
+ print(f"Concluded loop {i} with error {error}.")
408
+ return σ
409
+ ```
410
+
411
+
314
412
``` {code-cell} ipython3
315
413
:tags: [hide-output]
316
414
317
415
model = create_investment_model()
318
416
print("Starting HPI.")
319
417
qe.tic()
320
- out = policy_iteration (model)
418
+ out = howard_policy_iteration (model)
321
419
elapsed = qe.toc()
322
420
print(out)
323
421
print(f"HPI completed in {elapsed} seconds.")
@@ -328,7 +426,7 @@ print(f"HPI completed in {elapsed} seconds.")
328
426
329
427
print("Starting VFI.")
330
428
qe.tic()
331
- out = value_iteration (model)
429
+ out = value_function_iteration (model)
332
430
elapsed = qe.toc()
333
431
print(out)
334
432
print(f"VFI completed in {elapsed} seconds.")
@@ -356,7 +454,7 @@ y_grid, z_grid, Q = arrays
356
454
```
357
455
358
456
``` {code-cell} ipython3
359
- σ_star = policy_iteration (model)
457
+ σ_star = howard_policy_iteration (model)
360
458
361
459
fig, ax = plt.subplots(figsize=(9, 5))
362
460
ax.plot(y_grid, y_grid, "k--", label="45")
@@ -376,15 +474,15 @@ m_vals = range(5, 600, 40)
376
474
model = create_investment_model()
377
475
print("Running Howard policy iteration.")
378
476
qe.tic()
379
- σ_pi = policy_iteration (model)
477
+ σ_pi = howard_policy_iteration (model)
380
478
pi_time = qe.toc()
381
479
```
382
480
383
481
``` {code-cell} ipython3
384
482
print(f"PI completed in {pi_time} seconds.")
385
483
print("Running value function iteration.")
386
484
qe.tic()
387
- σ_vfi = value_iteration (model, tol=1e-5)
485
+ σ_vfi = value_function_iteration (model, tol=1e-5)
388
486
vfi_time = qe.toc()
389
487
print(f"VFI completed in {vfi_time} seconds.")
390
488
```
0 commit comments