diff --git a/megatron/core/fp8_utils.py b/megatron/core/fp8_utils.py index fa6be91dfbf..50b9d1e35c8 100644 --- a/megatron/core/fp8_utils.py +++ b/megatron/core/fp8_utils.py @@ -783,3 +783,78 @@ def prepare_model_for_fp8_inference(model): "prepare_model_for_fp8_inference requires Transformer Engine to be installed. " "Please install transformer-engine to use FP8 inference." ) + + +if HAVE_TE: + from functools import lru_cache + + import transformer_engine_torch as tex + from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( + Float8BlockQuantizer, + Float8BlockwiseQTensor, + ) + + @lru_cache(maxsize=None) + def _get_fp8_quantizer(recipe, all_gather_usage=False): + if recipe == Fp8Recipe.blockwise: + return Float8BlockQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=False, + amax_epsilon=1e-10, + force_pow_2_scales=True, + block_scaling_dim=1, + all_gather_usage=all_gather_usage, + ) + + return None + + def fp8_quantize(recipe: Fp8Recipe, x: torch.Tensor, all_gather_usage=False): + q = _get_fp8_quantizer(recipe, all_gather_usage) + if q is None: + return x + + quantized_tensor = q(x) + if recipe == Fp8Recipe.blockwise: + if quantized_tensor._data_format == tex.Float8BlockScaleTensorFormat.COMPACT: + return (quantized_tensor._rowwise_data, quantized_tensor._rowwise_scale_inv) + else: + return ( + quantized_tensor._rowwise_data, + quantized_tensor._rowwise_scale_inv.T.contiguous(), + ) + + return x + + def make_fp8_tensor(recipe: Fp8Recipe, x: torch.Tensor, x_scale: torch.Tensor): + q = _get_fp8_quantizer(recipe) + if q is None: + return None + + if recipe == Fp8Recipe.blockwise: + # To accelerate fp8flow and reduce redundant cases of T.contiguous() + # scale_inv use COMPACT + return Float8BlockwiseQTensor( + shape=x.shape, + dtype=torch.bfloat16, + rowwise_data=x.view(torch.uint8), + rowwise_scale_inv=x_scale, + columnwise_data=None, + columnwise_scale_inv=None, + fp8_dtype=tex.DType.kFloat8E4M3, + quantizer=q, + is_2D_scaled=False, + requires_grad=x.requires_grad, + data_format=tex.Float8BlockScaleTensorFormat.GEMM_READY, + ) + return None + +else: + + def fp8_quantize(recipe: Fp8Recipe, x: torch.Tensor, all_gather_usage=False): + """Transformer Engine not available: passthrough tensor.""" + return x + + def make_fp8_tensor(recipe: Fp8Recipe, x: torch.Tensor, x_scale: torch.Tensor): + """Transformer Engine not available: no FP8 tensor wrapper.""" + return None diff --git a/megatron/core/fusions/fused_bias_swiglu.py b/megatron/core/fusions/fused_bias_swiglu.py index 632470876c9..73cf9ed001c 100644 --- a/megatron/core/fusions/fused_bias_swiglu.py +++ b/megatron/core/fusions/fused_bias_swiglu.py @@ -6,6 +6,10 @@ import torch import torch.nn.functional as F +from megatron.core.fusions.fused_weighted_swiglu_quant import ( + fused_weighted_swiglu_quant, + fused_weighted_swiglu_quant_back, +) from megatron.core.jit import jit_fuser from megatron.core.utils import nvtx_decorator @@ -191,19 +195,38 @@ def backward(ctx, grad_output): class WeightedSwiGLUFunction(torch.autograd.Function): @staticmethod # bias is an optional argument - def forward(ctx, input, weights, fp8_input_store): + def forward(ctx, input, weights, fp8_input_store, config): input_for_backward = input.to(torch.float8_e4m3fn) if fp8_input_store else input + if config.moe_fp8_flow: + out_data, out_scale = fused_weighted_swiglu_quant(input, weights) + + from megatron.core.fp8_utils import make_fp8_tensor + + weighted_swiglu_out = make_fp8_tensor(config.fp8_recipe, out_data, out_scale) + else: + weighted_swiglu_out = weighted_swiglu(input, weights) ctx.save_for_backward(input_for_backward, weights) ctx.ori_input_dtype = input.dtype ctx.fp8_input_store = fp8_input_store - return weighted_swiglu(input, weights) + ctx.fp8_recipe = config.fp8_recipe + ctx.moe_fp8_flow = config.moe_fp8_flow + return weighted_swiglu_out @staticmethod def backward(ctx, grad_output): input, weights = ctx.saved_tensors input = input.to(ctx.ori_input_dtype) if ctx.fp8_input_store else input - tmp, wgrad = weighted_swiglu_back(grad_output, input, weights) - return tmp, wgrad, None + if ctx.moe_fp8_flow: + input_grad_data, input_grad_scale, wgrad = fused_weighted_swiglu_quant_back( + grad_output, input, weights + ) + + from megatron.core.fp8_utils import make_fp8_tensor + + tmp = make_fp8_tensor(ctx.fp8_recipe, input_grad_data, input_grad_scale) + else: + tmp, wgrad = weighted_swiglu_back(grad_output, input, weights) + return tmp, wgrad, None, None def bias_swiglu_impl(input, bias, fp8_input_store=False, cpu_offload_input=False): @@ -236,7 +259,7 @@ def bias_swiglu_impl(input, bias, fp8_input_store=False, cpu_offload_input=False return output if len(ori_shape) == 2 else output.view(ori_shape[0], ori_shape[1], -1) -def weighted_bias_swiglu_impl(input, bias, weights, fp8_input_store=False): +def weighted_bias_swiglu_impl(input, bias, weights, fp8_input_store=False, config=None): """ Token-wise-weighted bias swiglu fusion. """ @@ -246,7 +269,7 @@ def weighted_bias_swiglu_impl(input, bias, weights, fp8_input_store=False): if bias is not None: raise NotImplementedError("Bias is not supported for weighted swiglu fusion") else: - output = WeightedSwiGLUFunction.apply(input, weights, fp8_input_store) + output = WeightedSwiGLUFunction.apply(input, weights, fp8_input_store, config) return output if len(ori_shape) == 2 else output.view(ori_shape[0], ori_shape[1], -1) diff --git a/megatron/core/fusions/fused_weighted_swiglu_quant.py b/megatron/core/fusions/fused_weighted_swiglu_quant.py new file mode 100644 index 00000000000..f31219b62c6 --- /dev/null +++ b/megatron/core/fusions/fused_weighted_swiglu_quant.py @@ -0,0 +1,370 @@ +import torch +import triton +import triton.language as tl + +SCALE_MIN_THRES = 1e-10 +FP8_MAX_VALUE = {torch.float8_e4m3fn: 448.0, torch.float8_e5m2: 57344.0} + + +@triton.heuristics({"BLOCK_SN": lambda args: args["BLOCK_N"] // args["block_size"]}) +@triton.jit +def fused_weighted_swiglu_quant_kernel( + inp_ptr, + w_ptr, + out_data_ptr, + out_scale_ptr, + M, + H, + SN, + block_size: tl.constexpr, + fp8_max, + inp_stride_0, + inp_stride_1, + w_stride_0, + w_stride_1, + out_data_stride_0, + out_data_stride_1, + out_scale_stride_0, + out_scale_stride_1, + force_pow_2_scales: tl.constexpr, + SCALE_MIN_THRES: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): + pid_dim0 = tl.program_id(0) + pid_dim1 = tl.program_id(1) + + # split a and b, a: first BLOCK_N cols, b: next BLOCK_N cols + inp_block_ptr_a = tl.make_block_ptr( + base=inp_ptr, + shape=(M, 2 * H), + strides=(inp_stride_0, inp_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + inp_block_ptr_b = tl.make_block_ptr( + base=inp_ptr, + shape=(M, 2 * H), + strides=(inp_stride_0, inp_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N + H), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + w_block_ptr = tl.make_block_ptr( + base=w_ptr, + shape=(M, 1), + strides=(w_stride_0, w_stride_1), + offsets=(pid_dim0 * BLOCK_M, 0), + block_shape=(BLOCK_M, 1), + order=(1, 0), + ) + + a = tl.load(inp_block_ptr_a, boundary_check=(0, 1)).to(tl.float32) + b = tl.load(inp_block_ptr_b, boundary_check=(0, 1)).to(tl.float32) + w = tl.load(w_block_ptr, boundary_check=(0, 1)).to(tl.float32) + + # weighted swiglu = silu * b * w, silu(a) = a * sigmoid(a) + sig = tl.sigmoid(a) + silu = a * sig + data = silu * b * w + data = data.to(inp_ptr.type.element_ty).to(tl.float32) + + # Scale + data = tl.reshape(data, (BLOCK_M, BLOCK_SN, block_size)) + amax = tl.max(tl.abs(data), axis=2) + amax = tl.maximum(amax, SCALE_MIN_THRES) + scale = tl.fdiv(amax, fp8_max) + if force_pow_2_scales: + # scale = tl.exp2(tl.ceil(tl.log2(scale))) + s_bits = tl.cast(scale, tl.uint32, bitcast=True) + scale_exp = ((s_bits >> 23) & 0xFF) + tl.cast((s_bits & 0x7FFFFF) != 0, tl.uint32) + scale = tl.cast(scale_exp << 23, tl.float32, bitcast=True) + scale = tl.reshape(scale, (BLOCK_M, BLOCK_SN, 1)) + + # Quantize + data_q = tl.fdiv(data, scale) + + data_q = data_q.to(out_data_ptr.type.element_ty) + data_q = tl.reshape(data_q, (BLOCK_M, BLOCK_N)) + scale = scale.to(out_scale_ptr.type.element_ty) + scale = tl.reshape(scale, (BLOCK_M, BLOCK_SN)) + + out_data_block_ptr = tl.make_block_ptr( + base=out_data_ptr, + shape=(M, H), + strides=(out_data_stride_0, out_data_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + scale_block_ptr = tl.make_block_ptr( + base=out_scale_ptr, + shape=(M, SN), + strides=(out_scale_stride_0, out_scale_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + tl.store(out_data_block_ptr, data_q, boundary_check=(0, 1)) + tl.store(scale_block_ptr, scale, boundary_check=(0, 1)) + + +def fused_weighted_swiglu_quant( + x: torch.Tensor, weights: torch.Tensor, block_size: int = 128, fp8type=torch.float8_e4m3fn +): + assert x.dim() == 2 and weights.dim() == 2 and weights.shape[1] == 1, "weights's shape mismatch" + M, N2 = x.shape + if N2 % 2 != 0: + raise ValueError("Last dim of input must be even (2*H)") + H = N2 // 2 + + SN = (H + block_size - 1) // block_size + out_data = torch.empty((M, H), dtype=fp8type, device=x.device) + out_scale = torch.empty((M, SN), dtype=torch.float32, device=x.device) + + BLOCK_M = 32 + BLOCK_N = block_size + assert (H % BLOCK_N) == 0, "H must be divisible by BLOCK_N for this fixed setting" + grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(H, BLOCK_N)) + fused_weighted_swiglu_quant_kernel[grid]( + x, + weights, + out_data, + out_scale, + M, + H, + SN, + block_size, + FP8_MAX_VALUE[fp8type], + x.stride(0), + x.stride(1), + weights.stride(0), + weights.stride(1), + out_data.stride(0), + out_data.stride(1), + out_scale.stride(0), + out_scale.stride(1), + force_pow_2_scales=True, + SCALE_MIN_THRES=SCALE_MIN_THRES, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + return out_data, out_scale + + +@triton.heuristics({"BLOCK_SN": lambda args: args["BLOCK_N"] // args["block_size"]}) +@triton.jit +def fused_weighted_swiglu_quant_back_kernel( + gradout_ptr, + inp_ptr, + w_ptr, + out_input_grad_q_ptr, + out_input_grad_scale_ptr, + out_wgrad_ptr, + M, + H, + SN, + g_stride_0, + g_stride_1, + inp_stride_0, + inp_stride_1, + w_stride_0, + w_stride_1, + out_input_grad_q_stride_0, + out_input_grad_q_stride_1, + out_input_grad_scale_stride_0, + out_input_grad_scale_stride_1, + gw_stride_0: tl.constexpr, + block_size: tl.constexpr, + fp8_max: tl.constexpr, + force_pow_2_scales: tl.constexpr, + SCALE_MIN_THRES: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): + pid_dim0 = tl.program_id(0) + pid_dim1 = tl.program_id(1) + + a_ptr = tl.make_block_ptr( + base=inp_ptr, + shape=(M, 2 * H), + strides=(inp_stride_0, inp_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + b_ptr = tl.make_block_ptr( + base=inp_ptr, + shape=(M, 2 * H), + strides=(inp_stride_0, inp_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N + H), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + gradout_ptr_blk = tl.make_block_ptr( + base=gradout_ptr, + shape=(M, H), + strides=(g_stride_0, g_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + w_ptr_blk = tl.make_block_ptr( + base=w_ptr, + shape=(M, 1), + strides=(w_stride_0, w_stride_1), + offsets=(pid_dim0 * BLOCK_M, 0), + block_shape=(BLOCK_M, 1), + order=(1, 0), + ) + + a = tl.load(a_ptr, boundary_check=(0, 1)).to(tl.float32) + b = tl.load(b_ptr, boundary_check=(0, 1)).to(tl.float32) + g = tl.load(gradout_ptr_blk, boundary_check=(0, 1)).to(tl.float32) + w = tl.load(w_ptr_blk, boundary_check=(0, 1)).to(tl.float32) + + sig = tl.sigmoid(a) + silu = a * sig + dsilu = sig * (1.0 + a * (1.0 - sig)) + g_eff = g * w + dy1 = g_eff * dsilu * b + dy2 = g_eff * silu + dy1 = dy1.to(inp_ptr.type.element_ty).to(tl.float32) + dy2 = dy2.to(inp_ptr.type.element_ty).to(tl.float32) + + # grad_w + contrib = (silu * b) * g + part_sum = tl.sum(contrib, axis=1) + rows = pid_dim0 * BLOCK_M + tl.arange(0, BLOCK_M) + mask_rows = rows < M + tl.atomic_add(out_wgrad_ptr + rows * gw_stride_0, part_sum, mask=mask_rows) + + # Quantize + dy1 = tl.reshape(dy1, (BLOCK_M, BLOCK_SN, block_size)) + max1 = tl.max(tl.abs(dy1), axis=2) + max1 = tl.maximum(max1, SCALE_MIN_THRES) + scale1 = tl.fdiv(max1, fp8_max) + if force_pow_2_scales: + # scale1 = tl.exp2(tl.ceil(tl.log2(scale1))) + s_bits = tl.cast(scale1, tl.uint32, bitcast=True) + scale_exp1 = ((s_bits >> 23) & 0xFF) + tl.cast((s_bits & 0x7FFFFF) != 0, tl.uint32) + scale1 = tl.cast(scale_exp1 << 23, tl.float32, bitcast=True) + + dy1_q = tl.fdiv(dy1, tl.reshape(scale1, (BLOCK_M, BLOCK_SN, 1))) + + dy2 = tl.reshape(dy2, (BLOCK_M, BLOCK_SN, block_size)) + max2 = tl.max(tl.abs(dy2), axis=2) + max2 = tl.maximum(max2, SCALE_MIN_THRES) + scale2 = tl.fdiv(max2, fp8_max) + if force_pow_2_scales: + # scale2 = tl.exp2(tl.ceil(tl.log2(scale2))) + s_bits = tl.cast(scale2, tl.uint32, bitcast=True) + scale_exp2 = ((s_bits >> 23) & 0xFF) + tl.cast((s_bits & 0x7FFFFF) != 0, tl.uint32) + scale2 = tl.cast(scale_exp2 << 23, tl.float32, bitcast=True) + + dy2_q = tl.fdiv(dy2, tl.reshape(scale2, (BLOCK_M, BLOCK_SN, 1))) + + dy1_q = dy1_q.to(out_input_grad_q_ptr.type.element_ty) + dy1_q = tl.reshape(dy1_q, (BLOCK_M, BLOCK_N)) + dy2_q = dy2_q.to(out_input_grad_q_ptr.type.element_ty) + dy2_q = tl.reshape(dy2_q, (BLOCK_M, BLOCK_N)) + scale1 = scale1.to(out_input_grad_scale_ptr.type.element_ty) + scale2 = scale2.to(out_input_grad_scale_ptr.type.element_ty) + + gy1_ptr = tl.make_block_ptr( + base=out_input_grad_q_ptr, + shape=(M, 2 * H), + strides=(out_input_grad_q_stride_0, out_input_grad_q_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + gy2_ptr = tl.make_block_ptr( + base=out_input_grad_q_ptr, + shape=(M, 2 * H), + strides=(out_input_grad_q_stride_0, out_input_grad_q_stride_1), + offsets=(pid_dim0 * BLOCK_M, H + pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + tl.store(gy1_ptr, dy1_q, boundary_check=(0, 1)) + tl.store(gy2_ptr, dy2_q, boundary_check=(0, 1)) + + s1_ptr = tl.make_block_ptr( + base=out_input_grad_scale_ptr, + shape=(M, 2 * SN), + strides=(out_input_grad_scale_stride_0, out_input_grad_scale_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + s2_ptr = tl.make_block_ptr( + base=out_input_grad_scale_ptr, + shape=(M, 2 * SN), + strides=(out_input_grad_scale_stride_0, out_input_grad_scale_stride_1), + offsets=(pid_dim0 * BLOCK_M, SN + pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + tl.store(s1_ptr, scale1, boundary_check=(0, 1)) + tl.store(s2_ptr, scale2, boundary_check=(0, 1)) + + +def fused_weighted_swiglu_quant_back( + grad_output: torch.Tensor, + input: torch.Tensor, + weights: torch.Tensor, + block_size: int = 128, + fp8type=torch.float8_e4m3fn, +): + assert ( + input.shape[:-1] == grad_output.shape[:-1] and input.shape[-1] == 2 * grad_output.shape[-1] + ), "shape mismatch" + assert weights.shape[0] == grad_output.shape[0], "shape mismatch" + + device = grad_output.device + M, H = grad_output.shape + SN = (H + block_size - 1) // block_size + + out_input_grad_q = torch.empty((M, 2 * H), device=device, dtype=fp8type) + out_input_grad_scale = torch.empty((M, 2 * SN), device=device, dtype=torch.float32) + out_wgrad = torch.zeros((M,), device=device, dtype=torch.float32) + + BLOCK_M = 32 + BLOCK_N = block_size + assert (H % BLOCK_N) == 0, "H must be divisible by BLOCK_N for this fixed setting" + grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(H, BLOCK_N)) + fused_weighted_swiglu_quant_back_kernel[grid]( + grad_output, + input, + weights, + out_input_grad_q, + out_input_grad_scale, + out_wgrad, + M, + H, + SN, + grad_output.stride(0), + grad_output.stride(1), + input.stride(0), + input.stride(1), + weights.stride(0), + weights.stride(1), + out_input_grad_q.stride(0), + out_input_grad_q.stride(1), + out_input_grad_scale.stride(0), + out_input_grad_scale.stride(1), + 1, + block_size, + FP8_MAX_VALUE[fp8type], + force_pow_2_scales=True, + SCALE_MIN_THRES=SCALE_MIN_THRES, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + + out_wgrad = out_wgrad.reshape(M, 1).to(weights.dtype) + return out_input_grad_q, out_input_grad_scale, out_wgrad diff --git a/megatron/core/transformer/mlp.py b/megatron/core/transformer/mlp.py index 8a19fef87ec..e17f1d624fb 100644 --- a/megatron/core/transformer/mlp.py +++ b/megatron/core/transformer/mlp.py @@ -270,6 +270,7 @@ def forward( bias_parallel, per_token_scale.unsqueeze(-1), self.config.activation_func_fp8_input_store, + self.config, ) elif self.activation_func == quick_gelu and self.config.gated_linear_unit: intermediate_parallel = weighted_bias_quick_geglu_impl( diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 34e9fb17a02..e3e817df2ae 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -287,6 +287,7 @@ def bias_act_func(self, intermediate_parallel, bias_parallel, permuted_probs): bias_parallel, permuted_probs, self.config.activation_func_fp8_input_store, + self.config, ) elif self.activation_func == quick_gelu and self.config.gated_linear_unit: intermediate_parallel = weighted_bias_quick_geglu_impl( diff --git a/megatron/core/transformer/moe/fused_a2a.py b/megatron/core/transformer/moe/fused_a2a.py index 39f50a4a670..5e8f7faf6c7 100644 --- a/megatron/core/transformer/moe/fused_a2a.py +++ b/megatron/core/transformer/moe/fused_a2a.py @@ -15,6 +15,8 @@ import torch +from megatron.core.fp8_utils import fp8_quantize, make_fp8_tensor + _buffer = None @@ -27,7 +29,8 @@ def get_hidden_bytes(x: torch.Tensor) -> int: Returns: int: Number of hidden bytes """ - return x.size(1) * max(x.element_size(), 2) + t = x[0] if isinstance(x, tuple) else x + return t.size(1) * max(t.element_size(), 2) def get_buffer(group: torch.distributed.ProcessGroup, hidden_bytes: int): @@ -79,8 +82,13 @@ def forward( group, async_finish=False, allocate_on_comm_stream=False, + config=None, ): """Forward pass of fused dispatch.""" + # fp8 flow : quantize before dispatch + if config is not None and config.moe_fp8_flow: + x = fp8_quantize(config.fp8_recipe, x) + previous_event = None if async_finish: previous_event = EventOverlap(EventHandle()) @@ -134,6 +142,10 @@ def forward( ctx.allocate_on_comm_stream = allocate_on_comm_stream tokens_per_expert = torch.tensor(num_recv_tokens_per_expert_list) + # fp8 flow : create qtensor directly into expert up-projection + if config is not None and config.moe_fp8_flow: + recv_x = make_fp8_tensor(config.fp8_recipe, *recv_x) + return (recv_x, recv_token_indices, recv_token_probs, tokens_per_expert, handle) @staticmethod @@ -157,14 +169,16 @@ def backward( # Make sure current stream is synchronized if ctx.async_finish: after_event.current_stream_wait() - return grad_x, None, grad_token_probs, None, None, None, None + return grad_x, None, grad_token_probs, None, None, None, None, None class FusedCombine(torch.autograd.Function): """Fused combine operation for MoE output combining computation and communication.""" @staticmethod - def forward(ctx, x, group, handle, async_finish=False, allocate_on_comm_stream=False): + def forward( + ctx, x, group, handle, async_finish=False, allocate_on_comm_stream=False, config=None + ): """Forward pass of fused combine.""" previous_event = None if async_finish: @@ -185,17 +199,23 @@ def forward(ctx, x, group, handle, async_finish=False, allocate_on_comm_stream=F ctx.group = group ctx.async_finish = async_finish ctx.allocate_on_comm_stream = allocate_on_comm_stream + ctx.moe_fp8_flow = config.moe_fp8_flow if config is not None else False + ctx.fp8_recipe = config.fp8_recipe if config is not None else None return combined_x, None @staticmethod def backward(ctx, grad_output, previous_event=None): """Backward pass of fused combine.""" + # fp8 flow : quantize before dispatch + if ctx.moe_fp8_flow: + grad_output = fp8_quantize(ctx.fp8_recipe, grad_output) + previous_event = None if ctx.async_finish: previous_event = EventOverlap(EventHandle()) buffer = get_buffer(ctx.group, get_hidden_bytes(grad_output)) grad_x, _, _, _, _, after_event = buffer.dispatch( - grad_output.contiguous(), + grad_output.contiguous() if isinstance(grad_output, torch.Tensor) else grad_output, handle=ctx.handle, previous_event=previous_event, async_finish=ctx.async_finish, @@ -204,7 +224,12 @@ def backward(ctx, grad_output, previous_event=None): # Make sure current stream is synchronized if ctx.async_finish: after_event.current_stream_wait() - return grad_x, None, None, None, None + + # fp8 flow : create qtensor directly into expert up-projection + if ctx.moe_fp8_flow: + grad_x = make_fp8_tensor(ctx.fp8_recipe, *grad_x) + + return grad_x, None, None, None, None, None if HAVE_DEEP_EP: @@ -217,6 +242,7 @@ def fused_dispatch( group, async_finish=False, allocate_on_comm_stream=False, + config=None, ): """Perform fused dispatch operation if deep_ep is available. @@ -239,9 +265,12 @@ def fused_dispatch( group, async_finish, allocate_on_comm_stream, + config, ) - def fused_combine(x, group, handle, async_finish=False, allocate_on_comm_stream=False): + def fused_combine( + x, group, handle, async_finish=False, allocate_on_comm_stream=False, config=None + ): """Perform fused combine operation if deep_ep is available. Args: @@ -253,7 +282,7 @@ def fused_combine(x, group, handle, async_finish=False, allocate_on_comm_stream= Returns: Result of FusedCombine """ - return FusedCombine.apply(x, group, handle, async_finish, allocate_on_comm_stream) + return FusedCombine.apply(x, group, handle, async_finish, allocate_on_comm_stream, config) def set_deepep_num_sms(num_sms): """Sets the number of SMs to use for DeepEP""" diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index 35b567679fe..45e2f1ccb8f 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -227,6 +227,9 @@ def __init__( and "moe" in config.recompute_modules and config.cuda_graph_impl != 'local' ) + self.moe_expert_recompute = ( + config.recompute_granularity == 'selective' and "moe_expert" in config.recompute_modules + ) self.shared_experts_recompute = ( config.recompute_granularity == 'selective' and "shared_experts" in config.recompute_modules diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index 2466ffc0825..a93f3a64179 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -1191,6 +1191,7 @@ def dispatch( self.group, async_finish=async_finish, allocate_on_comm_stream=allocate_on_comm_stream, + config=self.config, ) ) self.handle = handle @@ -1248,6 +1249,7 @@ def combine( self.handle, async_finish=async_finish, allocate_on_comm_stream=allocate_on_comm_stream, + config=self.config, ) # Release the handle after combine operation self.handle = None diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 975f971fbc9..a7e036b2837 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -667,6 +667,10 @@ class TransformerConfig(ModelParallelConfig): """[Compatibility alias for moe_router_padding_for_quantization] Enabling this will also enable moe_router_padding_for_quantization.""" + moe_fp8_flow: Optional[bool] = False + """Whether to quantize activations to FP8 before DeepEP token dispatch to reduce + communication bandwidth and feed them directly into expert up-projection (GroupedMLP/GEMM).""" + moe_router_num_groups: Optional[int] = None """Number of groups to divide experts into for group-limited routing. When using group-limited routing: @@ -1369,6 +1373,7 @@ def __post_init__(self): "mla_up_proj", "mlp", "moe", + "moe_expert", "shared_experts", } invalid_modules = set(self.recompute_modules) - allowed_modules @@ -1419,6 +1424,15 @@ def __post_init__(self): f"but your version is {get_te_version()}." ) + if "moe" in self.recompute_modules and "moe_expert" in self.recompute_modules: + raise ValueError( + "moe in recompute_modules is not supported with moe_expert in recompute_modules" + ) + if "moe_expert" in self.recompute_modules and ( + not self.fp8 or (self.fp8 and self.fp8_recipe != 'blockwise') + ): + raise ValueError("moe_expert in recompute_modules is only supported with fp8 blockwise recipe") + if self.moe_layer_recompute: warnings.warn( "--moe-layer-recompute is deprecated. " @@ -1859,6 +1873,10 @@ def __post_init__(self): "moe_router_padding_for_quantization." ) + if self.moe_fp8_flow: + if self.fp8 is None or (self.fp8 is not None and self.fp8_recipe != "blockwise"): + raise ValueError("moe_fp8_flow only support blockwise.") + if ( self.moe_router_topk == 1 and self.moe_router_score_function == "softmax" diff --git a/tests/unit_tests/fusions/test_swiglu_fusion.py b/tests/unit_tests/fusions/test_swiglu_fusion.py index c72679cd047..47b19f680d5 100644 --- a/tests/unit_tests/fusions/test_swiglu_fusion.py +++ b/tests/unit_tests/fusions/test_swiglu_fusion.py @@ -1,7 +1,18 @@ import pytest import torch -from megatron.core.fusions.fused_bias_swiglu import bias_swiglu_impl, weighted_bias_swiglu_impl +from megatron.core.enums import Fp8Recipe +from megatron.core.fp8_utils import fp8_quantize +from megatron.core.fusions.fused_bias_swiglu import ( + bias_swiglu_impl, + weighted_bias_swiglu_impl, + weighted_swiglu, + weighted_swiglu_back, +) +from megatron.core.fusions.fused_weighted_swiglu_quant import ( + fused_weighted_swiglu_quant, + fused_weighted_swiglu_quant_back, +) @pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float32]) @@ -39,3 +50,90 @@ def test_weighted_bias_swiglu(input_dtype): assert weights_2.grad.dtype == weights.grad.dtype if input_dtype == torch.float32: assert torch.allclose(weights.grad, weights_2.grad, **tols) + + +def _test_fused_weighted_bias_swiglu_quant( + num_tokens, topK, MOE_INTERMEDIATE_SIZE, BENCHMARK=False +): + num_out_tokens = num_tokens * topK + x = torch.rand((num_out_tokens, 2 * MOE_INTERMEDIATE_SIZE), dtype=torch.float16).cuda() + weights = torch.rand((num_out_tokens, 1), dtype=torch.float32).cuda() + grad_output = torch.rand((num_out_tokens, MOE_INTERMEDIATE_SIZE), dtype=x.dtype).cuda() + tols = dict(rtol=2.0e-2, atol=1.0e-3) + + # Forward: fused kernel vs non-fused weighted_swiglu + TE blockwise quant. + fused_data, fused_scale = fused_weighted_swiglu_quant(x, weights) + + ref_out = weighted_swiglu(x, weights) + ref_data, ref_scale = fp8_quantize(Fp8Recipe.blockwise, ref_out) + + torch.testing.assert_close(fused_data.view(torch.uint8), ref_data, rtol=0, atol=1) + torch.testing.assert_close(fused_scale, ref_scale) + + # Backward: fused kernel vs non-fused weighted_swiglu_back + TE blockwise quant. + fused_dgrad_data, fused_dgrad_scale, fused_wgrad = fused_weighted_swiglu_quant_back( + grad_output, x, weights + ) + + ref_dgrad, ref_wgrad = weighted_swiglu_back(grad_output, x, weights) + ref_dgrad_data, ref_dgrad_scale = fp8_quantize(Fp8Recipe.blockwise, ref_dgrad) + + torch.testing.assert_close(fused_dgrad_data.view(torch.uint8), ref_dgrad_data, rtol=0, atol=1) + torch.testing.assert_close(fused_dgrad_scale, ref_dgrad_scale) + torch.testing.assert_close(fused_wgrad, ref_wgrad) + + if BENCHMARK: + + def _run_fused_fwd(): + _ = fused_weighted_swiglu_quant(x, weights) + + def _run_ref_fwd(): + ref_out_ = weighted_swiglu(x, weights) + _ = fp8_quantize(Fp8Recipe.blockwise, ref_out_) + + def _run_fused_bwd(): + _ = fused_weighted_swiglu_quant_back(grad_output, x, weights) + + def _run_ref_bwd(): + ref_dgrad_, _ = weighted_swiglu_back(grad_output, x, weights) + _ = fp8_quantize(Fp8Recipe.blockwise, ref_dgrad_) + + def _benchmark(fn, warmup=20, iters=100): + for _ in range(warmup): + fn() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(iters): + fn() + end.record() + torch.cuda.synchronize() + return start.elapsed_time(end) / iters + + fused_fwd_ms = _benchmark(_run_fused_fwd) + ref_fwd_ms = _benchmark(_run_ref_fwd) + fused_bwd_ms = _benchmark(_run_fused_bwd) + ref_bwd_ms = _benchmark(_run_ref_bwd) + + print( + "[perf] fused_weighted_swiglu_quant " + f"fwd: fused={fused_fwd_ms:.3f} ms, ref={ref_fwd_ms:.3f} ms, speedup={ref_fwd_ms/fused_fwd_ms:.3f}" + f"bwd: fused={fused_bwd_ms:.3f} ms, ref={ref_bwd_ms:.3f} ms, speedup={ref_bwd_ms/fused_bwd_ms:.3f}" + ) + assert fused_fwd_ms > 0.0 and ref_fwd_ms > 0.0 + assert fused_bwd_ms > 0.0 and ref_bwd_ms > 0.0 + + +@pytest.mark.parametrize( + "num_tokens, topK, MOE_INTERMEDIATE_SIZE", [(4096, 7, 256), (4096, 6, 1408)] +) +def test_fused_weighted_bias_swiglu_quant(num_tokens, topK, MOE_INTERMEDIATE_SIZE): + BENCHMARK = False + + _test_fused_weighted_bias_swiglu_quant( + num_tokens=num_tokens, + topK=topK, + MOE_INTERMEDIATE_SIZE=MOE_INTERMEDIATE_SIZE, + BENCHMARK=BENCHMARK, + ) diff --git a/tests/unit_tests/transformer/moe/test_moe_layer.py b/tests/unit_tests/transformer/moe/test_moe_layer.py index 0004b7fef98..b71f2172267 100644 --- a/tests/unit_tests/transformer/moe/test_moe_layer.py +++ b/tests/unit_tests/transformer/moe/test_moe_layer.py @@ -385,3 +385,107 @@ def test_moe_layer_recompute_forward_backward( def teardown_method(self, method): Utils.destroy_model_parallel() + + +class TestMoELayerFP8Flow: + """Test MoE fp8_flow with DeepEP dispatcher and moe_expert recompute.""" + + @staticmethod + def _is_deep_ep_available(): + from megatron.core.transformer.moe.fused_a2a import HAVE_DEEP_EP + + return HAVE_DEEP_EP + + def setup_method(self, method): + pass + + @pytest.mark.parametrize("num_moe_experts", [2, 4]) + @pytest.mark.parametrize("tp_size,ep_size", [(1, 2)]) + def test_moe_layer_fp8_flow_with_deepep_and_moe_expert_recompute( + self, num_moe_experts, tp_size, ep_size + ): + + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if not self._is_deep_ep_available(): + pytest.skip("DeepEP is not available") + + from contextlib import nullcontext + + from megatron.core.fp8_utils import get_fp8_context + + Utils.initialize_model_parallel( + tensor_model_parallel_size=tp_size, expert_model_parallel_size=ep_size + ) + _set_random_seed(seed_=123, data_parallel_random_init=False) + + hidden_size = 64 + sequence_length = 32 + micro_batch_size = 2 + + transformer_config = TransformerConfig( + num_layers=1, + hidden_size=hidden_size, + num_attention_heads=4, + num_moe_experts=num_moe_experts, + use_cpu_initialization=False, + moe_token_dispatcher_type="flex", + moe_flex_dispatcher_backend="deepep", + moe_router_load_balancing_type="aux_loss", + moe_router_topk=2, + moe_aux_loss_coeff=0.01, + moe_grouped_gemm=False, + moe_ffn_hidden_size=256, + add_bias_linear=False, + tensor_model_parallel_size=tp_size, + expert_model_parallel_size=ep_size, + sequence_parallel=tp_size > 1, + # Enable fp8 flow + moe expert recompute. + fp8="e4m3", + fp8_recipe="blockwise", + moe_fp8_flow=True, + moe_permute_fusion=True, + recompute_granularity="selective", + recompute_modules=["moe_expert"], + bf16=True, + params_dtype=torch.bfloat16, + ) + + transformer_layer_submodules = get_gpt_layer_with_transformer_engine_submodules( + num_experts=num_moe_experts, moe_grouped_gemm=True + ) + moe_layer = MoELayer(transformer_config, transformer_layer_submodules.mlp.submodules).cuda() + + hidden_states = torch.randn( + sequence_length, + micro_batch_size, + hidden_size, + device=torch.cuda.current_device(), + dtype=torch.bfloat16, + requires_grad=True, + ) + + fp8_context = ( + get_fp8_context(transformer_config, 0) if transformer_config.fp8 else nullcontext() + ) + with fp8_context: + output, _ = moe_layer(hidden_states) + assert output.dtype == torch.bfloat16, f"Expected bf16 output, got {output.dtype}" + assert output.shape == hidden_states.shape, "Output shape mismatch" + + loss = output.sum() + loss.backward() + + assert hidden_states.grad is not None, "Input gradients should exist" + assert ( + hidden_states.grad.dtype == torch.bfloat16 + ), f"Expected bf16 gradients, got {hidden_states.grad.dtype}" + + for name, param in moe_layer.named_parameters(): + if param.requires_grad: + assert param.grad is not None, f"Gradient for {name} should exist" + + Utils.destroy_model_parallel() + + def teardown_method(self, method): + Utils.destroy_model_parallel()