diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl new file mode 100644 index 00000000000..e0963dfcf48 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl @@ -0,0 +1,252 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +${define_required_extensions("buffer", DTYPE)} + +#extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_integer_dot_product : require + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, "buffer")} +#define T ${texel_load_component_type(DTYPE, "buffer")} + +${define_active_storage_type("buffer")} + +// corresponds to input/output width dim +#define TILE_M4 1 +// corresponds to input channels dim +#define TILE_K4 1 +// corresponds to output channels dim +#define TILE_N4 2 + +#define TILE_M 4 +#define TILE_K 4 +#define TILE_N 8 + +layout(std430) buffer; + +#include "indexing.glslh" +#include "common.glslh" +#include "block_indexing.glslh" + +${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", "buffer", is_scalar_array=True)} +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", "texture2d", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +// Metadata for input/output tensors +${layout_declare_ubo(B, "BufferMetadata", "outp")} +${layout_declare_ubo(B, "BufferMetadata", "inp")} + +layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; + float output_inv_scale; + int output_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "apply_bias", "1")} +${layout_declare_spec_const(C, "int", "conv2d_params_K4_per_group", "1")} + +// Layout specialization constants +${layout_declare_spec_const(C, "int", "outp_layout", "CONTIG_LAYOUT_INT")} +${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")} + +int compute_outp_buffer_idx( + const int w_block_idx, + const int h_idx, + const int c_block_idx) { + if (get_outer_packed_dim_block_size(outp_layout) == 1) { + return h_idx * int(outp.strides[0][1]) + + mul_4(w_block_idx) * int(outp.strides[0][0]) + + c_block_idx * int(outp.strides[0][2]); + } else { + return mul_4( + h_idx * int(outp.strides[0][1]) + + w_block_idx * int(outp.strides[0][0]) + + c_block_idx * int(outp.strides[0][2])); + } + +} + +void main() { + // Thread mapping: each thread handles TILE_M (4) widths × TILE_N (8) output channels + // gl_GlobalInvocationID.x → output channel blocks (TILE_N4 = 2 blocks of 4 channels) + // gl_GlobalInvocationID.y → width blocks (TILE_M4 = 1 block of 4 widths) + // gl_GlobalInvocationID.z → batch (or height * batch combined) + const int oc_block_idx = int(gl_GlobalInvocationID.x) * TILE_N4; + const int ow_block_idx = int(gl_GlobalInvocationID.y) * TILE_M4; + const int oh = int(gl_GlobalInvocationID.z); + + // Get output extents in block space (div_up_4 for packed dimensions) + const int W = int(outp.sizes[0][0]); + const int W4 = div_up_4(int(outp.sizes[0][0])); + const int H = int(outp.sizes[0][1]); + const int OC4 = div_up_4(int(outp.sizes[0][2])); + + // Bounds check in block space + if (ow_block_idx >= W4 || + oh >= H || + oc_block_idx >= OC4) { + return; + } + + // Get input extents in block space + const int inp_W4 = div_up_4(int(inp.sizes[0][0])); + const int inp_IC4 = div_up_4(int(inp.sizes[0][2])); + + // Precompute stride products for indexing + // For 4W4C layout: buffer_idx = batch * (W4 * C4) + w4 * C4 + c4 + const int inp_w_stride = int(inp.strides[0][0]); + const int inp_h_stride = int(inp.strides[0][1]); + const int inp_c_stride = int(inp.strides[0][2]); + + // Initialize int32 accumulator + ivec4 out_accum[TILE_M][TILE_N4]; + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + out_accum[m][n4] = ivec4(0); + } + } + + // Compute initial input tile index + // Input has same spatial layout, channel dimension iterates from 0 + int input_idx = oh * inp_h_stride + ow_block_idx * inp_w_stride; + + // Main accumulation loop over K dimension + for (int k4 = 0; k4 < conv2d_params_K4_per_group; k4++) { + // Load packed int8 input tile (TILE_M4=1, TILE_K4=1) + // Each int contains 4 packed int8s (one per width position in the tile) + ivec4 int8_input_tile = t_packed_int8_input[input_idx]; + + // Load int8 weight tile (TILE_K4=1, TILE_N4=2) + ivec4 int8_weight_tile[TILE_N4]; + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + int8_weight_tile[n4] = texelFetch( + t_packed_int8_weight, + ivec2(oc_block_idx + n4, k4), + 0); + } + + // Accumulate using int8 dot product + // Input tile indexed as input[m] where m is the width index within tile + // Weight tile indexed as weight[n4][n4i] where n4i is the channel index within block + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + [[unroll]] for (int n4i = 0; n4i < 4; ++n4i) { + out_accum[m][n4][n4i] = dotPacked4x8AccSatEXT( + int8_input_tile[m], + int8_weight_tile[n4][n4i], + out_accum[m][n4][n4i]); + } + } + } + + input_idx++; + } + + // Load weight scales tile + VEC4_T weight_scales[TILE_N4]; + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + weight_scales[n4] = t_weight_scales[oc_block_idx + n4]; + } + + // Load weight sums tile + ivec4 weight_sums[TILE_N4]; + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + weight_sums[n4] = ivec4(t_weight_sums[oc_block_idx + n4]); + } + + // Initialize int8 output tile + ivec4 int8_out_tile[TILE_M4][TILE_N4]; + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + int8_out_tile[m4][n4] = ivec4(0); + } + } + + // Compute int8 output tile from int32 accumulator + ivec4 input_zp_vec = ivec4(-input_zp); + + if (apply_bias > 0) { + // Load bias tile + VEC4_T bias[TILE_N4]; + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + bias[n4] = t_bias[oc_block_idx + n4]; + } + + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { + [[unroll]] for (int m4i = 0; m4i < 4; ++m4i) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + const int m = mul_4(m4) + m4i; + // Compute floating point output values + ivec4 accum_adjusted = + input_zp_vec * weight_sums[n4] + out_accum[m][n4]; + vec4 float_out_texel = + fma(vec4(accum_adjusted), + vec4(weight_scales[n4]) * input_scale, + vec4(bias[n4])); + // Requantize to int8 + float_out_texel = + round(float_out_texel * output_inv_scale) + output_zp; + ivec4 quantized_out_texel = clamp(ivec4(float_out_texel), -128, 127); + + int8_out_tile[m4][n4][m4i] = pack_into_int32(quantized_out_texel); + } + } + } + } else { + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { + [[unroll]] for (int m4i = 0; m4i < 4; ++m4i) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + const int m = mul_4(m4) + m4i; + // Compute floating point output values + ivec4 accum_adjusted = + input_zp_vec * weight_sums[n4] + out_accum[m][n4]; + vec4 float_out_texel = + vec4(accum_adjusted) * vec4(weight_scales[n4] * input_scale); + // Requantize to int8 + float_out_texel = + round(float_out_texel * output_inv_scale) + output_zp; + ivec4 quantized_out_texel = clamp(ivec4(float_out_texel), -128, 127); + + int8_out_tile[m4][n4][m4i] = pack_into_int32(quantized_out_texel); + } + } + } + } + + const int outp_w_stride = int(outp.strides[0][0]); + + // Store packed int8 output tile + [[unroll]] for (int m4 = 0; m4 < TILE_M4; m4++) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; n4++) { + const int base_outp_buffer_idx = compute_outp_buffer_idx( + ow_block_idx + m4, + oh, + oc_block_idx + n4); + if (oc_block_idx + n4 < OC4) { + // Store individual ints from the ivec4 + const int subtile_w_limit = min(4, W - mul_4(ow_block_idx + m4)); + [[unroll]] for (int subtile_w = 0; subtile_w < subtile_w_limit; ++subtile_w) { + if (get_outer_packed_dim_block_size(outp_layout) == 1) { + t_packed_int8_output[base_outp_buffer_idx + subtile_w * outp_w_stride] = int8_out_tile[m4][n4][subtile_w]; + } else { + t_packed_int8_output[base_outp_buffer_idx + subtile_w] = int8_out_tile[m4][n4][subtile_w]; + } + } + } + } + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.yaml b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.yaml new file mode 100644 index 00000000000..b7b8c42bf14 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.yaml @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +q8ta_conv2d_pw: + parameter_names_with_default_values: + DTYPE: float + generate_variant_forall: + DTYPE: + - VALUE: float + shader_variants: + - NAME: q8ta_conv2d_pw diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp index 2ee57551235..aa4e1e47d27 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp @@ -25,6 +25,13 @@ bool q8ta_conv2d_check_packed_dim_info(const api::PackedDimInfo& info) { info.outer_packed_dim_block_size == 4); } +bool q8ta_conv2d_check_4w4c_packed_dim_info(const api::PackedDimInfo& info) { + return info.packed_dim == WHCN::kChannelsDim && + info.packed_dim_block_size == 4 && + info.outer_packed_dim == WHCN::kWidthDim && + info.outer_packed_dim_block_size == 4; +} + // // Workgroup size selection functions // diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h index 2c66537acc5..2c22f35ad13 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h @@ -15,6 +15,8 @@ namespace vkcompute { bool q8ta_conv2d_check_packed_dim_info(const api::PackedDimInfo& info); +bool q8ta_conv2d_check_4w4c_packed_dim_info(const api::PackedDimInfo& info); + ValueRef prepack_quantized_conv2d_weight( ComputeGraph& graph, const QuantizationConfig& weight_quant_config, diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp new file mode 100644 index 00000000000..5ff69dac63b --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp @@ -0,0 +1,352 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include + +namespace vkcompute { + +// +// Shader dispatch utilities +// + +utils::uvec3 pick_q8ta_conv2d_pw_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef output = args.at(0).refs.at(0); + + const uint32_t W = graph->size_at(-1, output); + const uint32_t H = graph->size_at(-2, output); + const uint32_t C = graph->size_at(-3, output); + + // The 4W4C shader processes tiles of: + // - TILE_N4=2 groups of 4 output channels (8 channels per thread) + // - TILE_M4=1 groups of 4 widths (4 widths per thread) + // - 1 height per thread + constexpr uint32_t TILE_N4 = 2; + constexpr uint32_t TILE_M4 = 1; + + const uint32_t C4 = utils::div_up_4(C); + const uint32_t W4 = utils::div_up_4(W); + + // Global workgroup size: + // x = output channels / (TILE_N4 * 4) = C4 / TILE_N4 + // y = width / (TILE_M4 * 4) = W4 / TILE_M4 + // z = height + return {utils::div_up(C4, TILE_N4), utils::div_up(W4, TILE_M4), H}; +} + +utils::uvec3 pick_q8ta_conv2d_pw_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + return pick_hw_square_wg_size( + graph, shader, global_workgroup_size, args, resize_args); +} + +// +// 4W4C shader dispatch utilities +// + +utils::uvec3 pick_q8ta_conv2d_pw_4w4c_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef output = args.at(0).refs.at(0); + + const uint32_t W = graph->size_at(-1, output); + const uint32_t H = graph->size_at(-2, output); + const uint32_t C = graph->size_at(-3, output); + + // The 4W4C shader processes tiles of: + // - TILE_N4=2 groups of 4 output channels (8 channels per thread) + // - TILE_M4=1 groups of 4 widths (4 widths per thread) + // - 1 height per thread + constexpr uint32_t TILE_N4 = 2; + constexpr uint32_t TILE_M4 = 1; + + const uint32_t C4 = utils::div_up_4(C); + const uint32_t W4 = utils::div_up_4(W); + + // Global workgroup size: + // x = output channels / (TILE_N4 * 4) = C4 / TILE_N4 + // y = width / (TILE_M4 * 4) = W4 / TILE_M4 + // z = height + return {utils::div_up(C4, TILE_N4), utils::div_up(W4, TILE_M4), H}; +} + +utils::uvec3 pick_q8ta_conv2d_pw_4w4c_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + return pick_hw_square_wg_size( + graph, shader, global_workgroup_size, args, resize_args); +} + +// +// Prepack nodes +// + +ValueRef prepack_quantized_conv2d_pw_weight( + ComputeGraph& graph, + const QuantizationConfig& weight_quant_config, + const ValueRef weight_data, + const ValueRef input, + const ValueRef output) { + VK_CHECK_COND(weight_quant_config.nbits == 8); + VK_CHECK_COND(weight_quant_config.is_symmetric); + + const int64_t OC = graph.size_at(-3, output); + const int64_t IC = graph.size_at(-3, input); + + // For pointwise convolution, kernel_size = 1x1 + const int64_t K_h = 1; + const int64_t K_w = 1; + + const int64_t num_blocks_OC = utils::div_up_4(OC); + const int64_t num_blocks_IC = utils::div_up_4(IC); + + const int64_t num_blocks_y = num_blocks_IC * K_h; + const int64_t num_blocks_x = K_w * num_blocks_OC; + + // The packed tensor arranges blocks as [OC_blocks * K_total, IC_blocks] + const int64_t output_height = num_blocks_y; + const int64_t output_width = num_blocks_x * 4; + + // Store the original sizes of the weight data to pass to the shader + utils::ivec4 orig_sizes = { + utils::safe_downcast(OC), + utils::safe_downcast(K_h), + utils::safe_downcast(K_w), + utils::safe_downcast(IC)}; + + std::vector packed_weight_sizes{output_height, output_width}; + + utils::StorageType storage_type = utils::kTexture2D; + uint32_t max_extent = graph.context()->adapter_ptr()->max_texture2d_dim(); + if (output_width > max_extent * 4 || output_height > max_extent) { + storage_type = utils::kBuffer; + } + + ValueRef packed_weight = graph.add_tensor( + packed_weight_sizes, + vkcompute::vkapi::kInt, + storage_type, + utils::kWidthPacked); + + utils::uvec3 global_wg_size = { + utils::safe_downcast(num_blocks_x), + utils::safe_downcast(num_blocks_y), + 1u}; + + std::string kernel_name = "pack_q8_conv2d_weights"; + add_storage_type_suffix(kernel_name, storage_type); + + graph.prepack_nodes().emplace_back(new PrepackNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_wg_size, + graph.create_local_wg_size(global_wg_size), + // Inputs and Outputs + weight_data, + packed_weight, + // UBOs + {}, + // Specialization Constants + {}, + // Push Constants + {graph.sizes_pc_of(packed_weight), + PushConstantDataInfo(&orig_sizes, sizeof(utils::ivec4))})); + + return packed_weight; +} + +// +// Dispatch nodes +// + +void add_q8ta_conv2d_pw_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef packed_weight, + const ValueRef packed_weight_sums, + const ValueRef packed_weight_scales, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef bias_data, + const ValueRef packed_bias, + const ValueRef packed_int8_output) { + // Validate packed dim info for input and output tensors + // To maximize performance, the input tensor must be in 4W4C layout + VK_CHECK_COND(q8ta_conv2d_check_4w4c_packed_dim_info( + graph.packed_dim_info_of(packed_int8_input))); + // However, the requirements for output tensor layout is flexible + VK_CHECK_COND(q8ta_conv2d_check_packed_dim_info( + graph.packed_dim_info_of(packed_int8_output))); + + // Validate dtype is kInt8x4 + VK_CHECK_COND(graph.dtype_of(packed_int8_input) == vkapi::kInt8x4); + VK_CHECK_COND(graph.dtype_of(packed_int8_output) == vkapi::kInt8x4); + + float input_scale_val = graph.extract_scalar(input_scale); + int32_t input_zp_val = graph.extract_scalar(input_zp); + + float output_inv_scale_val = 1.0f / graph.extract_scalar(output_scale); + int32_t output_zp_val = graph.extract_scalar(output_zp); + + uint32_t apply_bias = 1; + if (graph.val_is_none(bias_data)) { + apply_bias = 0; + } + + // Get input channel count for K4_per_group + const uint32_t IC = graph.size_at(-3, packed_int8_input); + const uint32_t K4_per_group = utils::div_up_4(IC); + + std::vector push_constants = { + PushConstantDataInfo(&input_scale_val, sizeof(input_scale_val)), + PushConstantDataInfo(&input_zp_val, sizeof(input_zp_val)), + PushConstantDataInfo(&output_inv_scale_val, sizeof(output_inv_scale_val)), + PushConstantDataInfo(&output_zp_val, sizeof(output_zp_val)), + }; + + std::string kernel_name = "q8ta_conv2d_pw"; + add_dtype_suffix(kernel_name, graph.dtype_of(packed_weight_scales)); + + // Pass metadata for both output and input tensors + vkapi::ParamsBindList param_buffers = { + graph.buffer_meta_ubo(packed_int8_output), + graph.buffer_meta_ubo(packed_int8_input)}; + + // Build spec constants: apply_bias + layout constants + vkapi::SpecVarList spec_constants = { + apply_bias, + K4_per_group, + // Layout specialization constants + graph.hashed_layout_of(packed_int8_output), + graph.hashed_layout_of(packed_int8_input), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + pick_q8ta_conv2d_pw_global_wg_size, + pick_q8ta_conv2d_pw_local_wg_size, + // Inputs and Outputs + {{packed_int8_output, vkapi::kWrite}, + {{packed_int8_input, + packed_weight, + packed_weight_sums, + packed_weight_scales, + packed_bias}, + vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + spec_constants, + // Resize args + {})); +} + +// +// High level operator impl +// + +void q8ta_conv2d_pw(ComputeGraph& graph, const std::vector& args) { + int32_t idx = 0; + const ValueRef packed_int8_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_sums_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + // Accept but ignore conv params - pointwise has fixed kernel=1x1, stride=1, + // padding=0, dilation=1, groups=1 + (void)args.at(idx++); // kernel_size + (void)args.at(idx++); // stride + (void)args.at(idx++); // padding + (void)args.at(idx++); // dilation + (void)args.at(idx++); // groups + const ValueRef packed_int8_output = args.at(idx++); + + QuantizationConfig weight_quant_config(8, kPerChannel, {}); + + // Prepack weight using pointwise-specific packing + ValueRef packed_weight = prepack_quantized_conv2d_pw_weight( + graph, + weight_quant_config, + weight_data, + packed_int8_input, + packed_int8_output); + + ValueRef packed_weight_sums = prepack_standard( + graph, weight_sums_data, utils::kBuffer, utils::kWidthPacked); + + ValueRef packed_weight_scales = prepack_standard( + graph, weight_scales_data, utils::kBuffer, utils::kWidthPacked); + + // Create a dummy tensor to fill the binding slot of the bias tensor if it is + // not provided. This helps simplify dispatch logic and makes it so that + // fewer shader variants need to be generated. + TmpTensor dummy_bias( + &graph, + {}, + graph.dtype_of(weight_scales_data), + utils::kBuffer, + utils::kWidthPacked); + + ValueRef packed_bias = dummy_bias.vref; + if (graph.val_is_not_none(bias_data)) { + packed_bias = + prepack_standard(graph, bias_data, utils::kBuffer, utils::kWidthPacked); + } + + add_q8ta_conv2d_pw_node( + graph, + packed_int8_input, + input_scale, + input_zp, + packed_weight, + packed_weight_sums, + packed_weight_scales, + output_scale, + output_zp, + bias_data, + packed_bias, + packed_int8_output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(etvk.q8ta_conv2d_pw.default, q8ta_conv2d_pw); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp b/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp index fca82ef3eee..13e9c5b5b67 100644 --- a/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp +++ b/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp @@ -9,6 +9,7 @@ #include #include +#include #include namespace vkcompute { @@ -163,9 +164,88 @@ void test_q8ta_conv2d(ComputeGraph& graph, const std::vector& args) { graph, packed_int8_output, output_scale, output_zp, fp_output); } +void test_q8ta_conv2d_pw( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef fp_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_sums_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + const ValueRef kernel_size = args.at(idx++); + const ValueRef stride = args.at(idx++); + const ValueRef padding = args.at(idx++); + const ValueRef dilation = args.at(idx++); + const ValueRef groups = args.at(idx++); + const ValueRef layout_int = args.at(idx++); + const ValueRef impl_selector_str = args.at(idx++); + const ValueRef fp_output = args.at(idx++); + + // Extract the layout parameter and cast to GPUMemoryLayout + int32_t layout_value = graph.extract_scalar(layout_int); + utils::GPUMemoryLayout layout = + static_cast(layout_value); + + // Extract the impl_selector string + std::string impl_selector = graph.extract_string(impl_selector_str); + + // Create temporary packed int8 tensors for input and output + TmpTensor packed_int8_input( + &graph, + graph.sizes_of(fp_input), + vkapi::kInt8x4, + utils::kBuffer, + utils::kPackedInt8_4W4C); + + TmpTensor packed_int8_output( + &graph, + graph.sizes_of(fp_output), + vkapi::kInt8x4, + utils::kBuffer, + layout); + + // Quantize floating point input to packed int8 + add_q8ta_quantize_node( + graph, fp_input, input_scale, input_zp, packed_int8_input); + + // Build args for conv operator + std::vector conv_args = { + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + kernel_size, + stride, + padding, + dilation, + groups, + packed_int8_output}; + + if (impl_selector == "legacy_4w4c") { + VK_GET_OP_FN("et_vk.conv2d_q8ta_q8csw_q8to.default")(graph, conv_args); + } else { + VK_GET_OP_FN("etvk.q8ta_conv2d_pw.default")(graph, conv_args); + } + + // Dequantize packed int8 output to floating point + add_q8ta_dequantize_node( + graph, packed_int8_output, output_scale, output_zp, fp_output); +} + REGISTER_OPERATORS { VK_REGISTER_OP(test_etvk.test_q8ta_conv2d_dw.default, test_q8ta_conv2d_dw); VK_REGISTER_OP(test_etvk.test_q8ta_conv2d.default, test_q8ta_conv2d); + VK_REGISTER_OP(test_etvk.test_q8ta_conv2d_pw.default, test_q8ta_conv2d_pw); } } // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp index 1cd6151071d..6c7b2d94ecd 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp @@ -60,8 +60,9 @@ static TestCase create_test_case_from_config( } test_case.set_name(test_name); - // Set the operator name for the test case - use the unified test operator - std::string operator_name = "test_etvk.test_q8ta_conv2d.default"; + // Set the operator name for the test case - use the dedicated pointwise test + // operator + std::string operator_name = "test_etvk.test_q8ta_conv2d_pw.default"; test_case.set_operator_name(operator_name); ValueSpec input_tensor( @@ -163,7 +164,7 @@ static TestCase create_test_case_from_config( fp_memory_layout, DataGenType::ZEROS); - // Add all specs to test case for q8ta_q8csw_q8to operation + // Add all specs to test case for q8ta_conv2d_pw operation test_case.add_input_spec(input_tensor); test_case.add_input_spec(input_scale); test_case.add_input_spec(input_zero_point); @@ -275,6 +276,13 @@ static std::vector generate_quantized_conv2d_pw_test_cases() { Padding(0, 0), Dilation(1, 1), 1}, + {OutInChannels(64, 576), + InputSize2D(128, 128), + KernelSize(1, 1), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), + 1}, }; // Test with different storage types and memory layouts @@ -521,7 +529,7 @@ static void reference_impl(TestCase& test_case) { conv2d_q8ta_q8csw_q8to_reference_impl(test_case); } -// Custom FLOP calculator for quantized conv2d operation +// Custom FLOP calculator for quantized pointwise conv2d operation static int64_t quantized_conv2d_flop_calculator(const TestCase& test_case) { int kernel_idx = 9; // kernel_size is at index 9 for q8ta_q8csw_q8to @@ -582,8 +590,8 @@ int main(int argc, char* argv[]) { 0, 1, #else - 3, - 10, + 5, + 25, #endif ref_fn);