Skip to content

Commit

Permalink
[ca] add compiled autograd to CompileId (pytorch#141907)
Browse files Browse the repository at this point in the history
tlparse PR: pytorch/tlparse#83

Pull Request resolved: pytorch#141907
Approved by: https://github.com/ezyang
  • Loading branch information
xmfan authored and pytorchmergebot committed Dec 21, 2024
1 parent 0ce233b commit 4ee166b
Show file tree
Hide file tree
Showing 8 changed files with 226 additions and 29 deletions.
4 changes: 2 additions & 2 deletions test/dynamo/test_frame_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,15 @@ def callback1(frame, cache_entry, frame_state):
if frame.f_code in code_map1:
transformed_code = code_map1[frame.f_code]
return torch._dynamo.types.GuardedCode(
transformed_code, empty_guard_manager, CompileId(0, 0)
transformed_code, empty_guard_manager, CompileId(None, 0, 0)
)
return None

def callback2(frame, cache_entry, frame_state):
if frame.f_code in code_map2:
transformed_code = code_map2[frame.f_code]
return torch._dynamo.types.GuardedCode(
transformed_code, empty_guard_manager, CompileId(0, 0)
transformed_code, empty_guard_manager, CompileId(None, 0, 0)
)
return None

Expand Down
115 changes: 114 additions & 1 deletion test/dynamo/test_structured_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,31 @@ def assertParses(self):
finally:
shutil.rmtree(out, ignore_errors=True)

def test_compile_id_serialization_deserialization(self):
cid = torch._guards.CompileId(
frame_id=1,
frame_compile_id=2,
)
assert cid == torch._guards.CompileId.from_string(str(cid))

cid = torch._guards.CompileId(
compiled_autograd_id=1,
frame_id=2,
frame_compile_id=3,
)
assert cid == torch._guards.CompileId.from_string(str(cid))

cid = torch._guards.CompileId(
compiled_autograd_id=1,
frame_id=None,
frame_compile_id=None,
)
assert cid == torch._guards.CompileId.from_string(str(cid))

for bad_cid in ["-/-", "-/1", "1/-", "!1/2", "!1/-/-"]:
with self.assertRaises(ValueError):
torch._guards.CompileId.from_string(bad_cid)

@requires_cuda
def test_schedule(self):
fn_opt = torch.compile(inductor_schedule_fn, backend="inductor")
Expand Down Expand Up @@ -873,9 +898,97 @@ def fn(a):
fn_opt(x)
# Should print twice, including inductor_output_code
self.assertParses()
chromium_event = '{"chromium_event": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}'
chromium_event = (
'{"chromium_event": {}, "frame_id": 0, "frame_compile_id": 0, '
'"attempt": 0, "has_payload": "HASH"}'
)
self.assertTrue(chromium_event in self.buffer.getvalue())

@requires_tlparse
@torch._dynamo.config.patch("compiled_autograd", True)
@torch._inductor.config.patch("fx_graph_cache", True)
@show_chrome_events
def test_compiled_autograd_id(self):
def fn(a):
return a.sin().sum().backward()

x = torch.tensor([1.0], requires_grad=True)
fn_opt = torch._dynamo.optimize("inductor")(fn)
fn_opt(x)
torch._dynamo.reset()
# Trigger a cache hit
fn_opt(x)
# Should print twice, including inductor_output_code
self.assertParses()
chromium_events = [
(
'{"chromium_event": {}, "frame_id": 0, "frame_compile_id": 0, '
'"attempt": 0, "has_payload": "HASH"}'
),
(
'{"compiled_autograd_graph": {}, "compiled_autograd_id": 0, '
'"attempt": 0, "has_payload": "HASH"}'
),
(
'{"chromium_event": {}, "compiled_autograd_id": 0, "frame_id": 1, "frame_compile_id": 0, '
'"attempt": 0, "has_payload": "HASH"}'
),
]
logs = self.buffer.getvalue()
self.assertTrue(all(event in logs for event in chromium_events))

@requires_tlparse
@torch._dynamo.config.patch("compiled_autograd", True)
def test_compiled_autograd_attribution(self):
# multiple dynamo recompiles should still be attributed to the parent compiled autograd id
def fn():
class MySin(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return torch.sin(x)

@staticmethod
def backward(ctx, gO):
print("graph break")
(x,) = ctx.saved_tensors
print("graph break")
return gO * torch.cos(x)

grads = []
for i in [10, 100, 10, 15, 20, 25]:
x = torch.arange(0.0, i, requires_grad=True)
out = MySin.apply(x)
loss = out.sum()
loss.backward()
grads.append(x.grad)

return grads

fn_opt = torch.compile(fn)
fn_opt()
self.assertParses()
expected = [
'{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}',
'{"dynamo_start": {"stack": "STACK"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0}',
'{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 0, "frame_id": 2, "frame_compile_id": 0, "attempt": 0}',
'{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0}',
'{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}',
'{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 0, "frame_id": 5, "frame_compile_id": 0, "attempt": 0}',
'{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 0, "frame_id": 6, "frame_compile_id": 0, "attempt": 0}',
'{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 0, "frame_id": 7, "frame_compile_id": 0, "attempt": 0}',
'{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 0, "frame_id": 8, "frame_compile_id": 0, "attempt": 0}',
'{"dynamo_start": {"stack": "STACK"}, "frame_id": 1, "frame_compile_id": 1, "attempt": 0}',
'{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 9, "frame_compile_id": 0, "attempt": 0}',
'{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 5, "frame_compile_id": 1, "attempt": 0}',
'{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 6, "frame_compile_id": 1, "attempt": 0}',
'{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 10, "frame_compile_id": 0, "attempt": 0}',
'{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 9, "frame_compile_id": 1, "attempt": 0}',
'{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 10, "frame_compile_id": 1, "attempt": 0}',
]
logs = self.buffer.getvalue()
self.assertTrue(all(event in logs for event in expected))


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
1 change: 1 addition & 0 deletions torch/_dynamo/cache_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def exceeds_cache_size_limit(
# e.g. due to guarded objects being freed. This technically makes the
# will_compilation_exceed_accumulated_limit check unnecessary, but we will keep the
# check in case we have a better fix in the future.
assert compile_id.frame_compile_id is not None
if compile_id.frame_compile_id >= config.accumulated_cache_size_limit:
return True, "accumulated_cache_size_limit"
return False, ""
18 changes: 17 additions & 1 deletion torch/_dynamo/compiled_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from torch._dynamo.source import GetItemSource, LocalSource
from torch._dynamo.utils import counters, lazy_format_graph_code, set_locals_to_steal
from torch._guards import compile_context, CompileContext, CompileId
from torch._logging import getArtifactLogger, trace_structured
from torch._prims_common import clone_preserve_strides
from torch._subclasses import FakeTensorMode
Expand Down Expand Up @@ -70,6 +71,18 @@ def maybe_clone(x):
COMPILE_COUNTER = itertools.count()


def make_compile_context(compiled_autograd_id):
return compile_context(
CompileContext(
CompileId(
compiled_autograd_id=compiled_autograd_id,
frame_id=None,
frame_compile_id=None,
)
)
)


class AutogradCompilerInstance:
def __init__(self, compiler_fn) -> None:
self.compiler_fn = compiler_fn
Expand Down Expand Up @@ -102,6 +115,8 @@ def begin_capture(
):
counters["compiled_autograd"]["captures"] += 1
self.id = next(COMPILE_COUNTER)
self.compile_context = make_compile_context(self.id)
self.compile_context.__enter__()
self.aot_graph_cls_name: Optional[str] = None
self.aot_graph_infos: Dict[int, Dict[str, Any]] = {}
self.fx_tracer.root = torch.nn.Module()
Expand Down Expand Up @@ -401,11 +416,12 @@ def runtime_wrapper(compiled_fn, inputs, sizes, scalars, hooks):
for i in runtime_inputs_to_move:
inputs[i] = inputs[i].pin_memory().cuda(non_blocking=True)

with _disable():
with _disable(), make_compile_context(self.id):
return compiled_fn(inputs, sizes, scalars, hooks)
finally:
in_compiled_autograd_region = False

self.compile_context.__exit__(None, None, None)
return runtime_wrapper, self.compiler_fn(graph)

def rename_aot_dispatcher_nodes(self):
Expand Down
11 changes: 9 additions & 2 deletions torch/_dynamo/convert_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,14 @@ def __call__(
frame_compile_id = FRAME_COMPILE_COUNTER[frame_id]
FRAME_COMPILE_COUNTER[frame_id] += 1

compile_id = CompileId(frame_id, frame_compile_id)
compiled_autograd_id = None
if prior := CompileContext.current_compile_id():
compiled_autograd_id = prior.compiled_autograd_id
compile_id = CompileId(
compiled_autograd_id=compiled_autograd_id,
frame_id=frame_id,
frame_compile_id=frame_compile_id,
)

signpost_event(
"dynamo",
Expand Down Expand Up @@ -1292,7 +1299,7 @@ def replay(filename: str) -> None:
cache_entry=None,
frame=None,
frame_state={},
compile_id=CompileId(42, 999),
compile_id=CompileId(frame_id=42, frame_compile_id=999),
)
finally:
config.replay_record_enabled = original_replay_val
Expand Down
6 changes: 5 additions & 1 deletion torch/_dynamo/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,11 @@ def insert_nops(instructions: List[Any], code_options: Any) -> None:
torch_function_mode_stack=[],
)

return GuardedCode(code, CheckFunctionManager(frame.f_code, graph).guard_manager, CompileId(0, 0)) # type: ignore[arg-type]
return GuardedCode(
code,
CheckFunctionManager(frame.f_code, graph).guard_manager, # type: ignore[arg-type]
CompileId(frame_id=0, frame_compile_id=0),
)


class CompileCounter:
Expand Down
58 changes: 52 additions & 6 deletions torch/_guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
import enum
import functools
import logging
import re
import threading
import traceback
import unittest.mock
import weakref
from abc import abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -48,29 +50,72 @@
and no guard installation notions here.
"""

COMPILE_ID_PATTERN = re.compile(r"^(?P<frame_id>\d+)/(?P<frame_compile_id>\d+)$")
CA_COMPILE_ID_PATTERN = re.compile(
r"^!(?P<compiled_autograd_id>\d+)(?:/(?P<frame_id>\d+)/(?P<frame_compile_id>\d+))?$"
)

class CompileId(NamedTuple):
frame_id: int
# [Note: Updating CompiledId]
#
# CompiledId represents a unique program-level identifier, and we want to keep that
# property as the codebase evolves. This property is relied on even outside of the pytorch
# repo, e.g. tlparse or other internal tooling. The in-memory format can be freely changed,
# as those dependencies only consume the string serialization.
#
# The string form should be:
# 1. Program-level uid: CompileId can uniquely identify a compiled graph.
# 2. Storage efficient: This object is logged in nearly every entry. We should elide symbols when possible.
# 3. Compact: The string form is directly displayed by some tools. Special symbols are okay.


# TODO: mark as kw_only=True once we drop support for <Python 3.10
@dataclass(frozen=True)
class CompileId:
frame_id: Optional[int]
# This id is per-frame, and counts how many times we've compiled this
# frame. This could have been a global id but having this be per-frame
# gives you a better intuitive sense for how many recompiles have occurred
# so far.
frame_compile_id: int
frame_compile_id: Optional[int]

# torch.compiling a compiled autograd graph
compiled_autograd_id: Optional[int] = None

# TODO: consider also tracking the recompilation count
# See Note: Updating CompileId

def __str__(self):
return f"{self.frame_id}/{self.frame_compile_id}"
# NOTE: Keep this in sync with both from_string and the tlparse repo
if self.compiled_autograd_id is not None:
assert (self.frame_id is None) == (self.frame_compile_id is None)
frame_str = ""
if self.frame_id is not None:
frame_str = f"/{self.frame_id}/{self.frame_compile_id}"

return f"!{self.compiled_autograd_id}{frame_str}"
else:
assert self.frame_id is not None and self.frame_compile_id is not None
return f"{self.frame_id}/{self.frame_compile_id}"

@classmethod
def from_string(cls, compile_id: Optional[str]):
"""
Factory method that creates a CompileId from its string representation.
Keep this in sync with the __str__ method.
"""
if compile_id is None:
return None
try:
frame_id, frame_compile_id = compile_id.split("/")
return cls(int(frame_id), int(frame_compile_id))
for pattern in (COMPILE_ID_PATTERN, CA_COMPILE_ID_PATTERN):
if match := pattern.match(compile_id):
groups = match.groupdict()
for k, v in groups.items():
if v is not None:
groups[k] = int(v)
return cls(**groups) # type: ignore[arg-type]
else:
raise ValueError

except Exception as e:
raise ValueError(f"Invalid compile_id '{compile_id}'") from e

Expand All @@ -82,6 +127,7 @@ class TraceId(NamedTuple):
attempt: int

def __str__(self):
# Keep this in sync with tlparse repo
if self.attempt == 0:
return str(self.compile_id)
else:
Expand Down
Loading

0 comments on commit 4ee166b

Please sign in to comment.