@@ -3847,21 +3847,27 @@ static void concat_f32(const float *x,const float *y, float *dst, const int ne
3847
3847
}
3848
3848
}
3849
3849
3850
- static void upscale_f32(const float *x, float *dst, const int ne00, const int nb02, const int scale_factor,
3851
- const sycl::nd_item<3> &item_ct1) {
3852
- int ne0 = ne00 * scale_factor;
3853
- int nidx = item_ct1.get_local_id(2) +
3854
- item_ct1.get_group(2) * item_ct1.get_local_range(2);
3855
- if (nidx >= ne0) {
3850
+ static void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
3851
+ const int nb02, const int nb03, const int ne10, const int ne11,
3852
+ const int ne12, const int ne13, const float sf0, const float sf1,
3853
+ const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) {
3854
+ int index = item_ct1.get_local_id(0) +
3855
+ item_ct1.get_group(0) * item_ct1.get_local_range(0);
3856
+ if (index >= ne10 * ne11 * ne12 * ne13) {
3856
3857
return;
3857
3858
}
3858
3859
// operation
3859
- int i00 = nidx / scale_factor;
3860
- int i01 = item_ct1.get_group(1) / scale_factor;
3861
- int offset_src = i00 + i01 * ne00 + item_ct1.get_group(0) * nb02;
3862
- int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
3863
- item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
3864
- dst[offset_dst] = x[offset_src];
3860
+ int i10 = index % ne10;
3861
+ int i11 = (index / ne10) % ne11;
3862
+ int i12 = (index / (ne10 * ne11)) % ne12;
3863
+ int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
3864
+
3865
+ int i00 = i10 / sf0;
3866
+ int i01 = i11 / sf1;
3867
+ int i02 = i12 / sf2;
3868
+ int i03 = i13 / sf3;
3869
+
3870
+ dst[index] = *(float *)((char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
3865
3871
}
3866
3872
3867
3873
static void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02,
@@ -10085,18 +10091,17 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
10085
10091
});
10086
10092
}
10087
10093
10088
- static void upscale_f32_sycl(const float *x, float *dst, const int ne00,
10089
- const int ne01, const int ne02,
10090
- const int scale_factor, dpct::queue_ptr stream) {
10091
- int ne0 = (ne00 * scale_factor);
10092
- int num_blocks = (ne0 + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE;
10093
- sycl::range<3> gridDim(ne02, (ne01 * scale_factor), num_blocks);
10094
+ static void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
10095
+ const int nb02, const int nb03, const int ne10, const int ne11,
10096
+ const int ne12, const int ne13, const float sf0, const float sf1,
10097
+ const float sf2, const float sf3, dpct::queue_ptr stream) {
10098
+ int dst_size = ne10 * ne11 * ne12 * ne13;
10099
+ int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE;
10100
+ sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
10094
10101
stream->parallel_for(
10095
- sycl::nd_range<3>(gridDim *
10096
- sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE),
10097
- sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)),
10098
- [=](sycl::nd_item<3> item_ct1) {
10099
- upscale_f32(x, dst, ne00, ne00 * ne01, scale_factor, item_ct1);
10102
+ sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)),
10103
+ [=](sycl::nd_item<1> item_ct1) {
10104
+ upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1);
10100
10105
});
10101
10106
}
10102
10107
@@ -13985,15 +13990,15 @@ inline void ggml_sycl_op_upscale(const ggml_tensor *src0,
13985
13990
13986
13991
GGML_ASSERT(src0->type == GGML_TYPE_F32);
13987
13992
GGML_ASSERT(dst->type == GGML_TYPE_F32);
13988
- GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
13989
-
13990
- #pragma message("TODO: generalize upscale operator")
13991
- #pragma message(" https://github.com/ggerganov/ggml/pull/814")
13992
- GGML_ASSERT(false && "TODO: generalize upscale operator");
13993
13993
13994
- const int scale_factor = dst->op_params[0];
13994
+ const float sf0 = (float)dst->ne[0]/src0->ne[0];
13995
+ const float sf1 = (float)dst->ne[1]/src0->ne[1];
13996
+ const float sf2 = (float)dst->ne[2]/src0->ne[2];
13997
+ const float sf3 = (float)dst->ne[3]/src0->ne[3];
13995
13998
13996
- upscale_f32_sycl(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], scale_factor, main_stream);
13999
+ upscale_f32_sycl(src0_dd, dst_dd, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
14000
+ dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3,
14001
+ main_stream);
13997
14002
13998
14003
(void) src1;
13999
14004
(void) dst;
0 commit comments