|
14 | 14 | tl_round = tl.math.round |
15 | 15 |
|
16 | 16 |
|
17 | | -def per_channel_quant(x: torch.Tensor, n_bits: int, dtype: torch.dtype): |
| 17 | +def per_channel_quant(x: torch.Tensor, dtype: torch.dtype): |
18 | 18 | """Quantize the input tensor 'x' channel-wise using the given number of |
19 | 19 | bits. |
20 | 20 |
|
21 | 21 | Args: |
22 | 22 | x (torch.Tensor): The input tensor to be quantized. Must be a |
23 | 23 | 2-dimensional tensor. |
24 | | - n_bits (int): The number of bits to use for quantization. |
25 | 24 | dtype (torch.dtype): The data type to which the quantized tensor should |
26 | 25 | be converted. |
27 | 26 |
|
@@ -527,7 +526,7 @@ def linear_torch(x, b): |
527 | 526 | return F.linear(x, b) |
528 | 527 |
|
529 | 528 | linear_weight_quant, linear_scale = per_channel_quant( |
530 | | - linear_weight, 8, quant_dtype) |
| 529 | + linear_weight, quant_dtype) |
531 | 530 |
|
532 | 531 | rms_out, rms_scale = rms_norm_dynamic_quant(x, |
533 | 532 | rms_weight, |
@@ -627,7 +626,7 @@ def y_fwd(): |
627 | 626 | quant_dtype = torch.float8_e5m2 |
628 | 627 |
|
629 | 628 | linear_weight_quant, linear_scale = per_channel_quant( |
630 | | - linear_weight, 8, quant_dtype) |
| 629 | + linear_weight, quant_dtype) |
631 | 630 |
|
632 | 631 | alpha = max(x.max().abs(), x.min().abs()) |
633 | 632 | if quant_dtype.is_floating_point: |
|
0 commit comments