Skip to content
39 changes: 37 additions & 2 deletions pymllm/executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,20 +600,46 @@ def load_model(self) -> None:
quant_config = self._resolve_quant_config()

device_str = f"cuda:{self.gpu_id}" if self.device == "cuda" else self.device
requires_cpu_first_loading = bool(
getattr(model_cls, "requires_cpu_first_weight_loading", False)
)
instantiate_device_str = "cpu" if requires_cpu_first_loading else device_str
if instantiate_device_str != device_str:
logger.info(
"Using CPU-first model instantiation for %s before weight loading. "
"runtime_device=%s",
model_cls.__name__,
device_str,
)

# Use set_default_dtype so parameters created without explicit dtype
# get the target dtype, while parameters with explicit dtype=torch.float32
# (e.g. A_log, dt_bias in GDN layers) stay in float32.
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(self.dtype)
try:
with torch.device(device_str):
with torch.device(instantiate_device_str):
if quant_config is not None:
self.model = model_cls(hf_config, quant_config=quant_config)
else:
self.model = model_cls(hf_config)
finally:
torch.set_default_dtype(old_dtype)
self.model.load_weights(self._iter_weights(model_path))
use_model_path_loader = hasattr(self.model, "load_weights_from_model_path")
if use_model_path_loader:
loader_flag = getattr(self.model, "use_model_path_weight_loader", True)
if callable(loader_flag):
loader_flag = loader_flag()
use_model_path_loader = bool(loader_flag)

if use_model_path_loader:
logger.info(
"Using model-specific weight loader: %s.load_weights_from_model_path",
type(self.model).__name__,
)
self.model.load_weights_from_model_path(model_path)
else:
self.model.load_weights(self._iter_weights(model_path))

# Post-load processing: let each quantization method repack/transform
# weights from checkpoint format to runtime format (e.g. AWQ → Marlin,
Expand All @@ -623,6 +649,14 @@ def load_model(self) -> None:
if quant_method is not None and hasattr(quant_method, "process_weights_after_loading"):
quant_method.process_weights_after_loading(module)

move_compute_modules = getattr(self.model, "move_compute_modules_to_device", None)
if callable(move_compute_modules) and self.device == "cuda":
logger.info(
"Moving model compute modules to runtime device after weight loading: %s",
device_str,
)
move_compute_modules(torch.device(device_str))

self.model.eval()

after_mem = get_available_gpu_memory(self.device, self.gpu_id)
Expand Down Expand Up @@ -893,6 +927,7 @@ def init_attention_backend(self) -> None:
req_to_token=self.req_to_token_pool.req_to_token,
device=torch.device(self.device),
max_req_pool_size=self.req_to_token_pool.size,
sliding_window_size=self.sliding_window_size,
)

if self.gdn_pool is not None:
Expand Down
8 changes: 8 additions & 0 deletions pymllm/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
from pymllm.layers.linear import ColumnParallelLinear, Linear, RowParallelLinear
from pymllm.layers.mlp import MLP, ParallelMLP
from pymllm.layers.rms_norm import GemmaRMSNorm, RMSNorm
from pymllm.layers.gemma3n import (
Gemma3nMLP,
Gemma3nRMSNorm,
Gemma3nRMSNormNoWeight,
)
from pymllm.layers.rms_norm_gated import RMSNormGated
from pymllm.layers.gated_delta_net import GatedDeltaNet
from pymllm.layers.rope import (
Expand Down Expand Up @@ -62,4 +67,7 @@
"top_k_renorm_probs",
"top_k_mask_logits",
"chain_speculative_sampling",
"Gemma3nMLP",
"Gemma3nRMSNorm",
"Gemma3nRMSNormNoWeight",
]
128 changes: 128 additions & 0 deletions pymllm/layers/gemma3n.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from __future__ import annotations

import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import Parameter

from pymllm.layers.base import MllmBaseLayer
from pymllm.layers.linear import Linear
from pymllm.layers.utils import set_weight_attrs


def _get_gemma3n_hidden_act_fn(name: str):
name = (name or "silu").lower()
if name in ("silu", "swish"):
return F.silu
if name == "relu":
return F.relu
if name == "gelu":
return F.gelu
if name in ("gelu_tanh", "gelu_pytorch_tanh"):
return lambda x: F.gelu(x, approximate="tanh")
raise ValueError(f"Unsupported Gemma3n activation: {name}")


class Gemma3nRMSNorm(MllmBaseLayer):
"""Gemma3n RMSNorm used by the native text-only implementation.

This intentionally preserves the numerics of the verified Gemma3n path.
It is not replaced by the generic ``GemmaRMSNorm`` because the generic
FlashInfer-backed layer has different behavior for this checkpoint path.
"""

def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.weight = Parameter(torch.empty(hidden_size))
set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})

def _norm(self, x: torch.Tensor) -> torch.Tensor:
return x / torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)

def forward(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
):
if x.shape[-1] != self.hidden_size:
raise ValueError(
f"Expected last dim == hidden_size ({self.hidden_size}), "
f"but got input shape {tuple(x.shape)}"
)

if residual is not None:
residual = residual + x
x = residual

out = self._norm(x.float()) * self.weight.float()
out = out.to(dtype=x.dtype)

if residual is not None:
return out, residual
return out


class Gemma3nRMSNormNoWeight(MllmBaseLayer):
"""Weight-free RMSNorm used by Gemma3n value normalization."""

def __init__(self, eps: float = 1e-6):
super().__init__()
self.eps = eps

def forward(self, x: torch.Tensor) -> torch.Tensor:
x_dtype = x.dtype
x_fp32 = x.float()
var = x_fp32.pow(2).mean(dim=-1, keepdim=True)
out = x_fp32 * torch.rsqrt(var + self.eps)
return out.to(x_dtype)


class Gemma3nMLP(MllmBaseLayer):
"""Gemma3n feed-forward block.

Uses pymllm ``Linear`` projections while preserving Gemma3n's optional
activation sparsity branch, which is not implemented by the generic
``pymllm.layers.MLP``.
"""

def __init__(
self,
hidden_size: int,
intermediate_size: int,
activation: str,
activation_sparsity: float = 0.0,
):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.activation_name = activation
self.act = _get_gemma3n_hidden_act_fn(activation)
self.activation_sparsity = float(activation_sparsity)

self.gate_proj = Linear(hidden_size, intermediate_size, bias=False)
self.up_proj = Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = Linear(intermediate_size, hidden_size, bias=False)

def _gaussian_topk(self, inputs: torch.Tensor) -> torch.Tensor:
target_sparsity_tensor = torch.tensor(
self.activation_sparsity,
dtype=torch.float32,
device=inputs.device,
)
normal_dist = torch.distributions.normal.Normal(0, 1)
std_multiplier = normal_dist.icdf(target_sparsity_tensor).to(dtype=inputs.dtype)
inputs_mean = torch.mean(inputs, dim=-1, keepdim=True)
inputs_std = torch.std(inputs, dim=-1, keepdim=True, unbiased=False)
cutoff_x = inputs_mean + inputs_std * std_multiplier
return F.relu(inputs - cutoff_x)

def forward(self, x: torch.Tensor) -> torch.Tensor:
gate = self.gate_proj(x)
if self.activation_sparsity > 0.0:
gate = self._gaussian_topk(gate)
gate = self.act(gate)
up = self.up_proj(x)
hidden = gate * up
return self.down_proj(hidden)
8 changes: 8 additions & 0 deletions pymllm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@
"pymllm.models.qwen3_5",
"Qwen3_5ForConditionalGeneration",
),
"Gemma3nForCausalLM": (
"pymllm.models.gemma3n",
"Gemma3nForCausalLM",
),
"Gemma3nForConditionalGeneration": (
"pymllm.models.gemma3n",
"Gemma3nForConditionalGeneration",
),
}


Expand Down
Loading