Skip to content

[PyTorch] Support for cuDNN-backed flex attention#2984

Merged
vcherepanov-nv merged 30 commits into
NVIDIA:mainfrom
vcherepanov-nv:cudnn-score-mod-3
Jun 4, 2026
Merged

[PyTorch] Support for cuDNN-backed flex attention#2984
vcherepanov-nv merged 30 commits into
NVIDIA:mainfrom
vcherepanov-nv:cudnn-score-mod-3

Conversation

@vcherepanov-nv

@vcherepanov-nv vcherepanov-nv commented May 13, 2026

Copy link
Copy Markdown
Collaborator

Description

Adds experimental PyTorch support for cuDNN-backed flex attention in DotProductAttention via a new score_mod callback path.

Users can pass:

  • score_mod(graph, score, tensors) -> score for forward score modification
  • optional score_mod_bprop(graph, dP, tensors) -> dP for backward
  • optional runtime tensor dictionaries for forward/backward score-mod graph inputs

When score_mod_bprop is supplied, it is the user's responsibility to make it mathematically consistent with score_mod. TE forwards this callback to cuDNN as provided and does not derive or validate the backward score transformation automatically.

Supported score_mod configuration

The current cuDNN-backed Flex Attention path supports:

  • PyTorch DotProductAttention / FusedAttention
  • FP16 or BF16 unquantized torch.Tensor Q/K/V inputs
  • SBHD or BSHD Q/K/V layouts
  • cuDNN F16/BF16 arbitrary-seqlen fused attention backend
  • attn_mask_type="no_mask"
  • core_attention_bias_type="no_bias" with no explicit bias tensor
  • vanilla softmax
  • attention_dropout=0.0
  • num_splits=1

The path is currently not supported with FP8, fp8_output, THD format, explicit cu_seqlens inputs, pad_between_seqs, attention masks, attention bias, ALiBi, sliding-window attention, sink attention, dropout, KV cache, context parallelism, CUDA graph capture, checkpointed core attention, or return_max_logit.

For deterministic execution, TE passes the deterministic setting through backend selection and forwards it to cuDNN Frontend sdpa_backward as use_deterministic_algorithm. The score_mod forward sdpa call does not take a separate deterministic flag.

Fixes #2492.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Adds FusedAttentionWithScoreModFunc, a cuDNN frontend Python graph path for SDPA forward/backward with score_mod and score_mod_bprop.
  • Extends DotProductAttention / FusedAttention APIs with score_mod, score_mod_bprop, score_mod_tensors, and score_mod_bprop_tensors.
  • Adds backend-selection filtering so score_mod only selects supported cuDNN fused attention configurations.
  • Adds execution-plan caching for forward and backward score-mod graphs, keyed by tensor metadata, layout, scale, callback topology, and runtime tensor metadata.
  • Supports explicit score_mod_graph_cache_key() for stateful callbacks, while leaving unsafe unkeyed bound methods uncached.
  • Executes cuDNN graphs on PyTorch's current CUDA stream and preserves SBHD/BSHD layouts without extra BHSD copies.
  • Adds validation for unsupported combinations including FP8, context parallelism, THD, KV cache, explicit masks/biases, dropout, non-vanilla softmax, CUDA graph capture, and checkpointed core attention.
  • Adds tests for cache-key behavior, unsafe callback caching, runtime tensor version checking, and CUDA correctness cases covering causal masking, softcap, and post-scale-bias-style score modification.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
@greptile-apps

greptile-apps Bot commented May 13, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds experimental cuDNN-backed Flex Attention to DotProductAttention and FusedAttention via a score_mod callback path. A new FusedAttentionWithScoreModFunc autograd Function builds and caches cuDNN frontend Python graphs for forward and backward, with a sophisticated cache-key scheme that safely handles lambdas, bound methods, and stateful callable instances.

  • Introduces flex_attention.py with cuDNN graph construction, execution, and cache-key logic; extends DotProductAttention / FusedAttention APIs with score_mod, score_mod_bprop, score_mod_tensors, and score_mod_bprop_tensors parameters.
  • Adds backend-selection filtering in get_attention_backend gating score_mod to the F16_arbitrary_seqlen fused-attention sub-backend only, and disabling FlashAttention and unfused paths.
  • Ships a comprehensive test suite covering cache-key correctness, version-counter safety, and CUDA correctness for causal, softcap, and post-scale-bias score modifications.

Confidence Score: 5/5

The PR is safe to merge. No code path executes incorrect attention computation under the supported configurations, and the gate assertions in DotProductAttention block every unsupported combination before reaching cuDNN.

The core cuDNN graph construction, execution, caching, and autograd wiring are all correct. The two findings are both non-blocking style/documentation concerns: the flash-attention master-switch helper relies on a late AND-reduction that works correctly today, and the silent gradient drop for requires_grad tensors in score_mod dicts is a footgun to document rather than a runtime defect for the described use cases.

flex_attention.py (requires_grad footgun for score_mod_tensors) and utils.py (_disable_all_flash_attention robustness) are the two files worth a second look before follow-on work extends the backend-selection filter chain.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/flex_attention.py New cuDNN frontend Python SDPA path — graph construction, caching (with safe lambda/bound-method keying), execution, and autograd Function. Core logic is sound; requires_grad on score_mod_tensors silently produces no gradients.
transformer_engine/pytorch/attention/dot_product_attention/utils.py Adds score_mod backend filtering and new AttentionParams fields; fixes fp8_meta None guard in eq. _disable_all_flash_attention only directly disables the master flag, relying on a late AND-reduction — works today but is fragile.
transformer_engine/pytorch/attention/dot_product_attention/backends.py Adds score_mod dispatch branch in FusedAttention.forward; correctly placed as elif before the standard FusedAttnFunc path.
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py Extends DotProductAttention API with score_mod parameters; validation block correctly asserts structural constraints including score_mod_bprop_tensors requiring score_mod_bprop.
tests/pytorch/attention/test_flex_attention.py New test file; covers cache key correctness, version-counter safety, and CUDA correctness for causal/softcap/post-scale-bias cases. CPU tensors passed to CUDA cuDNN score_mod_tensors (noted in prior threads).
tests/pytorch/utils.py Adds has_score_mod/has_score_mod_bprop flags to get_available_attention_backends; straightforward addition.
qa/L0_pytorch_unittest/test.sh Adds test_flex_attention.py to the CI test suite.

Sequence Diagram

sequenceDiagram
    participant User
    participant DPA as DotProductAttention
    participant BS as get_attention_backend
    participant FA as FusedAttention
    participant Func as FusedAttentionWithScoreModFunc
    participant Cache as _cudnn_score_mod_graph_cache
    participant cuDNN as cuDNN Frontend

    User->>DPA: "forward(q,k,v, score_mod=..., score_mod_tensors=...)"
    DPA->>BS: "get_attention_backend(has_score_mod=True)"
    BS-->>DPA: "use_fused_attention=True (F16_arbitrary_seqlen only)"
    DPA->>FA: "forward(q,k,v, score_mod=..., score_mod_tensors=...)"
    FA->>Func: apply(is_training, q,k,v, score_mod, ...)

    Func->>Cache: _get_cudnn_score_mod_fwd_graph(key)
    alt Cache miss
        Cache->>cuDNN: "build pygraph + sdpa(score_mod=wrapped_cb)"
        cuDNN->>User: score_mod(graph, score_tensor, tensors) to score
        cuDNN-->>Cache: compiled graph entry
        Cache-->>Func: _CudnnScoreModFwdGraphEntry
    else Cache hit
        Cache-->>Func: cached entry
    end

    Func->>cuDNN: "execute(variant_pack={q,k,v,output,stats,score_mod_tensors})"
    cuDNN-->>Func: output, stats
    Func-->>User: output

    User->>Func: backward(d_out)
    Func->>Cache: _get_cudnn_score_mod_bwd_graph(key)
    alt Cache miss
        Cache->>cuDNN: build pygraph + sdpa_backward(score_mod, score_mod_bprop)
        cuDNN->>User: score_mod_bprop(graph, dP, tensors) to dP
        cuDNN-->>Cache: compiled backward graph entry
    end
    Func->>cuDNN: "execute(variant_pack={q,k,v,o,dO,stats,dq,dk,dv})"
    cuDNN-->>Func: dq, dk, dv
    Func-->>User: dq, dk, dv
Loading

Reviews (13): Last reviewed commit: "Address Flex Attention review comments" | Re-trigger Greptile

Comment thread transformer_engine/pytorch/attention/dot_product_attention/backends.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/backends.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/backends.py Outdated
Comment thread tests/pytorch/attention/test_attention.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/backends.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/backends.py Outdated
Comment thread tests/pytorch/attention/test_attention.py Outdated
vcherepanov-nv and others added 3 commits May 15, 2026 00:48
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Comment thread transformer_engine/pytorch/attention/dot_product_attention/backends.py Outdated
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Comment thread tests/pytorch/attention/test_attention.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/utils.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/utils.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/utils.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/utils.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/utils.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/backends.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/backends.py Outdated
Comment thread tests/pytorch/attention/test_attention.py Outdated
@KshitijLakhani

Copy link
Copy Markdown
Collaborator

Thanks for creating this PR @vcherepanov-nv
This is great !

I was curious about:

  1. Do you have benchmark numbers bases on any toy test cases you might have run ? - would be good to have those in here for users of the API.
    1. native PyT flex vs TE PyT flex
    2. traditional causal TE via cuDNN vs flex expressed causal TE via cuDNN
  2. I've linked the GH issue in the PR description. Could you please update / close it appropriately when this PR is merged
    Thanks !

vcherepanov-nv and others added 2 commits May 19, 2026 21:17
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Comment thread tests/pytorch/attention/test_flex_attention.py
Comment thread tests/pytorch/attention/test_flex_attention.py
@vcherepanov-nv

Copy link
Copy Markdown
Collaborator Author

Thanks for the thorough review!

  1. Do you have benchmark numbers bases on any toy test cases you might have run ? - would be good to have those in here for users of the API.

    1. native PyT flex vs TE PyT flex
    2. traditional causal TE via cuDNN vs flex expressed causal TE via cuDNN

I haven't done any benchmarking. Reportedly (from a Slack thread) score_mod can lead to significant perf gains if it allows to avoid mask materialization. For causal, I think I observed cuDNN choosing exactly the same kernel with score_mod and the explicit causal flag.

  1. I've linked the GH issue in the PR description. Could you please update / close it appropriately when this PR is merged

Sure, thanks for linking!

@sudhakarsingh27 sudhakarsingh27 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! A few comments;
0. Agree with all the comments from @KshitijLakhani and @cyanguwa, so just +1ed them

  1. A user doc specifying the design choices and the building blocks of graph caching would be valuable.
  2. score_mod seems like a argument more than a feature and so the error messaging could use something more substantial like "(TE/cuDNN) Flex Attention"
  3. New arguments of the form has_* in AttentionParams could be avoided. If passing score_mod, score_mod_tensors (which are hefty) is the blocker, could we create a encapsulating dataclass and pass that instead?
  4. user_supplied_seqlens is a big vague, it seems like just a derived variable - does it degenerate to mean pad_between_seqs=True?
    Among other nits

Comment thread transformer_engine/pytorch/attention/dot_product_attention/backends.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/backends.py Outdated
Comment thread tests/pytorch/attention/test_flex_attention.py
Comment thread tests/pytorch/attention/test_flex_attention.py
Comment thread transformer_engine/pytorch/attention/dot_product_attention/utils.py Outdated
Comment thread tests/pytorch/attention/test_flex_attention.py
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 21, 2026
or cu_seqlens_q_padded is not None
or cu_seqlens_kv_padded is not None
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be removed, and replaced by checking if padding in attention_mask_type because that's when those tensors are used (i.e. THD, or non-THD + padding_xxx mask).

use_flash_attention = False
use_flash_attention_2 = False
use_flash_attention_3 = False
use_flash_attention_4 = False

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't use_flash_attention=False disable all use_flash_attention_x? I thought the relationship was use_flash_attention=True when one of use_flash_attention_x is True, but when we set use_flash_attention=False, we're effectively disabling all use_flash_attention_x.

Comment thread transformer_engine/pytorch/attention/dot_product_attention/utils.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/backends.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/backends.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/flex_attention.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/flex_attention.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/utils.py
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
@vcherepanov-nv vcherepanov-nv requested a review from cyanguwa June 2, 2026 20:19
Comment thread transformer_engine/pytorch/attention/dot_product_attention/utils.py
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
if any(cudnn_frontend_package.glob("_compiled_module*")):
if cudnn_frontend_path not in sys.path:
sys.path.insert(0, cudnn_frontend_path)
return importlib.import_module("cudnn")

@cyanguwa cyanguwa Jun 3, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we expect users to compile in the 3rdparty/cudnn-frontend directory first before using this feature? i.e. how do they get the _compiled_module? Do we need to set this up in our setup.py file so users won't have this issue?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was just a question of submodule FE vs system FE that we discussed on Slack.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will investigate this in a follow-up PR.

stats: Optional[torch.Tensor],
) -> _CudnnScoreModFwdGraphEntry:
"""Build a cached cuDNN frontend graph for score_mod fprop."""
cudnn = _import_cudnn_frontend()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're calling this in pretty much every function in this file. We could do this once at the top of the file. Thanks.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So what's your concern here? Python caches repeated importlib.import_module("cudnn").

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's written like this because backends.py imports flex_attention.py eagerly, and cudnn python package might be optional?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CPU overhead. I think this could be called once at the beginning of the file?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will investigate this in a follow-up PR.

Also, if there's any way to make the feature easier to understand to users.

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
@cyanguwa

cyanguwa commented Jun 4, 2026

Copy link
Copy Markdown
Collaborator

/te-ci PyTorch L0

@KshitijLakhani KshitijLakhani left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deferring to @cyanguwa's thorough review of this PR and @vlad having perused/addressed my review comment sfrom before
Approving the PR so as to not hold it back
Good to merge once CI passes

@vcherepanov-nv vcherepanov-nv merged commit 97a9bfe into NVIDIA:main Jun 4, 2026
12 of 15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[PyTorch/Jax/common] Flex attention via cuDNN

4 participants