Skip to content

Commit 3493787

Browse files
committed
SYCL: Add all missing unary kernels
ggml-ci
1 parent 13be08d commit 3493787

File tree

3 files changed

+197
-0
lines changed

3 files changed

+197
-0
lines changed

ggml/src/ggml-sycl/element_wise.cpp

+179
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
#include "common.hpp"
22
#include "ggml.h"
33
#include "element_wise.hpp"
4+
#include <sycl/detail/builtins/builtins.hpp>
5+
#include <sycl/nd_item.hpp>
6+
#include <sycl/nd_range.hpp>
7+
#include <sycl/range.hpp>
48

59
static void acc_f32(const float * x, const float * y, float * dst, const int ne,
610
const int ne10, const int ne11, const int ne12,
@@ -21,6 +25,33 @@ static void acc_f32(const float * x, const float * y, float * dst, const int ne,
2125
}
2226
}
2327

28+
template<typename T>
29+
static void sgn(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) {
30+
const int i = item_ct1.get_global_id(2);
31+
if (i >= k) {
32+
return;
33+
}
34+
dst[i] = x[i] > static_cast<T>(0.f) ? static_cast<T>(1.f) : ((x[i] < static_cast<T>(0.f) ? static_cast<T>(-1.f) : static_cast<T>(0.f)));
35+
}
36+
37+
template<typename T>
38+
static void abs_op(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) {
39+
const int i = item_ct1.get_global_id(2);
40+
if (i >= k) {
41+
return;
42+
}
43+
dst[i] = sycl::fabs(x[i]);
44+
}
45+
46+
template<typename T>
47+
static void elu_op(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) {
48+
const int i = item_ct1.get_global_id(2);
49+
if (i >= k) {
50+
return;
51+
}
52+
dst[i] = (x[i] > static_cast<T>(0.f)) ? x[i] : sycl::expm1(x[i]);
53+
}
54+
2455
template<typename T>
2556
static void gelu(const T * x, T * dst, const int k,
2657
const sycl::nd_item<3> &item_ct1) {
@@ -335,6 +366,37 @@ static void silu_sycl(const T *x, T *dst, const int k,
335366
});
336367
}
337368

369+
template<typename T>
370+
static void sgn_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
371+
// hard code for now
372+
const int num_blocks = (k + 256 - 1) / 256;
373+
stream->parallel_for(
374+
sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range(1, 1, 256)), sycl::range(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) {
375+
sgn(x, dst, k, item_ct1);
376+
});
377+
}
378+
379+
template<typename T>
380+
static void abs_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
381+
// hard code for now
382+
const int num_blocks = (k + 256 - 1) / 256;
383+
stream->parallel_for(
384+
sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) {
385+
abs_op(x, dst, k, item_ct1);
386+
});
387+
}
388+
389+
390+
template<typename T>
391+
static void elu_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
392+
// hard code for now
393+
const int num_blocks = (k + 256 - 1) / 256;
394+
stream->parallel_for(
395+
sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) {
396+
elu_op(x, dst, k, item_ct1);
397+
});
398+
}
399+
338400
template<typename T>
339401
static void gelu_quick_sycl(const T *x, T *dst, const int k,
340402
queue_ptr stream) {
@@ -574,6 +636,106 @@ static void clamp_sycl(const T *x, T *dst, const float min,
574636
});
575637
}
576638

639+
inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
640+
#if defined (GGML_SYCL_F16)
641+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
642+
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
643+
644+
#else
645+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
646+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
647+
#endif
648+
GGML_ASSERT(dst->src[0]->type == dst->type);
649+
dpct::queue_ptr main_stream = ctx.stream();
650+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
651+
switch (dst->type) {
652+
#if defined (GGML_SYCL_F16)
653+
case GGML_TYPE_F16:
654+
{
655+
auto data_pts = cast_data<sycl::half>(dst);
656+
sgn_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
657+
break;
658+
}
659+
#endif
660+
case GGML_TYPE_F32:
661+
{
662+
auto data_pts = cast_data<float>(dst);
663+
sgn_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
664+
break;
665+
}
666+
default:
667+
GGML_ABORT("GGML tensor type not supported!\n");
668+
break;
669+
}
670+
}
671+
672+
inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
673+
#if defined (GGML_SYCL_F16)
674+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
675+
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
676+
677+
#else
678+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
679+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
680+
#endif
681+
GGML_ASSERT(dst->src[0]->type == dst->type);
682+
dpct::queue_ptr main_stream = ctx.stream();
683+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
684+
switch (dst->type) {
685+
#if defined (GGML_SYCL_F16)
686+
case GGML_TYPE_F16:
687+
{
688+
auto data_pts = cast_data<sycl::half>(dst);
689+
abs_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
690+
break;
691+
}
692+
#endif
693+
case GGML_TYPE_F32:
694+
{
695+
auto data_pts = cast_data<float>(dst);
696+
abs_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
697+
break;
698+
}
699+
default:
700+
GGML_ABORT("GGML tensor type not supported!\n");
701+
break;
702+
}
703+
}
704+
705+
706+
inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
707+
#if defined (GGML_SYCL_F16)
708+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
709+
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
710+
711+
#else
712+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
713+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
714+
#endif
715+
GGML_ASSERT(dst->src[0]->type == dst->type);
716+
dpct::queue_ptr main_stream = ctx.stream();
717+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
718+
switch (dst->type) {
719+
#if defined (GGML_SYCL_F16)
720+
case GGML_TYPE_F16:
721+
{
722+
auto data_pts = cast_data<sycl::half>(dst);
723+
elu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
724+
break;
725+
}
726+
#endif
727+
case GGML_TYPE_F32:
728+
{
729+
auto data_pts = cast_data<float>(dst);
730+
elu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
731+
break;
732+
}
733+
default:
734+
GGML_ABORT("GGML tensor type not supported!\n");
735+
break;
736+
}
737+
}
738+
577739
inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
578740
#if defined (GGML_SYCL_F16)
579741
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
@@ -1388,3 +1550,20 @@ void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
13881550
GGML_SYCL_DEBUG("call %s done\n", __func__);
13891551
}
13901552

1553+
void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1554+
GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type));
1555+
ggml_sycl_op_sgn(ctx, dst);
1556+
GGML_SYCL_DEBUG("call %s done\n", __func__);
1557+
}
1558+
1559+
void ggml_sycl_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1560+
GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type));
1561+
ggml_sycl_op_abs(ctx, dst);
1562+
GGML_SYCL_DEBUG("call %s done\n", __func__);
1563+
}
1564+
1565+
void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1566+
GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type));
1567+
ggml_sycl_op_elu(ctx, dst);
1568+
GGML_SYCL_DEBUG("call %s done\n", __func__);
1569+
}

ggml/src/ggml-sycl/element_wise.hpp

+5
Original file line numberDiff line numberDiff line change
@@ -66,5 +66,10 @@ void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
6666

6767
void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
6868

69+
void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
70+
71+
void ggml_sycl_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
72+
73+
void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
6974
#endif // GGML_SYCL_ELEMENTWISE_HPP
7075

ggml/src/ggml-sycl/ggml-sycl.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939
#include "ggml-sycl/backend.hpp"
4040
#include "ggml-sycl/common.hpp"
41+
#include "ggml-sycl/element_wise.hpp"
4142
#include "ggml-sycl/presets.hpp"
4243
#include "ggml-sycl/gemm.hpp"
4344
#include "ggml-sycl/sycl_hw.hpp"
@@ -3295,6 +3296,15 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
32953296
case GGML_UNARY_OP_EXP:
32963297
ggml_sycl_exp(ctx, dst);
32973298
break;
3299+
case GGML_UNARY_OP_SGN:
3300+
ggml_sycl_sgn(ctx, dst);
3301+
break;
3302+
case GGML_UNARY_OP_ABS:
3303+
ggml_sycl_abs(ctx, dst);
3304+
break;
3305+
case GGML_UNARY_OP_ELU:
3306+
ggml_sycl_elu(ctx, dst);
3307+
break;
32983308
default:
32993309
return false;
33003310
}
@@ -3840,6 +3850,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
38403850
case GGML_UNARY_OP_GELU_QUICK:
38413851
case GGML_UNARY_OP_TANH:
38423852
case GGML_UNARY_OP_EXP:
3853+
case GGML_UNARY_OP_SGN:
3854+
case GGML_UNARY_OP_ABS:
3855+
case GGML_UNARY_OP_ELU:
38433856
#if defined (GGML_SYCL_F16)
38443857
return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type);
38453858
#else

0 commit comments

Comments
 (0)