@@ -21,6 +21,27 @@ static void acc_f32(const float * x, const float * y, float * dst, const int ne,
21
21
}
22
22
}
23
23
24
+ template <typename T>
25
+ static void sgn (const T * x, T * dst, const int k, const sycl::nd_item<3 > &item_ct1) {
26
+ for (auto i = item_ct1.get_global_id (2 ); i < (const size_t )k; i += item_ct1.get_global_range (2 )) {
27
+ dst[i] = x[i] > static_cast <T>(0 .f ) ? static_cast <T>(1 .f ) : ((x[i] < static_cast <T>(0 .f ) ? static_cast <T>(-1 .f ) : static_cast <T>(0 .f )));
28
+ }
29
+ }
30
+
31
+ template <typename T>
32
+ static void abs_op (const T * x, T * dst, const int k, const sycl::nd_item<3 > &item_ct1) {
33
+ for (auto i = item_ct1.get_global_id (2 ); i < (const size_t )k; i += item_ct1.get_global_range (2 )) {
34
+ dst[i] = sycl::fabs (x[i]);
35
+ }
36
+ }
37
+
38
+ template <typename T>
39
+ static void elu_op (const T * x, T * dst, const int k, const sycl::nd_item<3 > &item_ct1) {
40
+ for (auto i = item_ct1.get_global_id (2 ); i < (const size_t )k; i += item_ct1.get_global_range (2 )) {
41
+ dst[i] = (x[i] > static_cast <T>(0 .f )) ? x[i] : sycl::expm1 (x[i]);
42
+ }
43
+ }
44
+
24
45
template <typename T>
25
46
static void gelu (const T * x, T * dst, const int k,
26
47
const sycl::nd_item<3 > &item_ct1) {
@@ -335,6 +356,37 @@ static void silu_sycl(const T *x, T *dst, const int k,
335
356
});
336
357
}
337
358
359
+ template <typename T>
360
+ static void sgn_sycl (const T * x, T * dst, const int k, queue_ptr stream) {
361
+ // hard code for now
362
+ const int num_blocks = ceil_div (k, 256 );
363
+ stream->parallel_for (
364
+ sycl::nd_range<3 >((sycl::range<3 >(1 , 1 , num_blocks) * sycl::range (1 , 1 , 256 )), sycl::range (1 , 1 , 256 )), [=](sycl::nd_item<3 > item_ct1) {
365
+ sgn (x, dst, k, item_ct1);
366
+ });
367
+ }
368
+
369
+ template <typename T>
370
+ static void abs_sycl (const T * x, T * dst, const int k, queue_ptr stream) {
371
+ // hard code for now
372
+ const int num_blocks = ceil_div (k, 256 );
373
+ stream->parallel_for (
374
+ sycl::nd_range<3 >((sycl::range<3 >(1 , 1 , num_blocks) * sycl::range<3 >(1 , 1 , 256 )), sycl::range<3 >(1 , 1 , 256 )), [=](sycl::nd_item<3 > item_ct1) {
375
+ abs_op (x, dst, k, item_ct1);
376
+ });
377
+ }
378
+
379
+
380
+ template <typename T>
381
+ static void elu_sycl (const T * x, T * dst, const int k, queue_ptr stream) {
382
+ // hard code for now
383
+ const int num_blocks = ceil_div (k, 256 );
384
+ stream->parallel_for (
385
+ sycl::nd_range<3 >((sycl::range<3 >(1 , 1 , num_blocks) * sycl::range<3 >(1 , 1 , 256 )), sycl::range<3 >(1 , 1 , 256 )), [=](sycl::nd_item<3 > item_ct1) {
386
+ elu_op (x, dst, k, item_ct1);
387
+ });
388
+ }
389
+
338
390
template <typename T>
339
391
static void gelu_quick_sycl (const T *x, T *dst, const int k,
340
392
queue_ptr stream) {
@@ -574,6 +626,106 @@ static void clamp_sycl(const T *x, T *dst, const float min,
574
626
});
575
627
}
576
628
629
+ inline void ggml_sycl_op_sgn (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
630
+ #if defined (GGML_SYCL_F16)
631
+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32 || dst->src [0 ]->type == GGML_TYPE_F16);
632
+ GGML_ASSERT (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
633
+
634
+ #else
635
+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
636
+ GGML_ASSERT (dst->type == GGML_TYPE_F32);
637
+ #endif
638
+ GGML_ASSERT (dst->src [0 ]->type == dst->type );
639
+ dpct::queue_ptr main_stream = ctx.stream ();
640
+ SYCL_CHECK (ggml_sycl_set_device (ctx.device ));
641
+ switch (dst->type ) {
642
+ #if defined (GGML_SYCL_F16)
643
+ case GGML_TYPE_F16:
644
+ {
645
+ auto data_pts = cast_data<sycl::half>(dst);
646
+ sgn_sycl (data_pts.src , data_pts.dst , ggml_nelements (dst->src [0 ]), main_stream);
647
+ break ;
648
+ }
649
+ #endif
650
+ case GGML_TYPE_F32:
651
+ {
652
+ auto data_pts = cast_data<float >(dst);
653
+ sgn_sycl (data_pts.src , data_pts.dst , ggml_nelements (dst->src [0 ]), main_stream);
654
+ break ;
655
+ }
656
+ default :
657
+ GGML_ABORT (" GGML tensor type not supported!\n " );
658
+ break ;
659
+ }
660
+ }
661
+
662
+ inline void ggml_sycl_op_abs (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
663
+ #if defined (GGML_SYCL_F16)
664
+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32 || dst->src [0 ]->type == GGML_TYPE_F16);
665
+ GGML_ASSERT (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
666
+
667
+ #else
668
+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
669
+ GGML_ASSERT (dst->type == GGML_TYPE_F32);
670
+ #endif
671
+ GGML_ASSERT (dst->src [0 ]->type == dst->type );
672
+ dpct::queue_ptr main_stream = ctx.stream ();
673
+ SYCL_CHECK (ggml_sycl_set_device (ctx.device ));
674
+ switch (dst->type ) {
675
+ #if defined (GGML_SYCL_F16)
676
+ case GGML_TYPE_F16:
677
+ {
678
+ auto data_pts = cast_data<sycl::half>(dst);
679
+ abs_sycl (data_pts.src , data_pts.dst , ggml_nelements (dst->src [0 ]), main_stream);
680
+ break ;
681
+ }
682
+ #endif
683
+ case GGML_TYPE_F32:
684
+ {
685
+ auto data_pts = cast_data<float >(dst);
686
+ abs_sycl (data_pts.src , data_pts.dst , ggml_nelements (dst->src [0 ]), main_stream);
687
+ break ;
688
+ }
689
+ default :
690
+ GGML_ABORT (" GGML tensor type not supported!\n " );
691
+ break ;
692
+ }
693
+ }
694
+
695
+
696
+ inline void ggml_sycl_op_elu (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
697
+ #if defined (GGML_SYCL_F16)
698
+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32 || dst->src [0 ]->type == GGML_TYPE_F16);
699
+ GGML_ASSERT (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
700
+
701
+ #else
702
+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
703
+ GGML_ASSERT (dst->type == GGML_TYPE_F32);
704
+ #endif
705
+ GGML_ASSERT (dst->src [0 ]->type == dst->type );
706
+ dpct::queue_ptr main_stream = ctx.stream ();
707
+ SYCL_CHECK (ggml_sycl_set_device (ctx.device ));
708
+ switch (dst->type ) {
709
+ #if defined (GGML_SYCL_F16)
710
+ case GGML_TYPE_F16:
711
+ {
712
+ auto data_pts = cast_data<sycl::half>(dst);
713
+ elu_sycl (data_pts.src , data_pts.dst , ggml_nelements (dst->src [0 ]), main_stream);
714
+ break ;
715
+ }
716
+ #endif
717
+ case GGML_TYPE_F32:
718
+ {
719
+ auto data_pts = cast_data<float >(dst);
720
+ elu_sycl (data_pts.src , data_pts.dst , ggml_nelements (dst->src [0 ]), main_stream);
721
+ break ;
722
+ }
723
+ default :
724
+ GGML_ABORT (" GGML tensor type not supported!\n " );
725
+ break ;
726
+ }
727
+ }
728
+
577
729
inline void ggml_sycl_op_silu (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
578
730
#if defined (GGML_SYCL_F16)
579
731
GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32 || dst->src [0 ]->type == GGML_TYPE_F16);
@@ -1388,3 +1540,20 @@ void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1388
1540
GGML_SYCL_DEBUG (" call %s done\n " , __func__);
1389
1541
}
1390
1542
1543
+ void ggml_sycl_sgn (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1544
+ GGML_SYCL_DEBUG (" call %s: DST Tensor type: %s\n " , __func__, ggml_type_name (dst->type ));
1545
+ ggml_sycl_op_sgn (ctx, dst);
1546
+ GGML_SYCL_DEBUG (" call %s done\n " , __func__);
1547
+ }
1548
+
1549
+ void ggml_sycl_abs (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1550
+ GGML_SYCL_DEBUG (" call %s: DST Tensor type: %s\n " , __func__, ggml_type_name (dst->type ));
1551
+ ggml_sycl_op_abs (ctx, dst);
1552
+ GGML_SYCL_DEBUG (" call %s done\n " , __func__);
1553
+ }
1554
+
1555
+ void ggml_sycl_elu (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1556
+ GGML_SYCL_DEBUG (" call %s: DST Tensor type: %s\n " , __func__, ggml_type_name (dst->type ));
1557
+ ggml_sycl_op_elu (ctx, dst);
1558
+ GGML_SYCL_DEBUG (" call %s done\n " , __func__);
1559
+ }
0 commit comments