Skip to content

Commit 4e29f01

Browse files
drisspgpytorchmergebot
authored andcommitted
Remove sdp_kernel and replace with sdpa_kernel in attention namespace (pytorch#114689)
# Summary Simplification of Backend Selection This PR deprecates the `torch.backends/cuda/sdp_kernel` context manager and replaces it with a new context manager `torch.nn.attention.sdpa_kernel`. This context manager also changes the api for this context manager. For `sdp_kernel` one would specify the backend choice by taking the negation of what kernel they would like to run. The purpose of this backend manager was to only to be a debugging tool, "turn off the math backend" and see if you can run one of the fused implementations. Problems: - This pattern makes sense if majority of users don't care to know anything about the backends that can be run. However, if users are seeking to use this context manager then they are explicitly trying to run a specific backend. - This is not scalable. We are working on adding the cudnn backend and this API makes it so so that more implementations will need to be turned off if user wants to explicitly run a given backend. - Discoverability of the current context manager. It is somewhat un-intutive that this backend manager is in backends/cuda/init when this now also controls the CPU fused kernel behavior. I think centralizing to attention namespace will be helpful. Other concerns: - Typically backends (kernels) for operators are entirely hidden from users and implementation details of the framework. We have exposed this to users already, albeit not by default and with beta warnings. Does making backends choices even more explicit lead to problems when we potentially want to remove existing backends, (perhaps inputs shapes will get covered by newer backends). A nice side effect is now that we aren't using the `BACKEND_MAP` in test_transformers many, many dynamo failures are passing for CPU tests. Pull Request resolved: pytorch#114689 Approved by: https://github.com/cpuhrsch
1 parent 77186af commit 4e29f01

File tree

13 files changed

+225
-1073
lines changed

13 files changed

+225
-1073
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
.. role:: hidden
2+
:class: hidden-section
3+
.. currentmodule:: {{ module }}
4+
5+
6+
{{ name | underline}}
7+
8+
.. autoclass:: {{ name }}
9+
:members:
10+
11+
.. autogenerated from source/_templates/autosummary/class.rst

docs/source/backends.rst

-2
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,6 @@ torch.backends.cuda
6868

6969
.. autofunction:: torch.backends.cuda.preferred_linalg_library
7070

71-
.. autoclass:: torch.backends.cuda.SDPBackend
72-
7371
.. autoclass:: torch.backends.cuda.SDPAParams
7472

7573
.. autofunction:: torch.backends.cuda.flash_sdp_enabled

docs/source/index.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ Features described in this documentation are classified by release status:
9393
torch.package <package>
9494
profiler
9595
nn.init
96-
nn.attention.bias
96+
nn.attention
9797
onnx
9898
optim
9999
complex_numbers

docs/source/nn.attention.bias.rst

+7-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,13 @@ torch.nn.attention.bias
1010
CausalBias
1111
==========
1212

13-
.. autoclass:: CausalBias
13+
.. autosummary::
14+
:toctree: generated
15+
:nosignatures:
16+
:template: classnoinheritance.rst
17+
18+
CausalBias
19+
1420

1521
.. autosummary::
1622
:toctree: generated

docs/source/nn.attention.rst

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
.. role:: hidden
2+
:class: hidden-section
3+
4+
torch.nn.attention
5+
==================
6+
7+
.. automodule:: torch.nn.attention
8+
9+
Utils
10+
-------------------
11+
.. autosummary::
12+
:toctree: generated
13+
:nosignatures:
14+
15+
sdpa_kernel
16+
SDPBackend
17+
18+
Submodules
19+
----------
20+
.. autosummary::
21+
:nosignatures:
22+
23+
bias
24+
25+
.. toctree::
26+
:hidden:
27+
28+
nn.attention.bias

docs/source/nn.rst

-1
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,6 @@ Lazy Modules Initialization
527527

528528
.. This module needs to be documented. Adding here in the meantime
529529
.. for tracking purposes
530-
.. py:module:: torch.nn.attention
531530
.. py:module:: torch.nn.backends
532531
.. py:module:: torch.nn.utils.stateless
533532
.. py:module:: torch.nn.backends.thnn

test/test_transformers.py

+66-75
Large diffs are not rendered by default.

torch/backends/cuda/__init__.py

+26-15
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import contextlib
2+
import warnings
23

34
from typing import Union
45

@@ -13,7 +14,6 @@
1314
"preferred_linalg_library",
1415
"cufft_plan_cache",
1516
"matmul",
16-
"SDPBackend",
1717
"SDPAParams",
1818
"enable_flash_sdp",
1919
"flash_sdp_enabled",
@@ -204,10 +204,9 @@ def preferred_linalg_library(
204204
return torch._C._get_linalg_preferred_backend()
205205

206206

207-
from torch._C import _SDPAParams as SDPAParams, _SDPBackend as SDPBackend
207+
from torch._C import _SDPAParams as SDPAParams
208208

209209
# Set the __module__ attribute
210-
SDPBackend.__module__ = "torch.backends.cuda"
211210
SDPAParams.__module__ = "torch.backends.cuda"
212211
SDPAParams.__name__ = "SDPAParams"
213212

@@ -318,18 +317,30 @@ def sdp_kernel(
318317
This context manager can be used to temporarily enable or disable any of the three backends for scaled dot product attention.
319318
Upon exiting the context manager, the previous state of the flags will be restored.
320319
"""
321-
previous_flash: bool = flash_sdp_enabled()
322-
previous_mem_efficient: bool = mem_efficient_sdp_enabled()
323-
previous_math: bool = math_sdp_enabled()
324-
try:
325-
enable_flash_sdp(enable_flash)
326-
enable_mem_efficient_sdp(enable_mem_efficient)
327-
enable_math_sdp(enable_math)
328-
yield {}
329-
finally:
330-
enable_flash_sdp(previous_flash)
331-
enable_mem_efficient_sdp(previous_mem_efficient)
332-
enable_math_sdp(previous_math)
320+
warnings.warn(
321+
(
322+
"torch.backends.cuda.sdp_kernel() "
323+
"is deprecated. In the future, this context manager will be removed. "
324+
"Please see, torch.nn.attention.sdpa_kernel() for the new context manager, with updated "
325+
"signature."
326+
),
327+
FutureWarning,
328+
)
329+
from torch.nn.attention import sdpa_kernel, SDPBackend
330+
331+
backend_list = []
332+
if enable_flash:
333+
backend_list.append(SDPBackend.FLASH_ATTENTION)
334+
if enable_mem_efficient:
335+
backend_list.append(SDPBackend.EFFICIENT_ATTENTION)
336+
if enable_math:
337+
backend_list.append(SDPBackend.MATH)
338+
339+
with sdpa_kernel(backend_list) as context:
340+
try:
341+
yield context
342+
finally:
343+
pass
333344

334345

335346
cufft_plan_cache = cuFFTPlanCacheManager()

torch/csrc/Module.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -1809,7 +1809,9 @@ Call this whenever a new thread is created in order to propagate values from
18091809
py::enum_<sdp::SDPBackend>(
18101810
py_module,
18111811
"_SDPBackend",
1812-
"Enum class for the scaled dot product attention backends\n\n... warning:: This class is in beta and subject to change.")
1812+
"An enum-like class that contains the different backends for scaled dot product attention.\n\n... warning:: This class is in beta and subject to change.\n\n"
1813+
"This backend class is designed to be used with the sdpa_kernel context manager."
1814+
"See :func: torch.nn.attention.sdpa_kernel for more details.")
18131815
.value("ERROR", sdp::SDPBackend::error)
18141816
.value("MATH", sdp::SDPBackend::math)
18151817
.value("FLASH_ATTENTION", sdp::SDPBackend::flash_attention)

torch/nested/_internal/sdpa.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
math_sdp_enabled,
1313
mem_efficient_sdp_enabled,
1414
SDPAParams,
15-
SDPBackend,
1615
)
1716

17+
from torch.nn.attention import SDPBackend
18+
1819
from .nested_tensor import buffer_from_jagged, NestedTensor, ViewNestedFromBuffer
1920

2021
log = logging.getLogger(__name__)

torch/nn/attention/__init__.py

+62-3
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,24 @@
1-
from typing import List
1+
""" This module contains functions and classes that alter the behavior of torch.nn.functional.scaled_dot_product_attention """
2+
import contextlib
3+
from typing import List, Union
24
from warnings import warn
35

46
from torch.backends.cuda import (
57
can_use_efficient_attention,
68
can_use_flash_attention,
9+
enable_flash_sdp,
10+
enable_math_sdp,
11+
enable_mem_efficient_sdp,
12+
flash_sdp_enabled,
13+
math_sdp_enabled,
14+
mem_efficient_sdp_enabled,
715
SDPAParams,
816
)
917

10-
__all__: List[str] = []
18+
__all__: List[str] = ["SDPBackend", "sdpa_kernel", "WARN_FOR_UNFUSED_KERNELS"]
1119

1220
# Note: [SDPA warnings]
13-
# TODO: Consider using this to sdpa regardless of subclasses
21+
# TODO: Consider using this for sdpa regardless of subclasses
1422
# This only effects users of bias subclasses
1523
# If this is set to True, we will warn the user if they are not using the fused kernels
1624
# As well, it will raise warnings for all the reasons why the fused kernels can't be run.
@@ -19,6 +27,21 @@
1927
WARN_FOR_UNFUSED_KERNELS = False
2028

2129

30+
from torch._C import _SDPBackend as SDPBackend
31+
32+
# Hacks for Sphinx documentation:
33+
# https://stackoverflow.com/questions/38765577/overriding-sphinx-autodoc-alias-of-for-import-of-private-class
34+
SDPBackend = SDPBackend
35+
r"""An enum-like class that contains the different backends for scaled dot product attention.
36+
This backend class is designed to be used with the sdpa_kernel context manager.
37+
See :func: torch.nn.attention.sdpa_kernel for more details.
38+
39+
... warning:: This class is in beta and subject to change.
40+
"""
41+
SDPBackend.__module__ = __name__
42+
SDPBackend.__name__ = "SDPBackend"
43+
44+
2245
def _raise_kernel_warnings(params: SDPAParams) -> None:
2346
"""
2447
If WARN_FOR_UNFUSED_KERNELS is set to True, this will raise warnings
@@ -31,3 +54,39 @@ def _raise_kernel_warnings(params: SDPAParams) -> None:
3154
if not can_use_flash_attention(params):
3255
warn("Flash attention can't be used because:")
3356
can_use_flash_attention(params, True)
57+
58+
59+
@contextlib.contextmanager
60+
def sdpa_kernel(backends: List[SDPBackend]):
61+
r"""
62+
Context manager to select which backend to use for scaled dot product attention.
63+
64+
.. warning:: This function is beta and subject to change.
65+
66+
Args:
67+
backend (Union[List[SDPBackend], SDPBackend]): A backend or list of backends for scaled dot product attention.
68+
69+
This context manager can be used to select which backend to use for scaled dot product attention.
70+
Upon exiting the context manager, the previous state of the flags will be restored, enabling all backends.
71+
"""
72+
assert backends is None or isinstance(
73+
backends, list
74+
), "Backend must be an instance of SDPBackend or a list of SDPBackend instances"
75+
76+
backends = set(backends)
77+
previous_flash: bool = flash_sdp_enabled()
78+
previous_mem_efficient: bool = mem_efficient_sdp_enabled()
79+
previous_math: bool = math_sdp_enabled()
80+
try:
81+
enable_flash = SDPBackend.FLASH_ATTENTION in backends
82+
enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION in backends
83+
enable_math = SDPBackend.MATH in backends
84+
85+
enable_flash_sdp(enable_flash)
86+
enable_mem_efficient_sdp(enable_mem_efficient)
87+
enable_math_sdp(enable_math)
88+
yield {}
89+
finally:
90+
enable_flash_sdp(previous_flash)
91+
enable_mem_efficient_sdp(previous_mem_efficient)
92+
enable_math_sdp(previous_math)

torch/nn/attention/bias.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Defines utilities for interacting with scaled_dot_product_attention"""
1+
"""Defines bias subclasses that work with scaled_dot_product_attention"""
22
from enum import auto, IntEnum
33
from typing import Optional
44
from warnings import warn

0 commit comments

Comments
 (0)