diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py
index 3d822fc0c7f..da0ce1885db 100644
--- a/vllm/model_executor/layers/fused_moe/layer.py
+++ b/vllm/model_executor/layers/fused_moe/layer.py
@@ -38,7 +38,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
 
     @abstractmethod
     def create_weights(self, layer: torch.nn.Module, num_experts: int,
-                       hidden_size: int, intermediate_size: int,
+                       hidden_size: int, intermediate_size_per_partition: int,
                        params_dtype: torch.dtype, **extra_weight_attrs):
         raise NotImplementedError
 
@@ -65,22 +65,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
     """MoE method without quantization."""
 
     def create_weights(self, layer: torch.nn.Module, num_experts: int,
-                       hidden_size: int, intermediate_size: int,
+                       hidden_size: int, intermediate_size_per_partition: int,
                        params_dtype: torch.dtype, **extra_weight_attrs):
         # Fused gate_up_proj (column parallel)
-        w13_weight = torch.nn.Parameter(torch.empty(num_experts,
-                                                    2 * intermediate_size,
-                                                    hidden_size,
-                                                    dtype=params_dtype),
+        w13_weight = torch.nn.Parameter(torch.empty(
+            num_experts,
+            2 * intermediate_size_per_partition,
+            hidden_size,
+            dtype=params_dtype),
                                         requires_grad=False)
         layer.register_parameter("w13_weight", w13_weight)
         set_weight_attrs(w13_weight, extra_weight_attrs)
 
         # down_proj (row parallel)
-        w2_weight = torch.nn.Parameter(torch.empty(num_experts,
-                                                   hidden_size,
-                                                   intermediate_size,
-                                                   dtype=params_dtype),
+        w2_weight = torch.nn.Parameter(torch.empty(
+            num_experts,
+            hidden_size,
+            intermediate_size_per_partition,
+            dtype=params_dtype),
                                        requires_grad=False)
         layer.register_parameter("w2_weight", w2_weight)
         set_weight_attrs(w2_weight, extra_weight_attrs)
@@ -289,13 +291,20 @@ def __init__(
             self.quant_method = quant_config.get_quant_method(self, prefix)
         assert self.quant_method is not None
 
-        self.quant_method.create_weights(
-            layer=self,
-            num_experts=num_experts,
-            hidden_size=hidden_size,
-            intermediate_size=self.intermediate_size_per_partition,
-            params_dtype=params_dtype,
-            weight_loader=self.weight_loader)
+        moe_quant_params = {
+            "num_experts": num_experts,
+            "hidden_size": hidden_size,
+            "intermediate_size_per_partition":
+            self.intermediate_size_per_partition,
+            "params_dtype": params_dtype,
+            "weight_loader": self.weight_loader,
+        }
+        # need full intermediate size pre-sharding for WNA16 act order
+        if (self.quant_method.__class__.__name__ ==
+                "CompressedTensorsWNA16MoEMethod"):
+            moe_quant_params["intermediate_size_full"] = intermediate_size
+
+        self.quant_method.create_weights(layer=self, **moe_quant_params)
 
     def _load_per_tensor_weight_scale(self, shard_id: str,
                                       param: torch.nn.Parameter,
@@ -312,19 +321,30 @@ def _load_per_tensor_weight_scale(self, shard_id: str,
         elif shard_id == "w2":
             param_data[expert_id] = loaded_weight
 
-    def _load_model_weight_or_group_weight_scale(self, shard_dim: int,
+    def _load_model_weight_or_group_weight_scale(self,
+                                                 shard_dim: int,
                                                  expert_data: torch.Tensor,
                                                  shard_id: str,
                                                  loaded_weight: torch.Tensor,
-                                                 tp_rank: int):
-        # Load grouped weight scales for group quantization
-        # or model weights
+                                                 tp_rank: int,
+                                                 load_full_w2: bool = False):
+        """
+        Load grouped weight scales for group quantization or model weights
+            :param shard_dim: dimension to shard
+            :param expert_data: parameter for a particular expert
+            :param shard_id: either w1, w2, or w3
+            :param loaded_weight: checkpoint weight to load into the param
+            :param tp_rank: tensor parallel rank
+            :param load_full_w2: whether or not the w2 loaded should be sharded.
+        """
         if shard_id == "w2":
-            self._load_w2(shard_id=shard_id,
-                          shard_dim=shard_dim,
+            # In the case where we have actorder/g_idx, we do not partition the
+            # w2 scales, as indicated by `load_full` argument, for all tp cases
+            self._load_w2(shard_dim=shard_dim,
                           loaded_weight=loaded_weight,
                           expert_data=expert_data,
-                          tp_rank=tp_rank)
+                          tp_rank=tp_rank,
+                          load_full=load_full_w2)
         elif shard_id in ("w1", "w3"):
             self._load_w13(shard_id=shard_id,
                            shard_dim=shard_dim,
@@ -364,15 +384,21 @@ def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
             expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
         expert_data.copy_(loaded_weight)
 
-    def _load_w2(self, expert_data: torch.Tensor, shard_dim: int,
-                 shard_id: str, loaded_weight: torch.Tensor, tp_rank: int):
+    def _load_w2(self,
+                 expert_data: torch.Tensor,
+                 shard_dim: int,
+                 loaded_weight: torch.Tensor,
+                 tp_rank: int,
+                 load_full: bool = False):
 
         # Index the loaded weight for tp sharding.
         # down_proj: "RowParallel" so tp sharding on input_dim
         # Narrow parameter and load.
         shard_size = expert_data.shape[shard_dim]
-        loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
-                                             shard_size)
+        if not load_full:
+            loaded_weight = loaded_weight.narrow(shard_dim,
+                                                 shard_size * tp_rank,
+                                                 shard_size)
         # w2, down_proj: Load into only logical weight of w2.
         expert_data.copy_(loaded_weight)
 
@@ -387,8 +413,7 @@ def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor,
                     shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int):
 
         if shard_id == "w2":
-            self._load_w2(shard_id=shard_id,
-                          shard_dim=shard_dim,
+            self._load_w2(shard_dim=shard_dim,
                           loaded_weight=loaded_weight,
                           expert_data=expert_data,
                           tp_rank=tp_rank)
@@ -416,7 +441,7 @@ def weight_loader(self, param: torch.nn.Parameter,
         ]
         # Fetch the dim to shard the parameter/loaded weight
         # based on the shard id. This will be whatever
-        # dimension intermediate_size is used.
+        # dimension intermediate_size_per_partition is used.
         SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
 
         expert_data = param.data[expert_id]
@@ -424,11 +449,11 @@ def weight_loader(self, param: torch.nn.Parameter,
 
         # is_transposed: if the dim to shard the weight
         # should be flipped. Required by GPTQ, compressed-tensors
-        # should be whatever dimension intermediate_size is
+        # should be whatever dimension intermediate_size_per_partition is
         is_transposed = getattr(param, "is_transposed", False)
         shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
         if is_transposed:
-            shard_dim = ~shard_dim
+            shard_dim = int(not shard_dim)
 
         # Case input scale: input_scale loading is only supported for fp8
         if "input_scale" in weight_name:
@@ -480,7 +505,8 @@ def weight_loader(self, param: torch.nn.Parameter,
                     shard_dim=shard_dim,
                     loaded_weight=loaded_weight,
                     expert_data=expert_data,
-                    tp_rank=tp_rank)
+                    tp_rank=tp_rank,
+                    load_full_w2=getattr(param, "load_full_w2", False))
             elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
                 self._load_per_tensor_weight_scale(shard_id=shard_id,
                                                    param=param,
diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py
index c28fd0c6737..0c3c9816878 100644
--- a/vllm/model_executor/layers/quantization/awq_marlin.py
+++ b/vllm/model_executor/layers/quantization/awq_marlin.py
@@ -303,7 +303,7 @@ def __init__(self, quant_config: AWQMarlinConfig):
         self.quant_config = quant_config
 
     def create_weights(self, layer: torch.nn.Module, num_experts: int,
-                       hidden_size: int, intermediate_size: int,
+                       hidden_size: int, intermediate_size_per_partition: int,
                        params_dtype: torch.dtype, **extra_weight_attrs):
         extra_weight_attrs.update({
             "is_transposed":
@@ -312,17 +312,18 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
             FusedMoeWeightScaleSupported.GROUP.value,
         })
 
-        w13_qweight = Parameter(torch.empty(num_experts,
-                                            hidden_size,
-                                            2 * intermediate_size //
-                                            self.quant_config.pack_factor,
-                                            dtype=torch.int32),
-                                requires_grad=False)
+        w13_qweight = Parameter(
+            torch.empty(num_experts,
+                        hidden_size,
+                        2 * intermediate_size_per_partition //
+                        self.quant_config.pack_factor,
+                        dtype=torch.int32),
+            requires_grad=False)
         layer.register_parameter("w13_qweight", w13_qweight)
         set_weight_attrs(w13_qweight, extra_weight_attrs)
 
         w2_qweight = Parameter(torch.empty(num_experts,
-                                           intermediate_size,
+                                           intermediate_size_per_partition,
                                            hidden_size //
                                            self.quant_config.pack_factor,
                                            dtype=torch.int32),
@@ -331,13 +332,14 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
         set_weight_attrs(w2_qweight, extra_weight_attrs)
 
         num_groups_w13 = hidden_size // self.quant_config.group_size
-        num_groups_w2 = intermediate_size // self.quant_config.group_size
+        num_groups_w2 = (intermediate_size_per_partition //
+                         self.quant_config.group_size)
 
         # WEIGHT_SCALES
         # Allocate 2 scales for w1 and w3 respectively.
         w13_scales = Parameter(torch.empty(num_experts,
                                            num_groups_w13,
-                                           intermediate_size * 2,
+                                           intermediate_size_per_partition * 2,
                                            dtype=params_dtype),
                                requires_grad=False)
         layer.register_parameter("w13_scales", w13_scales)
@@ -353,12 +355,13 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
 
         # WEIGHT_ZERO_POINT
         # Allocate 2 zero points for w1 and w3 respectively.
-        w13_qzeros = Parameter(torch.empty(num_experts,
-                                           num_groups_w13,
-                                           2 * intermediate_size //
-                                           self.quant_config.pack_factor,
-                                           dtype=torch.int32),
-                               requires_grad=False)
+        w13_qzeros = Parameter(
+            torch.empty(num_experts,
+                        num_groups_w13,
+                        2 * intermediate_size_per_partition //
+                        self.quant_config.pack_factor,
+                        dtype=torch.int32),
+            requires_grad=False)
         layer.register_parameter("w13_qzeros", w13_qzeros)
         set_weight_attrs(w13_qzeros, extra_weight_attrs)
 
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
index 4fb8fd84e92..e1c45f4e42e 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
@@ -13,6 +13,7 @@
                                                   FusedMoeWeightScaleSupported)
 from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
     WNA16_SUPPORTED_BITS)
+from vllm.model_executor.layers.quantization.utils import replace_parameter
 from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
     all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
 from vllm.model_executor.utils import set_weight_attrs
@@ -75,24 +76,26 @@ def __init__(
         self.static_input_scales = not self.input_quant.dynamic
 
     def create_weights(self, layer: torch.nn.Module, num_experts: int,
-                       hidden_size: int, intermediate_size: int,
+                       hidden_size: int, intermediate_size_per_partition: int,
                        params_dtype: torch.dtype, **extra_weight_attrs):
 
         params_dtype = torch.float8_e4m3fn
 
         # WEIGHTS
-        w13_weight = torch.nn.Parameter(torch.empty(num_experts,
-                                                    2 * intermediate_size,
-                                                    hidden_size,
-                                                    dtype=params_dtype),
+        w13_weight = torch.nn.Parameter(torch.empty(
+            num_experts,
+            2 * intermediate_size_per_partition,
+            hidden_size,
+            dtype=params_dtype),
                                         requires_grad=False)
         layer.register_parameter("w13_weight", w13_weight)
         set_weight_attrs(w13_weight, extra_weight_attrs)
 
-        w2_weight = torch.nn.Parameter(torch.empty(num_experts,
-                                                   hidden_size,
-                                                   intermediate_size,
-                                                   dtype=params_dtype),
+        w2_weight = torch.nn.Parameter(torch.empty(
+            num_experts,
+            hidden_size,
+            intermediate_size_per_partition,
+            dtype=params_dtype),
                                        requires_grad=False)
         layer.register_parameter("w2_weight", w2_weight)
         set_weight_attrs(w2_weight, extra_weight_attrs)
@@ -254,6 +257,7 @@ def __init__(
         self.packed_factor = 32 // config.num_bits
         self.strategy = config.strategy
         self.group_size = config.group_size
+        self.actorder = config.actorder
         assert config.symmetric, (
             "Only symmetric quantization is supported for MoE")
 
@@ -266,9 +270,16 @@ def __init__(
                              f"{WNA16_SUPPORTED_BITS}")
 
     def create_weights(self, layer: torch.nn.Module, num_experts: int,
-                       hidden_size: int, intermediate_size: int,
+                       hidden_size: int, intermediate_size_per_partition: int,
                        params_dtype: torch.dtype, **extra_weight_attrs):
 
+        assert params_dtype == torch.float16, (
+            "float16 is required for MoE compressed models. Set dtype=torch.float16"  # noqa: E501
+        )
+
+        intermediate_size_full = extra_weight_attrs.pop(
+            "intermediate_size_full")
+
         # Will transpose the loaded weight along the
         # intermediate and hidden dim sizes. Will
         # shard for TP along the transposed dims
@@ -276,35 +287,45 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
             "is_transposed": True,
             "quant_method": self.strategy
         })
-        w13_weight = torch.nn.Parameter(torch.empty(num_experts,
-                                                    hidden_size //
-                                                    self.packed_factor,
-                                                    2 * intermediate_size,
-                                                    dtype=torch.int32),
+        w13_weight = torch.nn.Parameter(torch.empty(
+            num_experts,
+            hidden_size // self.packed_factor,
+            2 * intermediate_size_per_partition,
+            dtype=torch.int32),
                                         requires_grad=False)
         layer.register_parameter("w13_weight_packed", w13_weight)
         set_weight_attrs(w13_weight, extra_weight_attrs)
 
-        w2_weight = torch.nn.Parameter(torch.empty(num_experts,
-                                                   intermediate_size //
-                                                   self.packed_factor,
-                                                   hidden_size,
-                                                   dtype=torch.int32),
+        w2_weight = torch.nn.Parameter(torch.empty(
+            num_experts,
+            intermediate_size_per_partition // self.packed_factor,
+            hidden_size,
+            dtype=torch.int32),
                                        requires_grad=False)
         layer.register_parameter("w2_weight_packed", w2_weight)
         set_weight_attrs(w2_weight, extra_weight_attrs)
 
+        # In the case where we have actorder/g_idx,
+        # we do not partition the w2 scales
+        load_full_w2 = self.actorder and self.group_size != -1
+        w2_scales_size = (intermediate_size_full
+                          if load_full_w2 else intermediate_size_per_partition)
+
+        self.is_k_full = (not self.actorder) or (
+            intermediate_size_per_partition == intermediate_size_full)
+
         if self.strategy == "channel":
             num_groups_w2 = num_groups_w13 = 1
             self.group_size = -1
         else:
-            num_groups_w2 = intermediate_size // self.group_size
+            num_groups_w2 = w2_scales_size // self.group_size
             num_groups_w13 = hidden_size // self.group_size
 
-        w13_scale = torch.nn.Parameter(torch.ones(num_experts,
-                                                  num_groups_w13,
-                                                  2 * intermediate_size,
-                                                  dtype=params_dtype),
+        w13_scale = torch.nn.Parameter(torch.ones(
+            num_experts,
+            num_groups_w13,
+            2 * intermediate_size_per_partition,
+            dtype=params_dtype),
                                        requires_grad=False)
         layer.register_parameter("w13_weight_scale", w13_scale)
         set_weight_attrs(w13_scale, extra_weight_attrs)
@@ -316,6 +337,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
                                       requires_grad=False)
         layer.register_parameter("w2_weight_scale", w2_scale)
         set_weight_attrs(w2_scale, extra_weight_attrs)
+        set_weight_attrs(w2_scale, {"load_full_w2": load_full_w2})
 
         w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2),
                                              requires_grad=False)
@@ -335,18 +357,18 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
             ),
             requires_grad=False,
         )
-        layer.register_parameter("w13_g_idx", w13_g_idx)
+        layer.register_parameter("w13_weight_g_idx", w13_g_idx)
         set_weight_attrs(w13_g_idx, extra_weight_attrs)
 
         w2_g_idx = torch.nn.Parameter(
             torch.empty(
                 num_experts,
-                intermediate_size,
+                intermediate_size_per_partition,
                 dtype=torch.int32,
             ),
             requires_grad=False,
         )
-        layer.register_parameter("w2_g_idx", w2_g_idx)
+        layer.register_parameter("w2_weight_g_idx", w2_g_idx)
         set_weight_attrs(w2_g_idx, extra_weight_attrs)
 
         w13_g_idx_sort_indices = torch.nn.Parameter(
@@ -364,7 +386,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
         w2_g_idx_sort_indices = torch.nn.Parameter(
             torch.empty(
                 num_experts,
-                intermediate_size,
+                intermediate_size_per_partition,
                 dtype=torch.int32,
             ),
             requires_grad=False,
@@ -422,24 +444,55 @@ def marlin_moe_permute_scales(s: torch.Tensor, size_k: int,
         size_k2 = layer.w2_weight_packed.shape[2]
         size_k13 = layer.w13_weight_packed.shape[2]
 
-        num_experts = layer.w13_g_idx.shape[0]
-        device = layer.w13_g_idx.device
-        layer.w13_g_idx = torch.nn.Parameter(
-            torch.empty((num_experts, 0), dtype=torch.int32, device=device),
-            requires_grad=False,
-        )
-        layer.w2_g_idx = torch.nn.Parameter(
-            torch.empty((num_experts, 0), dtype=torch.int32, device=device),
-            requires_grad=False,
-        )
-        layer.w13_g_idx_sort_indices = torch.nn.Parameter(
-            torch.empty((num_experts, 0), dtype=torch.int32, device=device),
-            requires_grad=False,
-        )
-        layer.w2_g_idx_sort_indices = torch.nn.Parameter(
-            torch.empty((num_experts, 0), dtype=torch.int32, device=device),
-            requires_grad=False,
-        )
+        num_experts = layer.w13_weight_g_idx.shape[0]
+        device = layer.w13_weight_g_idx.device
+
+        # when running models with grouped act order,
+        # resort to g_idx values provided in checkpoint
+        if self.actorder == "group":
+            w13_g_idx_sort_indices = torch.empty_like(layer.w13_weight_g_idx)
+            w2_g_idx_sort_indices = torch.empty_like(layer.w2_weight_g_idx)
+            w13_sorted_g_idx = torch.empty_like(layer.w13_weight_g_idx)
+            w2_sorted_g_idx = torch.empty_like(layer.w2_weight_g_idx)
+
+            for e in range(num_experts):
+                w13_g_idx_sort_indices[e] = torch.argsort(
+                    layer.w13_weight_g_idx[e]).to(torch.int32)
+                w2_g_idx_sort_indices[e] = torch.argsort(
+                    layer.w2_weight_g_idx[e]).to(torch.int32)
+                w13_sorted_g_idx[e] = layer.w13_weight_g_idx[e][
+                    w13_g_idx_sort_indices[e]]
+                w2_sorted_g_idx[e] = layer.w2_weight_g_idx[e][
+                    w2_g_idx_sort_indices[e]]
+
+            replace_parameter(layer, "w13_weight_g_idx", w13_sorted_g_idx)
+            replace_parameter(layer, "w2_weight_g_idx", w2_sorted_g_idx)
+            replace_parameter(layer, "w13_g_idx_sort_indices",
+                              w13_g_idx_sort_indices)
+            replace_parameter(layer, "w2_g_idx_sort_indices",
+                              w2_g_idx_sort_indices)
+
+        else:
+            layer.w13_weight_g_idx = torch.nn.Parameter(
+                torch.empty((num_experts, 0), dtype=torch.int32,
+                            device=device),
+                requires_grad=False,
+            )
+            layer.w2_weight_g_idx = torch.nn.Parameter(
+                torch.empty((num_experts, 0), dtype=torch.int32,
+                            device=device),
+                requires_grad=False,
+            )
+            layer.w13_g_idx_sort_indices = torch.nn.Parameter(
+                torch.empty((num_experts, 0), dtype=torch.int32,
+                            device=device),
+                requires_grad=False,
+            )
+            layer.w2_g_idx_sort_indices = torch.nn.Parameter(
+                torch.empty((num_experts, 0), dtype=torch.int32,
+                            device=device),
+                requires_grad=False,
+            )
 
         marlin_w13_qweight = ops.gptq_marlin_moe_repack(
             layer.w13_weight_packed,
@@ -511,9 +564,9 @@ def apply(
             router_logits,
             topk_weights,
             topk_ids,
-            g_idx1=layer.w13_g_idx,
-            g_idx2=layer.w2_g_idx,
+            g_idx1=layer.w13_weight_g_idx,
+            g_idx2=layer.w2_weight_g_idx,
             sort_indices1=layer.w13_g_idx_sort_indices,
             sort_indices2=layer.w2_g_idx_sort_indices,
             num_bits=self.num_bits,
-        )
+            is_k_full=self.is_k_full)
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
index 61d1c911cd1..2e1b5e3c2d3 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
@@ -62,7 +62,7 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
                        **kwargs):
 
         assert params_dtype == torch.float16, (
-            "float16 is required for marlin24 compressd models. Set dtype=torch.float16"  # noqa: E501
+            "float16 is required for marlin24 compressed models. Set dtype=torch.float16"  # noqa: E501
         )
 
         pack_factor = 32 // self.quant_type.size_bits
diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py
index 209f12c6dfe..100cbfa4c95 100644
--- a/vllm/model_executor/layers/quantization/experts_int8.py
+++ b/vllm/model_executor/layers/quantization/experts_int8.py
@@ -52,7 +52,7 @@ def __init__(self, quant_config: ExpertsInt8Config):
         self.quant_config = quant_config
 
     def create_weights(self, layer: torch.nn.Module, num_experts: int,
-                       hidden_size: int, intermediate_size: int,
+                       hidden_size: int, intermediate_size_per_partition: int,
                        params_dtype: torch.dtype, **extra_weight_attrs):
 
         int8_dtype = torch.int8
@@ -64,26 +64,29 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
         extra_weight_attrs['weight_loader'] = wrapped_weight_loader
 
         # Fused gate_up_proj (column parallel)
-        w13_weight = torch.nn.Parameter(torch.empty(num_experts,
-                                                    2 * intermediate_size,
-                                                    hidden_size,
-                                                    dtype=int8_dtype),
+        w13_weight = torch.nn.Parameter(torch.empty(
+            num_experts,
+            2 * intermediate_size_per_partition,
+            hidden_size,
+            dtype=int8_dtype),
                                         requires_grad=False)
         layer.register_parameter("w13_weight", w13_weight)
         set_weight_attrs(w13_weight, extra_weight_attrs)
 
         # down_proj (row parallel)
-        w2_weight = torch.nn.Parameter(torch.empty(num_experts,
-                                                   hidden_size,
-                                                   intermediate_size,
-                                                   dtype=int8_dtype),
+        w2_weight = torch.nn.Parameter(torch.empty(
+            num_experts,
+            hidden_size,
+            intermediate_size_per_partition,
+            dtype=int8_dtype),
                                        requires_grad=False)
         layer.register_parameter("w2_weight", w2_weight)
         set_weight_attrs(w2_weight, extra_weight_attrs)
 
-        w13_scale = torch.nn.Parameter(torch.zeros(num_experts,
-                                                   2 * intermediate_size,
-                                                   dtype=torch.float32),
+        w13_scale = torch.nn.Parameter(torch.zeros(
+            num_experts,
+            2 * intermediate_size_per_partition,
+            dtype=torch.float32),
                                        requires_grad=False)
         layer.register_parameter("w13_scale", w13_scale)
 
diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py
index 4969ee55952..5af9ec400bc 100644
--- a/vllm/model_executor/layers/quantization/fp8.py
+++ b/vllm/model_executor/layers/quantization/fp8.py
@@ -377,8 +377,8 @@ def __init__(self, quant_config: Fp8Config):
         self.block_quant = self.quant_config.weight_block_size is not None
 
     def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
-                       intermediate_size: int, params_dtype: torch.dtype,
-                       **extra_weight_attrs):
+                       intermediate_size_per_partition: int,
+                       params_dtype: torch.dtype, **extra_weight_attrs):
 
         if self.quant_config.is_checkpoint_fp8_serialized:
             params_dtype = torch.float8_e4m3fn
@@ -393,30 +393,34 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
             # scales, the output_size of the weights for both the gate and up
             # layers must be divisible by block_n.
             # Required by column parallel or enabling merged weights
-            if intermediate_size % block_n != 0:
+            if intermediate_size_per_partition % block_n != 0:
                 raise ValueError(
                     f"The output_size of gate's and up's weight = "
-                    f"{intermediate_size} is not divisible by "
+                    f"{intermediate_size_per_partition} is not divisible by "
                     f"weight quantization block_n = {block_n}.")
-            if (tp_size > 1 and intermediate_size % block_k != 0):
+            if (tp_size > 1
+                    and intermediate_size_per_partition % block_k != 0):
                 # Required by row parallel
-                raise ValueError(f"The input_size of down's weight = "
-                                 f"{intermediate_size} is not divisible by "
-                                 f"weight quantization block_k = {block_k}.")
+                raise ValueError(
+                    f"The input_size of down's weight = "
+                    f"{intermediate_size_per_partition} is not divisible by "
+                    f"weight quantization block_k = {block_k}.")
 
         # WEIGHTS
-        w13_weight = torch.nn.Parameter(torch.empty(num_experts,
-                                                    2 * intermediate_size,
-                                                    hidden_size,
-                                                    dtype=params_dtype),
+        w13_weight = torch.nn.Parameter(torch.empty(
+            num_experts,
+            2 * intermediate_size_per_partition,
+            hidden_size,
+            dtype=params_dtype),
                                         requires_grad=False)
         layer.register_parameter("w13_weight", w13_weight)
         set_weight_attrs(w13_weight, extra_weight_attrs)
 
-        w2_weight = torch.nn.Parameter(torch.empty(num_experts,
-                                                   hidden_size,
-                                                   intermediate_size,
-                                                   dtype=params_dtype),
+        w2_weight = torch.nn.Parameter(torch.empty(
+            num_experts,
+            hidden_size,
+            intermediate_size_per_partition,
+            dtype=params_dtype),
                                        requires_grad=False)
         layer.register_parameter("w2_weight", w2_weight)
         set_weight_attrs(w2_weight, extra_weight_attrs)
@@ -437,7 +441,8 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
             w13_weight_scale = torch.nn.Parameter(
                 torch.ones(
                     num_experts,
-                    2 * ((intermediate_size + block_n - 1) // block_n),
+                    2 * ((intermediate_size_per_partition + block_n - 1) //
+                         block_n),
                     (hidden_size + block_k - 1) // block_k,
                     dtype=torch.float32,
                 ),
@@ -447,7 +452,7 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
                 torch.ones(
                     num_experts,
                     (hidden_size + block_n - 1) // block_n,
-                    (intermediate_size + block_k - 1) // block_k,
+                    (intermediate_size_per_partition + block_k - 1) // block_k,
                     dtype=torch.float32,
                 ),
                 requires_grad=False,
diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py
index 2dbfca9b076..4dc4b052b04 100644
--- a/vllm/model_executor/layers/quantization/gptq_marlin.py
+++ b/vllm/model_executor/layers/quantization/gptq_marlin.py
@@ -317,7 +317,7 @@ def create_weights(
         layer: torch.nn.Module,
         num_experts: int,
         hidden_size: int,
-        intermediate_size: int,
+        intermediate_size_per_partition: int,
         params_dtype: torch.dtype,
         **extra_weight_attrs,
     ):
@@ -326,7 +326,8 @@ def create_weights(
         # Supports only sym for now (no zp)
         if self.quant_config.group_size != -1:
             scales_size13 = hidden_size // self.quant_config.group_size
-            scales_size2 = intermediate_size // self.quant_config.group_size
+            scales_size2 = (intermediate_size_per_partition //
+                            self.quant_config.group_size)
             strategy = FusedMoeWeightScaleSupported.GROUP.value
         else:
             scales_size13 = 1
@@ -342,7 +343,7 @@ def create_weights(
             torch.empty(
                 num_experts,
                 hidden_size // self.quant_config.pack_factor,
-                2 * intermediate_size,
+                2 * intermediate_size_per_partition,
                 dtype=torch.int32,
             ),
             requires_grad=False,
@@ -353,7 +354,8 @@ def create_weights(
         w2_qweight = torch.nn.Parameter(
             torch.empty(
                 num_experts,
-                intermediate_size // self.quant_config.pack_factor,
+                intermediate_size_per_partition //
+                self.quant_config.pack_factor,
                 hidden_size,
                 dtype=torch.int32,
             ),
@@ -365,7 +367,7 @@ def create_weights(
         w13_scales = torch.nn.Parameter(
             torch.empty(num_experts,
                         scales_size13,
-                        2 * intermediate_size,
+                        2 * intermediate_size_per_partition,
                         dtype=torch.half),
             requires_grad=False,
         )
@@ -385,7 +387,8 @@ def create_weights(
         w13_qzeros = torch.nn.Parameter(
             torch.empty(num_experts,
                         scales_size13,
-                        2 * intermediate_size // self.quant_config.pack_factor,
+                        2 * intermediate_size_per_partition //
+                        self.quant_config.pack_factor,
                         dtype=params_dtype),
             requires_grad=False,
         )
@@ -414,7 +417,7 @@ def create_weights(
         w2_g_idx = torch.nn.Parameter(
             torch.empty(
                 num_experts,
-                intermediate_size,
+                intermediate_size_per_partition,
                 dtype=torch.int32,
             ),
             requires_grad=False,
@@ -435,7 +438,7 @@ def create_weights(
         w2_g_idx_sort_indices = torch.nn.Parameter(
             torch.empty(
                 num_experts,
-                intermediate_size,
+                intermediate_size_per_partition,
                 dtype=torch.int32,
             ),
             requires_grad=False,
diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py
index 3e192473008..68a39545407 100644
--- a/vllm/model_executor/layers/quantization/quark/quark_moe.py
+++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py
@@ -60,24 +60,26 @@ def __init__(self, weight_config: Dict[str, Any], input_config: Dict[str,
         self.static_input_scales = not self.input_quant.get("is_dynamic")
 
     def create_weights(self, layer: torch.nn.Module, num_experts: int,
-                       hidden_size: int, intermediate_size: int,
+                       hidden_size: int, intermediate_size_per_partition: int,
                        params_dtype: torch.dtype, **extra_weight_attrs):
 
         params_dtype = torch.float8_e4m3fn
 
         # WEIGHTS
-        w13_weight = torch.nn.Parameter(torch.empty(num_experts,
-                                                    2 * intermediate_size,
-                                                    hidden_size,
-                                                    dtype=params_dtype),
+        w13_weight = torch.nn.Parameter(torch.empty(
+            num_experts,
+            2 * intermediate_size_per_partition,
+            hidden_size,
+            dtype=params_dtype),
                                         requires_grad=False)
         layer.register_parameter("w13_weight", w13_weight)
         set_weight_attrs(w13_weight, extra_weight_attrs)
 
-        w2_weight = torch.nn.Parameter(torch.empty(num_experts,
-                                                   hidden_size,
-                                                   intermediate_size,
-                                                   dtype=params_dtype),
+        w2_weight = torch.nn.Parameter(torch.empty(
+            num_experts,
+            hidden_size,
+            intermediate_size_per_partition,
+            dtype=params_dtype),
                                        requires_grad=False)
         layer.register_parameter("w2_weight", w2_weight)
         set_weight_attrs(w2_weight, extra_weight_attrs)