Skip to content

Commit 50b3f9d

Browse files
ifedanfacebook-github-bot
authored andcommitted
Allow use cpu_serial_kernel with void-lambda (pytorch#27370)
Summary: pytorch#27271 Pull Request resolved: pytorch#27370 Differential Revision: D17763265 Pulled By: ifedan fbshipit-source-id: d670560dfc555db529b18c01aa42f0ccb2127889
1 parent 19ab538 commit 50b3f9d

3 files changed

Lines changed: 108 additions & 18 deletions

File tree

aten/src/ATen/native/cpu/IsContiguous.h

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,34 +5,58 @@ namespace at { namespace native { namespace {
55
// n: number of function arguments (arity)
66
// traits: function_traits (see FunctionTraits.h)
77
// s: index of scalar argument or -1
8-
template <int n, typename traits, int s=-1>
8+
template <int n, int stride_index, typename traits, int s=-1>
99
struct IsContiguous {
1010
static bool eval(const int64_t* strides) {
1111
using type = typename traits::template arg<n - 1>::type;
12-
return strides[n] == (s == n ? 0 : sizeof(type)) &&
13-
IsContiguous<n - 1, traits, s>::eval(strides);
12+
return strides[stride_index] == (s == n ? 0 : sizeof(type)) &&
13+
IsContiguous<n - 1, stride_index - 1, traits, s>::eval(strides);
1414
}
1515
};
1616

17+
// will be called when there is an output exists
1718
template <typename traits, int s>
18-
struct IsContiguous<0, traits, s> {
19+
struct IsContiguous<0, 0, traits, s> {
1920
static bool eval(const int64_t* strides) {
2021
return strides[0] == sizeof(typename traits::result_type);
2122
}
2223
};
2324

25+
// will be called when there is no output
26+
template <typename traits, int s>
27+
struct IsContiguous<0, -1, traits, s> {
28+
static bool eval(const int64_t* strides) {
29+
return true;
30+
}
31+
};
32+
2433
// output and all inputs are contiguous
25-
template <typename traits>
34+
template <typename traits,
35+
typename std::enable_if<std::is_void<typename traits::result_type>::value>::type* = nullptr>
2636
static inline bool is_contiguous(const int64_t* strides) {
27-
return IsContiguous<traits::arity, traits>::eval(strides);
37+
return IsContiguous<traits::arity, traits::arity - 1, traits>::eval(strides);
38+
}
39+
40+
template <typename traits,
41+
typename std::enable_if<!std::is_void<typename traits::result_type>::value>::type* = nullptr>
42+
static inline bool is_contiguous(const int64_t* strides) {
43+
return IsContiguous<traits::arity, traits::arity, traits>::eval(strides);
2844
}
2945

3046
// input at `s` is scalar (stride 0); output and other inputs are contiguous
3147
// NB: output is typically at strides[0] so first input corresponds to s=1
32-
template <typename traits, int s>
48+
template <typename traits, int s,
49+
typename std::enable_if<std::is_void<typename traits::result_type>::value>::type* = nullptr>
50+
static inline bool is_contiguous_scalar(const int64_t* strides) {
51+
static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds");
52+
return IsContiguous<traits::arity, traits::arity - 1, traits, s>::eval(strides);
53+
}
54+
55+
template <typename traits, int s,
56+
typename std::enable_if<!std::is_void<typename traits::result_type>::value>::type* = nullptr>
3357
static inline bool is_contiguous_scalar(const int64_t* strides) {
3458
static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds");
35-
return IsContiguous<traits::arity, traits, s>::eval(strides);
59+
return IsContiguous<traits::arity, traits::arity, traits, s>::eval(strides);
3660
}
3761

3862
}}}

aten/src/ATen/native/cpu/Loops.h

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,40 @@ dereference_vec(char* C10_RESTRICT data[], const typename traits::result_type& o
8080
return dereference_vec_impl<traits>(data, opt_scalar, S, i, Indices{});
8181
}
8282

83+
template <typename func_t,
84+
typename std::enable_if<!std::is_void<typename function_traits<func_t>::result_type>::value>::type* = nullptr>
85+
static inline void
86+
execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t op) {
87+
using traits = function_traits<func_t>;
88+
using result_type = typename traits::result_type;
89+
for (; i < n; i++) {
90+
result_type* out_ptr = (result_type*)(data[0] + i * strides[0]);
91+
*out_ptr = c10::guts::apply(op, dereference<traits>(
92+
&data[1],
93+
&strides[1],
94+
i));
95+
}
96+
}
97+
98+
template <typename func_t,
99+
typename std::enable_if<std::is_void<typename function_traits<func_t>::result_type>::value>::type* = nullptr>
100+
static inline void
101+
execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t op) {
102+
using traits = function_traits<func_t>;
103+
for (; i < n; i++) {
104+
c10::guts::apply(op, dereference<traits>(
105+
&data[0],
106+
&strides[0],
107+
i));
108+
}
109+
}
110+
83111
// Basic loop operation (one output, N inputs). May be auto-vectorized
84112
// by the compiler. Supports inputs and outputs of different types.
85113
template <typename func_t>
86114
static inline void
87115
basic_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t op) {
88116
using traits = function_traits<func_t>;
89-
using result_type = typename traits::result_type;
90117
constexpr int ntensors = traits::arity + 1;
91118

92119
// Copying strides to temporary array helps auto vectorization in older GCC
@@ -96,13 +123,7 @@ basic_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_
96123
strides[arg] = strides_[arg];
97124
}
98125

99-
for (; i < n; i++) {
100-
result_type* out_ptr = (result_type*)(data[0] + i * strides[0]);
101-
*out_ptr = c10::guts::apply(op, dereference<traits>(
102-
&data[1],
103-
&strides[1],
104-
i));
105-
}
126+
execute_op(data, strides, i, n, op);
106127
}
107128

108129
// Explicitly vectorized loop implementation. All inputs and outputs must be
@@ -205,7 +226,8 @@ void cpu_kernel_vec(TensorIterator& iter, func_t op, vec_func_t vop) {
205226
template <typename func_t>
206227
void cpu_serial_kernel(TensorIterator& iter, func_t op) {
207228
using traits = function_traits<func_t>;
208-
TORCH_INTERNAL_ASSERT(iter.ntensors() >= traits::arity + 1);
229+
TORCH_INTERNAL_ASSERT((std::is_void<typename traits::result_type>::value &&
230+
iter.noutputs() == 0 && iter.ntensors() == traits::arity) || (iter.ntensors() >= traits::arity + 1));
209231

210232
iter.serial_for_each([&](char** data, const int64_t* strides, int64_t n) {
211233
if (is_contiguous<traits>(strides)) {
@@ -217,6 +239,7 @@ void cpu_serial_kernel(TensorIterator& iter, func_t op) {
217239
});
218240
}
219241
}, {0, iter.numel()});
242+
iter.cast_outputs();
220243
}
221244

222245
}}} // namespace at::native::<anonymous>

aten/src/ATen/test/tensor_iterator_test.cpp

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,17 +61,41 @@ TEST(TensorIteratorTest, SerialLoopUnary_##name) { \
6161
ASSERT_ANY_THROW(out.equal(expected)); \
6262
}
6363

64+
#define NO_OUTPUT_UNARY_TEST_ITER_FOR_TYPE(ctype,name) \
65+
TEST(TensorIteratorTest, SerialLoopUnaryNoOutput_##name) { \
66+
auto in = random_tensor_for_type(k##name); \
67+
auto iter = at::TensorIterator(); \
68+
iter.add_input(in); \
69+
iter.build(); \
70+
int64_t acc = 0; \
71+
at::native::cpu_serial_kernel(iter, [&](ctype a) -> void { acc++; }); \
72+
EXPECT_TRUE(acc == in.numel()); \
73+
}
74+
6475
#define BINARY_TEST_ITER_FOR_TYPE(ctype,name) \
6576
TEST(TensorIteratorTest, SerialLoopBinary_##name) { \
6677
Tensor out; \
6778
auto in1 = random_tensor_for_type(k##name); \
6879
auto in2 = random_tensor_for_type(k##name); \
6980
auto expected = in1.add(in2); \
70-
auto iter = TensorIterator::binary_op(out, in1, in2); \
81+
auto iter = TensorIterator::binary_op(out, in1, in2); \
7182
at::native::cpu_serial_kernel(iter, [=](ctype a, ctype b) -> int { return a + b; }); \
7283
ASSERT_ANY_THROW(out.equal(expected)); \
7384
}
7485

86+
#define NO_OUTPUT_BINARY_TEST_ITER_FOR_TYPE(ctype,name) \
87+
TEST(TensorIteratorTest, SerialLoopBinaryNoOutput_##name) { \
88+
auto in1 = random_tensor_for_type(k##name); \
89+
auto in2 = random_tensor_for_type(k##name); \
90+
auto iter = at::TensorIterator(); \
91+
iter.add_input(in1); \
92+
iter.add_input(in2); \
93+
iter.build(); \
94+
int64_t acc = 0; \
95+
at::native::cpu_serial_kernel(iter, [&](ctype a, ctype b) -> void { acc++; }); \
96+
EXPECT_TRUE(acc == in1.numel()); \
97+
}
98+
7599
#define POINTWISE_TEST_ITER_FOR_TYPE(ctype,name) \
76100
TEST(TensorIteratorTest, SerialLoopPointwise_##name) { \
77101
Tensor out; \
@@ -89,6 +113,21 @@ TEST(TensorIteratorTest, SerialLoopPointwise_##name) {
89113
ASSERT_ANY_THROW(out.equal(expected)); \
90114
}
91115

116+
#define NO_OUTPUT_POINTWISE_TEST_ITER_FOR_TYPE(ctype,name) \
117+
TEST(TensorIteratorTest, SerialLoopPoinwiseNoOutput_##name) { \
118+
auto in1 = random_tensor_for_type(k##name); \
119+
auto in2 = random_tensor_for_type(k##name); \
120+
auto in3 = random_tensor_for_type(k##name); \
121+
auto iter = at::TensorIterator(); \
122+
iter.add_input(in1); \
123+
iter.add_input(in2); \
124+
iter.add_input(in3); \
125+
iter.build(); \
126+
int64_t acc = 0; \
127+
at::native::cpu_serial_kernel(iter, [&](ctype a, ctype b, ctype c) -> void { acc++; }); \
128+
EXPECT_TRUE(acc == in1.numel()); \
129+
}
130+
92131
// The alternative way to calculate a < b is (b - a).clamp(0).toBool()
93132
// To prevent an overflow in subtraction (b - a) for unsigned types(unit, bool)
94133
// we will convert in to int first
@@ -112,6 +151,9 @@ TEST(TensorIteratorTest, ComparisonLoopBinary_##name) {
112151
AT_FORALL_SCALAR_TYPES(UNARY_TEST_ITER_FOR_TYPE)
113152
AT_FORALL_SCALAR_TYPES(BINARY_TEST_ITER_FOR_TYPE)
114153
AT_FORALL_SCALAR_TYPES(POINTWISE_TEST_ITER_FOR_TYPE)
154+
AT_FORALL_SCALAR_TYPES(NO_OUTPUT_UNARY_TEST_ITER_FOR_TYPE)
155+
AT_FORALL_SCALAR_TYPES(NO_OUTPUT_BINARY_TEST_ITER_FOR_TYPE)
156+
AT_FORALL_SCALAR_TYPES(NO_OUTPUT_POINTWISE_TEST_ITER_FOR_TYPE)
115157
AT_FORALL_SCALAR_TYPES_AND(Bool, COMPARISON_TEST_ITER_FOR_TYPE)
116158

117159
TEST(TensorIteratorTest, SerialLoopSingleThread) {
@@ -172,3 +214,4 @@ TEST(TensorIteratorTest, DoNotComputeCommonDTypeIfOutputIsUndefined) {
172214
iter.compute_common_dtype_only_for_inputs();
173215
ASSERT_ANY_THROW(iter.build());
174216
}
217+

0 commit comments

Comments
 (0)