Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_quantized_tensor
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_grouped_tensor.xml $TE_PATH/tests/pytorch/test_grouped_tensor.py || test_fail "test_grouped_tensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from transformer_engine.pytorch.custom_recipes import utils
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.common.recipe import NVFP4BlockScaling
from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor
from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor

import pytest
import torch
Expand Down
100 changes: 44 additions & 56 deletions tests/pytorch/test_grouped_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
import torch
import transformer_engine.pytorch as te
from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor
from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor
from transformer_engine.pytorch import (
Quantizer,
Float8Quantizer,
Expand Down Expand Up @@ -125,7 +125,7 @@ def test_basic_construction_all_same_shape(self) -> None:

grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
shapes=shape,
quantizer=None,
device="cuda",
dtype=torch.float32,
Expand All @@ -147,7 +147,7 @@ def test_basic_construction_varying_first_dim(self) -> None:

grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
shapes=shape,
quantizer=None,
device="cuda",
dtype=torch.float32,
Expand All @@ -170,14 +170,18 @@ def test_split_into_quantized_tensors_no_quantization(self) -> None:

grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
shapes=shape,
quantizer=None,
device="cuda",
dtype=torch.float32,
)

# Get the original data pointer
original_data_ptr = grouped_tensor.data.data_ptr()
# GroupedTensor is a wrapper; use backing storage buffer pointer.
storage = grouped_tensor.rowwise_data
if storage is None:
storage = grouped_tensor.columnwise_data
assert storage is not None
original_data_ptr = storage.data_ptr()

# Split into tensors
tensors = grouped_tensor.split_into_quantized_tensors()
Expand Down Expand Up @@ -207,13 +211,18 @@ def test_split_into_quantized_tensors_quantized(self, quantization: str) -> None

grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
shapes=shape,
quantizer=quantizer,
device="cuda",
dtype=torch.float32,
)

# Get the original data pointer
original_data_ptr = grouped_tensor.data.data_ptr()
# GroupedTensor is a wrapper; use backing storage buffer pointer.
storage = grouped_tensor.rowwise_data
if storage is None:
storage = grouped_tensor.columnwise_data
assert storage is not None
original_data_ptr = storage.data_ptr()

# Split into tensors
tensors = grouped_tensor.split_into_quantized_tensors()
Expand All @@ -236,13 +245,17 @@ def test_split_varying_shapes(self) -> None:

grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
shapes=shape,
quantizer=None,
device="cuda",
dtype=torch.float32,
)

original_data_ptr = grouped_tensor.data.data_ptr()
storage = grouped_tensor.rowwise_data
if storage is None:
storage = grouped_tensor.columnwise_data
assert storage is not None
original_data_ptr = storage.data_ptr()
tensors = grouped_tensor.split_into_quantized_tensors()

assert len(tensors) == num_tensors
Expand All @@ -264,13 +277,18 @@ def test_quantize_inplace(self, quantization: str) -> None:

grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
shapes=shape,
quantizer=quantizer,
device="cuda",
dtype=torch.float32,
)

# Get original data pointers before quantization
original_data_ptr = grouped_tensor.data.data_ptr()
storage = grouped_tensor.rowwise_data
if storage is None:
storage = grouped_tensor.columnwise_data
assert storage is not None
original_data_ptr = storage.data_ptr()
original_scale_inv_ptr = grouped_tensor.scale_inv.data_ptr()
original_scale_ptr = (
grouped_tensor.scale.data_ptr() if grouped_tensor.scale is not None else None
Expand All @@ -283,7 +301,7 @@ def test_quantize_inplace(self, quantization: str) -> None:
quantized_tensors = grouped_tensor.quantize(input_tensors)

# Verify data pointers haven't changed (in-place operation)
assert grouped_tensor.data.data_ptr() == original_data_ptr
assert storage.data_ptr() == original_data_ptr
assert grouped_tensor.scale_inv.data_ptr() == original_scale_inv_ptr
if original_scale_ptr is not None:
assert grouped_tensor.scale.data_ptr() == original_scale_ptr
Expand All @@ -304,13 +322,18 @@ def test_quantize_varying_shapes(self, quantization: str) -> None:

grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
shapes=shape,
quantizer=quantizer,
device="cuda",
dtype=torch.float32,
)

# Get original data pointers
original_data_ptr = grouped_tensor.data.data_ptr()
storage = grouped_tensor.rowwise_data
if storage is None:
storage = grouped_tensor.columnwise_data
assert storage is not None
original_data_ptr = storage.data_ptr()

# Create input tensors with varying shapes
input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape]
Expand All @@ -319,7 +342,7 @@ def test_quantize_varying_shapes(self, quantization: str) -> None:
quantized_tensors = grouped_tensor.quantize(input_tensors)

# Verify data pointer hasn't changed
assert grouped_tensor.data.data_ptr() == original_data_ptr
assert storage.data_ptr() == original_data_ptr

# Verify each tensor points to correct location
cumulative_numel = 0
Expand All @@ -329,38 +352,6 @@ def test_quantize_varying_shapes(self, quantization: str) -> None:
assert rowwise_data.data_ptr() == original_data_ptr + expected_offset
cumulative_numel += tensor_shape[0] * tensor_shape[1]

@pytest.mark.parametrize("quantization", _quantization_params)
def test_static_quantize_method(self, quantization: str) -> None:
"""Test the static quantize method"""
num_tensors = 3
shape = [(512, 512) for _ in range(num_tensors)]
quantizer = make_quantizer(quantization, num_tensors, shape)

# Create input tensors
input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape]

# Use static quantize method
grouped_tensor = GroupedTensor.create_and_quantize(
tensors=input_tensors,
quantizer=quantizer,
device="cuda",
)

# Verify the grouped tensor was created correctly
assert grouped_tensor.num_tensors == num_tensors
assert grouped_tensor.has_data()

# Verify quantized_tensors were created and point to same storage
assert grouped_tensor.quantized_tensors is not None
assert len(grouped_tensor.quantized_tensors) == num_tensors

original_data_ptr = grouped_tensor.data.data_ptr()
for i, qtensor in enumerate(grouped_tensor.quantized_tensors):
rowwise_data = _get_rowwise_data_tensor(qtensor, quantization)
numel = shape[i][0] * shape[i][1]
expected_offset = _rowwise_offset_bytes(i * numel, quantization)
assert rowwise_data.data_ptr() == original_data_ptr + expected_offset

@pytest.mark.parametrize(
"shape",
[[(256, 512), (512, 512), (768, 512)], [(512, 512), (512, 512), (512, 512)]],
Expand All @@ -374,9 +365,6 @@ def test_quantize_grouped_mxfp8(self, shape: List[Tuple[int, int]]) -> None:

# Create BF16 input tensors and pack into a 2D tensor
input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape]
quantized_tensors = [
MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(tensor) for tensor in input_tensors
]
grouped_input = torch.cat(input_tensors, dim=0)

# Create MXFP8 output grouped tensor (rowwise only for easier validation)
Expand Down Expand Up @@ -406,7 +394,7 @@ def test_quantize_grouped_mxfp8(self, shape: List[Tuple[int, int]]) -> None:
expected_data = torch.cat(expected_data)
expected_scale_inv = torch.cat(expected_scale_inv)

assert torch.equal(grouped_output.data, expected_data)
assert torch.equal(grouped_output.rowwise_data, expected_data)
assert torch.equal(grouped_output.scale_inv, expected_scale_inv)

@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8)
Expand Down Expand Up @@ -451,7 +439,7 @@ def test_group_quantize_cudagraph_capturable(self) -> None:
torch.cuda.synchronize()

expected = tex.group_quantize(static_input, quantizer, num_tensors, static_first_dims)
assert torch.equal(static_output.data, expected.data)
assert torch.equal(static_output.rowwise_data, expected.rowwise_data)
assert torch.equal(static_output.scale_inv, expected.scale_inv)

def test_clear(self) -> None:
Expand All @@ -461,7 +449,7 @@ def test_clear(self) -> None:

grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
shapes=shape,
quantizer=None,
device="cuda",
dtype=torch.float32,
Expand All @@ -474,5 +462,5 @@ def test_clear(self) -> None:

assert not grouped_tensor.has_data()
assert grouped_tensor.num_tensors == 0
assert grouped_tensor.data is None
assert grouped_tensor.rowwise_data is None
assert grouped_tensor.logical_shape == (0, 0)
Loading
Loading