Skip to content

Commit f301733

Browse files
Fix segmentation fault and calculation error in AveragePool2dKernel (#2091)
Fixed the following issues found by test/test_nn.py::TestNNDeviceTypeXPU::test_avg_pool_large_tensor2_xpu 1. A segmentation fault caused by a data type conversion error that invalidated the memory address. 2. A calculation error caused by data overflow. --------- Co-authored-by: Cui, Yifeng <[email protected]>
1 parent d5a81e0 commit f301733

File tree

1 file changed

+31
-36
lines changed

1 file changed

+31
-36
lines changed

src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp

Lines changed: 31 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <ATen/native/Pool.h>
66

77
#include <ATen/native/xpu/sycl/AveragePool2dKernels.h>
8+
#include <ATen/native/xpu/sycl/KernelUtils.h>
89
#include <comm/Runtime.h>
910
#include <comm/SYCLContext.h>
1011
#include <comm/SYCLHelpers.h>
@@ -25,9 +26,7 @@ inline int max(int a, int b) {
2526
template <typename scalar_t, typename accscalar_t, typename index_t>
2627
struct AvgPool2dKernelFunctor {
2728
void operator()(sycl::nd_item<1> item) const {
28-
index_t index = item.get_global_linear_id();
29-
30-
if (index < total_elements_) {
29+
XPU_KERNEL_LOOP(item, index, total_elements_) {
3130
const int pw = index % pooled_width_;
3231
const int ph = (index / pooled_width_) % pooled_height_;
3332
const int c = (index / pooled_width_ / pooled_height_) % channels_;
@@ -73,19 +72,19 @@ struct AvgPool2dKernelFunctor {
7372
AvgPool2dKernelFunctor(
7473
scalar_t* top_data,
7574
const scalar_t* bottom_data,
76-
index_t total_elements,
77-
index_t channels,
78-
index_t height,
79-
index_t width,
80-
int pooled_height,
81-
int pooled_width,
82-
int kernel_h,
83-
int kernel_w,
84-
int stride_h,
85-
int stride_w,
86-
int pad_h,
87-
int pad_w,
88-
int divisor_override,
75+
const int total_elements,
76+
const int64_t channels,
77+
const int64_t height,
78+
const int64_t width,
79+
const int64_t pooled_height,
80+
const int pooled_width,
81+
const int kernel_h,
82+
const int kernel_w,
83+
const int stride_h,
84+
const int stride_w,
85+
const int pad_h,
86+
const int pad_w,
87+
const int divisor_override,
8988
bool count_include_pad,
9089
bool use_divisor)
9190
: top_data_(top_data),
@@ -109,29 +108,27 @@ struct AvgPool2dKernelFunctor {
109108
private:
110109
scalar_t* top_data_;
111110
const scalar_t* bottom_data_;
112-
index_t total_elements_;
113-
index_t channels_;
114-
index_t height_;
115-
index_t width_;
116-
int pooled_height_;
117-
int pooled_width_;
118-
int kernel_h_;
119-
int kernel_w_;
120-
int stride_h_;
121-
int stride_w_;
122-
int pad_h_;
123-
int pad_w_;
124-
int divisor_override_;
111+
const int total_elements_;
112+
const int64_t channels_;
113+
const int64_t height_;
114+
const int64_t width_;
115+
const int64_t pooled_height_;
116+
const int pooled_width_;
117+
const int kernel_h_;
118+
const int kernel_w_;
119+
const int stride_h_;
120+
const int stride_w_;
121+
const int pad_h_;
122+
const int pad_w_;
123+
const int divisor_override_;
125124
bool count_include_pad_;
126125
bool use_divisor_;
127126
};
128127

129128
template <typename scalar_t, typename accscalar_t, typename index_t>
130129
struct AvgPool2dChannelsLastKernelFunctor {
131130
void operator()(sycl::nd_item<1> item) const {
132-
index_t index = item.get_global_linear_id();
133-
134-
if (index < total_elements_) {
131+
XPU_KERNEL_LOOP(item, index, total_elements_) {
135132
const int c = index % channels_;
136133
const int pw = (index / channels_) % pooled_width_;
137134
const int ph = (index / channels_ / pooled_width_) % pooled_height_;
@@ -327,8 +324,7 @@ void launch_avg_pool2d_kernel(
327324
template <typename scalar_t, typename accscalar_t, typename index_t>
328325
struct AvgPool2dChannelsLastBackwardKernelFunctor {
329326
void operator()(sycl::nd_item<1> item) const {
330-
index_t index = item.get_global_linear_id();
331-
if (index < total_elements_) {
327+
XPU_KERNEL_LOOP_TYPE(item, index, total_elements_, index_t) {
332328
const int c = index % channels_;
333329
const int w = (index / channels_) % width_ + pad_w_;
334330
const int h = (index / channels_ / width_) % height_ + pad_h_;
@@ -431,8 +427,7 @@ struct AvgPool2dChannelsLastBackwardKernelFunctor {
431427
template <typename scalar_t, typename accscalar_t, typename index_t>
432428
struct AvgPool2dBackwarKernelFunctor {
433429
void operator()(sycl::nd_item<1> item) const {
434-
index_t index = item.get_global_linear_id();
435-
if (index < total_elements_) {
430+
XPU_KERNEL_LOOP_TYPE(item, index, total_elements_, index_t) {
436431
// find out the local index
437432
// find out the local offset
438433
const int w = index % width_ + pad_w_;

0 commit comments

Comments
 (0)