Skip to content

Commit 8f409a5

Browse files
author
ssjia
committed
[ET-VK][quantized] Select dq8ca zero-point binding by its allocated dtype
Pull Request resolved: #20491 The per-token dynamic-activation-quant (`dq8ca`) zero-point image must be bound in the shader with the same dtype the tensor was allocated with; a binding-vs-allocation dtype mismatch corrupts the per-token zero-point. The allocation dtype differs by export path: standard `export_llama -qmode 8da4w` models (e.g. Qwen3-0.6B) serialize the zero-point as `int8`, while the Llama4-mini TISO backbone (torchao `per_token_dynamic_quant` / `Int8DynamicActivationIntxWeightConfig` with an explicit fp32 `zero_point_dtype`) serializes it as float, which `vulkan_graph_builder.get_effective_dtype` downcasts to `half` under `force_fp16`. A single fixed binding dtype cannot satisfy both paths. Binding the zero-point as `int8` (`rgba8i`) corrupts the float-allocated TISO zero-point on ARM Mali (Valhall) -- negative values come back as garbage, garbling the 8da4w TTS backbone. Conversely, binding it as the codegen `DTYPE` (matching the scale's float dtype) corrupts the int8-allocated zero-point: under fp16 inference the `rgba8i` image is read and written as `rgba16f`, saturating the per-token zero-point to the int8 floor/ceiling and garbling standard fp16 8da4w models such as Qwen3-0.6B. This change makes the zero-point binding a codegen variant so it always matches the tensor's allocation. A new `ZP_DTYPE_MODE` axis emits two variants of every dq8ca shader that binds the per-token zero-point: `zpint8` (binding declared `int8`, an `rgba8i` integer image) and `zpinherit` (binding declared with the codegen `DTYPE`, inheriting the inference float dtype to match the scale -- `rgba32f`, or `rgba16f` under `USE_VULKAN_FP16_INFERENCE`). The C++ shader pickers select the variant from `graph.dtype_of(zero_point)` (`kChar` -> `zpint8`; `kHalf` / `kFloat` -> `zpinherit`), so the shader binding matches the tensor's allocation regardless of how the model was exported. The shared read helper is unchanged: `ivec4(texelFetch(t_int8_input_zps, ...))` already reads both an integer image (identity) and a float image (exact truncation of the integer-valued zero-point in `[-128, 127]`). Affected shaders: `choose_qparams_per_row` (writes the zero-point, storing `ivec4` or `VEC4_T` per variant), `quantize_and_pack_4h4w_with_group_sums`, `linear_dq8ca_q4gsw_tiled`, and the dq8ca `linear_q4gsw_coop` variants (read the zero-point). This fixes the fp16 8da4w regression for standard int8 zero-point exports while preserving the float zero-point path that the TISO backbone and the original Mali fix depend on. Only the runtime shader binding changes, so existing `.pte` files are handled correctly with no re-export. Authored with Claude Code. ghstack-source-id: 397279874 @exported-using-ghexport Differential Revision: [D109595977](https://our.internmc.facebook.com/intern/diff/D109595977/)
1 parent fd00fa7 commit 8f409a5

14 files changed

Lines changed: 82 additions & 6 deletions

backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.glsl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ layout(std430) buffer;
3030
#include "common.glslh"
3131

3232
${layout_declare_tensor(B, "w", "t_scales", DTYPE, "texture3d")}
33-
${layout_declare_tensor(B, "w", "t_zps", "int8", "texture3d")}
33+
${layout_declare_tensor(B, "w", "t_zps", "int8" if ZP_DTYPE_MODE == "zpint8" else DTYPE, "texture3d")}
3434
${layout_declare_tensor(B, "r", "t_input", DTYPE, STORAGE, is_scalar_array=False)}
3535

3636
${layout_declare_ubo(B, "ivec4", "input_sizes")}
@@ -196,7 +196,10 @@ void main() {
196196

197197
if (worker_id == 0) {
198198
imageStore(t_scales, ivec3(output_y4, 0, 0), scales_out);
199-
imageStore(t_zps, ivec3(output_y4, 0, 0), zps_out);
199+
$if ZP_DTYPE_MODE == "zpint8":
200+
imageStore(t_zps, ivec3(output_y4, 0, 0), zps_out);
201+
$else:
202+
imageStore(t_zps, ivec3(output_y4, 0, 0), VEC4_T(zps_out));
200203
}
201204

202205
}

backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,16 @@ choose_qparams_per_row:
88
parameter_names_with_default_values:
99
DTYPE: float
1010
STORAGE: texture3d
11+
ZP_DTYPE_MODE: zpint8
1112
generate_variant_forall:
1213
STORAGE:
1314
- VALUE: texture3d
1415
- VALUE: buffer
1516
DTYPE:
1617
- VALUE: float
1718
- VALUE: half
19+
ZP_DTYPE_MODE:
20+
- VALUE: zpint8
21+
- VALUE: zpinherit
1822
shader_variants:
1923
- NAME: choose_qparams_per_row

backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ ${layout_declare_tensor(B, "r", "t_input", DTYPE, IO_STORAGE, is_scalar_array=Fa
4646
${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", PACKED_INT8_INPUT_STORAGE, is_scalar_array=False)}
4747
${layout_declare_tensor(B, "r", "t_int8_input_sums", "int", "buffer", is_scalar_array=False)}
4848
${layout_declare_tensor(B, "r", "t_int8_input_scales", DTYPE, "texture3d")}
49-
${layout_declare_tensor(B, "r", "t_int8_input_zps", "int8", "texture3d")}
49+
${layout_declare_tensor(B, "r", "t_int8_input_zps", "int8" if ZP_DTYPE_MODE == "zpint8" else DTYPE, "texture3d")}
5050
${layout_declare_tensor(B, "r", "t_packed_int4_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)}
5151
${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)}
5252
${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)}

backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,14 @@ linear_dq8ca_q4gsw_tiled:
1313
TILE_M4: 1
1414
TILE_K4: 1
1515
TILE_N8: 1
16+
ZP_DTYPE_MODE: zpint8
1617
generate_variant_forall:
1718
DTYPE:
1819
- VALUE: float
1920
- VALUE: half
21+
ZP_DTYPE_MODE:
22+
- VALUE: zpint8
23+
- VALUE: zpinherit
2024
shader_variants:
2125
- NAME: linear_dq8ca_q4gsw_tiled_texture3d_texture2d
2226
- NAME: linear_dq8ca_q4gsw_tiled_texture3d_buffer

backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_scales_zps_load.glslh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ void load_int8_input_scales_and_zps(
2020
[[unroll]] for (int m4 = 0; m4 < TILE_M4; m4++) {
2121
scales.data[m4] =
2222
VEC4_T(texelFetch(t_int8_input_scales, ivec3(m4_start + m4, 0, 0), 0));
23-
zps.data[m4] = texelFetch(t_int8_input_zps, ivec3(m4_start + m4, 0, 0), 0);
23+
zps.data[m4] =
24+
ivec4(texelFetch(t_int8_input_zps, ivec3(m4_start + m4, 0, 0), 0));
2425
}
2526
}
2627

backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ $if DYNAMIC_QUANT_VARIANT:
4040
${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", PACKED_INPUT_STORAGE, is_scalar_array=False)}
4141
${layout_declare_tensor(B, "r", "t_int_input_sums", "int", "buffer", is_scalar_array=False)}
4242
${layout_declare_tensor(B, "r", "t_input_scale", DTYPE, "texture3d")}
43-
${layout_declare_tensor(B, "r", "t_input_zp", "int", "texture3d")}
43+
${layout_declare_tensor(B, "r", "t_input_zp", "int8" if ZP_DTYPE_MODE == "zpint8" else DTYPE, "texture3d")}
4444
${layout_declare_tensor(B, "r", "t_packed_int4_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)}
4545
${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)}
4646
${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)}

backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.yaml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ linear_q4gsw_coop:
1515
TILE_N8: 1
1616
WGS: 64
1717
DYNAMIC_QUANT_VARIANT: false
18+
ZP_DTYPE_MODE: zpint8
1819
generate_variant_forall:
1920
DTYPE:
2021
- VALUE: float
@@ -30,14 +31,42 @@ linear_q4gsw_coop:
3031
WEIGHT_STORAGE: buffer
3132
- NAME: linear_dq8ca_q4gsw_coop_texture3d_texture2d
3233
DYNAMIC_QUANT_VARIANT: true
34+
generate_variant_forall:
35+
DTYPE:
36+
- VALUE: float
37+
- VALUE: half
38+
ZP_DTYPE_MODE:
39+
- VALUE: zpint8
40+
- VALUE: zpinherit
3341
- NAME: linear_dq8ca_q4gsw_coop_texture3d_buffer
3442
WEIGHT_STORAGE: buffer
3543
DYNAMIC_QUANT_VARIANT: true
44+
generate_variant_forall:
45+
DTYPE:
46+
- VALUE: float
47+
- VALUE: half
48+
ZP_DTYPE_MODE:
49+
- VALUE: zpint8
50+
- VALUE: zpinherit
3651
- NAME: linear_dq8ca_q4gsw_coop_buffer_texture2d
3752
IO_STORAGE: buffer
3853
WEIGHT_STORAGE: texture2d
3954
DYNAMIC_QUANT_VARIANT: true
55+
generate_variant_forall:
56+
DTYPE:
57+
- VALUE: float
58+
- VALUE: half
59+
ZP_DTYPE_MODE:
60+
- VALUE: zpint8
61+
- VALUE: zpinherit
4062
- NAME: linear_dq8ca_q4gsw_coop_buffer_buffer
4163
IO_STORAGE: buffer
4264
WEIGHT_STORAGE: buffer
4365
DYNAMIC_QUANT_VARIANT: true
66+
generate_variant_forall:
67+
DTYPE:
68+
- VALUE: float
69+
- VALUE: half
70+
ZP_DTYPE_MODE:
71+
- VALUE: zpint8
72+
- VALUE: zpinherit

backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w_with_group_sums.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ ${layout_declare_tensor(B, "w", "t_packed_int8_input", "int", OUTPUT_STORAGE, is
3333
${layout_declare_tensor(B, "w", "t_int8_input_sums", "int", "buffer", is_scalar_array=False)}
3434
${layout_declare_tensor(B, "r", "t_input", DTYPE, INPUT_STORAGE, is_scalar_array=False)}
3535
${layout_declare_tensor(B, "r", "t_int8_input_scales", DTYPE, "texture3d")}
36-
${layout_declare_tensor(B, "r", "t_int8_input_zps", "int8", "texture3d")}
36+
${layout_declare_tensor(B, "r", "t_int8_input_zps", "int8" if ZP_DTYPE_MODE == "zpint8" else DTYPE, "texture3d")}
3737

3838
${layout_declare_ubo(B, "ivec4", "input_sizes")}
3939

backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w_with_group_sums.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,14 @@ quantize_and_pack_4h4w_with_group_sums:
1111
INPUT_STORAGE: texture3d
1212
NUM_GROUPS_PER_WG: 2
1313
NUM_WORKERS_PER_GROUP: 32
14+
ZP_DTYPE_MODE: zpint8
1415
generate_variant_forall:
1516
DTYPE:
1617
- VALUE: half
1718
- VALUE: float
19+
ZP_DTYPE_MODE:
20+
- VALUE: zpint8
21+
- VALUE: zpinherit
1822
shader_variants:
1923
- NAME: quantize_and_pack_4h4w_with_group_sums_o2w32_buffer_texture3d
2024
- NAME: quantize_and_pack_4h4w_with_group_sums_o2w32_buffer_buffer

backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,12 @@ vkapi::ShaderInfo pick_choose_qparams_per_row_shader(
4141
(void)resize_args;
4242

4343
const ValueRef input = args.at(1).refs.at(0);
44+
const ValueRef input_zps = args.at(0).refs.at(1);
4445

4546
std::string kernel_name = "choose_qparams_per_row";
4647
add_storage_type_suffix(kernel_name, graph->storage_type_of(input));
4748
add_dtype_suffix(kernel_name, graph->dtype_of(input));
49+
add_zp_dtype_mode_suffix(kernel_name, graph->dtype_of(input_zps));
4850

4951
return VK_KERNEL_FROM_STR(kernel_name);
5052
}

0 commit comments

Comments
 (0)