|
18 | 18 | MXFP8Quantizer, |
19 | 19 | NVFP4Quantizer, |
20 | 20 | Float8Tensor, |
| 21 | + Float8BlockwiseQTensor, |
21 | 22 | MXFP8Tensor, |
22 | 23 | NVFP4Tensor, |
23 | 24 | QuantizedTensor, |
@@ -657,6 +658,50 @@ def test_chunk( |
657 | 658 | y_test = y_test.to(dtype=torch.float64, device="cpu") |
658 | 659 | torch.testing.assert_close(y_test, y_ref, **tols) |
659 | 660 |
|
| 661 | + @pytest.mark.parametrize("quantization", _quantization_list) |
| 662 | + def test_shape_with_none_data( |
| 663 | + self, |
| 664 | + *, |
| 665 | + quantization: str, |
| 666 | + shape: Iterable[int] = (128, 128), |
| 667 | + dtype: torch.dtype = torch.bfloat16, |
| 668 | + ) -> None: |
| 669 | + """Test that shape is accessible after internal data tensors are set to None. |
| 670 | +
|
| 671 | + During CPU offloading, both data and transpose tensors can be None. |
| 672 | + The shape should still be available via the wrapper subclass metadata. |
| 673 | + """ |
| 674 | + |
| 675 | + _, x_test = make_reference_and_test_tensors( |
| 676 | + shape=shape, |
| 677 | + quantization=quantization, |
| 678 | + test_dtype=dtype, |
| 679 | + requires_grad=False, |
| 680 | + ) |
| 681 | + |
| 682 | + # Verify shape before clearing data |
| 683 | + assert x_test.shape == torch.Size(shape) |
| 684 | + |
| 685 | + # Simulate CPU offloading: None out all internal data |
| 686 | + if isinstance(x_test, Float8Tensor): |
| 687 | + x_test._data = None |
| 688 | + x_test._transpose = None |
| 689 | + elif isinstance(x_test, MXFP8Tensor): |
| 690 | + x_test._rowwise_data = None |
| 691 | + x_test._columnwise_data = None |
| 692 | + elif isinstance(x_test, NVFP4Tensor): |
| 693 | + x_test._rowwise_data = None |
| 694 | + x_test._columnwise_data = None |
| 695 | + elif isinstance(x_test, Float8BlockwiseQTensor): |
| 696 | + x_test._rowwise_data = None |
| 697 | + x_test._columnwise_data = None |
| 698 | + |
| 699 | + # Shape must still be correct after data is cleared |
| 700 | + assert x_test.shape == torch.Size(shape), ( |
| 701 | + f"Expected shape {shape} but got {x_test.shape} " |
| 702 | + f"after setting data to None on {type(x_test).__name__}" |
| 703 | + ) |
| 704 | + |
660 | 705 |
|
661 | 706 | @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) |
662 | 707 | class TestMXFP8Tensor: |
|
0 commit comments