Skip to content

Commit 5f9550f

Browse files
CPU offloading fix: If Data and Transpose is None depend on super Torch tensor class for the shape (#2841)
* fix Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 86edac4 commit 5f9550f

5 files changed

Lines changed: 49 additions & 4 deletions

File tree

tests/pytorch/test_quantized_tensor.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
MXFP8Quantizer,
1919
NVFP4Quantizer,
2020
Float8Tensor,
21+
Float8BlockwiseQTensor,
2122
MXFP8Tensor,
2223
NVFP4Tensor,
2324
QuantizedTensor,
@@ -657,6 +658,50 @@ def test_chunk(
657658
y_test = y_test.to(dtype=torch.float64, device="cpu")
658659
torch.testing.assert_close(y_test, y_ref, **tols)
659660

661+
@pytest.mark.parametrize("quantization", _quantization_list)
662+
def test_shape_with_none_data(
663+
self,
664+
*,
665+
quantization: str,
666+
shape: Iterable[int] = (128, 128),
667+
dtype: torch.dtype = torch.bfloat16,
668+
) -> None:
669+
"""Test that shape is accessible after internal data tensors are set to None.
670+
671+
During CPU offloading, both data and transpose tensors can be None.
672+
The shape should still be available via the wrapper subclass metadata.
673+
"""
674+
675+
_, x_test = make_reference_and_test_tensors(
676+
shape=shape,
677+
quantization=quantization,
678+
test_dtype=dtype,
679+
requires_grad=False,
680+
)
681+
682+
# Verify shape before clearing data
683+
assert x_test.shape == torch.Size(shape)
684+
685+
# Simulate CPU offloading: None out all internal data
686+
if isinstance(x_test, Float8Tensor):
687+
x_test._data = None
688+
x_test._transpose = None
689+
elif isinstance(x_test, MXFP8Tensor):
690+
x_test._rowwise_data = None
691+
x_test._columnwise_data = None
692+
elif isinstance(x_test, NVFP4Tensor):
693+
x_test._rowwise_data = None
694+
x_test._columnwise_data = None
695+
elif isinstance(x_test, Float8BlockwiseQTensor):
696+
x_test._rowwise_data = None
697+
x_test._columnwise_data = None
698+
699+
# Shape must still be correct after data is cleared
700+
assert x_test.shape == torch.Size(shape), (
701+
f"Expected shape {shape} but got {x_test.shape} "
702+
f"after setting data to None on {type(x_test).__name__}"
703+
)
704+
660705

661706
@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8)
662707
class TestMXFP8Tensor:

transformer_engine/pytorch/tensor/float8_blockwise_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ def shape(self):
598598
return self._rowwise_data.shape
599599
if self._columnwise_data is not None:
600600
return self._columnwise_data.shape
601-
raise RuntimeError("Float8BlockwiseQTensor has no data!")
601+
return torch.Tensor.size(self)
602602

603603
@property
604604
def is_cuda(self):

transformer_engine/pytorch/tensor/float8_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -967,7 +967,7 @@ def shape(self):
967967
if self._transpose is not None:
968968
transpose_shape = self._transpose.shape
969969
return torch.Size(tuple(transpose_shape[1:]) + (transpose_shape[0],))
970-
raise RuntimeError("Both data and transpose are None")
970+
return torch.Tensor.size(self)
971971

972972
@property
973973
def is_cuda(self):

transformer_engine/pytorch/tensor/mxfp8_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -884,7 +884,7 @@ def shape(self):
884884
return self._rowwise_data.shape
885885
if self._columnwise_data is not None:
886886
return self._columnwise_data.shape
887-
raise RuntimeError("MXFP8Tensor has no data!")
887+
return torch.Tensor.size(self)
888888

889889
@property
890890
def is_cuda(self):

transformer_engine/pytorch/tensor/nvfp4_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,7 @@ def shape(self):
745745
if self._columnwise_data is not None:
746746
byte_shape = self._columnwise_data.shape
747747
return torch.Size(byte_shape[1:-1] + (byte_shape[-1] * 2, byte_shape[0]))
748-
raise RuntimeError("NVFP4Tensor has no data!")
748+
return torch.Tensor.size(self)
749749

750750
@property
751751
def is_cuda(self):

0 commit comments

Comments
 (0)