Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 56 additions & 72 deletions transformer_engine/plugin/core/backends/flagos/flagos.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,23 @@
rmsnorm_bwd_fl,
multi_tensor_scale_fl,
multi_tensor_adam_fl,
multi_tensor_adam_param_remainder_fl,
multi_tensor_l2_norm_fl,
generic_gemm_fl,
gelu_fl,
geglu_fl,
qgelu_fl,
qgeglu_fl,
relu_fl,
reglu_fl,
moe_permute_fwd_fl,
moe_unpermute_bwd_fl,
moe_unpermute_fwd_fl,
moe_permute_bwd_fl,
)


def _check_flagos_available() -> bool:
return True


class FlagOSBackend(TEFLBackendBase):
@staticmethod
def check_available() -> bool:
Expand All @@ -35,7 +42,6 @@ def is_available(self) -> bool:
def get_attention_backend(self, attention_params=None):
from packaging.version import Version as PkgVersion
from ...logger_manager import get_logger

logger = get_logger()

# Read environment variables to determine which backends to enable
Expand Down Expand Up @@ -65,7 +71,7 @@ def get_attention_backend(self, attention_params=None):
available_backends,
)

##### transformer_engine/pytorch/csrc/extensions/pybind.cpp #####
##### transformer_engine/pytorch/csrc/extensions/pybind.cpp #####
def generic_gemm(
self,
A: Any,
Expand All @@ -92,28 +98,10 @@ def generic_gemm(
beta: Optional[float] = None,
) -> List[Any]:
return generic_gemm_fl(
A,
transA,
B,
transB,
D,
quantizer,
output_dtype,
bias,
bias_type,
gelu,
gelu_in,
grad,
workspace,
workspace_size,
accumulate,
use_split_accumulator,
comm_overlap,
comm_type,
extra_output,
bulk_overlap,
alpha,
beta,
A, transA, B, transB, D, quantizer, output_dtype,
bias, bias_type, gelu, gelu_in, grad, workspace, workspace_size,
accumulate, use_split_accumulator, comm_overlap, comm_type,
extra_output, bulk_overlap, alpha, beta
)

# Other granular functions
Expand All @@ -129,16 +117,10 @@ def rmsnorm_fwd(
zero_centered_gamma: bool,
) -> List[Any]:
return rmsnorm_fwd_fl(
input=input,
weight=weight,
eps=eps,
ln_out=ln_out,
quantizer=quantizer,
odtype=otype,
sm_margin=sm_margin,
zero_centered_gamma=zero_centered_gamma,
input=input, weight=weight, eps=eps, ln_out=ln_out,
quantizer=quantizer, odtype=otype,
sm_margin=sm_margin, zero_centered_gamma=zero_centered_gamma,
)

def rmsnorm_bwd(
self,
dz: torch.Tensor,
Expand All @@ -149,14 +131,9 @@ def rmsnorm_bwd(
zero_centered_gamma: bool,
) -> List[Any]:
return rmsnorm_bwd_fl(
dy=dz,
x=x,
rsigma=rsigma,
gamma=gamma,
sm_margin=sm_margin,
zero_centered_gamma=zero_centered_gamma,
dy=dz, x=x, rsigma=rsigma, gamma=gamma,
sm_margin=sm_margin, zero_centered_gamma=zero_centered_gamma
)

def get_fused_attn_backend(self, *args, **kwargs) -> int:
return NVTE_Fused_Attn_Backend.NVTE_No_Backend

Expand All @@ -169,7 +146,6 @@ def multi_tensor_scale(
scale: float,
) -> None:
return multi_tensor_scale_fl(chunk_size, noop_flag, tensor_lists, scale)

def multi_tensor_l2norm(
self,
chunk_size: int,
Expand All @@ -178,7 +154,6 @@ def multi_tensor_l2norm(
per_tensor: Optional[bool] = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
return multi_tensor_l2_norm_fl(chunk_size, noop_flag, tensor_lists, per_tensor)

def multi_tensor_adam(
self,
chunk_size: int,
Expand All @@ -194,19 +169,9 @@ def multi_tensor_adam(
weight_decay: float,
) -> None:
return multi_tensor_adam_fl(
chunk_size,
noop_flag,
tensor_lists,
lr,
beta1,
beta2,
epsilon,
step,
mode,
bias_correction,
weight_decay,
chunk_size, noop_flag, tensor_lists, lr, beta1, beta2, epsilon,
step, mode, bias_correction, weight_decay,
)

def multi_tensor_adam_param_remainder(
self,
chunk_size: int,
Expand All @@ -222,31 +187,50 @@ def multi_tensor_adam_param_remainder(
weight_decay: float,
) -> None:
return multi_tensor_adam_param_remainder_fl(
chunk_size,
noop_flag,
tensor_lists,
lr,
beta1,
beta2,
epsilon,
step,
mode,
bias_correction,
weight_decay,
chunk_size, noop_flag, tensor_lists,
lr, beta1, beta2, epsilon,
step, mode, bias_correction, weight_decay,
)

# Misc
def get_cublasLt_version(self) -> int:
return 110000

def get_cudnn_version(self) -> int:
return 90000

def get_num_cublas_streams(self) -> int:
return 0

############## class func #################################
############## class func #################################
def get_flash_attention_class(self):
from .attention.dot_product_attention.backends import FlashAttentionFL

return FlashAttentionFL

def gelu(self, input: torch.Tensor, quantizer: Any) -> Any:
return gelu_fl(input, quantizer)

def geglu(self, input: torch.Tensor, quantizer: Any) -> Any:
return geglu_fl(input, quantizer)

def qgelu(self, input: torch.Tensor, quantizer: Any) -> Any:
return qgelu_fl(input, quantizer)

def qgeglu(self, input: torch.Tensor, quantizer: Any) -> Any:
return qgeglu_fl(input, quantizer)

def relu(self, input: torch.Tensor, quantizer: Any) -> Any:
return relu_fl(input, quantizer)

def reglu(self, input: torch.Tensor, quantizer: Any) -> Any:
return reglu_fl(input, quantizer)

def moe_permute_fwd(self, *args, **kwargs) -> Any:
return moe_permute_fwd_fl(*args, **kwargs)

def moe_unpermute_bwd(self, *args, **kwargs) -> Any:
return moe_unpermute_bwd_fl(*args, **kwargs)

def moe_unpermute_fwd(self, *args, **kwargs) -> Any:
return moe_unpermute_fwd_fl(*args, **kwargs)

def moe_permute_bwd(self, *args, **kwargs) -> Any:
return moe_permute_bwd_fl(*args, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@
from .rmsnorm import *
from .fused_adam import *
from .multi_tensor import *
from .activation import *
from .moe_permute import *
30 changes: 30 additions & 0 deletions transformer_engine/plugin/core/backends/flagos/impl/activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch
from typing import Any
import flag_gems


def gelu_fl(input: torch.Tensor, quantizer: Any) -> torch.Tensor:
return flag_gems.gelu(input, approximate="tanh")


def geglu_fl(input: torch.Tensor, quantizer: Any) -> torch.Tensor:
a, b = input.chunk(2, dim=-1)
return flag_gems.gelu(a, approximate="tanh") * b


def qgelu_fl(input: torch.Tensor, quantizer: Any) -> torch.Tensor:
return input * flag_gems.sigmoid(1.702 * input)


def qgeglu_fl(input: torch.Tensor, quantizer: Any) -> torch.Tensor:
a, b = input.chunk(2, dim=-1)
return a * flag_gems.sigmoid(1.702 * a) * b


def relu_fl(input: torch.Tensor, quantizer: Any) -> torch.Tensor:
return flag_gems.relu(input)


def reglu_fl(input: torch.Tensor, quantizer: Any) -> torch.Tensor:
a, b = input.chunk(2, dim=-1)
return flag_gems.relu(a) * b
Loading
Loading