|
13 | 13 | #include <executorch/backends/vulkan/runtime/graph/ops/impl/ConvolutionUtils.h> |
14 | 14 | #include <executorch/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.h> |
15 | 15 | #include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h> |
| 16 | +#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h> |
16 | 17 | #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h> |
17 | 18 |
|
18 | 19 | namespace vkcompute { |
@@ -95,6 +96,59 @@ std::vector<int64_t> calculate_q8ta_im2col_sizes( |
95 | 96 | return {K, H, W}; |
96 | 97 | } |
97 | 98 |
|
| 99 | +// |
| 100 | +// Resize |
| 101 | +// |
| 102 | + |
| 103 | +// resize_args = { input, kernel_size, stride, padding, dilation, groups } |
| 104 | +// |
| 105 | +// The im2col scratch tensor is [K, H_out, align_up_4(W_out)] where K (the |
| 106 | +// flattened conv window, channel/kernel-derived) is shape-independent and |
| 107 | +// H_out/W_out are the conv output spatial dims. The downstream PW GEMM that |
| 108 | +// consumes this scratch is resized separately (it preserves H/W). Without this, |
| 109 | +// the scratch freezes at the build-time upper bound and feeds garbage rows into |
| 110 | +// the GEMM. Recompute H_out/W_out from the CURRENT input (NOT the conv output |
| 111 | +// tensor, which may itself still be frozen at this point in the resize order). |
| 112 | +void resize_q8ta_im2col_node( |
| 113 | + ComputeGraph* graph, |
| 114 | + const std::vector<ArgGroup>& args, |
| 115 | + const std::vector<ValueRef>& resize_args) { |
| 116 | + const ValueRef im2col_out = args.at(0).refs.at(0); |
| 117 | + const ValueRef in = resize_args.at(0); |
| 118 | + const ValueRef kernel_size = resize_args.at(1); |
| 119 | + const ValueRef stride = resize_args.at(2); |
| 120 | + const ValueRef padding = resize_args.at(3); |
| 121 | + const ValueRef dilation = resize_args.at(4); |
| 122 | + const ValueRef groups = resize_args.at(5); |
| 123 | + |
| 124 | + const std::vector<int64_t> in_sizes = graph->sizes_of(in); |
| 125 | + |
| 126 | + // Conv output H/W from the current input. |
| 127 | + const std::vector<int64_t> out_hw = calc_out_sizes_hw( |
| 128 | + *graph, |
| 129 | + in_sizes, |
| 130 | + kernel_size, |
| 131 | + /*kernel_size_only=*/true, |
| 132 | + {stride, padding, dilation, dilation}, |
| 133 | + /*transposed=*/false); |
| 134 | + const int64_t out_height = out_hw.at(0); |
| 135 | + const int64_t out_width = out_hw.at(1); |
| 136 | + |
| 137 | + // K (flattened conv window) is shape-independent — recompute from channels + |
| 138 | + // kernel exactly as calculate_q8ta_im2col_sizes does. |
| 139 | + const int64_t in_channels = utils::val_at(-3, in_sizes); |
| 140 | + const int64_t groups_val = graph->extract_scalar<int64_t>(groups); |
| 141 | + const int64_t in_channels_per_group = in_channels / groups_val; |
| 142 | + const auto kernel_size_list = graph->get_int_list(kernel_size); |
| 143 | + const int64_t flattened_kernel_len = utils::align_up_4( |
| 144 | + in_channels_per_group * kernel_size_list->at(0) * |
| 145 | + kernel_size_list->at(1)); |
| 146 | + const int64_t K = flattened_kernel_len * groups_val; |
| 147 | + const int64_t W = utils::align_up_4(out_width); |
| 148 | + |
| 149 | + graph->virtual_resize(im2col_out, {K, out_height, W}); |
| 150 | +} |
| 151 | + |
98 | 152 | // |
99 | 153 | // Dispatch nodes |
100 | 154 | // |
@@ -168,10 +222,11 @@ void add_q8ta_im2col_node( |
168 | 222 | push_constants, |
169 | 223 | // Specialization Constants |
170 | 224 | spec_constants, |
171 | | - // Resize args |
172 | | - {}, |
173 | | - // Resizing Logic |
174 | | - nullptr)); |
| 225 | + // Resize args: { input, kernel_size, stride, padding, dilation, groups } |
| 226 | + {packed_int8_input, kernel_size, stride, padding, dilation, groups}, |
| 227 | + // Resizing Logic: recompute the im2col scratch dims from the current |
| 228 | + // input |
| 229 | + resize_q8ta_im2col_node)); |
175 | 230 | } |
176 | 231 |
|
177 | 232 | // |
@@ -272,7 +327,14 @@ void q8ta_conv2d_im2col( |
272 | 327 | packed_bias, |
273 | 328 | activation_type_val, |
274 | 329 | packed_int8_output, |
275 | | - groups_val); |
| 330 | + groups_val, |
| 331 | + // Original activation + conv geometry so the PW output H/W is recomputed |
| 332 | + // from the true conv result, not the width-padded im2col scratch. |
| 333 | + packed_int8_input, |
| 334 | + kernel_size, |
| 335 | + stride, |
| 336 | + padding, |
| 337 | + dilation); |
276 | 338 | } |
277 | 339 |
|
278 | 340 | REGISTER_OPERATORS { |
|
0 commit comments