-
Notifications
You must be signed in to change notification settings - Fork 108
Support CUDA stream operators in ThunderFX #2761
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
crcrpar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd not be an appropriate one to review this as I wrote part of this.
riccardofelluga
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a small repro that we can use to test this PR?
thunder/dynamo/splitter.py
Outdated
| example_value = node.meta["example_value"] | ||
| if isinstance(example_value, torch.cuda.Stream): | ||
| node.meta["example_value"] = None | ||
| node.meta["_original_stream_type"] = type(example_value).__name__ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this variable used for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this variable is something Torch will check so we manually set it. Maybe @crcrpar knows better about it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ask Cursor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question, I tried finding reference to _original_stream_type in PyTorch and found nothing. I think this is a dead-code.
Also, I don't think we need _preprocess_cuda_stream_objects as ATM we want to pass it to the fallback (which should just be able to handle it).
With the above change of removing _preprocess_cuda_stream_objects, I don't think we even need new fallback path for NotImplementedError and AssertionError in LazyInductorModule.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the above change of removing _preprocess_cuda_stream_objects, I don't think we even need new fallback path for NotImplementedError and AssertionError in LazyInductorModule.
Did you try it with the sglang models, I seem to get the following error when commented out _preprocess_cuda_stream_object when running gpt-oss (same error as @crcrpar met when he added the function)
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 414, in __init__
raise Exception(
Exception: Capture cuda graph failed: backend='<thunder.dynamo.compiler.ThunderCompiler object at 0xffda8e9ab470>' raised:
AssertionError: cannot extract sympy expressions from <torch.cuda.Stream device=cuda:0 cuda_stream=0x1552ee20> <class 'torch.cuda.streams.Stream'>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried with the new test in the PR not via sglang. In that case, it makes sense to keep _preprocess_cuda_stream_object. Thank you for checking.
Yeah, I add a case to test it |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think just making sure we cause a split when we encounter torch.cuda.Stream object should let the test case (and other relevant code/model) pass.
thunder/dynamo/splitter.py
Outdated
| example_value = node.meta["example_value"] | ||
| if isinstance(example_value, torch.cuda.Stream): | ||
| node.meta["example_value"] = None | ||
| node.meta["_original_stream_type"] = type(example_value).__name__ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question, I tried finding reference to _original_stream_type in PyTorch and found nothing. I think this is a dead-code.
Also, I don't think we need _preprocess_cuda_stream_objects as ATM we want to pass it to the fallback (which should just be able to handle it).
With the above change of removing _preprocess_cuda_stream_objects, I don't think we even need new fallback path for NotImplementedError and AssertionError in LazyInductorModule.
kshitij12345
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @kiya00
|
Hi @KaelanDt , could you help review this? |
Before submitting
What does this PR do?
Fixes #2332 .
Support CUDA stream operators in ThunderFX
This PR added the patch mentioned in #2332 and a follow-up fix in #2750 (comment)