Skip to content

Commit ce4ef38

Browse files
authored
Edits to opt savings and opt invest (#154)
* misc * misc * misc * misc
1 parent 91efc04 commit ce4ef38

File tree

7 files changed

+313
-107
lines changed

7 files changed

+313
-107
lines changed

lectures/_static/lecture_specific/hpi.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
# Implements HPI-Howard policy iteration routine
2-
3-
def policy_iteration(model, maxiter=250):
4-
constants, sizes, arrays = model
1+
def howard_policy_iteration(model, maxiter=250):
2+
"""
3+
Implements Howard policy iteration (see dp.quantecon.org)
4+
"""
5+
params, sizes, arrays = model
56
σ = jnp.zeros(sizes, dtype=int)
67
i, error = 0, 1.0
78
while error > 0 and i < maxiter:
8-
v_σ = get_value(σ, constants, sizes, arrays)
9-
σ_new = get_greedy(v_σ, constants, sizes, arrays)
9+
v_σ = get_value(σ, params, sizes, arrays)
10+
σ_new = get_greedy(v_σ, params, sizes, arrays)
1011
error = jnp.max(jnp.abs(σ_new - σ))
1112
σ = σ_new
1213
i = i + 1
Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
# Implements the OPI-Optimal policy Iteration routine
2-
31
def optimistic_policy_iteration(model, tol=1e-5, m=10):
4-
constants, sizes, arrays = model
2+
"""
3+
Implements optimistic policy iteration (see dp.quantecon.org)
4+
"""
5+
params, sizes, arrays = model
56
v = jnp.zeros(sizes)
67
error = tol + 1
78
while error > tol:
89
last_v = v
9-
σ = get_greedy(v, constants, sizes, arrays)
10+
σ = get_greedy(v, params, sizes, arrays)
1011
for _ in range(m):
11-
v = T_σ(v, σ, constants, sizes, arrays)
12+
v = T_σ(v, σ, params, sizes, arrays)
1213
error = jnp.max(jnp.abs(v - last_v))
13-
return get_greedy(v, constants, sizes, arrays)
14+
return get_greedy(v, params, sizes, arrays)
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
def successive_approx_jax(T, # Operator (callable)
22
x_0, # Initial condition
3-
tolerance=1e-6, # Error tolerance
3+
tol=1e-6, # Error tolerance
44
max_iter=10_000): # Max iteration bound
55
def body_fun(k_x_err):
66
k, x, error = k_x_err
@@ -10,9 +10,9 @@ def body_fun(k_x_err):
1010

1111
def cond_fun(k_x_err):
1212
k, x, error = k_x_err
13-
return jnp.logical_and(error > tolerance, k < max_iter)
13+
return jnp.logical_and(error > tol, k < max_iter)
1414

15-
k, x, error = jax.lax.while_loop(cond_fun, body_fun, (1, x_0, tolerance + 1))
15+
k, x, error = jax.lax.while_loop(cond_fun, body_fun, (1, x_0, tol + 1))
1616
return x
1717

1818
successive_approx_jax = jax.jit(successive_approx_jax, static_argnums=(0,))
Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
# Implements VFI-Value Function iteration
2-
3-
def value_iteration(model, tol=1e-5):
4-
constants, sizes, arrays = model
1+
def value_function_iteration(model, tol=1e-5):
2+
"""
3+
Implements value function iteration.
4+
"""
5+
params, sizes, arrays = model
56
vz = jnp.zeros(sizes)
6-
_T = lambda v: T(v, constants, sizes, arrays)
7+
_T = lambda v: T(v, params, sizes, arrays)
78
v_star = successive_approx_jax(_T, vz, tolerance=tol)
8-
return get_greedy(v_star, constants, sizes, arrays)
9+
return get_greedy(v_star, params, sizes, arrays)
10+

lectures/opt_invest.md

Lines changed: 107 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -294,30 +294,128 @@ get_value = jax.jit(get_value, static_argnums=(2,))
294294
We use successive approximation for VFI.
295295

296296
```{code-cell} ipython3
297-
:load: _static/lecture_specific/successive_approx.py
297+
def successive_approx_jax(T, # Operator (callable)
298+
x_0, # Initial condition
299+
tol=1e-6, # Error tolerance
300+
max_iter=10_000): # Max iteration bound
301+
def body_fun(k_x_err):
302+
k, x, error = k_x_err
303+
x_new = T(x)
304+
error = jnp.max(jnp.abs(x_new - x))
305+
return k + 1, x_new, error
306+
307+
def cond_fun(k_x_err):
308+
k, x, error = k_x_err
309+
return jnp.logical_and(error > tol, k < max_iter)
310+
311+
k, x, error = jax.lax.while_loop(cond_fun, body_fun, (1, x_0, tol + 1))
312+
return x
313+
314+
successive_approx_jax = jax.jit(successive_approx_jax, static_argnums=(0,))
315+
```
316+
317+
For OPI we'll add a compiled routine that computes $T_σ^m v$.
318+
319+
```{code-cell} ipython3
320+
def iterate_policy_operator(σ, v, m, params, sizes, arrays):
321+
322+
def update(i, v):
323+
v = T_σ(v, σ, params, sizes, arrays)
324+
return v
325+
326+
v = jax.lax.fori_loop(0, m, update, v)
327+
return v
328+
329+
iterate_policy_operator = jax.jit(iterate_policy_operator,
330+
static_argnums=(4,))
298331
```
299332

300333
Finally, we introduce the solvers that implement VFI, HPI and OPI.
301334

302335
```{code-cell} ipython3
303-
:load: _static/lecture_specific/vfi.py
336+
def value_function_iteration(model, tol=1e-5):
337+
"""
338+
Implements value function iteration.
339+
"""
340+
params, sizes, arrays = model
341+
vz = jnp.zeros(sizes)
342+
_T = lambda v: T(v, params, sizes, arrays)
343+
v_star = successive_approx_jax(_T, vz, tol=tol)
344+
return get_greedy(v_star, params, sizes, arrays)
304345
```
305346

347+
For OPI we will use a compiled JAX `lax.while_loop` operation to speed execution.
348+
349+
306350
```{code-cell} ipython3
307-
:load: _static/lecture_specific/hpi.py
351+
def opi_loop(params, sizes, arrays, m, tol, max_iter):
352+
"""
353+
Implements optimistic policy iteration (see dp.quantecon.org) with
354+
step size m.
355+
356+
"""
357+
v_init = jnp.zeros(sizes)
358+
359+
def condition_function(inputs):
360+
i, v, error = inputs
361+
return jnp.logical_and(error > tol, i < max_iter)
362+
363+
def update(inputs):
364+
i, v, error = inputs
365+
last_v = v
366+
σ = get_greedy(v, params, sizes, arrays)
367+
v = iterate_policy_operator(σ, v, m, params, sizes, arrays)
368+
error = jnp.max(jnp.abs(v - last_v))
369+
i += 1
370+
return i, v, error
371+
372+
num_iter, v, error = jax.lax.while_loop(condition_function,
373+
update,
374+
(0, v_init, tol + 1))
375+
376+
return get_greedy(v, params, sizes, arrays)
377+
378+
opi_loop = jax.jit(opi_loop, static_argnums=(1,))
308379
```
309380

381+
Here's a friendly interface to OPI
382+
310383
```{code-cell} ipython3
311-
:load: _static/lecture_specific/opi.py
384+
def optimistic_policy_iteration(model, m=10, tol=1e-5, max_iter=10_000):
385+
params, sizes, arrays = model
386+
σ_star = opi_loop(params, sizes, arrays, m, tol, max_iter)
387+
return σ_star
312388
```
313389

390+
Here's HPI
391+
392+
393+
```{code-cell} ipython3
394+
def howard_policy_iteration(model, maxiter=250):
395+
"""
396+
Implements Howard policy iteration (see dp.quantecon.org)
397+
"""
398+
params, sizes, arrays = model
399+
σ = jnp.zeros(sizes, dtype=int)
400+
i, error = 0, 1.0
401+
while error > 0 and i < maxiter:
402+
v_σ = get_value(σ, params, sizes, arrays)
403+
σ_new = get_greedy(v_σ, params, sizes, arrays)
404+
error = jnp.max(jnp.abs(σ_new - σ))
405+
σ = σ_new
406+
i = i + 1
407+
print(f"Concluded loop {i} with error {error}.")
408+
return σ
409+
```
410+
411+
314412
```{code-cell} ipython3
315413
:tags: [hide-output]
316414
317415
model = create_investment_model()
318416
print("Starting HPI.")
319417
qe.tic()
320-
out = policy_iteration(model)
418+
out = howard_policy_iteration(model)
321419
elapsed = qe.toc()
322420
print(out)
323421
print(f"HPI completed in {elapsed} seconds.")
@@ -328,7 +426,7 @@ print(f"HPI completed in {elapsed} seconds.")
328426
329427
print("Starting VFI.")
330428
qe.tic()
331-
out = value_iteration(model)
429+
out = value_function_iteration(model)
332430
elapsed = qe.toc()
333431
print(out)
334432
print(f"VFI completed in {elapsed} seconds.")
@@ -356,7 +454,7 @@ y_grid, z_grid, Q = arrays
356454
```
357455

358456
```{code-cell} ipython3
359-
σ_star = policy_iteration(model)
457+
σ_star = howard_policy_iteration(model)
360458
361459
fig, ax = plt.subplots(figsize=(9, 5))
362460
ax.plot(y_grid, y_grid, "k--", label="45")
@@ -376,15 +474,15 @@ m_vals = range(5, 600, 40)
376474
model = create_investment_model()
377475
print("Running Howard policy iteration.")
378476
qe.tic()
379-
σ_pi = policy_iteration(model)
477+
σ_pi = howard_policy_iteration(model)
380478
pi_time = qe.toc()
381479
```
382480

383481
```{code-cell} ipython3
384482
print(f"PI completed in {pi_time} seconds.")
385483
print("Running value function iteration.")
386484
qe.tic()
387-
σ_vfi = value_iteration(model, tol=1e-5)
485+
σ_vfi = value_function_iteration(model, tol=1e-5)
388486
vfi_time = qe.toc()
389487
print(f"VFI completed in {vfi_time} seconds.")
390488
```

0 commit comments

Comments
 (0)