Skip to content

Commit

Permalink
[Performance, Hardware] MoE weights padding to AMD MI300x GPUs (#1836)
Browse files Browse the repository at this point in the history
  • Loading branch information
HaiShaw authored Oct 30, 2024
1 parent 4e2af03 commit 5f65e2b
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
7 changes: 4 additions & 3 deletions python/sglang/srt/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from vllm.logger import init_logger

logger = init_logger(__name__)
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0


@triton.jit
Expand Down Expand Up @@ -263,7 +264,7 @@ def invoke_fused_moe_kernel(
expert_ids,
num_tokens_post_padded,
B.shape[1],
B.shape[2],
B.shape[2] - padding_size,
sorted_token_ids.shape[0],
topk_ids.numel(),
A.stride(0),
Expand Down Expand Up @@ -464,7 +465,7 @@ def fused_experts(
a2_scale: Optional[torch.Tensor] = None,
):
# Check constraints.
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert hidden_states.shape[1] == w1.shape[2] - padding_size, "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
Expand All @@ -481,7 +482,7 @@ def fused_experts(
get_config_func = functools.partial(
try_get_optimal_moe_config,
w1.shape,
w2.shape,
(w2.shape[0], w2.shape[1], w2.shape[2] - padding_size),
topk_ids.shape[1],
"float8" if use_fp8 else None,
override_config=override_config,
Expand Down
28 changes: 28 additions & 0 deletions python/sglang/srt/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Adapted from
# https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe
import os
from abc import abstractmethod
from typing import List, Optional, Tuple

import torch
import torch.nn.functional as F
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
Expand All @@ -18,6 +20,7 @@
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.utils import set_weight_attrs

from sglang.srt.layers.fused_moe.fused_moe import padding_size
from sglang.srt.utils import is_hip

logger = init_logger(__name__)
Expand Down Expand Up @@ -506,6 +509,19 @@ def process_weights_after_loading(self, layer: Module) -> None:
)
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)

# If ROCm, apply weight padding (min. Mem channel contention) only if set
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
return

# If checkpoint is fp8, we need to handle that the
Expand Down Expand Up @@ -572,6 +588,18 @@ def process_weights_after_loading(self, layer: Module) -> None:
start += shard_size

layer.w13_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)
# If ROCm, apply weight padding (min. Mem channel contention) only if set
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
return

def apply(
Expand Down

0 comments on commit 5f65e2b

Please sign in to comment.