From 35dce028a6acc70fb7388b287bf49c9bf9125f5d Mon Sep 17 00:00:00 2001 From: grape203 Date: Fri, 1 May 2026 13:55:28 +0800 Subject: [PATCH 01/12] Add Gemma3n text-only native server path --- pymllm/executor/model_runner.py | 39 +- pymllm/models/__init__.py | 8 + pymllm/models/gemma3n.py | 1165 +++++++++++++++++++ pymllm/orchestrator/model_runner_process.py | 10 +- 4 files changed, 1219 insertions(+), 3 deletions(-) create mode 100644 pymllm/models/gemma3n.py diff --git a/pymllm/executor/model_runner.py b/pymllm/executor/model_runner.py index 2178afa9..5289f78a 100644 --- a/pymllm/executor/model_runner.py +++ b/pymllm/executor/model_runner.py @@ -600,20 +600,55 @@ def load_model(self) -> None: quant_config = self._resolve_quant_config() device_str = f"cuda:{self.gpu_id}" if self.device == "cuda" else self.device + requires_cpu_first_loading = bool( + getattr(model_cls, "requires_cpu_first_weight_loading", False) + ) + if not requires_cpu_first_loading: + import os + + gemma3n_native_mode = ( + architecture in {"Gemma3nForCausalLM", "Gemma3nForConditionalGeneration"} + and os.environ.get("MLLM_GEMMA3N_NATIVE", "0") == "1" + ) + requires_cpu_first_loading = gemma3n_native_mode + + instantiate_device_str = "cpu" if requires_cpu_first_loading else device_str + if instantiate_device_str != device_str: + logger.info( + "Using CPU-first model instantiation for %s before weight loading. " + "runtime_device=%s", + model_cls.__name__, + device_str, + ) + # Use set_default_dtype so parameters created without explicit dtype # get the target dtype, while parameters with explicit dtype=torch.float32 # (e.g. A_log, dt_bias in GDN layers) stay in float32. old_dtype = torch.get_default_dtype() torch.set_default_dtype(self.dtype) try: - with torch.device(device_str): + with torch.device(instantiate_device_str): if quant_config is not None: self.model = model_cls(hf_config, quant_config=quant_config) else: self.model = model_cls(hf_config) finally: torch.set_default_dtype(old_dtype) - self.model.load_weights(self._iter_weights(model_path)) + use_model_path_loader = hasattr(self.model, "load_weights_from_model_path") + if use_model_path_loader: + loader_flag = getattr(self.model, "use_model_path_weight_loader", True) + if callable(loader_flag): + loader_flag = loader_flag() + use_model_path_loader = bool(loader_flag) + + if use_model_path_loader: + logger.info( + "Using model-specific weight loader: %s.load_weights_from_model_path", + type(self.model).__name__, + ) + self.model.load_weights_from_model_path(model_path) + else: + self.model.load_weights(self._iter_weights(model_path)) # Post-load processing: let each quantization method repack/transform # weights from checkpoint format to runtime format (e.g. AWQ → Marlin, diff --git a/pymllm/models/__init__.py b/pymllm/models/__init__.py index 7751b309..3b207ad4 100644 --- a/pymllm/models/__init__.py +++ b/pymllm/models/__init__.py @@ -30,6 +30,14 @@ "pymllm.models.qwen3_5", "Qwen3_5ForConditionalGeneration", ), + "Gemma3nForCausalLM": ( + "pymllm.models.gemma3n", + "Gemma3nForCausalLM", + ), + "Gemma3nForConditionalGeneration": ( + "pymllm.models.gemma3n", + "Gemma3nForConditionalGeneration", + ), } diff --git a/pymllm/models/gemma3n.py b/pymllm/models/gemma3n.py new file mode 100644 index 00000000..5c685809 --- /dev/null +++ b/pymllm/models/gemma3n.py @@ -0,0 +1,1165 @@ +from __future__ import annotations + +import logging +from typing import Any, List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +logger = logging.getLogger(__name__) + + +def _get_text_config(config): + """Extract text config from multimodal config, or return config as-is.""" + return getattr(config, "text_config", config) + + +def _get_layer_types(config) -> List[str]: + """Return per-layer type list for Gemma 3n.""" + tc = _get_text_config(config) + if hasattr(tc, "layer_types") and tc.layer_types is not None: + return tc.layer_types + + n_layers = tc.num_hidden_layers + return [ + "full_attention" if (i + 1) % 5 == 0 else "sliding_attention" + for i in range(n_layers) + ] + + +def _get_hidden_act_fn(name: str): + if name == "silu": + return F.silu + if name == "gelu": + return lambda x: F.gelu(x, approximate="none") + if name in ("gelu_tanh", "gelu_pytorch_tanh"): + return lambda x: F.gelu(x, approximate="tanh") + raise ValueError(f"Unsupported hidden_act: {name}") + + +def _get_intermediate_size(config, layer_id: int) -> int: + tc = _get_text_config(config) + intermediate_size = tc.intermediate_size + if isinstance(intermediate_size, int): + return intermediate_size + return int(intermediate_size[layer_id]) + + +class SimpleGemmaRMSNorm(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + self.weight = nn.Parameter(torch.ones(hidden_size)) + + def _norm(self, x: torch.Tensor) -> torch.Tensor: + return x / torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + def forward( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ): + if x.shape[-1] != self.hidden_size: + raise ValueError( + f"Expected last dim == hidden_size ({self.hidden_size}), " + f"but got input shape {tuple(x.shape)}" + ) + + if residual is not None: + residual = residual + x + x = residual + + out = self._norm(x.float()) * self.weight.float() + out = out.to(dtype=x.dtype) + + if residual is not None: + return out, residual + return out + + + +class SimpleRMSNormNoWeight(nn.Module): + def __init__(self, eps: float = 1e-6): + super().__init__() + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_dtype = x.dtype + x_fp32 = x.float() + var = x_fp32.pow(2).mean(dim=-1, keepdim=True) + out = x_fp32 * torch.rsqrt(var + self.eps) + return out.to(x_dtype) + + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def _build_rope_cos_sin( + positions: torch.Tensor, + head_dim: int, + base: float, + device, + dtype, +): + # positions: [B, S] + inv_freq = 1.0 / ( + base ** (torch.arange(0, head_dim, 2, device=device, dtype=torch.float32) / head_dim) + ) + freqs = positions[:, :, None].float() * inv_freq[None, None, :] + emb = torch.cat([freqs, freqs], dim=-1) + cos = emb.cos().to(dtype) + sin = emb.sin().to(dtype) + return cos, sin + + +def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + # x: [B, H, S, D], cos/sin: [B, S, D] + cos = cos.unsqueeze(1) + sin = sin.unsqueeze(1) + return (x * cos) + (_rotate_half(x) * sin) + + +class SimpleMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + activation: str, + activation_sparsity: float = 0.0, + ): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.activation_name = activation + self.act = _get_hidden_act_fn(activation) + self.activation_sparsity = float(activation_sparsity) + + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + + def _gaussian_topk(self, inputs: torch.Tensor) -> torch.Tensor: + target_sparsity_tensor = torch.tensor( + self.activation_sparsity, + dtype=torch.float32, + device=inputs.device, + ) + normal_dist = torch.distributions.normal.Normal(0, 1) + std_multiplier = normal_dist.icdf(target_sparsity_tensor).to(dtype=inputs.dtype) + inputs_mean = torch.mean(inputs, dim=-1, keepdim=True) + inputs_std = torch.std(inputs, dim=-1, keepdim=True, unbiased=False) + cutoff_x = inputs_mean + inputs_std * std_multiplier + return F.relu(inputs - cutoff_x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate = self.gate_proj(x) + if self.activation_sparsity > 0.0: + gate = self._gaussian_topk(gate) + gate = self.act(gate) + up = self.up_proj(x) + hidden = gate * up + return self.down_proj(hidden) + + + +class SimpleScaledWordEmbedding(nn.Embedding): + def __init__(self, num_embeddings: int, embedding_dim: int, embed_scale: float = 1.0): + super().__init__(num_embeddings, embedding_dim) + self.scalar_embed_scale = float(embed_scale) + self.register_buffer("embed_scale", torch.tensor(float(embed_scale)), persistent=False) + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype) + + +class SimpleLaurelBlock(nn.Module): + def __init__(self, config): + super().__init__() + tc = _get_text_config(config) + laurel_rank = getattr(tc, "laurel_rank", 8) + self.linear_left = nn.Linear(tc.hidden_size, laurel_rank, bias=False) + self.linear_right = nn.Linear(laurel_rank, tc.hidden_size, bias=False) + self.post_laurel_norm = SimpleGemmaRMSNorm(tc.hidden_size, eps=tc.rms_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + laurel_hidden_states = self.linear_left(hidden_states) + laurel_hidden_states = self.linear_right(laurel_hidden_states) + normed = self.post_laurel_norm(laurel_hidden_states) + return hidden_states + normed + + +class SimpleAltUp(nn.Module): + def __init__(self, config): + super().__init__() + tc = _get_text_config(config) + self.config = tc + self.altup_num_inputs = getattr(tc, "altup_num_inputs", 2) + self.altup_active_idx = getattr(tc, "altup_active_idx", 0) + self.correct_output_scale = nn.Parameter(torch.zeros(tc.hidden_size)) + self.correction_coefs = nn.Linear(self.altup_num_inputs, self.altup_num_inputs, bias=False) + self.prediction_coefs = nn.Linear(self.altup_num_inputs, self.altup_num_inputs ** 2, bias=False) + self.modality_router = nn.Linear(tc.hidden_size, self.altup_num_inputs, bias=False) + self.router_norm = SimpleGemmaRMSNorm(tc.hidden_size, eps=tc.rms_norm_eps) + self.register_buffer("router_input_scale", torch.tensor(tc.hidden_size ** -1.0), persistent=False) + + def compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor: + router_inputs = self.router_norm(x) * self.router_input_scale.to(dtype=x.dtype, device=x.device) + routed = self.modality_router(router_inputs) + return torch.tanh(routed.float()).to(dtype=x.dtype, device=x.device) + + def predict(self, hidden_states: torch.Tensor) -> torch.Tensor: + # hidden_states: [P, B, S, H] + modalities = self.compute_router_modalities(hidden_states[self.altup_active_idx]) + all_coefs = self.prediction_coefs(modalities).reshape( + *modalities.shape[:-1], self.altup_num_inputs, self.altup_num_inputs + ) + all_coefs = all_coefs.permute(0, 1, 3, 2) # [B, S, P, P] + predictions = torch.matmul(hidden_states.permute(1, 2, 3, 0), all_coefs) + predictions = predictions.permute(3, 0, 1, 2).contiguous() + predictions = predictions + hidden_states + return predictions.to(dtype=hidden_states.dtype) + + def correct(self, predictions: torch.Tensor, activated: torch.Tensor) -> torch.Tensor: + # predictions: [P, B, S, H], activated: [B, S, H] + modalities = self.compute_router_modalities(activated) + innovation = activated - predictions[self.altup_active_idx] + innovation = innovation.repeat(self.altup_num_inputs, 1, 1, 1) + + all_coefs = self.correction_coefs(modalities) + 1.0 + all_coefs = all_coefs.permute(2, 0, 1).unsqueeze(-1) # [P, B, S, 1] + + corrected = innovation * all_coefs + corrected = corrected + predictions + return corrected.to(dtype=activated.dtype) + + def forward(self, corrected: torch.Tensor) -> torch.Tensor: + return (corrected.to(dtype=self.correct_output_scale.dtype) * self.correct_output_scale).to(dtype=corrected.dtype) + + def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor: + return self.forward(corrected) + + +class Gemma3nAttention(nn.Module): + """Text-only Gemma 3n attention v1. + + Implements the text attention path used by the native Gemma3n + development mode, including layer-specific RoPE parameters and + full/sliding causal attention masks. + """ + + def __init__(self, config, layer_id: int): + super().__init__() + tc = _get_text_config(config) + self.layer_id = layer_id + self.layer_type = _get_layer_types(config)[layer_id] + + self.hidden_size = tc.hidden_size + self.num_heads = tc.num_attention_heads + self.num_kv_heads = tc.num_key_value_heads + self.head_dim = getattr(tc, "head_dim", self.hidden_size // self.num_heads) + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.q_proj = nn.Linear(self.hidden_size, self.q_size, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.kv_size, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.kv_size, bias=False) + self.o_proj = nn.Linear(self.q_size, self.hidden_size, bias=False) + + self.q_norm = SimpleGemmaRMSNorm(self.head_dim) + self.k_norm = SimpleGemmaRMSNorm(self.head_dim) + self.v_norm = SimpleRMSNormNoWeight(eps=getattr(tc, "rms_norm_eps", 1e-6)) + + self.sliding_window = getattr(tc, "sliding_window", None) + + num_kv_shared_layers = int(getattr(tc, "num_kv_shared_layers", 0)) + first_kv_shared_layer_idx = int(tc.num_hidden_layers) - num_kv_shared_layers + layer_types = _get_layer_types(config) + prev_layers = layer_types[:first_kv_shared_layer_idx] + + self.is_kv_shared_layer = layer_id >= first_kv_shared_layer_idx > 0 + if self.is_kv_shared_layer: + self.kv_shared_layer_index = ( + len(prev_layers) - 1 - prev_layers[::-1].index(self.layer_type) + ) + self.store_full_length_kv = False + else: + self.kv_shared_layer_index = None + if prev_layers and self.layer_type in prev_layers: + self.store_full_length_kv = ( + layer_id == len(prev_layers) - 1 - prev_layers[::-1].index(self.layer_type) + ) + else: + self.store_full_length_kv = False + + rope_parameters = getattr(tc, "rope_parameters", None) + rope_theta = None + if isinstance(rope_parameters, dict): + layer_rope = rope_parameters.get(self.layer_type, {}) + if isinstance(layer_rope, dict): + rope_theta = layer_rope.get("rope_theta", None) + + if rope_theta is None: + if self.layer_type == "sliding_attention": + rope_theta = getattr(tc, "rope_local_base_freq", 10000.0) + else: + rope_theta = getattr(tc, "rope_theta", 10000.0) + + self.rope_theta = float(rope_theta) + + def _build_attention_mask( + self, + positions: torch.Tensor, + device, + ) -> torch.Tensor: + # positions: [B, S] + q_pos = positions[:, :, None] # [B, S, 1] + k_pos = positions[:, None, :] # [B, 1, S] + + # causal + mask = k_pos <= q_pos + + # local sliding window for sliding_attention layers + if self.layer_type == "sliding_attention" and self.sliding_window is not None: + lower_bound = q_pos - (int(self.sliding_window) - 1) + mask = mask & (k_pos >= lower_bound) + + return mask + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: Any, + ) -> torch.Tensor: + positions = positions.to(device=hidden_states.device, non_blocking=True) + batch_size, seq_len, _ = hidden_states.shape + + q = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim) + q = self.q_norm(q).transpose(1, 2) # [B, H, S, D] + + cos, sin = _build_rope_cos_sin( + positions=positions, + head_dim=self.head_dim, + base=self.rope_theta, + device=hidden_states.device, + dtype=q.dtype, + ) + q = _apply_rope(q, cos, sin) + + shared_kv_cache = None + if isinstance(forward_batch, dict): + shared_kv_cache = forward_batch.get("kv_shared_cache") + elif forward_batch is not None and hasattr(forward_batch, "kv_shared_cache"): + shared_kv_cache = getattr(forward_batch, "kv_shared_cache") + + if ( + self.is_kv_shared_layer + and shared_kv_cache is not None + and self.kv_shared_layer_index in shared_kv_cache + ): + k, v = shared_kv_cache[self.kv_shared_layer_index] + k = k.to(device=q.device, dtype=q.dtype, non_blocking=True) + v = v.to(device=q.device, dtype=q.dtype, non_blocking=True) + else: + k = self.k_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim) + k = self.k_norm(k).transpose(1, 2) # [B, KV, S, D] + k = _apply_rope(k, cos, sin) + + v = self.v_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim) + v = self.v_norm(v).transpose(1, 2) # [B, KV, S, D] + + if shared_kv_cache is not None and self.store_full_length_kv: + shared_kv_cache[self.layer_id] = (k, v) + + if self.num_heads != self.num_kv_heads: + n_rep = self.num_heads // self.num_kv_heads + k = k[:, :, None, :, :].expand(batch_size, self.num_kv_heads, n_rep, seq_len, self.head_dim) + v = v[:, :, None, :, :].expand(batch_size, self.num_kv_heads, n_rep, seq_len, self.head_dim) + k = k.reshape(batch_size, self.num_heads, seq_len, self.head_dim) + v = v.reshape(batch_size, self.num_heads, seq_len, self.head_dim) + + attn_scores = torch.matmul(q, k.transpose(-2, -1)) + + attn_mask = self._build_attention_mask(positions, hidden_states.device) # [B, S, S] + mask_value = torch.finfo(attn_scores.dtype).min + attn_scores = attn_scores.masked_fill(~attn_mask[:, None, :, :], mask_value) + + attn_weights = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(attn_scores.dtype) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(1, 2).contiguous().reshape(batch_size, seq_len, self.q_size) + + return self.o_proj(attn_output) + + +class Gemma3nDecoderLayer(nn.Module): + def __init__(self, config, layer_id: int): + super().__init__() + tc = _get_text_config(config) + + self.self_attn = Gemma3nAttention(config, layer_id) + activation_sparsity_pattern = getattr(tc, "activation_sparsity_pattern", None) + if isinstance(activation_sparsity_pattern, (list, tuple)): + layer_activation_sparsity = float(activation_sparsity_pattern[layer_id]) + else: + layer_activation_sparsity = float(getattr(tc, "activation_sparsity", 0.0)) + + self.mlp = SimpleMLP( + hidden_size=tc.hidden_size, + intermediate_size=_get_intermediate_size(config, layer_id), + activation=getattr(tc, "hidden_activation", getattr(tc, "hidden_act", "silu")), + activation_sparsity=layer_activation_sparsity, + ) + self.input_layernorm = SimpleGemmaRMSNorm(tc.hidden_size, eps=tc.rms_norm_eps) + self.post_attention_layernorm = SimpleGemmaRMSNorm( + tc.hidden_size, eps=tc.rms_norm_eps + ) + self.pre_feedforward_layernorm = SimpleGemmaRMSNorm( + tc.hidden_size, eps=tc.rms_norm_eps + ) + self.post_feedforward_layernorm = SimpleGemmaRMSNorm( + tc.hidden_size, eps=tc.rms_norm_eps + ) + + self.altup = SimpleAltUp(config) + self.laurel = SimpleLaurelBlock(config) + self.hidden_size_per_layer_input = getattr(tc, "hidden_size_per_layer_input", 8) + self.per_layer_input_gate = nn.Linear( + tc.hidden_size, self.hidden_size_per_layer_input, bias=False + ) + self.per_layer_projection = nn.Linear( + self.hidden_size_per_layer_input, tc.hidden_size, bias=False + ) + self.post_per_layer_input_norm = SimpleGemmaRMSNorm( + tc.hidden_size, eps=tc.rms_norm_eps + ) + self.per_layer_input_act = _get_hidden_act_fn( + getattr(tc, "hidden_activation", getattr(tc, "hidden_act", "silu")) + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, # [P, B, S, H] + residual: Optional[torch.Tensor], + forward_batch: Any, + per_layer_input: Optional[torch.Tensor] = None, # [B, S, Pdim] + ): + del residual + predictions = self.altup.predict(hidden_states) + active_prediction = predictions[self.altup.altup_active_idx] # [B, S, H] + + active_prediction_normed = self.input_layernorm(active_prediction) + laurel_output = self.laurel(active_prediction_normed) + + attn_output = self.self_attn(positions, active_prediction_normed, forward_batch) + attn_output = self.post_attention_layernorm(attn_output) + + attn_gated = active_prediction + attn_output + attn_laurel = (attn_gated + laurel_output) / (2.0 ** 0.5) + + attn_norm = self.pre_feedforward_layernorm(attn_laurel) + attn_ffw = self.mlp(attn_norm) + attn_ffw_norm = self.post_feedforward_layernorm(attn_ffw) + attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm + + corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated) + + first_prediction = corrected_predictions[self.altup.altup_active_idx].clone() + if getattr(self.altup.config, "altup_correct_scale", False): + first_prediction = self.altup.scale_corrected_output(first_prediction) + + if per_layer_input is not None: + first_prediction = self.per_layer_input_gate(first_prediction) + first_prediction = self.per_layer_input_act(first_prediction) + first_prediction = first_prediction * per_layer_input + first_prediction = self.per_layer_projection(first_prediction) + first_prediction = self.post_per_layer_input_norm(first_prediction) + if corrected_predictions.shape[0] > 1: + corrected_predictions[1:] = corrected_predictions[1:] + first_prediction.unsqueeze(0) + + return corrected_predictions + + +class Gemma3nModel(nn.Module): + def __init__(self, config): + super().__init__() + tc = _get_text_config(config) + + self.config = config + self.embed_tokens = SimpleScaledWordEmbedding(tc.vocab_size, tc.hidden_size, embed_scale=tc.hidden_size ** 0.5) + self.layers = nn.ModuleList( + [Gemma3nDecoderLayer(config, i) for i in range(tc.num_hidden_layers)] + ) + self.norm = SimpleGemmaRMSNorm(tc.hidden_size, eps=tc.rms_norm_eps) + + # Text-only modules whose names match the official Gemma3n checkpoint. + self.hidden_size_per_layer_input = getattr(tc, "hidden_size_per_layer_input", 8) + self.vocab_size_per_layer_input = getattr(tc, "vocab_size_per_layer_input", tc.vocab_size) + self.altup_num_inputs = getattr(tc, "altup_num_inputs", 2) + + self.embed_tokens_per_layer = SimpleScaledWordEmbedding( + self.vocab_size_per_layer_input, + tc.num_hidden_layers * self.hidden_size_per_layer_input, + embed_scale=self.hidden_size_per_layer_input ** 0.5, + ) + self.per_layer_model_projection = nn.Linear( + tc.hidden_size, + tc.num_hidden_layers * self.hidden_size_per_layer_input, + bias=False, + ) + self.per_layer_projection_norm = SimpleGemmaRMSNorm( + self.hidden_size_per_layer_input, eps=tc.rms_norm_eps + ) + + self.altup_projections = nn.ModuleList( + [nn.Linear(tc.hidden_size, tc.hidden_size, bias=False) for _ in range(1, self.altup_num_inputs)] + ) + self.altup_unembed_projections = nn.ModuleList( + [nn.Linear(tc.hidden_size, tc.hidden_size, bias=False) for _ in range(1, self.altup_num_inputs)] + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: Any = None, + ): + model_device = self.norm.weight.device + # Native Gemma3nModel accepts either flat mllm tensors [T] + # or batched tensors [B, T]. Normalize them once here. + if input_ids.dim() == 1: + input_ids_hf = input_ids.unsqueeze(0) + else: + input_ids_hf = input_ids + + if positions is None: + positions = torch.arange( + input_ids_hf.shape[1], + dtype=torch.long, + device=input_ids_hf.device, + ).unsqueeze(0) + elif positions.dim() == 1: + positions = positions.unsqueeze(0) + + # Server/scheduler tensors may arrive on CUDA even when Gemma3n + # is intentionally instantiated on CPU for memory reasons. + # Keep position tensors on the same device as the native model. + positions = positions.to(device=model_device, non_blocking=True) + + embed_input_ids = input_ids_hf.to( + device=self.embed_tokens.weight.device, + non_blocking=True, + ) + hidden_states_0 = self.embed_tokens(embed_input_ids).to( + device=model_device, + non_blocking=True, + ) + batch_size, seq_len, _ = hidden_states_0.shape + num_layers = len(self.layers) + + per_layer_input_ids = input_ids_hf.to( + device=self.embed_tokens_per_layer.weight.device, + non_blocking=True, + ) + per_layer_inputs = self.embed_tokens_per_layer(per_layer_input_ids) + per_layer_inputs = per_layer_inputs.to( + device=hidden_states_0.device, + dtype=hidden_states_0.dtype, + non_blocking=True, + ).reshape( + batch_size, + seq_len, + num_layers, + self.hidden_size_per_layer_input, + ) + per_layer_projection = self.per_layer_model_projection(hidden_states_0) + per_layer_projection = per_layer_projection * (hidden_states_0.shape[-1] ** -0.5) + per_layer_projection = per_layer_projection.reshape( + batch_size, + seq_len, + num_layers, + self.hidden_size_per_layer_input, + ) + per_layer_projection = self.per_layer_projection_norm(per_layer_projection) + per_layer_inputs = (per_layer_projection + per_layer_inputs) * (2.0 ** -0.5) + + target_magnitude = torch.mean(hidden_states_0 ** 2, dim=-1, keepdim=True).clamp_min(1e-5).sqrt() + + temp_hidden_states = [hidden_states_0] + for i in range(1, self.altup_num_inputs): + altup_proj = self.altup_projections[i - 1](hidden_states_0) + current_hidden_state = altup_proj.to(dtype=hidden_states_0.dtype, device=hidden_states_0.device) + new_magnitude = torch.mean(current_hidden_state ** 2, dim=-1, keepdim=True).clamp_min(1e-5).sqrt() + current_hidden_state = current_hidden_state * target_magnitude / new_magnitude + temp_hidden_states.append(current_hidden_state) + + hidden_states = torch.stack(temp_hidden_states, dim=0) # [P, B, S, H] + + native_forward_batch = {"kv_shared_cache": {}} + + for layer_idx, layer in enumerate(self.layers): + hidden_states = layer( + positions=positions, + hidden_states=hidden_states, + residual=None, + forward_batch=native_forward_batch, + per_layer_input=per_layer_inputs[:, :, layer_idx, :], + ) + + target_magnitude = torch.mean(hidden_states[0] ** 2, dim=-1, keepdim=True).clamp_min(1e-5).sqrt() + temp_hidden_states = [hidden_states[0]] + for i in range(1, self.altup_num_inputs): + altup_unemb = self.altup_unembed_projections[i - 1](hidden_states[i]) + current_hidden_state = altup_unemb.to(dtype=hidden_states_0.dtype, device=hidden_states_0.device) + new_magnitude = torch.mean(current_hidden_state ** 2, dim=-1, keepdim=True).clamp_min(1e-5).sqrt() + current_hidden_state = current_hidden_state * target_magnitude / new_magnitude + temp_hidden_states.append(current_hidden_state) + + hidden_states = torch.stack(temp_hidden_states, dim=0) + hidden_states = torch.mean(hidden_states, dim=0) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class Gemma3nForCausalLM(nn.Module): + """Text-only Gemma3n Causal LM wrapper. + + The default server path uses the official HF text model and keeps the main + token embedding and lm_head on CUDA, while offloading the large per-layer + embedding table to CPU during weight loading. Native mode can be enabled + for development with MLLM_GEMMA3N_NATIVE=1. + """ + + requires_cpu_first_weight_loading = False + + def __init__(self, config, quant_config=None, prefix: str = ""): + super().__init__() + text_config = _get_text_config(config) + text_config._attn_implementation = "eager" + + import os + self.use_native = os.environ.get("MLLM_GEMMA3N_NATIVE", "0") == "1" + self.use_model_path_weight_loader = self.use_native + if self.use_native: + self.config = text_config + self.quant_config = quant_config + self.prefix = prefix + self.model = Gemma3nModel(config) + self.lm_head = nn.Linear(text_config.hidden_size, text_config.vocab_size, bias=False) + self._hf_past_key_values = None + # Minimal batch=1 decode cache for pymllm-server native smoke tests. + # During prefill the server passes the full prompt; during decode it + # passes only the latest token. Until native attention is integrated + # with mllm's KV cache, keep the full token/position history here and + # recompute the full context, returning only the last-token logits. + self._native_cached_input_ids = None + self._native_cached_positions = None + return + + from transformers.models.gemma3n.modeling_gemma3n import ( + Gemma3nForCausalLM as HFGemma3nForCausalLM, + ) + + self.config = text_config + self.quant_config = quant_config + self.prefix = prefix + + hf_model = HFGemma3nForCausalLM(text_config) + self.model = hf_model.model + self.lm_head = hf_model.lm_head + + # Minimal HF cache for current text-only server path. + # mllm supplies only the newest token during decode, so the HF wrapper + # must retain and reuse past_key_values across decode steps. + self._hf_past_key_values = None + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: Any = None, + ): + # mllm passes flat tensors for extend/prefill, e.g. input_ids=[T] and positions=[T]. + # HF Gemma3n expects batched text tensors: + # input_ids: [B, T] + # position_ids: [B, T] + # inputs_embeds: [B, T, H] + # Current server path only supports batch_size=1 for Gemma3n text-only. + forward_batch_obj = forward_batch + del forward_batch + + if getattr(self, "use_native", False): + if input_ids.dim() == 1: + input_ids_hf = input_ids.unsqueeze(0) + else: + input_ids_hf = input_ids + + if positions.dim() == 1: + position_ids_hf = positions.unsqueeze(0) + else: + position_ids_hf = positions + + # Native text-only path currently keeps large Gemma3n weights on CPU. + # Keep the recomputed full-context forward on CPU as well; otherwise + # server decode can pass CUDA token tensors while embeddings/lm_head + # remain on CPU, producing a different path from the verified local + # stepwise generation. + # Important: do a blocking CPU copy here. In server decode, the + # CUDA input tensor can be reused/cleared by the runtime after the + # forward call is scheduled; a non_blocking CUDA->CPU copy may then + # observe zeros or stale values. + input_ids_hf = input_ids_hf.detach().to(device=torch.device("cpu"), non_blocking=False).clone() + position_ids_hf = position_ids_hf.detach().to(device=torch.device("cpu"), non_blocking=False).clone() + + # pymllm-server prefill/extend is the start of a new request + # context. Prefer the server's forward mode over sequence length so + # one-token prompts do not get appended to a previous request cache. + is_extend_mode = False + if forward_batch_obj is not None: + forward_mode = getattr(forward_batch_obj, "forward_mode", None) + is_extend = getattr(forward_mode, "is_extend", None) + if callable(is_extend): + is_extend_mode = bool(is_extend()) + elif isinstance(forward_batch_obj, dict): + is_extend_mode = forward_batch_obj.get("forward_mode") == "extend" + + # pymllm-server prefill sends the full prompt, while decode sends + # only the latest token. For the current text-only native path, + # recompute from the full cached sequence so multi-token server + # decode matches direct greedy generation. + is_prefill = ( + is_extend_mode + or input_ids_hf.shape[1] > 1 + or self._native_cached_input_ids is None + or self._native_cached_positions is None + ) + if is_prefill: + full_input_ids = input_ids_hf + else: + full_input_ids = torch.cat( + [ + self._native_cached_input_ids.to(device=input_ids_hf.device), + input_ids_hf, + ], + dim=1, + ) + + # Recompute contiguous full-context positions, matching direct greedy + # generation. Decode-time positions from the server correspond to + # the submitted token batch, not necessarily to the recomputed full + # native context. + full_positions = torch.arange( + full_input_ids.shape[1], + dtype=torch.long, + device=full_input_ids.device, + ).unsqueeze(0).expand(full_input_ids.shape[0], -1) + + self._native_cached_input_ids = full_input_ids.detach().cpu() + self._native_cached_positions = full_positions.detach().cpu() + + hidden_states = self.model(input_ids=full_input_ids, positions=full_positions) + output_device = hidden_states.device + logits = self.lm_head(hidden_states.to(device=self.lm_head.weight.device)) + + final_logit_softcapping = getattr(self.config, "final_logit_softcapping", None) + if final_logit_softcapping is not None: + logits = logits / final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * final_logit_softcapping + + # For decode, the server expects logits for the token(s) just + # submitted in this forward call, not the whole recomputed prefix. + logits = logits[:, -input_ids_hf.shape[1]:, :] + + return logits.to(device=output_device, non_blocking=True) + + if input_ids.dim() == 1: + input_ids_hf = input_ids.unsqueeze(0) + else: + input_ids_hf = input_ids + + if positions.dim() == 1: + position_ids_hf = positions.unsqueeze(0) + else: + position_ids_hf = positions + + # For batch_size=1 text-only serving: + # - prefill has sequence length > 1, so reset HF cache; + # - decode has sequence length == 1, so reuse stored HF cache. + is_prefill = input_ids_hf.shape[1] > 1 or self._hf_past_key_values is None + past_key_values = None if is_prefill else self._hf_past_key_values + + model_device = self.model.norm.weight.device + embed_dev = self.model.embed_tokens.weight.device + per_layer_dev = self.model.embed_tokens_per_layer.weight.device + + # If both embedding tables already live on the same device as the HF text model, + # use the native HF text-model path directly. This is the closest apples-to-apples + # comparison with the official implementation. + if embed_dev == model_device and per_layer_dev == model_device: + outputs = self.model( + input_ids=input_ids_hf.to(device=model_device, non_blocking=True), + per_layer_inputs=None, + attention_mask=None, + position_ids=position_ids_hf.to(device=model_device, non_blocking=True), + past_key_values=past_key_values, + inputs_embeds=None, + use_cache=True, + ) + else: + embed_input_ids = input_ids_hf.to( + device=embed_dev, + non_blocking=True, + ) + inputs_embeds = self.model.embed_tokens(embed_input_ids).to( + device=model_device, + non_blocking=True, + ) + + per_layer_input_ids = input_ids_hf.to( + device=per_layer_dev, + non_blocking=True, + ) + per_layer_inputs = self.model.get_per_layer_inputs(per_layer_input_ids).to( + device=model_device, + dtype=inputs_embeds.dtype, + non_blocking=True, + ) + + outputs = self.model( + input_ids=None, + per_layer_inputs=per_layer_inputs, + attention_mask=None, + position_ids=position_ids_hf.to(device=model_device, non_blocking=True), + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=True, + ) + + if hasattr(outputs, "past_key_values"): + self._hf_past_key_values = outputs.past_key_values + + hidden_states = outputs.last_hidden_state + output_device = hidden_states.device + hidden_states_for_head = hidden_states.to( + device=self.lm_head.weight.device, + non_blocking=True, + ) + logits = self.lm_head(hidden_states_for_head) + + final_logit_softcapping = getattr(self.config, "final_logit_softcapping", None) + if final_logit_softcapping is not None: + logits = logits / final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * final_logit_softcapping + + return logits.to(device=output_device, non_blocking=True) + + + def load_weights_from_model_path(self, model_path, chunk_bytes: int = 256 * 1024 * 1024): + """Load Gemma3n text-only weights directly from a local HF checkpoint. + + This path avoids materializing full safetensors shards or the full + iterator in memory. It is needed because Gemma3n has very large text + embedding tables, especially embed_tokens_per_layer.weight. + """ + import gc + import json + from pathlib import Path + + from safetensors import safe_open + + model_path = Path(model_path) + own_state = self.state_dict() + + def _normalize_name(name: str): + if ( + name.startswith("model.audio_tower.") + or name.startswith("model.vision_tower.") + or name.startswith("model.embed_audio.") + or name.startswith("model.embed_vision.") + ): + return None, "ignored_multimodal" + + if name.startswith("model.language_model."): + name = "model." + name[len("model.language_model."):] + elif name.startswith("language_model."): + name = "model." + name[len("language_model."):] + + return name, None + + def _tensor_nbytes_from_shape(shape, element_size: int): + numel = 1 + for dim in shape: + numel *= int(dim) + return numel * element_size + + def _copy_tensor_streaming(safe_file, raw_name: str, target: torch.Tensor): + tensor_slice = safe_file.get_slice(raw_name) + shape = tuple(tensor_slice.get_shape()) + + if tuple(target.shape) != shape: + raise RuntimeError( + f"shape mismatch for {raw_name}: " + f"target={tuple(target.shape)} checkpoint={shape}" + ) + + estimated_bytes = _tensor_nbytes_from_shape(shape, target.element_size()) + + if len(shape) >= 2 and estimated_bytes >= chunk_bytes: + row_numel = 1 + for dim in shape[1:]: + row_numel *= int(dim) + + rows_per_chunk = max(1, chunk_bytes // (row_numel * target.element_size())) + total_rows = int(shape[0]) + + logger.info( + "Gemma3n streaming large tensor %s shape=%s rows_per_chunk=%d", + raw_name, + shape, + rows_per_chunk, + ) + + with torch.no_grad(): + for start in range(0, total_rows, rows_per_chunk): + end = min(total_rows, start + rows_per_chunk) + piece = tensor_slice[start:end] + target[start:end].copy_( + piece.to(dtype=target.dtype, device=target.device) + ) + del piece + gc.collect() + else: + piece = safe_file.get_tensor(raw_name) + with torch.no_grad(): + target.copy_(piece.to(dtype=target.dtype, device=target.device)) + del piece + + index_path = model_path / "model.safetensors.index.json" + if index_path.exists(): + index = json.loads(index_path.read_text()) + weight_map = index.get("weight_map", {}) + by_file = {} + for raw_name, filename in weight_map.items(): + by_file.setdefault(filename, []).append(raw_name) + else: + st_files = sorted(model_path.glob("*.safetensors")) + if not st_files: + logger.info( + "Gemma3n streaming loader found no safetensors under %s; " + "falling back to regular load_weights path.", + model_path, + ) + raise FileNotFoundError( + f"No safetensors checkpoint shards found under {model_path}. " + "Gemma3n native loading currently expects safetensors weights." + ) + + by_file = {} + for path in st_files: + with safe_open(path, framework="pt", device="cpu") as safe_file: + by_file[path.name] = list(safe_file.keys()) + + loaded = [] + skipped = [] + normalized_weight_names = set() + + logger.info( + "Gemma3n streaming load begin. model_path=%s shards=%d", + model_path, + len(by_file), + ) + + for filename in sorted(by_file.keys()): + shard_path = model_path / filename + raw_names = by_file[filename] + logger.info( + "Gemma3n streaming shard begin: %s tensors=%d", + filename, + len(raw_names), + ) + + with safe_open(shard_path, framework="pt", device="cpu") as safe_file: + for raw_name in raw_names: + name, skip_reason = _normalize_name(raw_name) + + if name is None: + skipped.append((raw_name, skip_reason)) + continue + + normalized_weight_names.add(name) + + if name not in own_state: + skipped.append((raw_name, f"missing_in_model -> {name}")) + continue + + target = own_state[name] + shape = tuple(safe_file.get_slice(raw_name).get_shape()) + + if tuple(target.shape) != shape: + skipped.append( + ( + raw_name, + f"shape_mismatch mapped={name} " + f"model={tuple(target.shape)} ckpt={shape}", + ) + ) + continue + + _copy_tensor_streaming(safe_file, raw_name, target) + loaded.append((raw_name, name)) + + logger.info( + "Gemma3n streaming shard end: %s loaded_total=%d skipped_total=%d", + filename, + len(loaded), + len(skipped), + ) + gc.collect() + + if "lm_head.weight" in own_state and "model.embed_tokens.weight" in own_state: + with torch.no_grad(): + own_state["lm_head.weight"].copy_(own_state["model.embed_tokens.weight"]) + loaded.append(("always_tied_from_embed_tokens", "lm_head.weight")) + normalized_weight_names.add("lm_head.weight") + + missing_in_ckpt = [ + name for name in own_state.keys() + if name not in normalized_weight_names + ] + + logger.info( + "Gemma3n streaming load end: loaded=%d skipped=%d missing_in_ckpt=%d", + len(loaded), + len(skipped), + len(missing_in_ckpt), + ) + + if skipped: + logger.info("Gemma3n streaming load first_skipped=%s", skipped[:20]) + + return { + "loaded": loaded, + "skipped": skipped, + "missing_in_ckpt": missing_in_ckpt, + } + + def load_weights(self, weights): + if hasattr(weights, "state_dict"): + weights = weights.state_dict() + + if isinstance(weights, dict): + weight_items = weights.items() + else: + try: + weight_items = iter(weights) + except TypeError: + raise TypeError( + f"weights must be a dict-like state_dict, a module with state_dict(), " + f"or an iterable of (name, tensor), got {type(weights)}" + ) + + def _normalize_name(name: str): + # Ignore clearly multimodal-only weights for current text-only path. + if ( + name.startswith("model.audio_tower.") + or name.startswith("model.vision_tower.") + or name.startswith("model.embed_audio.") + or name.startswith("model.embed_vision.") + ): + return None, "ignored_multimodal" + + # HF Gemma3n outer model wraps text weights under model.language_model.* + if name.startswith("model.language_model."): + name = "model." + name[len("model.language_model."):] + elif name.startswith("language_model."): + name = "model." + name[len("language_model."):] + + return name, None + + # Keep the main token embedding path on GPU for output quality. + # Only offload the very large per-layer embedding table to CPU. + self.model.embed_tokens_per_layer = self.model.embed_tokens_per_layer.to("cpu") + + import gc + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + logger.info("Gemma3n devices: embed_tokens=%s embed_tokens_per_layer=%s lm_head=%s", self.model.embed_tokens.weight.device, self.model.embed_tokens_per_layer.weight.device, self.lm_head.weight.device) + + own_state = self.state_dict() + loaded = [] + skipped = [] + normalized_weight_names = set() + + for raw_name, tensor in weight_items: + name, skip_reason = _normalize_name(raw_name) + + if name is None: + skipped.append((raw_name, skip_reason)) + continue + + normalized_weight_names.add(name) + + if name not in own_state: + skipped.append((raw_name, f"missing_in_model -> {name}")) + continue + + if own_state[name].shape != tensor.shape: + skipped.append( + ( + raw_name, + f"shape_mismatch mapped={name} " + f"model={tuple(own_state[name].shape)} ckpt={tuple(tensor.shape)}", + ) + ) + continue + + own_state[name].copy_( + tensor.to(dtype=own_state[name].dtype, device=own_state[name].device) + ) + loaded.append((raw_name, name)) + + # Always tie lm_head to embed_tokens for text-only Gemma3n, + # matching the official tied-weights behavior. + if "lm_head.weight" in own_state and "model.embed_tokens.weight" in own_state: + own_state["lm_head.weight"].copy_(own_state["model.embed_tokens.weight"]) + loaded.append(("always_tied_from_embed_tokens", "lm_head.weight")) + normalized_weight_names.add("lm_head.weight") + + missing_in_ckpt = [name for name in own_state.keys() if name not in normalized_weight_names] + + logger.info( + "Gemma3nForCausalLM.load_weights: loaded=%d skipped=%d missing_in_ckpt=%d", + len(loaded), + len(skipped), + len(missing_in_ckpt), + ) + + if loaded: + logger.info("Gemma3nForCausalLM.load_weights: first_loaded=%s", loaded[:20]) + if skipped: + logger.info("Gemma3nForCausalLM.load_weights: first_skipped=%s", skipped[:20]) + + return { + "loaded": loaded, + "skipped": skipped, + "missing_in_ckpt": missing_in_ckpt, + } + + +class Gemma3nForConditionalGeneration(Gemma3nForCausalLM): + """Text-only compatibility entry point for Gemma3n checkpoints. + + The official checkpoint architecture resolves to this class, while the + current implementation supports the language-model path. + """ + pass diff --git a/pymllm/orchestrator/model_runner_process.py b/pymllm/orchestrator/model_runner_process.py index a514ac2e..245ee46a 100644 --- a/pymllm/orchestrator/model_runner_process.py +++ b/pymllm/orchestrator/model_runner_process.py @@ -498,6 +498,11 @@ def _insert_into_radix_cache(self, requests_meta: List[Dict[str, Any]]) -> None: if cache is None: return + # When radix cache is disabled, the runner uses ChunkCache rather than + # RadixCache. ChunkCache should not enter radix insertion logic. + if not hasattr(cache, "page_size"): + return + runner = self._runner gdn_pool = getattr(runner, "gdn_pool", None) @@ -999,7 +1004,10 @@ def _free_rid_resources(self, rid: str) -> None: # and the eviction callback; here we just remove the rid mapping. self._rid_to_gdn_track_slot.pop(rid, None) - cache_enabled = cache is not None + # ChunkCache is used when radix cache is disabled. It is still stored + # in self._radix_cache for the shared prefix-cache interface, but it + # must not enter RadixCache-specific cleanup/insert logic here. + cache_enabled = cache is not None and hasattr(cache, "page_size") # ---------------------------------------------------------- # Phase 1: Read all KV indices BEFORE freeing anything. From f30d54905d2e11424911331ea35ada14411313e9 Mon Sep 17 00:00:00 2001 From: grape203 Date: Tue, 5 May 2026 23:37:56 +0800 Subject: [PATCH 02/12] Refactor Gemma3n native wrapper --- pymllm/executor/model_runner.py | 9 - pymllm/models/gemma3n.py | 308 +++++++++++--------------------- 2 files changed, 106 insertions(+), 211 deletions(-) diff --git a/pymllm/executor/model_runner.py b/pymllm/executor/model_runner.py index 5289f78a..1ced172f 100644 --- a/pymllm/executor/model_runner.py +++ b/pymllm/executor/model_runner.py @@ -603,15 +603,6 @@ def load_model(self) -> None: requires_cpu_first_loading = bool( getattr(model_cls, "requires_cpu_first_weight_loading", False) ) - if not requires_cpu_first_loading: - import os - - gemma3n_native_mode = ( - architecture in {"Gemma3nForCausalLM", "Gemma3nForConditionalGeneration"} - and os.environ.get("MLLM_GEMMA3N_NATIVE", "0") == "1" - ) - requires_cpu_first_loading = gemma3n_native_mode - instantiate_device_str = "cpu" if requires_cpu_first_loading else device_str if instantiate_device_str != device_str: logger.info( diff --git a/pymllm/models/gemma3n.py b/pymllm/models/gemma3n.py index 5c685809..148776e8 100644 --- a/pymllm/models/gemma3n.py +++ b/pymllm/models/gemma3n.py @@ -46,7 +46,7 @@ def _get_intermediate_size(config, layer_id: int) -> int: return int(intermediate_size[layer_id]) -class SimpleGemmaRMSNorm(nn.Module): +class Gemma3nRMSNorm(nn.Module): def __init__(self, hidden_size: int, eps: float = 1e-6): super().__init__() self.hidden_size = hidden_size @@ -80,7 +80,7 @@ def forward( -class SimpleRMSNormNoWeight(nn.Module): +class Gemma3nRMSNormNoWeight(nn.Module): def __init__(self, eps: float = 1e-6): super().__init__() self.eps = eps @@ -124,7 +124,7 @@ def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch. return (x * cos) + (_rotate_half(x) * sin) -class SimpleMLP(nn.Module): +class Gemma3nMLP(nn.Module): def __init__( self, hidden_size: int, @@ -184,7 +184,7 @@ def __init__(self, config): laurel_rank = getattr(tc, "laurel_rank", 8) self.linear_left = nn.Linear(tc.hidden_size, laurel_rank, bias=False) self.linear_right = nn.Linear(laurel_rank, tc.hidden_size, bias=False) - self.post_laurel_norm = SimpleGemmaRMSNorm(tc.hidden_size, eps=tc.rms_norm_eps) + self.post_laurel_norm = Gemma3nRMSNorm(tc.hidden_size, eps=tc.rms_norm_eps) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: laurel_hidden_states = self.linear_left(hidden_states) @@ -204,7 +204,7 @@ def __init__(self, config): self.correction_coefs = nn.Linear(self.altup_num_inputs, self.altup_num_inputs, bias=False) self.prediction_coefs = nn.Linear(self.altup_num_inputs, self.altup_num_inputs ** 2, bias=False) self.modality_router = nn.Linear(tc.hidden_size, self.altup_num_inputs, bias=False) - self.router_norm = SimpleGemmaRMSNorm(tc.hidden_size, eps=tc.rms_norm_eps) + self.router_norm = Gemma3nRMSNorm(tc.hidden_size, eps=tc.rms_norm_eps) self.register_buffer("router_input_scale", torch.tensor(tc.hidden_size ** -1.0), persistent=False) def compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor: @@ -271,9 +271,9 @@ def __init__(self, config, layer_id: int): self.v_proj = nn.Linear(self.hidden_size, self.kv_size, bias=False) self.o_proj = nn.Linear(self.q_size, self.hidden_size, bias=False) - self.q_norm = SimpleGemmaRMSNorm(self.head_dim) - self.k_norm = SimpleGemmaRMSNorm(self.head_dim) - self.v_norm = SimpleRMSNormNoWeight(eps=getattr(tc, "rms_norm_eps", 1e-6)) + self.q_norm = Gemma3nRMSNorm(self.head_dim) + self.k_norm = Gemma3nRMSNorm(self.head_dim) + self.v_norm = Gemma3nRMSNormNoWeight(eps=getattr(tc, "rms_norm_eps", 1e-6)) self.sliding_window = getattr(tc, "sliding_window", None) @@ -409,20 +409,20 @@ def __init__(self, config, layer_id: int): else: layer_activation_sparsity = float(getattr(tc, "activation_sparsity", 0.0)) - self.mlp = SimpleMLP( + self.mlp = Gemma3nMLP( hidden_size=tc.hidden_size, intermediate_size=_get_intermediate_size(config, layer_id), activation=getattr(tc, "hidden_activation", getattr(tc, "hidden_act", "silu")), activation_sparsity=layer_activation_sparsity, ) - self.input_layernorm = SimpleGemmaRMSNorm(tc.hidden_size, eps=tc.rms_norm_eps) - self.post_attention_layernorm = SimpleGemmaRMSNorm( + self.input_layernorm = Gemma3nRMSNorm(tc.hidden_size, eps=tc.rms_norm_eps) + self.post_attention_layernorm = Gemma3nRMSNorm( tc.hidden_size, eps=tc.rms_norm_eps ) - self.pre_feedforward_layernorm = SimpleGemmaRMSNorm( + self.pre_feedforward_layernorm = Gemma3nRMSNorm( tc.hidden_size, eps=tc.rms_norm_eps ) - self.post_feedforward_layernorm = SimpleGemmaRMSNorm( + self.post_feedforward_layernorm = Gemma3nRMSNorm( tc.hidden_size, eps=tc.rms_norm_eps ) @@ -435,7 +435,7 @@ def __init__(self, config, layer_id: int): self.per_layer_projection = nn.Linear( self.hidden_size_per_layer_input, tc.hidden_size, bias=False ) - self.post_per_layer_input_norm = SimpleGemmaRMSNorm( + self.post_per_layer_input_norm = Gemma3nRMSNorm( tc.hidden_size, eps=tc.rms_norm_eps ) self.per_layer_input_act = _get_hidden_act_fn( @@ -496,7 +496,7 @@ def __init__(self, config): self.layers = nn.ModuleList( [Gemma3nDecoderLayer(config, i) for i in range(tc.num_hidden_layers)] ) - self.norm = SimpleGemmaRMSNorm(tc.hidden_size, eps=tc.rms_norm_eps) + self.norm = Gemma3nRMSNorm(tc.hidden_size, eps=tc.rms_norm_eps) # Text-only modules whose names match the official Gemma3n checkpoint. self.hidden_size_per_layer_input = getattr(tc, "hidden_size_per_layer_input", 8) @@ -513,7 +513,7 @@ def __init__(self, config): tc.num_hidden_layers * self.hidden_size_per_layer_input, bias=False, ) - self.per_layer_projection_norm = SimpleGemmaRMSNorm( + self.per_layer_projection_norm = Gemma3nRMSNorm( self.hidden_size_per_layer_input, eps=tc.rms_norm_eps ) @@ -630,54 +630,33 @@ def forward( class Gemma3nForCausalLM(nn.Module): """Text-only Gemma3n Causal LM wrapper. - The default server path uses the official HF text model and keeps the main - token embedding and lm_head on CUDA, while offloading the large per-layer - embedding table to CPU during weight loading. Native mode can be enabled - for development with MLLM_GEMMA3N_NATIVE=1. + This implementation uses the native Gemma3n text-only path directly. + It keeps CPU-first instantiation and model-path streaming weight loading + for the large Gemma3n text checkpoint until the attention path is fully + integrated with pymllm.layers and RadixAttention. """ - requires_cpu_first_weight_loading = False + requires_cpu_first_weight_loading = True def __init__(self, config, quant_config=None, prefix: str = ""): super().__init__() text_config = _get_text_config(config) text_config._attn_implementation = "eager" - import os - self.use_native = os.environ.get("MLLM_GEMMA3N_NATIVE", "0") == "1" - self.use_model_path_weight_loader = self.use_native - if self.use_native: - self.config = text_config - self.quant_config = quant_config - self.prefix = prefix - self.model = Gemma3nModel(config) - self.lm_head = nn.Linear(text_config.hidden_size, text_config.vocab_size, bias=False) - self._hf_past_key_values = None - # Minimal batch=1 decode cache for pymllm-server native smoke tests. - # During prefill the server passes the full prompt; during decode it - # passes only the latest token. Until native attention is integrated - # with mllm's KV cache, keep the full token/position history here and - # recompute the full context, returning only the last-token logits. - self._native_cached_input_ids = None - self._native_cached_positions = None - return - - from transformers.models.gemma3n.modeling_gemma3n import ( - Gemma3nForCausalLM as HFGemma3nForCausalLM, - ) - self.config = text_config self.quant_config = quant_config self.prefix = prefix - - hf_model = HFGemma3nForCausalLM(text_config) - self.model = hf_model.model - self.lm_head = hf_model.lm_head - - # Minimal HF cache for current text-only server path. - # mllm supplies only the newest token during decode, so the HF wrapper - # must retain and reuse past_key_values across decode steps. - self._hf_past_key_values = None + self.use_model_path_weight_loader = True + self.model = Gemma3nModel(config) + self.lm_head = nn.Linear(text_config.hidden_size, text_config.vocab_size, bias=False) + + # Minimal batch=1 decode cache for pymllm-server native smoke tests. + # During prefill the server passes the full prompt; during decode it + # passes only the latest token. Until native attention is integrated + # with mllm's KV cache, keep the full token/position history here and + # recompute the full context, returning only the last-token logits. + self._native_cached_input_ids = None + self._native_cached_positions = None def forward( self, @@ -685,173 +664,95 @@ def forward( positions: torch.Tensor, forward_batch: Any = None, ): - # mllm passes flat tensors for extend/prefill, e.g. input_ids=[T] and positions=[T]. - # HF Gemma3n expects batched text tensors: - # input_ids: [B, T] - # position_ids: [B, T] - # inputs_embeds: [B, T, H] - # Current server path only supports batch_size=1 for Gemma3n text-only. + # mllm passes flat tensors for extend/prefill, e.g. input_ids=[T] + # and positions=[T]. Native Gemma3nModel expects batched tensors. forward_batch_obj = forward_batch del forward_batch - if getattr(self, "use_native", False): - if input_ids.dim() == 1: - input_ids_hf = input_ids.unsqueeze(0) - else: - input_ids_hf = input_ids - - if positions.dim() == 1: - position_ids_hf = positions.unsqueeze(0) - else: - position_ids_hf = positions - - # Native text-only path currently keeps large Gemma3n weights on CPU. - # Keep the recomputed full-context forward on CPU as well; otherwise - # server decode can pass CUDA token tensors while embeddings/lm_head - # remain on CPU, producing a different path from the verified local - # stepwise generation. - # Important: do a blocking CPU copy here. In server decode, the - # CUDA input tensor can be reused/cleared by the runtime after the - # forward call is scheduled; a non_blocking CUDA->CPU copy may then - # observe zeros or stale values. - input_ids_hf = input_ids_hf.detach().to(device=torch.device("cpu"), non_blocking=False).clone() - position_ids_hf = position_ids_hf.detach().to(device=torch.device("cpu"), non_blocking=False).clone() - - # pymllm-server prefill/extend is the start of a new request - # context. Prefer the server's forward mode over sequence length so - # one-token prompts do not get appended to a previous request cache. - is_extend_mode = False - if forward_batch_obj is not None: - forward_mode = getattr(forward_batch_obj, "forward_mode", None) - is_extend = getattr(forward_mode, "is_extend", None) - if callable(is_extend): - is_extend_mode = bool(is_extend()) - elif isinstance(forward_batch_obj, dict): - is_extend_mode = forward_batch_obj.get("forward_mode") == "extend" - - # pymllm-server prefill sends the full prompt, while decode sends - # only the latest token. For the current text-only native path, - # recompute from the full cached sequence so multi-token server - # decode matches direct greedy generation. - is_prefill = ( - is_extend_mode - or input_ids_hf.shape[1] > 1 - or self._native_cached_input_ids is None - or self._native_cached_positions is None - ) - if is_prefill: - full_input_ids = input_ids_hf - else: - full_input_ids = torch.cat( - [ - self._native_cached_input_ids.to(device=input_ids_hf.device), - input_ids_hf, - ], - dim=1, - ) - - # Recompute contiguous full-context positions, matching direct greedy - # generation. Decode-time positions from the server correspond to - # the submitted token batch, not necessarily to the recomputed full - # native context. - full_positions = torch.arange( - full_input_ids.shape[1], - dtype=torch.long, - device=full_input_ids.device, - ).unsqueeze(0).expand(full_input_ids.shape[0], -1) - - self._native_cached_input_ids = full_input_ids.detach().cpu() - self._native_cached_positions = full_positions.detach().cpu() - - hidden_states = self.model(input_ids=full_input_ids, positions=full_positions) - output_device = hidden_states.device - logits = self.lm_head(hidden_states.to(device=self.lm_head.weight.device)) - - final_logit_softcapping = getattr(self.config, "final_logit_softcapping", None) - if final_logit_softcapping is not None: - logits = logits / final_logit_softcapping - logits = torch.tanh(logits) - logits = logits * final_logit_softcapping - - # For decode, the server expects logits for the token(s) just - # submitted in this forward call, not the whole recomputed prefix. - logits = logits[:, -input_ids_hf.shape[1]:, :] - - return logits.to(device=output_device, non_blocking=True) - if input_ids.dim() == 1: input_ids_hf = input_ids.unsqueeze(0) else: input_ids_hf = input_ids - if positions.dim() == 1: + if positions is None: + position_ids_hf = torch.arange( + input_ids_hf.shape[1], + dtype=torch.long, + device=input_ids_hf.device, + ).unsqueeze(0).expand(input_ids_hf.shape[0], -1) + elif positions.dim() == 1: position_ids_hf = positions.unsqueeze(0) else: position_ids_hf = positions - # For batch_size=1 text-only serving: - # - prefill has sequence length > 1, so reset HF cache; - # - decode has sequence length == 1, so reuse stored HF cache. - is_prefill = input_ids_hf.shape[1] > 1 or self._hf_past_key_values is None - past_key_values = None if is_prefill else self._hf_past_key_values - - model_device = self.model.norm.weight.device - embed_dev = self.model.embed_tokens.weight.device - per_layer_dev = self.model.embed_tokens_per_layer.weight.device - - # If both embedding tables already live on the same device as the HF text model, - # use the native HF text-model path directly. This is the closest apples-to-apples - # comparison with the official implementation. - if embed_dev == model_device and per_layer_dev == model_device: - outputs = self.model( - input_ids=input_ids_hf.to(device=model_device, non_blocking=True), - per_layer_inputs=None, - attention_mask=None, - position_ids=position_ids_hf.to(device=model_device, non_blocking=True), - past_key_values=past_key_values, - inputs_embeds=None, - use_cache=True, - ) + # Native text-only path currently keeps large Gemma3n weights on CPU. + # Keep the recomputed full-context forward on CPU as well; otherwise + # server decode can pass CUDA token tensors while embeddings/lm_head + # remain on CPU, producing a different path from the verified local + # stepwise generation. + # + # Important: do a blocking CPU copy here. In server decode, the + # CUDA input tensor can be reused/cleared by the runtime after the + # forward call is scheduled; a non_blocking CUDA->CPU copy may then + # observe zeros or stale values. + input_ids_hf = input_ids_hf.detach().to( + device=torch.device("cpu"), + non_blocking=False, + ).clone() + position_ids_hf = position_ids_hf.detach().to( + device=torch.device("cpu"), + non_blocking=False, + ).clone() + + # pymllm-server prefill/extend is the start of a new request context. + # Prefer the server's forward mode over sequence length so one-token + # prompts do not get appended to a previous request cache. + is_extend_mode = False + if forward_batch_obj is not None: + forward_mode = getattr(forward_batch_obj, "forward_mode", None) + is_extend = getattr(forward_mode, "is_extend", None) + if callable(is_extend): + is_extend_mode = bool(is_extend()) + elif isinstance(forward_batch_obj, dict): + is_extend_mode = forward_batch_obj.get("forward_mode") == "extend" + + # pymllm-server prefill sends the full prompt, while decode sends + # only the latest token. For the current text-only native path, + # recompute from the full cached sequence so multi-token server + # decode matches direct greedy generation. + is_prefill = ( + is_extend_mode + or input_ids_hf.shape[1] > 1 + or self._native_cached_input_ids is None + or self._native_cached_positions is None + ) + if is_prefill: + full_input_ids = input_ids_hf else: - embed_input_ids = input_ids_hf.to( - device=embed_dev, - non_blocking=True, - ) - inputs_embeds = self.model.embed_tokens(embed_input_ids).to( - device=model_device, - non_blocking=True, - ) - - per_layer_input_ids = input_ids_hf.to( - device=per_layer_dev, - non_blocking=True, - ) - per_layer_inputs = self.model.get_per_layer_inputs(per_layer_input_ids).to( - device=model_device, - dtype=inputs_embeds.dtype, - non_blocking=True, + full_input_ids = torch.cat( + [ + self._native_cached_input_ids.to(device=input_ids_hf.device), + input_ids_hf, + ], + dim=1, ) - outputs = self.model( - input_ids=None, - per_layer_inputs=per_layer_inputs, - attention_mask=None, - position_ids=position_ids_hf.to(device=model_device, non_blocking=True), - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=True, - ) + # Recompute contiguous full-context positions, matching direct greedy + # generation. Decode-time positions from the server correspond to the + # submitted token batch, not necessarily to the recomputed full native + # context. + full_positions = torch.arange( + full_input_ids.shape[1], + dtype=torch.long, + device=full_input_ids.device, + ).unsqueeze(0).expand(full_input_ids.shape[0], -1) - if hasattr(outputs, "past_key_values"): - self._hf_past_key_values = outputs.past_key_values + self._native_cached_input_ids = full_input_ids.detach().cpu() + self._native_cached_positions = full_positions.detach().cpu() - hidden_states = outputs.last_hidden_state + hidden_states = self.model(input_ids=full_input_ids, positions=full_positions) output_device = hidden_states.device - hidden_states_for_head = hidden_states.to( - device=self.lm_head.weight.device, - non_blocking=True, - ) - logits = self.lm_head(hidden_states_for_head) + logits = self.lm_head(hidden_states.to(device=self.lm_head.weight.device)) final_logit_softcapping = getattr(self.config, "final_logit_softcapping", None) if final_logit_softcapping is not None: @@ -859,6 +760,9 @@ def forward( logits = torch.tanh(logits) logits = logits * final_logit_softcapping + # For decode, the server expects logits for the token(s) just submitted + # in this forward call, not the whole recomputed prefix. + logits = logits[:, -input_ids_hf.shape[1]:, :] return logits.to(device=output_device, non_blocking=True) From 9737d40a89f477abdb74102a69020076d34c31c2 Mon Sep 17 00:00:00 2001 From: grape203 Date: Tue, 5 May 2026 23:42:59 +0800 Subject: [PATCH 03/12] Use pymllm Linear layers in Gemma3n --- pymllm/models/gemma3n.py | 50 +++++++++++++++++++++------------------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/pymllm/models/gemma3n.py b/pymllm/models/gemma3n.py index 148776e8..6dc21898 100644 --- a/pymllm/models/gemma3n.py +++ b/pymllm/models/gemma3n.py @@ -7,6 +7,8 @@ import torch.nn as nn import torch.nn.functional as F +from pymllm.layers.linear import Linear + logger = logging.getLogger(__name__) @@ -139,9 +141,9 @@ def __init__( self.act = _get_hidden_act_fn(activation) self.activation_sparsity = float(activation_sparsity) - self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) - self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) - self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + self.gate_proj = Linear(hidden_size, intermediate_size, bias=False) + self.up_proj = Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = Linear(intermediate_size, hidden_size, bias=False) def _gaussian_topk(self, inputs: torch.Tensor) -> torch.Tensor: target_sparsity_tensor = torch.tensor( @@ -167,7 +169,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: -class SimpleScaledWordEmbedding(nn.Embedding): +class Gemma3nScaledWordEmbedding(nn.Embedding): def __init__(self, num_embeddings: int, embedding_dim: int, embed_scale: float = 1.0): super().__init__(num_embeddings, embedding_dim) self.scalar_embed_scale = float(embed_scale) @@ -177,13 +179,13 @@ def forward(self, input_ids: torch.Tensor): return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype) -class SimpleLaurelBlock(nn.Module): +class Gemma3nLaurelBlock(nn.Module): def __init__(self, config): super().__init__() tc = _get_text_config(config) laurel_rank = getattr(tc, "laurel_rank", 8) - self.linear_left = nn.Linear(tc.hidden_size, laurel_rank, bias=False) - self.linear_right = nn.Linear(laurel_rank, tc.hidden_size, bias=False) + self.linear_left = Linear(tc.hidden_size, laurel_rank, bias=False) + self.linear_right = Linear(laurel_rank, tc.hidden_size, bias=False) self.post_laurel_norm = Gemma3nRMSNorm(tc.hidden_size, eps=tc.rms_norm_eps) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -193,7 +195,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states + normed -class SimpleAltUp(nn.Module): +class Gemma3nAltUp(nn.Module): def __init__(self, config): super().__init__() tc = _get_text_config(config) @@ -201,9 +203,9 @@ def __init__(self, config): self.altup_num_inputs = getattr(tc, "altup_num_inputs", 2) self.altup_active_idx = getattr(tc, "altup_active_idx", 0) self.correct_output_scale = nn.Parameter(torch.zeros(tc.hidden_size)) - self.correction_coefs = nn.Linear(self.altup_num_inputs, self.altup_num_inputs, bias=False) - self.prediction_coefs = nn.Linear(self.altup_num_inputs, self.altup_num_inputs ** 2, bias=False) - self.modality_router = nn.Linear(tc.hidden_size, self.altup_num_inputs, bias=False) + self.correction_coefs = Linear(self.altup_num_inputs, self.altup_num_inputs, bias=False) + self.prediction_coefs = Linear(self.altup_num_inputs, self.altup_num_inputs ** 2, bias=False) + self.modality_router = Linear(tc.hidden_size, self.altup_num_inputs, bias=False) self.router_norm = Gemma3nRMSNorm(tc.hidden_size, eps=tc.rms_norm_eps) self.register_buffer("router_input_scale", torch.tensor(tc.hidden_size ** -1.0), persistent=False) @@ -266,10 +268,10 @@ def __init__(self, config, layer_id: int): self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim - self.q_proj = nn.Linear(self.hidden_size, self.q_size, bias=False) - self.k_proj = nn.Linear(self.hidden_size, self.kv_size, bias=False) - self.v_proj = nn.Linear(self.hidden_size, self.kv_size, bias=False) - self.o_proj = nn.Linear(self.q_size, self.hidden_size, bias=False) + self.q_proj = Linear(self.hidden_size, self.q_size, bias=False) + self.k_proj = Linear(self.hidden_size, self.kv_size, bias=False) + self.v_proj = Linear(self.hidden_size, self.kv_size, bias=False) + self.o_proj = Linear(self.q_size, self.hidden_size, bias=False) self.q_norm = Gemma3nRMSNorm(self.head_dim) self.k_norm = Gemma3nRMSNorm(self.head_dim) @@ -426,13 +428,13 @@ def __init__(self, config, layer_id: int): tc.hidden_size, eps=tc.rms_norm_eps ) - self.altup = SimpleAltUp(config) - self.laurel = SimpleLaurelBlock(config) + self.altup = Gemma3nAltUp(config) + self.laurel = Gemma3nLaurelBlock(config) self.hidden_size_per_layer_input = getattr(tc, "hidden_size_per_layer_input", 8) - self.per_layer_input_gate = nn.Linear( + self.per_layer_input_gate = Linear( tc.hidden_size, self.hidden_size_per_layer_input, bias=False ) - self.per_layer_projection = nn.Linear( + self.per_layer_projection = Linear( self.hidden_size_per_layer_input, tc.hidden_size, bias=False ) self.post_per_layer_input_norm = Gemma3nRMSNorm( @@ -492,7 +494,7 @@ def __init__(self, config): tc = _get_text_config(config) self.config = config - self.embed_tokens = SimpleScaledWordEmbedding(tc.vocab_size, tc.hidden_size, embed_scale=tc.hidden_size ** 0.5) + self.embed_tokens = Gemma3nScaledWordEmbedding(tc.vocab_size, tc.hidden_size, embed_scale=tc.hidden_size ** 0.5) self.layers = nn.ModuleList( [Gemma3nDecoderLayer(config, i) for i in range(tc.num_hidden_layers)] ) @@ -503,12 +505,12 @@ def __init__(self, config): self.vocab_size_per_layer_input = getattr(tc, "vocab_size_per_layer_input", tc.vocab_size) self.altup_num_inputs = getattr(tc, "altup_num_inputs", 2) - self.embed_tokens_per_layer = SimpleScaledWordEmbedding( + self.embed_tokens_per_layer = Gemma3nScaledWordEmbedding( self.vocab_size_per_layer_input, tc.num_hidden_layers * self.hidden_size_per_layer_input, embed_scale=self.hidden_size_per_layer_input ** 0.5, ) - self.per_layer_model_projection = nn.Linear( + self.per_layer_model_projection = Linear( tc.hidden_size, tc.num_hidden_layers * self.hidden_size_per_layer_input, bias=False, @@ -518,10 +520,10 @@ def __init__(self, config): ) self.altup_projections = nn.ModuleList( - [nn.Linear(tc.hidden_size, tc.hidden_size, bias=False) for _ in range(1, self.altup_num_inputs)] + [Linear(tc.hidden_size, tc.hidden_size, bias=False) for _ in range(1, self.altup_num_inputs)] ) self.altup_unembed_projections = nn.ModuleList( - [nn.Linear(tc.hidden_size, tc.hidden_size, bias=False) for _ in range(1, self.altup_num_inputs)] + [Linear(tc.hidden_size, tc.hidden_size, bias=False) for _ in range(1, self.altup_num_inputs)] ) def forward( From 3a962ccb22575b57c370f29ed04f404482ef4fcd Mon Sep 17 00:00:00 2001 From: grape203 Date: Tue, 5 May 2026 23:53:37 +0800 Subject: [PATCH 04/12] Add RadixAttention metadata to Gemma3n attention --- pymllm/models/gemma3n.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/pymllm/models/gemma3n.py b/pymllm/models/gemma3n.py index 6dc21898..48898950 100644 --- a/pymllm/models/gemma3n.py +++ b/pymllm/models/gemma3n.py @@ -8,6 +8,7 @@ import torch.nn.functional as F from pymllm.layers.linear import Linear +from pymllm.layers.attention.radix_attention import RadixAttention logger = logging.getLogger(__name__) @@ -278,6 +279,27 @@ def __init__(self, config, layer_id: int): self.v_norm = Gemma3nRMSNormNoWeight(eps=getattr(tc, "rms_norm_eps", 1e-6)) self.sliding_window = getattr(tc, "sliding_window", None) + self.sliding_window_size = ( + int(self.sliding_window) + if self.layer_type == "sliding_attention" and self.sliding_window is not None + else -1 + ) + query_pre_attn_scalar = getattr(tc, "query_pre_attn_scalar", self.head_dim) + self.scaling = float(query_pre_attn_scalar) ** -0.5 + + # RadixAttention is the pymllm-native attention layer. Keep the eager + # full-context path below as a correctness fallback for the current + # CPU-first Gemma3n text-only implementation, but configure RadixAttention + # here so Gemma3n layers carry the correct per-layer SWA metadata: + # sliding layers use tc.sliding_window, full layers use -1. + self.attn = RadixAttention( + num_heads=self.num_heads, + head_dim=self.head_dim, + scaling=self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + sliding_window_size=self.sliding_window_size, + ) num_kv_shared_layers = int(getattr(tc, "num_kv_shared_layers", 0)) first_kv_shared_layer_idx = int(tc.num_hidden_layers) - num_kv_shared_layers From a7b66e4420746b3999f923a44d208d6d9314e475 Mon Sep 17 00:00:00 2001 From: grape203 Date: Tue, 5 May 2026 23:56:20 +0800 Subject: [PATCH 05/12] Enable sliding-window attention backend initialization --- pymllm/executor/model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pymllm/executor/model_runner.py b/pymllm/executor/model_runner.py index 1ced172f..b9202215 100644 --- a/pymllm/executor/model_runner.py +++ b/pymllm/executor/model_runner.py @@ -919,6 +919,7 @@ def init_attention_backend(self) -> None: req_to_token=self.req_to_token_pool.req_to_token, device=torch.device(self.device), max_req_pool_size=self.req_to_token_pool.size, + sliding_window_size=self.sliding_window_size, ) if self.gdn_pool is not None: From 99ad51b2146d6a6b79aa3bed48bfb43ad8e23662 Mon Sep 17 00:00:00 2001 From: grape203 Date: Wed, 6 May 2026 00:01:56 +0800 Subject: [PATCH 06/12] Add guarded RadixAttention path for Gemma3n attention --- pymllm/models/gemma3n.py | 42 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/pymllm/models/gemma3n.py b/pymllm/models/gemma3n.py index 48898950..819edc7b 100644 --- a/pymllm/models/gemma3n.py +++ b/pymllm/models/gemma3n.py @@ -336,6 +336,34 @@ def __init__(self, config, layer_id: int): self.rope_theta = float(rope_theta) + def _can_use_radix_attention( + self, + hidden_states: torch.Tensor, + forward_batch: Any, + ) -> bool: + """Return whether this layer can use pymllm RadixAttention. + + The current text-only direct verification path is CPU-first and calls + attention with ``forward_batch=None``; that path must keep the eager + implementation below. The guarded RadixAttention path is intended for + the server CUDA/KV-cache path. + + Gemma3n KV-sharing layers cannot be routed to RadixAttention yet because + the current FlashInfer backend reads/writes KV cache by ``layer_id`` and + does not redirect a layer to another layer's shared KV buffer. + """ + if self.is_kv_shared_layer: + return False + if forward_batch is None: + return False + if getattr(forward_batch, "attn_backend", None) is None: + return False + if getattr(forward_batch, "token_to_kv_pool", None) is None: + return False + if getattr(forward_batch, "out_cache_loc", None) is None: + return False + return bool(hidden_states.is_cuda) + def _build_attention_mask( self, positions: torch.Tensor, @@ -401,6 +429,20 @@ def forward( if shared_kv_cache is not None and self.store_full_length_kv: shared_kv_cache[self.layer_id] = (k, v) + if self._can_use_radix_attention(hidden_states, forward_batch): + q_flat = q.transpose(1, 2).contiguous().reshape( + batch_size * seq_len, self.q_size + ) + k_flat = k.transpose(1, 2).contiguous().reshape( + batch_size * seq_len, self.kv_size + ) + v_flat = v.transpose(1, 2).contiguous().reshape( + batch_size * seq_len, self.kv_size + ) + attn_output = self.attn(q_flat, k_flat, v_flat, forward_batch) + attn_output = attn_output.view(batch_size, seq_len, self.q_size) + return self.o_proj(attn_output) + if self.num_heads != self.num_kv_heads: n_rep = self.num_heads // self.num_kv_heads k = k[:, :, None, :, :].expand(batch_size, self.num_kv_heads, n_rep, seq_len, self.head_dim) From 485591b89a884caf5ccd99914fd70cbb914a0db5 Mon Sep 17 00:00:00 2001 From: grape203 Date: Wed, 6 May 2026 00:16:25 +0800 Subject: [PATCH 07/12] Move Gemma3n compute modules to runtime device --- pymllm/executor/model_runner.py | 8 ++++++ pymllm/models/gemma3n.py | 49 +++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/pymllm/executor/model_runner.py b/pymllm/executor/model_runner.py index b9202215..0644a3cc 100644 --- a/pymllm/executor/model_runner.py +++ b/pymllm/executor/model_runner.py @@ -649,6 +649,14 @@ def load_model(self) -> None: if quant_method is not None and hasattr(quant_method, "process_weights_after_loading"): quant_method.process_weights_after_loading(module) + move_compute_modules = getattr(self.model, "move_compute_modules_to_device", None) + if callable(move_compute_modules) and self.device == "cuda": + logger.info( + "Moving model compute modules to runtime device after weight loading: %s", + device_str, + ) + move_compute_modules(torch.device(device_str)) + self.model.eval() after_mem = get_available_gpu_memory(self.device, self.gpu_id) diff --git a/pymllm/models/gemma3n.py b/pymllm/models/gemma3n.py index 819edc7b..ddad229d 100644 --- a/pymllm/models/gemma3n.py +++ b/pymllm/models/gemma3n.py @@ -589,6 +589,33 @@ def __init__(self, config): self.altup_unembed_projections = nn.ModuleList( [Linear(tc.hidden_size, tc.hidden_size, bias=False) for _ in range(1, self.altup_num_inputs)] ) + def move_compute_modules_to_device(self, device): + """Move Gemma3n compute modules to the runtime device. + + The two large embedding tables intentionally remain on CPU. This keeps + memory usage manageable while allowing decoder computation to run on + CUDA and making the guarded RadixAttention path reachable in the server + path once ``forward_batch`` is passed through. + """ + device = torch.device(device) + + self.layers.to(device) + self.norm.to(device) + self.per_layer_model_projection.to(device) + self.per_layer_projection_norm.to(device) + self.altup_projections.to(device) + self.altup_unembed_projections.to(device) + + logger.info( + "Gemma3n compute modules moved to %s; " + "embed_tokens=%s embed_tokens_per_layer=%s norm=%s", + device, + self.embed_tokens.weight.device, + self.embed_tokens_per_layer.weight.device, + self.norm.weight.device, + ) + return self + def forward( self, @@ -723,6 +750,28 @@ def __init__(self, config, quant_config=None, prefix: str = ""): # recompute the full context, returning only the last-token logits. self._native_cached_input_ids = None self._native_cached_positions = None + def move_compute_modules_to_device(self, device): + """Move Gemma3n decoder compute modules to the runtime device. + + Keep ``lm_head`` on CPU because it is tied from ``embed_tokens`` and is + very large. ``forward`` moves hidden states to the lm_head device for + logits and then returns logits on the decoder output device. + """ + device = torch.device(device) + + self.model.move_compute_modules_to_device(device) + self.lm_head.to(torch.device("cpu")) + + logger.info( + "Gemma3n device split: embed_tokens=%s embed_tokens_per_layer=%s " + "decoder_norm=%s lm_head=%s", + self.model.embed_tokens.weight.device, + self.model.embed_tokens_per_layer.weight.device, + self.model.norm.weight.device, + self.lm_head.weight.device, + ) + return self + def forward( self, From 5f64a1f7a6211252483381950603f0dfe8350351 Mon Sep 17 00:00:00 2001 From: grape203 Date: Wed, 6 May 2026 00:23:24 +0800 Subject: [PATCH 08/12] Route Gemma3n prefill through guarded RadixAttention --- pymllm/models/gemma3n.py | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/pymllm/models/gemma3n.py b/pymllm/models/gemma3n.py index ddad229d..8b2e7b44 100644 --- a/pymllm/models/gemma3n.py +++ b/pymllm/models/gemma3n.py @@ -300,6 +300,7 @@ def __init__(self, config, layer_id: int): layer_id=layer_id, sliding_window_size=self.sliding_window_size, ) + self._radix_path_logged = False num_kv_shared_layers = int(getattr(tc, "num_kv_shared_layers", 0)) first_kv_shared_layer_idx = int(tc.num_hidden_layers) - num_kv_shared_layers @@ -439,6 +440,16 @@ def forward( v_flat = v.transpose(1, 2).contiguous().reshape( batch_size * seq_len, self.kv_size ) + if not self._radix_path_logged: + logger.info( + "Gemma3n RadixAttention path active: layer=%d type=%s " + "sliding_window_size=%s tokens=%d", + self.layer_id, + self.layer_type, + self.sliding_window_size, + batch_size * seq_len, + ) + self._radix_path_logged = True attn_output = self.attn(q_flat, k_flat, v_flat, forward_batch) attn_output = attn_output.view(batch_size, seq_len, self.q_size) return self.o_proj(attn_output) @@ -694,7 +705,14 @@ def forward( hidden_states = torch.stack(temp_hidden_states, dim=0) # [P, B, S, H] - native_forward_batch = {"kv_shared_cache": {}} + if forward_batch is None: + native_forward_batch = {"kv_shared_cache": {}} + elif isinstance(forward_batch, dict): + native_forward_batch = forward_batch + native_forward_batch.setdefault("kv_shared_cache", {}) + else: + native_forward_batch = forward_batch + setattr(native_forward_batch, "kv_shared_cache", {}) for layer_idx, layer in enumerate(self.layers): hidden_states = layer( @@ -865,7 +883,16 @@ def forward( self._native_cached_input_ids = full_input_ids.detach().cpu() self._native_cached_positions = full_positions.detach().cpu() - hidden_states = self.model(input_ids=full_input_ids, positions=full_positions) + if is_extend_mode and forward_batch_obj is not None: + model_forward_batch = forward_batch_obj + else: + model_forward_batch = {"kv_shared_cache": {}} + + hidden_states = self.model( + input_ids=full_input_ids, + positions=full_positions, + forward_batch=model_forward_batch, + ) output_device = hidden_states.device logits = self.lm_head(hidden_states.to(device=self.lm_head.weight.device)) From a7259f7d1e3e3e74c9a715b029b412fe0905acc7 Mon Sep 17 00:00:00 2001 From: grape203 Date: Wed, 6 May 2026 00:44:23 +0800 Subject: [PATCH 09/12] Match Gemma3n RadixAttention scaling to eager path --- pymllm/models/gemma3n.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pymllm/models/gemma3n.py b/pymllm/models/gemma3n.py index 8b2e7b44..26836104 100644 --- a/pymllm/models/gemma3n.py +++ b/pymllm/models/gemma3n.py @@ -295,7 +295,9 @@ def __init__(self, config, layer_id: int): self.attn = RadixAttention( num_heads=self.num_heads, head_dim=self.head_dim, - scaling=self.scaling, + # Match the currently verified eager Gemma3n attention path, which + # applies softmax to q @ k^T without an additional scale factor. + scaling=1.0, num_kv_heads=self.num_kv_heads, layer_id=layer_id, sliding_window_size=self.sliding_window_size, From 71a6f5b32701ed3f7c5f94b24ab8f5756ef38118 Mon Sep 17 00:00:00 2001 From: grape203 Date: Wed, 6 May 2026 00:49:53 +0800 Subject: [PATCH 10/12] Clean up Gemma3n RadixAttention logging --- pymllm/models/gemma3n.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pymllm/models/gemma3n.py b/pymllm/models/gemma3n.py index 26836104..50bfdaf3 100644 --- a/pymllm/models/gemma3n.py +++ b/pymllm/models/gemma3n.py @@ -443,7 +443,7 @@ def forward( batch_size * seq_len, self.kv_size ) if not self._radix_path_logged: - logger.info( + logger.debug( "Gemma3n RadixAttention path active: layer=%d type=%s " "sliding_window_size=%s tokens=%d", self.layer_id, @@ -602,6 +602,7 @@ def __init__(self, config): self.altup_unembed_projections = nn.ModuleList( [Linear(tc.hidden_size, tc.hidden_size, bias=False) for _ in range(1, self.altup_num_inputs)] ) + def move_compute_modules_to_device(self, device): """Move Gemma3n compute modules to the runtime device. @@ -770,6 +771,7 @@ def __init__(self, config, quant_config=None, prefix: str = ""): # recompute the full context, returning only the last-token logits. self._native_cached_input_ids = None self._native_cached_positions = None + def move_compute_modules_to_device(self, device): """Move Gemma3n decoder compute modules to the runtime device. From f309835cec6591582a2da9bacf08b8f6effefeec Mon Sep 17 00:00:00 2001 From: grape203 Date: Wed, 6 May 2026 01:03:16 +0800 Subject: [PATCH 11/12] Move Gemma3n-specific layers into pymllm layers --- pymllm/layers/__init__.py | 8 +++ pymllm/layers/gemma3n.py | 128 ++++++++++++++++++++++++++++++++++++++ pymllm/models/gemma3n.py | 87 ++------------------------ 3 files changed, 141 insertions(+), 82 deletions(-) create mode 100644 pymllm/layers/gemma3n.py diff --git a/pymllm/layers/__init__.py b/pymllm/layers/__init__.py index 2ecb1396..d8257f13 100644 --- a/pymllm/layers/__init__.py +++ b/pymllm/layers/__init__.py @@ -6,6 +6,11 @@ from pymllm.layers.linear import ColumnParallelLinear, Linear, RowParallelLinear from pymllm.layers.mlp import MLP, ParallelMLP from pymllm.layers.rms_norm import GemmaRMSNorm, RMSNorm +from pymllm.layers.gemma3n import ( + Gemma3nMLP, + Gemma3nRMSNorm, + Gemma3nRMSNormNoWeight, +) from pymllm.layers.rms_norm_gated import RMSNormGated from pymllm.layers.gated_delta_net import GatedDeltaNet from pymllm.layers.rope import ( @@ -62,4 +67,7 @@ "top_k_renorm_probs", "top_k_mask_logits", "chain_speculative_sampling", + "Gemma3nMLP", + "Gemma3nRMSNorm", + "Gemma3nRMSNormNoWeight", ] diff --git a/pymllm/layers/gemma3n.py b/pymllm/layers/gemma3n.py new file mode 100644 index 00000000..f9e7ace4 --- /dev/null +++ b/pymllm/layers/gemma3n.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import Parameter + +from pymllm.layers.base import MllmBaseLayer +from pymllm.layers.linear import Linear +from pymllm.layers.utils import set_weight_attrs + + +def _get_gemma3n_hidden_act_fn(name: str): + name = (name or "silu").lower() + if name in ("silu", "swish"): + return F.silu + if name == "relu": + return F.relu + if name == "gelu": + return F.gelu + if name in ("gelu_tanh", "gelu_pytorch_tanh"): + return lambda x: F.gelu(x, approximate="tanh") + raise ValueError(f"Unsupported Gemma3n activation: {name}") + + +class Gemma3nRMSNorm(MllmBaseLayer): + """Gemma3n RMSNorm used by the native text-only implementation. + + This intentionally preserves the numerics of the verified Gemma3n path. + It is not replaced by the generic ``GemmaRMSNorm`` because the generic + FlashInfer-backed layer has different behavior for this checkpoint path. + """ + + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + self.weight = Parameter(torch.empty(hidden_size)) + set_weight_attrs(self.weight, {"weight_loader": self.weight_loader}) + + def _norm(self, x: torch.Tensor) -> torch.Tensor: + return x / torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + def forward( + self, + x: torch.Tensor, + residual: torch.Tensor | None = None, + ): + if x.shape[-1] != self.hidden_size: + raise ValueError( + f"Expected last dim == hidden_size ({self.hidden_size}), " + f"but got input shape {tuple(x.shape)}" + ) + + if residual is not None: + residual = residual + x + x = residual + + out = self._norm(x.float()) * self.weight.float() + out = out.to(dtype=x.dtype) + + if residual is not None: + return out, residual + return out + + +class Gemma3nRMSNormNoWeight(MllmBaseLayer): + """Weight-free RMSNorm used by Gemma3n value normalization.""" + + def __init__(self, eps: float = 1e-6): + super().__init__() + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_dtype = x.dtype + x_fp32 = x.float() + var = x_fp32.pow(2).mean(dim=-1, keepdim=True) + out = x_fp32 * torch.rsqrt(var + self.eps) + return out.to(x_dtype) + + +class Gemma3nMLP(MllmBaseLayer): + """Gemma3n feed-forward block. + + Uses pymllm ``Linear`` projections while preserving Gemma3n's optional + activation sparsity branch, which is not implemented by the generic + ``pymllm.layers.MLP``. + """ + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + activation: str, + activation_sparsity: float = 0.0, + ): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.activation_name = activation + self.act = _get_gemma3n_hidden_act_fn(activation) + self.activation_sparsity = float(activation_sparsity) + + self.gate_proj = Linear(hidden_size, intermediate_size, bias=False) + self.up_proj = Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = Linear(intermediate_size, hidden_size, bias=False) + + def _gaussian_topk(self, inputs: torch.Tensor) -> torch.Tensor: + target_sparsity_tensor = torch.tensor( + self.activation_sparsity, + dtype=torch.float32, + device=inputs.device, + ) + normal_dist = torch.distributions.normal.Normal(0, 1) + std_multiplier = normal_dist.icdf(target_sparsity_tensor).to(dtype=inputs.dtype) + inputs_mean = torch.mean(inputs, dim=-1, keepdim=True) + inputs_std = torch.std(inputs, dim=-1, keepdim=True, unbiased=False) + cutoff_x = inputs_mean + inputs_std * std_multiplier + return F.relu(inputs - cutoff_x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate = self.gate_proj(x) + if self.activation_sparsity > 0.0: + gate = self._gaussian_topk(gate) + gate = self.act(gate) + up = self.up_proj(x) + hidden = gate * up + return self.down_proj(hidden) diff --git a/pymllm/models/gemma3n.py b/pymllm/models/gemma3n.py index 50bfdaf3..10260015 100644 --- a/pymllm/models/gemma3n.py +++ b/pymllm/models/gemma3n.py @@ -9,6 +9,11 @@ from pymllm.layers.linear import Linear from pymllm.layers.attention.radix_attention import RadixAttention +from pymllm.layers.gemma3n import ( + Gemma3nMLP, + Gemma3nRMSNorm, + Gemma3nRMSNormNoWeight, +) logger = logging.getLogger(__name__) @@ -49,53 +54,11 @@ def _get_intermediate_size(config, layer_id: int) -> int: return int(intermediate_size[layer_id]) -class Gemma3nRMSNorm(nn.Module): - def __init__(self, hidden_size: int, eps: float = 1e-6): - super().__init__() - self.hidden_size = hidden_size - self.eps = eps - self.weight = nn.Parameter(torch.ones(hidden_size)) - - def _norm(self, x: torch.Tensor) -> torch.Tensor: - return x / torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) - - def forward( - self, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ): - if x.shape[-1] != self.hidden_size: - raise ValueError( - f"Expected last dim == hidden_size ({self.hidden_size}), " - f"but got input shape {tuple(x.shape)}" - ) - - if residual is not None: - residual = residual + x - x = residual - out = self._norm(x.float()) * self.weight.float() - out = out.to(dtype=x.dtype) - if residual is not None: - return out, residual - return out -class Gemma3nRMSNormNoWeight(nn.Module): - def __init__(self, eps: float = 1e-6): - super().__init__() - self.eps = eps - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x_dtype = x.dtype - x_fp32 = x.float() - var = x_fp32.pow(2).mean(dim=-1, keepdim=True) - out = x_fp32 * torch.rsqrt(var + self.eps) - return out.to(x_dtype) - - def _rotate_half(x: torch.Tensor) -> torch.Tensor: x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] @@ -127,46 +90,6 @@ def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch. return (x * cos) + (_rotate_half(x) * sin) -class Gemma3nMLP(nn.Module): - def __init__( - self, - hidden_size: int, - intermediate_size: int, - activation: str, - activation_sparsity: float = 0.0, - ): - super().__init__() - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.activation_name = activation - self.act = _get_hidden_act_fn(activation) - self.activation_sparsity = float(activation_sparsity) - - self.gate_proj = Linear(hidden_size, intermediate_size, bias=False) - self.up_proj = Linear(hidden_size, intermediate_size, bias=False) - self.down_proj = Linear(intermediate_size, hidden_size, bias=False) - - def _gaussian_topk(self, inputs: torch.Tensor) -> torch.Tensor: - target_sparsity_tensor = torch.tensor( - self.activation_sparsity, - dtype=torch.float32, - device=inputs.device, - ) - normal_dist = torch.distributions.normal.Normal(0, 1) - std_multiplier = normal_dist.icdf(target_sparsity_tensor).to(dtype=inputs.dtype) - inputs_mean = torch.mean(inputs, dim=-1, keepdim=True) - inputs_std = torch.std(inputs, dim=-1, keepdim=True, unbiased=False) - cutoff_x = inputs_mean + inputs_std * std_multiplier - return F.relu(inputs - cutoff_x) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - gate = self.gate_proj(x) - if self.activation_sparsity > 0.0: - gate = self._gaussian_topk(gate) - gate = self.act(gate) - up = self.up_proj(x) - hidden = gate * up - return self.down_proj(hidden) From 9d2a6e481c74d38eca88ee0c412741f0a8bd747c Mon Sep 17 00:00:00 2001 From: grape203 Date: Wed, 6 May 2026 01:24:32 +0800 Subject: [PATCH 12/12] Address CodeRabbit comments for Gemma3n layers --- pymllm/layers/gemma3n.py | 18 +++++++++++++----- pymllm/models/gemma3n.py | 13 ++++--------- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/pymllm/layers/gemma3n.py b/pymllm/layers/gemma3n.py index f9e7ace4..ce962931 100644 --- a/pymllm/layers/gemma3n.py +++ b/pymllm/layers/gemma3n.py @@ -101,18 +101,26 @@ def __init__( self.act = _get_gemma3n_hidden_act_fn(activation) self.activation_sparsity = float(activation_sparsity) + if self.activation_sparsity > 0.0: + normal_dist = torch.distributions.normal.Normal(0.0, 1.0) + std_multiplier = normal_dist.icdf( + torch.tensor(self.activation_sparsity, dtype=torch.float32) + ) + self.register_buffer( + "_std_multiplier", + std_multiplier, + persistent=False, + ) + self.gate_proj = Linear(hidden_size, intermediate_size, bias=False) self.up_proj = Linear(hidden_size, intermediate_size, bias=False) self.down_proj = Linear(intermediate_size, hidden_size, bias=False) def _gaussian_topk(self, inputs: torch.Tensor) -> torch.Tensor: - target_sparsity_tensor = torch.tensor( - self.activation_sparsity, - dtype=torch.float32, + std_multiplier = self._std_multiplier.to( device=inputs.device, + dtype=inputs.dtype, ) - normal_dist = torch.distributions.normal.Normal(0, 1) - std_multiplier = normal_dist.icdf(target_sparsity_tensor).to(dtype=inputs.dtype) inputs_mean = torch.mean(inputs, dim=-1, keepdim=True) inputs_std = torch.std(inputs, dim=-1, keepdim=True, unbiased=False) cutoff_x = inputs_mean + inputs_std * std_multiplier diff --git a/pymllm/models/gemma3n.py b/pymllm/models/gemma3n.py index 10260015..0c23056e 100644 --- a/pymllm/models/gemma3n.py +++ b/pymllm/models/gemma3n.py @@ -334,7 +334,7 @@ def forward( if isinstance(forward_batch, dict): shared_kv_cache = forward_batch.get("kv_shared_cache") elif forward_batch is not None and hasattr(forward_batch, "kv_shared_cache"): - shared_kv_cache = getattr(forward_batch, "kv_shared_cache") + shared_kv_cache = forward_batch.kv_shared_cache if ( self.is_kv_shared_layer @@ -638,7 +638,7 @@ def forward( native_forward_batch.setdefault("kv_shared_cache", {}) else: native_forward_batch = forward_batch - setattr(native_forward_batch, "kv_shared_cache", {}) + native_forward_batch.kv_shared_cache = {} for layer_idx, layer in enumerate(self.layers): hidden_states = layer( @@ -925,11 +925,6 @@ def _copy_tensor_streaming(safe_file, raw_name: str, target: torch.Tensor): else: st_files = sorted(model_path.glob("*.safetensors")) if not st_files: - logger.info( - "Gemma3n streaming loader found no safetensors under %s; " - "falling back to regular load_weights path.", - model_path, - ) raise FileNotFoundError( f"No safetensors checkpoint shards found under {model_path}. " "Gemma3n native loading currently expects safetensors weights." @@ -1033,11 +1028,11 @@ def load_weights(self, weights): else: try: weight_items = iter(weights) - except TypeError: + except TypeError as err: raise TypeError( f"weights must be a dict-like state_dict, a module with state_dict(), " f"or an iterable of (name, tensor), got {type(weights)}" - ) + ) from err def _normalize_name(name: str): # Ignore clearly multimodal-only weights for current text-only path.