5
5
#include < ATen/native/Pool.h>
6
6
7
7
#include < ATen/native/xpu/sycl/AveragePool2dKernels.h>
8
+ #include < ATen/native/xpu/sycl/KernelUtils.h>
8
9
#include < comm/Runtime.h>
9
10
#include < comm/SYCLContext.h>
10
11
#include < comm/SYCLHelpers.h>
@@ -25,9 +26,7 @@ inline int max(int a, int b) {
25
26
template <typename scalar_t , typename accscalar_t , typename index_t >
26
27
struct AvgPool2dKernelFunctor {
27
28
void operator ()(sycl::nd_item<1 > item) const {
28
- index_t index = item.get_global_linear_id ();
29
-
30
- if (index < total_elements_) {
29
+ XPU_KERNEL_LOOP (item, index, total_elements_) {
31
30
const int pw = index % pooled_width_;
32
31
const int ph = (index / pooled_width_) % pooled_height_;
33
32
const int c = (index / pooled_width_ / pooled_height_) % channels_;
@@ -73,19 +72,19 @@ struct AvgPool2dKernelFunctor {
73
72
AvgPool2dKernelFunctor (
74
73
scalar_t * top_data,
75
74
const scalar_t * bottom_data,
76
- index_t total_elements,
77
- index_t channels,
78
- index_t height,
79
- index_t width,
80
- int pooled_height,
81
- int pooled_width,
82
- int kernel_h,
83
- int kernel_w,
84
- int stride_h,
85
- int stride_w,
86
- int pad_h,
87
- int pad_w,
88
- int divisor_override,
75
+ const int total_elements,
76
+ const int64_t channels,
77
+ const int64_t height,
78
+ const int64_t width,
79
+ const int64_t pooled_height,
80
+ const int pooled_width,
81
+ const int kernel_h,
82
+ const int kernel_w,
83
+ const int stride_h,
84
+ const int stride_w,
85
+ const int pad_h,
86
+ const int pad_w,
87
+ const int divisor_override,
89
88
bool count_include_pad,
90
89
bool use_divisor)
91
90
: top_data_(top_data),
@@ -109,29 +108,27 @@ struct AvgPool2dKernelFunctor {
109
108
private:
110
109
scalar_t * top_data_;
111
110
const scalar_t * bottom_data_;
112
- index_t total_elements_;
113
- index_t channels_;
114
- index_t height_;
115
- index_t width_;
116
- int pooled_height_;
117
- int pooled_width_;
118
- int kernel_h_;
119
- int kernel_w_;
120
- int stride_h_;
121
- int stride_w_;
122
- int pad_h_;
123
- int pad_w_;
124
- int divisor_override_;
111
+ const int total_elements_;
112
+ const int64_t channels_;
113
+ const int64_t height_;
114
+ const int64_t width_;
115
+ const int64_t pooled_height_;
116
+ const int pooled_width_;
117
+ const int kernel_h_;
118
+ const int kernel_w_;
119
+ const int stride_h_;
120
+ const int stride_w_;
121
+ const int pad_h_;
122
+ const int pad_w_;
123
+ const int divisor_override_;
125
124
bool count_include_pad_;
126
125
bool use_divisor_;
127
126
};
128
127
129
128
template <typename scalar_t , typename accscalar_t , typename index_t >
130
129
struct AvgPool2dChannelsLastKernelFunctor {
131
130
void operator ()(sycl::nd_item<1 > item) const {
132
- index_t index = item.get_global_linear_id ();
133
-
134
- if (index < total_elements_) {
131
+ XPU_KERNEL_LOOP (item, index, total_elements_) {
135
132
const int c = index % channels_;
136
133
const int pw = (index / channels_) % pooled_width_;
137
134
const int ph = (index / channels_ / pooled_width_) % pooled_height_;
@@ -327,8 +324,7 @@ void launch_avg_pool2d_kernel(
327
324
template <typename scalar_t , typename accscalar_t , typename index_t >
328
325
struct AvgPool2dChannelsLastBackwardKernelFunctor {
329
326
void operator ()(sycl::nd_item<1 > item) const {
330
- index_t index = item.get_global_linear_id ();
331
- if (index < total_elements_) {
327
+ XPU_KERNEL_LOOP_TYPE (item, index, total_elements_, index_t ) {
332
328
const int c = index % channels_;
333
329
const int w = (index / channels_) % width_ + pad_w_;
334
330
const int h = (index / channels_ / width_) % height_ + pad_h_;
@@ -431,8 +427,7 @@ struct AvgPool2dChannelsLastBackwardKernelFunctor {
431
427
template <typename scalar_t , typename accscalar_t , typename index_t >
432
428
struct AvgPool2dBackwarKernelFunctor {
433
429
void operator ()(sycl::nd_item<1 > item) const {
434
- index_t index = item.get_global_linear_id ();
435
- if (index < total_elements_) {
430
+ XPU_KERNEL_LOOP_TYPE (item, index, total_elements_, index_t ) {
436
431
// find out the local index
437
432
// find out the local offset
438
433
const int w = index % width_ + pad_w_;
0 commit comments