Skip to content

Commit 7164d03

Browse files
authored
Use nullptr to represent fallback kernels
Differential Revision: D76201866 Pull Request resolved: #11484
1 parent e6440a0 commit 7164d03

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

runtime/kernel/operator_registry.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,12 @@ constexpr uint32_t kMaxRegisteredKernels = kMaxOperators * kMaxKernelsPerOp;
3535
// require constructing them at init time. Since we don't care about the values
3636
// until we add each entry to the table, allocate static zeroed memory instead
3737
// and point the table at it.
38+
struct alignas(Kernel) KernelBuffer {
39+
uint8_t data[sizeof(Kernel)];
40+
};
41+
3842
// @lint-ignore CLANGTIDY facebook-hte-CArray
39-
alignas(sizeof(Kernel)) uint8_t
40-
registered_kernels_data[kMaxRegisteredKernels * sizeof(Kernel)];
43+
KernelBuffer registered_kernels_data[kMaxRegisteredKernels];
4144

4245
/// Global table of registered kernels.
4346
Kernel* registered_kernels = reinterpret_cast<Kernel*>(registered_kernels_data);

runtime/kernel/operator_registry.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,15 +123,15 @@ struct KernelKey {
123123
* for all input tensor dtypes and dim orders if the specialized kernel is not
124124
* registered.
125125
*/
126-
KernelKey() : is_fallback_(true) {}
126+
KernelKey() = default;
127127

128128
/**
129129
* Creates a specialized (non-fallback) kernel key that matches a specific
130130
* set of input tensor dtypes and dim orders. See the class comment for the
131131
* expected format of `kernel_key_data`.
132132
*/
133133
/* implicit */ KernelKey(const char* kernel_key_data)
134-
: kernel_key_data_(kernel_key_data), is_fallback_(false) {}
134+
: kernel_key_data_(kernel_key_data) {}
135135

136136
bool operator==(const KernelKey& other) const {
137137
return this->equals(other);
@@ -142,17 +142,17 @@ struct KernelKey {
142142
}
143143

144144
bool equals(const KernelKey& other) const {
145-
if (is_fallback_ != other.is_fallback_) {
145+
if (is_fallback() != other.is_fallback()) {
146146
return false;
147147
}
148-
if (is_fallback_) {
148+
if (is_fallback()) {
149149
return true;
150150
}
151151
return strcmp(kernel_key_data_, other.kernel_key_data_) == 0;
152152
}
153153

154154
bool is_fallback() const {
155-
return is_fallback_;
155+
return kernel_key_data_ == nullptr;
156156
}
157157

158158
const char* data() const {
@@ -168,7 +168,6 @@ struct KernelKey {
168168

169169
private:
170170
const char* kernel_key_data_ = nullptr;
171-
bool is_fallback_;
172171
};
173172

174173
/**

0 commit comments

Comments
 (0)