1
- from typing import List , Optional
1
+ from typing import TYPE_CHECKING , List , Optional
2
2
3
3
import aesara .tensor as at
4
4
import numpy as np
5
5
from aesara .graph .basic import Node
6
6
from aesara .graph .fg import FunctionGraph
7
7
from aesara .graph .rewriting .basic import node_rewriter
8
- from aesara .scalar .basic import Ceil , Clip , Floor , RoundHalfToEven
8
+ from aesara .scalar .basic import ceil as scalar_ceil
9
9
from aesara .scalar .basic import clip as scalar_clip
10
- from aesara .tensor .elemwise import Elemwise
10
+ from aesara .scalar .basic import floor as scalar_floor
11
+ from aesara .scalar .basic import round_half_to_even as scalar_round_half_to_even
12
+ from aesara .tensor .math import ceil , clip , floor , round_half_to_even
11
13
from aesara .tensor .var import TensorConstant
12
14
13
15
from aeppl .abstract import (
18
20
from aeppl .logprob import CheckParameterValue , _logcdf , _logprob , logdiffexp
19
21
from aeppl .rewriting import measurable_ir_rewrites_db
20
22
23
+ if TYPE_CHECKING :
24
+ from aesara .graph .basic import Op , Variable
25
+
21
26
22
27
class MeasurableClip (MeasurableElemwise ):
23
28
"""A placeholder used to specify a log-likelihood for a clipped RV sub-graph."""
24
29
25
- valid_scalar_types = (Clip ,)
26
-
27
30
28
31
measurable_clip = MeasurableClip (scalar_clip )
29
32
30
33
31
- @node_rewriter (tracks = [ Elemwise ])
34
+ @node_rewriter ([ clip ])
32
35
def find_measurable_clips (
33
36
fgraph : FunctionGraph , node : Node
34
- ) -> Optional [List [MeasurableClip ]]:
37
+ ) -> Optional [List ["Variable" ]]:
35
38
# TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub)
36
39
37
40
rv_map_feature = getattr (fgraph , "preserve_rv_mappings" , None )
38
41
if rv_map_feature is None :
39
42
return None # pragma: no cover
40
43
41
- if isinstance (node .op , MeasurableClip ):
42
- return None # pragma: no cover
43
-
44
- if not (isinstance (node .op , Elemwise ) and isinstance (node .op .scalar_op , Clip )):
45
- return None
46
-
47
44
clipped_var = node .outputs [0 ]
48
45
base_var , lower_bound , upper_bound = node .inputs
49
46
@@ -75,7 +72,6 @@ def find_measurable_clips(
75
72
measurable_ir_rewrites_db .register (
76
73
"find_measurable_clips" ,
77
74
find_measurable_clips ,
78
- 0 ,
79
75
"basic" ,
80
76
"censoring" ,
81
77
)
@@ -147,27 +143,55 @@ def clip_logprob(op, values, base_rv, lower_bound, upper_bound, **kwargs):
147
143
class MeasurableRound (MeasurableElemwise ):
148
144
"""A placeholder used to specify a log-likelihood for a clipped RV sub-graph."""
149
145
150
- valid_scalar_types = (RoundHalfToEven , Floor , Ceil )
151
146
147
+ measurable_ceil = MeasurableRound (scalar_ceil )
148
+ measurable_floor = MeasurableRound (scalar_floor )
149
+ measurable_round_half_to_even = MeasurableRound (scalar_round_half_to_even )
152
150
153
- @node_rewriter (tracks = [Elemwise ])
154
- def find_measurable_roundings (
155
- fgraph : FunctionGraph , node : Node
156
- ) -> Optional [List [MeasurableRound ]]:
151
+
152
+ @node_rewriter ([ceil ])
153
+ def find_measurable_ceil (fgraph : FunctionGraph , node : Node ):
154
+ return construct_measurable_rounding (fgraph , node , measurable_ceil )
155
+
156
+
157
+ @node_rewriter ([floor ])
158
+ def find_measurable_floor (fgraph : FunctionGraph , node : Node ):
159
+ return construct_measurable_rounding (fgraph , node , measurable_floor )
160
+
161
+
162
+ @node_rewriter ([round_half_to_even ])
163
+ def find_measurable_round_half_to_even (fgraph : FunctionGraph , node : Node ):
164
+ return construct_measurable_rounding (fgraph , node , measurable_round_half_to_even )
165
+
166
+
167
+ measurable_ir_rewrites_db .register (
168
+ "find_measurable_ceil" ,
169
+ find_measurable_ceil ,
170
+ "basic" ,
171
+ "censoring" ,
172
+ )
173
+ measurable_ir_rewrites_db .register (
174
+ "find_measurable_floor" ,
175
+ find_measurable_floor ,
176
+ "basic" ,
177
+ "censoring" ,
178
+ )
179
+ measurable_ir_rewrites_db .register (
180
+ "find_measurable_round_half_to_even" ,
181
+ find_measurable_round_half_to_even ,
182
+ "basic" ,
183
+ "censoring" ,
184
+ )
185
+
186
+
187
+ def construct_measurable_rounding (
188
+ fgraph : FunctionGraph , node : Node , rounded_op : "Op"
189
+ ) -> Optional [List ["Variable" ]]:
157
190
158
191
rv_map_feature = getattr (fgraph , "preserve_rv_mappings" , None )
159
192
if rv_map_feature is None :
160
193
return None # pragma: no cover
161
194
162
- if isinstance (node .op , MeasurableRound ):
163
- return None # pragma: no cover
164
-
165
- if not (
166
- isinstance (node .op , Elemwise )
167
- and isinstance (node .op .scalar_op , MeasurableRound .valid_scalar_types )
168
- ):
169
- return None
170
-
171
195
(rounded_var ,) = node .outputs
172
196
(base_var ,) = node .inputs
173
197
@@ -183,21 +207,11 @@ def find_measurable_roundings(
183
207
# Make base_var unmeasurable
184
208
unmeasurable_base_var = assign_custom_measurable_outputs (base_var .owner )
185
209
186
- rounded_op = MeasurableRound (node .op .scalar_op )
187
210
rounded_rv = rounded_op .make_node (unmeasurable_base_var ).default_output ()
188
211
rounded_rv .name = rounded_var .name
189
212
return [rounded_rv ]
190
213
191
214
192
- measurable_ir_rewrites_db .register (
193
- "find_measurable_roundings" ,
194
- find_measurable_roundings ,
195
- 0 ,
196
- "basic" ,
197
- "censoring" ,
198
- )
199
-
200
-
201
215
@_logprob .register (MeasurableRound )
202
216
def round_logprob (op , values , base_rv , ** kwargs ):
203
217
r"""Logprob of a rounded censored distribution
@@ -226,15 +240,15 @@ def round_logprob(op, values, base_rv, **kwargs):
226
240
"""
227
241
(value ,) = values
228
242
229
- if isinstance ( op . scalar_op , RoundHalfToEven ) :
243
+ if op == measurable_round_half_to_even :
230
244
value = at .round (value )
231
245
value_upper = value + 0.5
232
246
value_lower = value - 0.5
233
- elif isinstance ( op . scalar_op , Floor ) :
247
+ elif op == measurable_floor :
234
248
value = at .floor (value )
235
249
value_upper = value + 1.0
236
250
value_lower = value
237
- elif isinstance ( op . scalar_op , Ceil ) :
251
+ elif op == measurable_ceil :
238
252
value = at .ceil (value )
239
253
value_upper = value
240
254
value_lower = value - 1.0
0 commit comments