[Common] Support scaled & clamped swiglu, srelu for BF16 #3132
[Common] Support scaled & clamped swiglu, srelu for BF16 #3132zhongbozhu wants to merge 6 commits into
Conversation
Signed-off-by: zhongboz <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds six new CUDA kernels (
Confidence Score: 4/5Safe to merge; the new kernels are mathematically consistent with the existing utility functions and the test suite covers the primary code paths for both contiguous and interleaved GLU layouts. The core kernel math, alignment dispatch, and block reduction are correct. The only items worth addressing before shipping are: FP16 is absent from the test dtype sweep even though the dispatch macro includes it, the one-block-per-row launch casts
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
API_FWD["nvte_scaled_swiglu /\nnvte_scaled_clamped_swiglu /\nnvte_scaled_srelu"]
API_BWD["nvte_scaled_dswiglu /\nnvte_scaled_clamped_dswiglu /\nnvte_scaled_dsrelu"]
API_FWD --> CHK_GATED_FWD{Gated?}
CHK_GATED_FWD -- "SwiGLU / ClampedSwiGLU" --> ALIGN_FWD[check alignment & segment layout]
CHK_GATED_FWD -- "SReLU" --> ALIGN_SRELU_FWD[check alignment]
ALIGN_FWD -- "aligned" --> KFG_VEC["scaled_gated_forward_kernel nvec>1"]
ALIGN_FWD -- "unaligned" --> KFG_SCAL["scaled_gated_forward_kernel nvec=1"]
ALIGN_SRELU_FWD -- "aligned" --> KSF_VEC["scaled_srelu_forward_kernel nvec>1"]
ALIGN_SRELU_FWD -- "unaligned" --> KSF_SCAL["scaled_srelu_forward_kernel nvec=1"]
API_BWD --> CHK_GATED_BWD{Gated?}
CHK_GATED_BWD -- "SwiGLU / ClampedSwiGLU" --> CHK_SCALE_G[grad_act_scales?]
CHK_GATED_BWD -- "SReLU" --> CHK_SCALE_S[grad_act_scales?]
CHK_SCALE_G -- "null" --> KGB_FLAT["scaled_gated_backward_kernel flat grid"]
CHK_SCALE_G -- "present" --> KGB_RED["scaled_gated_backward_with_scale_grad_kernel one block per row + warp reduction"]
CHK_SCALE_S -- "null" --> KSB_FLAT["scaled_srelu_backward_kernel flat grid"]
CHK_SCALE_S -- "present" --> KSB_RED["scaled_srelu_backward_with_scale_grad_kernel one block per row + warp reduction"]
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
API_FWD["nvte_scaled_swiglu /\nnvte_scaled_clamped_swiglu /\nnvte_scaled_srelu"]
API_BWD["nvte_scaled_dswiglu /\nnvte_scaled_clamped_dswiglu /\nnvte_scaled_dsrelu"]
API_FWD --> CHK_GATED_FWD{Gated?}
CHK_GATED_FWD -- "SwiGLU / ClampedSwiGLU" --> ALIGN_FWD[check alignment & segment layout]
CHK_GATED_FWD -- "SReLU" --> ALIGN_SRELU_FWD[check alignment]
ALIGN_FWD -- "aligned" --> KFG_VEC["scaled_gated_forward_kernel nvec>1"]
ALIGN_FWD -- "unaligned" --> KFG_SCAL["scaled_gated_forward_kernel nvec=1"]
ALIGN_SRELU_FWD -- "aligned" --> KSF_VEC["scaled_srelu_forward_kernel nvec>1"]
ALIGN_SRELU_FWD -- "unaligned" --> KSF_SCAL["scaled_srelu_forward_kernel nvec=1"]
API_BWD --> CHK_GATED_BWD{Gated?}
CHK_GATED_BWD -- "SwiGLU / ClampedSwiGLU" --> CHK_SCALE_G[grad_act_scales?]
CHK_GATED_BWD -- "SReLU" --> CHK_SCALE_S[grad_act_scales?]
CHK_SCALE_G -- "null" --> KGB_FLAT["scaled_gated_backward_kernel flat grid"]
CHK_SCALE_G -- "present" --> KGB_RED["scaled_gated_backward_with_scale_grad_kernel one block per row + warp reduction"]
CHK_SCALE_S -- "null" --> KSB_FLAT["scaled_srelu_backward_kernel flat grid"]
CHK_SCALE_S -- "present" --> KSB_RED["scaled_srelu_backward_with_scale_grad_kernel one block per row + warp reduction"]
|
| } | ||
| } |
There was a problem hiding this comment.
gated_unscaled computes unscaled on line 170, but gated_grads unconditionally writes *unscaled on line 171, overwriting it. The first call is dead code — every gated_grads case sets *unscaled before returning, so the result of gated_unscaled is never observed. This should simply be removed.
There was a problem hiding this comment.
either out-dated or not precise comment
There was a problem hiding this comment.
You're right that "dead code" was imprecise. The more accurate characterization: gated_unscaled is a redundant computation — its return value is assigned to unscaled, but gated_grads immediately overwrites *unscaled in all three switch cases before returning, so the first assignment is discarded. The function itself is reachable and executes, but its result is never observed.
The substance still holds: the call to gated_unscaled on line 170 can be removed without any behavioral change, since gated_grads on line 171 already computes and sets unscaled as a side-output needed for the scale gradient. Keeping both creates a mild readability issue — a reader might assume the two calls serve different purposes.
A cleaner form would be:
| } | |
| } | |
| gated_grads(activation, act_in, linear_in, &dact, &dlinear, &unscaled); |
Tip: You can customize Greptile's behavior for this repo with .greptile/rules.md and .greptile/config.json.
|
/te-ci pytorch |
Description
Support Mega-C++ with Cublas BF16 Grouped GEMM backend: #3099
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: