Skip to content

Conversation

@kiya00
Copy link
Collaborator

@kiya00 kiya00 commented Nov 20, 2025

Before submitting
  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

Fixes #2332 .

Support CUDA stream operators in ThunderFX
This PR added the patch mentioned in #2332 and a follow-up fix in #2750 (comment)

Copy link
Collaborator

@crcrpar crcrpar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd not be an appropriate one to review this as I wrote part of this.

Copy link
Collaborator

@riccardofelluga riccardofelluga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a small repro that we can use to test this PR?

example_value = node.meta["example_value"]
if isinstance(example_value, torch.cuda.Stream):
node.meta["example_value"] = None
node.meta["_original_stream_type"] = type(example_value).__name__
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this variable used for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this variable is something Torch will check so we manually set it. Maybe @crcrpar knows better about it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ask Cursor

Copy link
Collaborator

@kshitij12345 kshitij12345 Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, I tried finding reference to _original_stream_type in PyTorch and found nothing. I think this is a dead-code.

Also, I don't think we need _preprocess_cuda_stream_objects as ATM we want to pass it to the fallback (which should just be able to handle it).

With the above change of removing _preprocess_cuda_stream_objects, I don't think we even need new fallback path for NotImplementedError and AssertionError in LazyInductorModule.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the above change of removing _preprocess_cuda_stream_objects, I don't think we even need new fallback path for NotImplementedError and AssertionError in LazyInductorModule.

Did you try it with the sglang models, I seem to get the following error when commented out _preprocess_cuda_stream_object when running gpt-oss (same error as @crcrpar met when he added the function)

  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 414, in __init__
    raise Exception(
Exception: Capture cuda graph failed: backend='<thunder.dynamo.compiler.ThunderCompiler object at 0xffda8e9ab470>' raised:
AssertionError: cannot extract sympy expressions from <torch.cuda.Stream device=cuda:0 cuda_stream=0x1552ee20> <class 'torch.cuda.streams.Stream'>

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried with the new test in the PR not via sglang. In that case, it makes sense to keep _preprocess_cuda_stream_object. Thank you for checking.

@kiya00
Copy link
Collaborator Author

kiya00 commented Nov 21, 2025

Is there a small repro that we can use to test this PR?

Yeah, I add a case to test it

Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think just making sure we cause a split when we encounter torch.cuda.Stream object should let the test case (and other relevant code/model) pass.

example_value = node.meta["example_value"]
if isinstance(example_value, torch.cuda.Stream):
node.meta["example_value"] = None
node.meta["_original_stream_type"] = type(example_value).__name__
Copy link
Collaborator

@kshitij12345 kshitij12345 Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, I tried finding reference to _original_stream_type in PyTorch and found nothing. I think this is a dead-code.

Also, I don't think we need _preprocess_cuda_stream_objects as ATM we want to pass it to the fallback (which should just be able to handle it).

With the above change of removing _preprocess_cuda_stream_objects, I don't think we even need new fallback path for NotImplementedError and AssertionError in LazyInductorModule.

Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @kiya00

@kiya00
Copy link
Collaborator Author

kiya00 commented Nov 25, 2025

Hi @KaelanDt , could you help review this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

thunderfx fails on torch.cuda.Stream.wait_stream

4 participants