diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col.glsl new file mode 100644 index 00000000000..4c9ca6b5728 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col.glsl @@ -0,0 +1,164 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define PACKED_INT8_OUTPUT_BUFFER + +#define TILE_M4 1 +#define TILE_N4 1 +#define TILE_K4 1 + +#define TILE_M 4 +#define TILE_N 4 +#define TILE_K 4 + +layout(std430) buffer; + +#include "indexing.glslh" +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", "buffer", is_scalar_array=True)} + +// Metadata for im2col output and input tensors (layout-agnostic) +${layout_declare_ubo(B, "BufferMetadata", "im2col_outp")} +${layout_declare_ubo(B, "BufferMetadata", "inp")} +${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + +${layout_declare_spec_const(C, "int", "apply_bias", "1")} + +// Layout specialization constants +${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")} +${layout_declare_spec_const(C, "int", "im2col_outp_layout", "CONTIG_LAYOUT_INT")} + +layout(push_constant) uniform restrict Block { + int zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "conv2d_int8_output_tile_store.glslh" + +// Compute input tensor index from im2col coordinates +TensorIndex4D get_input_tidx( + const int im2col_w, + const int im2col_h, + const int k_in_group, + const int group_idx) { + TensorIndex4D tidx; + tidx.data.w = 0; + + const int c_in_group = k_in_group % conv2d_params.in_channels_per_group; + const int row = k_in_group / conv2d_params.in_channels_per_group; + const int kernel_x = row % conv2d_params.kernel_size.x; + const int kernel_y = row / conv2d_params.kernel_size.x; + + tidx.data.z = group_idx * conv2d_params.in_channels_per_group + c_in_group; + + tidx.data.x = (im2col_w * conv2d_params.stride.x) - conv2d_params.padding.x + + (kernel_x * conv2d_params.dilation.x); + tidx.data.y = (im2col_h * conv2d_params.stride.y) - conv2d_params.padding.y + + (kernel_y * conv2d_params.dilation.y); + + return tidx; +} + +// Load a single int8 value from the input tensor using layout-agnostic indexing +int load_input_element(const TensorIndex4D tidx, const int input_zp) { + // Bounds checking + if (any(lessThan(tidx.data, ivec4(0))) || + any(greaterThanEqual(tidx.data, ivec4(inp.sizes[0])))) { + return input_zp; + } + + // Use layout-agnostic indexing to get buffer position + int texel_idx; + if (get_outer_packed_dim_block_size(inp_layout) == 1) { + // For 4C or 4C1W layouts: use tensor4d_idx_to_texel_idx + texel_idx = tensor4d_idx_to_texel_idx(inp, tidx, inp_layout); + } else { + // For 4W4C layout: compute index directly + const int w4 = div_4(tidx.data[0]); + const int c4 = div_4(tidx.data[2]); + const int h_stride = int(inp.strides[0][1]); + const int w_stride = int(inp.strides[0][0]); + texel_idx = (tidx.data[1] * h_stride + w4 * w_stride + c4) * 4 + mod_4(tidx.data[0]); + } + + // Load packed int32 containing 4 int8 values + const int packed_input = t_packed_int8_input[texel_idx]; + + // Extract the appropriate int8 value based on channel offset within texel + const int c_offset = mod_4(tidx.data[2]); + return extract_8bit_from_packed_int_le(packed_input, c_offset); +} + +// Load a 4x4 im2col block (4 widths × 4 channels) +ivec4 load_im2col_block( + const int im2col_w_start, + const int im2col_h, + const int k_in_group_start, + const int group_idx) { + ivec4 im2col_block; + + for (int r = 0; r < 4; r++) { + const int im2col_w = im2col_w_start + r; + ivec4 row_values; + for (int c = 0; c < 4; c++) { + const int k_in_group = k_in_group_start + c; + + if (k_in_group >= conv2d_params.logical_K_per_group) { + row_values[c] = zp; + continue; + } + + TensorIndex4D input_tidx = + get_input_tidx(im2col_w, im2col_h, k_in_group, group_idx); + + row_values[c] = load_input_element(input_tidx, zp); + } + + im2col_block[r] = pack_into_int32(row_values); + } + return im2col_block; +} + +void main() { + const int out_buf_idx = int(gl_GlobalInvocationID.x); + + const ivec4 im2col_sizes = ivec4(im2col_outp.sizes[0]); + Conv2dBlockExtents im2col_block_extents = make_block_extents(im2col_sizes); + + Conv2dBlockIndex im2col_block_idx = linear_idx_to_block_idx( + out_buf_idx, im2col_block_extents); + + if (block_idx_out_of_bounds(im2col_block_idx, im2col_block_extents)) { + return; + } + + // Convert block index to im2col coordinates + const int im2col_w = mul_4(im2col_block_idx.data.x); + const int im2col_h = im2col_block_idx.data.y; + const int im2col_k = mul_4(im2col_block_idx.data.z); + + // Compute group and k offset within group + const int group_idx = im2col_k / conv2d_params.K_per_group; + const int k_in_group = im2col_k % conv2d_params.K_per_group; + + // Load the im2col block using layout-agnostic input access + Int8OutTile int8_im2col_tile; + int8_im2col_tile.data[0][0] = load_im2col_block( + im2col_w, im2col_h, k_in_group, group_idx); + + // Store to output (4W4C format) + store_packed_int8_output_tile( + int8_im2col_tile, im2col_block_idx, im2col_block_extents); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col.yaml b/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col.yaml new file mode 100644 index 00000000000..08ce5d59d35 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col.yaml @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +q8ta_im2col: + parameter_names_with_default_values: + DTYPE: float + shader_variants: + - NAME: q8ta_im2col diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col_4w4c.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col_4w4c.glsl new file mode 100644 index 00000000000..9c5e0ee9066 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col_4w4c.glsl @@ -0,0 +1,130 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define PACKED_INT8_OUTPUT_BUFFER + +layout(std430) buffer; + +#include "indexing.glslh" +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", "buffer", is_scalar_array=True)} + +// Metadata for im2col output and input tensors (layout-agnostic) +${layout_declare_ubo(B, "BufferMetadata", "im2col_outp")} +${layout_declare_ubo(B, "BufferMetadata", "inp")} +${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + +${layout_declare_spec_const(C, "int", "apply_bias", "1")} + +// Layout specialization constants +${layout_declare_spec_const(C, "int", "outp_layout", "CONTIG_LAYOUT_INT")} +${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")} + +layout(push_constant) uniform restrict Block { + int zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const int out_buf_idx = int(gl_GlobalInvocationID.x); + + // Extract sizes from BufferMetadata + const ivec4 im2col_sizes = ivec4(im2col_outp.sizes[0]); + const ivec4 input_sizes = ivec4(inp.sizes[0]); + + // im2col block extents + const int im2col_W4 = div_up_4(im2col_sizes.x); + const int im2col_H = im2col_sizes.y; + const int im2col_Z4 = div_up_4(im2col_sizes.z); + + // im2col block index from linear output buffer index + const int c4_idx = out_buf_idx % im2col_Z4; + const int row = out_buf_idx / im2col_Z4; + const int w4_idx = row % im2col_W4; + const int h_idx = row / im2col_W4; + + // out of bounds check + if (w4_idx >= im2col_W4 || h_idx >= im2col_H || c4_idx >= im2col_Z4) { + return; + } + + const int im2col_w = mul_4(w4_idx); + const int im2col_h = h_idx; + const int im2col_k = mul_4(c4_idx); + + const int group_idx = im2col_k / conv2d_params.K_per_group; + const int k_in_group = im2col_k % conv2d_params.K_per_group; + + const int c_in_group = k_in_group % conv2d_params.in_channels_per_group; + const int krow = k_in_group / conv2d_params.in_channels_per_group; + const int kernel_x = krow % conv2d_params.kernel_size.x; + const int kernel_y = krow / conv2d_params.kernel_size.x; + + // Base input position + const int input_x_base = + (im2col_w * conv2d_params.stride.x) - conv2d_params.padding.x + + (kernel_x * conv2d_params.dilation.x); + const int input_y = + (im2col_h * conv2d_params.stride.y) - conv2d_params.padding.y + + (kernel_y * conv2d_params.dilation.y); + const int input_z = + group_idx * conv2d_params.in_channels_per_group + c_in_group; + + // Input tensor extents + const int input_W = input_sizes.x; + const int input_H = input_sizes.y; + const int input_Z4 = div_up_4(input_sizes.z); + + const int zp_packed = pack_into_int32(ivec4(zp)); + const int z4 = div_4(input_z); + + // Check if y and z are in bounds (constant for all 4 width elements) + const bool y_z_in_bounds = + (input_y >= 0 && input_y < input_H && z4 >= 0 && z4 < input_Z4); + + // Load 4 elements from input, one for each output width position. + // Each loaded int contains 4 packed int8 channel values. + ivec4 im2col_block; + for (int i = 0; i < 4; i++) { + const int x = input_x_base + i; + if (!y_z_in_bounds || x < 0 || x >= input_W) { + im2col_block[i] = zp_packed; + } else { + const int x4 = div_4(x); + const int x_mod = mod_4(x); + int scalar_idx; + if (get_outer_packed_dim_block_size(inp_layout) == 1) { + scalar_idx = input_y * int(inp.strides[0][1]) + + x * int(inp.strides[0][0]) + + z4 * int(inp.strides[0][2]); + } else { + scalar_idx = mul_4( + input_y * int(inp.strides[0][1]) + + x4 * int(inp.strides[0][0]) + + z4) + x_mod; + } + im2col_block[i] = t_packed_int8_input[scalar_idx]; + } + } + + // store_packed_int8_output_tile (with TILE_M4=1, TILE_N4=1) + const int buffer_idx = h_idx * int(im2col_outp.strides[0][1]) + + w4_idx * int(im2col_outp.strides[0][0]) + + c4_idx; + + if (w4_idx < im2col_W4 && c4_idx < im2col_Z4) { + t_packed_int8_output[buffer_idx] = im2col_block; + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col_4w4c.yaml b/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col_4w4c.yaml new file mode 100644 index 00000000000..0de3d97f324 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col_4w4c.yaml @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +q8ta_im2col_4w4c: + parameter_names_with_default_values: + DTYPE: float + shader_variants: + - NAME: q8ta_im2col_4w4c diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h index 2c22f35ad13..53a5aa15fe6 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h @@ -99,4 +99,18 @@ void add_q8ta_conv2d_node( const ValueRef groups, const ValueRef packed_int8_output); +void add_q8ta_conv2d_pw_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef packed_weight, + const ValueRef packed_weight_sums, + const ValueRef packed_weight_scales, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef bias_data, + const ValueRef packed_bias, + const ValueRef packed_int8_output); + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp new file mode 100644 index 00000000000..4bbcc16e43d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp @@ -0,0 +1,275 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include +#include + +namespace vkcompute { + +// +// Shader dispatch utilities +// + +utils::uvec3 pick_q8ta_im2col_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef im2col_output = args.at(0).refs.at(0); + + std::vector im2col_sizes = graph->sizes_of(im2col_output); + const uint32_t K = utils::safe_downcast(im2col_sizes[0]); + const uint32_t H = utils::safe_downcast(im2col_sizes[1]); + const uint32_t W = utils::safe_downcast(im2col_sizes[2]); + + const uint32_t K4 = utils::div_up_4(K); + const uint32_t W4 = utils::div_up_4(W); + + // Each thread handles one 4x4 block in the output + return {K4 * W4 * H, 1, 1}; +} + +utils::uvec3 pick_q8ta_im2col_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)graph; + (void)shader; + (void)args; + (void)resize_args; + (void)global_workgroup_size; + + return {64, 1, 1}; +} + +// +// Im2col calculation utilities +// + +std::vector calculate_q8ta_im2col_sizes( + ComputeGraph* graph, + const ValueRef& input, + const ValueRef& output, + const ValueRef& kernel_size, + const ValueRef& groups) { + std::vector in_sizes = graph->sizes_of(input); + const int64_t in_channels = utils::val_at(-3, in_sizes); + + std::vector out_sizes = graph->sizes_of(output); + const int64_t out_height = utils::val_at(-2, out_sizes); + const int64_t out_width = utils::val_at(-1, out_sizes); + + const int64_t groups_val = graph->extract_scalar(groups); + const int64_t in_channels_per_group = in_channels / groups_val; + + const auto kernel_size_list = graph->get_int_list(kernel_size); + + // Align to next multiple of 4 to ensure data loads align nicely with + // texel boundaries + const int64_t flattened_kernel_len = utils::align_up_4( + in_channels_per_group * kernel_size_list->at(0) * + kernel_size_list->at(1)); + + // K -> flattened convolution window (repeated for each group) + const int64_t K = flattened_kernel_len * groups_val; + // M -> number of elements in 2D output plane + const int64_t W = utils::align_up_4(out_width); + const int64_t H = out_height; + + return {K, H, W}; +} + +// +// Dispatch nodes +// + +void add_q8ta_im2col_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef kernel_size, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef groups, + const ValueRef packed_int8_output, + const ValueRef packed_int8_im2col, + const int32_t zp) { + // Validate packed dim info for input and output tensors + VK_CHECK_COND(q8ta_conv2d_check_packed_dim_info( + graph.packed_dim_info_of(packed_int8_input))); + // The the output tensor must be in 4W4C layout + VK_CHECK_COND(q8ta_conv2d_check_4w4c_packed_dim_info( + graph.packed_dim_info_of(packed_int8_im2col))); + + Conv2DParams conv_params = create_conv2d_params( + graph, + packed_int8_input, + packed_int8_output, + kernel_size, + stride, + padding, + dilation, + groups); + + // At the moment, the im2col path only supports non-grouped convolutions + VK_CHECK_COND(conv_params.groups == 1); + // The implementation also requires that input channels is a multiple of 4 + VK_CHECK_COND(conv_params.in_channels_per_group % 4 == 0); + + std::string kernel_name = "q8ta_im2col_4w4c"; + + vkapi::ParamsBindList param_buffers = { + graph.buffer_meta_ubo(packed_int8_im2col), + graph.buffer_meta_ubo(packed_int8_input), + graph.create_params_buffer(conv_params)}; + + std::vector push_constants = { + PushConstantDataInfo(&zp, sizeof(zp)), + }; + + // Build spec constants: apply_bias + layout constants (for generic shader) + vkapi::SpecVarList spec_constants = { + 1u, + graph.hashed_layout_of(packed_int8_im2col), + graph.hashed_layout_of(packed_int8_input), + }; + + // // Add layout specialization constants (only for generic shader) + // if (!use_4w4c_path) { + // spec_constants.append(graph.hashed_layout_of(packed_int8_input)); + // spec_constants.append(graph.hashed_layout_of(packed_int8_im2col)); + // } + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + pick_q8ta_im2col_global_wg_size, + pick_q8ta_im2col_local_wg_size, + // Inputs and Outputs + {{packed_int8_im2col, vkapi::kWrite}, {packed_int8_input, vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + spec_constants, + // Resize args + {}, + // Resizing Logic + nullptr)); +} + +// +// High level operator impl +// + +void q8ta_conv2d_im2col( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef packed_int8_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_sums_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + const ValueRef kernel_size = args.at(idx++); + const ValueRef stride = args.at(idx++); + const ValueRef padding = args.at(idx++); + const ValueRef dilation = args.at(idx++); + const ValueRef groups = args.at(idx++); + const ValueRef packed_int8_output = args.at(idx++); + + QuantizationConfig weight_quant_config(8, kPerChannel, {}); + + // Prepack weight using linear weight packing (for im2col approach) + ValueRef packed_weight = + prepack_quantized_linear_weight(graph, weight_quant_config, weight_data); + + ValueRef packed_weight_sums = prepack_standard( + graph, weight_sums_data, utils::kBuffer, utils::kWidthPacked); + + ValueRef packed_weight_scales = prepack_standard( + graph, weight_scales_data, utils::kBuffer, utils::kWidthPacked); + + // Create dummy tensor to fill bias binding slot if not provided + TmpTensor dummy_bias( + &graph, + {}, + graph.dtype_of(weight_scales_data), + utils::kBuffer, + utils::kWidthPacked); + + ValueRef packed_bias = dummy_bias.vref; + if (graph.val_is_not_none(bias_data)) { + packed_bias = + prepack_standard(graph, bias_data, utils::kBuffer, utils::kWidthPacked); + } + + // Calculate im2col output sizes + std::vector im2col_sizes = calculate_q8ta_im2col_sizes( + &graph, packed_int8_input, packed_int8_output, kernel_size, groups); + + // Create temporary tensor for im2col output (4W4C layout) + TmpTensor packed_int8_im2col( + &graph, + im2col_sizes, + vkapi::kInt8x4, + utils::kBuffer, + utils::kPackedInt8_4W4C); + + int32_t zp = graph.extract_scalar(input_zp); + + // Step 1: Perform im2col transformation + add_q8ta_im2col_node( + graph, + packed_int8_input, + kernel_size, + stride, + padding, + dilation, + groups, + packed_int8_output, + packed_int8_im2col, + zp); + + // Step 2: Perform pointwise convolution on the im2col result + add_q8ta_conv2d_pw_node( + graph, + packed_int8_im2col, + input_scale, + input_zp, + packed_weight, + packed_weight_sums, + packed_weight_scales, + output_scale, + output_zp, + bias_data, + packed_bias, + packed_int8_output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(etvk.q8ta_conv2d_im2col.default, q8ta_conv2d_im2col); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp b/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp index 13e9c5b5b67..acb5a3d03f5 100644 --- a/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp +++ b/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp @@ -154,6 +154,9 @@ void test_q8ta_conv2d(ComputeGraph& graph, const std::vector& args) { if (impl_selector == "legacy_4w4c") { // Use the general quantized conv2d operator for legacy path VK_GET_OP_FN("et_vk.conv2d_q8ta_q8csw_q8to.default")(graph, conv_args); + } else if (impl_selector == "im2col") { + // Use the im2col-based conv2d operator + VK_GET_OP_FN("etvk.q8ta_conv2d_im2col.default")(graph, conv_args); } else { // Use the new general q8ta_conv2d operator VK_GET_OP_FN("etvk.q8ta_conv2d.default")(graph, conv_args); diff --git a/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp b/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp index 5d5ad356122..8f445ab7230 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp @@ -235,6 +235,17 @@ std::vector generate_quantized_conv2d_easy_cases() { test_cases.push_back(create_test_case_from_config( config, vkapi::kFloat, fp_storage_type, int8_memory_layout)); + // Test im2col implementation for non-grouped convolutions with input + // channels that are a multiple of 4 and stride_w == 1 + if (config.groups == 1 && config.channels.in % 4 == 0 && + config.stride.w == 1) { + test_cases.push_back(create_test_case_from_config( + config, + vkapi::kFloat, + fp_storage_type, + int8_memory_layout, + /*impl_selector=*/"im2col")); + } // For 4W4C layout, also test the legacy implementation if (int8_memory_layout == utils::kPackedInt8_4W4C) { test_cases.push_back(create_test_case_from_config( @@ -403,6 +414,18 @@ static std::vector generate_quantized_conv2d_test_cases() { int8_memory_layout, /*impl_selector=*/"legacy_4w4c")); } + + // Test im2col implementation for non-grouped convolutions with input + // channels that are a multiple of 4 and stride_w == 1 + if (config.groups == 1 && config.channels.in % 4 == 0 && + config.stride.w == 1) { + test_cases.push_back(create_test_case_from_config( + config, + vkapi::kFloat, + fp_storage_type, + int8_memory_layout, + /*impl_selector=*/"im2col")); + } } } }