Skip to content

Commit 6c47511

Browse files
committed
decouple kernel launch range from data size using strided loop
1 parent 3493787 commit 6c47511

File tree

1 file changed

+6
-12
lines changed

1 file changed

+6
-12
lines changed

ggml/src/ggml-sycl/element_wise.cpp

+6-12
Original file line numberDiff line numberDiff line change
@@ -27,29 +27,23 @@ static void acc_f32(const float * x, const float * y, float * dst, const int ne,
2727

2828
template<typename T>
2929
static void sgn(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) {
30-
const int i = item_ct1.get_global_id(2);
31-
if (i >= k) {
32-
return;
30+
for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) {
31+
dst[i] = x[i] > static_cast<T>(0.f) ? static_cast<T>(1.f) : ((x[i] < static_cast<T>(0.f) ? static_cast<T>(-1.f) : static_cast<T>(0.f)));
3332
}
34-
dst[i] = x[i] > static_cast<T>(0.f) ? static_cast<T>(1.f) : ((x[i] < static_cast<T>(0.f) ? static_cast<T>(-1.f) : static_cast<T>(0.f)));
3533
}
3634

3735
template<typename T>
3836
static void abs_op(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) {
39-
const int i = item_ct1.get_global_id(2);
40-
if (i >= k) {
41-
return;
37+
for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) {
38+
dst[i] = sycl::fabs(x[i]);
4239
}
43-
dst[i] = sycl::fabs(x[i]);
4440
}
4541

4642
template<typename T>
4743
static void elu_op(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) {
48-
const int i = item_ct1.get_global_id(2);
49-
if (i >= k) {
50-
return;
44+
for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) {
45+
dst[i] = (x[i] > static_cast<T>(0.f)) ? x[i] : sycl::expm1(x[i]);
5146
}
52-
dst[i] = (x[i] > static_cast<T>(0.f)) ? x[i] : sycl::expm1(x[i]);
5347
}
5448

5549
template<typename T>

0 commit comments

Comments
 (0)