Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions thunder/dynamo/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,23 @@
if TYPE_CHECKING:
from collections.abc import Callable
from typing import Any
from torch.fx import GraphModule


# TODO: investigate and see if there's a cleaner way to prevent the error.
# ``cannot extract sympy expressions from <torch.cuda.Stream device=cuda:0 cuda_stream=0x0> <class 'torch.cuda.streams.Stream'>``
def _preprocess_cuda_stream_objects(gm: GraphModule) -> None:
"""Preprocess the graph to handle :class:`torch.cuda.Stream` objects.

Since :class:`torch.cuda.Stream` does not have sympy expression apparently,
manually setting its metadata to :obj:`None` to avoid an error such as
``cannot extract sympy expressions from <torch.cuda.Stream device=cuda:0 cuda_stream=0x0> <class 'torch.cuda.streams.Stream'>``
"""
for node in gm.graph.nodes:
if hasattr(node, "meta") and "example_value" in node.meta:
example_value = node.meta["example_value"]
if isinstance(example_value, torch.cuda.Stream):
node.meta["example_value"] = None


def _splitter(
Expand Down Expand Up @@ -166,6 +183,7 @@ def callback(node) -> int:
for n in functionctx_nodes_to_del:
gm.graph.erase_node(n)
gm.recompile()
_preprocess_cuda_stream_objects(gm)

# `split_module` iterates over nodes and determines the partition to place them based on the callback.
split_gm: torch.fx.GraphModule = split_module(
Expand Down
11 changes: 11 additions & 0 deletions thunder/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,11 @@ def forward(self, *args):
self.graph_module.graph = original_graph
self.graph_module.recompile()
self.compiled_fn = self.graph_module
except (NotImplementedError, AssertionError) as e:
warnings.warn(f"torch._inductor.compile failed: {e}. Falling back to eager.")
self.graph_module.graph = original_graph
self.graph_module.recompile()
self.compiled_fn = self.graph_module

return self.compiled_fn(*args)

Expand Down Expand Up @@ -495,6 +500,12 @@ def is_node_supported_by_thunder(
target = node.target # Target is the function to call.
if node.op == "call_method":
target = getattr(torch.Tensor, node.target, None)
if target is None and hasattr(torch.cuda.Stream, node.target):
split_reason = SplitReason(
SplitReasonType.MISSING_OP_SUPPORT,
f"node with name {node.name} and target {node.target} is a `torch.cuda.Stream` method which is not supported by Thunder.",
)
return False, split_reason
assert target is not None, f"Failed to find method {node.target}"

# If the operation has automatic registration, we mark it as unsupported as `inductor` might be
Expand Down
16 changes: 16 additions & 0 deletions thunder/tests/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1919,3 +1919,19 @@ def fake_compile(*args, **kwargs):
with patch("torch._inductor.compile", side_effect=fake_compile):
cfunc = thunderfx(func)
cfunc(x)


@requiresCUDA
def test_stream_op():
def fn():
cuda = torch.device("cuda")
s = torch.cuda.streams.Stream()
s.wait_stream(torch.cuda.current_stream(cuda))

cfunc = thunderfx(fn)
cfunc()
split_reasons = cfunc._backend.subgraph_infos[0].split_reasons
assert any(
"is a `torch.cuda.Stream` method which is not supported by Thunder" in getattr(reason, "info", "")
for reason in split_reasons
)
Loading