Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/cpp/operator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ add_executable(test_operator
test_act.cu
test_normalization.cu
test_normalization_mxfp8.cu
test_cumsum.cu
test_memset.cu
test_multi_cast_transpose.cu
test_multi_padding.cu
Expand Down
92 changes: 92 additions & 0 deletions tests/cpp/operator/test_cumsum.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include <cstdint>
#include <string>
#include <tuple>
#include <vector>

#include <cuda_runtime.h>
#include <gtest/gtest.h>

#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"

using namespace transformer_engine;

namespace {

std::vector<int64_t> reference_cumsum_with_leading_zero(const std::vector<int64_t> &input) {
std::vector<int64_t> output(input.size() + 1, 0);
for (size_t i = 0; i < input.size(); ++i) {
output[i + 1] = output[i] + input[i];
}
return output;
}

void run_cumsum_test(const std::vector<int64_t> &h_input) {
const size_t n = h_input.size();
auto h_expected = reference_cumsum_with_leading_zero(h_input);
std::vector<int64_t> h_output(n + 1, 0);

int64_t *d_input = nullptr;
int64_t *d_output = nullptr;
NVTE_CHECK_CUDA(cudaMalloc(&d_input, n * sizeof(int64_t)));
NVTE_CHECK_CUDA(cudaMalloc(&d_output, (n + 1) * sizeof(int64_t)));

NVTE_CHECK_CUDA(
cudaMemcpy(d_input, h_input.data(), n * sizeof(int64_t), cudaMemcpyHostToDevice));
nvte_cumsum(d_input, d_output, n, 0 /* stream */);
NVTE_CHECK_CUDA(
cudaMemcpy(h_output.data(), d_output, (n + 1) * sizeof(int64_t), cudaMemcpyDeviceToHost));
NVTE_CHECK_CUDA(cudaDeviceSynchronize());

NVTE_CHECK_CUDA(cudaFree(d_input));
NVTE_CHECK_CUDA(cudaFree(d_output));

ASSERT_EQ(h_output.size(), h_expected.size());
for (size_t i = 0; i < h_output.size(); ++i) {
EXPECT_EQ(h_output[i], h_expected[i]) << "Mismatch at output index " << i;
}
}

std::vector<int64_t> make_input(size_t n) {
std::vector<int64_t> input(n);
for (size_t i = 0; i < n; ++i) {
// Deterministic signed values in [-3, 3].
input[i] = static_cast<int64_t>(i % 7) - 3;
}
return input;
}

std::vector<size_t> cumsum_test_sizes = {
1,
2,
17,
256,
257,
513,
1024,
};

} // namespace

TEST(CumsumTest, KnownValues) {
const std::vector<int64_t> input = {3, -1, 4, 0, -5};
run_cumsum_test(input);
}

class CumsumSizeTestSuite : public ::testing::TestWithParam<size_t> {};

TEST_P(CumsumSizeTestSuite, TestCumsumBySize) {
run_cumsum_test(make_input(GetParam()));
}

INSTANTIATE_TEST_SUITE_P(
OperatorTest, CumsumSizeTestSuite, ::testing::ValuesIn(cumsum_test_sizes),
[](const testing::TestParamInfo<CumsumSizeTestSuite::ParamType> &info) {
return "N" + std::to_string(info.param);
});
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from transformer_engine.pytorch.custom_recipes import utils
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.common.recipe import NVFP4BlockScaling
from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor
from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor

import pytest
import torch
Expand Down
75 changes: 23 additions & 52 deletions tests/pytorch/test_grouped_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
import torch
import transformer_engine.pytorch as te
from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor
from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor
from transformer_engine.pytorch import (
Quantizer,
Float8Quantizer,
Expand Down Expand Up @@ -125,8 +125,8 @@ def test_basic_construction_all_same_shape(self) -> None:

grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=None,
shapes=shape,
quantizers=None,
device="cuda",
dtype=torch.float32,
)
Expand All @@ -147,8 +147,8 @@ def test_basic_construction_varying_first_dim(self) -> None:

grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=None,
shapes=shape,
quantizers=None,
device="cuda",
dtype=torch.float32,
)
Expand All @@ -170,8 +170,8 @@ def test_split_into_quantized_tensors_no_quantization(self) -> None:

grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=None,
shapes=shape,
quantizers=None,
device="cuda",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each per-tensor quantizer constructed with full-group num_tensors

Each list entry calls make_quantizer(quantization, num_tensors, shape) with num_tensors=3, meaning each quantizer's internal buffers (e.g., FP8 amax/scale tensors) are sized for the entire group of 3 tensors, not for a single tensor. While this doesn't break correctness today (only index 0 of the per-quantizer buffers is used), it inflates memory usage and diverges from production use, where each per-tensor quantizer should be sized for one tensor.

Consider constructing each quantizer for num_tensors=1:

quantizers = [make_quantizer(quantization, 1, shape) for _ in range(num_tensors)]

dtype=torch.float32,
)
Expand Down Expand Up @@ -203,13 +203,14 @@ def test_split_into_quantized_tensors_quantized(self, quantization: str) -> None
"""Test split_into_quantized_tensors for quantized tensors"""
num_tensors = 3
shape = [(512, 512) for _ in range(num_tensors)]
quantizer = make_quantizer(quantization, num_tensors, shape)
quantizers = [make_quantizer(quantization, num_tensors, shape) for _ in range(num_tensors)]

grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=quantizer,
shapes=shape,
quantizers=quantizers,
device="cuda",
dtype=torch.float32,
)

# Get the original data pointer
Expand All @@ -236,8 +237,8 @@ def test_split_varying_shapes(self) -> None:

grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=None,
shapes=shape,
quantizers=None,
device="cuda",
dtype=torch.float32,
)
Expand All @@ -260,13 +261,14 @@ def test_quantize_inplace(self, quantization: str) -> None:
"""Test that quantize is done in-place for all recipes"""
num_tensors = 3
shape = [(512, 512) for _ in range(num_tensors)]
quantizer = make_quantizer(quantization, num_tensors, shape)
quantizers = [make_quantizer(quantization, num_tensors, shape) for _ in range(num_tensors)]

grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=quantizer,
shapes=shape,
quantizers=quantizers,
device="cuda",
dtype=torch.float32,
)

# Get original data pointers before quantization
Expand Down Expand Up @@ -300,13 +302,14 @@ def test_quantize_varying_shapes(self, quantization: str) -> None:
"""Test quantize with varying shapes"""
num_tensors = 3
shape = [(256, 512), (512, 512), (768, 512)]
quantizer = make_quantizer(quantization, num_tensors, shape)
quantizers = [make_quantizer(quantization, num_tensors, shape) for _ in range(num_tensors)]

grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=quantizer,
shapes=shape,
quantizers=quantizers,
device="cuda",
dtype=torch.float32,
)

# Get original data pointers
Expand All @@ -329,38 +332,6 @@ def test_quantize_varying_shapes(self, quantization: str) -> None:
assert rowwise_data.data_ptr() == original_data_ptr + expected_offset
cumulative_numel += tensor_shape[0] * tensor_shape[1]

@pytest.mark.parametrize("quantization", _quantization_params)
def test_static_quantize_method(self, quantization: str) -> None:
"""Test the static quantize method"""
num_tensors = 3
shape = [(512, 512) for _ in range(num_tensors)]
quantizer = make_quantizer(quantization, num_tensors, shape)

# Create input tensors
input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape]

# Use static quantize method
grouped_tensor = GroupedTensor.create_and_quantize(
tensors=input_tensors,
quantizer=quantizer,
device="cuda",
)

# Verify the grouped tensor was created correctly
assert grouped_tensor.num_tensors == num_tensors
assert grouped_tensor.has_data()

# Verify quantized_tensors were created and point to same storage
assert grouped_tensor.quantized_tensors is not None
assert len(grouped_tensor.quantized_tensors) == num_tensors

original_data_ptr = grouped_tensor.data.data_ptr()
for i, qtensor in enumerate(grouped_tensor.quantized_tensors):
rowwise_data = _get_rowwise_data_tensor(qtensor, quantization)
numel = shape[i][0] * shape[i][1]
expected_offset = _rowwise_offset_bytes(i * numel, quantization)
assert rowwise_data.data_ptr() == original_data_ptr + expected_offset

@pytest.mark.parametrize(
"shape",
[[(256, 512), (512, 512), (768, 512)], [(512, 512), (512, 512), (512, 512)]],
Expand Down Expand Up @@ -461,8 +432,8 @@ def test_clear(self) -> None:

grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=None,
shapes=shape,
quantizers=None,
device="cuda",
dtype=torch.float32,
)
Expand Down
Loading
Loading