Skip to content

Commit c126ee3

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
Don't revisit shared subjaxprs in jaxpr_util.pprof_equation_profile.
It is probably a more useful default behavior not to implicitly inline everything. PiperOrigin-RevId: 769452443
1 parent 28c31b8 commit c126ee3

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

jax/_src/jaxpr_util.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,23 @@
3333
zip, unsafe_zip = util.safe_zip, zip
3434

3535

36-
def all_eqns(jaxpr: core.Jaxpr) -> Iterator[tuple[core.Jaxpr, core.JaxprEqn]]:
36+
def _all_eqns(
37+
jaxpr: core.Jaxpr, visited: set[core.Jaxpr] | None,
38+
) -> Iterator[tuple[core.Jaxpr, core.JaxprEqn]]:
3739
for eqn in jaxpr.eqns:
3840
yield (jaxpr, eqn)
3941
for subjaxpr in core.subjaxprs(jaxpr):
40-
yield from all_eqns(subjaxpr)
42+
if visited is None:
43+
yield from _all_eqns(subjaxpr, visited)
44+
elif subjaxpr not in visited:
45+
visited.add(subjaxpr)
46+
yield from _all_eqns(subjaxpr, visited)
47+
48+
def all_eqns(
49+
jaxpr: core.Jaxpr, revisit_inner_jaxprs: bool = True
50+
) -> Iterator[tuple[core.Jaxpr, core.JaxprEqn]]:
51+
yield from _all_eqns(jaxpr, None if revisit_inner_jaxprs else set())
52+
4153

4254
def collect_eqns(jaxpr: core.Jaxpr, key: Callable):
4355
d = defaultdict(list)
@@ -206,7 +218,7 @@ def pprof_equation_profile(jaxpr: core.Jaxpr) -> bytes:
206218
"""
207219
d = Counter(
208220
(eqn.source_info.traceback, eqn.primitive)
209-
for _, eqn in all_eqns(jaxpr)
221+
for _, eqn in all_eqns(jaxpr, revisit_inner_jaxprs=False)
210222
)
211223
return _pprof_profile(d)
212224

0 commit comments

Comments
 (0)