[PyTorch] Support for cuDNN-backed flex attention#2984
Conversation
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds experimental cuDNN-backed Flex Attention to
Confidence Score: 5/5The PR is safe to merge. No code path executes incorrect attention computation under the supported configurations, and the gate assertions in DotProductAttention block every unsupported combination before reaching cuDNN. The core cuDNN graph construction, execution, caching, and autograd wiring are all correct. The two findings are both non-blocking style/documentation concerns: the flash-attention master-switch helper relies on a late AND-reduction that works correctly today, and the silent gradient drop for requires_grad tensors in score_mod dicts is a footgun to document rather than a runtime defect for the described use cases. flex_attention.py (requires_grad footgun for score_mod_tensors) and utils.py (_disable_all_flash_attention robustness) are the two files worth a second look before follow-on work extends the backend-selection filter chain. Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant DPA as DotProductAttention
participant BS as get_attention_backend
participant FA as FusedAttention
participant Func as FusedAttentionWithScoreModFunc
participant Cache as _cudnn_score_mod_graph_cache
participant cuDNN as cuDNN Frontend
User->>DPA: "forward(q,k,v, score_mod=..., score_mod_tensors=...)"
DPA->>BS: "get_attention_backend(has_score_mod=True)"
BS-->>DPA: "use_fused_attention=True (F16_arbitrary_seqlen only)"
DPA->>FA: "forward(q,k,v, score_mod=..., score_mod_tensors=...)"
FA->>Func: apply(is_training, q,k,v, score_mod, ...)
Func->>Cache: _get_cudnn_score_mod_fwd_graph(key)
alt Cache miss
Cache->>cuDNN: "build pygraph + sdpa(score_mod=wrapped_cb)"
cuDNN->>User: score_mod(graph, score_tensor, tensors) to score
cuDNN-->>Cache: compiled graph entry
Cache-->>Func: _CudnnScoreModFwdGraphEntry
else Cache hit
Cache-->>Func: cached entry
end
Func->>cuDNN: "execute(variant_pack={q,k,v,output,stats,score_mod_tensors})"
cuDNN-->>Func: output, stats
Func-->>User: output
User->>Func: backward(d_out)
Func->>Cache: _get_cudnn_score_mod_bwd_graph(key)
alt Cache miss
Cache->>cuDNN: build pygraph + sdpa_backward(score_mod, score_mod_bprop)
cuDNN->>User: score_mod_bprop(graph, dP, tensors) to dP
cuDNN-->>Cache: compiled backward graph entry
end
Func->>cuDNN: "execute(variant_pack={q,k,v,o,dO,stats,dq,dk,dv})"
cuDNN-->>Func: dq, dk, dv
Func-->>User: dq, dk, dv
Reviews (13): Last reviewed commit: "Address Flex Attention review comments" | Re-trigger Greptile |
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
|
Thanks for creating this PR @vcherepanov-nv I was curious about:
|
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
|
Thanks for the thorough review!
I haven't done any benchmarking. Reportedly (from a Slack thread) score_mod can lead to significant perf gains if it allows to avoid mask materialization. For causal, I think I observed cuDNN choosing exactly the same kernel with score_mod and the explicit causal flag.
Sure, thanks for linking! |
There was a problem hiding this comment.
Thanks for the PR! A few comments;
0. Agree with all the comments from @KshitijLakhani and @cyanguwa, so just +1ed them
- A user doc specifying the design choices and the building blocks of graph caching would be valuable.
score_modseems like a argument more than a feature and so the error messaging could use something more substantial like "(TE/cuDNN) Flex Attention"- New arguments of the form
has_*inAttentionParamscould be avoided. If passingscore_mod,score_mod_tensors(which are hefty) is the blocker, could we create a encapsulating dataclass and pass that instead? user_supplied_seqlensis a big vague, it seems like just a derived variable - does it degenerate to meanpad_between_seqs=True?
Among other nits
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
| or cu_seqlens_q_padded is not None | ||
| or cu_seqlens_kv_padded is not None | ||
| ) | ||
|
|
There was a problem hiding this comment.
This can be removed, and replaced by checking if padding in attention_mask_type because that's when those tensors are used (i.e. THD, or non-THD + padding_xxx mask).
| use_flash_attention = False | ||
| use_flash_attention_2 = False | ||
| use_flash_attention_3 = False | ||
| use_flash_attention_4 = False |
There was a problem hiding this comment.
Wouldn't use_flash_attention=False disable all use_flash_attention_x? I thought the relationship was use_flash_attention=True when one of use_flash_attention_x is True, but when we set use_flash_attention=False, we're effectively disabling all use_flash_attention_x.
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
| if any(cudnn_frontend_package.glob("_compiled_module*")): | ||
| if cudnn_frontend_path not in sys.path: | ||
| sys.path.insert(0, cudnn_frontend_path) | ||
| return importlib.import_module("cudnn") |
There was a problem hiding this comment.
Do we expect users to compile in the 3rdparty/cudnn-frontend directory first before using this feature? i.e. how do they get the _compiled_module? Do we need to set this up in our setup.py file so users won't have this issue?
There was a problem hiding this comment.
This was just a question of submodule FE vs system FE that we discussed on Slack.
There was a problem hiding this comment.
Will investigate this in a follow-up PR.
| stats: Optional[torch.Tensor], | ||
| ) -> _CudnnScoreModFwdGraphEntry: | ||
| """Build a cached cuDNN frontend graph for score_mod fprop.""" | ||
| cudnn = _import_cudnn_frontend() |
There was a problem hiding this comment.
We're calling this in pretty much every function in this file. We could do this once at the top of the file. Thanks.
There was a problem hiding this comment.
So what's your concern here? Python caches repeated importlib.import_module("cudnn").
There was a problem hiding this comment.
I think it's written like this because backends.py imports flex_attention.py eagerly, and cudnn python package might be optional?
There was a problem hiding this comment.
CPU overhead. I think this could be called once at the beginning of the file?
There was a problem hiding this comment.
Will investigate this in a follow-up PR.
Also, if there's any way to make the feature easier to understand to users.
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
|
/te-ci PyTorch L0 |
Description
Adds experimental PyTorch support for cuDNN-backed flex attention in
DotProductAttentionvia a newscore_modcallback path.Users can pass:
score_mod(graph, score, tensors) -> scorefor forward score modificationscore_mod_bprop(graph, dP, tensors) -> dPfor backwardWhen
score_mod_bpropis supplied, it is the user's responsibility to make it mathematically consistent withscore_mod. TE forwards this callback to cuDNN as provided and does not derive or validate the backward score transformation automatically.Supported score_mod configuration
The current cuDNN-backed Flex Attention path supports:
DotProductAttention/FusedAttentiontorch.TensorQ/K/V inputsattn_mask_type="no_mask"core_attention_bias_type="no_bias"with no explicit bias tensorattention_dropout=0.0num_splits=1The path is currently not supported with FP8,
fp8_output, THD format, explicitcu_seqlensinputs,pad_between_seqs, attention masks, attention bias, ALiBi, sliding-window attention, sink attention, dropout, KV cache, context parallelism, CUDA graph capture, checkpointed core attention, orreturn_max_logit.For deterministic execution, TE passes the deterministic setting through backend selection and forwards it to cuDNN Frontend
sdpa_backwardasuse_deterministic_algorithm. The score_mod forwardsdpacall does not take a separate deterministic flag.Fixes #2492.
Type of change
Changes
FusedAttentionWithScoreModFunc, a cuDNN frontend Python graph path for SDPA forward/backward withscore_modandscore_mod_bprop.DotProductAttention/FusedAttentionAPIs withscore_mod,score_mod_bprop,score_mod_tensors, andscore_mod_bprop_tensors.score_modonly selects supported cuDNN fused attention configurations.score_mod_graph_cache_key()for stateful callbacks, while leaving unsafe unkeyed bound methods uncached.Checklist: