1
1
/*******************************************************************************
2
- * Copyright 2019-2024 Intel Corporation
2
+ * Copyright 2019-2025 Intel Corporation
3
3
*
4
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
5
* you may not use this file except in compliance with the License.
15
15
*******************************************************************************/
16
16
17
17
#include "gpu/intel/ocl/dispatch.h"
18
+ #include "gpu/intel/ocl/ocl_io.h"
18
19
#include "gpu/intel/ocl/ocl_types.h"
19
20
20
21
#if IS_FWD == 1
@@ -39,9 +40,8 @@ __kernel void ref_lrn_fwd(__global const DATA_T *src,
39
40
for (int j = 0 ; j < LOCAL_SIZE ; j ++ ) {
40
41
const int z_idx = (j + ic - PADDING );
41
42
bool zero = (z_idx < 0 || z_idx >= IC );
42
- DEF_ACC_DATA_T val = zero
43
- ? 0.0f
44
- : TO_DEF_ACC_DATA_T (src [SRC_OFF (mb , z_idx , id , ih , iw )]);
43
+ DEF_ACC_DATA_T val = 0 ;
44
+ if (!zero ) load (& val , src + SRC_OFF (mb , z_idx , id , ih , iw ));
45
45
sum += val * val ;
46
46
}
47
47
#else
@@ -60,8 +60,7 @@ __kernel void ref_lrn_fwd(__global const DATA_T *src,
60
60
for (int k = d_start ; k < d_end ; ++ k ) {
61
61
for (int j = h_start ; j < h_end ; ++ j ) {
62
62
for (int i = w_start ; i < w_end ; ++ i ) {
63
- DEF_ACC_DATA_T val
64
- = TO_DEF_ACC_DATA_T (src [SRC_OFF (mb , ic , k , j , i )]);
63
+ DEF_ACC_DATA_T val = load (val , src + SRC_OFF (mb , ic , k , j , i ));
65
64
sum += val * val ;
66
65
}
67
66
}
@@ -75,12 +74,12 @@ __kernel void ref_lrn_fwd(__global const DATA_T *src,
75
74
const DEF_ACC_DATA_T normalization_factor
76
75
= native_powr (base , (DEF_ACC_DATA_T )(- LRN_BETA ));
77
76
78
- const DEF_ACC_DATA_T val = TO_DEF_ACC_DATA_T ( src [ src_index ] );
77
+ const DEF_ACC_DATA_T val = load ( val , src + src_index );
79
78
const DEF_ACC_DATA_T normres = val * normalization_factor ;
80
79
#if IS_TRAINING == 1
81
80
ws [dst_index ] = base ;
82
81
#endif
83
- dst [ dst_index ] = TO_DATA_T ( normres );
82
+ write ( dst + dst_index , normres );
84
83
}
85
84
#endif
86
85
@@ -106,13 +105,12 @@ __kernel void ref_lrn_bwd(__global const DATA_T *src,
106
105
bool zero = (z_idx < 0 || z_idx >= IC );
107
106
if (!zero ) {
108
107
DEF_ACC_DATA_T val
109
- = TO_DEF_ACC_DATA_T (src [SRC_OFF (mb , z_idx , id , ih , iw )]);
110
- DEF_ACC_DATA_T omega = ws [SRC_OFF (mb , z_idx , id , ih , iw )];
108
+ = load (val , src + SRC_OFF (mb , z_idx , id , ih , iw ));
109
+ DEF_ACC_DATA_T omega
110
+ = load (omega , ws + SRC_OFF (mb , z_idx , id , ih , iw ));
111
111
DEF_ACC_DATA_T tmp = (DEF_ACC_DATA_T )1.0f
112
112
/ native_powr (omega , (DEF_ACC_DATA_T )LRN_BETA + 1 );
113
- B += tmp * val
114
- * TO_DEF_ACC_DATA_T (
115
- diff_dst [DST_OFF (mb , z_idx , id , ih , iw )]);
113
+ B += tmp * val * load (B , diff_dst + DST_OFF (mb , z_idx , id , ih , iw ));
116
114
}
117
115
}
118
116
#else
@@ -130,21 +128,20 @@ __kernel void ref_lrn_bwd(__global const DATA_T *src,
130
128
for (int j = h_start ; j < h_end ; ++ j ) {
131
129
for (int i = w_start ; i < w_end ; ++ i ) {
132
130
int data_off = SRC_OFF (mb , ic , k , j , i );
133
- DEF_ACC_DATA_T val = TO_DEF_ACC_DATA_T ( src [ data_off ] );
131
+ DEF_ACC_DATA_T val = load ( val , src + data_off );
134
132
DEF_ACC_DATA_T omega = ws [data_off ];
135
133
DEF_ACC_DATA_T tmp = (DEF_ACC_DATA_T )1.0f
136
134
/ native_powr (omega , (DEF_ACC_DATA_T )(LRN_BETA + 1 ));
137
- B += tmp * val * TO_DEF_ACC_DATA_T ( diff_dst [ data_off ] );
135
+ B += tmp * val * load ( B , diff_dst + data_off );
138
136
}
139
137
}
140
138
}
141
139
#endif
142
- const DEF_ACC_DATA_T A
143
- = native_powr (ws [src_index ], (DEF_ACC_DATA_T )- LRN_BETA )
144
- * TO_DEF_ACC_DATA_T (diff_dst [dst_index ]);
140
+ DEF_ACC_DATA_T A = native_powr (ws [src_index ], (DEF_ACC_DATA_T )- LRN_BETA )
141
+ * load (A , diff_dst + dst_index );
145
142
146
- diff_src [ src_index ] = TO_DATA_T ( A
147
- - TO_DEF_ACC_DATA_T ( src [ src_index ]) * 2 * ( DEF_ACC_DATA_T ) LRN_ALPHA
148
- * ( DEF_ACC_DATA_T ) LRN_BETA * num_elements_div * B );
143
+ A -= load ( A , src + src_index ) * 2 * ( DEF_ACC_DATA_T ) LRN_ALPHA
144
+ * ( DEF_ACC_DATA_T ) LRN_BETA * num_elements_div * B ;
145
+ write ( diff_src + src_index , A );
149
146
}
150
147
#endif
0 commit comments