Commit d9034b7
Update fused quant broadcast logic (#20171)
Summary:
Unifies QuantParamsStruct (sas_compiler's central quant-params abstraction) onto a single affine-quantization representation and drops the axis argument from every fused-quant op interface.
Core change (ops.py): scale/zero_point are now either a singleton (per-tensor, auto-expanded internally) or a full-rank tensor whose shape encodes the affine block layout — block_size[i] = tensor.shape[i] // scale.shape[i]. This one representation covers per-tensor, per-channel, per-group, and blockwise uniformly. quantize/dequantize delegate to torch.ops.torchao.(de)quantize_affine. The axis field is removed from QuantParamsStruct, all ~60 op-schema fields, _make_qp, and the _lib.define strings. is_per_tensor/is_per_channel/is_per_group and a new channel_axis() helper are now derived from scale shape (channel_axis() returns 0 if all dims are unary, the single non-unary dim if there's
exactly one, else None).
Fusion (fusion_pass.py, fusion_passes/utils.py): the qparams flat block is 6→5 tuple; the per-channel branch inserts an aten.view to make 1-D scales full-rank [1, …, C, …, 1] so their shape encodes the block layout.
Lowering boundary (graph_utils.py, lower_to_turing_linear.py, lower_to_turing_conv_no_nlu_params.py): Helios ParameterExtraction wants the compact scale form ([K] / [K, num_groups]), but the fused op now carries full-rank scales. New compact_scale_node() squeezes size-1 dims at the lowering boundary; the inserted view_copy folds via ConstantPropPass before extraction. Lowering asserts channel_axis() is not None.
Broadcast fix (fuse_mul_into_linear.py): channel_scale is an activation-space [K] vector (out-features is the trailing dim of the mul constant), whereas the weight scale is now full-rank [K, 1] (out-features at dim 0). Reshape the channel factor to [-1, 1, …] so it broadcasts along the weight scale's channel axis instead of producing a [K, K] outer product. The 1-D bias multiply is unchanged (bias is already [K]).
Misc consumers: quant_absorption.py per-tensor check is now out_scale.numel() == 1; BUCK adds the torchao dep.
Reviewed By: ethansfng
Differential Revision: D1080655881 parent 92e6a4c commit d9034b7
16 files changed
Lines changed: 105 additions & 236 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
43 | 43 | | |
44 | 44 | | |
45 | 45 | | |
46 | | - | |
47 | 46 | | |
48 | 47 | | |
49 | 48 | | |
50 | 49 | | |
51 | 50 | | |
52 | | - | |
53 | 51 | | |
54 | 52 | | |
55 | 53 | | |
56 | 54 | | |
57 | 55 | | |
58 | | - | |
59 | 56 | | |
60 | 57 | | |
61 | 58 | | |
| |||
72 | 69 | | |
73 | 70 | | |
74 | 71 | | |
75 | | - | |
| 72 | + | |
76 | 73 | | |
77 | 74 | | |
78 | 75 | | |
| |||
88 | 85 | | |
89 | 86 | | |
90 | 87 | | |
91 | | - | |
92 | | - | |
93 | | - | |
94 | | - | |
95 | | - | |
96 | | - | |
| 88 | + | |
97 | 89 | | |
98 | 90 | | |
99 | 91 | | |
| |||
107 | 99 | | |
108 | 100 | | |
109 | 101 | | |
110 | | - | |
| 102 | + | |
111 | 103 | | |
112 | 104 | | |
113 | 105 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
24 | 24 | | |
25 | 25 | | |
26 | 26 | | |
27 | | - | |
28 | 27 | | |
29 | 28 | | |
30 | 29 | | |
31 | 30 | | |
32 | 31 | | |
33 | 32 | | |
34 | | - | |
35 | 33 | | |
36 | 34 | | |
37 | 35 | | |
38 | 36 | | |
39 | 37 | | |
40 | | - | |
41 | 38 | | |
42 | 39 | | |
43 | 40 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
53 | 53 | | |
54 | 54 | | |
55 | 55 | | |
56 | | - | |
57 | 56 | | |
58 | 57 | | |
59 | 58 | | |
60 | 59 | | |
61 | 60 | | |
62 | | - | |
63 | 61 | | |
64 | 62 | | |
65 | 63 | | |
66 | 64 | | |
67 | 65 | | |
68 | | - | |
69 | 66 | | |
70 | 67 | | |
71 | 68 | | |
| |||
87 | 84 | | |
88 | 85 | | |
89 | 86 | | |
90 | | - | |
| 87 | + | |
91 | 88 | | |
92 | 89 | | |
93 | 90 | | |
| |||
104 | 101 | | |
105 | 102 | | |
106 | 103 | | |
107 | | - | |
108 | | - | |
109 | | - | |
110 | | - | |
111 | | - | |
112 | | - | |
| 104 | + | |
113 | 105 | | |
114 | 106 | | |
115 | 107 | | |
| |||
126 | 118 | | |
127 | 119 | | |
128 | 120 | | |
129 | | - | |
| 121 | + | |
130 | 122 | | |
131 | 123 | | |
132 | 124 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
24 | 24 | | |
25 | 25 | | |
26 | 26 | | |
27 | | - | |
28 | 27 | | |
29 | 28 | | |
30 | 29 | | |
31 | 30 | | |
32 | 31 | | |
33 | 32 | | |
34 | | - | |
35 | 33 | | |
36 | 34 | | |
37 | 35 | | |
38 | 36 | | |
39 | 37 | | |
40 | | - | |
41 | 38 | | |
42 | 39 | | |
43 | 40 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
40 | 40 | | |
41 | 41 | | |
42 | 42 | | |
43 | | - | |
44 | 43 | | |
45 | 44 | | |
46 | 45 | | |
47 | 46 | | |
48 | 47 | | |
49 | | - | |
50 | 48 | | |
51 | 49 | | |
52 | 50 | | |
| |||
60 | 58 | | |
61 | 59 | | |
62 | 60 | | |
63 | | - | |
| 61 | + | |
64 | 62 | | |
65 | 63 | | |
66 | 64 | | |
| |||
74 | 72 | | |
75 | 73 | | |
76 | 74 | | |
77 | | - | |
| 75 | + | |
78 | 76 | | |
79 | 77 | | |
80 | 78 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
23 | 23 | | |
24 | 24 | | |
25 | 25 | | |
26 | | - | |
27 | 26 | | |
28 | 27 | | |
29 | 28 | | |
30 | 29 | | |
31 | 30 | | |
32 | | - | |
33 | 31 | | |
34 | 32 | | |
35 | 33 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
42 | 42 | | |
43 | 43 | | |
44 | 44 | | |
45 | | - | |
46 | 45 | | |
47 | 46 | | |
48 | 47 | | |
49 | 48 | | |
50 | 49 | | |
51 | | - | |
52 | 50 | | |
53 | 51 | | |
54 | 52 | | |
55 | 53 | | |
56 | 54 | | |
57 | | - | |
58 | 55 | | |
59 | 56 | | |
60 | 57 | | |
| |||
74 | 71 | | |
75 | 72 | | |
76 | 73 | | |
77 | | - | |
| 74 | + | |
78 | 75 | | |
79 | 76 | | |
80 | 77 | | |
| |||
90 | 87 | | |
91 | 88 | | |
92 | 89 | | |
93 | | - | |
94 | | - | |
95 | | - | |
96 | | - | |
97 | | - | |
98 | | - | |
| 90 | + | |
99 | 91 | | |
100 | 92 | | |
101 | 93 | | |
| |||
109 | 101 | | |
110 | 102 | | |
111 | 103 | | |
112 | | - | |
| 104 | + | |
113 | 105 | | |
114 | 106 | | |
115 | 107 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
24 | 24 | | |
25 | 25 | | |
26 | 26 | | |
27 | | - | |
28 | 27 | | |
29 | 28 | | |
30 | 29 | | |
31 | 30 | | |
32 | 31 | | |
33 | 32 | | |
34 | | - | |
35 | 33 | | |
36 | 34 | | |
37 | 35 | | |
38 | 36 | | |
39 | 37 | | |
40 | | - | |
41 | 38 | | |
42 | 39 | | |
43 | 40 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
39 | 39 | | |
40 | 40 | | |
41 | 41 | | |
42 | | - | |
43 | 42 | | |
44 | 43 | | |
45 | 44 | | |
46 | 45 | | |
47 | 46 | | |
48 | | - | |
49 | 47 | | |
50 | 48 | | |
51 | 49 | | |
| |||
59 | 57 | | |
60 | 58 | | |
61 | 59 | | |
62 | | - | |
| 60 | + | |
63 | 61 | | |
64 | 62 | | |
65 | 63 | | |
| |||
73 | 71 | | |
74 | 72 | | |
75 | 73 | | |
76 | | - | |
| 74 | + | |
77 | 75 | | |
78 | 76 | | |
79 | 77 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
23 | 23 | | |
24 | 24 | | |
25 | 25 | | |
26 | | - | |
27 | 26 | | |
28 | 27 | | |
29 | 28 | | |
30 | 29 | | |
31 | 30 | | |
32 | | - | |
33 | 31 | | |
34 | 32 | | |
35 | 33 | | |
| |||
0 commit comments