Skip to content

Commit 23f9021

Browse files
ssjiaSS-JIA
authored andcommitted
[ET-VK] Add dynamic-shape resize to q8ta ops
Pull Request resolved: #20312 The q8ta (quantized int8) op `DynamicDispatchNode`s were constructed with an empty resize-args list and no resize function, so their output tensors were never `virtual_resize`d on `trigger_resize()`. On a dynamic-shape graph this froze the q8ta outputs at the build-time upper-bound shape — the same failure mode the fp32 ops already avoid. Concretely, in a quantized Vulkan-delegated graph the terminal pointwise conv produces the graph output, so a smaller input (e.g. 238 rows fed into a graph allocated at the 241-row upper bound) left stale rows that propagate downstream, where GroupNorm's global per-group statistics smear them across the whole tensor. Add resize functions across the q8ta op family, each matching that op's output-shape semantics (mirroring the corresponding fp32 op's resize): - `q8ta_conv2d` / `q8ta_conv2d_dw`: output H/W recomputed from the input via `calc_out_sizes_hw`. - `q8ta_conv2d_pw`: 1x1 conv preserves spatial dims (out H/W == in H/W). - `q8ta_conv2d_transposed`: transposed output formula via `calc_out_sizes_hw(transposed=true)` (threads `output_padding` through the dispatch, which was previously dropped). - `q8ta` im2col scratch: flattened-window `K` from channels/kernel/groups, `H_out`/`W_out` from the current input. - `q8ta_linear`: `[*input.shape[:-1], out_features]`. - `q8ta` binary: `broadcast(in_a, in_b)`. - `q8ta` quantize / dequantize: elementwise, output shape == input shape. The quantized conv/quant path now honors dynamic input shapes like the fp32 path. ghstack-source-id: 394480015 @exported-using-ghexport Differential Revision: [D108788845](https://our.internmc.facebook.com/intern/diff/D108788845/)
1 parent 0c763d2 commit 23f9021

11 files changed

Lines changed: 411 additions & 22 deletions

File tree

backends/vulkan/runtime/graph/ops/impl/Q8taBinary.cpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,36 @@
99
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1010

1111
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
12+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1213
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
1314

1415
namespace vkcompute {
1516

17+
//
18+
// Resize
19+
//
20+
21+
// resize_args = { block_config_ref } (unused here)
22+
//
23+
// Elementwise binary with broadcasting: output = broadcast(in_a, in_b). Without
24+
// this the DynamicDispatchNode freezes the output at the build-time upper
25+
// bound. Mirrors the fp32 resize_binary_op_node (same arg-group layout: inputs
26+
// are args[1].refs[0] and [1]).
27+
void resize_q8ta_binary_node(
28+
ComputeGraph* graph,
29+
const std::vector<ArgGroup>& args,
30+
const std::vector<ValueRef>& resize_args) {
31+
(void)resize_args;
32+
const ValueRef out = args.at(0).refs.at(0);
33+
const ValueRef in_a = args.at(1).refs.at(0);
34+
const ValueRef in_b = args.at(1).refs.at(1);
35+
36+
const std::vector<int64_t> a_sizes = graph->sizes_of(in_a);
37+
const std::vector<int64_t> b_sizes = graph->sizes_of(in_b);
38+
graph->virtual_resize(
39+
out, calculate_broadcasted_output_size(a_sizes, b_sizes));
40+
}
41+
1642
//
1743
// Dispatch nodes
1844
//
@@ -111,7 +137,7 @@ void add_q8ta_binary_node(
111137
// Resize args
112138
{block_config_ref},
113139
// Resizing Logic
114-
nullptr));
140+
resize_q8ta_binary_node));
115141
}
116142

117143
//

backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
1414
#include <executorch/backends/vulkan/runtime/graph/ops/impl/ConvolutionUtils.h>
1515
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
16+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
1617
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
1718

1819
namespace vkcompute {
@@ -218,6 +219,51 @@ ValueRef prepack_quantized_conv2d_weight(
218219
return packed_weight;
219220
}
220221

222+
//
223+
// Resize
224+
//
225+
226+
// resize_args = { input, kernel_size, stride, padding, dilation }
227+
//
228+
// The q8ta_conv2d output is statically allocated at the build-time upper-bound
229+
// shape. Without this resize function the DynamicDispatchNode would never
230+
// virtual_resize the output on trigger_resize(), so a dynamic-shape graph would
231+
// freeze the conv output at its upper bound — feeding e.g. a 238-row input into
232+
// a 241-row buffer leaves garbage rows that GroupNorm's global statistics then
233+
// smear across the whole tensor. Recompute H/W from the current input (N and C
234+
// are shape-independent and stay as currently allocated).
235+
void resize_q8ta_conv2d_node(
236+
ComputeGraph* graph,
237+
const std::vector<ArgGroup>& args,
238+
const std::vector<ValueRef>& resize_args) {
239+
const ValueRef out = args.at(0).refs.at(0);
240+
const ValueRef in = resize_args.at(0);
241+
const ValueRef kernel_size = resize_args.at(1);
242+
const ValueRef stride = resize_args.at(2);
243+
const ValueRef padding = resize_args.at(3);
244+
const ValueRef dilation = resize_args.at(4);
245+
246+
const std::vector<int64_t> in_sizes = graph->sizes_of(in);
247+
248+
// H/W from the current input via the shared conv-output helper. kernel dims
249+
// come from the kernel_size IntList (kernel_size_only=true); the args[3] slot
250+
// is consulted only as an optional ceil_mode and dilation (non-bool) resolves
251+
// it to false. transposed=false.
252+
const std::vector<int64_t> out_hw = calc_out_sizes_hw(
253+
*graph,
254+
in_sizes,
255+
kernel_size,
256+
/*kernel_size_only=*/true,
257+
{stride, padding, dilation, dilation},
258+
/*transposed=*/false);
259+
260+
std::vector<int64_t> new_sizes = graph->sizes_of(out);
261+
const size_t ndim = new_sizes.size();
262+
new_sizes.at(ndim - 2) = out_hw.at(0);
263+
new_sizes.at(ndim - 1) = out_hw.at(1);
264+
graph->virtual_resize(out, new_sizes);
265+
}
266+
221267
//
222268
// Dispatch nodes
223269
//
@@ -327,8 +373,10 @@ void add_q8ta_conv2d_node(
327373
push_constants,
328374
// Specialization Constants
329375
spec_constants,
330-
// Resize args
331-
{}));
376+
// Resize args: { input, kernel_size, stride, padding, dilation }
377+
{packed_int8_input, kernel_size, stride, padding, dilation},
378+
// Resize function: propagate dynamic H/W to the output.
379+
resize_q8ta_conv2d_node));
332380
}
333381

334382
//

backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,12 @@ void add_q8ta_conv2d_pw_node(
123123
const ValueRef packed_bias,
124124
const uint32_t activation_type,
125125
const ValueRef packed_int8_output,
126-
const int32_t groups = 1);
126+
const int32_t groups = 1,
127+
const ValueRef conv_input = kDummyValueRef,
128+
const ValueRef kernel_size = kDummyValueRef,
129+
const ValueRef stride = kDummyValueRef,
130+
const ValueRef padding = kDummyValueRef,
131+
const ValueRef dilation = kDummyValueRef);
127132

128133
std::vector<int64_t> calculate_q8ta_im2col_sizes(
129134
ComputeGraph* graph,

backends/vulkan/runtime/graph/ops/impl/Q8taConv2dDW.cpp

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
1313
#include <executorch/backends/vulkan/runtime/graph/ops/impl/ConvolutionUtils.h>
1414
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
15+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
1516
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
1617

1718
namespace vkcompute {
@@ -172,6 +173,45 @@ ValueRef prepack_quantized_conv2d_dw_weight(
172173
return packed_weight;
173174
}
174175

176+
//
177+
// Resize
178+
//
179+
180+
// resize_args = { input, kernel_size, stride, padding, dilation }
181+
//
182+
// Depthwise conv output H/W follows the same formula as a regular conv (channel
183+
// count is unchanged: groups == in_channels == out_channels). Without this the
184+
// DynamicDispatchNode freezes the output at the build-time upper bound. N/C are
185+
// shape-independent and stay as currently allocated. Mirrors the regular q8ta
186+
// conv resize (resize_q8ta_conv2d_node).
187+
void resize_q8ta_conv2d_dw_node(
188+
ComputeGraph* graph,
189+
const std::vector<ArgGroup>& args,
190+
const std::vector<ValueRef>& resize_args) {
191+
const ValueRef out = args.at(0).refs.at(0);
192+
const ValueRef in = resize_args.at(0);
193+
const ValueRef kernel_size = resize_args.at(1);
194+
const ValueRef stride = resize_args.at(2);
195+
const ValueRef padding = resize_args.at(3);
196+
const ValueRef dilation = resize_args.at(4);
197+
198+
const std::vector<int64_t> in_sizes = graph->sizes_of(in);
199+
200+
const std::vector<int64_t> out_hw = calc_out_sizes_hw(
201+
*graph,
202+
in_sizes,
203+
kernel_size,
204+
/*kernel_size_only=*/true,
205+
{stride, padding, dilation, dilation},
206+
/*transposed=*/false);
207+
208+
std::vector<int64_t> new_sizes = graph->sizes_of(out);
209+
const size_t ndim = new_sizes.size();
210+
new_sizes.at(ndim - 2) = out_hw.at(0);
211+
new_sizes.at(ndim - 1) = out_hw.at(1);
212+
graph->virtual_resize(out, new_sizes);
213+
}
214+
175215
//
176216
// Dispatch nodes
177217
//
@@ -258,10 +298,10 @@ void add_conv2d_dw_q8ta_q8csw_q8to_4w4c_node(
258298
push_constants,
259299
// Specialization Constants
260300
spec_constants,
261-
// Resize args
262-
{},
301+
// Resize args: { input, kernel_size, stride, padding, dilation }
302+
{packed_int8_input, kernel_size, stride, padding, dilation},
263303
// Resizing Logic
264-
nullptr));
304+
resize_q8ta_conv2d_dw_node));
265305
}
266306

267307
void add_q8ta_conv2d_dw_node(
@@ -363,8 +403,10 @@ void add_q8ta_conv2d_dw_node(
363403
push_constants,
364404
// Specialization Constants
365405
spec_constants,
366-
// Resize args
367-
{}));
406+
// Resize args: { input, kernel_size, stride, padding, dilation }
407+
{packed_int8_input, kernel_size, stride, padding, dilation},
408+
// Resizing Logic
409+
resize_q8ta_conv2d_dw_node));
368410
}
369411

370412
//

backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <executorch/backends/vulkan/runtime/graph/ops/impl/ConvolutionUtils.h>
1414
#include <executorch/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.h>
1515
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
16+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
1617
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
1718

1819
namespace vkcompute {
@@ -95,6 +96,59 @@ std::vector<int64_t> calculate_q8ta_im2col_sizes(
9596
return {K, H, W};
9697
}
9798

99+
//
100+
// Resize
101+
//
102+
103+
// resize_args = { input, kernel_size, stride, padding, dilation, groups }
104+
//
105+
// The im2col scratch tensor is [K, H_out, align_up_4(W_out)] where K (the
106+
// flattened conv window, channel/kernel-derived) is shape-independent and
107+
// H_out/W_out are the conv output spatial dims. The downstream PW GEMM that
108+
// consumes this scratch is resized separately (it preserves H/W). Without this,
109+
// the scratch freezes at the build-time upper bound and feeds garbage rows into
110+
// the GEMM. Recompute H_out/W_out from the CURRENT input (NOT the conv output
111+
// tensor, which may itself still be frozen at this point in the resize order).
112+
void resize_q8ta_im2col_node(
113+
ComputeGraph* graph,
114+
const std::vector<ArgGroup>& args,
115+
const std::vector<ValueRef>& resize_args) {
116+
const ValueRef im2col_out = args.at(0).refs.at(0);
117+
const ValueRef in = resize_args.at(0);
118+
const ValueRef kernel_size = resize_args.at(1);
119+
const ValueRef stride = resize_args.at(2);
120+
const ValueRef padding = resize_args.at(3);
121+
const ValueRef dilation = resize_args.at(4);
122+
const ValueRef groups = resize_args.at(5);
123+
124+
const std::vector<int64_t> in_sizes = graph->sizes_of(in);
125+
126+
// Conv output H/W from the current input.
127+
const std::vector<int64_t> out_hw = calc_out_sizes_hw(
128+
*graph,
129+
in_sizes,
130+
kernel_size,
131+
/*kernel_size_only=*/true,
132+
{stride, padding, dilation, dilation},
133+
/*transposed=*/false);
134+
const int64_t out_height = out_hw.at(0);
135+
const int64_t out_width = out_hw.at(1);
136+
137+
// K (flattened conv window) is shape-independent — recompute from channels +
138+
// kernel exactly as calculate_q8ta_im2col_sizes does.
139+
const int64_t in_channels = utils::val_at(-3, in_sizes);
140+
const int64_t groups_val = graph->extract_scalar<int64_t>(groups);
141+
const int64_t in_channels_per_group = in_channels / groups_val;
142+
const auto kernel_size_list = graph->get_int_list(kernel_size);
143+
const int64_t flattened_kernel_len = utils::align_up_4(
144+
in_channels_per_group * kernel_size_list->at(0) *
145+
kernel_size_list->at(1));
146+
const int64_t K = flattened_kernel_len * groups_val;
147+
const int64_t W = utils::align_up_4(out_width);
148+
149+
graph->virtual_resize(im2col_out, {K, out_height, W});
150+
}
151+
98152
//
99153
// Dispatch nodes
100154
//
@@ -168,10 +222,11 @@ void add_q8ta_im2col_node(
168222
push_constants,
169223
// Specialization Constants
170224
spec_constants,
171-
// Resize args
172-
{},
173-
// Resizing Logic
174-
nullptr));
225+
// Resize args: { input, kernel_size, stride, padding, dilation, groups }
226+
{packed_int8_input, kernel_size, stride, padding, dilation, groups},
227+
// Resizing Logic: recompute the im2col scratch dims from the current
228+
// input
229+
resize_q8ta_im2col_node));
175230
}
176231

177232
//
@@ -272,7 +327,14 @@ void q8ta_conv2d_im2col(
272327
packed_bias,
273328
activation_type_val,
274329
packed_int8_output,
275-
groups_val);
330+
groups_val,
331+
// Original activation + conv geometry so the PW output H/W is recomputed
332+
// from the true conv result, not the width-padded im2col scratch.
333+
packed_int8_input,
334+
kernel_size,
335+
stride,
336+
padding,
337+
dilation);
276338
}
277339

278340
REGISTER_OPERATORS {

0 commit comments

Comments
 (0)