Skip to content

Commit b169b03

Browse files
committed
Add test
Signed-off-by: cyy <[email protected]>
1 parent d53c8f6 commit b169b03

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

tests/test_functional.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize,
118118
for i in range(iters):
119119
A1 = torch.randn(1024, 1024, device=device, dtype=dtype)
120120
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
121+
if i == 0:
122+
d = S.as_dict()
123+
S = F.QuantState.from_dict(d, device=torch.device(device))
121124
A2 = F.dequantize_blockwise(C, S)
122125
diff = torch.abs(A1 - A2).float()
123126
reldiff = diff / torch.abs(A1.float() + 1e-8)
@@ -134,6 +137,9 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize,
134137
for i in range(iters):
135138
A1 = torch.rand(1024, 1024, device=device, dtype=dtype)
136139
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested, code=code)
140+
if i == 0:
141+
d = S.as_dict()
142+
S = F.QuantState.from_dict(d, device=torch.device(device))
137143
A2 = F.dequantize_blockwise(C, S)
138144
diff = torch.abs(A1 - A2).float()
139145
reldiff = diff / torch.abs(A1.float() + 1e-8)
@@ -243,6 +249,9 @@ def test_fp8_quant(self, device):
243249
for i in range(10):
244250
A1 = torch.randn(1024, 1024, device=device)
245251
C, SC = F.quantize_blockwise(A1, code=code)
252+
if i == 0:
253+
d = SC.as_dict()
254+
SC = F.QuantState.from_dict(d, device=torch.device(device))
246255
A2 = F.dequantize_blockwise(C, SC)
247256
diff = torch.abs(A1 - A2)
248257
reldiff = diff / torch.abs(A1 + 1e-8)
@@ -1116,6 +1125,8 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
11161125

11171126
A1 = torch.randn(1024, 1024, device=device, dtype=dtype)
11181127
qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
1128+
d = SA.as_dict()
1129+
SA = F.QuantState.from_dict(d, device=torch.device(device))
11191130
A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
11201131

11211132
err = (A1 - A2).abs().float()

0 commit comments

Comments
 (0)