Skip to content

Commit 473c1e6

Browse files
Refactor MeasurableElemwise and automatic random variable transforms
These changes make it possible for users to extend coverage of measurable `Elemwise` transforms. They also put these rewrites under the management of Aesara's rewrite system, which means they can be applied more efficiently and their use can customized and tracked with more granularity. `sub`, `neg`, and `true_div` support is added, as well.
1 parent f03a820 commit 473c1e6

12 files changed

+371
-164
lines changed

aeppl/abstract.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import abc
22
from copy import copy
33
from functools import singledispatch
4-
from typing import Callable, List, Tuple
4+
from typing import Callable, List
55

66
from aesara.graph.basic import Apply, Variable
77
from aesara.graph.op import Op
@@ -122,15 +122,5 @@ def assign_custom_measurable_outputs(
122122
class MeasurableElemwise(Elemwise):
123123
"""Base class for Measurable Elemwise variables"""
124124

125-
valid_scalar_types: Tuple[MetaType, ...] = ()
126-
127-
def __init__(self, scalar_op, *args, **kwargs):
128-
if not isinstance(scalar_op, self.valid_scalar_types):
129-
raise TypeError(
130-
f"scalar_op {scalar_op} is not valid for class {self.__class__}. "
131-
f"Acceptable types are {self.valid_scalar_types}"
132-
)
133-
super().__init__(scalar_op, *args, **kwargs)
134-
135125

136126
MeasurableVariable.register(MeasurableElemwise)

aeppl/censoring.py

Lines changed: 55 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1-
from typing import List, Optional
1+
from typing import TYPE_CHECKING, List, Optional
22

33
import aesara.tensor as at
44
import numpy as np
55
from aesara.graph.basic import Node
66
from aesara.graph.fg import FunctionGraph
77
from aesara.graph.rewriting.basic import node_rewriter
8-
from aesara.scalar.basic import Ceil, Clip, Floor, RoundHalfToEven
8+
from aesara.scalar.basic import ceil as scalar_ceil
99
from aesara.scalar.basic import clip as scalar_clip
10-
from aesara.tensor.elemwise import Elemwise
10+
from aesara.scalar.basic import floor as scalar_floor
11+
from aesara.scalar.basic import round_half_to_even as scalar_round_half_to_even
12+
from aesara.tensor.math import ceil, clip, floor, round_half_to_even
1113
from aesara.tensor.var import TensorConstant
1214

1315
from aeppl.abstract import (
@@ -18,32 +20,27 @@
1820
from aeppl.logprob import CheckParameterValue, _logcdf, _logprob, logdiffexp
1921
from aeppl.rewriting import measurable_ir_rewrites_db
2022

23+
if TYPE_CHECKING:
24+
from aesara.graph.basic import Op, Variable
25+
2126

2227
class MeasurableClip(MeasurableElemwise):
2328
"""A placeholder used to specify a log-likelihood for a clipped RV sub-graph."""
2429

25-
valid_scalar_types = (Clip,)
26-
2730

2831
measurable_clip = MeasurableClip(scalar_clip)
2932

3033

31-
@node_rewriter(tracks=[Elemwise])
34+
@node_rewriter([clip])
3235
def find_measurable_clips(
3336
fgraph: FunctionGraph, node: Node
34-
) -> Optional[List[MeasurableClip]]:
37+
) -> Optional[List["Variable"]]:
3538
# TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub)
3639

3740
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
3841
if rv_map_feature is None:
3942
return None # pragma: no cover
4043

41-
if isinstance(node.op, MeasurableClip):
42-
return None # pragma: no cover
43-
44-
if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Clip)):
45-
return None
46-
4744
clipped_var = node.outputs[0]
4845
base_var, lower_bound, upper_bound = node.inputs
4946

@@ -75,7 +72,6 @@ def find_measurable_clips(
7572
measurable_ir_rewrites_db.register(
7673
"find_measurable_clips",
7774
find_measurable_clips,
78-
0,
7975
"basic",
8076
"censoring",
8177
)
@@ -147,27 +143,55 @@ def clip_logprob(op, values, base_rv, lower_bound, upper_bound, **kwargs):
147143
class MeasurableRound(MeasurableElemwise):
148144
"""A placeholder used to specify a log-likelihood for a clipped RV sub-graph."""
149145

150-
valid_scalar_types = (RoundHalfToEven, Floor, Ceil)
151146

147+
measurable_ceil = MeasurableRound(scalar_ceil)
148+
measurable_floor = MeasurableRound(scalar_floor)
149+
measurable_round_half_to_even = MeasurableRound(scalar_round_half_to_even)
152150

153-
@node_rewriter(tracks=[Elemwise])
154-
def find_measurable_roundings(
155-
fgraph: FunctionGraph, node: Node
156-
) -> Optional[List[MeasurableRound]]:
151+
152+
@node_rewriter([ceil])
153+
def find_measurable_ceil(fgraph: FunctionGraph, node: Node):
154+
return construct_measurable_rounding(fgraph, node, measurable_ceil)
155+
156+
157+
@node_rewriter([floor])
158+
def find_measurable_floor(fgraph: FunctionGraph, node: Node):
159+
return construct_measurable_rounding(fgraph, node, measurable_floor)
160+
161+
162+
@node_rewriter([round_half_to_even])
163+
def find_measurable_round_half_to_even(fgraph: FunctionGraph, node: Node):
164+
return construct_measurable_rounding(fgraph, node, measurable_round_half_to_even)
165+
166+
167+
measurable_ir_rewrites_db.register(
168+
"find_measurable_ceil",
169+
find_measurable_ceil,
170+
"basic",
171+
"censoring",
172+
)
173+
measurable_ir_rewrites_db.register(
174+
"find_measurable_floor",
175+
find_measurable_floor,
176+
"basic",
177+
"censoring",
178+
)
179+
measurable_ir_rewrites_db.register(
180+
"find_measurable_round_half_to_even",
181+
find_measurable_round_half_to_even,
182+
"basic",
183+
"censoring",
184+
)
185+
186+
187+
def construct_measurable_rounding(
188+
fgraph: FunctionGraph, node: Node, rounded_op: "Op"
189+
) -> Optional[List["Variable"]]:
157190

158191
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
159192
if rv_map_feature is None:
160193
return None # pragma: no cover
161194

162-
if isinstance(node.op, MeasurableRound):
163-
return None # pragma: no cover
164-
165-
if not (
166-
isinstance(node.op, Elemwise)
167-
and isinstance(node.op.scalar_op, MeasurableRound.valid_scalar_types)
168-
):
169-
return None
170-
171195
(rounded_var,) = node.outputs
172196
(base_var,) = node.inputs
173197

@@ -183,21 +207,11 @@ def find_measurable_roundings(
183207
# Make base_var unmeasurable
184208
unmeasurable_base_var = assign_custom_measurable_outputs(base_var.owner)
185209

186-
rounded_op = MeasurableRound(node.op.scalar_op)
187210
rounded_rv = rounded_op.make_node(unmeasurable_base_var).default_output()
188211
rounded_rv.name = rounded_var.name
189212
return [rounded_rv]
190213

191214

192-
measurable_ir_rewrites_db.register(
193-
"find_measurable_roundings",
194-
find_measurable_roundings,
195-
0,
196-
"basic",
197-
"censoring",
198-
)
199-
200-
201215
@_logprob.register(MeasurableRound)
202216
def round_logprob(op, values, base_rv, **kwargs):
203217
r"""Logprob of a rounded censored distribution
@@ -226,15 +240,15 @@ def round_logprob(op, values, base_rv, **kwargs):
226240
"""
227241
(value,) = values
228242

229-
if isinstance(op.scalar_op, RoundHalfToEven):
243+
if op == measurable_round_half_to_even:
230244
value = at.round(value)
231245
value_upper = value + 0.5
232246
value_lower = value - 0.5
233-
elif isinstance(op.scalar_op, Floor):
247+
elif op == measurable_floor:
234248
value = at.floor(value)
235249
value_upper = value + 1.0
236250
value_lower = value
237-
elif isinstance(op.scalar_op, Ceil):
251+
elif op == measurable_ceil:
238252
value = at.ceil(value)
239253
value_upper = value
240254
value_lower = value - 1.0

aeppl/cumsum.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ def find_measurable_cumsums(fgraph, node) -> Optional[List[MeasurableCumsum]]:
8383
measurable_ir_rewrites_db.register(
8484
"find_measurable_cumsums",
8585
find_measurable_cumsums,
86-
0,
8786
"basic",
8887
"cumsum",
8988
)

aeppl/mixture.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,6 @@ def logprob_MixtureRV(
423423
[mixture_replace, switch_mixture_replace],
424424
max_use_ratio=aesara.config.optdb__max_use_ratio,
425425
),
426-
0,
427426
"basic",
428427
"mixture",
429428
)

aeppl/rewriting.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -218,9 +218,7 @@ def incsubtensor_rv_replace(fgraph, node):
218218

219219
logprob_rewrites_db = SequenceDB()
220220
logprob_rewrites_db.name = "logprob_rewrites_db"
221-
logprob_rewrites_db.register(
222-
"pre-canonicalize", optdb.query("+canonicalize"), -10, "basic"
223-
)
221+
logprob_rewrites_db.register("pre-canonicalize", optdb.query("+canonicalize"), "basic")
224222

225223
# These rewrites convert un-measurable variables into their measurable forms,
226224
# but they need to be reapplied, because some of the measurable forms require
@@ -229,22 +227,18 @@ def incsubtensor_rv_replace(fgraph, node):
229227
measurable_ir_rewrites_db.name = "measurable_ir_rewrites_db"
230228

231229
logprob_rewrites_db.register(
232-
"measurable_ir_rewrites", measurable_ir_rewrites_db, -10, "basic"
230+
"measurable_ir_rewrites", measurable_ir_rewrites_db, "basic"
233231
)
234232

235233
# These rewrites push random/measurable variables "down", making them closer to
236234
# (or eventually) the graph outputs. Often this is done by lifting other `Op`s
237235
# "up" through the random/measurable variables and into their inputs.
236+
measurable_ir_rewrites_db.register("subtensor_lift", local_subtensor_rv_lift, "basic")
238237
measurable_ir_rewrites_db.register(
239-
"subtensor_lift", local_subtensor_rv_lift, -5, "basic"
240-
)
241-
measurable_ir_rewrites_db.register(
242-
"incsubtensor_lift", incsubtensor_rv_replace, -5, "basic"
238+
"incsubtensor_lift", incsubtensor_rv_replace, "basic"
243239
)
244240

245-
logprob_rewrites_db.register(
246-
"post-canonicalize", optdb.query("+canonicalize"), 10, "basic"
247-
)
241+
logprob_rewrites_db.register("post-canonicalize", optdb.query("+canonicalize"), "basic")
248242

249243

250244
def construct_ir_fgraph(

aeppl/scan.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -513,19 +513,17 @@ def _get_measurable_outputs_MeasurableScan(op, node):
513513
# out2in(
514514
# add_opts_to_inner_graphs, name="add_opts_to_inner_graphs", ignore_newtrees=True
515515
# ),
516-
-100,
517516
"basic",
518517
"scan",
519518
)
520519

521520
measurable_ir_rewrites_db.register(
522521
"find_measurable_scans",
523522
find_measurable_scans,
524-
0,
525523
"basic",
526524
"scan",
527525
)
528526

529527
# Add scan canonicalizations that aren't in the canonicalization DB
530-
logprob_rewrites_db.register("scan_eqopt1", scan_eqopt1, -9, "basic", "scan")
531-
logprob_rewrites_db.register("scan_eqopt2", scan_eqopt2, -9, "basic", "scan")
528+
logprob_rewrites_db.register("scan_eqopt1", scan_eqopt1, "basic", "scan")
529+
logprob_rewrites_db.register("scan_eqopt2", scan_eqopt2, "basic", "scan")

aeppl/tensor.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,25 +273,24 @@ def find_measurable_dimshuffles(fgraph, node) -> Optional[List[MeasurableDimShuf
273273

274274

275275
measurable_ir_rewrites_db.register(
276-
"dimshuffle_lift", local_dimshuffle_rv_lift, -5, "basic", "tensor"
276+
"dimshuffle_lift", local_dimshuffle_rv_lift, "basic", "tensor"
277277
)
278278

279279

280280
# We register this later than `dimshuffle_lift` so that it is only applied as a fallback
281281
measurable_ir_rewrites_db.register(
282-
"find_measurable_dimshuffles", find_measurable_dimshuffles, 0, "basic", "tensor"
282+
"find_measurable_dimshuffles", find_measurable_dimshuffles, "basic", "tensor"
283283
)
284284

285285

286286
measurable_ir_rewrites_db.register(
287-
"broadcast_to_lift", naive_bcast_rv_lift, -5, "basic", "tensor"
287+
"broadcast_to_lift", naive_bcast_rv_lift, "basic", "tensor"
288288
)
289289

290290

291291
measurable_ir_rewrites_db.register(
292292
"find_measurable_stacks",
293293
find_measurable_stacks,
294-
0,
295294
"basic",
296295
"tensor",
297296
)

0 commit comments

Comments
 (0)