diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.glsl index 0b2cd7fef5a..8dd1ac1bdbe 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.glsl @@ -30,7 +30,7 @@ layout(std430) buffer; #include "common.glslh" ${layout_declare_tensor(B, "w", "t_scales", DTYPE, "texture3d")} -${layout_declare_tensor(B, "w", "t_zps", "int8", "texture3d")} +${layout_declare_tensor(B, "w", "t_zps", "int8" if ZP_DTYPE_MODE == "zpint8" else DTYPE, "texture3d")} ${layout_declare_tensor(B, "r", "t_input", DTYPE, STORAGE, is_scalar_array=False)} ${layout_declare_ubo(B, "ivec4", "input_sizes")} @@ -196,7 +196,10 @@ void main() { if (worker_id == 0) { imageStore(t_scales, ivec3(output_y4, 0, 0), scales_out); - imageStore(t_zps, ivec3(output_y4, 0, 0), zps_out); + $if ZP_DTYPE_MODE == "zpint8": + imageStore(t_zps, ivec3(output_y4, 0, 0), zps_out); + $else: + imageStore(t_zps, ivec3(output_y4, 0, 0), VEC4_T(zps_out)); } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.yaml b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.yaml index 5dbf3d7adaa..f90ce6a8394 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.yaml @@ -8,6 +8,7 @@ choose_qparams_per_row: parameter_names_with_default_values: DTYPE: float STORAGE: texture3d + ZP_DTYPE_MODE: zpint8 generate_variant_forall: STORAGE: - VALUE: texture3d @@ -15,5 +16,8 @@ choose_qparams_per_row: DTYPE: - VALUE: float - VALUE: half + ZP_DTYPE_MODE: + - VALUE: zpint8 + - VALUE: zpinherit shader_variants: - NAME: choose_qparams_per_row diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.glsl index fa0129b65a5..a68a56c3713 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.glsl @@ -46,7 +46,7 @@ ${layout_declare_tensor(B, "r", "t_input", DTYPE, IO_STORAGE, is_scalar_array=Fa ${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", PACKED_INT8_INPUT_STORAGE, is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_int8_input_sums", "int", "buffer", is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_int8_input_scales", DTYPE, "texture3d")} -${layout_declare_tensor(B, "r", "t_int8_input_zps", "int8", "texture3d")} +${layout_declare_tensor(B, "r", "t_int8_input_zps", "int8" if ZP_DTYPE_MODE == "zpint8" else DTYPE, "texture3d")} ${layout_declare_tensor(B, "r", "t_packed_int4_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.yaml index a252055ed40..f88008f488d 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.yaml @@ -13,10 +13,14 @@ linear_dq8ca_q4gsw_tiled: TILE_M4: 1 TILE_K4: 1 TILE_N8: 1 + ZP_DTYPE_MODE: zpint8 generate_variant_forall: DTYPE: - VALUE: float - VALUE: half + ZP_DTYPE_MODE: + - VALUE: zpint8 + - VALUE: zpinherit shader_variants: - NAME: linear_dq8ca_q4gsw_tiled_texture3d_texture2d - NAME: linear_dq8ca_q4gsw_tiled_texture3d_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_scales_zps_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_scales_zps_load.glslh index e1a570622c2..9b178d5c6c0 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_scales_zps_load.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_scales_zps_load.glslh @@ -20,7 +20,8 @@ void load_int8_input_scales_and_zps( [[unroll]] for (int m4 = 0; m4 < TILE_M4; m4++) { scales.data[m4] = VEC4_T(texelFetch(t_int8_input_scales, ivec3(m4_start + m4, 0, 0), 0)); - zps.data[m4] = texelFetch(t_int8_input_zps, ivec3(m4_start + m4, 0, 0), 0); + zps.data[m4] = + ivec4(texelFetch(t_int8_input_zps, ivec3(m4_start + m4, 0, 0), 0)); } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.glsl index 053f27d6c9b..505fc3d0009 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.glsl @@ -40,7 +40,7 @@ $if DYNAMIC_QUANT_VARIANT: ${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", PACKED_INPUT_STORAGE, is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_int_input_sums", "int", "buffer", is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_input_scale", DTYPE, "texture3d")} - ${layout_declare_tensor(B, "r", "t_input_zp", "int", "texture3d")} + ${layout_declare_tensor(B, "r", "t_input_zp", "int8" if ZP_DTYPE_MODE == "zpint8" else DTYPE, "texture3d")} ${layout_declare_tensor(B, "r", "t_packed_int4_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.yaml index 2c5001fdd17..e63075e9eb1 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.yaml @@ -15,6 +15,7 @@ linear_q4gsw_coop: TILE_N8: 1 WGS: 64 DYNAMIC_QUANT_VARIANT: false + ZP_DTYPE_MODE: zpint8 generate_variant_forall: DTYPE: - VALUE: float @@ -30,14 +31,42 @@ linear_q4gsw_coop: WEIGHT_STORAGE: buffer - NAME: linear_dq8ca_q4gsw_coop_texture3d_texture2d DYNAMIC_QUANT_VARIANT: true + generate_variant_forall: + DTYPE: + - VALUE: float + - VALUE: half + ZP_DTYPE_MODE: + - VALUE: zpint8 + - VALUE: zpinherit - NAME: linear_dq8ca_q4gsw_coop_texture3d_buffer WEIGHT_STORAGE: buffer DYNAMIC_QUANT_VARIANT: true + generate_variant_forall: + DTYPE: + - VALUE: float + - VALUE: half + ZP_DTYPE_MODE: + - VALUE: zpint8 + - VALUE: zpinherit - NAME: linear_dq8ca_q4gsw_coop_buffer_texture2d IO_STORAGE: buffer WEIGHT_STORAGE: texture2d DYNAMIC_QUANT_VARIANT: true + generate_variant_forall: + DTYPE: + - VALUE: float + - VALUE: half + ZP_DTYPE_MODE: + - VALUE: zpint8 + - VALUE: zpinherit - NAME: linear_dq8ca_q4gsw_coop_buffer_buffer IO_STORAGE: buffer WEIGHT_STORAGE: buffer DYNAMIC_QUANT_VARIANT: true + generate_variant_forall: + DTYPE: + - VALUE: float + - VALUE: half + ZP_DTYPE_MODE: + - VALUE: zpint8 + - VALUE: zpinherit diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w_with_group_sums.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w_with_group_sums.glsl index e4d211a95f5..d5f14f70b62 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w_with_group_sums.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w_with_group_sums.glsl @@ -33,7 +33,7 @@ ${layout_declare_tensor(B, "w", "t_packed_int8_input", "int", OUTPUT_STORAGE, is ${layout_declare_tensor(B, "w", "t_int8_input_sums", "int", "buffer", is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_input", DTYPE, INPUT_STORAGE, is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_int8_input_scales", DTYPE, "texture3d")} -${layout_declare_tensor(B, "r", "t_int8_input_zps", "int8", "texture3d")} +${layout_declare_tensor(B, "r", "t_int8_input_zps", "int8" if ZP_DTYPE_MODE == "zpint8" else DTYPE, "texture3d")} ${layout_declare_ubo(B, "ivec4", "input_sizes")} diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w_with_group_sums.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w_with_group_sums.yaml index bdbc81c59d7..5e98e2e318b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w_with_group_sums.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w_with_group_sums.yaml @@ -11,10 +11,14 @@ quantize_and_pack_4h4w_with_group_sums: INPUT_STORAGE: texture3d NUM_GROUPS_PER_WG: 2 NUM_WORKERS_PER_GROUP: 32 + ZP_DTYPE_MODE: zpint8 generate_variant_forall: DTYPE: - VALUE: half - VALUE: float + ZP_DTYPE_MODE: + - VALUE: zpint8 + - VALUE: zpinherit shader_variants: - NAME: quantize_and_pack_4h4w_with_group_sums_o2w32_buffer_texture3d - NAME: quantize_and_pack_4h4w_with_group_sums_o2w32_buffer_buffer diff --git a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp index 5b8615e0a70..fdacce0236c 100644 --- a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp @@ -41,10 +41,12 @@ vkapi::ShaderInfo pick_choose_qparams_per_row_shader( (void)resize_args; const ValueRef input = args.at(1).refs.at(0); + const ValueRef input_zps = args.at(0).refs.at(1); std::string kernel_name = "choose_qparams_per_row"; add_storage_type_suffix(kernel_name, graph->storage_type_of(input)); add_dtype_suffix(kernel_name, graph->dtype_of(input)); + add_zp_dtype_mode_suffix(kernel_name, graph->dtype_of(input_zps)); return VK_KERNEL_FROM_STR(kernel_name); } diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizeDequantize.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizeDequantize.cpp index e02a42f60e1..98f97eab572 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizeDequantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizeDequantize.cpp @@ -66,6 +66,7 @@ vkapi::ShaderInfo pick_quantize_and_pack_4h4w_with_group_sums_shader( const std::vector& resize_args) { const ValueRef packed_int_input = args.at(0).refs.at(0); const ValueRef fp_input = args.at(1).refs.at(0); + const ValueRef packed_input_zps = args.at(1).refs.at(2); const ValueRef group_size = resize_args.at(0); const int64_t group_size_val = graph->extract_scalar(group_size); @@ -81,6 +82,7 @@ vkapi::ShaderInfo pick_quantize_and_pack_4h4w_with_group_sums_shader( shader_name, graph->storage_type_of(packed_int_input)); add_storage_type_suffix(shader_name, graph->storage_type_of(fp_input)); add_dtype_suffix(shader_name, graph->dtype_of(fp_input)); + add_zp_dtype_mode_suffix(shader_name, graph->dtype_of(packed_input_zps)); return VK_KERNEL_FROM_STR(shader_name); } diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp index 62aa5cd9fb9..db09c5585d2 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -145,6 +145,7 @@ vkapi::ShaderInfo pick_linear_dqa_qw_shader( const ValueRef fp_input = args.at(1).refs.at(0); const ValueRef int_input = args.at(1).refs.at(1); (void)int_input; + const ValueRef input_zp = args.at(1).refs.at(4); const ValueRef int_weight = args.at(1).refs.at(5); const bool weight_is_4bit = resize_args.at(0) != kDummyValueRef; @@ -165,6 +166,7 @@ vkapi::ShaderInfo pick_linear_dqa_qw_shader( add_storage_type_suffix(kernel_name, graph->storage_type_of(out)); add_storage_type_suffix(kernel_name, graph->storage_type_of(int_weight)); add_dtype_suffix(kernel_name, graph->dtype_of(out)); + add_zp_dtype_mode_suffix(kernel_name, graph->dtype_of(input_zp)); return VK_KERNEL_FROM_STR(kernel_name); } diff --git a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp index 59a9d79a6e3..a7b6b246f84 100644 --- a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp @@ -72,6 +72,22 @@ void add_dtype_suffix(std::string& kernel_name, const vkapi::ScalarType dtype) { } } +void add_zp_dtype_mode_suffix( + std::string& kernel_name, + const vkapi::ScalarType zp_dtype) { + switch (zp_dtype) { + case vkapi::kChar: + kernel_name += "_zpint8"; + break; + case vkapi::kHalf: + case vkapi::kFloat: + kernel_name += "_zpinherit"; + break; + default: + VK_THROW("Unsupported per-token zero-point dtype for dq8ca"); + } +} + void add_packed_dim_suffix(std::string& kernel_name, const int32_t packed_dim) { switch (packed_dim) { case WHCN::kWidthDim: diff --git a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h index 4a2fddb5cf2..feb21f36b56 100644 --- a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h +++ b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h @@ -22,6 +22,15 @@ void add_storage_type_suffix( void add_dtype_suffix(std::string& kernel_name, const vkapi::ScalarType dtype); +// Selects the per-token zero-point shader binding variant by the dtype the +// zero-point tensor was allocated with: "_zpint8" when the tensor is int8 +// (rgba8i integer image), "_zpinherit" when it follows the inference float +// dtype (rgba32f/rgba16f, matching the scale). Matches the ZP_DTYPE_MODE +// codegen axis used by the dq8ca qparams shaders. +void add_zp_dtype_mode_suffix( + std::string& kernel_name, + const vkapi::ScalarType zp_dtype); + void add_ndim_suffix(std::string& kernel_name, const size_t ndim); void add_packed_dim_suffix(std::string& kernel_name, const int32_t packed_dim);