Skip to content

Commit 5bab20e

Browse files
authored
Fixes for investment lecture (#122)
* misc * misc * misc * misc * misc
1 parent 9445f8d commit 5bab20e

File tree

2 files changed

+62
-77
lines changed

2 files changed

+62
-77
lines changed

lectures/opt_invest.md

Lines changed: 27 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ We require the following library to be installed.
2525
!pip install --upgrade quantecon
2626
```
2727

28-
A monopolist faces inverse demand
29-
curve
28+
We study a monopolist who faces inverse demand curve
3029

3130
$$
3231
P_t = a_0 - a_1 Y_t + Z_t,
@@ -38,7 +37,7 @@ where
3837
* $Y_t$ is output and
3938
* $Z_t$ is a demand shock.
4039

41-
We assume that $Z_t$ is a discretized AR(1) process.
40+
We assume that $Z_t$ is a discretized AR(1) process, specified below.
4241

4342
Current profits are
4443

@@ -116,10 +115,10 @@ def create_investment_model(
116115

117116

118117
Let's re-write the vectorized version of the right-hand side of the
119-
Bellman equation (before maximization), which is a 3D array representing:
118+
Bellman equation (before maximization), which is a 3D array representing
120119

121120
$$
122-
B(y, z, y') = r(y, z, y') + \beta \sum_{z'} v(y', z') Q(z, z')
121+
B(y, z, y') = r(y, z, y') + \beta \sum_{z'} v(y', z') Q(z, z')
123122
$$
124123

125124
for all $(y, z, y')$.
@@ -154,8 +153,10 @@ def B(v, constants, sizes, arrays):
154153
B = jax.jit(B, static_argnums=(2,))
155154
```
156155

156+
We define a function to compute the current rewards $r_\sigma$ given policy $\sigma$,
157+
which is defined as the vector
157158

158-
Define a function to compute the current rewards given policy $\sigma$.
159+
$$ r_\sigma(y, z) := r(y, z, \sigma(y, z)) $$
159160

160161
```{code-cell} ipython3
161162
def compute_r_σ(σ, constants, sizes, arrays):
@@ -238,47 +239,32 @@ T_σ = jax.jit(T_σ, static_argnums=(3,))
238239

239240
Next, we want to computes the lifetime value of following policy $\sigma$.
240241

241-
The basic problem is to solve the linear system
242+
This lifetime value is a function $v_\sigma$ that satisfies
242243

243-
$$ v(y, z) = r(y, z, \sigma(y, z)) + \beta \sum_{z'} v(\sigma(y, z), z') Q(z, z) $$
244+
$$ v_\sigma(y, z) = r_\sigma(y, z) + \beta \sum_{z'} v_\sigma(\sigma(y, z), z') Q(z, z') $$
244245

245-
for $v$.
246+
We wish to solve this equation for $v_\sigma$.
246247

247-
It turns out to be helpful to rewrite this as
248+
Suppose we define the linear operator $L_\sigma$ by
248249

249-
$$ v(y, z) = r(y, z, \sigma(y, z)) + \beta \sum_{y', z'} v(y', z') P_\sigma(y, z, y', z') $$
250+
$$ (L_\sigma v)(y, z) = v(y, z) - \beta \sum_{z'} v(\sigma(y, z), z') Q(z, z') $$
250251

251-
where $P_\sigma(y, z, y', z') = 1\{y' = \sigma(y, z)\} Q(z, z')$.
252-
253-
We want to write this as $v = r_\sigma + \beta P_\sigma v$ and then solve for $v$
254-
255-
Note, however, that $v$ is a multi-index array, rather than a vector.
256-
257-
258-
The value $v_{\sigma}$ of a policy $\sigma$ is defined as
252+
With this notation, the problem is to solve for $v$ via
259253

260254
$$
261-
v_{\sigma} = (I - \beta P_{\sigma})^{-1} r_{\sigma}
255+
(L_{\sigma} v)(y, z) = r_\sigma(y, z)
262256
$$
263257

264-
Here we set up the linear map $v \mapsto R_{\sigma} v$,
265-
266-
where $R_{\sigma} := I - \beta P_{\sigma}$
267-
268-
In the investment problem, this map can be expressed as
269-
270-
$$
271-
(R_{\sigma} v)(y, z) = v(y, z) - \beta \sum_{z'} v(\sigma(y, z), z') Q(z, z')
272-
$$
258+
In vector for this is $L_\sigma v = r_\sigma$, which tells us that the function
259+
we seek is
273260

274-
Defining the map as above works in a more intuitive multi-index setting
275-
(e.g. working with $v[i, j]$ rather than flattening v to a one-dimensional
276-
array) and avoids instantiating the large matrix $P_{\sigma}$.
261+
$$ v_\sigma = L_\sigma^{-1} r_\sigma $$
277262

278-
Let's define the function $R_{\sigma}$.
263+
JAX allows us to solve linear systems defined in terms of operators; the first
264+
step is to define the function $L_{\sigma}$.
279265

280266
```{code-cell} ipython3
281-
def R_σ(v, σ, constants, sizes, arrays):
267+
def L_σ(v, σ, constants, sizes, arrays):
282268
283269
β, a_0, a_1, γ, c = constants
284270
y_size, z_size = sizes
@@ -296,12 +282,11 @@ def R_σ(v, σ, constants, sizes, arrays):
296282
# Compute and return v[i, j] - β Σ_jp v[σ[i, j], jp] * Q[j, jp]
297283
return v - β * jnp.sum(V * Q, axis=2)
298284
299-
R_σ = jax.jit(R_σ, static_argnums=(3,))
285+
L_σ = jax.jit(L_σ, static_argnums=(3,))
300286
```
301287

288+
Now we can define a function to compute $v_{\sigma}$
302289

303-
Define a function to get the value $v_{\sigma}$ of policy
304-
$\sigma$ by inverting the linear map $R_{\sigma}$.
305290

306291
```{code-cell} ipython3
307292
def get_value(σ, constants, sizes, arrays):
@@ -313,16 +298,16 @@ def get_value(σ, constants, sizes, arrays):
313298
314299
r_σ = compute_r_σ(σ, constants, sizes, arrays)
315300
316-
# Reduce R_σ to a function in v
317-
partial_R_σ = lambda v: R_σ(v, σ, constants, sizes, arrays)
301+
# Reduce L_σ to a function in v
302+
partial_L_σ = lambda v: L_σ(v, σ, constants, sizes, arrays)
318303
319-
return jax.scipy.sparse.linalg.bicgstab(partial_R_σ, r_σ)[0]
304+
return jax.scipy.sparse.linalg.bicgstab(partial_L_σ, r_σ)[0]
320305
321306
get_value = jax.jit(get_value, static_argnums=(2,))
322307
```
323308

324309

325-
Now we define the solvers, which implement VFI, HPI and OPI.
310+
Finally, we introduce the solvers that implement VFI, HPI and OPI.
326311

327312
```{code-cell} ipython3
328313
:load: _static/lecture_specific/vfi.py
@@ -396,7 +381,7 @@ plt.show()
396381
Let's plot the time taken by each of the solvers and compare them.
397382

398383
```{code-cell} ipython3
399-
m_vals = range(5, 3000, 100)
384+
m_vals = range(5, 600, 40)
400385
```
401386

402387
```{code-cell} ipython3

lectures/opt_savings.md

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,14 @@ def B(v, constants, sizes, arrays):
133133

134134
## Operators
135135

136-
Now we define the policy operator $T_\sigma$
136+
137+
We define a function to compute the current rewards $r_\sigma$ given policy $\sigma$,
138+
which is defined as the vector
139+
140+
141+
$$ r_\sigma(w, y) := r(w, y, \sigma(w, y)) $$
142+
143+
137144

138145
```{code-cell} ipython3
139146
def compute_r_σ(σ, constants, sizes, arrays):
@@ -157,6 +164,8 @@ def compute_r_σ(σ, constants, sizes, arrays):
157164
return r_σ
158165
```
159166

167+
Now we define the policy operator $T_\sigma$
168+
160169
```{code-cell} ipython3
161170
def T_σ(v, σ, constants, sizes, arrays):
162171
"The σ-policy operator."
@@ -201,47 +210,36 @@ def get_greedy(v, constants, sizes, arrays):
201210

202211
The function below computes the value $v_\sigma$ of following policy $\sigma$.
203212

204-
The basic problem is to solve the linear system
205-
206-
$$ v(w,y ) = u(Rw + y - \sigma(w, y)) + β \sum_{y'} v(\sigma(w, y), y') Q(y, y) $$
213+
This lifetime value is a function $v_\sigma$ that satisfies
207214

208-
for $v$.
215+
$$ v_\sigma(w, y) = r_\sigma(w, y) + \beta \sum_{y'} v_\sigma(\sigma(w, y), y') Q(y, y') $$
209216

210-
It turns out to be helpful to rewrite this as
217+
We wish to solve this equation for $v_\sigma$.
211218

212-
$$ v(w,y) = r(w, y, \sigma(w, y)) + β \sum_{w', y'} v(w', y') P_\sigma(w, y, w', y') $$
219+
Suppose we define the linear operator $L_\sigma$ by
213220

214-
where $P_\sigma(w, y, w', y') = 1\{w' = \sigma(w, y)\} Q(y, y')$.
221+
$$ (L_\sigma v)(w, y) = v(w, y) - \beta \sum_{y'} v(\sigma(w, y), y') Q(y, y') $$
215222

216-
We want to write this as $v = r_\sigma + P_\sigma v$ and then solve for $v$
223+
With this notation, the problem is to solve for $v$ via
217224

218-
Note, however,
225+
$$
226+
(L_{\sigma} v)(w, y) = r_\sigma(w, y)
227+
$$
219228

220-
* $v$ is a 2 index array, rather than a single vector.
221-
* $P_\sigma$ has four indices rather than 2
229+
In vector for this is $L_\sigma v = r_\sigma$, which tells us that the function
230+
we seek is
222231

223-
The code below
232+
$$ v_\sigma = L_\sigma^{-1} r_\sigma $$
224233

225-
1. reshapes $v$ and $r_\sigma$ to 1D arrays and $P_\sigma$ to a matrix
226-
2. solves the linear system
227-
3. converts back to multi-index arrays.
234+
JAX allows us to solve linear systems defined in terms of operators; the first
235+
step is to define the function $L_{\sigma}$.
228236

229237
```{code-cell} ipython3
230-
def R_σ(v, σ, constants, sizes, arrays):
238+
def L_σ(v, σ, constants, sizes, arrays):
231239
"""
232-
The value v_σ of a policy σ is defined as
233-
234-
v_σ = (I - β P_σ)^{-1} r_σ
235-
236-
Here we set up the linear map v -> R_σ v, where R_σ := I - β P_σ.
240+
Here we set up the linear map v -> L_σ v, where
237241
238-
In the consumption problem, this map can be expressed as
239-
240-
(R_σ v)(w, y) = v(w, y) - β Σ_y′ v(σ(w, y), y′) Q(y, y′)
241-
242-
Defining the map as above works in a more intuitive multi-index setting
243-
(e.g. working with v[i, j] rather than flattening v to a one-dimensional
244-
array) and avoids instantiating the large matrix P_σ.
242+
(L_σ v)(w, y) = v(w, y) - β Σ_y′ v(σ(w, y), y′) Q(y, y′)
245243
246244
"""
247245
@@ -262,9 +260,11 @@ def R_σ(v, σ, constants, sizes, arrays):
262260
return v - β * jnp.sum(V * Q, axis=2)
263261
```
264262

263+
Now we can define a function to compute $v_{\sigma}$
264+
265265
```{code-cell} ipython3
266266
def get_value(σ, constants, sizes, arrays):
267-
"Get the value v_σ of policy σ by inverting the linear map R_σ."
267+
"Get the value v_σ of policy σ by inverting the linear map L_σ."
268268
269269
# Unpack
270270
β, R, γ = constants
@@ -273,10 +273,10 @@ def get_value(σ, constants, sizes, arrays):
273273
274274
r_σ = compute_r_σ(σ, constants, sizes, arrays)
275275
276-
# Reduce R_σ to a function in v
277-
partial_R_σ = lambda v: R_σ(v, σ, constants, sizes, arrays)
276+
# Reduce L_σ to a function in v
277+
partial_L_σ = lambda v: L_σ(v, σ, constants, sizes, arrays)
278278
279-
return jax.scipy.sparse.linalg.bicgstab(partial_R_σ, r_σ)[0]
279+
return jax.scipy.sparse.linalg.bicgstab(partial_L_σ, r_σ)[0]
280280
```
281281

282282
## JIT compiled versions
@@ -288,7 +288,7 @@ T = jax.jit(T, static_argnums=(2,))
288288
get_greedy = jax.jit(get_greedy, static_argnums=(2,))
289289
get_value = jax.jit(get_value, static_argnums=(2,))
290290
T_σ = jax.jit(T_σ, static_argnums=(3,))
291-
R_σ = jax.jit(R_σ, static_argnums=(3,))
291+
L_σ = jax.jit(L_σ, static_argnums=(3,))
292292
```
293293

294294
## Solvers
@@ -379,7 +379,7 @@ model = create_consumption_model()
379379
σ_pi, pi_time = run_algorithm(policy_iteration, model)
380380
σ_vfi, vfi_time = run_algorithm(value_iteration, model, tol=1e-5)
381381
382-
m_vals = range(5, 3000, 100)
382+
m_vals = range(5, 600, 40)
383383
opi_times = []
384384
for m in m_vals:
385385
σ_opi, opi_time = run_algorithm(optimistic_policy_iteration, model, m=m, tol=1e-5)

0 commit comments

Comments
 (0)