Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2417,6 +2417,14 @@ def fused_apply_rotary_pos_emb_thd(
fused_sort_chunks_by_index_with_probs = None
fused_unpermute = None

try:
from transformer_engine.pytorch.permutation import moe_permute_and_pad_with_probs

fused_permute_and_pad_with_probs = moe_permute_and_pad_with_probs

except ImportError:
fused_permute_and_pad_with_probs = None

try:
from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy

Expand Down
66 changes: 56 additions & 10 deletions megatron/core/transformer/moe/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from megatron.core.transformer.enums import CudaGraphScope
from megatron.core.transformer.moe.router_replay import RouterReplay
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import internal_api
from megatron.core.utils import internal_api, is_te_min_version

try:
import transformer_engine as te # pylint: disable=unused-import
Expand All @@ -26,6 +26,7 @@
fused_compute_score_for_moe_aux_loss,
fused_moe_aux_loss,
fused_permute,
fused_permute_and_pad_with_probs,
fused_permute_with_probs,
fused_sort_chunks_by_index,
fused_sort_chunks_by_index_with_probs,
Expand Down Expand Up @@ -295,7 +296,15 @@ def permute(
num_out_tokens: Optional[int] = None,
fused: bool = False,
drop_and_pad: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
tokens_per_expert: Optional[torch.Tensor] = None,
align_size: int = -1,
) -> Tuple[
torch.Tensor,
Optional[torch.Tensor],
torch.Tensor,
Optional[torch.Tensor],
Optional[torch.Tensor],
]:
"""Permute the tokens and probs based on the mask.
Tokens with the same designated expert will be grouped together.
The shape of mask is [tokens, num_experts], it indicates which experts were selected
Expand All @@ -304,6 +313,9 @@ def permute(
When drop_and_pad=True, in routing_map, the number of non-zeros in each column equals to
expert capacity. This function exploits this feature to use ops that support cuda graph.

If the fused permute and pad kernel is available, it will pad the tokens to the align_size
and return the padded permuted tokens, pad_offsets and padded tokens per expert.

Args:
tokens (torch.Tensor): The input token tensor, [num_tokens, hidden].
routing_map (torch.Tensor): The sparse token to expert mapping, [num_tokens, num_experts].
Expand All @@ -315,25 +327,47 @@ def permute(
and pads the number of tokens to the expert capacity.
If set to true, routing_map has a fixed number of non-zeros
in each column.
tokens_per_expert (torch.Tensor, optional): Tensor of shape `[num_experts]` containing
actual token counts per expert.
align_size (int, optional): The alignment size for the input tensor for fp8 or fp4.

Returns:
Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
The permuted tokens, permuted probs, and sorted indices.
Tuple[
torch.Tensor,
Optional[torch.Tensor],
torch.Tensor,
Optional[torch.Tensor],
Optional[torch.Tensor],
]:
The permuted tokens, (optional) permuted probs, sorted indices,
(optional) pad_offsets, (optional) padded_tokens_per_expert.
"""
if fused and probs is None:
if not HAVE_TE or fused_permute is None:
raise ValueError("fused_permute is not available. Please install TE >= 2.1.0.")
permuted_input, sorted_indices = fused_permute(
tokens, routing_map, num_out_tokens=num_out_tokens
)
return permuted_input, None, sorted_indices
return permuted_input, None, sorted_indices, None, tokens_per_expert

if fused and probs is not None:
if not HAVE_TE or fused_permute_with_probs is None:
if not HAVE_TE or (
fused_permute_and_pad_with_probs is None and fused_permute_with_probs is None
):
raise ValueError(
"fused_permute_with_probs is not available. Please install TE >= 2.1.0."
"Transformer Engine (TE) fused kernel is not available. "
"fused_permute_with_probs typically requires TE >= 2.1.0, and "
"fused_permute_and_pad_with_probs` typically requires TE >= 2.12.0. "
)
if fused_permute_and_pad_with_probs is not None:
return fused_permute_and_pad_with_probs(
tokens, probs, routing_map, tokens_per_expert, align_size
)
else:
output, permuted_probs, row_id_map = fused_permute_with_probs(
tokens, probs, routing_map, num_out_tokens=num_out_tokens
)
return fused_permute_with_probs(tokens, probs, routing_map, num_out_tokens=num_out_tokens)
return output, permuted_probs, row_id_map, None, tokens_per_expert

num_tokens, hidden = tokens.shape
num_experts = routing_map.shape[1]
Expand Down Expand Up @@ -376,7 +410,7 @@ def permute(
# use the mapping to permute the tokens
permuted_input = tokens.index_select(0, sorted_indices)

return permuted_input, permuted_probs, sorted_indices
return permuted_input, permuted_probs, sorted_indices, None, tokens_per_expert


def unpermute(
Expand All @@ -387,6 +421,7 @@ def unpermute(
routing_map: Optional[torch.Tensor] = None,
fused: bool = False,
drop_and_pad: bool = False,
pad_offsets: Optional[torch.Tensor] = None,
Comment thread
xiaoxi-wangfj marked this conversation as resolved.
) -> torch.Tensor:
"""
Restore the original order of tokens after permutation. If probs are provided, it
Expand All @@ -408,15 +443,26 @@ def unpermute(
fused (bool, optional): Whether use the fused unpermute function.
drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop
and pads the number of tokens to the expert capacity.
pad_offsets (torch.Tensor, optional):
Tensor of per-expert cumulative padding offsets used to remove padding added
during permutation. This is the fourth output of `moe_permute_and_pad_with_probs`
and is required when unpermuting padded outputs. Defaults to None.

Returns:
torch.Tensor: The tokens restored to their original order.
"""
if fused:
if not HAVE_TE or fused_unpermute is None:
raise ValueError("fused_unpermute is not available. Please install TE >= 2.1.0.")
extra_kwargs = {}
if is_te_min_version("2.12.0"):
extra_kwargs["pad_offsets"] = pad_offsets
return fused_unpermute(
permuted_tokens, sorted_indices, merging_probs=probs, restore_shape=restore_shape
permuted_tokens,
sorted_indices,
merging_probs=probs,
restore_shape=restore_shape,
**extra_kwargs,
)

_, hidden = restore_shape
Expand Down
25 changes: 19 additions & 6 deletions megatron/core/transformer/moe/token_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,11 +294,13 @@ def dispatch_postprocess(self, hidden_states, probs):

tokens_per_expert = self.local_map.sum(dim=0).long().cpu()

(permuted_local_hidden_states, _, self.reversed_local_input_permutation_mapping) = permute(
hidden_states,
self.local_map,
num_out_tokens=tokens_per_expert.sum().item(),
fused=self.config.moe_permute_fusion,
permuted_local_hidden_states, _, self.reversed_local_input_permutation_mapping, _, _ = (
permute(
Comment thread
xiaoxi-wangfj marked this conversation as resolved.
hidden_states,
self.local_map,
num_out_tokens=tokens_per_expert.sum().item(),
fused=self.config.moe_permute_fusion,
)
)

self.local_probs = self.local_probs.T.contiguous().masked_select(
Expand Down Expand Up @@ -634,6 +636,8 @@ def dispatch_preprocess(
permutated_local_input_tokens,
permuted_probs,
self.reversed_local_input_permutation_mapping,
_,
_,
) = permute(
hidden_states,
self.routing_map,
Expand Down Expand Up @@ -1295,12 +1299,20 @@ def get_permuted_hidden_states_by_experts(self, hidden_states: torch.Tensor) ->

self.hidden_shape_before_permute = hidden_states.shape
assert self.dispatched_probs.dtype == torch.float32, "DeepEP only supports float32 probs"
hidden_states, permuted_probs, self.reversed_mapping_for_combine = permute(
(
hidden_states,
permuted_probs,
self.reversed_mapping_for_combine,
self.pad_offsets,
self.tokens_per_expert,
) = permute(
hidden_states,
self.dispatched_routing_map,
probs=self.dispatched_probs,
num_out_tokens=self.tokens_per_expert.sum().item(),
fused=self.permute_fusion,
tokens_per_expert=self.tokens_per_expert,
align_size=get_align_size_for_quantization(self.config),
)
if self.router_dtype == "fp64":
permuted_probs = permuted_probs.to(torch.float64)
Expand All @@ -1313,6 +1325,7 @@ def get_restored_hidden_states_by_experts(self, hidden_states: torch.Tensor) ->
restore_shape=self.hidden_shape_before_permute,
routing_map=self.dispatched_routing_map,
fused=self.permute_fusion,
pad_offsets=self.pad_offsets,
)
return hidden_states

Expand Down
Loading