|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#version 450 core |
| 10 | + |
| 11 | +${define_required_extensions("buffer", DTYPE)} |
| 12 | + |
| 13 | +#extension GL_EXT_control_flow_attributes : require |
| 14 | +#extension GL_EXT_integer_dot_product : require |
| 15 | + |
| 16 | +#define PRECISION ${PRECISION} |
| 17 | +#define VEC4_T ${texel_load_type(DTYPE, "buffer")} |
| 18 | +#define T ${texel_load_component_type(DTYPE, "buffer")} |
| 19 | + |
| 20 | +${define_active_storage_type("buffer")} |
| 21 | + |
| 22 | +// corresponds to input/output width dim |
| 23 | +#define TILE_M4 1 |
| 24 | +// corresponds to input channels dim |
| 25 | +#define TILE_K4 1 |
| 26 | +// corresponds to output channels dim |
| 27 | +#define TILE_N4 2 |
| 28 | + |
| 29 | +#define TILE_M 4 |
| 30 | +#define TILE_K 4 |
| 31 | +#define TILE_N 8 |
| 32 | + |
| 33 | +layout(std430) buffer; |
| 34 | + |
| 35 | +#include "indexing.glslh" |
| 36 | +#include "common.glslh" |
| 37 | +#include "block_indexing.glslh" |
| 38 | + |
| 39 | +${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", "buffer", is_scalar_array=True)} |
| 40 | +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", "buffer", is_scalar_array=False)} |
| 41 | +${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", "texture2d", is_scalar_array=False)} |
| 42 | +${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)} |
| 43 | +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} |
| 44 | +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} |
| 45 | + |
| 46 | +// Metadata for input/output tensors |
| 47 | +${layout_declare_ubo(B, "BufferMetadata", "outp")} |
| 48 | +${layout_declare_ubo(B, "BufferMetadata", "inp")} |
| 49 | + |
| 50 | +layout(push_constant) uniform restrict Block { |
| 51 | + float input_scale; |
| 52 | + int input_zp; |
| 53 | + float output_inv_scale; |
| 54 | + int output_zp; |
| 55 | +}; |
| 56 | + |
| 57 | +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; |
| 58 | + |
| 59 | +${layout_declare_spec_const(C, "int", "apply_bias", "1")} |
| 60 | +${layout_declare_spec_const(C, "int", "conv2d_params_K4_per_group", "1")} |
| 61 | + |
| 62 | +// Layout specialization constants |
| 63 | +${layout_declare_spec_const(C, "int", "outp_layout", "CONTIG_LAYOUT_INT")} |
| 64 | +${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")} |
| 65 | + |
| 66 | +int compute_outp_buffer_idx( |
| 67 | + const int w_block_idx, |
| 68 | + const int h_idx, |
| 69 | + const int c_block_idx) { |
| 70 | + if (get_outer_packed_dim_block_size(outp_layout) == 1) { |
| 71 | + return h_idx * int(outp.strides[0][1]) |
| 72 | + + mul_4(w_block_idx) * int(outp.strides[0][0]) |
| 73 | + + c_block_idx * int(outp.strides[0][2]); |
| 74 | + } else { |
| 75 | + return mul_4( |
| 76 | + h_idx * int(outp.strides[0][1]) |
| 77 | + + w_block_idx * int(outp.strides[0][0]) |
| 78 | + + c_block_idx * int(outp.strides[0][2])); |
| 79 | + } |
| 80 | + |
| 81 | +} |
| 82 | + |
| 83 | +void main() { |
| 84 | + // Thread mapping: each thread handles TILE_M (4) widths × TILE_N (8) output channels |
| 85 | + // gl_GlobalInvocationID.x → output channel blocks (TILE_N4 = 2 blocks of 4 channels) |
| 86 | + // gl_GlobalInvocationID.y → width blocks (TILE_M4 = 1 block of 4 widths) |
| 87 | + // gl_GlobalInvocationID.z → batch (or height * batch combined) |
| 88 | + const int oc_block_idx = int(gl_GlobalInvocationID.x) * TILE_N4; |
| 89 | + const int ow_block_idx = int(gl_GlobalInvocationID.y) * TILE_M4; |
| 90 | + const int oh = int(gl_GlobalInvocationID.z); |
| 91 | + |
| 92 | + // Get output extents in block space (div_up_4 for packed dimensions) |
| 93 | + const int W = int(outp.sizes[0][0]); |
| 94 | + const int W4 = div_up_4(int(outp.sizes[0][0])); |
| 95 | + const int H = int(outp.sizes[0][1]); |
| 96 | + const int OC4 = div_up_4(int(outp.sizes[0][2])); |
| 97 | + |
| 98 | + // Bounds check in block space |
| 99 | + if (ow_block_idx >= W4 || |
| 100 | + oh >= H || |
| 101 | + oc_block_idx >= OC4) { |
| 102 | + return; |
| 103 | + } |
| 104 | + |
| 105 | + // Get input extents in block space |
| 106 | + const int inp_W4 = div_up_4(int(inp.sizes[0][0])); |
| 107 | + const int inp_IC4 = div_up_4(int(inp.sizes[0][2])); |
| 108 | + |
| 109 | + // Precompute stride products for indexing |
| 110 | + // For 4W4C layout: buffer_idx = batch * (W4 * C4) + w4 * C4 + c4 |
| 111 | + const int inp_w_stride = int(inp.strides[0][0]); |
| 112 | + const int inp_h_stride = int(inp.strides[0][1]); |
| 113 | + const int inp_c_stride = int(inp.strides[0][2]); |
| 114 | + |
| 115 | + // Initialize int32 accumulator |
| 116 | + ivec4 out_accum[TILE_M][TILE_N4]; |
| 117 | + [[unroll]] for (int m = 0; m < TILE_M; ++m) { |
| 118 | + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { |
| 119 | + out_accum[m][n4] = ivec4(0); |
| 120 | + } |
| 121 | + } |
| 122 | + |
| 123 | + // Compute initial input tile index |
| 124 | + // Input has same spatial layout, channel dimension iterates from 0 |
| 125 | + int input_idx = oh * inp_h_stride + ow_block_idx * inp_w_stride; |
| 126 | + |
| 127 | + // Main accumulation loop over K dimension |
| 128 | + for (int k4 = 0; k4 < conv2d_params_K4_per_group; k4++) { |
| 129 | + // Load packed int8 input tile (TILE_M4=1, TILE_K4=1) |
| 130 | + // Each int contains 4 packed int8s (one per width position in the tile) |
| 131 | + ivec4 int8_input_tile = t_packed_int8_input[input_idx]; |
| 132 | + |
| 133 | + // Load int8 weight tile (TILE_K4=1, TILE_N4=2) |
| 134 | + ivec4 int8_weight_tile[TILE_N4]; |
| 135 | + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { |
| 136 | + int8_weight_tile[n4] = texelFetch( |
| 137 | + t_packed_int8_weight, |
| 138 | + ivec2(oc_block_idx + n4, k4), |
| 139 | + 0); |
| 140 | + } |
| 141 | + |
| 142 | + // Accumulate using int8 dot product |
| 143 | + // Input tile indexed as input[m] where m is the width index within tile |
| 144 | + // Weight tile indexed as weight[n4][n4i] where n4i is the channel index within block |
| 145 | + [[unroll]] for (int m = 0; m < TILE_M; ++m) { |
| 146 | + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { |
| 147 | + [[unroll]] for (int n4i = 0; n4i < 4; ++n4i) { |
| 148 | + out_accum[m][n4][n4i] = dotPacked4x8AccSatEXT( |
| 149 | + int8_input_tile[m], |
| 150 | + int8_weight_tile[n4][n4i], |
| 151 | + out_accum[m][n4][n4i]); |
| 152 | + } |
| 153 | + } |
| 154 | + } |
| 155 | + |
| 156 | + input_idx++; |
| 157 | + } |
| 158 | + |
| 159 | + // Load weight scales tile |
| 160 | + VEC4_T weight_scales[TILE_N4]; |
| 161 | + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { |
| 162 | + weight_scales[n4] = t_weight_scales[oc_block_idx + n4]; |
| 163 | + } |
| 164 | + |
| 165 | + // Load weight sums tile |
| 166 | + ivec4 weight_sums[TILE_N4]; |
| 167 | + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { |
| 168 | + weight_sums[n4] = ivec4(t_weight_sums[oc_block_idx + n4]); |
| 169 | + } |
| 170 | + |
| 171 | + // Initialize int8 output tile |
| 172 | + ivec4 int8_out_tile[TILE_M4][TILE_N4]; |
| 173 | + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { |
| 174 | + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { |
| 175 | + int8_out_tile[m4][n4] = ivec4(0); |
| 176 | + } |
| 177 | + } |
| 178 | + |
| 179 | + // Compute int8 output tile from int32 accumulator |
| 180 | + ivec4 input_zp_vec = ivec4(-input_zp); |
| 181 | + |
| 182 | + if (apply_bias > 0) { |
| 183 | + // Load bias tile |
| 184 | + VEC4_T bias[TILE_N4]; |
| 185 | + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { |
| 186 | + bias[n4] = t_bias[oc_block_idx + n4]; |
| 187 | + } |
| 188 | + |
| 189 | + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { |
| 190 | + [[unroll]] for (int m4i = 0; m4i < 4; ++m4i) { |
| 191 | + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { |
| 192 | + const int m = mul_4(m4) + m4i; |
| 193 | + // Compute floating point output values |
| 194 | + ivec4 accum_adjusted = |
| 195 | + input_zp_vec * weight_sums[n4] + out_accum[m][n4]; |
| 196 | + vec4 float_out_texel = |
| 197 | + fma(vec4(accum_adjusted), |
| 198 | + vec4(weight_scales[n4]) * input_scale, |
| 199 | + vec4(bias[n4])); |
| 200 | + // Requantize to int8 |
| 201 | + float_out_texel = |
| 202 | + round(float_out_texel * output_inv_scale) + output_zp; |
| 203 | + ivec4 quantized_out_texel = clamp(ivec4(float_out_texel), -128, 127); |
| 204 | + |
| 205 | + int8_out_tile[m4][n4][m4i] = pack_into_int32(quantized_out_texel); |
| 206 | + } |
| 207 | + } |
| 208 | + } |
| 209 | + } else { |
| 210 | + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { |
| 211 | + [[unroll]] for (int m4i = 0; m4i < 4; ++m4i) { |
| 212 | + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { |
| 213 | + const int m = mul_4(m4) + m4i; |
| 214 | + // Compute floating point output values |
| 215 | + ivec4 accum_adjusted = |
| 216 | + input_zp_vec * weight_sums[n4] + out_accum[m][n4]; |
| 217 | + vec4 float_out_texel = |
| 218 | + vec4(accum_adjusted) * vec4(weight_scales[n4] * input_scale); |
| 219 | + // Requantize to int8 |
| 220 | + float_out_texel = |
| 221 | + round(float_out_texel * output_inv_scale) + output_zp; |
| 222 | + ivec4 quantized_out_texel = clamp(ivec4(float_out_texel), -128, 127); |
| 223 | + |
| 224 | + int8_out_tile[m4][n4][m4i] = pack_into_int32(quantized_out_texel); |
| 225 | + } |
| 226 | + } |
| 227 | + } |
| 228 | + } |
| 229 | + |
| 230 | + const int outp_w_stride = int(outp.strides[0][0]); |
| 231 | + |
| 232 | + // Store packed int8 output tile |
| 233 | + [[unroll]] for (int m4 = 0; m4 < TILE_M4; m4++) { |
| 234 | + [[unroll]] for (int n4 = 0; n4 < TILE_N4; n4++) { |
| 235 | + const int base_outp_buffer_idx = compute_outp_buffer_idx( |
| 236 | + ow_block_idx + m4, |
| 237 | + oh, |
| 238 | + oc_block_idx + n4); |
| 239 | + if (oc_block_idx + n4 < OC4) { |
| 240 | + // Store individual ints from the ivec4 |
| 241 | + const int subtile_w_limit = min(4, W - mul_4(ow_block_idx + m4)); |
| 242 | + [[unroll]] for (int subtile_w = 0; subtile_w < subtile_w_limit; ++subtile_w) { |
| 243 | + if (get_outer_packed_dim_block_size(outp_layout) == 1) { |
| 244 | + t_packed_int8_output[base_outp_buffer_idx + subtile_w * outp_w_stride] = int8_out_tile[m4][n4][subtile_w]; |
| 245 | + } else { |
| 246 | + t_packed_int8_output[base_outp_buffer_idx + subtile_w] = int8_out_tile[m4][n4][subtile_w]; |
| 247 | + } |
| 248 | + } |
| 249 | + } |
| 250 | + } |
| 251 | + } |
| 252 | +} |
0 commit comments