Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2388,6 +2388,14 @@ def fused_apply_rotary_pos_emb_thd(
fused_sort_chunks_by_index_with_probs = None
fused_unpermute = None

try:
from transformer_engine.pytorch.permutation import moe_permute_and_pad_with_probs

fused_permute_and_pad_with_probs = moe_permute_and_pad_with_probs

except ImportError:
fused_permute_and_pad_with_probs = None

try:
from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy

Expand Down
41 changes: 36 additions & 5 deletions megatron/core/transformer/moe/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from megatron.core.transformer.cuda_graphs import is_graph_capturing
from megatron.core.transformer.enums import CudaGraphScope
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import internal_api
from megatron.core.utils import internal_api, is_te_min_version

try:
import transformer_engine as te # pylint: disable=unused-import
Expand All @@ -24,6 +24,7 @@
fused_compute_score_for_moe_aux_loss,
fused_moe_aux_loss,
fused_permute,
fused_permute_and_pad_with_probs,
fused_permute_with_probs,
fused_sort_chunks_by_index,
fused_sort_chunks_by_index_with_probs,
Expand Down Expand Up @@ -232,6 +233,8 @@ def permute(
num_out_tokens: Optional[int] = None,
fused: bool = False,
drop_and_pad: bool = False,
tokens_per_expert: Optional[torch.Tensor] = None,
align_size: int = -1,
):
"""Permute the tokens and probs based on the mask.
Tokens with the same designated expert will be grouped together.
Expand All @@ -252,6 +255,10 @@ def permute(
and pads the number of tokens to the expert capacity.
If set to true, routing_map has a fixed number of non-zeros
in each column.
tokens_per_expert : torch.Tensor
Tensor of shape `[num_experts]` containing actual token counts per expert.
align_size : int
the alignment size for the input tensor for fp8 or fp4.
"""
if fused and probs is None:
if not HAVE_TE or fused_permute is None:
Expand All @@ -262,11 +269,23 @@ def permute(
return permuted_input, None, sorted_indices

if fused and probs is not None:
if not HAVE_TE or fused_permute_with_probs is None:
if not HAVE_TE or (
fused_permute_and_pad_with_probs is None and fused_permute_with_probs is None
):
raise ValueError(
"fused_permute_with_probs is not available. Please install TE >= 2.1.0."
"Transformer Engine (TE) fused kernel is not available. "
"fused_permute_with_probs typically requires TE >= 2.1.0, and "
"fused_permute_and_pad_with_probs` typically requires TE >= 2.12.0. "
)
if fused_permute_and_pad_with_probs is not None:
return fused_permute_and_pad_with_probs(
tokens, probs, routing_map, tokens_per_expert, align_size
)
else:
output, permuted_probs, row_id_map = fused_permute_with_probs(
tokens, probs, routing_map, num_out_tokens=num_out_tokens
)
return fused_permute_with_probs(tokens, probs, routing_map, num_out_tokens=num_out_tokens)
return output, permuted_probs, row_id_map, None, tokens_per_expert

num_tokens, hidden = tokens.shape
num_experts = routing_map.shape[1]
Expand Down Expand Up @@ -320,6 +339,7 @@ def unpermute(
routing_map: Optional[torch.Tensor] = None,
fused: bool = False,
drop_and_pad: bool = False,
pad_offsets: Optional[torch.Tensor] = None,
Comment thread
xiaoxi-wangfj marked this conversation as resolved.
):
"""
Restore the original order of tokens after permutation. If probs are provided, it
Expand All @@ -341,15 +361,26 @@ def unpermute(
fused (bool, optional): Whether use the fused unpermute function.
drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop
and pads the number of tokens to the expert capacity.
pad_offsets : torch.Tensor, default = None
Tensor of per-expert cumulative padding offsets used to remove padding added
during permutation. This is the fourth output of `moe_permute_and_pad_with_probs`
and is required when unpermuting padded outputs.

Returns:
torch.Tensor: The tokens restored to their original order.
"""
if fused:
if not HAVE_TE or fused_unpermute is None:
raise ValueError("fused_unpermute is not available. Please install TE >= 2.1.0.")
extra_kwargs = {}
if is_te_min_version("2.12.0"):
extra_kwargs["pad_offsets"] = pad_offsets
return fused_unpermute(
permuted_tokens, sorted_indices, merging_probs=probs, restore_shape=restore_shape
permuted_tokens,
sorted_indices,
merging_probs=probs,
restore_shape=restore_shape,
**extra_kwargs,
)

_, hidden = restore_shape
Expand Down
11 changes: 10 additions & 1 deletion megatron/core/transformer/moe/token_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1295,12 +1295,20 @@ def get_permuted_hidden_states_by_experts(self, hidden_states: torch.Tensor) ->

self.hidden_shape_before_permute = hidden_states.shape
assert self.dispatched_probs.dtype == torch.float32, "DeepEP only supports float32 probs"
hidden_states, permuted_probs, self.reversed_mapping_for_combine = permute(
(
hidden_states,
permuted_probs,
self.reversed_mapping_for_combine,
self.pad_offsets,
self.tokens_per_expert,
) = permute(
hidden_states,
self.dispatched_routing_map,
probs=self.dispatched_probs,
num_out_tokens=self.tokens_per_expert.sum().item(),
fused=self.permute_fusion,
tokens_per_expert=self.tokens_per_expert,
align_size=get_align_size_for_quantization(self.config),
)
if self.router_dtype == "fp64":
permuted_probs = permuted_probs.to(torch.float64)
Expand All @@ -1313,6 +1321,7 @@ def get_restored_hidden_states_by_experts(self, hidden_states: torch.Tensor) ->
restore_shape=self.hidden_shape_before_permute,
routing_map=self.dispatched_routing_map,
fused=self.permute_fusion,
pad_offsets=self.pad_offsets,
)
return hidden_states

Expand Down
Loading