diff --git a/transformer_engine/plugin/core/backends/flagos/flagos.py b/transformer_engine/plugin/core/backends/flagos/flagos.py index fd8a61f492..9af8b8596a 100644 --- a/transformer_engine/plugin/core/backends/flagos/flagos.py +++ b/transformer_engine/plugin/core/backends/flagos/flagos.py @@ -14,16 +14,23 @@ rmsnorm_bwd_fl, multi_tensor_scale_fl, multi_tensor_adam_fl, - multi_tensor_adam_param_remainder_fl, multi_tensor_l2_norm_fl, generic_gemm_fl, + gelu_fl, + geglu_fl, + qgelu_fl, + qgeglu_fl, + relu_fl, + reglu_fl, + moe_permute_fwd_fl, + moe_unpermute_bwd_fl, + moe_unpermute_fwd_fl, + moe_permute_bwd_fl, ) - def _check_flagos_available() -> bool: return True - class FlagOSBackend(TEFLBackendBase): @staticmethod def check_available() -> bool: @@ -35,7 +42,6 @@ def is_available(self) -> bool: def get_attention_backend(self, attention_params=None): from packaging.version import Version as PkgVersion from ...logger_manager import get_logger - logger = get_logger() # Read environment variables to determine which backends to enable @@ -65,7 +71,7 @@ def get_attention_backend(self, attention_params=None): available_backends, ) - ##### transformer_engine/pytorch/csrc/extensions/pybind.cpp ##### +##### transformer_engine/pytorch/csrc/extensions/pybind.cpp ##### def generic_gemm( self, A: Any, @@ -92,28 +98,10 @@ def generic_gemm( beta: Optional[float] = None, ) -> List[Any]: return generic_gemm_fl( - A, - transA, - B, - transB, - D, - quantizer, - output_dtype, - bias, - bias_type, - gelu, - gelu_in, - grad, - workspace, - workspace_size, - accumulate, - use_split_accumulator, - comm_overlap, - comm_type, - extra_output, - bulk_overlap, - alpha, - beta, + A, transA, B, transB, D, quantizer, output_dtype, + bias, bias_type, gelu, gelu_in, grad, workspace, workspace_size, + accumulate, use_split_accumulator, comm_overlap, comm_type, + extra_output, bulk_overlap, alpha, beta ) # Other granular functions @@ -129,16 +117,10 @@ def rmsnorm_fwd( zero_centered_gamma: bool, ) -> List[Any]: return rmsnorm_fwd_fl( - input=input, - weight=weight, - eps=eps, - ln_out=ln_out, - quantizer=quantizer, - odtype=otype, - sm_margin=sm_margin, - zero_centered_gamma=zero_centered_gamma, + input=input, weight=weight, eps=eps, ln_out=ln_out, + quantizer=quantizer, odtype=otype, + sm_margin=sm_margin, zero_centered_gamma=zero_centered_gamma, ) - def rmsnorm_bwd( self, dz: torch.Tensor, @@ -149,14 +131,9 @@ def rmsnorm_bwd( zero_centered_gamma: bool, ) -> List[Any]: return rmsnorm_bwd_fl( - dy=dz, - x=x, - rsigma=rsigma, - gamma=gamma, - sm_margin=sm_margin, - zero_centered_gamma=zero_centered_gamma, + dy=dz, x=x, rsigma=rsigma, gamma=gamma, + sm_margin=sm_margin, zero_centered_gamma=zero_centered_gamma ) - def get_fused_attn_backend(self, *args, **kwargs) -> int: return NVTE_Fused_Attn_Backend.NVTE_No_Backend @@ -169,7 +146,6 @@ def multi_tensor_scale( scale: float, ) -> None: return multi_tensor_scale_fl(chunk_size, noop_flag, tensor_lists, scale) - def multi_tensor_l2norm( self, chunk_size: int, @@ -178,7 +154,6 @@ def multi_tensor_l2norm( per_tensor: Optional[bool] = False, ) -> Tuple[torch.Tensor, torch.Tensor]: return multi_tensor_l2_norm_fl(chunk_size, noop_flag, tensor_lists, per_tensor) - def multi_tensor_adam( self, chunk_size: int, @@ -194,19 +169,9 @@ def multi_tensor_adam( weight_decay: float, ) -> None: return multi_tensor_adam_fl( - chunk_size, - noop_flag, - tensor_lists, - lr, - beta1, - beta2, - epsilon, - step, - mode, - bias_correction, - weight_decay, + chunk_size, noop_flag, tensor_lists, lr, beta1, beta2, epsilon, + step, mode, bias_correction, weight_decay, ) - def multi_tensor_adam_param_remainder( self, chunk_size: int, @@ -222,31 +187,50 @@ def multi_tensor_adam_param_remainder( weight_decay: float, ) -> None: return multi_tensor_adam_param_remainder_fl( - chunk_size, - noop_flag, - tensor_lists, - lr, - beta1, - beta2, - epsilon, - step, - mode, - bias_correction, - weight_decay, + chunk_size, noop_flag, tensor_lists, + lr, beta1, beta2, epsilon, + step, mode, bias_correction, weight_decay, ) # Misc def get_cublasLt_version(self) -> int: return 110000 - def get_cudnn_version(self) -> int: return 90000 - def get_num_cublas_streams(self) -> int: return 0 - ############## class func ################################# +############## class func ################################# def get_flash_attention_class(self): from .attention.dot_product_attention.backends import FlashAttentionFL - return FlashAttentionFL + + def gelu(self, input: torch.Tensor, quantizer: Any) -> Any: + return gelu_fl(input, quantizer) + + def geglu(self, input: torch.Tensor, quantizer: Any) -> Any: + return geglu_fl(input, quantizer) + + def qgelu(self, input: torch.Tensor, quantizer: Any) -> Any: + return qgelu_fl(input, quantizer) + + def qgeglu(self, input: torch.Tensor, quantizer: Any) -> Any: + return qgeglu_fl(input, quantizer) + + def relu(self, input: torch.Tensor, quantizer: Any) -> Any: + return relu_fl(input, quantizer) + + def reglu(self, input: torch.Tensor, quantizer: Any) -> Any: + return reglu_fl(input, quantizer) + + def moe_permute_fwd(self, *args, **kwargs) -> Any: + return moe_permute_fwd_fl(*args, **kwargs) + + def moe_unpermute_bwd(self, *args, **kwargs) -> Any: + return moe_unpermute_bwd_fl(*args, **kwargs) + + def moe_unpermute_fwd(self, *args, **kwargs) -> Any: + return moe_unpermute_fwd_fl(*args, **kwargs) + + def moe_permute_bwd(self, *args, **kwargs) -> Any: + return moe_permute_bwd_fl(*args, **kwargs) diff --git a/transformer_engine/plugin/core/backends/flagos/impl/__init__.py b/transformer_engine/plugin/core/backends/flagos/impl/__init__.py index f17b38c9e6..e64e8f4913 100644 --- a/transformer_engine/plugin/core/backends/flagos/impl/__init__.py +++ b/transformer_engine/plugin/core/backends/flagos/impl/__init__.py @@ -6,3 +6,5 @@ from .rmsnorm import * from .fused_adam import * from .multi_tensor import * +from .activation import * +from .moe_permute import * diff --git a/transformer_engine/plugin/core/backends/flagos/impl/activation.py b/transformer_engine/plugin/core/backends/flagos/impl/activation.py new file mode 100644 index 0000000000..69cbc1dbfb --- /dev/null +++ b/transformer_engine/plugin/core/backends/flagos/impl/activation.py @@ -0,0 +1,30 @@ +import torch +from typing import Any +import flag_gems + + +def gelu_fl(input: torch.Tensor, quantizer: Any) -> torch.Tensor: + return flag_gems.gelu(input, approximate="tanh") + + +def geglu_fl(input: torch.Tensor, quantizer: Any) -> torch.Tensor: + a, b = input.chunk(2, dim=-1) + return flag_gems.gelu(a, approximate="tanh") * b + + +def qgelu_fl(input: torch.Tensor, quantizer: Any) -> torch.Tensor: + return input * flag_gems.sigmoid(1.702 * input) + + +def qgeglu_fl(input: torch.Tensor, quantizer: Any) -> torch.Tensor: + a, b = input.chunk(2, dim=-1) + return a * flag_gems.sigmoid(1.702 * a) * b + + +def relu_fl(input: torch.Tensor, quantizer: Any) -> torch.Tensor: + return flag_gems.relu(input) + + +def reglu_fl(input: torch.Tensor, quantizer: Any) -> torch.Tensor: + a, b = input.chunk(2, dim=-1) + return flag_gems.relu(a) * b diff --git a/transformer_engine/plugin/core/backends/flagos/impl/moe_permute.py b/transformer_engine/plugin/core/backends/flagos/impl/moe_permute.py new file mode 100644 index 0000000000..b49e3289b0 --- /dev/null +++ b/transformer_engine/plugin/core/backends/flagos/impl/moe_permute.py @@ -0,0 +1,328 @@ +import torch +from typing import Tuple, List, Optional +import triton +import triton.language as tl + + +@triton.jit +def moe_unpermute_kernel_triton( + input_ptr, # [total_expert_tokens, num_cols] + output_ptr, # [num_tokens, num_cols] + row_id_map_ptr, # [topK, num_rows] or [topK * num_rows] + prob_ptr, # [num_rows, topK] + num_rows: tl.constexpr, # num_tokens + num_cols: tl.constexpr, + topK: tl.constexpr, + HAS_PROB: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + # blockIdx.x + source_token = tl.program_id(0) + + # Traverse along the hidden dimention + for col_offset in range(0, num_cols, BLOCK_SIZE): + cols = col_offset + tl.arange(0, BLOCK_SIZE) + col_mask = cols < num_cols + + # frag_sum + acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + + for k in range(topK): + # k * num_rows + source_token + map_idx = k * num_rows + source_token + + source_row = tl.load(row_id_map_ptr + map_idx) + + # source_row == -1 + if source_row != -1: + in_ptrs = input_ptr + source_row * num_cols + cols + val = tl.load(in_ptrs, mask=col_mask, other=0.0) + + if HAS_PROB: + # source_token * topK + k + prob_val = tl.load(prob_ptr + source_token * topK + k) + val = val * prob_val + + acc += val + + # unpermuted_output + out_ptrs = output_ptr + source_token * num_cols + cols + + # store + tl.store(out_ptrs, acc.to(output_ptr.dtype.element_ty), mask=col_mask) + + +@triton.jit +def _kernel_moe_permute( + input_bwd_ptr, + input_fwd_ptr, + act_grad_ptr, + prob_ptr, + prob_grad_ptr, + row_id_map_ptr, + num_rows, + num_cols, + topk, + BLOCK_SIZE: tl.constexpr, + HAS_PROB: tl.constexpr, + TOPK_P2: tl.constexpr, # TopK padded to power of 2 +): + pid = tl.program_id(0) + + source_row_start_ptr = input_bwd_ptr + pid * num_cols + + # Accumulator for prob_grad: Shape [TOPK_P2] + # Initialize with zeros + prob_grad_acc = tl.zeros((TOPK_P2,), dtype=tl.float32) + + for col_offset in range(0, num_cols, BLOCK_SIZE): + cols = col_offset + tl.arange(0, BLOCK_SIZE) + mask = cols < num_cols + + # Load source data (Reused across K) + # [BLOCK_SIZE] + source_vec = tl.load(source_row_start_ptr + cols, mask=mask, other=0.0) + + # Loop over experts + for k in range(topk): + # Map index calculation: row_id_map is [TopK, Num_Rows] + # Stride is num_rows + map_idx = k * num_rows + pid + dest_row = tl.load(row_id_map_ptr + map_idx) + + if dest_row != -1: + dest_ptr_base = act_grad_ptr + dest_row * num_cols + dest_ptr = dest_ptr_base + cols + + if HAS_PROB: + # Load prob: [N, TopK] -> ptr + pid*topk + k + p = tl.load(prob_ptr + pid * topk + k) + + # 1. Compute act_grad + val = source_vec * p + tl.store(dest_ptr, val, mask=mask) + + # 2. Compute prob_grad (accumulate dot product) + fwd_vec = tl.load( + input_fwd_ptr + dest_row * num_cols + cols, mask=mask, other=0.0 + ) + partial_dot = tl.sum(source_vec * fwd_vec) + + # Update accumulator at index k + # We use a mask to update only the k-th element of the vector + k_mask = tl.arange(0, TOPK_P2) == k + prob_grad_acc = tl.where(k_mask, prob_grad_acc + partial_dot, prob_grad_acc) + else: + tl.store(dest_ptr, source_vec, mask=mask) + + if HAS_PROB: + # Store prob_grad + # prob_grad_ptr is [N, TopK] + out_ptr = prob_grad_ptr + pid * topk + tl.arange(0, TOPK_P2) + mask_store = tl.arange(0, TOPK_P2) < topk + tl.store(out_ptr, prob_grad_acc, mask=mask_store) + + +@triton.jit +def moe_permute_row_map_kernel( + sorted_row_id_ptr, # const int *sorted_row_id + row_id_map_ptr, # int *row_id_map + num_rows, # const int num_rows + topk, # const int topK + num_out_tokens, # const int num_out_tokens + n_elements, # (num_rows * topk) + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + mask = offsets < n_elements + + source_row = tl.load(sorted_row_id_ptr + offsets, mask=mask, other=0) + + source_token_id = source_row // topk + source_topK_id = source_row % topk + + dest_offset = source_topK_id * num_rows + source_token_id + dest_ptr = row_id_map_ptr + dest_offset + + value_to_write = tl.where(offsets < num_out_tokens, offsets, -1) + + value_to_write = value_to_write.to(tl.int32) + + tl.store(dest_ptr, value_to_write, mask=mask) + + +def moe_permute_row_map(sorted_row_id, num_rows, topk, num_out_tokens): + """ """ + sorted_row_id = sorted_row_id.contiguous() + + row_id_map = torch.empty(topk * num_rows, device=sorted_row_id.device, dtype=torch.int32) + + n_elements = num_rows * topk + + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + + moe_permute_row_map_kernel[grid]( + sorted_row_id, row_id_map, num_rows, topk, num_out_tokens, n_elements, BLOCK_SIZE=1024 + ) + + return row_id_map + + +def moe_permute_fwd_fl( + inp: torch.Tensor, + dtype: torch.dtype, + indices: torch.Tensor, + num_out_tokens: int, + workspace: Optional[List[torch.Tensor]] = None, + max_token_num: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: + num_tokens, num_cols = inp.shape + topk = indices.shape[1] + device = inp.device + + num_out_tokens = num_out_tokens if num_out_tokens > 0 else num_tokens * topk + + assert inp.is_cuda, "compute needs CUDA." + + # nvte_device_radix_sort_pairs + source_row_ids = torch.arange(num_tokens * topk, dtype=torch.int32, device=device) + keys_flat = indices.view(-1).to(torch.int32) + sorted_expert_indices, permutation_idx = torch.sort(keys_flat, stable=True) + sorted_row_id = source_row_ids[permutation_idx] + + # row_id_map + row_id_map = moe_permute_row_map(sorted_row_id, num_tokens, topk, num_out_tokens) + + permuted_tokens = torch.empty((num_out_tokens, num_cols), dtype=inp.dtype, device=device) + + grid = (num_tokens,) + + _kernel_moe_permute[grid]( + input_bwd_ptr=inp, + input_fwd_ptr=None, # + act_grad_ptr=permuted_tokens, + prob_ptr=None, # + prob_grad_ptr=None, # + row_id_map_ptr=row_id_map, + num_rows=num_tokens, + num_cols=num_cols, + topk=topk, + BLOCK_SIZE=1024, + HAS_PROB=False, + TOPK_P2=triton.next_power_of_2(topk), + ) + + return permuted_tokens, row_id_map, (workspace if workspace is not None else []) + + +def moe_unpermute_bwd_fl( + input_bwd: torch.Tensor, + input_fwd: torch.Tensor, + dtype: torch.dtype, + row_id_map: torch.Tensor, + prob: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + if prob.numel() > 0: + topk = prob.size(1) + num_tokens = prob.size(0) + else: + topk = 1 + num_tokens = row_id_map.size(0) + num_cols = input_bwd.size(1) + device = input_bwd.device + + act_grad = torch.empty((input_fwd.size(0), num_cols), dtype=input_bwd.dtype, device=device) + + prob_grad = torch.zeros((num_tokens, topk), dtype=torch.float32, device=device) + + grid = (num_tokens,) + + if prob.numel() == 0: + _kernel_moe_permute[grid]( + input_bwd_ptr=input_bwd, + input_fwd_ptr=input_fwd, + act_grad_ptr=act_grad, + prob_ptr=None, + prob_grad_ptr=None, + row_id_map_ptr=row_id_map, + num_rows=num_tokens, + num_cols=num_cols, + topk=topk, + BLOCK_SIZE=1024, + HAS_PROB=False, + TOPK_P2=triton.next_power_of_2(topk), + ) + else: + _kernel_moe_permute[grid]( + input_bwd_ptr=input_bwd, + input_fwd_ptr=input_fwd, + act_grad_ptr=act_grad, + prob_ptr=prob, + prob_grad_ptr=prob_grad, + row_id_map_ptr=row_id_map, + num_rows=num_tokens, + num_cols=num_cols, + topk=topk, + BLOCK_SIZE=1024, + HAS_PROB=True, + TOPK_P2=triton.next_power_of_2(topk), + ) + + return act_grad, prob_grad + + +def moe_permute_bwd_fl( + input_bwd: torch.Tensor, + dtype: torch.dtype, + row_id_map: torch.Tensor, + prob: torch.Tensor = None, + num_tokens: int = None, + topK: int = None, +) -> torch.Tensor: + return moe_unpermute_fwd_fl(input_bwd, dtype, row_id_map, prob, num_tokens, topK) + + +def moe_unpermute_fwd_fl( + input_fwd: torch.Tensor, + dtype: torch.dtype, + row_id_map: torch.Tensor, + prob: torch.Tensor = None, + num_tokens: int = None, + topK: int = None, +) -> torch.Tensor: + num_cols = input_fwd.size(1) + device = input_fwd.device + unpermuted_output = torch.empty((num_tokens, num_cols), dtype=input_fwd.dtype, device=device) + + BLOCK_SIZE = triton.next_power_of_2(num_cols) + BLOCK_SIZE = min(BLOCK_SIZE, 1024) + grid = (num_tokens,) + + if prob.numel() == 0: + moe_unpermute_kernel_triton[grid]( + input_fwd, + unpermuted_output, + row_id_map, + None, + num_rows=num_tokens, + num_cols=num_cols, + topK=topK, + HAS_PROB=False, + BLOCK_SIZE=BLOCK_SIZE, + ) + else: + moe_unpermute_kernel_triton[grid]( + input_fwd, + unpermuted_output, + row_id_map, + prob, + num_rows=num_tokens, + num_cols=num_cols, + topK=topK, + HAS_PROB=True, + BLOCK_SIZE=BLOCK_SIZE, + ) + return unpermuted_output diff --git a/transformer_engine/plugin/core/backends/flagos/register_ops.py b/transformer_engine/plugin/core/backends/flagos/register_ops.py index 0136b6a983..2f865f4c42 100644 --- a/transformer_engine/plugin/core/backends/flagos/register_ops.py +++ b/transformer_engine/plugin/core/backends/flagos/register_ops.py @@ -17,11 +17,9 @@ def _bind_is_available(fn, is_available_fn): """Wrap a function and bind _is_available attribute for OpImpl.is_available() check.""" - @functools.wraps(fn) def wrapper(*args, **kwargs): return fn(*args, **kwargs) - wrapper._is_available = is_available_fn return wrapper @@ -42,88 +40,30 @@ def register_builtins(registry) -> None: is_avail = backend.is_available impls = [ - OpImpl( - op_name="rmsnorm_fwd", - impl_id="default.flagos", - kind=BackendImplKind.DEFAULT, - fn=_bind_is_available(backend.rmsnorm_fwd, is_avail), - vendor=None, - priority=150, - ), - OpImpl( - op_name="rmsnorm_bwd", - impl_id="default.flagos", - kind=BackendImplKind.DEFAULT, - fn=_bind_is_available(backend.rmsnorm_bwd, is_avail), - vendor=None, - priority=150, - ), - OpImpl( - op_name="generic_gemm", - impl_id="default.flagos", - kind=BackendImplKind.DEFAULT, - fn=_bind_is_available(backend.generic_gemm, is_avail), - vendor=None, - priority=150, - ), - OpImpl( - op_name="multi_tensor_scale", - impl_id="default.flagos", - kind=BackendImplKind.DEFAULT, - fn=_bind_is_available(backend.multi_tensor_scale, is_avail), - vendor=None, - priority=150, - ), - OpImpl( - op_name="multi_tensor_adam", - impl_id="default.flagos", - kind=BackendImplKind.DEFAULT, - fn=_bind_is_available(backend.multi_tensor_adam, is_avail), - vendor=None, - priority=150, - ), - OpImpl( - op_name="multi_tensor_adam_param_remainder", - impl_id="default.flagos", - kind=BackendImplKind.DEFAULT, - fn=_bind_is_available(backend.multi_tensor_adam_param_remainder, is_avail), - vendor=None, - priority=150, - ), - OpImpl( - op_name="multi_tensor_l2norm", - impl_id="default.flagos", - kind=BackendImplKind.DEFAULT, - fn=_bind_is_available(backend.multi_tensor_l2norm, is_avail), - vendor=None, - priority=150, - ), + OpImpl(op_name="rmsnorm_fwd", impl_id="default.flagos", kind=BackendImplKind.DEFAULT, fn=_bind_is_available(backend.rmsnorm_fwd, is_avail), vendor=None, priority=150), + OpImpl(op_name="rmsnorm_bwd", impl_id="default.flagos", kind=BackendImplKind.DEFAULT, fn=_bind_is_available(backend.rmsnorm_bwd, is_avail), vendor=None, priority=150), + OpImpl(op_name="generic_gemm", impl_id="default.flagos", kind=BackendImplKind.DEFAULT, fn=_bind_is_available(backend.generic_gemm, is_avail), vendor=None, priority=150), + OpImpl(op_name="multi_tensor_scale", impl_id="default.flagos", kind=BackendImplKind.DEFAULT, fn=_bind_is_available(backend.multi_tensor_scale, is_avail), vendor=None, priority=150), + OpImpl(op_name="multi_tensor_adam", impl_id="default.flagos", kind=BackendImplKind.DEFAULT, fn=_bind_is_available(backend.multi_tensor_adam, is_avail), vendor=None, priority=150), + OpImpl(op_name="multi_tensor_adam_param_remainder", impl_id="default.flagos", kind=BackendImplKind.DEFAULT, fn=_bind_is_available(backend.multi_tensor_adam_param_remainder, is_avail), vendor=None, priority=150), + OpImpl(op_name="multi_tensor_l2norm", impl_id="default.flagos", kind=BackendImplKind.DEFAULT, fn=_bind_is_available(backend.multi_tensor_l2norm, is_avail), vendor=None, priority=150), + # MOE operations + OpImpl(op_name="moe_permute_fwd", impl_id="default.flagos", kind=BackendImplKind.DEFAULT, fn=_bind_is_available(backend.moe_permute_fwd, is_avail), vendor=None, priority=150), + OpImpl(op_name="moe_unpermute_bwd", impl_id="default.flagos", kind=BackendImplKind.DEFAULT, fn=_bind_is_available(backend.moe_unpermute_bwd, is_avail), vendor=None, priority=150), + OpImpl(op_name="moe_permute_bwd", impl_id="default.flagos", kind=BackendImplKind.DEFAULT, fn=_bind_is_available(backend.moe_permute_bwd, is_avail), vendor=None, priority=150), + OpImpl(op_name="moe_unpermute_fwd", impl_id="default.flagos", kind=BackendImplKind.DEFAULT, fn=_bind_is_available(backend.moe_unpermute_fwd, is_avail), vendor=None, priority=150), + # FlashAttention class getter - OpImpl( - op_name="get_flash_attention_class", - impl_id="default.flagos", - kind=BackendImplKind.DEFAULT, - fn=_bind_is_available(backend.get_flash_attention_class, is_avail), - vendor=None, - priority=150, - ), + OpImpl(op_name="get_flash_attention_class", impl_id="default.flagos", kind=BackendImplKind.DEFAULT, fn=_bind_is_available(backend.get_flash_attention_class, is_avail), vendor=None, priority=150), + # Attention backend selection - OpImpl( - op_name="get_attention_backend", - impl_id="default.flagos", - kind=BackendImplKind.DEFAULT, - fn=_bind_is_available(backend.get_attention_backend, is_avail), - vendor=None, - priority=150, - ), - OpImpl( - op_name="get_fused_attn_backend", - impl_id="default.flagos", - kind=BackendImplKind.DEFAULT, - fn=_bind_is_available(backend.get_fused_attn_backend, is_avail), - vendor=None, - priority=150, - ), + OpImpl(op_name="get_attention_backend", impl_id="default.flagos", kind=BackendImplKind.DEFAULT, fn=_bind_is_available(backend.get_attention_backend, is_avail), vendor=None, priority=150), + OpImpl(op_name="get_fused_attn_backend", impl_id="default.flagos", kind=BackendImplKind.DEFAULT, fn=_bind_is_available(backend.get_fused_attn_backend, is_avail), vendor=None, priority=150), + # MOE operations + OpImpl(op_name="moe_permute_fwd", impl_id="default.flagos", kind=BackendImplKind.DEFAULT, fn=_bind_is_available(backend.moe_permute_fwd, is_avail), vendor=None, priority=150), + OpImpl(op_name="moe_unpermute_bwd", impl_id="default.flagos", kind=BackendImplKind.DEFAULT, fn=_bind_is_available(backend.moe_unpermute_bwd, is_avail), vendor=None, priority=150), + OpImpl(op_name="moe_permute_bwd", impl_id="default.flagos", kind=BackendImplKind.DEFAULT, fn=_bind_is_available(backend.moe_permute_bwd, is_avail), vendor=None, priority=150), + OpImpl(op_name="moe_unpermute_fwd", impl_id="default.flagos", kind=BackendImplKind.DEFAULT, fn=_bind_is_available(backend.moe_unpermute_fwd, is_avail), vendor=None, priority=150), ] registry.register_many(impls) diff --git a/transformer_engine/plugin/tests/test_moe_permute.py b/transformer_engine/plugin/tests/test_moe_permute.py new file mode 100644 index 0000000000..446dca7da0 --- /dev/null +++ b/transformer_engine/plugin/tests/test_moe_permute.py @@ -0,0 +1,340 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +import torch +import torch.nn.functional as F + +from transformer_engine.plugin.test_utils import ( + get_available_backends, + get_backend, + TestCase, + generate_random_tensor, +) + + +class moeTests(TestCase): + def __init__(self, device="cpu"): + super().__init__( + "Moe permute Operations", + "Test correctness of all moe permute operations across backends", + ) + self.backends = get_available_backends() + self.device = device + + def test_moe_permute_fw_basic(self, num_tokens=8, num_cols=256, topK=4, num_out_tokens=-1): + print( + f"\nTesting moe_permute_fw_basic (tokens={num_tokens}, cols={num_cols}," + f" topK={topK},num_out_tokens={num_out_tokens})" + ) + import transformer_engine_torch_nv as te + + input_tensor = generate_random_tensor( + (num_tokens, num_cols), dtype=torch.float16, device=self.device + ) + indices = torch.randint(0, 8, (num_tokens, topK), dtype=torch.int32, device=self.device) + + if input_tensor.dtype == torch.float16: + dtype = te.DType.kFloat16 + elif input_tensor.dtype == torch.float32: + dtype = te.DType.kFloat32 + elif input_tensor.dtype == torch.bfloat16: + dtype = te.DType.kBFloat16 + else: + raise ValueError("Unsupported dtype") + + workspace = [] + max_expanded_token_num = num_tokens * topK + 10 + + reference_permuted, reference_row_id_map, reference_workspace = te.moe_permute_fwd( + input_tensor, dtype, indices, num_out_tokens, workspace, max_expanded_token_num + ) + for backend_name in self.backends: + backend = get_backend(backend_name) + print("backend:", backend) + try: + permuted_output, row_id_map, workspace_out = backend.moe_permute_fwd( + input_tensor, dtype, indices, num_out_tokens, workspace, max_expanded_token_num + ) + self.assert_close( + reference_permuted, + permuted_output, + rtol=1e-2, + atol=1e-3, + msg=f"moe_permute_fwd mismatch for {backend_name}", + ) + self.assert_close( + row_id_map, + reference_row_id_map, + rtol=1e-2, + atol=1e-3, + msg=f"moe_permute_fwd mismatch for {backend_name}", + ) + + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ Test failed: {e}") + + def test_moe_permute_fw_out_token(self, num_tokens=8, num_cols=256, topK=4, num_out_tokens=8): + print( + f"\nTesting moe_permute_fw_out_token (tokens={num_tokens}, cols={num_cols}," + f" topK={topK},num_out_tokens={num_out_tokens})" + ) + import transformer_engine_torch_nv as te + + input_tensor = generate_random_tensor( + (num_tokens, num_cols), dtype=torch.float16, device=self.device + ) + indices = torch.randint(0, 8, (num_tokens, topK), dtype=torch.int32, device=self.device) + + if input_tensor.dtype == torch.float16: + dtype = te.DType.kFloat16 + elif input_tensor.dtype == torch.float32: + dtype = te.DType.kFloat32 + elif input_tensor.dtype == torch.bfloat16: + dtype = te.DType.kBFloat16 + else: + raise ValueError("Unsupported dtype") + + workspace = [] + max_expanded_token_num = num_tokens * topK + 10 + + reference_permuted, reference_row_id_map, reference_workspace = te.moe_permute_fwd( + input_tensor, dtype, indices, num_out_tokens, workspace, max_expanded_token_num + ) + for backend_name in self.backends: + backend = get_backend(backend_name) + print("backend:", backend) + try: + permuted_output, row_id_map, workspace_out = backend.moe_permute_fwd( + input_tensor, dtype, indices, num_out_tokens, workspace, max_expanded_token_num + ) + self.assert_close( + reference_permuted, + permuted_output, + rtol=1e-2, + atol=1e-3, + msg=f"moe_permute_fwd_out_token mismatch for {backend_name}", + ) + self.assert_close( + row_id_map, + reference_row_id_map, + rtol=1e-2, + atol=1e-3, + msg=f"moe_permute_fwd_out_token mismatch for {backend_name}", + ) + + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ Test failed: {e}") + + def test_moe_unpermute_bw_basic(self, num_tokens=8, num_cols=256, topK=4, num_out_tokens=8): + print( + f"\nTesting moe_unpermute_bw_basic (tokens={num_tokens}, cols={num_cols}, topK={topK}," + f" num_out_tokens={num_out_tokens})" + ) + import transformer_engine_torch_nv as te + + input_fwd = generate_random_tensor( + (num_tokens, num_cols), dtype=torch.float16, device=self.device + ) + indices = torch.randint(0, 8, (num_tokens, topK), dtype=torch.int32, device=self.device) + + prob = torch.rand((num_tokens, topK), dtype=torch.float32, device=self.device) + prob = prob / prob.sum(dim=1, keepdim=True) + + if input_fwd.dtype == torch.float16: + dtype = te.DType.kFloat16 + elif input_fwd.dtype == torch.float32: + dtype = te.DType.kFloat32 + elif input_fwd.dtype == torch.bfloat16: + dtype = te.DType.kBFloat16 + else: + raise ValueError("Unsupported dtype") + + workspace = [] + max_expanded_token_num = num_tokens * topK + 10 + + _, row_id_map, _ = te.moe_permute_fwd( + input_fwd, dtype, indices, num_out_tokens, workspace, max_expanded_token_num + ) + + input_bwd = generate_random_tensor( + (num_out_tokens, num_cols), dtype=input_fwd.dtype, device=self.device + ) + + reference_act_grad, reference_prob_grad = te.moe_unpermute_bwd( + input_bwd, input_fwd, dtype, row_id_map, prob + ) + for backend_name in self.backends: + backend = get_backend(backend_name) + print("backend:", backend) + try: + act_grad, prob_grad = backend.moe_unpermute_bwd( + input_bwd, input_fwd, dtype, row_id_map, prob + ) + self.assert_close( + act_grad, + reference_act_grad, + rtol=1e-2, + atol=1e-3, + msg=f"moe_unpermute_bw_basic mismatch for {backend_name}", + ) + self.assert_close( + prob_grad, + reference_prob_grad, + rtol=1e-2, + atol=1e-3, + msg=f"moe_unpermute_bw_basic mismatch for {backend_name}", + ) + + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ Test failed: {e}") + + def test_moe_unpermute_fw_basic(self, num_tokens=4, num_cols=8, topK=1, num_out_tokens=2): + print( + f"\nTesting moe_unpermute_fw_basic (tokens={num_tokens}, cols={num_cols}," + f" topK={topK},num_out_tokens={num_out_tokens})" + ) + import transformer_engine_torch_nv as te + + input_tensor = generate_random_tensor( + (num_tokens, num_cols), dtype=torch.float16, device=self.device + ) + probs = torch.rand((num_tokens, topK), device=self.device, dtype=torch.float32) + probs = probs / probs.sum(dim=1, keepdim=True) + # probs = torch.empty(0) + + indices = torch.randint(0, 8, (num_tokens, topK), dtype=torch.int32, device=self.device) + + if input_tensor.dtype == torch.float16: + dtype = te.DType.kFloat16 + elif input_tensor.dtype == torch.float32: + dtype = te.DType.kFloat32 + elif input_tensor.dtype == torch.bfloat16: + dtype = te.DType.kBFloat16 + else: + raise ValueError("Unsupported dtype") + + workspace = [] + max_expanded_token_num = num_tokens * topK + 10 + + permuted, row_id_map, _ = te.moe_permute_fwd( + input_tensor, dtype, indices, num_out_tokens, workspace, max_expanded_token_num + ) + reference_output = te.moe_unpermute_fwd( + permuted, dtype, row_id_map, probs, num_tokens, topK + ) + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + moe_output = backend.moe_unpermute_fwd( + permuted, dtype, row_id_map, probs, num_tokens, topK + ) + self.assert_close( + moe_output, + reference_output, + rtol=1e-2, + atol=1e-3, + msg=f"moe_unpermute_fw_basic mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ Test failed: {e}") + + def test_moe_permute_bw_basic(self, num_tokens=8, num_cols=256, topK=4, num_out_tokens=8): + print( + f"\nTesting moe_permute_bw_basic (tokens={num_tokens}, cols={num_cols}," + f" topK={topK},num_out_tokens={num_out_tokens})" + ) + import transformer_engine_torch_nv as te + + input_tensor = generate_random_tensor( + (num_tokens, num_cols), dtype=torch.float16, device=self.device + ) + # cuda kernel :probs must float32 + probs = generate_random_tensor((num_tokens, topK), dtype=torch.float32, device=self.device) + probs = probs / probs.sum(dim=1, keepdim=True) + + indices = torch.randint(0, 8, (num_tokens, topK), dtype=torch.int32, device=self.device) + + if input_tensor.dtype == torch.float16: + dtype = te.DType.kFloat16 + elif input_tensor.dtype == torch.float32: + dtype = te.DType.kFloat32 + elif input_tensor.dtype == torch.bfloat16: + dtype = te.DType.kBFloat16 + else: + raise ValueError("Unsupported dtype") + + workspace = [] + max_expanded_token_num = num_tokens * topK + 10 + + permuted, row_id_map, _ = te.moe_permute_fwd( + input_tensor, dtype, indices, num_out_tokens, workspace, max_expanded_token_num + ) + reference_output = te.moe_permute_bwd(permuted, dtype, row_id_map, probs, num_tokens, topK) + + for backend_name in self.backends: + backend = get_backend(backend_name) + try: + moe_output = backend.moe_permute_bwd( + permuted, dtype, row_id_map, probs, num_tokens, topK + ) + self.assert_close( + moe_output, + reference_output, + rtol=1e-2, + atol=1e-3, + msg=f"moe_permute_bw_basic mismatch for {backend_name}", + ) + print(f" ✓ {backend_name}") + except NotImplementedError: + self.skipped += 1 + print(f" ⊘ {backend_name} (not implemented)") + except Exception as e: + self.failed += 1 + print(f" ✗ Test failed: {e}") + + def run_all_tests(self): + print("\n" + "=" * 60) + print("=" * 60) + print(f"Available backends: {', '.join(self.backends)}") + + # moe permute forward tests + # self.test_moe_permute_fw_basic() + # self.test_moe_permute_fw_out_token() + # self.test_moe_permute_bw_basic() + # self.test_moe_unpermute_fw_basic() + self.test_moe_unpermute_bw_basic() + + return self.report() + + +def main(): + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}") + test_suite = moeTests(device=device) + success = test_suite.run_all_tests() + return 0 if success else 1 + + +if __name__ == "__main__": + exit(main())