Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2385,6 +2385,14 @@ def fused_apply_rotary_pos_emb_thd(
fused_sort_chunks_by_index_with_probs = None
fused_unpermute = None

try:
from transformer_engine.pytorch.permutation import moe_permute_and_pad_with_probs

fused_permute_and_pad_with_probs = moe_permute_and_pad_with_probs

except ImportError:
fused_permute_and_pad_with_probs = None

try:
from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy

Expand Down
74 changes: 74 additions & 0 deletions megatron/core/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,3 +783,77 @@ def prepare_model_for_fp8_inference(model):
"prepare_model_for_fp8_inference requires Transformer Engine to be installed. "
"Please install transformer-engine to use FP8 inference."
)


if HAVE_TE:
from functools import lru_cache
import transformer_engine_torch as tex
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockQuantizer,
Float8BlockwiseQTensor,
)

@lru_cache(maxsize=None)
def _get_fp8_quantizer(recipe, all_gather_usage=False):
if recipe == Fp8Recipe.blockwise:
return Float8BlockQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
rowwise=True,
columnwise=False,
amax_epsilon=1e-10,
force_pow_2_scales=True,
block_scaling_dim=1,
all_gather_usage=all_gather_usage,
)

return None

def fp8_quantize(recipe: Fp8Recipe, x: torch.Tensor, all_gather_usage=False):
q = _get_fp8_quantizer(recipe, all_gather_usage)
if q is None:
return x

quantized_tensor = q(x)
if recipe == Fp8Recipe.blockwise:
if quantized_tensor._data_format == tex.Float8BlockScaleTensorFormat.COMPACT:
return (quantized_tensor._rowwise_data, quantized_tensor._rowwise_scale_inv)
else:
return (
quantized_tensor._rowwise_data,
quantized_tensor._rowwise_scale_inv.T.contiguous(),
)

return x

def make_fp8_tensor(recipe: Fp8Recipe, x: torch.Tensor, x_scale: torch.Tensor):
q = _get_fp8_quantizer(recipe)
if q is None:
return None

if recipe == Fp8Recipe.blockwise:
# To accelerate fp8flow and reduce redundant cases of T.contiguous()
# scale_inv use COMPACT
return Float8BlockwiseQTensor(
shape=x.shape,
dtype=torch.bfloat16,
rowwise_data=x.view(torch.uint8),
rowwise_scale_inv=x_scale,
columnwise_data=None,
columnwise_scale_inv=None,
fp8_dtype=tex.DType.kFloat8E4M3,
quantizer=q,
is_2D_scaled=False,
requires_grad=x.requires_grad,
data_format=tex.Float8BlockScaleTensorFormat.GEMM_READY,
)
return None

else:

def fp8_quantize(recipe: Fp8Recipe, x: torch.Tensor, all_gather_usage=False):
"""Transformer Engine not available: passthrough tensor."""
return x

def make_fp8_tensor(recipe: Fp8Recipe, x: torch.Tensor, x_scale: torch.Tensor):
"""Transformer Engine not available: no FP8 tensor wrapper."""
return None
35 changes: 29 additions & 6 deletions megatron/core/fusions/fused_bias_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@

from megatron.core.jit import jit_fuser
from megatron.core.utils import nvtx_decorator
from megatron.core.fusions.fused_weighted_swiglu_quant import (
fused_weighted_swiglu_quant,
fused_weighted_swiglu_quant_back,
)

###### BIAS SWIGLU FUSION/ NO AUTOGRAD ################

Expand Down Expand Up @@ -191,19 +195,38 @@ def backward(ctx, grad_output):
class WeightedSwiGLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, weights, fp8_input_store):
def forward(ctx, input, weights, fp8_input_store, config):
input_for_backward = input.to(torch.float8_e4m3fn) if fp8_input_store else input
if config.moe_fp8_flow:
out_data, out_scale = fused_weighted_swiglu_quant(input, weights)

from megatron.core.fp8_utils import make_fp8_tensor

weighted_swiglu_out = make_fp8_tensor(config.fp8_recipe, out_data, out_scale)
else:
weighted_swiglu_out = weighted_swiglu(input, weights)
ctx.save_for_backward(input_for_backward, weights)
ctx.ori_input_dtype = input.dtype
ctx.fp8_input_store = fp8_input_store
return weighted_swiglu(input, weights)
ctx.fp8_recipe = config.fp8_recipe
ctx.moe_fp8_flow = config.moe_fp8_flow
return weighted_swiglu_out

@staticmethod
def backward(ctx, grad_output):
input, weights = ctx.saved_tensors
input = input.to(ctx.ori_input_dtype) if ctx.fp8_input_store else input
tmp, wgrad = weighted_swiglu_back(grad_output, input, weights)
return tmp, wgrad, None
if ctx.moe_fp8_flow:
input_grad_data, input_grad_scale, wgrad = fused_weighted_swiglu_quant_back(
grad_output, input, weights
)

from megatron.core.fp8_utils import make_fp8_tensor

tmp = make_fp8_tensor(ctx.fp8_recipe, input_grad_data, input_grad_scale)
else:
tmp, wgrad = weighted_swiglu_back(grad_output, input, weights)
return tmp, wgrad, None, None


def bias_swiglu_impl(input, bias, fp8_input_store=False, cpu_offload_input=False):
Expand Down Expand Up @@ -236,7 +259,7 @@ def bias_swiglu_impl(input, bias, fp8_input_store=False, cpu_offload_input=False
return output if len(ori_shape) == 2 else output.view(ori_shape[0], ori_shape[1], -1)


def weighted_bias_swiglu_impl(input, bias, weights, fp8_input_store=False):
def weighted_bias_swiglu_impl(input, bias, weights, fp8_input_store=False, config=None):
"""
Token-wise-weighted bias swiglu fusion.
"""
Expand All @@ -246,7 +269,7 @@ def weighted_bias_swiglu_impl(input, bias, weights, fp8_input_store=False):
if bias is not None:
raise NotImplementedError("Bias is not supported for weighted swiglu fusion")
else:
output = WeightedSwiGLUFunction.apply(input, weights, fp8_input_store)
output = WeightedSwiGLUFunction.apply(input, weights, fp8_input_store, config)

return output if len(ori_shape) == 2 else output.view(ori_shape[0], ori_shape[1], -1)

Expand Down
Loading