Skip to content

Commit 76deb35

Browse files
author
morelos
committed
[ET-VK][Ops] choose_qparams op shaders and impl
Pull Request resolved: #11557 Creating the choose_qparams per_tensor and per_token logic shaders and impl which are linked with the testing framework ghstack-source-id: 289992011 @exported-using-ghexport Differential Revision: [D76436933](https://our.internmc.facebook.com/intern/diff/D76436933/)
1 parent 9278eb0 commit 76deb35

File tree

7 files changed

+1001
-0
lines changed

7 files changed

+1001
-0
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#ifndef CHOOSE_QPARAMS_GLSLH
10+
#define CHOOSE_QPARAMS_GLSLH
11+
12+
// equivalent of the eps defined in the cpu implementation
13+
#define SMALL_SCALE_THRESHOLD 6.1e-5
14+
15+
// Calculate scale and zero point from min and max values
16+
void calculate_scale_and_zero_point(
17+
float min_val,
18+
float max_val,
19+
int qmin,
20+
int qmax,
21+
out float scale_val,
22+
out int zero_point_val) {
23+
// ensure we have zero included in our range
24+
min_val = min(min_val, 0.0);
25+
max_val = max(max_val, 0.0);
26+
27+
scale_val = (max_val - min_val) / float(qmax - qmin);
28+
29+
// Handle zero or very small scale
30+
if (scale_val == 0.0 || isinf(1.0 / scale_val)) {
31+
scale_val = 0.1;
32+
}
33+
34+
// Cut off small scale
35+
if (scale_val < SMALL_SCALE_THRESHOLD) {
36+
float org_scale = scale_val;
37+
scale_val = SMALL_SCALE_THRESHOLD;
38+
39+
// Adjust min and max based on new scale
40+
if (min_val == 0.0) {
41+
max_val = SMALL_SCALE_THRESHOLD * float(qmax - qmin);
42+
} else if (max_val == 0.0) {
43+
min_val = -SMALL_SCALE_THRESHOLD * float(qmax - qmin);
44+
} else {
45+
float amplifier = SMALL_SCALE_THRESHOLD / org_scale;
46+
min_val *= amplifier;
47+
max_val *= amplifier;
48+
}
49+
}
50+
51+
// Calculate zero point
52+
float zero_point_from_min = float(qmin) - min_val / scale_val;
53+
float zero_point_from_max = float(qmax) - max_val / scale_val;
54+
float zero_point_from_min_error = abs(float(qmin)) - abs(min_val / scale_val);
55+
float zero_point_from_max_error = abs(float(qmax)) - abs(max_val / scale_val);
56+
float initial_zero_point = zero_point_from_min_error < zero_point_from_max_error
57+
? zero_point_from_min
58+
: zero_point_from_max;
59+
60+
// Nudge zero point to integer
61+
if (initial_zero_point < float(qmin)) {
62+
zero_point_val = qmin;
63+
} else if (initial_zero_point > float(qmax)) {
64+
zero_point_val = qmax;
65+
} else {
66+
zero_point_val = int(round(initial_zero_point));
67+
}
68+
}
69+
70+
#endif // CHOOSE_QPARAMS_GLSLH
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define IN_T ${buffer_scalar_type(IN_DTYPE)}
14+
15+
${define_active_storage_type("buffer")}
16+
${define_required_extensions(IN_DTYPE)}
17+
18+
#extension GL_EXT_control_flow_attributes : require
19+
20+
layout(std430) buffer;
21+
22+
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")}
23+
${layout_declare_tensor(B, "w", "t_scale", "float", "buffer")}
24+
${layout_declare_tensor(B, "w", "t_zero_point", "int", "buffer")}
25+
26+
$if MODE == "per_tensor":
27+
layout(push_constant) uniform restrict Block {
28+
int quant_min;
29+
int quant_max;
30+
};
31+
$else:
32+
layout(push_constant) uniform restrict Block {
33+
int num_tokens;
34+
int quant_min;
35+
int quant_max;
36+
};
37+
38+
${layout_declare_ubo(B, "ivec4", "t_in_sizes")}
39+
${layout_declare_ubo(B, "ivec4", "t_in_strides")}
40+
${layout_declare_ubo(B, "ivec4", "t_scale_sizes")}
41+
${layout_declare_ubo(B, "ivec4", "t_scale_strides")}
42+
${layout_declare_ubo(B, "ivec4", "t_zero_point_sizes")}
43+
${layout_declare_ubo(B, "ivec4", "t_zero_point_strides")}
44+
45+
#include "indexing_utils.h"
46+
#include "choose_qparams.glslh"
47+
48+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
49+
50+
#define NWORKERS 64
51+
52+
// Shared memory for reduction - must match local work group size
53+
shared float shared_min[NWORKERS];
54+
shared float shared_max[NWORKERS];
55+
56+
void main() {
57+
$if MODE == "per_tensor":
58+
uint global_id = gl_GlobalInvocationID.x;
59+
uint local_id = gl_LocalInvocationID.x;
60+
uint group_id = gl_WorkGroupID.x;
61+
uint total_threads = gl_NumWorkGroups.x * gl_WorkGroupSize.x;
62+
63+
uint total_elements = uint(t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w);
64+
65+
// Each thread processes multiple elements with stride
66+
float thread_min = 1.0/0.0; // +infinity
67+
float thread_max = -1.0/0.0; // -infinity
68+
bool found_valid = false;
69+
70+
for (uint i = global_id; i < total_elements; i += total_threads) {
71+
float val = t_in[i];
72+
if (!isnan(val) && !isinf(val)) {
73+
if (!found_valid) {
74+
thread_min = val;
75+
thread_max = val;
76+
found_valid = true;
77+
} else {
78+
thread_min = min(thread_min, val);
79+
thread_max = max(thread_max, val);
80+
}
81+
}
82+
}
83+
84+
// Intra-group reduction using shared memory
85+
shared_min[local_id] = thread_min;
86+
shared_max[local_id] = thread_max;
87+
barrier();
88+
89+
// Tree reduction within work group
90+
for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) {
91+
if (local_id < stride) {
92+
float other_min = shared_min[local_id + stride];
93+
float other_max = shared_max[local_id + stride];
94+
95+
if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) {
96+
shared_min[local_id] = other_min;
97+
}
98+
if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) {
99+
shared_max[local_id] = other_max;
100+
}
101+
}
102+
barrier();
103+
}
104+
105+
// Final result calculation (single workgroup only)
106+
if (local_id == 0) {
107+
float global_min = shared_min[0];
108+
float global_max = shared_max[0];
109+
110+
float scale_val;
111+
int zero_point_val;
112+
calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, scale_val, zero_point_val);
113+
114+
t_scale[0] = scale_val;
115+
t_zero_point[0] = zero_point_val;
116+
}
117+
118+
$if MODE == "per_token":
119+
uint global_id = gl_GlobalInvocationID.x;
120+
uint local_id = gl_LocalInvocationID.x;
121+
uint group_id = gl_WorkGroupID.x;
122+
uint total_workgroups = gl_NumWorkGroups.x;
123+
124+
uint total_elements = uint(t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w);
125+
uint token_size = total_elements / uint(num_tokens);
126+
127+
// Calculate how many tokens each workgroup should process
128+
// This handles the case where we have more tokens than workgroups
129+
uint tokens_per_workgroup = (uint(num_tokens) + total_workgroups - 1) / total_workgroups;
130+
131+
// Calculate which tokens this workgroup is responsible for
132+
uint start_token = group_id * tokens_per_workgroup;
133+
uint end_token = min(start_token + tokens_per_workgroup, uint(num_tokens));
134+
135+
// Early exit if this workgroup has no tokens to process
136+
if (start_token >= uint(num_tokens)) {
137+
return;
138+
}
139+
140+
// Process each token assigned to this workgroup
141+
for (uint token_id = start_token; token_id < end_token; token_id++) {
142+
// Calculate the start and end indices for this token
143+
uint token_start = token_id * token_size;
144+
uint token_end = token_start + token_size;
145+
146+
// Each thread processes multiple elements within the token with stride
147+
float thread_min = 1.0/0.0; // +infinity
148+
float thread_max = -1.0/0.0; // -infinity
149+
bool found_valid = false;
150+
151+
// Process elements within this token only
152+
for (uint i = token_start + local_id; i < token_end; i += gl_WorkGroupSize.x) {
153+
float val = t_in[i];
154+
if (!isnan(val) && !isinf(val)) {
155+
if (!found_valid) {
156+
thread_min = val;
157+
thread_max = val;
158+
found_valid = true;
159+
} else {
160+
thread_min = min(thread_min, val);
161+
thread_max = max(thread_max, val);
162+
}
163+
}
164+
}
165+
166+
// Intra-group reduction using shared memory
167+
shared_min[local_id] = thread_min;
168+
shared_max[local_id] = thread_max;
169+
barrier();
170+
171+
// Tree reduction within work group
172+
for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) {
173+
if (local_id < stride) {
174+
float other_min = shared_min[local_id + stride];
175+
float other_max = shared_max[local_id + stride];
176+
177+
if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) {
178+
shared_min[local_id] = other_min;
179+
}
180+
if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) {
181+
shared_max[local_id] = other_max;
182+
}
183+
}
184+
barrier();
185+
}
186+
187+
// Final calculation for this token
188+
if (local_id == 0) {
189+
float token_min = shared_min[0];
190+
float token_max = shared_max[0];
191+
192+
float scale_val;
193+
int zero_point_val;
194+
calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, scale_val, zero_point_val);
195+
196+
t_scale[token_id] = scale_val;
197+
t_zero_point[token_id] = zero_point_val;
198+
}
199+
200+
// Synchronize before processing next token
201+
barrier();
202+
}
203+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
choose_qparams_buffer:
2+
parameter_names_with_default_values:
3+
IN_DTYPE: float
4+
MODE: per_tensor
5+
generate_variant_forall:
6+
IN_DTYPE:
7+
- VALUE: float
8+
shader_variants:
9+
- NAME: choose_qparams_tensor_buffer
10+
MODE: per_tensor
11+
- NAME: choose_qparams_per_token_asymmetric_buffer
12+
MODE: per_token

0 commit comments

Comments
 (0)