Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix parameter names and process_after_weight_loading for W4A16 MoE Group Act Order #11528

Merged
merged 8 commits into from
Jan 23, 2025
94 changes: 60 additions & 34 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):

@abstractmethod
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
raise NotImplementedError

Expand All @@ -65,22 +65,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""

def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
2 * intermediate_size,
hidden_size,
dtype=params_dtype),
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)

# down_proj (row parallel)
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size,
intermediate_size,
dtype=params_dtype),
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
Expand Down Expand Up @@ -289,13 +291,20 @@ def __init__(
self.quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method is not None

self.quant_method.create_weights(
layer=self,
num_experts=num_experts,
hidden_size=hidden_size,
intermediate_size=self.intermediate_size_per_partition,
params_dtype=params_dtype,
weight_loader=self.weight_loader)
moe_quant_params = {
"num_experts": num_experts,
"hidden_size": hidden_size,
"intermediate_size_per_partition":
self.intermediate_size_per_partition,
"params_dtype": params_dtype,
"weight_loader": self.weight_loader,
}
# need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__ ==
"CompressedTensorsWNA16MoEMethod"):
moe_quant_params["intermediate_size_full"] = intermediate_size

self.quant_method.create_weights(layer=self, **moe_quant_params)

def _load_per_tensor_weight_scale(self, shard_id: str,
param: torch.nn.Parameter,
Expand All @@ -312,19 +321,30 @@ def _load_per_tensor_weight_scale(self, shard_id: str,
elif shard_id == "w2":
param_data[expert_id] = loaded_weight

def _load_model_weight_or_group_weight_scale(self, shard_dim: int,
def _load_model_weight_or_group_weight_scale(self,
shard_dim: int,
expert_data: torch.Tensor,
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int):
# Load grouped weight scales for group quantization
# or model weights
tp_rank: int,
load_full_w2: bool = False):
"""
Load grouped weight scales for group quantization or model weights
:param shard_dim: dimension to shard
:param expert_data: parameter for a particular expert
:param shard_id: either w1, w2, or w3
:param loaded_weight: checkpoint weight to load into the param
:param tp_rank: tensor parallel rank
:param load_full_w2: whether or not the w2 loaded should be sharded.
"""
if shard_id == "w2":
self._load_w2(shard_id=shard_id,
shard_dim=shard_dim,
# In the case where we have actorder/g_idx, we do not partition the
# w2 scales, as indicated by `load_full` argument, for all tp cases
self._load_w2(shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
tp_rank=tp_rank,
load_full=load_full_w2)
elif shard_id in ("w1", "w3"):
self._load_w13(shard_id=shard_id,
shard_dim=shard_dim,
Expand Down Expand Up @@ -364,15 +384,21 @@ def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight)

def _load_w2(self, expert_data: torch.Tensor, shard_dim: int,
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int):
def _load_w2(self,
expert_data: torch.Tensor,
shard_dim: int,
loaded_weight: torch.Tensor,
tp_rank: int,
load_full: bool = False):

# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
shard_size = expert_data.shape[shard_dim]
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
shard_size)
if not load_full:
loaded_weight = loaded_weight.narrow(shard_dim,
shard_size * tp_rank,
shard_size)
# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight)

Expand All @@ -387,8 +413,7 @@ def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor,
shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int):

if shard_id == "w2":
self._load_w2(shard_id=shard_id,
shard_dim=shard_dim,
self._load_w2(shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
Expand Down Expand Up @@ -416,19 +441,19 @@ def weight_loader(self, param: torch.nn.Parameter,
]
# Fetch the dim to shard the parameter/loaded weight
# based on the shard id. This will be whatever
# dimension intermediate_size is used.
# dimension intermediate_size_per_partition is used.
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}

expert_data = param.data[expert_id]
tp_rank = get_tensor_model_parallel_rank()

# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size is
# should be whatever dimension intermediate_size_per_partition is
is_transposed = getattr(param, "is_transposed", False)
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
if is_transposed:
shard_dim = ~shard_dim
shard_dim = int(not shard_dim)

# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
Expand Down Expand Up @@ -480,7 +505,8 @@ def weight_loader(self, param: torch.nn.Parameter,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
tp_rank=tp_rank,
load_full_w2=getattr(param, "load_full_w2", False))
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
self._load_per_tensor_weight_scale(shard_id=shard_id,
param=param,
Expand Down
35 changes: 19 additions & 16 deletions vllm/model_executor/layers/quantization/awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def __init__(self, quant_config: AWQMarlinConfig):
self.quant_config = quant_config

def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
extra_weight_attrs.update({
"is_transposed":
Expand All @@ -312,17 +312,18 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
FusedMoeWeightScaleSupported.GROUP.value,
})

w13_qweight = Parameter(torch.empty(num_experts,
hidden_size,
2 * intermediate_size //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
w13_qweight = Parameter(
torch.empty(num_experts,
hidden_size,
2 * intermediate_size_per_partition //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w13_qweight", w13_qweight)
set_weight_attrs(w13_qweight, extra_weight_attrs)

w2_qweight = Parameter(torch.empty(num_experts,
intermediate_size,
intermediate_size_per_partition,
hidden_size //
self.quant_config.pack_factor,
dtype=torch.int32),
Expand All @@ -331,13 +332,14 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
set_weight_attrs(w2_qweight, extra_weight_attrs)

num_groups_w13 = hidden_size // self.quant_config.group_size
num_groups_w2 = intermediate_size // self.quant_config.group_size
num_groups_w2 = (intermediate_size_per_partition //
self.quant_config.group_size)

# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
w13_scales = Parameter(torch.empty(num_experts,
num_groups_w13,
intermediate_size * 2,
intermediate_size_per_partition * 2,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_scales", w13_scales)
Expand All @@ -353,12 +355,13 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,

# WEIGHT_ZERO_POINT
# Allocate 2 zero points for w1 and w3 respectively.
w13_qzeros = Parameter(torch.empty(num_experts,
num_groups_w13,
2 * intermediate_size //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
w13_qzeros = Parameter(
torch.empty(num_experts,
num_groups_w13,
2 * intermediate_size_per_partition //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w13_qzeros", w13_qzeros)
set_weight_attrs(w13_qzeros, extra_weight_attrs)

Expand Down
Loading
Loading