Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
7553e6a
Stage 1&2: Python containers + quantize/gemm dispatch/unwrap
negvet Mar 31, 2026
19acc5e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 31, 2026
3f33f40
Merge branch 'main' into hybrid_quantization
negvet Apr 6, 2026
f80f5d0
Enable quantized_model_init
negvet Apr 16, 2026
2185b30
FSDP support
negvet Apr 17, 2026
f22a395
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2026
c9754e9
Merge branch 'main' into hybrid_quantization
negvet Apr 17, 2026
103fffe
Enable CPU offloading
negvet Apr 22, 2026
16fb371
Activation recomputation
negvet Apr 24, 2026
a50fd63
TP/SP
negvet Apr 24, 2026
2214843
Resolve comments: hybrid uniform list, make_empty try, __repr__, etc.
negvet Apr 24, 2026
88fe467
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 24, 2026
f649cc9
Merge branch 'main' into hybrid_quantization
negvet Apr 29, 2026
4858491
Respect usage
negvet Apr 29, 2026
ef31a9a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 29, 2026
e7136fd
Merge main
negvet May 13, 2026
c7da5b2
Misc minor fixes: comments, tests, etc.
negvet May 20, 2026
a164cd3
Towards MLM integration
negvet May 21, 2026
62e7668
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 21, 2026
4b0d18c
Merge main
negvet Jun 1, 2026
7316516
Resolve comments: improve fsdp/tp/sp tests + amax reduction fix
negvet Jun 3, 2026
5892a74
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2026
ec7b84c
Enable FSDP2 hybrid protocol for Float8Block tensor
negvet Jun 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,618 changes: 1,618 additions & 0 deletions tests/pytorch/test_hybrid_quantization.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
static_assert(std::is_same<CType, float>::value);
const CType scale_inv = 1.0f / block_tile_scale;

size_t row_idx = tile_id_y;
size_t col_idx = tile_id_x;
tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv;
if (tile_scales_inv_c != nullptr) {
size_t row_idx = tile_id_y;
size_t col_idx = tile_id_x;
tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv;
}

if constexpr (kReturnTranspose) {
row_idx = tile_id_x;
col_idx = tile_id_y;
size_t row_idx = tile_id_x;
size_t col_idx = tile_id_y;
tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv;
}
}
Expand All @@ -189,7 +191,9 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
thrd_tile_out_trans[j].data.elt[i] = scaled_elt;
}
}
tmp_output_c.store_to(output_c + thread_tile_start_idx + i * row_length);
if (output_c != nullptr) {
tmp_output_c.store_to(output_c + thread_tile_start_idx + i * row_length);
}
}

// Step 4: store transpose into shared memory
Expand Down Expand Up @@ -388,13 +392,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
static_assert(std::is_same<CType, float>::value);
const CType scale_inv = 1.0f / block_tile_scale;

size_t row_idx = tile_id_y;
size_t col_idx = tile_id_x;
tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv;
if (tile_scales_inv_c != nullptr) {
size_t row_idx = tile_id_y;
size_t col_idx = tile_id_x;
tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv;
}

if constexpr (kReturnTranspose) {
row_idx = tile_id_x;
col_idx = tile_id_y;
size_t row_idx = tile_id_x;
size_t col_idx = tile_id_y;
tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv;
}
}
Expand Down Expand Up @@ -433,8 +439,10 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
thrd_tile_out_trans[j].data.elt[i] = scaled_elt;
}
}
tmp_output_c.store_to_elts(output_c + thread_tile_start_idx + i * row_length, 0,
thread_tile_ncols);
if (output_c != nullptr) {
tmp_output_c.store_to_elts(output_c + thread_tile_start_idx + i * row_length, 0,
thread_tile_ncols);
}
}

if constexpr (kReturnTranspose) {
Expand Down Expand Up @@ -492,19 +500,26 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
"with MXFP8, which requires using power of two scaling factors.");
}

NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape.");
const bool return_identity = output.dptr != nullptr;
if (return_identity) {
NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape.");
}
NVTE_CHECK(return_identity || return_transpose,
"At least one of rowwise or columnwise output must be requested.");
const size_t row_length = input.shape.size() > 0 ? input.shape.back() : 1;
size_t num_rows = 1;
for (size_t i = 0; (i < input.shape.size() - 1) && (input.shape.size() > 0); ++i) {
num_rows *= input.shape.at(i);
}

NVTE_CHECK(scale_inv.shape.size() == 2, "scale_inv must have 2 dimensions.");

size_t scale_k = scale_inv.shape[1];

const size_t scale_stride_x = 1;
const size_t scale_stride_y = scale_k;
size_t scale_k = 0;
const size_t scale_stride_x = return_identity ? 1 : 0;
size_t scale_stride_y = 0;
if (return_identity) {
NVTE_CHECK(scale_inv.shape.size() == 2, "scale_inv must have 2 dimensions.");
scale_k = scale_inv.shape[1];
scale_stride_y = scale_k;
}

size_t scale_t_stride_x = 0;
size_t scale_t_stride_y = 0;
Expand All @@ -522,22 +537,26 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
") and output_t (shape=", output_t.shape, ") have incompatible dims.");
}
}
NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same type.");
if (return_identity) {
NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same type.");
}

NVTE_CHECK(scale_inv_t.shape.size() == 2, "scale_inv_t must have 2 dimensions.");

scale_t_stride_x = 1;
scale_t_stride_y = scale_inv_t.shape[1];
}

const auto out_dtype = return_identity ? output.dtype : output_t.dtype;

const size_t num_blocks_x = DIVUP(row_length, BLOCK_TILE_DIM);
const size_t num_blocks_y = DIVUP(num_rows, BLOCK_TILE_DIM);

TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype, InputType,

TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output.dtype, OutputType,
out_dtype, OutputType,

TRANSFORMER_ENGINE_SWITCH_CONDITION(
return_transpose, kReturnTranspose,
Expand Down
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@
from transformer_engine.pytorch.tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor import Float8BlockwiseQTensor
from transformer_engine.pytorch.tensor import NVFP4Tensor
from transformer_engine.pytorch.tensor import HybridQuantizer
from transformer_engine.pytorch.tensor import HybridQuantizedTensorStorage
from transformer_engine.pytorch.tensor import HybridQuantizedTensor

try:
torch._dynamo.config.error_on_nested_jit_trace = False
Expand Down
37 changes: 37 additions & 0 deletions transformer_engine/pytorch/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ..quantized_tensor import Quantizer
from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ..tensor.utils import is_custom
from ..tensor.storage.hybrid_tensor_storage import HybridQuantizedTensorStorage
from ..custom_recipes.gemm import custom_gemm
from ...debug.pytorch.debug_quantization import DebugQuantizer

Expand Down Expand Up @@ -69,6 +70,36 @@ def validate_gemm_scale(scale: Optional[float], required: bool) -> float:
return 0.0


def _unwrap_hybrid_A(tensor, layout):
"""Extract the direction-appropriate native sub-storage for GEMM operand A.

Operand A's data direction is determined by its transpose flag (layout[0]):
T (transposed) → rowwise sub-storage (.data consumed by C++)
N (not-transposed) → columnwise sub-storage (.columnwise_data consumed by C++)
For non-hybrid tensors this is a no-op passthrough.
"""
if not isinstance(tensor, HybridQuantizedTensorStorage):
return tensor
if layout[0] == "T":
return tensor.rowwise_sub_storage
return tensor.columnwise_sub_storage


def _unwrap_hybrid_B(tensor, layout):
"""Extract the direction-appropriate native sub-storage for GEMM operand B.

Operand B's data direction is determined by its transpose flag (layout[1]):
N (not-transposed) → rowwise sub-storage (.data consumed by C++)
T (transposed) → columnwise sub-storage (.columnwise_data consumed by C++)
For non-hybrid tensors this is a no-op passthrough.
"""
if not isinstance(tensor, HybridQuantizedTensorStorage):
return tensor
if layout[1] == "N":
return tensor.rowwise_sub_storage
return tensor.columnwise_sub_storage


def general_gemm(
A: torch.Tensor,
B: torch.Tensor,
Expand All @@ -95,6 +126,9 @@ def general_gemm(
transa = layout[0] == "T"
transb = layout[1] == "T"

A = _unwrap_hybrid_A(A, layout)
B = _unwrap_hybrid_B(B, layout)

alpha = validate_gemm_scale(alpha, True)
beta = validate_gemm_scale(beta, accumulate)
workspace = get_cublas_workspace(A.device.index, ub is not None, False)
Expand Down Expand Up @@ -204,6 +238,9 @@ def general_grouped_gemm(
"""
num_gemms = len(A)

A = [_unwrap_hybrid_A(a, layout) for a in A]
B = [_unwrap_hybrid_B(b, layout) for b in B]

transa = layout[0] == "T"
transb = layout[1] == "T"

Expand Down
6 changes: 4 additions & 2 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.hybrid_tensor import HybridQuantizer
from ..tensor.storage.float8_tensor_storage import Float8TensorStorage
from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from ..tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage
Expand Down Expand Up @@ -1258,8 +1259,9 @@ def grad_output_preprocess(
):
grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0)
else:
if isinstance(quantizer, Float8BlockQuantizer):
# unfuse bgrad for now until cast_transpose + dgrad calculation is ready for Float8BlockQuantizer.
if isinstance(quantizer, (Float8BlockQuantizer, HybridQuantizer)):
# Float8BlockQuantizer: unfused until cast_transpose + dgrad is ready.
# HybridQuantizer: tex.bgrad_quantize doesn't recognize hybrid quantizers.
grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
else:
grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer)
Expand Down
66 changes: 63 additions & 3 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,50 @@
prepare_for_saving,
restore_from_saved,
)
from ..tensor.hybrid_tensor import HybridQuantizer
from ...debug.pytorch.debug_quantization import DebugQuantizer
from ...debug.pytorch.debug_state import TEDebugState


def _has_hybrid_quantizer(quantizers):
"""Check if any quantizer in the list is a HybridQuantizer."""
return any(isinstance(q, HybridQuantizer) for q in quantizers if q is not None)


def _hybrid_split_quantize(tensor, m_splits, quantizers):
"""Grouped split+quantize for HybridQuantizer lists.

Runs tex.split_quantize twice (once per direction with the native
sub-quantizers), then zips the results into HybridQuantizedTensorStorage.
Non-hybrid quantizers in the list fall back to per-split Python quantize.
"""
from ..tensor.storage.hybrid_tensor_storage import HybridQuantizedTensorStorage as HybridStorage

row_quantizers = [q.rowwise_quantizer for q in quantizers]
col_quantizers = [q.columnwise_quantizer for q in quantizers]

row_results = tex.split_quantize(tensor, m_splits, row_quantizers)
col_results = tex.split_quantize(tensor, m_splits, col_quantizers)

return [
HybridStorage(
rowwise_storage=row,
columnwise_storage=col,
rowwise_quantizer=rq,
columnwise_quantizer=cq,
quantizer=q,
fake_dtype=tensor.dtype,
Comment thread
negvet marked this conversation as resolved.
)
for row, col, rq, cq, q in zip(
row_results,
col_results,
row_quantizers,
col_quantizers,
quantizers,
)
]


__all__ = ["GroupedLinear"]


Expand Down Expand Up @@ -144,7 +185,8 @@ def forward(
)
inp_view = inp.reshape(-1, in_features)
inputmats: list
if fp8 and not debug:
hybrid = _has_hybrid_quantizer(input_quantizers)
if fp8 and not debug and not hybrid:
# Disable bulk allocation when CPU offloading is active: offloading skips small
# tensors (like scales), but bulk allocation shares storage across all tensors,
# so if scales can't be offloaded, nothing in the group can be offloaded.
Expand All @@ -154,6 +196,8 @@ def forward(
input_quantizers,
disable_bulk_allocation=cpu_offloading,
)
elif fp8 and hybrid:
inputmats = _hybrid_split_quantize(inp_view, m_splits, input_quantizers)
elif debug:
inputmats = DebugQuantizer.multi_tensor_quantize(
inp_view, input_quantizers, m_splits, activation_dtype
Expand Down Expand Up @@ -338,7 +382,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1])
grad_output = [None] * ctx.num_gemms
grad_biases = [None] * ctx.num_gemms
if ctx.fp8 and not ctx.debug:
grad_output_hybrid = _has_hybrid_quantizer(ctx.grad_output_quantizers)
if ctx.fp8 and not ctx.debug and not grad_output_hybrid:
if ctx.use_bias:
grad_output_mats = torch.split(grad_output_view, ctx.m_splits)
recipe = ctx.fp8_recipe
Expand All @@ -365,6 +410,16 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
ctx.m_splits,
ctx.grad_output_quantizers,
)
elif ctx.fp8 and grad_output_hybrid:
if ctx.use_bias:
grad_output_mats = torch.split(grad_output_view, ctx.m_splits)
for i in range(ctx.num_gemms):
grad_biases[i] = grad_output_mats[i].sum(dim=0)
grad_output = _hybrid_split_quantize(
grad_output_view,
ctx.m_splits,
ctx.grad_output_quantizers,
)
elif ctx.debug:
grad_output_mats = torch.split(grad_output_view, ctx.m_splits)
for i in range(ctx.num_gemms):
Expand Down Expand Up @@ -451,8 +506,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
else:
input_quantizer.set_usage(rowwise=False, columnwise=True)
inputmats: list
if ctx.fp8 and not ctx.debug:
input_hybrid = _has_hybrid_quantizer(ctx.input_quantizers)
if ctx.fp8 and not ctx.debug and not input_hybrid:
inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers)
elif ctx.fp8 and input_hybrid:
inputmats = _hybrid_split_quantize(
inp_view, ctx.m_splits, ctx.input_quantizers
)
elif ctx.debug:
inputmats = DebugQuantizer.multi_tensor_quantize(
inp_view,
Expand Down
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
)
from ...debug.pytorch.debug_state import TEDebugState
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.hybrid_tensor import HybridQuantizer
from ..cpu_offload import (
is_cpu_offload_enabled,
start_offload,
Expand Down Expand Up @@ -206,12 +207,14 @@ def forward(
# Avoid quantized norm kernel if norm output will be returned
# or if a gather of ln_out must be in high precision.
custom = is_custom(input_quantizer)
hybrid = isinstance(input_quantizer, HybridQuantizer)
with_quantized_norm = (
fp8
and not debug
and not return_layernorm_output
and not return_layernorm_output_gathered
and not custom # TODO(negvet): and not FP8GlobalStateManager.get_fp8_recipe().custom()
and not hybrid
)

# Apply normalization
Expand Down
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.nvfp4_tensor import NVFP4Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.hybrid_tensor import HybridQuantizer
from ._common import apply_normalization, WeightGradStore
from ..cpu_offload import (
is_cpu_offload_enabled,
Expand Down Expand Up @@ -390,12 +391,14 @@ def _forward(
# for debug: : layernorm output = High precision to enable processing of this norm

custom = is_custom(fc1_input_quantizer)
hybrid = isinstance(fc1_input_quantizer, HybridQuantizer)
with_quantized_norm = (
fp8
and not debug
and not return_layernorm_output
and not return_layernorm_output_gathered
and not custom
and not hybrid
)

# Apply normalization
Expand Down
Loading
Loading