Skip to content

Commit b1d0159

Browse files
authored
[ET-VK][qconv] Add flexible layout impl for quantized pointwise conv
Differential Revision: D92307253 Pull Request resolved: #17221
1 parent adbec9a commit b1d0159

7 files changed

Lines changed: 721 additions & 6 deletions

File tree

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
${define_required_extensions("buffer", DTYPE)}
12+
13+
#extension GL_EXT_control_flow_attributes : require
14+
#extension GL_EXT_integer_dot_product : require
15+
16+
#define PRECISION ${PRECISION}
17+
#define VEC4_T ${texel_load_type(DTYPE, "buffer")}
18+
#define T ${texel_load_component_type(DTYPE, "buffer")}
19+
20+
${define_active_storage_type("buffer")}
21+
22+
// corresponds to input/output width dim
23+
#define TILE_M4 1
24+
// corresponds to input channels dim
25+
#define TILE_K4 1
26+
// corresponds to output channels dim
27+
#define TILE_N4 2
28+
29+
#define TILE_M 4
30+
#define TILE_K 4
31+
#define TILE_N 8
32+
33+
layout(std430) buffer;
34+
35+
#include "indexing.glslh"
36+
#include "common.glslh"
37+
#include "block_indexing.glslh"
38+
39+
${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", "buffer", is_scalar_array=True)}
40+
${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", "buffer", is_scalar_array=False)}
41+
${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", "texture2d", is_scalar_array=False)}
42+
${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)}
43+
${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)}
44+
${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)}
45+
46+
// Metadata for input/output tensors
47+
${layout_declare_ubo(B, "BufferMetadata", "outp")}
48+
${layout_declare_ubo(B, "BufferMetadata", "inp")}
49+
50+
layout(push_constant) uniform restrict Block {
51+
float input_scale;
52+
int input_zp;
53+
float output_inv_scale;
54+
int output_zp;
55+
};
56+
57+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
58+
59+
${layout_declare_spec_const(C, "int", "apply_bias", "1")}
60+
${layout_declare_spec_const(C, "int", "conv2d_params_K4_per_group", "1")}
61+
62+
// Layout specialization constants
63+
${layout_declare_spec_const(C, "int", "outp_layout", "CONTIG_LAYOUT_INT")}
64+
${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")}
65+
66+
int compute_outp_buffer_idx(
67+
const int w_block_idx,
68+
const int h_idx,
69+
const int c_block_idx) {
70+
if (get_outer_packed_dim_block_size(outp_layout) == 1) {
71+
return h_idx * int(outp.strides[0][1])
72+
+ mul_4(w_block_idx) * int(outp.strides[0][0])
73+
+ c_block_idx * int(outp.strides[0][2]);
74+
} else {
75+
return mul_4(
76+
h_idx * int(outp.strides[0][1])
77+
+ w_block_idx * int(outp.strides[0][0])
78+
+ c_block_idx * int(outp.strides[0][2]));
79+
}
80+
81+
}
82+
83+
void main() {
84+
// Thread mapping: each thread handles TILE_M (4) widths × TILE_N (8) output channels
85+
// gl_GlobalInvocationID.x → output channel blocks (TILE_N4 = 2 blocks of 4 channels)
86+
// gl_GlobalInvocationID.y → width blocks (TILE_M4 = 1 block of 4 widths)
87+
// gl_GlobalInvocationID.z → batch (or height * batch combined)
88+
const int oc_block_idx = int(gl_GlobalInvocationID.x) * TILE_N4;
89+
const int ow_block_idx = int(gl_GlobalInvocationID.y) * TILE_M4;
90+
const int oh = int(gl_GlobalInvocationID.z);
91+
92+
// Get output extents in block space (div_up_4 for packed dimensions)
93+
const int W = int(outp.sizes[0][0]);
94+
const int W4 = div_up_4(int(outp.sizes[0][0]));
95+
const int H = int(outp.sizes[0][1]);
96+
const int OC4 = div_up_4(int(outp.sizes[0][2]));
97+
98+
// Bounds check in block space
99+
if (ow_block_idx >= W4 ||
100+
oh >= H ||
101+
oc_block_idx >= OC4) {
102+
return;
103+
}
104+
105+
// Get input extents in block space
106+
const int inp_W4 = div_up_4(int(inp.sizes[0][0]));
107+
const int inp_IC4 = div_up_4(int(inp.sizes[0][2]));
108+
109+
// Precompute stride products for indexing
110+
// For 4W4C layout: buffer_idx = batch * (W4 * C4) + w4 * C4 + c4
111+
const int inp_w_stride = int(inp.strides[0][0]);
112+
const int inp_h_stride = int(inp.strides[0][1]);
113+
const int inp_c_stride = int(inp.strides[0][2]);
114+
115+
// Initialize int32 accumulator
116+
ivec4 out_accum[TILE_M][TILE_N4];
117+
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
118+
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
119+
out_accum[m][n4] = ivec4(0);
120+
}
121+
}
122+
123+
// Compute initial input tile index
124+
// Input has same spatial layout, channel dimension iterates from 0
125+
int input_idx = oh * inp_h_stride + ow_block_idx * inp_w_stride;
126+
127+
// Main accumulation loop over K dimension
128+
for (int k4 = 0; k4 < conv2d_params_K4_per_group; k4++) {
129+
// Load packed int8 input tile (TILE_M4=1, TILE_K4=1)
130+
// Each int contains 4 packed int8s (one per width position in the tile)
131+
ivec4 int8_input_tile = t_packed_int8_input[input_idx];
132+
133+
// Load int8 weight tile (TILE_K4=1, TILE_N4=2)
134+
ivec4 int8_weight_tile[TILE_N4];
135+
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
136+
int8_weight_tile[n4] = texelFetch(
137+
t_packed_int8_weight,
138+
ivec2(oc_block_idx + n4, k4),
139+
0);
140+
}
141+
142+
// Accumulate using int8 dot product
143+
// Input tile indexed as input[m] where m is the width index within tile
144+
// Weight tile indexed as weight[n4][n4i] where n4i is the channel index within block
145+
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
146+
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
147+
[[unroll]] for (int n4i = 0; n4i < 4; ++n4i) {
148+
out_accum[m][n4][n4i] = dotPacked4x8AccSatEXT(
149+
int8_input_tile[m],
150+
int8_weight_tile[n4][n4i],
151+
out_accum[m][n4][n4i]);
152+
}
153+
}
154+
}
155+
156+
input_idx++;
157+
}
158+
159+
// Load weight scales tile
160+
VEC4_T weight_scales[TILE_N4];
161+
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
162+
weight_scales[n4] = t_weight_scales[oc_block_idx + n4];
163+
}
164+
165+
// Load weight sums tile
166+
ivec4 weight_sums[TILE_N4];
167+
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
168+
weight_sums[n4] = ivec4(t_weight_sums[oc_block_idx + n4]);
169+
}
170+
171+
// Initialize int8 output tile
172+
ivec4 int8_out_tile[TILE_M4][TILE_N4];
173+
[[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) {
174+
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
175+
int8_out_tile[m4][n4] = ivec4(0);
176+
}
177+
}
178+
179+
// Compute int8 output tile from int32 accumulator
180+
ivec4 input_zp_vec = ivec4(-input_zp);
181+
182+
if (apply_bias > 0) {
183+
// Load bias tile
184+
VEC4_T bias[TILE_N4];
185+
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
186+
bias[n4] = t_bias[oc_block_idx + n4];
187+
}
188+
189+
[[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) {
190+
[[unroll]] for (int m4i = 0; m4i < 4; ++m4i) {
191+
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
192+
const int m = mul_4(m4) + m4i;
193+
// Compute floating point output values
194+
ivec4 accum_adjusted =
195+
input_zp_vec * weight_sums[n4] + out_accum[m][n4];
196+
vec4 float_out_texel =
197+
fma(vec4(accum_adjusted),
198+
vec4(weight_scales[n4]) * input_scale,
199+
vec4(bias[n4]));
200+
// Requantize to int8
201+
float_out_texel =
202+
round(float_out_texel * output_inv_scale) + output_zp;
203+
ivec4 quantized_out_texel = clamp(ivec4(float_out_texel), -128, 127);
204+
205+
int8_out_tile[m4][n4][m4i] = pack_into_int32(quantized_out_texel);
206+
}
207+
}
208+
}
209+
} else {
210+
[[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) {
211+
[[unroll]] for (int m4i = 0; m4i < 4; ++m4i) {
212+
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
213+
const int m = mul_4(m4) + m4i;
214+
// Compute floating point output values
215+
ivec4 accum_adjusted =
216+
input_zp_vec * weight_sums[n4] + out_accum[m][n4];
217+
vec4 float_out_texel =
218+
vec4(accum_adjusted) * vec4(weight_scales[n4] * input_scale);
219+
// Requantize to int8
220+
float_out_texel =
221+
round(float_out_texel * output_inv_scale) + output_zp;
222+
ivec4 quantized_out_texel = clamp(ivec4(float_out_texel), -128, 127);
223+
224+
int8_out_tile[m4][n4][m4i] = pack_into_int32(quantized_out_texel);
225+
}
226+
}
227+
}
228+
}
229+
230+
const int outp_w_stride = int(outp.strides[0][0]);
231+
232+
// Store packed int8 output tile
233+
[[unroll]] for (int m4 = 0; m4 < TILE_M4; m4++) {
234+
[[unroll]] for (int n4 = 0; n4 < TILE_N4; n4++) {
235+
const int base_outp_buffer_idx = compute_outp_buffer_idx(
236+
ow_block_idx + m4,
237+
oh,
238+
oc_block_idx + n4);
239+
if (oc_block_idx + n4 < OC4) {
240+
// Store individual ints from the ivec4
241+
const int subtile_w_limit = min(4, W - mul_4(ow_block_idx + m4));
242+
[[unroll]] for (int subtile_w = 0; subtile_w < subtile_w_limit; ++subtile_w) {
243+
if (get_outer_packed_dim_block_size(outp_layout) == 1) {
244+
t_packed_int8_output[base_outp_buffer_idx + subtile_w * outp_w_stride] = int8_out_tile[m4][n4][subtile_w];
245+
} else {
246+
t_packed_int8_output[base_outp_buffer_idx + subtile_w] = int8_out_tile[m4][n4][subtile_w];
247+
}
248+
}
249+
}
250+
}
251+
}
252+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
q8ta_conv2d_pw:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
generate_variant_forall:
11+
DTYPE:
12+
- VALUE: float
13+
shader_variants:
14+
- NAME: q8ta_conv2d_pw

backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@ bool q8ta_conv2d_check_packed_dim_info(const api::PackedDimInfo& info) {
2525
info.outer_packed_dim_block_size == 4);
2626
}
2727

28+
bool q8ta_conv2d_check_4w4c_packed_dim_info(const api::PackedDimInfo& info) {
29+
return info.packed_dim == WHCN::kChannelsDim &&
30+
info.packed_dim_block_size == 4 &&
31+
info.outer_packed_dim == WHCN::kWidthDim &&
32+
info.outer_packed_dim_block_size == 4;
33+
}
34+
2835
//
2936
// Workgroup size selection functions
3037
//

backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ namespace vkcompute {
1515

1616
bool q8ta_conv2d_check_packed_dim_info(const api::PackedDimInfo& info);
1717

18+
bool q8ta_conv2d_check_4w4c_packed_dim_info(const api::PackedDimInfo& info);
19+
1820
ValueRef prepack_quantized_conv2d_weight(
1921
ComputeGraph& graph,
2022
const QuantizationConfig& weight_quant_config,

0 commit comments

Comments
 (0)