Skip to content

Commit

Permalink
Remove the deprecated jax.ops.index_update (pyro-ppl#1371)
Browse files Browse the repository at this point in the history
* Fix deprecated ops.index

* fix further numerical issues
  • Loading branch information
fehiepsi authored Mar 20, 2022
1 parent f333e91 commit 37ac4b3
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion examples/stein_dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _reverse_padded(padded, lengths):
def _reverse_single(p, length):
new = jnp.zeros_like(p)
reverse = jnp.roll(p[::-1], length, axis=0)
return jax.ops.index_update(new, jax.ops.index[:], reverse)
return new.at[:].set(reverse)

return jax.vmap(_reverse_single)(padded, lengths)

Expand Down
4 changes: 2 additions & 2 deletions test/contrib/test_funsor.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,11 @@ def transition_fn(x, y):
actual_log_joint = log_density(enum(config_enumerate(fun_model)), (data,), {}, {})[
0
]
assert_allclose(actual_log_joint, expected_log_joint)
assert_allclose(actual_log_joint, expected_log_joint, rtol=1e-6)

actual_last_x = enum(config_enumerate(fun_model))(data)
expected_last_x = enum(config_enumerate(model))(data)
assert_allclose(actual_last_x, expected_last_x)
assert_allclose(actual_last_x, expected_last_x, rtol=1e-6)


def test_scan_enum_plate():
Expand Down
6 changes: 3 additions & 3 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1816,7 +1816,7 @@ def test_biject_to(constraint, shape):

# test inv
z = transform.inv(y)
assert_allclose(x, z, atol=1e-6, rtol=1e-6)
assert_allclose(x, z, atol=1e-5, rtol=1e-5)

# test domain, currently all is constraints.real or constraints.real_vector
assert_array_equal(transform.domain(z), jnp.ones(batch_shape))
Expand Down Expand Up @@ -1881,8 +1881,8 @@ def inv_vec_transform(y):
expected = jnp.log(jnp.abs(grad(transform)(x)))
inv_expected = jnp.log(jnp.abs(grad(transform.inv)(y)))

assert_allclose(actual, expected, atol=1e-6, rtol=1e-6)
assert_allclose(actual, -inv_expected, atol=1e-6, rtol=1e-6)
assert_allclose(actual, expected, atol=1e-5, rtol=1e-5)
assert_allclose(actual, -inv_expected, atol=1e-5, rtol=1e-5)


# NB: skip transforms which are tested in `test_biject_to`
Expand Down

0 comments on commit 37ac4b3

Please sign in to comment.