Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
658 changes: 649 additions & 9 deletions tests/pytorch/test_permutation.py

Large diffs are not rendered by default.

38 changes: 29 additions & 9 deletions transformer_engine/common/triton/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def _permute_kernel(
probs_ptr,
scale_ptr,
permuted_scale_ptr,
pad_offsets_ptr,
# sizes
scale_hidden_dim,
# strides
Expand All @@ -224,8 +225,11 @@ def _permute_kernel(
hidden_size: tl.constexpr,
PERMUTE_PROBS: tl.constexpr,
PERMUTE_SCALE: tl.constexpr,
FUSION_PAD: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
expert_idx = 0

pid_t = tl.program_id(0)
pid_h = tl.program_id(1)
cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
Expand All @@ -246,18 +250,22 @@ def _permute_kernel(
dst_row = tl.load(
row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
).to(tl.int64)
if FUSION_PAD or PERMUTE_PROBS:
expert_idx = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
if FUSION_PAD:
pad_off = tl.load(pad_offsets_ptr + expert_idx)
dst_row = dst_row + pad_off
output_off = dst_row * stride_output_token + cur_off * stride_output_hidden
if PERMUTE_SCALE:
permuted_scale_off = (
dst_row * stride_permuted_scale_token + cur_off * stride_permuted_scale_hidden
)
tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale)
if PERMUTE_PROBS:
expert_idx = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
prob_off = pid_t * stride_probs_token + expert_idx * stride_probs_expert
prob = tl.load(probs_ptr + prob_off)
if pid_h == 0:
Expand Down Expand Up @@ -297,6 +305,7 @@ def _unpermute_kernel(
row_id_map_ptr,
merging_probs_ptr,
permuted_probs_ptr,
pad_offsets_ptr,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
Expand All @@ -318,10 +327,12 @@ def _unpermute_kernel(
PROBS_LOAD_WIDTH: tl.constexpr,
WITH_MERGING_PROBS: tl.constexpr,
PERMUTE_PROBS: tl.constexpr,
FUSION_UNPAD: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
data_type = input_ptr.dtype.element_ty
compute_type = tl.float32
expert_idx = 0

pid_t = tl.program_id(0)
pid_h = tl.program_id(1)
Expand All @@ -348,15 +359,19 @@ def _unpermute_kernel(
src_row = tl.load(
row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
).to(tl.int64)
input_off = src_row * stride_input_token + current_offset * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
inp = inp.to(compute_type)
if WITH_MERGING_PROBS:
if FUSION_UNPAD or WITH_MERGING_PROBS:
expert_idx = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
if FUSION_UNPAD:
pad_off = tl.load(pad_offsets_ptr + expert_idx)
src_row = src_row + pad_off
input_off = src_row * stride_input_token + current_offset * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
inp = inp.to(compute_type)
if WITH_MERGING_PROBS:
merging_prob_off = (
pid_t * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
)
Expand Down Expand Up @@ -407,6 +422,7 @@ def _unpermute_bwd_with_merging_probs_kernel(
fwd_input_ptr,
merging_probs_ptr,
row_id_map_ptr,
pad_offsets_ptr,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
Expand All @@ -427,6 +443,7 @@ def _unpermute_bwd_with_merging_probs_kernel(
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
PROBS_LOAD_WIDTH: tl.constexpr,
FUSION_UNPAD: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
data_type = fwd_output_grad_ptr.dtype.element_ty
Expand All @@ -450,6 +467,9 @@ def _unpermute_bwd_with_merging_probs_kernel(
+ pid * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
if FUSION_UNPAD:
pad_off = tl.load(pad_offsets_ptr + expert_idx)
dst_row = dst_row + pad_off
prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type)
current_start = 0
while current_start < hidden_size:
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from transformer_engine.pytorch.permutation import (
moe_permute,
moe_permute_with_probs,
moe_permute_and_pad_with_probs,
moe_unpermute,
moe_sort_chunks_by_index,
moe_sort_chunks_by_index_with_probs,
Expand Down
39 changes: 28 additions & 11 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from ..cpu_offload import is_cpu_offload_enabled, mark_not_offload, start_offload

from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockwiseQTensor
from ..quantized_tensor import (
QuantizedTensorStorage,
Quantizer,
Expand Down Expand Up @@ -143,7 +144,12 @@ def forward(
inp_view = inp.reshape(-1, in_features)
inputmats: list
if fp8 and not debug:
inputmats = tex.split_quantize(inp_view, m_splits, input_quantizers)
if isinstance(inp_view, Float8BlockwiseQTensor):
inputmats = inp_view.split_scaling_aware_fp8_transpose(
m_splits, input_quantizers
)
else:
inputmats = tex.split_quantize(inp_view, m_splits, input_quantizers)
elif debug:
inputmats = DebugQuantizer.multi_tensor_quantize(
inp_view, input_quantizers, m_splits, activation_dtype
Expand Down Expand Up @@ -343,18 +349,28 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
# Unfused bias grad and multi-tensor quantize
for i in range(ctx.num_gemms):
grad_biases[i] = grad_output_mats[i].sum(dim=0)
if isinstance(grad_output_view, Float8BlockwiseQTensor):
grad_output = grad_output_view.split_scaling_aware_fp8_transpose(
ctx.m_splits, ctx.grad_output_quantizers
)
else:
grad_output = tex.split_quantize(
grad_output_view,
ctx.m_splits,
ctx.grad_output_quantizers,
)
else:
# Multi-tensor quantize
if isinstance(grad_output_view, Float8BlockwiseQTensor):
grad_output = grad_output_view.split_scaling_aware_fp8_transpose(
ctx.m_splits, ctx.grad_output_quantizers
)
else:
grad_output = tex.split_quantize(
grad_output_view,
ctx.m_splits,
ctx.grad_output_quantizers,
)
else:
# Multi-tensor quantize
grad_output = tex.split_quantize(
grad_output_view,
ctx.m_splits,
ctx.grad_output_quantizers,
)
elif ctx.debug:
grad_output_mats = torch.split(grad_output_view, ctx.m_splits)
for i in range(ctx.num_gemms):
Expand Down Expand Up @@ -781,9 +797,10 @@ def forward(
"""
debug = self.is_debug_iter()

assert not isinstance(
inp, QuantizedTensorStorage
), "GroupedLinear doesn't support input tensor in FP8."
if not isinstance(inp, Float8BlockwiseQTensor):
assert not isinstance(
inp, QuantizedTensorStorage
), "GroupedLinear doesn't support input tensor in FP8."
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."

is_grad_enabled = torch.is_grad_enabled()
Expand Down
Loading
Loading