Skip to content

Commit ee3e48d

Browse files
gchananfacebook-github-bot
authored andcommitted
Move Backend, Layout, ATenGeneral, Deprecated, Generator to ATen/core. (pytorch#10740)
Summary: I included "legacy" includes in the old spots for Backend, Generator, Layout; it seemed unlikely that the other ones had direct user includes. This is another step on the path to move Type/Tensor to ATen/core. Pull Request resolved: pytorch#10740 Reviewed By: ezyang Differential Revision: D9435888 Pulled By: gchanan fbshipit-source-id: 89f4f0f445d4498a059d3a79069ba641b22bbcac
1 parent 5ca2713 commit ee3e48d

34 files changed

+267
-262
lines changed

aten/src/ATen/ATen.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22

3-
#include "ATen/ATenGeneral.h"
3+
#include "ATen/core/ATenGeneral.h"
44
#include "ATen/Allocator.h"
55
#include "ATen/CPUGeneral.h"
66
#include "ATen/CUDAGuard.h"
@@ -11,8 +11,8 @@
1111
#include "ATen/Dispatch.h"
1212
#include "ATen/Formatting.h"
1313
#include "ATen/Functions.h"
14-
#include "ATen/Generator.h"
15-
#include "ATen/Layout.h"
14+
#include "ATen/core/Generator.h"
15+
#include "ATen/core/Layout.h"
1616
#include "ATen/OptionsGuard.h"
1717
#include "ATen/Scalar.h"
1818
#include "ATen/Storage.h"

aten/src/ATen/ATenGeneral.cpp

-1
This file was deleted.

aten/src/ATen/Backend.h

+1-167
Original file line numberDiff line numberDiff line change
@@ -1,168 +1,2 @@
11
#pragma once
2-
3-
#include <ATen/core/TensorTypeId.h>
4-
#include <ATen/core/TensorTypeIdRegistration.h>
5-
#include <ATen/core/Error.h>
6-
#include <ATen/core/DeviceType.h>
7-
8-
#include <stdexcept>
9-
10-
namespace at {
11-
12-
/**
13-
* This legacy enum class defines the set of backends supported by
14-
* old school, code generated Type-based ATen. The reason we are
15-
* sunsetting this enum class is because it doesn't allow for
16-
* open registration of backends. TensorTypeId is the replacement
17-
* for Backend which supports open registration.
18-
*
19-
* ARE YOU SURE YOU WANT TO USE THIS TYPE? Think about if SparseCPU/SparseCUDA
20-
* would make sense in your use case. If it doesn't make sense, maybe
21-
* you want DeviceType.
22-
*/
23-
enum class Backend { CPU, CUDA, SparseCPU, SparseCUDA, Undefined, NumOptions };
24-
25-
static inline Backend toSparse(Backend b) {
26-
switch (b) {
27-
case Backend::CPU:
28-
return Backend::SparseCPU;
29-
case Backend::CUDA:
30-
return Backend::SparseCUDA;
31-
case Backend::SparseCPU:
32-
return Backend::SparseCPU;
33-
case Backend::SparseCUDA:
34-
return Backend::SparseCUDA;
35-
default:
36-
throw std::runtime_error("Unknown backend");
37-
}
38-
}
39-
40-
static inline Backend toDense(Backend b) {
41-
switch (b) {
42-
case Backend::CPU:
43-
return Backend::CPU;
44-
case Backend::CUDA:
45-
return Backend::CUDA;
46-
case Backend::SparseCPU:
47-
return Backend::CPU;
48-
case Backend::SparseCUDA:
49-
return Backend::CUDA;
50-
default:
51-
throw std::runtime_error("Unknown backend");
52-
}
53-
}
54-
55-
static inline Backend tensorTypeIdToBackend(TensorTypeId t) {
56-
if (t == CPUTensorId()) {
57-
return Backend::CPU;
58-
} else if (t == CUDATensorId()) {
59-
return Backend::CUDA;
60-
} else if (t == SparseCPUTensorId()) {
61-
return Backend::SparseCPU;
62-
} else if (t == SparseCUDATensorId()) {
63-
return Backend::SparseCUDA;
64-
} else if (t == UndefinedTensorId()) {
65-
return Backend::Undefined;
66-
} else {
67-
AT_ERROR("Unrecognized tensor type ID: ", t);
68-
}
69-
}
70-
71-
static inline TensorTypeId backendToTensorTypeId(Backend b) {
72-
switch (b) {
73-
case Backend::CPU:
74-
return CPUTensorId();
75-
case Backend::CUDA:
76-
return CUDATensorId();
77-
case Backend::SparseCPU:
78-
return SparseCPUTensorId();
79-
case Backend::SparseCUDA:
80-
return SparseCUDATensorId();
81-
case Backend::Undefined:
82-
return UndefinedTensorId();
83-
default:
84-
throw std::runtime_error("Unknown backend");
85-
}
86-
}
87-
88-
static inline DeviceType backendToDeviceType(Backend b) {
89-
switch (b) {
90-
case Backend::CPU:
91-
return DeviceType::CPU;
92-
case Backend::CUDA:
93-
return DeviceType::CUDA;
94-
case Backend::SparseCPU:
95-
return DeviceType::CPU;
96-
case Backend::SparseCUDA:
97-
return DeviceType::CUDA;
98-
case Backend::Undefined:
99-
AT_ERROR("Undefined backend is not a valid device type");
100-
default:
101-
AT_ERROR("Unknown backend");
102-
}
103-
}
104-
105-
static inline Backend deviceTypeToBackend(DeviceType d) {
106-
switch (d) {
107-
case DeviceType::CPU:
108-
return Backend::CPU;
109-
case DeviceType::CUDA:
110-
return Backend::CUDA;
111-
default:
112-
AT_ERROR("Unknown device type ", d);
113-
}
114-
}
115-
116-
static inline Backend backendToCPU(Backend b) {
117-
switch (b) {
118-
case Backend::CPU:
119-
return Backend::CPU;
120-
case Backend::CUDA:
121-
return Backend::CPU;
122-
case Backend::SparseCPU:
123-
return Backend::SparseCPU;
124-
case Backend::SparseCUDA:
125-
return Backend::SparseCPU;
126-
case Backend::Undefined:
127-
return Backend::Undefined;
128-
default:
129-
AT_ERROR("Unknown backend");
130-
}
131-
}
132-
133-
static inline Backend backendToCUDA(Backend b) {
134-
switch (b) {
135-
case Backend::CPU:
136-
return Backend::CUDA;
137-
case Backend::CUDA:
138-
return Backend::CUDA;
139-
case Backend::SparseCPU:
140-
return Backend::SparseCUDA;
141-
case Backend::SparseCUDA:
142-
return Backend::SparseCUDA;
143-
case Backend::Undefined:
144-
return Backend::Undefined;
145-
default:
146-
AT_ERROR("Unknown backend");
147-
}
148-
}
149-
150-
constexpr DeviceType kCPU = DeviceType::CPU;
151-
constexpr DeviceType kCUDA = DeviceType::CUDA;
152-
153-
static inline const char* toString(Backend b) {
154-
switch (b) {
155-
case Backend::CPU:
156-
return "CPU";
157-
case Backend::CUDA:
158-
return "CUDA";
159-
case Backend::SparseCPU:
160-
return "SparseCPU";
161-
case Backend::SparseCUDA:
162-
return "SparseCUDA";
163-
default:
164-
return "UNKNOWN_BACKEND";
165-
}
166-
}
167-
168-
} // namespace at
2+
#include <ATen/core/Backend.h>

aten/src/ATen/CPUGeneral.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
// linking errors using MSVC
55
// See https://msdn.microsoft.com/en-us/library/a90k134d.aspx
66
// This header adds this if using AT_API
7-
#include "ATen/ATenGeneral.h"
7+
#include "ATen/core/ATenGeneral.h"
88

99
namespace at {
1010
AT_API void set_num_threads(int);

aten/src/ATen/CheckGenerator.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22

3-
#include "ATen/Generator.h"
3+
#include "ATen/core/Generator.h"
44
#include "ATen/Utils.h"
55
#include "ATen/core/Error.h"
66

aten/src/ATen/Context.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#pragma once
22

33
#include <ATen/CPUGeneral.h>
4-
#include "ATen/ATenGeneral.h"
4+
#include "ATen/core/ATenGeneral.h"
55
#include "ATen/CUDAStream.h"
6-
#include "ATen/Generator.h"
6+
#include "ATen/core/Generator.h"
77
#include "ATen/Type.h"
88
#include "ATen/Utils.h"
99
#include "ATen/core/Error.h"

aten/src/ATen/Device.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
#pragma once
22

3-
#include <ATen/ATenGeneral.h>
3+
#include <ATen/core/ATenGeneral.h>
44
#include <ATen/core/Error.h>
55
#include <ATen/core/DeviceType.h>
66
#include <ATen/core/Error.h>
7-
#include <ATen/Backend.h>
7+
#include <ATen/core/Backend.h>
88

99
#include <cstddef>
1010
#include <iosfwd>

aten/src/ATen/Generator.h

+1-24
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,2 @@
11
#pragma once
2-
3-
#include <ATen/ATenGeneral.h>
4-
5-
#include <stdint.h>
6-
7-
namespace at {
8-
9-
struct AT_API Generator {
10-
Generator() {};
11-
Generator(const Generator& other) = delete;
12-
Generator(Generator&& other) = delete;
13-
virtual ~Generator() {};
14-
15-
virtual Generator& copy(const Generator& other) = 0;
16-
virtual Generator& free() = 0;
17-
18-
virtual uint64_t seed() = 0;
19-
virtual uint64_t initialSeed() = 0;
20-
virtual Generator& manualSeed(uint64_t seed) = 0;
21-
virtual Generator& manualSeedAll(uint64_t seed) = 0;
22-
virtual void * unsafeGetTH() = 0;
23-
};
24-
25-
} // namespace at
2+
#include <ATen/core/Generator.h>

aten/src/ATen/Layout.h

+1-34
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,2 @@
11
#pragma once
2-
3-
#include <ATen/Backend.h>
4-
#include <ATen/core/Error.h>
5-
6-
#include <iostream>
7-
8-
namespace at {
9-
enum class Layout { Strided, Sparse };
10-
11-
constexpr auto kStrided = Layout::Strided;
12-
constexpr auto kSparse = Layout::Sparse;
13-
14-
inline Layout layout_from_backend(Backend backend) {
15-
switch (backend) {
16-
case Backend::SparseCPU:
17-
case Backend::SparseCUDA:
18-
return Layout::Sparse;
19-
default:
20-
return Layout::Strided;
21-
}
22-
}
23-
24-
inline std::ostream& operator<<(std::ostream& stream, at::Layout layout) {
25-
switch (layout) {
26-
case at::kStrided:
27-
return stream << "Strided";
28-
case at::kSparse:
29-
return stream << "Sparse";
30-
default:
31-
AT_ERROR("Unknown layout");
32-
}
33-
}
34-
35-
} // namespace at
2+
#include <ATen/core/Layout.h>

aten/src/ATen/OptionsGuard.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#pragma once
22

33
#include <ATen/Device.h>
4-
#include <ATen/Layout.h>
4+
#include <ATen/core/Layout.h>
55
#include <ATen/ScalarType.h>
66
#include <ATen/TensorOptions.h>
77
#include <ATen/core/optional.h>

aten/src/ATen/Registry.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
#include <string>
1919
#include <vector>
2020

21-
#include <ATen/ATenGeneral.h>
21+
#include <ATen/core/ATenGeneral.h>
2222
#include <ATen/core/Backtrace.h>
2323

2424
namespace at {

aten/src/ATen/Scalar.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
#include <string>
77
#include <utility>
88

9-
#include "ATen/ATenGeneral.h"
10-
#include "ATen/ScalarType.h"
9+
#include "ATen/core/ATenGeneral.h"
10+
#include "ATen/core/ScalarType.h"
1111
#include "ATen/TensorBase.h"
1212
#include "ATen/core/Half.h"
1313

aten/src/ATen/ScalarType.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
#pragma once
2-
#include <ATen/ATenGeneral.h> // for BC reasons
3-
#include <ATen/Backend.h>
2+
#include <ATen/core/ATenGeneral.h> // for BC reasons
3+
#include <ATen/core/Backend.h>
44
#include <ATen/core/ScalarType.h>

aten/src/ATen/TensorImpl.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include <ATen/Tensor.h>
55
#include <ATen/core/optional.h>
66
#include <ATen/Context.h>
7-
#include <ATen/Backend.h>
7+
#include <ATen/core/Backend.h>
88

99
#include <ATen/detail/VariableHooksInterface.h>
1010

aten/src/ATen/TensorOptions.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#include <ATen/TensorOptions.h>
22

33
#include <ATen/Device.h>
4-
#include <ATen/Layout.h>
4+
#include <ATen/core/Layout.h>
55
#include <ATen/OptionsGuard.h>
66
#include <ATen/ScalarType.h>
77
#include <ATen/core/optional.h>

aten/src/ATen/TensorOptions.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
#pragma once
22

3-
#include <ATen/Backend.h>
3+
#include <ATen/core/Backend.h>
44
#include <ATen/Context.h>
55
#include <ATen/Device.h>
66
#include <ATen/DeviceGuard.h>
7-
#include <ATen/Layout.h>
7+
#include <ATen/core/Layout.h>
88
#include <ATen/ScalarType.h>
99
#include <ATen/Tensor.h>
1010
#include <ATen/Type.h>

aten/src/ATen/Utils.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22

3-
#include "ATen/ATenGeneral.h"
3+
#include "ATen/core/ATenGeneral.h"
44
#include "ATen/StorageImpl.h"
55
#include "ATen/UndefinedTensor.h"
66

aten/src/ATen/core/ATenGeneral.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#include <ATen/core/ATenGeneral.h>
File renamed without changes.

0 commit comments

Comments
 (0)