Skip to content

Commit 1bb3b89

Browse files
Jake VanderPlascopybara-github
Jake VanderPlas
authored andcommitted
Remove references to deprecated jax.ShapedArray
This is deprecated as of jax-ml/jax#15263: most users will never need to use ShapedArray directly, and so having it exposed in the top-level public namespace causes undue confusion. PiperOrigin-RevId: 520189916
1 parent e962caa commit 1bb3b89

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

trax/tf_numpy/jax_tests/lax_numpy_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2881,7 +2881,7 @@ def body(i, xy):
28812881
f = lambda y: lax.fori_loop(0, 5, body, (y, y))
28822882
wrapped = linear_util.wrap_init(f)
28832883
pv = partial_eval.PartialVal(
2884-
(jax.ShapedArray((3, 4), onp.float32), jax.core.unit))
2884+
(jax.core.ShapedArray((3, 4), onp.float32), jax.core.unit))
28852885
_, _, consts = partial_eval.trace_to_jaxpr(wrapped, [pv])
28862886
self.assertFalse(
28872887
any(onp.array_equal(x, onp.full((3, 4), 2., dtype=onp.float32))

0 commit comments

Comments
 (0)