@@ -25,8 +25,7 @@ We require the following library to be installed.
25
25
!pip install --upgrade quantecon
26
26
```
27
27
28
- A monopolist faces inverse demand
29
- curve
28
+ We study a monopolist who faces inverse demand curve
30
29
31
30
$$
32
31
P_t = a_0 - a_1 Y_t + Z_t,
38
37
* $Y_t$ is output and
39
38
* $Z_t$ is a demand shock.
40
39
41
- We assume that $Z_t$ is a discretized AR(1) process.
40
+ We assume that $Z_t$ is a discretized AR(1) process, specified below .
42
41
43
42
Current profits are
44
43
@@ -116,10 +115,10 @@ def create_investment_model(
116
115
117
116
118
117
Let's re-write the vectorized version of the right-hand side of the
119
- Bellman equation (before maximization), which is a 3D array representing:
118
+ Bellman equation (before maximization), which is a 3D array representing
120
119
121
120
$$
122
- B(y, z, y') = r(y, z, y') + \beta \sum_{z'} v(y', z') Q(z, z')
121
+ B(y, z, y') = r(y, z, y') + \beta \sum_{z'} v(y', z') Q(z, z')
123
122
$$
124
123
125
124
for all $(y, z, y')$.
@@ -154,8 +153,10 @@ def B(v, constants, sizes, arrays):
154
153
B = jax.jit(B, static_argnums=(2,))
155
154
```
156
155
156
+ We define a function to compute the current rewards $r_ \sigma$ given policy $\sigma$,
157
+ which is defined as the vector
157
158
158
- Define a function to compute the current rewards given policy $\sigma$.
159
+ $$ r_\sigma(y, z) := r(y, z, \sigma(y, z)) $$
159
160
160
161
``` {code-cell} ipython3
161
162
def compute_r_σ(σ, constants, sizes, arrays):
@@ -238,47 +239,32 @@ T_σ = jax.jit(T_σ, static_argnums=(3,))
238
239
239
240
Next, we want to computes the lifetime value of following policy $\sigma$.
240
241
241
- The basic problem is to solve the linear system
242
+ This lifetime value is a function $v _ \sigma$ that satisfies
242
243
243
- $$ v (y, z) = r(y, z, \sigma(y, z)) + \beta \sum_{z'} v (\sigma(y, z), z') Q(z, z) $$
244
+ $$ v_\sigma (y, z) = r_ \sigma(y, z) + \beta \sum_{z'} v_\sigma (\sigma(y, z), z') Q(z, z' ) $$
244
245
245
- for $v $.
246
+ We wish to solve this equation for $v _ \sigma $.
246
247
247
- It turns out to be helpful to rewrite this as
248
+ Suppose we define the linear operator $L _ \sigma$ by
248
249
249
- $$ v (y, z) = r (y, z, \sigma(y, z)) + \beta \sum_{y', z'} v(y', z') P_ \sigma(y, z, y' , z') $$
250
+ $$ (L_\sigma v) (y, z) = v (y, z) - \beta \sum_{z'} v(\sigma(y, z), z') Q(z , z') $$
250
251
251
- where $P_ \sigma(y, z, y', z') = 1\{ y' = \sigma(y, z)\} Q(z, z')$.
252
-
253
- We want to write this as $v = r_ \sigma + \beta P_ \sigma v$ and then solve for $v$
254
-
255
- Note, however, that $v$ is a multi-index array, rather than a vector.
256
-
257
-
258
- The value $v_ {\sigma}$ of a policy $\sigma$ is defined as
252
+ With this notation, the problem is to solve for $v$ via
259
253
260
254
$$
261
- v_ {\sigma} = (I - \beta P_{\sigma})^{-1} r_{ \sigma}
255
+ (L_ {\sigma} v)(y, z) = r_ \sigma(y, z)
262
256
$$
263
257
264
- Here we set up the linear map $v \mapsto R_ {\sigma} v$,
265
-
266
- where $R_ {\sigma} := I - \beta P_ {\sigma}$
267
-
268
- In the investment problem, this map can be expressed as
269
-
270
- $$
271
- (R_{\sigma} v)(y, z) = v(y, z) - \beta \sum_{z'} v(\sigma(y, z), z') Q(z, z')
272
- $$
258
+ In vector for this is $L_ \sigma v = r_ \sigma$, which tells us that the function
259
+ we seek is
273
260
274
- Defining the map as above works in a more intuitive multi-index setting
275
- (e.g. working with $v[ i, j] $ rather than flattening v to a one-dimensional
276
- array) and avoids instantiating the large matrix $P_ {\sigma}$.
261
+ $$ v_\sigma = L_\sigma^{-1} r_\sigma $$
277
262
278
- Let's define the function $R_ {\sigma}$.
263
+ JAX allows us to solve linear systems defined in terms of operators; the first
264
+ step is to define the function $L_ {\sigma}$.
279
265
280
266
``` {code-cell} ipython3
281
- def R_σ (v, σ, constants, sizes, arrays):
267
+ def L_σ (v, σ, constants, sizes, arrays):
282
268
283
269
β, a_0, a_1, γ, c = constants
284
270
y_size, z_size = sizes
@@ -296,12 +282,11 @@ def R_σ(v, σ, constants, sizes, arrays):
296
282
# Compute and return v[i, j] - β Σ_jp v[σ[i, j], jp] * Q[j, jp]
297
283
return v - β * jnp.sum(V * Q, axis=2)
298
284
299
- R_σ = jax.jit(R_σ , static_argnums=(3,))
285
+ L_σ = jax.jit(L_σ , static_argnums=(3,))
300
286
```
301
287
288
+ Now we can define a function to compute $v_ {\sigma}$
302
289
303
- Define a function to get the value $v_ {\sigma}$ of policy
304
- $\sigma$ by inverting the linear map $R_ {\sigma}$.
305
290
306
291
``` {code-cell} ipython3
307
292
def get_value(σ, constants, sizes, arrays):
@@ -313,16 +298,16 @@ def get_value(σ, constants, sizes, arrays):
313
298
314
299
r_σ = compute_r_σ(σ, constants, sizes, arrays)
315
300
316
- # Reduce R_σ to a function in v
317
- partial_R_σ = lambda v: R_σ (v, σ, constants, sizes, arrays)
301
+ # Reduce L_σ to a function in v
302
+ partial_L_σ = lambda v: L_σ (v, σ, constants, sizes, arrays)
318
303
319
- return jax.scipy.sparse.linalg.bicgstab(partial_R_σ , r_σ)[0]
304
+ return jax.scipy.sparse.linalg.bicgstab(partial_L_σ , r_σ)[0]
320
305
321
306
get_value = jax.jit(get_value, static_argnums=(2,))
322
307
```
323
308
324
309
325
- Now we define the solvers, which implement VFI, HPI and OPI.
310
+ Finally, we introduce the solvers that implement VFI, HPI and OPI.
326
311
327
312
``` {code-cell} ipython3
328
313
:load: _static/lecture_specific/vfi.py
@@ -396,7 +381,7 @@ plt.show()
396
381
Let's plot the time taken by each of the solvers and compare them.
397
382
398
383
``` {code-cell} ipython3
399
- m_vals = range(5, 3000, 100 )
384
+ m_vals = range(5, 600, 40 )
400
385
```
401
386
402
387
``` {code-cell} ipython3
0 commit comments