-
Notifications
You must be signed in to change notification settings - Fork 654
Add cutlass decode kernel to TritonBench #4853
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
✅ Deploy Preview for pytorch-fbgemm-docs ready!
To edit notification comments on pull requests, go to your Netlify project configuration. |
This pull request was exported from Phabricator. Differential Revision: D80041532 |
Summary: D80992628 introduced SWA FWD kernel changes which did not support decode kernels (i.e., supporting sm100_fmha_fwd but not sm100_fmha_gen). Similarly, softmax_scale introduced in D82788784 did not support decode kernels either. In blackwell_fmha_test, the these parameters are dropped during decode kernel selection (https://www.internalfb.com/code/fbsource/[cd7066706035]/fbcode/deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py?lines=182) To avoid confusion, do not test test_decode with ignored parameters. Differential Revision: D82991496
Summary: 1. Reduce pipeline stages to avoid exceeding smem limit 2. Add static_assert to make sure smem capacity violation is raised during compilation rather than runtime 3. Select the TMEM intrinsics based on sizeof(Element). 4. Update unittest to include bf16 5. Also label decode kernel test name with their corresponding test parameters. Differential Revision: D82991495
@Aya-ZIbra has exported this pull request. If you are a Meta employee, you can view the originating diff in D80041532. |
8bead29
to
3b2e442
Compare
Summary: Pull Request resolved: pytorch#4853 X-link: facebookresearch/FBGEMM#1875 Add cutlass blackwell FMHA decode kernel implementation to TritonBench benchmarking suite . Reviewed By: sryap Differential Revision: D80041532
@Aya-ZIbra has exported this pull request. If you are a Meta employee, you can view the originating diff in D80041532. |
Summary: Pull Request resolved: pytorch#4853 X-link: facebookresearch/FBGEMM#1875 Add cutlass blackwell FMHA decode kernel implementation to TritonBench benchmarking suite . Reviewed By: sryap Differential Revision: D80041532
3b2e442
to
68081f5
Compare
Summary: X-link: pytorch/FBGEMM#4853 X-link: facebookresearch/FBGEMM#1875 Add cutlass blackwell FMHA decode kernel implementation to TritonBench benchmarking suite . Reviewed By: sryap Differential Revision: D80041532
Summary: X-link: meta-pytorch/tritonbench#376 Pull Request resolved: pytorch#4853 X-link: facebookresearch/FBGEMM#1875 Add cutlass blackwell FMHA decode kernel implementation to TritonBench benchmarking suite . Reviewed By: sryap Differential Revision: D80041532
@Aya-ZIbra has exported this pull request. If you are a Meta employee, you can view the originating diff in D80041532. |
68081f5
to
5589cbb
Compare
Summary:
X-link: https://github.com/facebookresearch/FBGEMM/pull/1875
Add cutlass blackwell FMHA decode kernel implementation to TritonBench benchmarking suite .
Reviewed By: sryap
Differential Revision: D80041532