Skip to content

Commit

Permalink
Add a warning for calling inverse of a non-bijective transform (pyro-…
Browse files Browse the repository at this point in the history
…ppl#1269)

* Add a warning for inverse transform

* address comment for adding stacklevel
  • Loading branch information
fehiepsi authored Jan 2, 2022
1 parent 0b2eb76 commit fd99ec5
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
5 changes: 5 additions & 0 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,11 @@ def __call__(self, x):
return jnp.abs(x)

def _inverse(self, y):
warnings.warn(
"AbsTransform is not a bijective transform."
" The inverse of `y` will be `y`.",
stacklevel=find_stack_level(),
)
return y


Expand Down
2 changes: 1 addition & 1 deletion numpyro/infer/reparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ class NeuTraReparam(Reparam):
"""
Neural Transport reparameterizer [1] of multiple latent variables.
This uses a trained :class:`~pyro.contrib.autoguide.AutoContinuous`
This uses a trained :class:`~numpyro.infer.autoguide.AutoContinuous`
guide to alter the geometry of a model, typically for use e.g. in MCMC.
Example usage::
Expand Down
4 changes: 2 additions & 2 deletions numpyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,9 +444,9 @@ def format_shapes(
def model(*args, **kwargs):
...
with numpyro.handlers.seed(rng_key=1):
with numpyro.handlers.seed(rng_seed=1):
trace = numpyro.handlers.trace(model).get_trace(*args, **kwargs)
numpyro.util.format_shapes(trace)
print(numpyro.util.format_shapes(trace))
"""
if not trace.keys():
return title
Expand Down

0 comments on commit fd99ec5

Please sign in to comment.