diff --git a/internlm/model/moe/gshard_layer.py b/internlm/model/moe/gshard_layer.py index 3aba8d1a3..62b8b0ede 100644 --- a/internlm/model/moe/gshard_layer.py +++ b/internlm/model/moe/gshard_layer.py @@ -4,7 +4,7 @@ Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555 We retain the following license from the original files: """ - +from collections import namedtuple from typing import Callable, Dict, Optional, Tuple import torch @@ -25,10 +25,26 @@ # global llm logger logger = get_logger(__file__) +try: + # To enable Tutel MoE optimizations: + # python3 -m pip install --user --upgrade git+https://github.com/microsoft/tutel@v0.3.x + from tutel import moe as tutel_moe + + TUTEL_INSTALLED = True +except (ModuleNotFoundError, ImportError): + # Fail silently so we don't spam logs unnecessarily if user isn't using tutel + TUTEL_INSTALLED = False + logger.warning("from tutel import moe failed") + pass + uniform_map: Dict[torch.device, Callable] = {} gumbel_map: Dict[torch.device, Callable] = {} exp_selection_uniform_map: Dict[torch.device, Callable] = {} +GatingTokenRearrangeInfo = namedtuple( + "GatingTokenRearrangeInfo", ["token_rearranged_ec_idx", "token_exp_weights", "expert_select_token_idx"] +) + def multiplicative_jitter(x, device: torch.device, epsilon=1e-2): """ @@ -223,7 +239,7 @@ def top1gating( dispatch_mask = combine_weights.bool() - return l_aux, combine_weights, dispatch_mask, exp_counts + return l_aux, combine_weights, dispatch_mask def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]: @@ -253,7 +269,7 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup locations2 += torch.sum(mask1, dim=0, keepdim=True) # gating decisions - exp_counts = torch.sum(mask1, dim=0).detach().to("cpu") + # exp_counts = torch.sum(mask1, dim=0).detach().to("cpu") # Compute l_aux me = torch.mean(gates, dim=0) @@ -289,7 +305,7 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup combine_weights = combine1_sec + combine2_sec dispatch_mask = combine_weights.bool() - return l_aux, combine_weights, dispatch_mask, exp_counts + return l_aux, combine_weights, dispatch_mask def fused_topkgating( @@ -297,6 +313,8 @@ def fused_topkgating( k: int, capacity_factor: float, min_capacity: int, + enable_token_rearrange_opt: bool = True, + use_tutel: bool = True, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Implements TopKGating on logits.""" # everything is in fp32 in this function @@ -306,19 +324,21 @@ def fused_topkgating( capacity = _capacity(gates, torch.tensor(capacity_factor * k), torch.tensor(min_capacity)) # Create a mask by top-k experts - indices_s = torch.topk(gates, k, dim=1).indices - indices_s = indices_s.permute(1, 0).reshape(-1) - masks = F.one_hot(indices_s, num_classes=num_experts) + indices_s = torch.topk(gates, k, dim=1).indices.t() + masks = F.one_hot(indices_s.reshape(-1), num_classes=num_experts) # Compute locations in capacity buffer - locations = torch.cumsum(masks, dim=0) - 1 + if use_tutel and TUTEL_INSTALLED: + locations = tutel_moe.fast_cumsum_sub_one(masks) + else: + locations = torch.cumsum(masks, dim=0) - 1 # reshape (s,e) to (k,s,e) masks = masks.reshape(-1, gates.shape[0], num_experts) locations = locations.reshape(-1, gates.shape[0], num_experts) # gating decisions - exp_counts = torch.sum(masks[0], dim=0).detach().to("cpu") + # exp_counts = torch.sum(masks[0], dim=0).detach() # Compute l_aux me = torch.mean(gates, dim=0) @@ -333,20 +353,39 @@ def fused_topkgating( # Normalize gate probabilities mask_float = masks.type_as(logits) - gate_s = einsum("se,kse->ks", gates, mask_float) + # gate_s = einsum("se,kse->ks", gates, mask_float) + gate_s, indices_s = torch.max(gates * mask_float, dim=2) denom_s = torch.sum(gate_s, dim=0) # Avoid divide-by-zero denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps) gate_s /= denom_s - # Calculate combine_weights and dispatch_mask - gate_all = einsum("ks,kse->kse", gate_s, mask_float) - locations_sc = F.one_hot(locations_s, num_classes=capacity).type_as(logits) - combine_sec = einsum("kse,ksc->ksec", gate_all, locations_sc) - combine_weights = torch.sum(combine_sec, dim=0) - dispatch_mask = combine_weights.bool() + if enable_token_rearrange_opt: + token_rearranged_ec_idx = indices_s.int() * capacity + locations_s.int() + # shape:[S, E]->[C, E]->[E, C]->[E*C] + token_sel_exp_int_mask = masks * torch.arange(k, 0, -1, device=masks.device).reshape(k, 1, 1) + expert_sel_top_c_token_idx = torch.topk( + torch.sum(token_sel_exp_int_mask, dim=0), k=capacity, dim=0, sorted=True + )[1] + expert_select_token_idx = expert_sel_top_c_token_idx.t().reshape(num_experts * capacity) + token_rearranged_ec_idx = token_rearranged_ec_idx.reshape(-1) + token_exp_weights = gate_s.reshape(-1) + + top2_gating_token_infos = GatingTokenRearrangeInfo( + token_rearranged_ec_idx=token_rearranged_ec_idx, + token_exp_weights=token_exp_weights, + expert_select_token_idx=expert_select_token_idx, + ) + return l_aux, top2_gating_token_infos + else: + # Calculate combine_weights and dispatch_mask + gate_all = einsum("ks,kse->kse", gate_s, mask_float) + locations_sc = F.one_hot(locations_s, num_classes=capacity).type_as(logits) + combine_sec = einsum("kse,ksc->ksec", gate_all, locations_sc) + combine_weights = torch.sum(combine_sec, dim=0) + dispatch_mask = combine_weights.bool() - return l_aux, combine_weights, dispatch_mask, exp_counts + return l_aux, combine_weights, dispatch_mask class TopKGate(Module): @@ -378,10 +417,13 @@ def __init__( noisy_gate_policy: Optional[str] = None, drop_tokens: bool = True, use_rts: bool = True, - use_fused_gating: bool = False, + use_fused_gating: bool = True, + enable_token_rearrange_opt: bool = True, + use_tutel: bool = True, ) -> None: super().__init__() - # alway use fp32 + + # Deepspeed's mechisms, alway use fp32 self.wg = torch.nn.Linear(model_dim, num_experts, bias=False) self.k = topk self.capacity_factor = capacity_factor @@ -393,6 +435,8 @@ def __init__( self.drop_tokens = drop_tokens self.use_rts = use_rts self.use_fused_gating = use_fused_gating + self.enable_token_rearrange_opt = enable_token_rearrange_opt + self.use_tutel = use_tutel def forward( self, inputs: torch.Tensor, used_token: torch.Tensor = None @@ -408,7 +452,12 @@ def forward( if self.use_fused_gating or self.k > 2: assert self.noisy_gate_policy != "RSample", "RSample noisy is not supported by fused_gating policy" gate_output = fused_topkgating( - logits, self.k, self.capacity_factor if self.training else self.eval_capacity_factor, self.min_capacity + logits, + self.k, + self.capacity_factor if self.training else self.eval_capacity_factor, + self.min_capacity, + self.enable_token_rearrange_opt, + self.use_tutel, ) # deepspeed-style code elif self.k == 1: @@ -437,11 +486,11 @@ def forward( class GShardMoELayer(BaseMoELayer): - """MOELayer module which implements MixtureOfExperts as described in Gshard_. + """MoELayer module which implements MixtureOfExperts as described in Gshard_. :: gate = TopKGate(model_dim, num_experts) - moe = MOELayer(gate, expert) + moe = MoELayer(gate, expert) output = moe(inputs) l_aux = moe.l_aux @@ -475,6 +524,8 @@ def __init__( drop_tokens: bool = True, use_rts: bool = True, use_fused_gating: bool = True, + enable_token_rearrange_opt: bool = True, + use_tutel: bool = True, use_grouped_mlp: bool = True, ) -> None: assert noisy_gate_policy is None or noisy_gate_policy in ["None", "Jitter", "RSample"], ( @@ -483,6 +534,12 @@ def __init__( assert ( num_experts % ep_size == 0 ), f"Number of experts ({num_experts}) should be divisible by expert parallel size ({ep_size})" + + if enable_token_rearrange_opt: + assert ( + use_fused_gating or top_k > 2 + ), "enable_token_rearrange_opt only can be used when use_fused_gating or top_k>2" + if use_grouped_mlp: experts = new_feed_forward( in_features, @@ -529,6 +586,8 @@ def __init__( drop_tokens, use_rts, use_fused_gating, + enable_token_rearrange_opt, + use_tutel, ), experts, ep_group, @@ -542,6 +601,9 @@ def __init__( self.time_salltoall = 0.0 self.time_moe = 0.0 self.wall_clock_breakdown = False + self.enable_token_rearrange_opt = enable_token_rearrange_opt + self.num_experts = num_experts + self.topk = top_k def forward(self, *inputs: Tensor) -> Tensor: if self.wall_clock_breakdown: @@ -555,11 +617,24 @@ def forward(self, *inputs: Tensor) -> Tensor: # group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1 reshaped_inputs = inputs[0].reshape(-1, d_model) - self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_inputs, inputs[1]) - dispatched_inputs = einsum( - "sec,sm->ecm", dispatch_mask.type_as(inputs[0]), reshaped_inputs - ) # TODO: heavy memory usage due to long sequence length - + if not self.enable_token_rearrange_opt: + self.l_aux, combine_weights, dispatch_mask = self.gate(reshaped_inputs, inputs[1]) + dispatched_inputs = einsum( + "sec,sm->ecm", dispatch_mask.type_as(inputs[0]), reshaped_inputs + ) # TODO: heavy memory usage due to long sequence length + else: + self.l_aux, token_rearrange_infos = self.gate(reshaped_inputs) + org_dtype = reshaped_inputs.dtype + if org_dtype == torch.bfloat16: # avoid precision missing + rearranged_input = torch.index_select( + reshaped_inputs.to(torch.float32), dim=0, index=token_rearrange_infos.expert_select_token_idx + ).to(org_dtype) + else: + rearranged_input = torch.index_select( + reshaped_inputs, dim=0, index=token_rearrange_infos.expert_select_token_idx + ) + capacity = token_rearrange_infos.expert_select_token_idx.size(0) // self.num_experts + dispatched_inputs = rearranged_input.reshape(self.num_experts, capacity, d_model).contiguous() if self.wall_clock_breakdown: timer("falltoall").start() @@ -600,7 +675,25 @@ def forward(self, *inputs: Tensor) -> Tensor: # Re-shape back: gecm -> ecm expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model) - combined_output = einsum("sec,ecm->sm", combine_weights.type_as(inputs[0]), expert_output) + if not self.enable_token_rearrange_opt: + combined_output = einsum("sec,ecm->sm", combine_weights.type_as(inputs[0]), expert_output) + else: + E, C, M = expert_output.shape + org_dtype = expert_output.dtype + if org_dtype == torch.bfloat16: + valid_expert_out = torch.index_select( + expert_output.view(E * C, M).to(torch.float32), + dim=0, + index=token_rearrange_infos.token_rearranged_ec_idx, + ).to(org_dtype) + else: + valid_expert_out = torch.index_select( + expert_output.view(E * C, M), dim=0, index=token_rearrange_infos.token_rearranged_ec_idx + ) + combined_output = valid_expert_out * token_rearrange_infos.token_exp_weights.unsqueeze(1).type_as(inputs[0]) + if self.topk > 1: + combined_output = combined_output.reshape(self.topk, -1, M) + combined_output = torch.sum(combined_output, dim=0) out = combined_output.reshape(inputs[0].shape)