@@ -363,7 +363,7 @@ static void silu_sycl(const T *x, T *dst, const int k,
363
363
template <typename T>
364
364
static void sgn_sycl (const T * x, T * dst, const int k, queue_ptr stream) {
365
365
// hard code for now
366
- const int num_blocks = (k + 256 - 1 ) / 256 ;
366
+ const int num_blocks = ceil_div (k, 256 ) ;
367
367
stream->parallel_for (
368
368
sycl::nd_range<3 >((sycl::range<3 >(1 , 1 , num_blocks) * sycl::range (1 , 1 , 256 )), sycl::range (1 , 1 , 256 )), [=](sycl::nd_item<3 > item_ct1) {
369
369
sgn (x, dst, k, item_ct1);
@@ -373,7 +373,7 @@ static void sgn_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
373
373
template <typename T>
374
374
static void abs_sycl (const T * x, T * dst, const int k, queue_ptr stream) {
375
375
// hard code for now
376
- const int num_blocks = (k + 256 - 1 ) / 256 ;
376
+ const int num_blocks = ceil_div (k, 256 ) ;
377
377
stream->parallel_for (
378
378
sycl::nd_range<3 >((sycl::range<3 >(1 , 1 , num_blocks) * sycl::range<3 >(1 , 1 , 256 )), sycl::range<3 >(1 , 1 , 256 )), [=](sycl::nd_item<3 > item_ct1) {
379
379
abs_op (x, dst, k, item_ct1);
@@ -384,7 +384,7 @@ static void abs_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
384
384
template <typename T>
385
385
static void elu_sycl (const T * x, T * dst, const int k, queue_ptr stream) {
386
386
// hard code for now
387
- const int num_blocks = (k + 256 - 1 ) / 256 ;
387
+ const int num_blocks = ceil_div (k, 256 ) ;
388
388
stream->parallel_for (
389
389
sycl::nd_range<3 >((sycl::range<3 >(1 , 1 , num_blocks) * sycl::range<3 >(1 , 1 , 256 )), sycl::range<3 >(1 , 1 , 256 )), [=](sycl::nd_item<3 > item_ct1) {
390
390
elu_op (x, dst, k, item_ct1);
0 commit comments