From 2695f21350570e49e3ee11d07807bf3a7a1d9d06 Mon Sep 17 00:00:00 2001 From: ssjia Date: Thu, 5 Feb 2026 08:33:09 -0800 Subject: [PATCH] [ET-VK][qconv] Add flexible layout impl for im2col This implements an im2col-based approach for quantized conv2d, which transforms convolution into matrix multiplication. The im2col transformation extracts sliding windows from the input tensor and reshapes them into a 2D matrix, enabling reuse of the optimized pointwise convolution shader for the compute-intensive portion. Two im2col shaders are added: - `q8ta_im2col.glsl`: Generic shader with layout-agnostic input access via BufferMetadata and specialization constants - `q8ta_im2col_4w4c.glsl`: Optimized shader for 4W4C input layout that exploits the alignment between consecutive width positions and packed channel values The im2col output is always stored in 4W4C layout to match the expected input format of the pointwise convolution shader. The operator is registered as `etvk.q8ta_conv2d_im2col.default` and currently supports non-grouped convolutions where input channels is a multiple of 4. Authored with assistance from Claude. Differential Revision: [D92407723](https://our.internmc.facebook.com/intern/diff/D92407723/) [ghstack-poisoned] --- .../runtime/graph/ops/glsl/q8ta_im2col.glsl | 164 +++++++++++ .../runtime/graph/ops/glsl/q8ta_im2col.yaml | 11 + .../graph/ops/glsl/q8ta_im2col_4w4c.glsl | 130 +++++++++ .../graph/ops/glsl/q8ta_im2col_4w4c.yaml | 11 + .../runtime/graph/ops/impl/Q8taConv2d.h | 14 + .../graph/ops/impl/Q8taConv2dIm2Col.cpp | 275 ++++++++++++++++++ .../test/custom_ops/impl/TestQ8taConv2d.cpp | 3 + .../test/custom_ops/test_q8ta_conv2d.cpp | 23 ++ 8 files changed, 631 insertions(+) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col_4w4c.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col_4w4c.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col.glsl new file mode 100644 index 00000000000..4c9ca6b5728 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col.glsl @@ -0,0 +1,164 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define PACKED_INT8_OUTPUT_BUFFER + +#define TILE_M4 1 +#define TILE_N4 1 +#define TILE_K4 1 + +#define TILE_M 4 +#define TILE_N 4 +#define TILE_K 4 + +layout(std430) buffer; + +#include "indexing.glslh" +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", "buffer", is_scalar_array=True)} + +// Metadata for im2col output and input tensors (layout-agnostic) +${layout_declare_ubo(B, "BufferMetadata", "im2col_outp")} +${layout_declare_ubo(B, "BufferMetadata", "inp")} +${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + +${layout_declare_spec_const(C, "int", "apply_bias", "1")} + +// Layout specialization constants +${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")} +${layout_declare_spec_const(C, "int", "im2col_outp_layout", "CONTIG_LAYOUT_INT")} + +layout(push_constant) uniform restrict Block { + int zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "conv2d_int8_output_tile_store.glslh" + +// Compute input tensor index from im2col coordinates +TensorIndex4D get_input_tidx( + const int im2col_w, + const int im2col_h, + const int k_in_group, + const int group_idx) { + TensorIndex4D tidx; + tidx.data.w = 0; + + const int c_in_group = k_in_group % conv2d_params.in_channels_per_group; + const int row = k_in_group / conv2d_params.in_channels_per_group; + const int kernel_x = row % conv2d_params.kernel_size.x; + const int kernel_y = row / conv2d_params.kernel_size.x; + + tidx.data.z = group_idx * conv2d_params.in_channels_per_group + c_in_group; + + tidx.data.x = (im2col_w * conv2d_params.stride.x) - conv2d_params.padding.x + + (kernel_x * conv2d_params.dilation.x); + tidx.data.y = (im2col_h * conv2d_params.stride.y) - conv2d_params.padding.y + + (kernel_y * conv2d_params.dilation.y); + + return tidx; +} + +// Load a single int8 value from the input tensor using layout-agnostic indexing +int load_input_element(const TensorIndex4D tidx, const int input_zp) { + // Bounds checking + if (any(lessThan(tidx.data, ivec4(0))) || + any(greaterThanEqual(tidx.data, ivec4(inp.sizes[0])))) { + return input_zp; + } + + // Use layout-agnostic indexing to get buffer position + int texel_idx; + if (get_outer_packed_dim_block_size(inp_layout) == 1) { + // For 4C or 4C1W layouts: use tensor4d_idx_to_texel_idx + texel_idx = tensor4d_idx_to_texel_idx(inp, tidx, inp_layout); + } else { + // For 4W4C layout: compute index directly + const int w4 = div_4(tidx.data[0]); + const int c4 = div_4(tidx.data[2]); + const int h_stride = int(inp.strides[0][1]); + const int w_stride = int(inp.strides[0][0]); + texel_idx = (tidx.data[1] * h_stride + w4 * w_stride + c4) * 4 + mod_4(tidx.data[0]); + } + + // Load packed int32 containing 4 int8 values + const int packed_input = t_packed_int8_input[texel_idx]; + + // Extract the appropriate int8 value based on channel offset within texel + const int c_offset = mod_4(tidx.data[2]); + return extract_8bit_from_packed_int_le(packed_input, c_offset); +} + +// Load a 4x4 im2col block (4 widths × 4 channels) +ivec4 load_im2col_block( + const int im2col_w_start, + const int im2col_h, + const int k_in_group_start, + const int group_idx) { + ivec4 im2col_block; + + for (int r = 0; r < 4; r++) { + const int im2col_w = im2col_w_start + r; + ivec4 row_values; + for (int c = 0; c < 4; c++) { + const int k_in_group = k_in_group_start + c; + + if (k_in_group >= conv2d_params.logical_K_per_group) { + row_values[c] = zp; + continue; + } + + TensorIndex4D input_tidx = + get_input_tidx(im2col_w, im2col_h, k_in_group, group_idx); + + row_values[c] = load_input_element(input_tidx, zp); + } + + im2col_block[r] = pack_into_int32(row_values); + } + return im2col_block; +} + +void main() { + const int out_buf_idx = int(gl_GlobalInvocationID.x); + + const ivec4 im2col_sizes = ivec4(im2col_outp.sizes[0]); + Conv2dBlockExtents im2col_block_extents = make_block_extents(im2col_sizes); + + Conv2dBlockIndex im2col_block_idx = linear_idx_to_block_idx( + out_buf_idx, im2col_block_extents); + + if (block_idx_out_of_bounds(im2col_block_idx, im2col_block_extents)) { + return; + } + + // Convert block index to im2col coordinates + const int im2col_w = mul_4(im2col_block_idx.data.x); + const int im2col_h = im2col_block_idx.data.y; + const int im2col_k = mul_4(im2col_block_idx.data.z); + + // Compute group and k offset within group + const int group_idx = im2col_k / conv2d_params.K_per_group; + const int k_in_group = im2col_k % conv2d_params.K_per_group; + + // Load the im2col block using layout-agnostic input access + Int8OutTile int8_im2col_tile; + int8_im2col_tile.data[0][0] = load_im2col_block( + im2col_w, im2col_h, k_in_group, group_idx); + + // Store to output (4W4C format) + store_packed_int8_output_tile( + int8_im2col_tile, im2col_block_idx, im2col_block_extents); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col.yaml b/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col.yaml new file mode 100644 index 00000000000..08ce5d59d35 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col.yaml @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +q8ta_im2col: + parameter_names_with_default_values: + DTYPE: float + shader_variants: + - NAME: q8ta_im2col diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col_4w4c.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col_4w4c.glsl new file mode 100644 index 00000000000..9c5e0ee9066 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col_4w4c.glsl @@ -0,0 +1,130 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define PACKED_INT8_OUTPUT_BUFFER + +layout(std430) buffer; + +#include "indexing.glslh" +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", "buffer", is_scalar_array=True)} + +// Metadata for im2col output and input tensors (layout-agnostic) +${layout_declare_ubo(B, "BufferMetadata", "im2col_outp")} +${layout_declare_ubo(B, "BufferMetadata", "inp")} +${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + +${layout_declare_spec_const(C, "int", "apply_bias", "1")} + +// Layout specialization constants +${layout_declare_spec_const(C, "int", "outp_layout", "CONTIG_LAYOUT_INT")} +${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")} + +layout(push_constant) uniform restrict Block { + int zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const int out_buf_idx = int(gl_GlobalInvocationID.x); + + // Extract sizes from BufferMetadata + const ivec4 im2col_sizes = ivec4(im2col_outp.sizes[0]); + const ivec4 input_sizes = ivec4(inp.sizes[0]); + + // im2col block extents + const int im2col_W4 = div_up_4(im2col_sizes.x); + const int im2col_H = im2col_sizes.y; + const int im2col_Z4 = div_up_4(im2col_sizes.z); + + // im2col block index from linear output buffer index + const int c4_idx = out_buf_idx % im2col_Z4; + const int row = out_buf_idx / im2col_Z4; + const int w4_idx = row % im2col_W4; + const int h_idx = row / im2col_W4; + + // out of bounds check + if (w4_idx >= im2col_W4 || h_idx >= im2col_H || c4_idx >= im2col_Z4) { + return; + } + + const int im2col_w = mul_4(w4_idx); + const int im2col_h = h_idx; + const int im2col_k = mul_4(c4_idx); + + const int group_idx = im2col_k / conv2d_params.K_per_group; + const int k_in_group = im2col_k % conv2d_params.K_per_group; + + const int c_in_group = k_in_group % conv2d_params.in_channels_per_group; + const int krow = k_in_group / conv2d_params.in_channels_per_group; + const int kernel_x = krow % conv2d_params.kernel_size.x; + const int kernel_y = krow / conv2d_params.kernel_size.x; + + // Base input position + const int input_x_base = + (im2col_w * conv2d_params.stride.x) - conv2d_params.padding.x + + (kernel_x * conv2d_params.dilation.x); + const int input_y = + (im2col_h * conv2d_params.stride.y) - conv2d_params.padding.y + + (kernel_y * conv2d_params.dilation.y); + const int input_z = + group_idx * conv2d_params.in_channels_per_group + c_in_group; + + // Input tensor extents + const int input_W = input_sizes.x; + const int input_H = input_sizes.y; + const int input_Z4 = div_up_4(input_sizes.z); + + const int zp_packed = pack_into_int32(ivec4(zp)); + const int z4 = div_4(input_z); + + // Check if y and z are in bounds (constant for all 4 width elements) + const bool y_z_in_bounds = + (input_y >= 0 && input_y < input_H && z4 >= 0 && z4 < input_Z4); + + // Load 4 elements from input, one for each output width position. + // Each loaded int contains 4 packed int8 channel values. + ivec4 im2col_block; + for (int i = 0; i < 4; i++) { + const int x = input_x_base + i; + if (!y_z_in_bounds || x < 0 || x >= input_W) { + im2col_block[i] = zp_packed; + } else { + const int x4 = div_4(x); + const int x_mod = mod_4(x); + int scalar_idx; + if (get_outer_packed_dim_block_size(inp_layout) == 1) { + scalar_idx = input_y * int(inp.strides[0][1]) + + x * int(inp.strides[0][0]) + + z4 * int(inp.strides[0][2]); + } else { + scalar_idx = mul_4( + input_y * int(inp.strides[0][1]) + + x4 * int(inp.strides[0][0]) + + z4) + x_mod; + } + im2col_block[i] = t_packed_int8_input[scalar_idx]; + } + } + + // store_packed_int8_output_tile (with TILE_M4=1, TILE_N4=1) + const int buffer_idx = h_idx * int(im2col_outp.strides[0][1]) + + w4_idx * int(im2col_outp.strides[0][0]) + + c4_idx; + + if (w4_idx < im2col_W4 && c4_idx < im2col_Z4) { + t_packed_int8_output[buffer_idx] = im2col_block; + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col_4w4c.yaml b/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col_4w4c.yaml new file mode 100644 index 00000000000..0de3d97f324 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col_4w4c.yaml @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +q8ta_im2col_4w4c: + parameter_names_with_default_values: + DTYPE: float + shader_variants: + - NAME: q8ta_im2col_4w4c diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h index 2c22f35ad13..53a5aa15fe6 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h @@ -99,4 +99,18 @@ void add_q8ta_conv2d_node( const ValueRef groups, const ValueRef packed_int8_output); +void add_q8ta_conv2d_pw_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef packed_weight, + const ValueRef packed_weight_sums, + const ValueRef packed_weight_scales, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef bias_data, + const ValueRef packed_bias, + const ValueRef packed_int8_output); + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp new file mode 100644 index 00000000000..4bbcc16e43d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp @@ -0,0 +1,275 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include +#include + +namespace vkcompute { + +// +// Shader dispatch utilities +// + +utils::uvec3 pick_q8ta_im2col_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef im2col_output = args.at(0).refs.at(0); + + std::vector im2col_sizes = graph->sizes_of(im2col_output); + const uint32_t K = utils::safe_downcast(im2col_sizes[0]); + const uint32_t H = utils::safe_downcast(im2col_sizes[1]); + const uint32_t W = utils::safe_downcast(im2col_sizes[2]); + + const uint32_t K4 = utils::div_up_4(K); + const uint32_t W4 = utils::div_up_4(W); + + // Each thread handles one 4x4 block in the output + return {K4 * W4 * H, 1, 1}; +} + +utils::uvec3 pick_q8ta_im2col_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)graph; + (void)shader; + (void)args; + (void)resize_args; + (void)global_workgroup_size; + + return {64, 1, 1}; +} + +// +// Im2col calculation utilities +// + +std::vector calculate_q8ta_im2col_sizes( + ComputeGraph* graph, + const ValueRef& input, + const ValueRef& output, + const ValueRef& kernel_size, + const ValueRef& groups) { + std::vector in_sizes = graph->sizes_of(input); + const int64_t in_channels = utils::val_at(-3, in_sizes); + + std::vector out_sizes = graph->sizes_of(output); + const int64_t out_height = utils::val_at(-2, out_sizes); + const int64_t out_width = utils::val_at(-1, out_sizes); + + const int64_t groups_val = graph->extract_scalar(groups); + const int64_t in_channels_per_group = in_channels / groups_val; + + const auto kernel_size_list = graph->get_int_list(kernel_size); + + // Align to next multiple of 4 to ensure data loads align nicely with + // texel boundaries + const int64_t flattened_kernel_len = utils::align_up_4( + in_channels_per_group * kernel_size_list->at(0) * + kernel_size_list->at(1)); + + // K -> flattened convolution window (repeated for each group) + const int64_t K = flattened_kernel_len * groups_val; + // M -> number of elements in 2D output plane + const int64_t W = utils::align_up_4(out_width); + const int64_t H = out_height; + + return {K, H, W}; +} + +// +// Dispatch nodes +// + +void add_q8ta_im2col_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef kernel_size, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef groups, + const ValueRef packed_int8_output, + const ValueRef packed_int8_im2col, + const int32_t zp) { + // Validate packed dim info for input and output tensors + VK_CHECK_COND(q8ta_conv2d_check_packed_dim_info( + graph.packed_dim_info_of(packed_int8_input))); + // The the output tensor must be in 4W4C layout + VK_CHECK_COND(q8ta_conv2d_check_4w4c_packed_dim_info( + graph.packed_dim_info_of(packed_int8_im2col))); + + Conv2DParams conv_params = create_conv2d_params( + graph, + packed_int8_input, + packed_int8_output, + kernel_size, + stride, + padding, + dilation, + groups); + + // At the moment, the im2col path only supports non-grouped convolutions + VK_CHECK_COND(conv_params.groups == 1); + // The implementation also requires that input channels is a multiple of 4 + VK_CHECK_COND(conv_params.in_channels_per_group % 4 == 0); + + std::string kernel_name = "q8ta_im2col_4w4c"; + + vkapi::ParamsBindList param_buffers = { + graph.buffer_meta_ubo(packed_int8_im2col), + graph.buffer_meta_ubo(packed_int8_input), + graph.create_params_buffer(conv_params)}; + + std::vector push_constants = { + PushConstantDataInfo(&zp, sizeof(zp)), + }; + + // Build spec constants: apply_bias + layout constants (for generic shader) + vkapi::SpecVarList spec_constants = { + 1u, + graph.hashed_layout_of(packed_int8_im2col), + graph.hashed_layout_of(packed_int8_input), + }; + + // // Add layout specialization constants (only for generic shader) + // if (!use_4w4c_path) { + // spec_constants.append(graph.hashed_layout_of(packed_int8_input)); + // spec_constants.append(graph.hashed_layout_of(packed_int8_im2col)); + // } + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + pick_q8ta_im2col_global_wg_size, + pick_q8ta_im2col_local_wg_size, + // Inputs and Outputs + {{packed_int8_im2col, vkapi::kWrite}, {packed_int8_input, vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + spec_constants, + // Resize args + {}, + // Resizing Logic + nullptr)); +} + +// +// High level operator impl +// + +void q8ta_conv2d_im2col( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef packed_int8_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_sums_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + const ValueRef kernel_size = args.at(idx++); + const ValueRef stride = args.at(idx++); + const ValueRef padding = args.at(idx++); + const ValueRef dilation = args.at(idx++); + const ValueRef groups = args.at(idx++); + const ValueRef packed_int8_output = args.at(idx++); + + QuantizationConfig weight_quant_config(8, kPerChannel, {}); + + // Prepack weight using linear weight packing (for im2col approach) + ValueRef packed_weight = + prepack_quantized_linear_weight(graph, weight_quant_config, weight_data); + + ValueRef packed_weight_sums = prepack_standard( + graph, weight_sums_data, utils::kBuffer, utils::kWidthPacked); + + ValueRef packed_weight_scales = prepack_standard( + graph, weight_scales_data, utils::kBuffer, utils::kWidthPacked); + + // Create dummy tensor to fill bias binding slot if not provided + TmpTensor dummy_bias( + &graph, + {}, + graph.dtype_of(weight_scales_data), + utils::kBuffer, + utils::kWidthPacked); + + ValueRef packed_bias = dummy_bias.vref; + if (graph.val_is_not_none(bias_data)) { + packed_bias = + prepack_standard(graph, bias_data, utils::kBuffer, utils::kWidthPacked); + } + + // Calculate im2col output sizes + std::vector im2col_sizes = calculate_q8ta_im2col_sizes( + &graph, packed_int8_input, packed_int8_output, kernel_size, groups); + + // Create temporary tensor for im2col output (4W4C layout) + TmpTensor packed_int8_im2col( + &graph, + im2col_sizes, + vkapi::kInt8x4, + utils::kBuffer, + utils::kPackedInt8_4W4C); + + int32_t zp = graph.extract_scalar(input_zp); + + // Step 1: Perform im2col transformation + add_q8ta_im2col_node( + graph, + packed_int8_input, + kernel_size, + stride, + padding, + dilation, + groups, + packed_int8_output, + packed_int8_im2col, + zp); + + // Step 2: Perform pointwise convolution on the im2col result + add_q8ta_conv2d_pw_node( + graph, + packed_int8_im2col, + input_scale, + input_zp, + packed_weight, + packed_weight_sums, + packed_weight_scales, + output_scale, + output_zp, + bias_data, + packed_bias, + packed_int8_output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(etvk.q8ta_conv2d_im2col.default, q8ta_conv2d_im2col); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp b/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp index 13e9c5b5b67..acb5a3d03f5 100644 --- a/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp +++ b/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp @@ -154,6 +154,9 @@ void test_q8ta_conv2d(ComputeGraph& graph, const std::vector& args) { if (impl_selector == "legacy_4w4c") { // Use the general quantized conv2d operator for legacy path VK_GET_OP_FN("et_vk.conv2d_q8ta_q8csw_q8to.default")(graph, conv_args); + } else if (impl_selector == "im2col") { + // Use the im2col-based conv2d operator + VK_GET_OP_FN("etvk.q8ta_conv2d_im2col.default")(graph, conv_args); } else { // Use the new general q8ta_conv2d operator VK_GET_OP_FN("etvk.q8ta_conv2d.default")(graph, conv_args); diff --git a/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp b/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp index 5d5ad356122..8f445ab7230 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp @@ -235,6 +235,17 @@ std::vector generate_quantized_conv2d_easy_cases() { test_cases.push_back(create_test_case_from_config( config, vkapi::kFloat, fp_storage_type, int8_memory_layout)); + // Test im2col implementation for non-grouped convolutions with input + // channels that are a multiple of 4 and stride_w == 1 + if (config.groups == 1 && config.channels.in % 4 == 0 && + config.stride.w == 1) { + test_cases.push_back(create_test_case_from_config( + config, + vkapi::kFloat, + fp_storage_type, + int8_memory_layout, + /*impl_selector=*/"im2col")); + } // For 4W4C layout, also test the legacy implementation if (int8_memory_layout == utils::kPackedInt8_4W4C) { test_cases.push_back(create_test_case_from_config( @@ -403,6 +414,18 @@ static std::vector generate_quantized_conv2d_test_cases() { int8_memory_layout, /*impl_selector=*/"legacy_4w4c")); } + + // Test im2col implementation for non-grouped convolutions with input + // channels that are a multiple of 4 and stride_w == 1 + if (config.groups == 1 && config.channels.in % 4 == 0 && + config.stride.w == 1) { + test_cases.push_back(create_test_case_from_config( + config, + vkapi::kFloat, + fp_storage_type, + int8_memory_layout, + /*impl_selector=*/"im2col")); + } } } }