@@ -118,6 +118,9 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize,
118
118
for i in range (iters ):
119
119
A1 = torch .randn (1024 , 1024 , device = device , dtype = dtype )
120
120
C , S = F .quantize_blockwise (A1 , blocksize = blocksize , nested = nested )
121
+ if i == 0 :
122
+ d = S .as_dict ()
123
+ S = F .QuantState .from_dict (d , device = torch .device (device ))
121
124
A2 = F .dequantize_blockwise (C , S )
122
125
diff = torch .abs (A1 - A2 ).float ()
123
126
reldiff = diff / torch .abs (A1 .float () + 1e-8 )
@@ -134,6 +137,9 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize,
134
137
for i in range (iters ):
135
138
A1 = torch .rand (1024 , 1024 , device = device , dtype = dtype )
136
139
C , S = F .quantize_blockwise (A1 , blocksize = blocksize , nested = nested , code = code )
140
+ if i == 0 :
141
+ d = S .as_dict ()
142
+ S = F .QuantState .from_dict (d , device = torch .device (device ))
137
143
A2 = F .dequantize_blockwise (C , S )
138
144
diff = torch .abs (A1 - A2 ).float ()
139
145
reldiff = diff / torch .abs (A1 .float () + 1e-8 )
@@ -243,6 +249,9 @@ def test_fp8_quant(self, device):
243
249
for i in range (10 ):
244
250
A1 = torch .randn (1024 , 1024 , device = device )
245
251
C , SC = F .quantize_blockwise (A1 , code = code )
252
+ if i == 0 :
253
+ d = SC .as_dict ()
254
+ SC = F .QuantState .from_dict (d , device = torch .device (device ))
246
255
A2 = F .dequantize_blockwise (C , SC )
247
256
diff = torch .abs (A1 - A2 )
248
257
reldiff = diff / torch .abs (A1 + 1e-8 )
@@ -1116,6 +1125,8 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
1116
1125
1117
1126
A1 = torch .randn (1024 , 1024 , device = device , dtype = dtype )
1118
1127
qa , SA = F .quantize_4bit (A1 , blocksize = blocksize , quant_type = quant_type )
1128
+ d = SA .as_dict ()
1129
+ SA = F .QuantState .from_dict (d , device = torch .device (device ))
1119
1130
A2 = F .dequantize_4bit (qa , SA , blocksize = blocksize , quant_type = quant_type )
1120
1131
1121
1132
err = (A1 - A2 ).abs ().float ()
0 commit comments