Skip to content

Commit

Permalink
xe: ocl: refactor ref_lrn to use ocl_io.h
Browse files Browse the repository at this point in the history
  • Loading branch information
rjoursler committed Feb 11, 2025
1 parent b78eed0 commit 873c9c5
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 19 deletions.
31 changes: 14 additions & 17 deletions src/gpu/intel/ocl/ref_lrn.cl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2019-2024 Intel Corporation
* Copyright 2019-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,6 +15,7 @@
*******************************************************************************/

#include "gpu/intel/ocl/dispatch.h"
#include "gpu/intel/ocl/ocl_io.h"
#include "gpu/intel/ocl/ocl_types.h"

#if IS_FWD == 1
Expand All @@ -39,9 +40,8 @@ __kernel void ref_lrn_fwd(__global const DATA_T *src,
for (int j = 0; j < LOCAL_SIZE; j++) {
const int z_idx = (j + ic - PADDING);
bool zero = (z_idx < 0 || z_idx >= IC);
DEF_ACC_DATA_T val = zero
? 0.0f
: TO_DEF_ACC_DATA_T(src[SRC_OFF(mb, z_idx, id, ih, iw)]);
DEF_ACC_DATA_T val
= zero ? 0.0f : load(val, src + SRC_OFF(mb, z_idx, id, ih, iw));
sum += val * val;
}
#else
Expand All @@ -60,8 +60,7 @@ __kernel void ref_lrn_fwd(__global const DATA_T *src,
for (int k = d_start; k < d_end; ++k) {
for (int j = h_start; j < h_end; ++j) {
for (int i = w_start; i < w_end; ++i) {
DEF_ACC_DATA_T val
= TO_DEF_ACC_DATA_T(src[SRC_OFF(mb, ic, k, j, i)]);
DEF_ACC_DATA_T val = load(val, src + SRC_OFF(mb, ic, k, j, i));
sum += val * val;
}
}
Expand All @@ -75,12 +74,12 @@ __kernel void ref_lrn_fwd(__global const DATA_T *src,
const DEF_ACC_DATA_T normalization_factor
= native_powr(base, (DEF_ACC_DATA_T)(-LRN_BETA));

const DEF_ACC_DATA_T val = TO_DEF_ACC_DATA_T(src[src_index]);
const DEF_ACC_DATA_T val = load(val, src + src_index);
const DEF_ACC_DATA_T normres = val * normalization_factor;
#if IS_TRAINING == 1
ws[dst_index] = base;
#endif
dst[dst_index] = TO_DATA_T(normres);
write(dst + dst_index, normres);
}
#endif

Expand Down Expand Up @@ -110,9 +109,7 @@ __kernel void ref_lrn_bwd(__global const DATA_T *src,
DEF_ACC_DATA_T omega = ws[SRC_OFF(mb, z_idx, id, ih, iw)];
DEF_ACC_DATA_T tmp = (DEF_ACC_DATA_T)1.0f
/ native_powr(omega, (DEF_ACC_DATA_T)LRN_BETA + 1);
B += tmp * val
* TO_DEF_ACC_DATA_T(
diff_dst[DST_OFF(mb, z_idx, id, ih, iw)]);
B += tmp * val * load(B, diff_dst + DST_OFF(mb, z_idx, id, ih, iw));
}
}
#else
Expand All @@ -130,21 +127,21 @@ __kernel void ref_lrn_bwd(__global const DATA_T *src,
for (int j = h_start; j < h_end; ++j) {
for (int i = w_start; i < w_end; ++i) {
int data_off = SRC_OFF(mb, ic, k, j, i);
DEF_ACC_DATA_T val = TO_DEF_ACC_DATA_T(src[data_off]);
DEF_ACC_DATA_T val = load(val, src + data_off);
DEF_ACC_DATA_T omega = ws[data_off];
DEF_ACC_DATA_T tmp = (DEF_ACC_DATA_T)1.0f
/ native_powr(omega, (DEF_ACC_DATA_T)(LRN_BETA + 1));
B += tmp * val * TO_DEF_ACC_DATA_T(diff_dst[data_off]);
B += tmp * val * load(B, diff_dst + data_off);
}
}
}
#endif
const DEF_ACC_DATA_T A
= native_powr(ws[src_index], (DEF_ACC_DATA_T)-LRN_BETA)
* TO_DEF_ACC_DATA_T(diff_dst[dst_index]);
* load(A, diff_dst + dst_index);

diff_src[src_index] = TO_DATA_T(A
- TO_DEF_ACC_DATA_T(src[src_index]) * 2 * (DEF_ACC_DATA_T)LRN_ALPHA
* (DEF_ACC_DATA_T)LRN_BETA * num_elements_div * B);
A -= load(A, src + src_index) * 2 * (DEF_ACC_DATA_T)LRN_ALPHA
* (DEF_ACC_DATA_T)LRN_BETA * num_elements_div * B);
write(diff_src + src_index, A);
}
#endif
4 changes: 2 additions & 2 deletions src/gpu/intel/ocl/ref_lrn.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2019-2024 Intel Corporation
* Copyright 2019-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -89,7 +89,7 @@ struct ref_lrn_fwd_t : public gpu_primitive_t {
status_t status = status::success;
const auto *desc = pd()->desc();

kernel_ctx.set_data_type(desc->src_desc.data_type);
kernel_ctx.set_data_type(desc->src_desc.data_type, false);

kernel_ctx.define_int("IS_FWD", 1);

Expand Down

0 comments on commit 873c9c5

Please sign in to comment.