Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions thunder/dynamo/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,22 @@
if TYPE_CHECKING:
from collections.abc import Callable
from typing import Any
from torch.fx import GraphModule


def _preprocess_cuda_stream_objects(gm: GraphModule) -> None:
"""Preprocess the graph to handle :class:`torch.cuda.Stream` objects.

Since :class:`torch.cuda.Stream` does not have sympy expression apparently,
manually setting its metadata to :obj:`None` to avoid an error such as
``cannot extract sympy expressions from <torch.cuda.Stream device=cuda:0 cuda_stream=0x0> <class 'torch.cuda.streams.Stream'>``
"""
for node in gm.graph.nodes:
if hasattr(node, "meta") and "example_value" in node.meta:
example_value = node.meta["example_value"]
if isinstance(example_value, torch.cuda.Stream):
node.meta["example_value"] = None
node.meta["_original_stream_type"] = type(example_value).__name__
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this variable used for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this variable is something Torch will check so we manually set it. Maybe @crcrpar knows better about it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ask Cursor

Copy link
Collaborator

@kshitij12345 kshitij12345 Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, I tried finding reference to _original_stream_type in PyTorch and found nothing. I think this is a dead-code.

Also, I don't think we need _preprocess_cuda_stream_objects as ATM we want to pass it to the fallback (which should just be able to handle it).

With the above change of removing _preprocess_cuda_stream_objects, I don't think we even need new fallback path for NotImplementedError and AssertionError in LazyInductorModule.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the above change of removing _preprocess_cuda_stream_objects, I don't think we even need new fallback path for NotImplementedError and AssertionError in LazyInductorModule.

Did you try it with the sglang models, I seem to get the following error when commented out _preprocess_cuda_stream_object when running gpt-oss (same error as @crcrpar met when he added the function)

  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 414, in __init__
    raise Exception(
Exception: Capture cuda graph failed: backend='<thunder.dynamo.compiler.ThunderCompiler object at 0xffda8e9ab470>' raised:
AssertionError: cannot extract sympy expressions from <torch.cuda.Stream device=cuda:0 cuda_stream=0x1552ee20> <class 'torch.cuda.streams.Stream'>

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried with the new test in the PR not via sglang. In that case, it makes sense to keep _preprocess_cuda_stream_object. Thank you for checking.



def _splitter(
Expand Down Expand Up @@ -166,6 +182,7 @@ def callback(node) -> int:
for n in functionctx_nodes_to_del:
gm.graph.erase_node(n)
gm.recompile()
_preprocess_cuda_stream_objects(gm)

# `split_module` iterates over nodes and determines the partition to place them based on the callback.
split_gm: torch.fx.GraphModule = split_module(
Expand Down
11 changes: 11 additions & 0 deletions thunder/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,11 @@ def forward(self, *args):
self.graph_module.graph = original_graph
self.graph_module.recompile()
self.compiled_fn = self.graph_module
except (NotImplementedError, AssertionError) as e:
warnings.warn(f"torch._inductor.compile failed: {e}. Falling back to eager.")
self.graph_module.graph = original_graph
self.graph_module.recompile()
self.compiled_fn = self.graph_module

return self.compiled_fn(*args)

Expand Down Expand Up @@ -495,6 +500,12 @@ def is_node_supported_by_thunder(
target = node.target # Target is the function to call.
if node.op == "call_method":
target = getattr(torch.Tensor, node.target, None)
if target is None and hasattr(torch.cuda.Stream, node.target):
split_reason = SplitReason(
SplitReasonType.MISSING_OP_SUPPORT,
f"node with name {node.name} and target {node.target} is a `torch.cuda.Stream` method which is not supported by Thunder.",
)
return False, split_reason
assert target is not None, f"Failed to find method {node.target}"

# If the operation has automatic registration, we mark it as unsupported as `inductor` might be
Expand Down
16 changes: 16 additions & 0 deletions thunder/tests/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1919,3 +1919,19 @@ def fake_compile(*args, **kwargs):
with patch("torch._inductor.compile", side_effect=fake_compile):
cfunc = thunderfx(func)
cfunc(x)


@requiresCUDA
def test_stream_op():
def fn():
cuda = torch.device("cuda")
s = torch.cuda.streams.Stream()
s.wait_stream(torch.cuda.current_stream(cuda))

cfunc = thunderfx(fn)
cfunc()
split_reasons = cfunc._backend.subgraph_infos[0].split_reasons
assert any(
"is a `torch.cuda.Stream` method which is not supported by Thunder" in getattr(reason, "info", "")
for reason in split_reasons
)
Loading