1
1
#include " common.hpp"
2
2
#include " ggml.h"
3
3
#include " element_wise.hpp"
4
+ #include < sycl/detail/builtins/builtins.hpp>
5
+ #include < sycl/nd_item.hpp>
6
+ #include < sycl/nd_range.hpp>
7
+ #include < sycl/range.hpp>
4
8
5
9
static void acc_f32 (const float * x, const float * y, float * dst, const int ne,
6
10
const int ne10, const int ne11, const int ne12,
@@ -21,6 +25,33 @@ static void acc_f32(const float * x, const float * y, float * dst, const int ne,
21
25
}
22
26
}
23
27
28
+ template <typename T>
29
+ static void sgn (const T * x, T * dst, const int k, const sycl::nd_item<3 > &item_ct1) {
30
+ const int i = item_ct1.get_global_id (2 );
31
+ if (i >= k) {
32
+ return ;
33
+ }
34
+ dst[i] = x[i] > static_cast <T>(0 .f ) ? static_cast <T>(1 .f ) : ((x[i] < static_cast <T>(0 .f ) ? static_cast <T>(-1 .f ) : static_cast <T>(0 .f )));
35
+ }
36
+
37
+ template <typename T>
38
+ static void abs_op (const T * x, T * dst, const int k, const sycl::nd_item<3 > &item_ct1) {
39
+ const int i = item_ct1.get_global_id (2 );
40
+ if (i >= k) {
41
+ return ;
42
+ }
43
+ dst[i] = sycl::fabs (x[i]);
44
+ }
45
+
46
+ template <typename T>
47
+ static void elu_op (const T * x, T * dst, const int k, const sycl::nd_item<3 > &item_ct1) {
48
+ const int i = item_ct1.get_global_id (2 );
49
+ if (i >= k) {
50
+ return ;
51
+ }
52
+ dst[i] = (x[i] > static_cast <T>(0 .f )) ? x[i] : sycl::expm1 (x[i]);
53
+ }
54
+
24
55
template <typename T>
25
56
static void gelu (const T * x, T * dst, const int k,
26
57
const sycl::nd_item<3 > &item_ct1) {
@@ -335,6 +366,37 @@ static void silu_sycl(const T *x, T *dst, const int k,
335
366
});
336
367
}
337
368
369
+ template <typename T>
370
+ static void sgn_sycl (const T * x, T * dst, const int k, queue_ptr stream) {
371
+ // hard code for now
372
+ const int num_blocks = (k + 256 - 1 ) / 256 ;
373
+ stream->parallel_for (
374
+ sycl::nd_range<3 >((sycl::range<3 >(1 , 1 , num_blocks) * sycl::range (1 , 1 , 256 )), sycl::range (1 , 1 , 256 )), [=](sycl::nd_item<3 > item_ct1) {
375
+ sgn (x, dst, k, item_ct1);
376
+ });
377
+ }
378
+
379
+ template <typename T>
380
+ static void abs_sycl (const T * x, T * dst, const int k, queue_ptr stream) {
381
+ // hard code for now
382
+ const int num_blocks = (k + 256 - 1 ) / 256 ;
383
+ stream->parallel_for (
384
+ sycl::nd_range<3 >((sycl::range<3 >(1 , 1 , num_blocks) * sycl::range<3 >(1 , 1 , 256 )), sycl::range<3 >(1 , 1 , 256 )), [=](sycl::nd_item<3 > item_ct1) {
385
+ abs_op (x, dst, k, item_ct1);
386
+ });
387
+ }
388
+
389
+
390
+ template <typename T>
391
+ static void elu_sycl (const T * x, T * dst, const int k, queue_ptr stream) {
392
+ // hard code for now
393
+ const int num_blocks = (k + 256 - 1 ) / 256 ;
394
+ stream->parallel_for (
395
+ sycl::nd_range<3 >((sycl::range<3 >(1 , 1 , num_blocks) * sycl::range<3 >(1 , 1 , 256 )), sycl::range<3 >(1 , 1 , 256 )), [=](sycl::nd_item<3 > item_ct1) {
396
+ elu_op (x, dst, k, item_ct1);
397
+ });
398
+ }
399
+
338
400
template <typename T>
339
401
static void gelu_quick_sycl (const T *x, T *dst, const int k,
340
402
queue_ptr stream) {
@@ -574,6 +636,106 @@ static void clamp_sycl(const T *x, T *dst, const float min,
574
636
});
575
637
}
576
638
639
+ inline void ggml_sycl_op_sgn (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
640
+ #if defined (GGML_SYCL_F16)
641
+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32 || dst->src [0 ]->type == GGML_TYPE_F16);
642
+ GGML_ASSERT (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
643
+
644
+ #else
645
+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
646
+ GGML_ASSERT (dst->type == GGML_TYPE_F32);
647
+ #endif
648
+ GGML_ASSERT (dst->src [0 ]->type == dst->type );
649
+ dpct::queue_ptr main_stream = ctx.stream ();
650
+ SYCL_CHECK (ggml_sycl_set_device (ctx.device ));
651
+ switch (dst->type ) {
652
+ #if defined (GGML_SYCL_F16)
653
+ case GGML_TYPE_F16:
654
+ {
655
+ auto data_pts = cast_data<sycl::half>(dst);
656
+ sgn_sycl (data_pts.src , data_pts.dst , ggml_nelements (dst->src [0 ]), main_stream);
657
+ break ;
658
+ }
659
+ #endif
660
+ case GGML_TYPE_F32:
661
+ {
662
+ auto data_pts = cast_data<float >(dst);
663
+ sgn_sycl (data_pts.src , data_pts.dst , ggml_nelements (dst->src [0 ]), main_stream);
664
+ break ;
665
+ }
666
+ default :
667
+ GGML_ABORT (" GGML tensor type not supported!\n " );
668
+ break ;
669
+ }
670
+ }
671
+
672
+ inline void ggml_sycl_op_abs (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
673
+ #if defined (GGML_SYCL_F16)
674
+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32 || dst->src [0 ]->type == GGML_TYPE_F16);
675
+ GGML_ASSERT (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
676
+
677
+ #else
678
+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
679
+ GGML_ASSERT (dst->type == GGML_TYPE_F32);
680
+ #endif
681
+ GGML_ASSERT (dst->src [0 ]->type == dst->type );
682
+ dpct::queue_ptr main_stream = ctx.stream ();
683
+ SYCL_CHECK (ggml_sycl_set_device (ctx.device ));
684
+ switch (dst->type ) {
685
+ #if defined (GGML_SYCL_F16)
686
+ case GGML_TYPE_F16:
687
+ {
688
+ auto data_pts = cast_data<sycl::half>(dst);
689
+ abs_sycl (data_pts.src , data_pts.dst , ggml_nelements (dst->src [0 ]), main_stream);
690
+ break ;
691
+ }
692
+ #endif
693
+ case GGML_TYPE_F32:
694
+ {
695
+ auto data_pts = cast_data<float >(dst);
696
+ abs_sycl (data_pts.src , data_pts.dst , ggml_nelements (dst->src [0 ]), main_stream);
697
+ break ;
698
+ }
699
+ default :
700
+ GGML_ABORT (" GGML tensor type not supported!\n " );
701
+ break ;
702
+ }
703
+ }
704
+
705
+
706
+ inline void ggml_sycl_op_elu (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
707
+ #if defined (GGML_SYCL_F16)
708
+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32 || dst->src [0 ]->type == GGML_TYPE_F16);
709
+ GGML_ASSERT (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
710
+
711
+ #else
712
+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
713
+ GGML_ASSERT (dst->type == GGML_TYPE_F32);
714
+ #endif
715
+ GGML_ASSERT (dst->src [0 ]->type == dst->type );
716
+ dpct::queue_ptr main_stream = ctx.stream ();
717
+ SYCL_CHECK (ggml_sycl_set_device (ctx.device ));
718
+ switch (dst->type ) {
719
+ #if defined (GGML_SYCL_F16)
720
+ case GGML_TYPE_F16:
721
+ {
722
+ auto data_pts = cast_data<sycl::half>(dst);
723
+ elu_sycl (data_pts.src , data_pts.dst , ggml_nelements (dst->src [0 ]), main_stream);
724
+ break ;
725
+ }
726
+ #endif
727
+ case GGML_TYPE_F32:
728
+ {
729
+ auto data_pts = cast_data<float >(dst);
730
+ elu_sycl (data_pts.src , data_pts.dst , ggml_nelements (dst->src [0 ]), main_stream);
731
+ break ;
732
+ }
733
+ default :
734
+ GGML_ABORT (" GGML tensor type not supported!\n " );
735
+ break ;
736
+ }
737
+ }
738
+
577
739
inline void ggml_sycl_op_silu (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
578
740
#if defined (GGML_SYCL_F16)
579
741
GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32 || dst->src [0 ]->type == GGML_TYPE_F16);
@@ -1388,3 +1550,20 @@ void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1388
1550
GGML_SYCL_DEBUG (" call %s done\n " , __func__);
1389
1551
}
1390
1552
1553
+ void ggml_sycl_sgn (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1554
+ GGML_SYCL_DEBUG (" call %s: DST Tensor type: %s\n " , __func__, ggml_type_name (dst->type ));
1555
+ ggml_sycl_op_sgn (ctx, dst);
1556
+ GGML_SYCL_DEBUG (" call %s done\n " , __func__);
1557
+ }
1558
+
1559
+ void ggml_sycl_abs (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1560
+ GGML_SYCL_DEBUG (" call %s: DST Tensor type: %s\n " , __func__, ggml_type_name (dst->type ));
1561
+ ggml_sycl_op_abs (ctx, dst);
1562
+ GGML_SYCL_DEBUG (" call %s done\n " , __func__);
1563
+ }
1564
+
1565
+ void ggml_sycl_elu (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1566
+ GGML_SYCL_DEBUG (" call %s: DST Tensor type: %s\n " , __func__, ggml_type_name (dst->type ));
1567
+ ggml_sycl_op_elu (ctx, dst);
1568
+ GGML_SYCL_DEBUG (" call %s done\n " , __func__);
1569
+ }
0 commit comments