Skip to content

Commit

Permalink
[compiled autograd] log compilation time to perfetto (pytorch#140964)
Browse files Browse the repository at this point in the history
https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmprli4iy/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=100
```
[
  {
    "args": {
      "compile_id": "0/-/-",
      "graph_id": 0
    },
    "cat": "dynamo_timed",
    "name": "compiled_autograd",
    "ph": "B",
    "pid": 0,
    "tid": 0,
    "ts": 1733886868992655.8
  },
  {
    "args": {
      "compile_id": "0/-/-",
      "graph_id": 0
    },
    "cat": "dynamo_timed",
    "name": "compiled_autograd",
    "ph": "E",
    "pid": 0,
    "tid": 0,
    "ts": 1733886869130681.0
  },
  {
    "args": {
      "compile_id": "0/0/0"
    },
    "cat": "dynamo_timed",
    "name": "dynamo",
    "ph": "B",
    "pid": 0,
    "tid": 0,
    "ts": 1733886869134350.5
  },
  {
```

Pull Request resolved: pytorch#140964
Approved by: https://github.com/masnesral
ghstack dependencies: pytorch#141907, pytorch#143175
  • Loading branch information
xmfan authored and pytorchmergebot committed Dec 21, 2024
1 parent c7d7eff commit a8953c3
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
25 changes: 25 additions & 0 deletions test/dynamo/test_structured_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,31 @@ def backward(ctx, gO):
logs = self.buffer.getvalue()
self.assertTrue(all(event in logs for event in expected))

@requires_tlparse
@show_chrome_events
def test_compiled_autograd_chromium(self):
with torch._dynamo.compiled_autograd._enable(torch.compile):
for i in [10, 100, 10, 15, 20, 25]:
x = torch.arange(0.0, i, requires_grad=True)
loss = x.sum()
loss.backward()

self.assertParses()
expected = [
'{"chromium_event": {}, "compiled_autograd_id": 0, "attempt": 0, "has_payload": "HASH"}',
'{"chromium_event": {}, "compiled_autograd_id": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, '
'"has_payload": "HASH"}',
'{"chromium_event": {}, "compiled_autograd_id": 1, "frame_id": 1, "frame_compile_id": 1, "attempt": 0, '
'"has_payload": "HASH"}',
'{"chromium_event": {}, "compiled_autograd_id": 1, "attempt": 0, "has_payload": "HASH"}',
'{"chromium_event": {}, "compiled_autograd_id": 1, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, '
'"has_payload": "HASH"}',
'{"chromium_event": {}, "compiled_autograd_id": 1, "frame_id": 1, "frame_compile_id": 1, "attempt": 0, '
'"has_payload": "HASH"}',
]
logs = self.buffer.getvalue()
self.assertTrue(all(event in logs for event in expected))


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
23 changes: 22 additions & 1 deletion torch/_dynamo/compiled_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import functools
import itertools
import operator
import time
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union

import torch
Expand All @@ -12,7 +13,12 @@
FakeCompiledAutogradEngine,
)
from torch._dynamo.source import GetItemSource, LocalSource
from torch._dynamo.utils import counters, lazy_format_graph_code, set_locals_to_steal
from torch._dynamo.utils import (
counters,
get_chromium_event_logger,
lazy_format_graph_code,
set_locals_to_steal,
)
from torch._guards import compile_context, CompileContext, CompileId
from torch._logging import getArtifactLogger, trace_structured
from torch._prims_common import clone_preserve_strides
Expand Down Expand Up @@ -117,6 +123,14 @@ def begin_capture(
self.id = next(COMPILE_COUNTER)
self.compile_context = make_compile_context(self.id)
self.compile_context.__enter__()
self.start_time_ns = time.time_ns()
get_chromium_event_logger().log_event_start(
"compiled_autograd",
self.start_time_ns,
{"graph_id": self.id},
log_pt2_compile_event=True,
)

self.aot_graph_cls_name: Optional[str] = None
self.aot_graph_infos: Dict[int, Dict[str, Any]] = {}
self.fx_tracer.root = torch.nn.Module()
Expand Down Expand Up @@ -421,6 +435,13 @@ def runtime_wrapper(compiled_fn, inputs, sizes, scalars, hooks):
finally:
in_compiled_autograd_region = False

get_chromium_event_logger().log_event_end(
"compiled_autograd",
time.time_ns(),
{"graph_id": self.id},
self.start_time_ns,
log_pt2_compile_event=True,
)
self.compile_context.__exit__(None, None, None)
return runtime_wrapper, self.compiler_fn(graph)

Expand Down

0 comments on commit a8953c3

Please sign in to comment.