1515import torch
1616from auto_round .data_type .utils import round_ste , reshape_pad_tensor_by_group_size , revert_tensor_by_pad , logger
1717from auto_round .data_type .register import register_dtype
18+ from auto_round .utils import get_reciprocal
1819
1920
2021@register_dtype ("int_sym_dq" )
@@ -69,7 +70,7 @@ def quant_tensor_sym_dq(
6970
7071 scale = scale .view (- 1 , 1 )
7172 zp = torch .full_like (scale , maxq ) # pylint: disable=E1130
72- int_w = torch . where ( scale != 0 , round_ste (tensor / scale + v ), 0 )
73+ int_w = round_ste (tensor * get_reciprocal ( scale ) + v )
7374 q = torch .clamp (int_w + zp , 0 , 2 ** bits - 1 )
7475 qdq_result = (scale * (q - zp )).to (tensor .dtype )
7576 qdq_result = revert_tensor_by_pad (qdq_result , orig_shape = orig_shape , pad_len = pad_len )
@@ -140,7 +141,7 @@ def double_quant_tensor(tensor, bits):
140141 scale = wmax / maxq
141142 scale = scale .view (- 1 , 1 )
142143 # inverse_scale = torch.where(scale == 0, 0, 1 / scale)
143- inverse_scale = torch . where ( wmax > 0 , maxq / wmax , 0 ).view (- 1 , 1 )
144+ inverse_scale = ( maxq * get_reciprocal ( wmax )). clamp ( min = 0 ).view (- 1 , 1 )
144145 qdq_tensor = torch .clamp (round_ste (tensor * inverse_scale ), max = maxq ) * scale
145146 return qdq_tensor , scale
146147
@@ -151,7 +152,7 @@ def double_quant_tensor_sym(tensor, bits):
151152 imax = abs (tensor ).argmax (axis = - 1 , keepdims = True )
152153 wmax = torch .take_along_dim (tensor , imax , dim = - 1 )
153154 scale = wmax / - maxq
154- inverse_scale = torch . where (scale == 0 , 0 , 1 / scale ) ##1e-40
155+ inverse_scale = get_reciprocal (scale )
155156 qdq_tensor = torch .clip ((round_ste (tensor * inverse_scale )), - maxq , maxq - 1 ) * scale
156157 return qdq_tensor , scale
157158
@@ -161,7 +162,7 @@ def make_qp_quants(nmax, data, quant_weights):
161162 quant_weights = quant_weights .to (torch .float32 )
162163 group_max = torch .max (data , dim = - 1 , keepdim = True )[0 ]
163164 scale = group_max / nmax
164- iscale = torch . where ( scale == 0 , 0 , 1 / scale )
165+ iscale = get_reciprocal ( scale )
165166
166167 L = torch .round (iscale * data )
167168 diffs = data - scale * L
@@ -171,7 +172,7 @@ def make_qp_quants(nmax, data, quant_weights):
171172 if _is == 0 :
172173 continue
173174 scale_is = group_max / (0.1 * _is + nmax )
174- iscale_is = torch . where ( scale_is == 0 , 0 , 1 / scale_is )
175+ iscale_is = get_reciprocal ( scale_is )
175176
176177 tmp_L = torch .round (iscale_is * data ).clip (max = nmax )
177178 diffs = data - scale_is * tmp_L
@@ -328,11 +329,11 @@ def quant_tensor_gguf_asym_dq(
328329 use_mad = params ["use_mad" ], weights = quant_weights
329330 )
330331 scale = scale .to (scale_dtype )
331- scale = torch .where (torch .abs (scale ) < 1e-30 , 0 , scale )
332+ scale = torch .where (torch .abs (scale ) < 1e-30 , torch . zeros_like ( scale ) , scale )
332333 scale = scale .reshape (- 1 , super_group_size )
333334 wmin = wmin .reshape (- 1 , super_group_size )
334335 scale , d_scale = double_quant_tensor (scale , super_bits )
335- wmin = torch .where (torch .abs (wmin ) < 1e-30 , 0 , wmin )
336+ wmin = torch .where (torch .abs (wmin ) < 1e-30 , torch . zeros_like ( wmin ) , wmin )
336337 wmin , d_wmin = double_quant_tensor (wmin , super_bits )
337338 wmin = wmin .view (- 1 , 1 )
338339 scale = scale .view (- 1 , 1 )
@@ -386,7 +387,7 @@ def quant_tensor_gguf_asym_dq(
386387 use_mad = params ["use_mad" ], weights = quant_weights
387388 )
388389 scale = scale .to (scale_dtype )
389- scale = torch .where (torch .abs (scale ) < 1e-30 , 0 , scale )
390+ scale = torch .where (torch .abs (scale ) < 1e-30 , torch . zeros_like ( scale ) , scale )
390391 nmax = 2 ** super_bits - 1
391392 scale = scale .reshape (- 1 , super_group_size )
392393 wmin = wmin_0 .reshape (- 1 , super_group_size )
@@ -399,7 +400,7 @@ def quant_tensor_gguf_asym_dq(
399400 d_wmin = d_wmin .unsqueeze (- 1 )
400401 scale = (d_scale * q_scale ).view (- 1 , 1 )
401402 wmin = (d_wmin * q_wmin ).view (- 1 , 1 )
402- inverse_scale = torch . where ( scale == 0 , 0 , 1 / scale )
403+ inverse_scale = get_reciprocal ( scale )
403404
404405 int_w = torch .clamp (round_ste ((tensor + wmin ) * inverse_scale + v ), 0 , maxq )
405406 qdq_result = (scale * int_w - wmin ).to (orig_dtype )
@@ -436,7 +437,7 @@ def iterative_wls_quant_search(data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, u
436437
437438 # scale = 1 / ((maxq - minq) / (rmax - rmin + 1e-8))
438439 scale = (rmax - rmin ) / (maxq - minq )
439- iscale = torch . where ( scale == 0 , 0 , 1 / scale )
440+ iscale = get_reciprocal ( scale )
440441 # quant_data = torch.clamp(torch.round((maxq - minq) / (rmax - rmin + 1e-8) * (data - rmin)), minq, maxq)
441442 quant_data = torch .clamp (torch .round (iscale * (data - rmin )), minq , maxq )
442443 diff = scale * quant_data + rmin - data
@@ -447,7 +448,7 @@ def iterative_wls_quant_search(data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, u
447448 factor = rrmin + rdelta * is_ + maxq - minq
448449 # iscale_new = factor / (rmax - rmin + 1e-8)
449450 scale_new = (rmax - rmin ) / factor
450- iscale_new = torch . where ( scale_new == 0 , 0 , 1 / scale_new )
451+ iscale_new = get_reciprocal ( scale_new )
451452 quant_data_new = torch .clamp (torch .round (iscale_new * (data - rmin )), minq , maxq )
452453
453454 mul_weights_quant_data = weights * quant_data_new
@@ -460,7 +461,7 @@ def iterative_wls_quant_search(data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, u
460461 this_min = (sum_l2 * sum_x - sum_l * sum_xl ) / D
461462 this_min [this_min > 0 ] = 0
462463 this_scale [this_min > 0 ] = (sum_xl / sum_l2 )[this_min > 0 ]
463- reverse_this_scale = torch . where ( this_scale == 0 , 0 , 1 / this_scale )
464+ reverse_this_scale = get_reciprocal ( this_scale )
464465
465466 quant_data = torch .clamp (torch .round (reverse_this_scale * (data - this_min )), minq , maxq )
466467 diff = this_scale * quant_data + this_min - data
@@ -569,13 +570,13 @@ def quant_tensor_gguf_sym_dq(
569570 quant_weights [mean_replace_index , :] = tmp_quant_weights [mean_replace_index , :]
570571
571572 scale , int_w = make_qx_quants (tensor , bits = bits , rmse_type = 1 , qw = quant_weights )
572- scale = torch .where (torch .abs (scale ) < 1e-30 , 0 , scale )
573+ scale = torch .where (torch .abs (scale ) < 1e-30 , torch . zeros_like ( scale ) , scale )
573574 # conduct double quant
574575 scale , d_scale = double_quant_tensor_sym (scale , super_bits )
575576
576577 scale = scale .unsqueeze (- 1 )
577578 zp = torch .full_like (scale , maxq ) # pylint: disable=E1130
578- inverse_scale = torch . where ( scale == 0 , 0 , 1.0 / scale )
579+ inverse_scale = get_reciprocal ( scale )
579580 int_w = torch .round (tensor * inverse_scale ).clip (- maxq , maxq - 1 ) + maxq
580581 qdq_result = (scale * (int_w - zp )).to (orig_dtype )
581582 qdq_result = revert_tensor_by_pad (qdq_result , orig_shape = orig_shape , pad_len = pad_len )
0 commit comments