diff --git a/examples/stein_dmm.py b/examples/stein_dmm.py index a3e997869..f8aa88dc0 100644 --- a/examples/stein_dmm.py +++ b/examples/stein_dmm.py @@ -39,7 +39,7 @@ def _reverse_padded(padded, lengths): def _reverse_single(p, length): new = jnp.zeros_like(p) reverse = jnp.roll(p[::-1], length, axis=0) - return jax.ops.index_update(new, jax.ops.index[:], reverse) + return new.at[:].set(reverse) return jax.vmap(_reverse_single)(padded, lengths) diff --git a/test/contrib/test_funsor.py b/test/contrib/test_funsor.py index 02fe148bd..795dbcc49 100644 --- a/test/contrib/test_funsor.py +++ b/test/contrib/test_funsor.py @@ -256,11 +256,11 @@ def transition_fn(x, y): actual_log_joint = log_density(enum(config_enumerate(fun_model)), (data,), {}, {})[ 0 ] - assert_allclose(actual_log_joint, expected_log_joint) + assert_allclose(actual_log_joint, expected_log_joint, rtol=1e-6) actual_last_x = enum(config_enumerate(fun_model))(data) expected_last_x = enum(config_enumerate(model))(data) - assert_allclose(actual_last_x, expected_last_x) + assert_allclose(actual_last_x, expected_last_x, rtol=1e-6) def test_scan_enum_plate(): diff --git a/test/test_distributions.py b/test/test_distributions.py index 8cfdc7564..395f39802 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -1816,7 +1816,7 @@ def test_biject_to(constraint, shape): # test inv z = transform.inv(y) - assert_allclose(x, z, atol=1e-6, rtol=1e-6) + assert_allclose(x, z, atol=1e-5, rtol=1e-5) # test domain, currently all is constraints.real or constraints.real_vector assert_array_equal(transform.domain(z), jnp.ones(batch_shape)) @@ -1881,8 +1881,8 @@ def inv_vec_transform(y): expected = jnp.log(jnp.abs(grad(transform)(x))) inv_expected = jnp.log(jnp.abs(grad(transform.inv)(y))) - assert_allclose(actual, expected, atol=1e-6, rtol=1e-6) - assert_allclose(actual, -inv_expected, atol=1e-6, rtol=1e-6) + assert_allclose(actual, expected, atol=1e-5, rtol=1e-5) + assert_allclose(actual, -inv_expected, atol=1e-5, rtol=1e-5) # NB: skip transforms which are tested in `test_biject_to`