Skip to content

Commit 9ada615

Browse files
Delay Inductor until we get real input tensors (#2689)
Co-authored-by: Masato Shinokawa <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f8648aa commit 9ada615

File tree

5 files changed

+69
-64
lines changed

5 files changed

+69
-64
lines changed

thunder/dynamo/compiler.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,6 @@ def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, tor
147147
split_module, subgraph_info = _splitter(
148148
gm,
149149
partial(jit, **thunder_options),
150-
torch._inductor.compile,
151-
sample_args,
152150
thunder_options,
153151
)
154152
self.subgraph_infos.append(subgraph_info)

thunder/dynamo/report.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1107,7 +1107,7 @@ def foo(x):
11071107
thunder_options[k] = v
11081108

11091109
thunder_jit = partial(jit, **thunder_options, nv_save_fake_inputs=True)
1110-
_, subgraph_info = _splitter(gm, thunder_jit, torch._inductor.compile, _unused_sample_args=None)
1110+
_, subgraph_info = _splitter(gm, thunder_jit)
11111111

11121112
thunder_module_names = [f"{report.graph_name}_{name}" for name in get_thunder_module_names(subgraph_info)]
11131113
original_modules_to_thunder_modules = (

thunder/dynamo/splitter.py

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@
33
from typing import TYPE_CHECKING
44
import copy
55
from functools import partial
6-
import warnings
76

87
import torch
9-
from torch._subclasses.fake_tensor import DynamicOutputShapeException
108
from torch.fx.passes.split_module import split_module
119

1210
from thunder.core import baseutils
@@ -16,12 +14,12 @@
1614
CompilerType,
1715
SplitReason,
1816
SplitReasonType,
17+
LazyInductorModule,
1918
is_node_supported_by_thunder,
2019
get_nodes_in_unsupported_ctx_regions,
2120
update_node_and_submodule,
2221
recompile_graph,
2322
checkpoint_converter,
24-
make_fake_arguments,
2523
_get_example_inputs_from_placeholder,
2624
_ThunderSplitGraphModule,
2725
translate_dtensor_ops,
@@ -35,8 +33,6 @@
3533
def _splitter(
3634
gm: torch.fx.GraphModule,
3735
thunder_jit: Callable,
38-
torch_inductor: Callable,
39-
_unused_sample_args: list[torch.SymInt, torch.Tensor],
4036
thunder_options: dict[str, Any] | None = None,
4137
) -> tuple[torch.fx.GraphModule, SubgraphInfo]:
4238
"""
@@ -226,32 +222,10 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
226222
elif node.name.startswith("submod"): # For inductor
227223
graph_module = getattr(split_gm, node.name)
228224

229-
class ModuleWrapper(torch.nn.Module):
230-
def __init__(self, fn):
231-
super().__init__()
232-
self.fn = fn
233-
234-
def forward(self, *args, **kwargs):
235-
return self.fn(*args, **kwargs)
236-
237-
def fallback_eager(reason: str) -> torch.nn.Module:
238-
warnings.warn(f"{reason} Falling back to eager.")
239-
# TODO: Use torch.compile here. Investigate its behavior and ensure correctness.
240-
return graph_module
241-
242-
fake_args = make_fake_arguments(graph_module)
243-
if fake_args is None:
244-
jit_fn = fallback_eager("Example values for arguments are not available.")
245-
else:
246-
try:
247-
# torch._inductor.compile returns a function, but update_node_and_submodule expects a Module
248-
jit_fn = ModuleWrapper(torch_inductor(graph_module, fake_args))
249-
except DynamicOutputShapeException as e:
250-
# This exception is meant to be handled by Dynamo, which is responsible for graph break
251-
jit_fn = fallback_eager(f"Dynamic output shape operator encountered: {e}.")
252-
253-
# This is for ease of debugging. We add graph attribute so GraphModule.print_readable will print it
254-
jit_fn.graph = graph_module.graph
225+
fake_mode = torch._guards.detect_fake_mode()
226+
# Delay Inductor compilation until invocation with real tensors,
227+
# because we do not know the strides of tensors that Thunder-compiled submodules return.
228+
jit_fn = LazyInductorModule(graph_module, fake_mode)
255229

256230
# Update the node name from "submod_*" to "inductor_*" for more user-friendly names
257231
update_node_and_submodule(split_gm, node, node.name.replace("submod", "inductor"), jit_fn)

thunder/dynamo/utils.py

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22
from collections.abc import Callable, Sequence
3+
from contextlib import contextmanager
34
from enum import Enum, auto
45
from typing import TYPE_CHECKING
56
import dataclasses
@@ -8,12 +9,15 @@
89
from types import NoneType
910
from collections import defaultdict
1011
from collections import namedtuple
12+
import warnings
13+
14+
from looseversion import LooseVersion
1115

1216
import torch
1317
from torch.nn.modules.module import _addindent
1418
from torch.utils.weak import TensorWeakRef
15-
from torch._guards import detect_fake_mode
16-
from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
19+
from torch._guards import tracing, TracingContext
20+
from torch._subclasses.fake_tensor import DynamicOutputShapeException
1721

1822
if torch.distributed.is_available():
1923
from torch.distributed.tensor import DTensor
@@ -154,6 +158,56 @@ def is_thunder_supported_partition(self, node: torch.fx.Node) -> bool:
154158
return node.name.startswith("submod") and int(node.name.replace("submod_", "")) in self.supported_indexes
155159

156160

161+
class LazyInductorModule(torch.nn.Module):
162+
def __init__(self, graph_module, fake_mode):
163+
super().__init__()
164+
self.graph_module = graph_module
165+
self.compiled_fn = None
166+
self.fake_mode = fake_mode
167+
168+
# For ease of debugging, we add graph attribute so GraphModule.print_readable will print it
169+
self.graph = graph_module.graph
170+
171+
# TODO: Remove this once we drop support for PyTorch 2.7.x
172+
@contextmanager
173+
def _maybe_patch_increment_toplevel(self):
174+
# In PyTorch before 2.8.0, FXGraphCache assumes that it is run behind Dynamo
175+
# and tries to update metrics_context.
176+
# See https://github.com/pytorch/pytorch/pull/150423
177+
if LooseVersion(torch.__version__) >= LooseVersion("2.8.0"):
178+
yield
179+
return
180+
181+
from torch._dynamo.utils import CompileEventLogger
182+
183+
def fake_increment_toplevel(*args, **kwargs):
184+
metrics_context = torch._dynamo.utils.get_metrics_context()
185+
assert not metrics_context.in_progress()
186+
return
187+
188+
original = CompileEventLogger.increment_toplevel
189+
CompileEventLogger.increment_toplevel = fake_increment_toplevel
190+
try:
191+
yield
192+
finally:
193+
CompileEventLogger.increment_toplevel = original
194+
195+
def forward(self, *args):
196+
if self.compiled_fn is None:
197+
with self._maybe_patch_increment_toplevel():
198+
# Inductor needs fake_mode, particularly its shape_env, to handle SymInts
199+
with tracing(TracingContext(fake_mode=self.fake_mode)):
200+
try:
201+
self.compiled_fn = torch._inductor.compile(self.graph_module, args)
202+
except DynamicOutputShapeException as e:
203+
# This exception is meant to be handled by Dynamo, which is responsible for graph break
204+
# TODO: Use torch.compile for fallback. Ensure its correctness.
205+
warnings.warn(f"Dynamic output shape operator encountered: {e}. Falling back to eager.")
206+
self.compiled_fn = self.graph_module
207+
208+
return self.compiled_fn(*args)
209+
210+
157211
@dataclasses.dataclass()
158212
class ProfileStats:
159213
"""
@@ -1064,25 +1118,6 @@ def get_compiled_fn_and_timing(report, compile_fn, timer_fn):
10641118
return sorted_compiled_gm_to_measurement[0].compiled_fn
10651119

10661120

1067-
def make_fake_arguments(gm: torch.fx.GraphModule) -> list[FakeTensor] | None:
1068-
fake_mode = detect_fake_mode()
1069-
if fake_mode is None:
1070-
fake_mode = FakeTensorMode()
1071-
args = []
1072-
for node in gm.graph.nodes:
1073-
if node.op == "placeholder":
1074-
meta_val = node.meta.get("example_value")
1075-
if meta_val is None:
1076-
# We observed Dynamo creating nodes without `example_value` on Tensor.tolist().
1077-
# This no longer happens in PyTorch 2.10 (see https://github.com/pytorch/pytorch/pull/163807).
1078-
return None
1079-
if isinstance(meta_val, torch.Tensor):
1080-
# Tie to the currently enabled fake mode
1081-
meta_val = fake_mode.fake_tensor_converter.from_real_tensor(fake_mode, meta_val)
1082-
args.append(meta_val)
1083-
return args
1084-
1085-
10861121
def translate_dtensor_ops(gm: torch.fx.GraphModule) -> None:
10871122
# We need this function because:
10881123
#

thunder/tests/test_dynamo.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -998,7 +998,7 @@ def forward(self, x):
998998
n = gm.graph.find_nodes(op="output")
999999
gm.graph.erase_node(n[0])
10001000

1001-
_, subgraph_info = thunder.dynamo.splitter._splitter(gm, thunder.jit, torch._inductor.compile, [])
1001+
_, subgraph_info = thunder.dynamo.splitter._splitter(gm, thunder.jit)
10021002
original_split_gm = subgraph_info.original_split_graph_module.split_graph_module
10031003
assert original_split_gm.graph.find_nodes(op="output")
10041004
for subm in original_split_gm.children():
@@ -1768,22 +1768,21 @@ def forward(self, x):
17681768
torch.testing.assert_close(org_m.fc.weight.grad, thunder_m.fc.weight.grad)
17691769

17701770

1771-
def test_thunderfx_node_with_no_example_value():
1771+
def test_thunderfx_tolist():
17721772
def test_fn(x):
17731773
y = x + 10
17741774
z = y.tolist()[0]
17751775
return z + 2
17761776

17771777
x = torch.tensor([1, 2, 3, 4, 5])
17781778
# Without this patch, tolist() would cause graph break. See https://github.com/pytorch/pytorch/pull/163807
1779-
with patch("torch._dynamo.config.capture_scalar_outputs", True):
1780-
with pytest.warns(match="Example values for arguments are not available"):
1781-
actual = thunderfx(test_fn)(x)
1779+
with torch._dynamo.config.patch(capture_scalar_outputs=True):
1780+
actual = thunderfx(test_fn)(x)
17821781
expected = test_fn(x)
17831782
torch.testing.assert_close(actual, expected)
17841783

17851784

1786-
def test_thunderfx_no_example_value_and_autocast():
1785+
def test_thunderfx_tolist_autocast():
17871786
def fn(x):
17881787
with torch.autocast("cpu"):
17891788
y = x + 10
@@ -1792,9 +1791,8 @@ def fn(x):
17921791

17931792
x = torch.tensor([1, 2, 3, 4, 5])
17941793
# Without this patch, tolist() would cause graph break. See https://github.com/pytorch/pytorch/pull/163807
1795-
with patch("torch._dynamo.config.capture_scalar_outputs", True):
1796-
with pytest.warns(match="Example values for arguments are not available"):
1797-
actual = thunderfx(fn)(x)
1794+
with torch._dynamo.config.patch(capture_scalar_outputs=True):
1795+
actual = thunderfx(fn)(x)
17981796
expected = fn(x)
17991797
torch.testing.assert_close(actual, expected)
18001798

0 commit comments

Comments
 (0)