File tree Expand file tree Collapse file tree 1 file changed +15
-3
lines changed Expand file tree Collapse file tree 1 file changed +15
-3
lines changed Original file line number Diff line number Diff line change 33
33
zip , unsafe_zip = util .safe_zip , zip
34
34
35
35
36
- def all_eqns (jaxpr : core .Jaxpr ) -> Iterator [tuple [core .Jaxpr , core .JaxprEqn ]]:
36
+ def _all_eqns (
37
+ jaxpr : core .Jaxpr , visited : set [core .Jaxpr ] | None ,
38
+ ) -> Iterator [tuple [core .Jaxpr , core .JaxprEqn ]]:
37
39
for eqn in jaxpr .eqns :
38
40
yield (jaxpr , eqn )
39
41
for subjaxpr in core .subjaxprs (jaxpr ):
40
- yield from all_eqns (subjaxpr )
42
+ if visited is None :
43
+ yield from _all_eqns (subjaxpr , visited )
44
+ elif subjaxpr not in visited :
45
+ visited .add (subjaxpr )
46
+ yield from _all_eqns (subjaxpr , visited )
47
+
48
+ def all_eqns (
49
+ jaxpr : core .Jaxpr , revisit_inner_jaxprs : bool = True
50
+ ) -> Iterator [tuple [core .Jaxpr , core .JaxprEqn ]]:
51
+ yield from _all_eqns (jaxpr , None if revisit_inner_jaxprs else set ())
52
+
41
53
42
54
def collect_eqns (jaxpr : core .Jaxpr , key : Callable ):
43
55
d = defaultdict (list )
@@ -206,7 +218,7 @@ def pprof_equation_profile(jaxpr: core.Jaxpr) -> bytes:
206
218
"""
207
219
d = Counter (
208
220
(eqn .source_info .traceback , eqn .primitive )
209
- for _ , eqn in all_eqns (jaxpr )
221
+ for _ , eqn in all_eqns (jaxpr , revisit_inner_jaxprs = False )
210
222
)
211
223
return _pprof_profile (d )
212
224
You can’t perform that action at this time.
0 commit comments