Skip to content

Commit a4c340f

Browse files
authored
SYCL: Add all missing unary kernels (#13074)
* SYCL: Add all missing unary kernels ggml-ci * decouple kernel launch range from data size using strided loop * use ciel_div helper for num_blocks ggml-ci * clean auto imported header files
1 parent d0a417f commit a4c340f

File tree

4 files changed

+191
-0
lines changed

4 files changed

+191
-0
lines changed

ggml/src/ggml-sycl/common.hpp

+4
Original file line numberDiff line numberDiff line change
@@ -493,5 +493,9 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor<Tp, dim> acc) {
493493

494494
int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size);
495495

496+
constexpr size_t ceil_div(const size_t m, const size_t n) {
497+
return (m + n - 1) / n;
498+
}
499+
496500
bool gpu_has_xmx(sycl::device &dev);
497501
#endif // GGML_SYCL_COMMON_HPP

ggml/src/ggml-sycl/element_wise.cpp

+169
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,27 @@ static void acc_f32(const float * x, const float * y, float * dst, const int ne,
2121
}
2222
}
2323

24+
template<typename T>
25+
static void sgn(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) {
26+
for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) {
27+
dst[i] = x[i] > static_cast<T>(0.f) ? static_cast<T>(1.f) : ((x[i] < static_cast<T>(0.f) ? static_cast<T>(-1.f) : static_cast<T>(0.f)));
28+
}
29+
}
30+
31+
template<typename T>
32+
static void abs_op(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) {
33+
for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) {
34+
dst[i] = sycl::fabs(x[i]);
35+
}
36+
}
37+
38+
template<typename T>
39+
static void elu_op(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) {
40+
for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) {
41+
dst[i] = (x[i] > static_cast<T>(0.f)) ? x[i] : sycl::expm1(x[i]);
42+
}
43+
}
44+
2445
template<typename T>
2546
static void gelu(const T * x, T * dst, const int k,
2647
const sycl::nd_item<3> &item_ct1) {
@@ -335,6 +356,37 @@ static void silu_sycl(const T *x, T *dst, const int k,
335356
});
336357
}
337358

359+
template<typename T>
360+
static void sgn_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
361+
// hard code for now
362+
const int num_blocks = ceil_div(k, 256);
363+
stream->parallel_for(
364+
sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range(1, 1, 256)), sycl::range(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) {
365+
sgn(x, dst, k, item_ct1);
366+
});
367+
}
368+
369+
template<typename T>
370+
static void abs_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
371+
// hard code for now
372+
const int num_blocks = ceil_div(k, 256);
373+
stream->parallel_for(
374+
sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) {
375+
abs_op(x, dst, k, item_ct1);
376+
});
377+
}
378+
379+
380+
template<typename T>
381+
static void elu_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
382+
// hard code for now
383+
const int num_blocks = ceil_div(k, 256);
384+
stream->parallel_for(
385+
sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) {
386+
elu_op(x, dst, k, item_ct1);
387+
});
388+
}
389+
338390
template<typename T>
339391
static void gelu_quick_sycl(const T *x, T *dst, const int k,
340392
queue_ptr stream) {
@@ -574,6 +626,106 @@ static void clamp_sycl(const T *x, T *dst, const float min,
574626
});
575627
}
576628

629+
inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
630+
#if defined (GGML_SYCL_F16)
631+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
632+
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
633+
634+
#else
635+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
636+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
637+
#endif
638+
GGML_ASSERT(dst->src[0]->type == dst->type);
639+
dpct::queue_ptr main_stream = ctx.stream();
640+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
641+
switch (dst->type) {
642+
#if defined (GGML_SYCL_F16)
643+
case GGML_TYPE_F16:
644+
{
645+
auto data_pts = cast_data<sycl::half>(dst);
646+
sgn_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
647+
break;
648+
}
649+
#endif
650+
case GGML_TYPE_F32:
651+
{
652+
auto data_pts = cast_data<float>(dst);
653+
sgn_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
654+
break;
655+
}
656+
default:
657+
GGML_ABORT("GGML tensor type not supported!\n");
658+
break;
659+
}
660+
}
661+
662+
inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
663+
#if defined (GGML_SYCL_F16)
664+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
665+
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
666+
667+
#else
668+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
669+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
670+
#endif
671+
GGML_ASSERT(dst->src[0]->type == dst->type);
672+
dpct::queue_ptr main_stream = ctx.stream();
673+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
674+
switch (dst->type) {
675+
#if defined (GGML_SYCL_F16)
676+
case GGML_TYPE_F16:
677+
{
678+
auto data_pts = cast_data<sycl::half>(dst);
679+
abs_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
680+
break;
681+
}
682+
#endif
683+
case GGML_TYPE_F32:
684+
{
685+
auto data_pts = cast_data<float>(dst);
686+
abs_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
687+
break;
688+
}
689+
default:
690+
GGML_ABORT("GGML tensor type not supported!\n");
691+
break;
692+
}
693+
}
694+
695+
696+
inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
697+
#if defined (GGML_SYCL_F16)
698+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
699+
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
700+
701+
#else
702+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
703+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
704+
#endif
705+
GGML_ASSERT(dst->src[0]->type == dst->type);
706+
dpct::queue_ptr main_stream = ctx.stream();
707+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
708+
switch (dst->type) {
709+
#if defined (GGML_SYCL_F16)
710+
case GGML_TYPE_F16:
711+
{
712+
auto data_pts = cast_data<sycl::half>(dst);
713+
elu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
714+
break;
715+
}
716+
#endif
717+
case GGML_TYPE_F32:
718+
{
719+
auto data_pts = cast_data<float>(dst);
720+
elu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
721+
break;
722+
}
723+
default:
724+
GGML_ABORT("GGML tensor type not supported!\n");
725+
break;
726+
}
727+
}
728+
577729
inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
578730
#if defined (GGML_SYCL_F16)
579731
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
@@ -1388,3 +1540,20 @@ void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
13881540
GGML_SYCL_DEBUG("call %s done\n", __func__);
13891541
}
13901542

1543+
void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1544+
GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type));
1545+
ggml_sycl_op_sgn(ctx, dst);
1546+
GGML_SYCL_DEBUG("call %s done\n", __func__);
1547+
}
1548+
1549+
void ggml_sycl_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1550+
GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type));
1551+
ggml_sycl_op_abs(ctx, dst);
1552+
GGML_SYCL_DEBUG("call %s done\n", __func__);
1553+
}
1554+
1555+
void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1556+
GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type));
1557+
ggml_sycl_op_elu(ctx, dst);
1558+
GGML_SYCL_DEBUG("call %s done\n", __func__);
1559+
}

ggml/src/ggml-sycl/element_wise.hpp

+5
Original file line numberDiff line numberDiff line change
@@ -66,5 +66,10 @@ void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
6666

6767
void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
6868

69+
void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
70+
71+
void ggml_sycl_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
72+
73+
void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
6974
#endif // GGML_SYCL_ELEMENTWISE_HPP
7075

ggml/src/ggml-sycl/ggml-sycl.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939
#include "ggml-sycl/backend.hpp"
4040
#include "ggml-sycl/common.hpp"
41+
#include "ggml-sycl/element_wise.hpp"
4142
#include "ggml-sycl/presets.hpp"
4243
#include "ggml-sycl/gemm.hpp"
4344
#include "ggml-sycl/sycl_hw.hpp"
@@ -3355,6 +3356,15 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
33553356
case GGML_UNARY_OP_EXP:
33563357
ggml_sycl_exp(ctx, dst);
33573358
break;
3359+
case GGML_UNARY_OP_SGN:
3360+
ggml_sycl_sgn(ctx, dst);
3361+
break;
3362+
case GGML_UNARY_OP_ABS:
3363+
ggml_sycl_abs(ctx, dst);
3364+
break;
3365+
case GGML_UNARY_OP_ELU:
3366+
ggml_sycl_elu(ctx, dst);
3367+
break;
33583368
default:
33593369
return false;
33603370
}
@@ -3837,6 +3847,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
38373847
case GGML_UNARY_OP_GELU_QUICK:
38383848
case GGML_UNARY_OP_TANH:
38393849
case GGML_UNARY_OP_EXP:
3850+
case GGML_UNARY_OP_SGN:
3851+
case GGML_UNARY_OP_ABS:
3852+
case GGML_UNARY_OP_ELU:
38403853
#if defined (GGML_SYCL_F16)
38413854
return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type);
38423855
#else

0 commit comments

Comments
 (0)