Skip to content

Commit ca777fb

Browse files
albanDpytorchmergebot
authored andcommitted
Add Accelerator device and shell hooks (pytorch#119329)
This adds a concept of Accelerator that points to one of our devices. See DeviceAccelerator.h in this PR for details https://github.com/pytorch/pytorch/pull/119329/files#diff-83cc748bed5df1a453c272cc5ecc7e572d4eb694c5125384d8fbd17a0b5f50c8 It also adds scaffolding for shared C++ API to allow generic feature implementation. This PR in particular updates the autograd engine to use this generic API. Pull Request resolved: pytorch#119329 Approved by: https://github.com/ezyang, https://github.com/huydhn
1 parent e9b78f2 commit ca777fb

16 files changed

+186
-103
lines changed

Diff for: aten/src/ATen/Context.h

+18
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
#pragma once
22

33
#include <ATen/CPUGeneratorImpl.h>
4+
#include <ATen/DeviceAccelerator.h>
45
#include <ATen/LinalgBackend.h>
56
#include <ATen/core/ATenGeneral.h>
67
#include <ATen/core/DeprecatedTypeProperties.h>
78
#include <ATen/core/Generator.h>
89
#include <ATen/core/LegacyTypeDispatch.h>
10+
#include <ATen/detail/AcceleratorHooksInterface.h>
911
#include <ATen/detail/CUDAHooksInterface.h>
1012
#include <ATen/detail/HIPHooksInterface.h>
1113
#include <ATen/detail/IPUHooksInterface.h>
@@ -56,6 +58,22 @@ class TORCH_API Context {
5658
AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled.");
5759
}
5860
}
61+
const AcceleratorHooksInterface& getAcceleratorHooksInterface(
62+
c10::optional<c10::DeviceType> opt_device_type = c10::nullopt) {
63+
c10::DeviceType device_type = opt_device_type.has_value()
64+
? opt_device_type.value()
65+
: at::getAccelerator(true).value();
66+
if (device_type == at::kCUDA) {
67+
return at::detail::getCUDAHooks();
68+
} else if (device_type == at::kMPS) {
69+
return at::detail::getMPSHooks();
70+
} else if (device_type == at::kPrivateUse1) {
71+
return at::detail::getPrivateUse1Hooks();
72+
} else {
73+
AT_ERROR(
74+
c10::DeviceTypeName(device_type), " device type not an accelerator.");
75+
}
76+
}
5977
Device getDeviceFromPtr(void* data, c10::DeviceType device_type) {
6078
initCUDAIfNeeded(device_type);
6179
initHIPIfNeeded(device_type);

Diff for: aten/src/ATen/DeviceAccelerator.cpp

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#include <ATen/DeviceAccelerator.h>
2+
#include <ATen/Context.h>
3+
4+
namespace at {
5+
6+
C10_API std::optional<DeviceType> getAccelerator(bool checked) {
7+
#define CHECK_NO_CUDA \
8+
TORCH_CHECK(!at::hasCUDA(), "Cannot have both CUDA and PrivateUse1");
9+
10+
#define CHECK_NO_PU1 \
11+
TORCH_CHECK(!is_privateuse1_backend_registered(), "Cannot have both CUDA and PrivateUse1");
12+
13+
if (is_privateuse1_backend_registered()) {
14+
// We explicitly allow PrivateUse1 and another device at the same time
15+
// as we use this for testing.
16+
// Whenever a PrivateUse1 device is registered, use it first.
17+
return kPrivateUse1;
18+
} else if (at::hasCUDA()) {
19+
CHECK_NO_PU1
20+
return kCUDA;
21+
} else {
22+
TORCH_CHECK(!checked, "Cannot access accelerator device when none is available.")
23+
return std::nullopt;
24+
}
25+
26+
#undef CHECK_NO_CUDA
27+
#undef CHECK_NO_PU1
28+
}
29+
30+
31+
} // namespace at

Diff for: aten/src/ATen/DeviceAccelerator.h

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#pragma once
2+
3+
#include <c10/core/DeviceType.h>
4+
#include <c10/macros/Macros.h>
5+
6+
#include <ATen/detail/MTIAHooksInterface.h>
7+
#include <optional>
8+
9+
// This file defines the top level Accelerator concept for PyTorch.
10+
// A device is an accelerator per the definition here if:
11+
// - It is mutually exclusive with all other accelerators
12+
// - It performs asynchronous compute via a Stream/Event system
13+
// - It provides a set of common APIs as defined by AcceleratorHooksInterface
14+
//
15+
// As of today, accelerator devices are (in no particular order):
16+
// CUDA, MTIA, PrivateUse1
17+
// We want to add once all the proper APIs are supported and tested:
18+
// HIP, MPS, XPU
19+
20+
namespace at {
21+
22+
// Ensures that only one accelerator is available (at
23+
// compile time if possible) and return it.
24+
// When checked is true, the returned optional always has a value.
25+
TORCH_API std::optional<DeviceType> getAccelerator(bool checked = false);
26+
27+
} // namespace at

Diff for: aten/src/ATen/detail/AcceleratorHooksInterface.h

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#pragma once
2+
3+
#include <c10/core/Device.h>
4+
5+
namespace at {
6+
7+
// AcceleratorHooksInterface is a shared interface provided by all
8+
// accelerators to allow generic code.
9+
// This inferface is hook-based as it corresponds to all the functions
10+
// that are going to be called in a generic way from the CPU code.
11+
12+
struct TORCH_API AcceleratorHooksInterface {
13+
// This should never actually be implemented, but it is used to
14+
// squelch -Werror=non-virtual-dtor
15+
virtual ~AcceleratorHooksInterface() = default;
16+
17+
// Whether the device at device_index is fully initialized or not.
18+
virtual bool hasPrimaryContext(DeviceIndex device_index) const = 0;
19+
};
20+
21+
} // namespace at

Diff for: aten/src/ATen/detail/CUDAHooksInterface.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#include <c10/util/Exception.h>
55
#include <c10/util/Registry.h>
66

7+
#include <ATen/detail/AcceleratorHooksInterface.h>
8+
79
// Forward-declares at::Generator and at::cuda::NVRTC
810
namespace at {
911
struct Generator;
@@ -57,7 +59,7 @@ constexpr const char* CUDA_HELP =
5759
// TODO: Consider putting the stub definitions in another class, so that one
5860
// never forgets to implement each virtual function in the real implementation
5961
// in CUDAHooks. This probably doesn't buy us much though.
60-
struct TORCH_API CUDAHooksInterface {
62+
struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
6163
// This should never actually be implemented, but it is used to
6264
// squelch -Werror=non-virtual-dtor
6365
virtual ~CUDAHooksInterface() = default;
@@ -107,7 +109,7 @@ struct TORCH_API CUDAHooksInterface {
107109
TORCH_CHECK(false, "NVRTC requires CUDA. ", CUDA_HELP);
108110
}
109111

110-
virtual bool hasPrimaryContext(DeviceIndex device_index) const {
112+
virtual bool hasPrimaryContext(DeviceIndex device_index) const override {
111113
TORCH_CHECK(false, "Cannot call hasPrimaryContext(", device_index, ") without ATen_cuda library. ", CUDA_HELP);
112114
}
113115

Diff for: aten/src/ATen/detail/MPSHooksInterface.h

+5-2
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44

55
#include <c10/core/Allocator.h>
66
#include <ATen/core/Generator.h>
7+
#include <ATen/detail/AcceleratorHooksInterface.h>
78
#include <c10/util/Exception.h>
89
#include <c10/util/Registry.h>
910

1011
#include <cstddef>
1112

1213
namespace at {
1314

14-
struct TORCH_API MPSHooksInterface {
15+
struct TORCH_API MPSHooksInterface : AcceleratorHooksInterface {
1516
// this fails the implementation if MPSHooks functions are called, but
1617
// MPS backend is not present.
1718
#define FAIL_MPSHOOKS_FUNC(func) \
@@ -86,7 +87,9 @@ struct TORCH_API MPSHooksInterface {
8687
virtual double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const {
8788
FAIL_MPSHOOKS_FUNC(__func__);
8889
}
89-
90+
virtual bool hasPrimaryContext(DeviceIndex device_index) const override {
91+
FAIL_MPSHOOKS_FUNC(__func__);
92+
}
9093
#undef FAIL_MPSHOOKS_FUNC
9194
};
9295

Diff for: aten/src/ATen/detail/MTIAHooksInterface.h

+11-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
#include <c10/util/Registry.h>
66

7+
#include <ATen/detail/AcceleratorHooksInterface.h>
8+
79
#include <string>
810

911
namespace at {
@@ -17,7 +19,7 @@ constexpr const char* MTIA_HELP =
1719
"this error has occurred because you are trying "
1820
"to use some MTIA's functionality without MTIA extension included.";
1921

20-
struct TORCH_API MTIAHooksInterface {
22+
struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
2123
virtual ~MTIAHooksInterface() = default;
2224

2325
virtual void initMTIA() const {
@@ -37,6 +39,14 @@ struct TORCH_API MTIAHooksInterface {
3739
"Cannot query detailed MTIA version without MTIA Extension for PyTorch.",
3840
MTIA_HELP);
3941
}
42+
43+
virtual bool hasPrimaryContext(DeviceIndex device_index) const override {
44+
TORCH_CHECK(
45+
false,
46+
"Cannot check MTIA primary context without MTIA Extension for PyTorch.",
47+
MTIA_HELP);
48+
}
49+
4050
};
4151

4252
struct TORCH_API MTIAHooksArgs {};

Diff for: aten/src/ATen/detail/PrivateUse1HooksInterface.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,15 @@ TORCH_API bool isPrivateUse1HooksRegistered() {
2222
return privateuse1_hooks != nullptr;
2323
}
2424

25+
namespace detail {
26+
27+
TORCH_API const at::PrivateUse1HooksInterface& getPrivateUse1Hooks() {
28+
TORCH_CHECK(
29+
privateuse1_hooks != nullptr,
30+
"Please register PrivateUse1HooksInterface by `RegisterPrivateUse1HooksInterface` first.");
31+
return *privateuse1_hooks;
2532
}
33+
34+
} // namespace detail
35+
36+
} // namespace at

Diff for: aten/src/ATen/detail/PrivateUse1HooksInterface.h

+9-2
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
#pragma once
22

33
#include <ATen/core/Generator.h>
4+
#include <ATen/detail/AcceleratorHooksInterface.h>
45
#include <c10/core/Allocator.h>
56
#include <c10/core/Device.h>
67
#include <c10/core/Storage.h>
78
#include <c10/util/Exception.h>
89
namespace at {
910

10-
struct TORCH_API PrivateUse1HooksInterface {
11+
struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface {
1112
virtual ~PrivateUse1HooksInterface() = default;
1213
virtual const at::Generator& getDefaultGenerator(
1314
c10::DeviceIndex device_index) {
@@ -28,7 +29,7 @@ struct TORCH_API PrivateUse1HooksInterface {
2829
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getPinnedMemoryAllocator`.");
2930
}
3031

31-
virtual bool hasPrimaryContext(DeviceIndex device_index) const {
32+
virtual bool hasPrimaryContext(DeviceIndex device_index) const override {
3233
TORCH_CHECK_NOT_IMPLEMENTED(
3334
false,
3435
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `hasPrimaryContext`.");
@@ -51,4 +52,10 @@ TORCH_API at::PrivateUse1HooksInterface* GetPrivateUse1HooksInterface();
5152

5253
TORCH_API bool isPrivateUse1HooksRegistered();
5354

55+
namespace detail {
56+
57+
TORCH_API const at::PrivateUse1HooksInterface& getPrivateUse1Hooks();
58+
59+
} // namespace detail
60+
5461
} // namespace at

Diff for: aten/src/ATen/mps/MPSHooks.h

+6
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ struct MPSHooks : public at::MPSHooksInterface {
4646
void synchronizeEvent(uint32_t event_id) const override;
4747
bool queryEvent(uint32_t event_id) const override;
4848
double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const override;
49+
50+
// Compatibility with Accelerator API
51+
bool hasPrimaryContext(DeviceIndex device_index) const override {
52+
// When MPS is available, it is always in use for the one device.
53+
return true;
54+
}
4955
};
5056

5157
} // namespace at::mps

Diff for: build_variables.bzl

+2
Original file line numberDiff line numberDiff line change
@@ -956,6 +956,7 @@ aten_cpu_non_globed_sources = [
956956
aten_cpu_non_globed_headers = [
957957
"aten/src/ATen/CPUGeneratorImpl.h",
958958
"aten/src/ATen/NumericUtils.h",
959+
"aten/src/ATen/detail/AcceleratorHooksInterface.h",
959960
"aten/src/ATen/detail/CUDAHooksInterface.h",
960961
"aten/src/ATen/detail/MPSHooksInterface.h",
961962
"aten/src/ATen/detail/HIPHooksInterface.h",
@@ -970,6 +971,7 @@ aten_cpu_source_non_codegen_list = [
970971
"aten/src/ATen/AccumulateType.cpp",
971972
"aten/src/ATen/LegacyBatchedTensorImpl.cpp",
972973
"aten/src/ATen/CPUGeneratorImpl.cpp",
974+
"aten/src/ATen/DeviceAccelerator.cpp",
973975
"aten/src/ATen/Context.cpp",
974976
"aten/src/ATen/DLConvertor.cpp",
975977
"aten/src/ATen/EmptyTensor.cpp",

0 commit comments

Comments
 (0)