Skip to content

Commit 983f734

Browse files
committed
xe: ocl: refactor ref_lrn to use ocl_io.h
1 parent 00c2388 commit 983f734

File tree

2 files changed

+21
-24
lines changed

2 files changed

+21
-24
lines changed

src/gpu/intel/ocl/ref_lrn.cl

+18-21
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2019-2024 Intel Corporation
2+
* Copyright 2019-2025 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -15,6 +15,7 @@
1515
*******************************************************************************/
1616

1717
#include "gpu/intel/ocl/dispatch.h"
18+
#include "gpu/intel/ocl/ocl_io.h"
1819
#include "gpu/intel/ocl/ocl_types.h"
1920

2021
#if IS_FWD == 1
@@ -39,9 +40,8 @@ __kernel void ref_lrn_fwd(__global const DATA_T *src,
3940
for (int j = 0; j < LOCAL_SIZE; j++) {
4041
const int z_idx = (j + ic - PADDING);
4142
bool zero = (z_idx < 0 || z_idx >= IC);
42-
DEF_ACC_DATA_T val = zero
43-
? 0.0f
44-
: TO_DEF_ACC_DATA_T(src[SRC_OFF(mb, z_idx, id, ih, iw)]);
43+
DEF_ACC_DATA_T val = 0;
44+
if (!zero) load(&val, src + SRC_OFF(mb, z_idx, id, ih, iw));
4545
sum += val * val;
4646
}
4747
#else
@@ -60,8 +60,7 @@ __kernel void ref_lrn_fwd(__global const DATA_T *src,
6060
for (int k = d_start; k < d_end; ++k) {
6161
for (int j = h_start; j < h_end; ++j) {
6262
for (int i = w_start; i < w_end; ++i) {
63-
DEF_ACC_DATA_T val
64-
= TO_DEF_ACC_DATA_T(src[SRC_OFF(mb, ic, k, j, i)]);
63+
DEF_ACC_DATA_T val = load(val, src + SRC_OFF(mb, ic, k, j, i));
6564
sum += val * val;
6665
}
6766
}
@@ -75,12 +74,12 @@ __kernel void ref_lrn_fwd(__global const DATA_T *src,
7574
const DEF_ACC_DATA_T normalization_factor
7675
= native_powr(base, (DEF_ACC_DATA_T)(-LRN_BETA));
7776

78-
const DEF_ACC_DATA_T val = TO_DEF_ACC_DATA_T(src[src_index]);
77+
const DEF_ACC_DATA_T val = load(val, src + src_index);
7978
const DEF_ACC_DATA_T normres = val * normalization_factor;
8079
#if IS_TRAINING == 1
8180
ws[dst_index] = base;
8281
#endif
83-
dst[dst_index] = TO_DATA_T(normres);
82+
write(dst + dst_index, normres);
8483
}
8584
#endif
8685

@@ -106,13 +105,12 @@ __kernel void ref_lrn_bwd(__global const DATA_T *src,
106105
bool zero = (z_idx < 0 || z_idx >= IC);
107106
if (!zero) {
108107
DEF_ACC_DATA_T val
109-
= TO_DEF_ACC_DATA_T(src[SRC_OFF(mb, z_idx, id, ih, iw)]);
110-
DEF_ACC_DATA_T omega = ws[SRC_OFF(mb, z_idx, id, ih, iw)];
108+
= load(val, src + SRC_OFF(mb, z_idx, id, ih, iw));
109+
DEF_ACC_DATA_T omega
110+
= load(omega, ws + SRC_OFF(mb, z_idx, id, ih, iw));
111111
DEF_ACC_DATA_T tmp = (DEF_ACC_DATA_T)1.0f
112112
/ native_powr(omega, (DEF_ACC_DATA_T)LRN_BETA + 1);
113-
B += tmp * val
114-
* TO_DEF_ACC_DATA_T(
115-
diff_dst[DST_OFF(mb, z_idx, id, ih, iw)]);
113+
B += tmp * val * load(B, diff_dst + DST_OFF(mb, z_idx, id, ih, iw));
116114
}
117115
}
118116
#else
@@ -130,21 +128,20 @@ __kernel void ref_lrn_bwd(__global const DATA_T *src,
130128
for (int j = h_start; j < h_end; ++j) {
131129
for (int i = w_start; i < w_end; ++i) {
132130
int data_off = SRC_OFF(mb, ic, k, j, i);
133-
DEF_ACC_DATA_T val = TO_DEF_ACC_DATA_T(src[data_off]);
131+
DEF_ACC_DATA_T val = load(val, src + data_off);
134132
DEF_ACC_DATA_T omega = ws[data_off];
135133
DEF_ACC_DATA_T tmp = (DEF_ACC_DATA_T)1.0f
136134
/ native_powr(omega, (DEF_ACC_DATA_T)(LRN_BETA + 1));
137-
B += tmp * val * TO_DEF_ACC_DATA_T(diff_dst[data_off]);
135+
B += tmp * val * load(B, diff_dst + data_off);
138136
}
139137
}
140138
}
141139
#endif
142-
const DEF_ACC_DATA_T A
143-
= native_powr(ws[src_index], (DEF_ACC_DATA_T)-LRN_BETA)
144-
* TO_DEF_ACC_DATA_T(diff_dst[dst_index]);
140+
DEF_ACC_DATA_T A = native_powr(ws[src_index], (DEF_ACC_DATA_T)-LRN_BETA)
141+
* load(A, diff_dst + dst_index);
145142

146-
diff_src[src_index] = TO_DATA_T(A
147-
- TO_DEF_ACC_DATA_T(src[src_index]) * 2 * (DEF_ACC_DATA_T)LRN_ALPHA
148-
* (DEF_ACC_DATA_T)LRN_BETA * num_elements_div * B);
143+
A -= load(A, src + src_index) * 2 * (DEF_ACC_DATA_T)LRN_ALPHA
144+
* (DEF_ACC_DATA_T)LRN_BETA * num_elements_div * B;
145+
write(diff_src + src_index, A);
149146
}
150147
#endif

src/gpu/intel/ocl/ref_lrn.hpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2019-2024 Intel Corporation
2+
* Copyright 2019-2025 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -89,7 +89,7 @@ struct ref_lrn_fwd_t : public gpu_primitive_t {
8989
status_t status = status::success;
9090
const auto *desc = pd()->desc();
9191

92-
kernel_ctx.set_data_type(desc->src_desc.data_type);
92+
kernel_ctx.set_data_type(desc->src_desc.data_type, false);
9393

9494
kernel_ctx.define_int("IS_FWD", 1);
9595

@@ -214,7 +214,7 @@ struct ref_lrn_bwd_t : public gpu_primitive_t {
214214
status_t status = status::success;
215215
const auto *desc = pd()->desc();
216216

217-
kernel_ctx.set_data_type(desc->src_desc.data_type);
217+
kernel_ctx.set_data_type(desc->src_desc.data_type, false);
218218

219219
kernel_ctx.define_int("IS_BWD", 1);
220220

0 commit comments

Comments
 (0)