From ed2161d86d2c66353cd1a618972fba72aa3c3547 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sun, 4 Dec 2022 23:01:49 -0600 Subject: [PATCH] Add support for power transforms --- aeppl/transforms.py | 42 ++++++++++++++++++++++++++++++++-------- tests/test_transforms.py | 19 ++++++++++++++++++ 2 files changed, 53 insertions(+), 8 deletions(-) diff --git a/aeppl/transforms.py b/aeppl/transforms.py index 6abe62b6..dde42529 100644 --- a/aeppl/transforms.py +++ b/aeppl/transforms.py @@ -10,7 +10,7 @@ from aesara.graph.fg import FunctionGraph from aesara.graph.op import Op from aesara.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter -from aesara.tensor.math import add, exp, log, mul, reciprocal, sub, true_div +from aesara.tensor.math import add, exp, log, mul, pow, reciprocal, sub, true_div from aesara.tensor.rewriting.basic import ( register_specialize, register_stabilize, @@ -422,8 +422,20 @@ def transform(measurable_input, *other_inputs): def measurable_reciprocal(fgraph, node): """Rewrite a `reciprocal` node to a `MeasurableVariable`.""" - def transform(measurable_input, *other_inputs): - return ReciprocalTransform(), (measurable_input,) + new_node = at.power(node.inputs[0], at.as_tensor(-1)).owner + return measurable_pow.transform(fgraph, new_node) + + +@register_measurable_ir +@node_rewriter([pow]) +def measurable_pow(fgraph, node): + """Rewrite a `pow` node to a `MeasurableVariable`.""" + + def transform(measurable_input, *args): + return PowerTransform(transform_args_fn=lambda *inputs: inputs[-1]), ( + measurable_input, + *args, + ) return construct_elemwise_transform(fgraph, node, transform) @@ -579,17 +591,31 @@ def log_jac_det(self, value, *inputs): return -at.log(value) -class ReciprocalTransform(RVTransform): - name = "reciprocal" +class PowerTransform(RVTransform): + name = "power" + + def __init__(self, transform_args_fn): + self.transform_args_fn = transform_args_fn def forward(self, value, *inputs): - return at.reciprocal(value) + power = self.transform_args_fn(*inputs) + return at.power(value, power) def backward(self, value, *inputs): - return at.reciprocal(value) + power = self.transform_args_fn(*inputs) + + inv_power = at.reciprocal(power) + return at.switch( + at.eq(at.mod(power, 2), 0), + at.power(value, inv_power), + at.sgn(value) * at.power(at.abs(value), inv_power), + ) def log_jac_det(self, value, *inputs): - return -2 * at.log(value) + from aeppl.logprob import xlogy0 + + power = self.transform_args_fn(*inputs) + return at.log(at.abs(power)) + xlogy0((power - 1), at.abs(value)) class IntervalTransform(RVTransform): diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 6670404f..968b6a11 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -763,3 +763,22 @@ def test_transform_measurable_sub(): with pytest.raises(RuntimeError, match="The logprob terms"): joint_logprob(Z_rv, X_rv) + + +@pytest.mark.parametrize( + "pow_fn, exp_val_fn", + [ + (lambda x: x**2, lambda z: sp.stats.chi2(df=1).logpdf(z)) + # TODO: Add more cases. + ], +) +def test_transform_measurable_pow(pow_fn, exp_val_fn): + X_rv = at.random.normal(0, 1, name="X") + Z_rv = pow_fn(X_rv) + Z_rv.name = "Z" + + z_logp, (z_vv,) = conditional_logprob(Z_rv) + z_logp_fn = aesara.function([z_vv], z_logp[Z_rv]) + + z_val = 0.5 + assert np.allclose(z_logp_fn(z_val), exp_val_fn(z_val))