You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Remove sdp_kernel and replace with sdpa_kernel in attention namespace (pytorch#114689)
# Summary
Simplification of Backend Selection
This PR deprecates the `torch.backends/cuda/sdp_kernel` context manager and replaces it with a new context manager `torch.nn.attention.sdpa_kernel`. This context manager also changes the api for this context manager.
For `sdp_kernel` one would specify the backend choice by taking the negation of what kernel they would like to run. The purpose of this backend manager was to only to be a debugging tool, "turn off the math backend" and see if you can run one of the fused implementations.
Problems:
- This pattern makes sense if majority of users don't care to know anything about the backends that can be run. However, if users are seeking to use this context manager then they are explicitly trying to run a specific backend.
- This is not scalable. We are working on adding the cudnn backend and this API makes it so so that more implementations will need to be turned off if user wants to explicitly run a given backend.
- Discoverability of the current context manager. It is somewhat un-intutive that this backend manager is in backends/cuda/init when this now also controls the CPU fused kernel behavior. I think centralizing to attention namespace will be helpful.
Other concerns:
- Typically backends (kernels) for operators are entirely hidden from users and implementation details of the framework. We have exposed this to users already, albeit not by default and with beta warnings. Does making backends choices even more explicit lead to problems when we potentially want to remove existing backends, (perhaps inputs shapes will get covered by newer backends).
A nice side effect is now that we aren't using the `BACKEND_MAP` in test_transformers many, many dynamo failures are passing for CPU tests.
Pull Request resolved: pytorch#114689
Approved by: https://github.com/cpuhrsch
Copy file name to clipboardexpand all lines: torch/csrc/Module.cpp
+3-1
Original file line number
Diff line number
Diff line change
@@ -1809,7 +1809,9 @@ Call this whenever a new thread is created in order to propagate values from
1809
1809
py::enum_<sdp::SDPBackend>(
1810
1810
py_module,
1811
1811
"_SDPBackend",
1812
-
"Enum class for the scaled dot product attention backends\n\n... warning:: This class is in beta and subject to change.")
1812
+
"An enum-like class that contains the different backends for scaled dot product attention.\n\n... warning:: This class is in beta and subject to change.\n\n"
1813
+
"This backend class is designed to be used with the sdpa_kernel context manager."
1814
+
"See :func: torch.nn.attention.sdpa_kernel for more details.")
0 commit comments