diff --git a/thunder/dynamo/splitter.py b/thunder/dynamo/splitter.py index c5d784a783..11ba9eca0e 100644 --- a/thunder/dynamo/splitter.py +++ b/thunder/dynamo/splitter.py @@ -28,6 +28,23 @@ if TYPE_CHECKING: from collections.abc import Callable from typing import Any + from torch.fx import GraphModule + + +# TODO: investigate and see if there's a cleaner way to prevent the error. +# ``cannot extract sympy expressions from `` +def _preprocess_cuda_stream_objects(gm: GraphModule) -> None: + """Preprocess the graph to handle :class:`torch.cuda.Stream` objects. + + Since :class:`torch.cuda.Stream` does not have sympy expression apparently, + manually setting its metadata to :obj:`None` to avoid an error such as + ``cannot extract sympy expressions from `` + """ + for node in gm.graph.nodes: + if hasattr(node, "meta") and "example_value" in node.meta: + example_value = node.meta["example_value"] + if isinstance(example_value, torch.cuda.Stream): + node.meta["example_value"] = None def _splitter( @@ -166,6 +183,7 @@ def callback(node) -> int: for n in functionctx_nodes_to_del: gm.graph.erase_node(n) gm.recompile() + _preprocess_cuda_stream_objects(gm) # `split_module` iterates over nodes and determines the partition to place them based on the callback. split_gm: torch.fx.GraphModule = split_module( diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index e0d0610fcc..8c643cee71 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -220,6 +220,11 @@ def forward(self, *args): self.graph_module.graph = original_graph self.graph_module.recompile() self.compiled_fn = self.graph_module + except (NotImplementedError, AssertionError) as e: + warnings.warn(f"torch._inductor.compile failed: {e}. Falling back to eager.") + self.graph_module.graph = original_graph + self.graph_module.recompile() + self.compiled_fn = self.graph_module return self.compiled_fn(*args) @@ -495,6 +500,12 @@ def is_node_supported_by_thunder( target = node.target # Target is the function to call. if node.op == "call_method": target = getattr(torch.Tensor, node.target, None) + if target is None and hasattr(torch.cuda.Stream, node.target): + split_reason = SplitReason( + SplitReasonType.MISSING_OP_SUPPORT, + f"node with name {node.name} and target {node.target} is a `torch.cuda.Stream` method which is not supported by Thunder.", + ) + return False, split_reason assert target is not None, f"Failed to find method {node.target}" # If the operation has automatic registration, we mark it as unsupported as `inductor` might be diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index 90b750e712..ee0004ca47 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -1919,3 +1919,19 @@ def fake_compile(*args, **kwargs): with patch("torch._inductor.compile", side_effect=fake_compile): cfunc = thunderfx(func) cfunc(x) + + +@requiresCUDA +def test_stream_op(): + def fn(): + cuda = torch.device("cuda") + s = torch.cuda.streams.Stream() + s.wait_stream(torch.cuda.current_stream(cuda)) + + cfunc = thunderfx(fn) + cfunc() + split_reasons = cfunc._backend.subgraph_infos[0].split_reasons + assert any( + "is a `torch.cuda.Stream` method which is not supported by Thunder" in getattr(reason, "info", "") + for reason in split_reasons + )