Skip to content

Commit 77a5f1f

Browse files
authored
update jax shape discussion (#413)
1 parent a383d0e commit 77a5f1f

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

lectures/jax_intro.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,12 +274,18 @@ for A in matrices:
274274
print(A)
275275
```
276276

277-
One point to remember is that JAX expects tuples to describe array shapes, even for flat arrays. Hence, to get a one-dimensional array of normal random draws we use `(len, )` for the shape, as in
277+
To get a one-dimensional array of normal random draws, we can either use `(len, )` for the shape, as in
278278

279279
```{code-cell} ipython3
280280
random.normal(key, (5, ))
281281
```
282282

283+
or simply use `5` as the shape argument:
284+
285+
```{code-cell} ipython3
286+
random.normal(key, 5)
287+
```
288+
283289
## JIT compilation
284290

285291
The JAX just-in-time (JIT) compiler accelerates logic within functions by fusing linear

0 commit comments

Comments
 (0)