Skip to content

Commit beed9b3

Browse files
committed
use ciel_div helper for num_blocks
ggml-ci
1 parent 6c47511 commit beed9b3

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

ggml/src/ggml-sycl/common.hpp

+4
Original file line numberDiff line numberDiff line change
@@ -494,5 +494,9 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor<Tp, dim> acc) {
494494

495495
int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size);
496496

497+
constexpr size_t ceil_div(const size_t m, const size_t n) {
498+
return (m + n - 1) / n;
499+
}
500+
497501
bool gpu_has_xmx(sycl::device &dev);
498502
#endif // GGML_SYCL_COMMON_HPP

ggml/src/ggml-sycl/element_wise.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ static void silu_sycl(const T *x, T *dst, const int k,
363363
template<typename T>
364364
static void sgn_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
365365
// hard code for now
366-
const int num_blocks = (k + 256 - 1) / 256;
366+
const int num_blocks = ceil_div(k, 256);
367367
stream->parallel_for(
368368
sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range(1, 1, 256)), sycl::range(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) {
369369
sgn(x, dst, k, item_ct1);
@@ -373,7 +373,7 @@ static void sgn_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
373373
template<typename T>
374374
static void abs_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
375375
// hard code for now
376-
const int num_blocks = (k + 256 - 1) / 256;
376+
const int num_blocks = ceil_div(k, 256);
377377
stream->parallel_for(
378378
sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) {
379379
abs_op(x, dst, k, item_ct1);
@@ -384,7 +384,7 @@ static void abs_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
384384
template<typename T>
385385
static void elu_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
386386
// hard code for now
387-
const int num_blocks = (k + 256 - 1) / 256;
387+
const int num_blocks = ceil_div(k, 256);
388388
stream->parallel_for(
389389
sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) {
390390
elu_op(x, dst, k, item_ct1);

0 commit comments

Comments
 (0)